Problem 725

completed November 22, 2020

Comments

Of course this was another problem that requires you to stop searching for answers and instead create them.

I first found all of the integer partitions for numbers 1 to 9. These lists, plus the number they are the partition of, serve as the basis for each digit sum number. There is a pretty simple formula for then calculating the sum for each of these partitions. They key to getting it to run quickly was making sure to use a summing algorithm that was linear time instead of exponential time.

Code

import more_itertools as miter

modulo = 10**16

_fact = {0: 1, 1: 1}
def factorial(n):
    if n in _fact:
        return _fact[n]
    fact = n * factorial(n - 1)
    _fact[n] = fact
    return fact

def sum_permutations(arr, n):
    digits = {}
    for d in arr:
        digits[d] = digits.get(d, 0) + 1

    perms = factorial(n)
    for count in digits.values():
        perms //= factorial(count)

    num = 0
    for digit, count in digits.items():
        num += digit * perms * count // n

    return sum(num * 10**e for e in range(n))

def search(arr, index, left, summ):
    total = 0
    start = arr[index-1] if index else 0
    for k in range(start, 10):
        s = summ + k
        if s > 9:
            break
        arr[index] = k
        if left:
            total += search(arr, index+1, left-1, s)
        else:
            arr[-1] = s
            total += sum_permutations(arr, index+2)
    return total % modulo

def solve(n):
    return search([0] * n, 0, n-2, 0)

if __name__ == '__main__':
    import sys
    n = eval(sys.argv[1])
    sys.setrecursionlimit(n * n)
    print(solve(n))

Tests

import pytest
from problem import solve

_test_solve = (
        #(1, 0),
        (2, 495),
        (3, 63270),
        (4, 3149685),
        (5, 112398876),
        (6, 3311663355),
        (7, 85499991450),
        (8, 1998499980015),
)

@pytest.mark.parametrize('n,expect', _test_solve)
def test_solve(n, expect):
    assert expect == solve(n)