Problem 159

completed April 24, 2021

Comments

This problem felt pretty straightforward, yet I kept getting the wrong answer. I searched and searched and searched for what I was doing wrong – as can be seen in the huge number of tests that I wrote – but still couldn’t find it. After 5 days of struggle, I finally wrote tests for my digital_root function and finally found my mistake. My original implementation of this function was:

def digital_root(n):
    total = 0
    while n:
        n, mod = divmod(n, 10)
        total += mod
    if total > 10:
        return digital_root(total)
    return total

Can you spot the error? Well, the 6th line should have been a >= since returning 10 would not be a true digital root. Once I made that small change (and fixed all the tests that had been written with the old code (not very good tests were they)) I was able to get the answer immediately. Program runs in about 1.7 seconds with Pypy3.

The algorithm here is recursive. The process for finding the maximal digital root sum is to start with the digital root sum (DRS) for the original number. Then recursively split this number into its divisors, calculating the DRS of each along the way. Obviously the sum of the maximal DRS of a factorization of a number is equal to the maximal DRS for the number.

Code

def memoize(fn):
    _cache = {}
    def wrap(n):
        if n in _cache:
            return _cache[n]
        ret = fn(n)
        _cache[n] = ret
        return ret
    return wrap

def digital_root(n):
    return n % 9 or 9

def sieve(n):
    sieve = [False, False] + [True] * (n - 1)
    for p, isprime in enumerate(sieve):
        if isprime is not True:
            continue
        num = p * p
        while num <= n:
            sieve[num] = p
            num += p
    return sieve
primes = sieve(10**6)

@memoize
def divisors(n):
    if n == 1:
        return [1]
    p = primes[n]
    if p is True:
        return [1, n]
    n //= p
    e = 1
    while n % p == 0:
        n //= p
        e += 1
    divs = divisors(n)
    return sorted([d*p**exp for exp in range(e+1) for d in divs])

@memoize
def maximum_digital_root_sum(n):
    top = digital_root(n)
    stop = n ** 0.5
    for a in divisors(n)[1:]:
        if a > stop:
            return top
        b = n // a
        new_top = maximum_digital_root_sum(a) + maximum_digital_root_sum(b)
        if new_top > top:
            top = new_top

def solve(n):
    total = 0
    for k in range(2, n):
        total += maximum_digital_root_sum(k)
    return total

if __name__ == '__main__':
    import sys
    n = eval(sys.argv[1])
    print(solve(n))

Tests

import pytest
from problem import (
        solve, digital_root, maximum_digital_root_sum, divisors, sieve
)

_test_sieve = (
        (1, [False, False]),
        (2, [False, False, True]),
        (3, [False, False, True, True]),
        (4, [False, False, True, True, 2]),
        (5, [False, False, True, True, 2, True]),
        (6, [False, False, True, True, 2, True, 2]),
        (7, [False, False, True, True, 2, True, 2, True]),
        (8, [False, False, True, True, 2, True, 2, True, 2]),
        (9, [False, False, True, True, 2, True, 2, True, 2, 3]),
        (10, [False, False, True, True, 2, True, 2, True, 2, 3, 2]),
        (11, [False, False, True, True, 2, True, 2, True, 2, 3, 2, True]),
        (12, [False, False, True, True, 2, True, 2, True, 2, 3, 2, True, 3]),
        (13, [False, False, True, True, 2, True, 2, True, 2, 3, 2, True, 3,
            True]),
        (14, [False, False, True, True, 2, True, 2, True, 2, 3, 2, True, 3,
            True, 2]),
        (15, [False, False, True, True, 2, True, 2, True, 2, 3, 2, True, 3,
            True, 2, 3]),
        (16, [False, False, True, True, 2, True, 2, True, 2, 3, 2, True, 3,
            True, 2, 3, 2]),
        (17, [False, False, True, True, 2, True, 2, True, 2, 3, 2, True, 3,
            True, 2, 3, 2, True]),
        (18, [False, False, True, True, 2, True, 2, True, 2, 3, 2, True, 3,
            True, 2, 3, 2, True, 3]),
        (19, [False, False, True, True, 2, True, 2, True, 2, 3, 2, True, 3,
            True, 2, 3, 2, True, 3, True]),
        (20, [False, False, True, True, 2, True, 2, True, 2, 3, 2, True, 3,
            True, 2, 3, 2, True, 3, True, 2]),
        (21, [False, False, True, True, 2, True, 2, True, 2, 3, 2, True, 3,
            True, 2, 3, 2, True, 3, True, 2, 3]),
        (22, [False, False, True, True, 2, True, 2, True, 2, 3, 2, True, 3,
            True, 2, 3, 2, True, 3, True, 2, 3, 2]),
        (23, [False, False, True, True, 2, True, 2, True, 2, 3, 2, True, 3,
            True, 2, 3, 2, True, 3, True, 2, 3, 2, True]),
)

