Problem 749

completed March 7, 2021

Comments

Searching through all numbers below \(10^{16}\) to see if they are near power sums is unreasonable. Instead, I searched through all permutations of 16 digits 0 - 9. This was much more manageable at just over 2 million permutations. For each of these, I raised the digits to k powers and then checked to see if this value was a permutation of the original permutation.

The toughest part of this was dealing with answers and sums that end in either 0 or 9. For these, there was more than one digit that was different between the two numbers. I tried several algorithms for this and ultimately settled on what you see below.

It’s pretty slow, most of the time is taken up in the digits method. Runs in just under 50 seconds on Pypy3.

Code

def solve(n):

    def permutations(n):
        arr = [0] * n

        def _perm(depth=0):
            if depth == n:
                if any(d > 1 for d in arr):
                    yield arr
                return
            start = arr[depth-1] if depth else 0
            for i in range(start, 10):
                arr[depth] = i
                yield from _perm(depth+1)

        yield from _perm()

    def digits(summ):
        digs = []
        while summ:
            summ, d = divmod(summ, 10)
            digs.append(d)
        return digs + [0] * (n - len(digs))

    def nearpower(arr, summ):
        sorted_arr = tuple(sorted(arr))
        # summ + 1
        sum_digs = digits(summ + 1)
        if tuple(sorted(sum_digs)) == sorted_arr:
            return summ + 1

        # summ - 1
        sum_digs = digits(summ - 1)
        if tuple(sorted(sum_digs)) == sorted_arr:
            return summ - 1

        return 0

    top = 10 ** n
    total = 0

    for p in permutations(n):
        k = 1
        while True:
            summ = sum(d**k for d in p)
            if summ > top:
                break
            num = nearpower(p, summ)
            if num:
                total += num
                break
            k += 1

    return total

if __name__ == '__main__':
    import sys
    n = eval(sys.argv[1])
    print(solve(n))

Tests

import pytest
from problem import solve

_test_solve = (
        (2, 110),
        (3, 110),
        (4, 110),
        (5, 110),
        (6, 2562701),
        (7, 2562701),
        (8, 381784152),
        (9, 1144909106),
)

@pytest.mark.parametrize('n,expect', _test_solve)
def test_solve(n, expect):
    assert expect == solve(n)