Of course this was another problem that requires you to stop searching for answers and instead create them.
I first found all of the integer partitions for numbers 1 to 9. These lists, plus the number they are the partition of, serve as the basis for each digit sum number. There is a pretty simple formula for then calculating the sum for each of these partitions. They key to getting it to run quickly was making sure to use a summing algorithm that was linear time instead of exponential time.
import more_itertools as miter
modulo = 10**16
_fact = {0: 1, 1: 1}
def factorial(n):
if n in _fact:
return _fact[n]
fact = n * factorial(n - 1)
_fact[n] = fact
return fact
def sum_permutations(arr, n):
digits = {}
for d in arr:
digits[d] = digits.get(d, 0) + 1
perms = factorial(n)
for count in digits.values():
perms //= factorial(count)
num = 0
for digit, count in digits.items():
num += digit * perms * count // n
return sum(num * 10**e for e in range(n))
def search(arr, index, left, summ):
total = 0
start = arr[index-1] if index else 0
for k in range(start, 10):
s = summ + k
if s > 9:
break
arr[index] = k
if left:
total += search(arr, index+1, left-1, s)
else:
arr[-1] = s
total += sum_permutations(arr, index+2)
return total % modulo
def solve(n):
return search([0] * n, 0, n-2, 0)
if __name__ == '__main__':
import sys
n = eval(sys.argv[1])
sys.setrecursionlimit(n * n)
print(solve(n))
import pytest
from problem import solve
_test_solve = (
#(1, 0),
(2, 495),
(3, 63270),
(4, 3149685),
(5, 112398876),
(6, 3311663355),
(7, 85499991450),
(8, 1998499980015),
)
@pytest.mark.parametrize('n,expect', _test_solve)
def test_solve(n, expect):
assert expect == solve(n)