@pytest.mark.parametrize('n,expect', _test_sieve)
def test_sieve(n, expect):
    assert expect == sieve(n)

_test_divisors = (
        (1, [1]),
        (2, [1, 2]),
        (3, [1, 3]),
        (4, [1, 2, 4]),
        (5, [1, 5]),
        (6, [1, 2, 3, 6]),
        (7, [1, 7]),
        (8, [1, 2, 4, 8]),
        (9, [1, 3, 9]),
        (10, [1, 2, 5, 10]),
        (11, [1, 11]),
        (12, [1, 2, 3, 4, 6, 12]),
        (13, [1, 13]),
        (14, [1, 2, 7, 14]),
        (15, [1, 3, 5, 15]),
        (16, [1, 2, 4, 8, 16]),
        (17, [1, 17]),
        (18, [1, 2, 3, 6, 9, 18]),
        (19, [1, 19]),
        (20, [1, 2, 4, 5, 10, 20]),
        (21, [1, 3, 7, 21]),
        (22, [1, 2, 11, 22]),
        (23, [1, 23]),
        (24, [1, 2, 3, 4, 6, 8, 12, 24]),
        (25, [1, 5, 25]),
)

@pytest.mark.parametrize('n,expect', _test_divisors)
def test_divisors(n, expect):
    assert expect == divisors(n)

_test_digital_root = (
        (1, 1),
        (2, 2),
        (3, 3),
        (4, 4),
        (5, 5),
        (6, 6),
        (7, 7),
        (8, 8),
        (9, 9),
        (10, 1),
        (11, 2),
        (12, 3),
        (13, 4),
        (14, 5),
        (15, 6),
        (16, 7),
        (17, 8),
        (18, 9),
        (19, 1),
        (20, 2),
        (21, 3),
        (22, 4),
        (23, 5),
        (24, 6),
        (25, 7),
        (26, 8),
        (27, 9),
        (28, 1),
        (29, 2),
        (30, 3),
        (31, 4),
        (32, 5),
        (33, 6),
        (34, 7),
        (35, 8),
        (36, 9),
        (37, 1),
        (38, 2),
        (39, 3),
        (40, 4),
        (41, 5),
        (42, 6),
        (43, 7),
        (44, 8),
        (45, 9),
        (46, 1),
        (47, 2),
        (48, 3),
        (49, 4),
        (50, 5),
        (51, 6),
        (52, 7),
        (53, 8),
        (54, 9),
        (55, 1),
        (56, 2),
        (57, 3),
        (58, 4),
        (59, 5),
        (60, 6),
        (61, 7),
        (62, 8),
        (63, 9),
        (64, 1),
        (65, 2),
        (66, 3),
        (67, 4),
        (68, 5),
        (69, 6),
        (70, 7),
        (71, 8),
        (72, 9),
        (73, 1),
        (74, 2),
        (75, 3),
        (76, 4),
        (77, 5),
        (78, 6),
        (79, 7),
        (80, 8),
        (81, 9),
        (82, 1),
        (83, 2),
        (84, 3),
        (85, 4),
        (86, 5),
        (87, 6),
        (88, 7),
        (89, 8),
        (90, 9),
        (91, 1),
        (92, 2),
        (93, 3),
        (94, 4),
        (95, 5),
        (96, 6),
        (97, 7),
        (98, 8),
        (99, 9),
        (100, 1),
        (467, 8),
        (999999994, 4),
)

@pytest.mark.parametrize('n,expect', _test_digital_root)
def test_digital_root(n, expect):
    assert expect == digital_root(n)

_test_maximum_digital_root_sum = (
        (2, 2),
        (3, 3),
        (4, 4),
        (5, 5),
        (6, 6),
        (7, 7),
        (8, 8),
        (9, 9),
        (10, 7),
        (11, 2),
        (12, 8),
        (13, 4),
        (14, 9),
        (15, 8),
        (16, 10),
        (17, 8),
        (18, 11),
        (19, 1),
        (20, 9),
        (21, 10),
        (22, 4),
        (23, 5),
        (24, 11),
        (25, 10),
        (26, 8),
        (27, 12),
        (28, 11),
        (29, 2),
        (30, 11),
        (31, 4),
        (32, 12),
        (33, 6),
        (34, 10),
        (35, 12),
        (36, 13),
        (37, 1),
        (38, 3),
        (39, 7),
        (40, 13),
        (41, 5),
        (42, 13),
        (43, 7),
        (44, 8),
        (45, 14),
        (46, 7),
        (47, 2),
        (48, 14),
        (49, 14),
        (50, 12),
)

@pytest.mark.parametrize('n,expect', _test_maximum_digital_root_sum)
def test_maximum_digital_root_sum(n, expect):
    assert expect == maximum_digital_root_sum(n)

_test_not_solve = (
        (10**6, 45999954),
        (10**6, 45999908),
        (10**6, 17469850),
)

@pytest.mark.parametrize('n,expect', _test_not_solve)
def test_not_solve(n, expect):
    assert expect != solve(n)