This problem felt pretty straightforward, yet I kept getting the wrong answer.
I searched and searched and searched for what I was doing wrong – as can be
seen in the huge number of tests that I wrote – but still couldn’t find it.
After 5 days of struggle, I finally wrote tests for my digital_root
function
and finally found my mistake. My original implementation of this function was:
def digital_root(n):
total = 0
while n:
n, mod = divmod(n, 10)
total += mod
if total > 10:
return digital_root(total)
return total
Can you spot the error? Well, the 6th line should have been a >=
since
returning 10
would not be a true digital root. Once I made that small change
(and fixed all the tests that had been written with the old code (not very good
tests were they)) I was able to get the answer immediately. Program runs in
about 1.7 seconds with Pypy3.
The algorithm here is recursive. The process for finding the maximal digital root sum is to start with the digital root sum (DRS) for the original number. Then recursively split this number into its divisors, calculating the DRS of each along the way. Obviously the sum of the maximal DRS of a factorization of a number is equal to the maximal DRS for the number.
def memoize(fn):
_cache = {}
def wrap(n):
if n in _cache:
return _cache[n]
ret = fn(n)
_cache[n] = ret
return ret
return wrap
def digital_root(n):
return n % 9 or 9
def sieve(n):
sieve = [False, False] + [True] * (n - 1)
for p, isprime in enumerate(sieve):
if isprime is not True:
continue
num = p * p
while num <= n:
sieve[num] = p
num += p
return sieve
primes = sieve(10**6)
@memoize
def divisors(n):
if n == 1:
return [1]
p = primes[n]
if p is True:
return [1, n]
n //= p
e = 1
while n % p == 0:
n //= p
e += 1
divs = divisors(n)
return sorted([d*p**exp for exp in range(e+1) for d in divs])
@memoize
def maximum_digital_root_sum(n):
top = digital_root(n)
stop = n ** 0.5
for a in divisors(n)[1:]:
if a > stop:
return top
b = n // a
new_top = maximum_digital_root_sum(a) + maximum_digital_root_sum(b)
if new_top > top:
top = new_top
def solve(n):
total = 0
for k in range(2, n):
total += maximum_digital_root_sum(k)
return total
if __name__ == '__main__':
import sys
n = eval(sys.argv[1])
print(solve(n))
import pytest
from problem import (
solve, digital_root, maximum_digital_root_sum, divisors, sieve
)
_test_sieve = (
(1, [False, False]),
(2, [False, False, True]),
(3, [False, False, True, True]),
(4, [False, False, True, True, 2]),
(5, [False, False, True, True, 2, True]),
(6, [False, False, True, True, 2, True, 2]),
(7, [False, False, True, True, 2, True, 2, True]),
(8, [False, False, True, True, 2, True, 2, True, 2]),
(9, [False, False, True, True, 2, True, 2, True, 2, 3]),
(10, [False, False, True, True, 2, True, 2, True, 2, 3, 2]),
(11, [False, False, True, True, 2, True, 2, True, 2, 3, 2, True]),
(12, [False, False, True, True, 2, True, 2, True, 2, 3, 2, True, 3]),
(13, [False, False, True, True, 2, True, 2, True, 2, 3, 2, True, 3,
True]),
(14, [False, False, True, True, 2, True, 2, True, 2, 3, 2, True, 3,
True, 2]),
(15, [False, False, True, True, 2, True, 2, True, 2, 3, 2, True, 3,
True, 2, 3]),
(16, [False, False, True, True, 2, True, 2, True, 2, 3, 2, True, 3,
True, 2, 3, 2]),
(17, [False, False, True, True, 2, True, 2, True, 2, 3, 2, True, 3,
True, 2, 3, 2, True]),
(18, [False, False, True, True, 2, True, 2, True, 2, 3, 2, True, 3,
True, 2, 3, 2, True, 3]),
(19, [False, False, True, True, 2, True, 2, True, 2, 3, 2, True, 3,
True, 2, 3, 2, True, 3, True]),
(20, [False, False, True, True, 2, True, 2, True, 2, 3, 2, True, 3,
True, 2, 3, 2, True, 3, True, 2]),
(21, [False, False, True, True, 2, True, 2, True, 2, 3, 2, True, 3,
True, 2, 3, 2, True, 3, True, 2, 3]),
(22, [False, False, True, True, 2, True, 2, True, 2, 3, 2, True, 3,
True, 2, 3, 2, True, 3, True, 2, 3, 2]),
(23, [False, False, True, True, 2, True, 2, True, 2, 3, 2, True, 3,
True, 2, 3, 2, True, 3, True, 2, 3, 2, True]),
)
@pytest.mark.parametrize('n,expect', _test_sieve)
def test_sieve(n, expect):
assert expect == sieve(n)
_test_divisors = (
(1, [1]),
(2, [1, 2]),
(3, [1, 3]),
(4, [1, 2, 4]),
(5, [1, 5]),
(6, [1, 2, 3, 6]),
(7, [1, 7]),
(8, [1, 2, 4, 8]),
(9, [1, 3, 9]),
(10, [1, 2, 5, 10]),
(11, [1, 11]),
(12, [1, 2, 3, 4, 6, 12]),
(13, [1, 13]),
(14, [1, 2, 7, 14]),
(15, [1, 3, 5, 15]),
(16, [1, 2, 4, 8, 16]),
(17, [1, 17]),
(18, [1, 2, 3, 6, 9, 18]),
(19, [1, 19]),
(20, [1, 2, 4, 5, 10, 20]),
(21, [1, 3, 7, 21]),
(22, [1, 2, 11, 22]),
(23, [1, 23]),
(24, [1, 2, 3, 4, 6, 8, 12, 24]),
(25, [1, 5, 25]),
)
@pytest.mark.parametrize('n,expect', _test_divisors)
def test_divisors(n, expect):
assert expect == divisors(n)
_test_digital_root = (
(1, 1),
(2, 2),
(3, 3),
(4, 4),
(5, 5),
(6, 6),
(7, 7),
(8, 8),
(9, 9),
(10, 1),
(11, 2),
(12, 3),
(13, 4),
(14, 5),
(15, 6),
(16, 7),
(17, 8),
(18, 9),
(19, 1),
(20, 2),
(21, 3),
(22, 4),
(23, 5),
(24, 6),
(25, 7),
(26, 8),
(27, 9),
(28, 1),
(29, 2),
(30, 3),
(31, 4),
(32, 5),
(33, 6),
(34, 7),
(35, 8),
(36, 9),
(37, 1),
(38, 2),
(39, 3),
(40, 4),
(41, 5),
(42, 6),
(43, 7),
(44, 8),
(45, 9),
(46, 1),
(47, 2),
(48, 3),
(49, 4),
(50, 5),
(51, 6),
(52, 7),
(53, 8),
(54, 9),
(55, 1),
(56, 2),
(57, 3),
(58, 4),
(59, 5),
(60, 6),
(61, 7),
(62, 8),
(63, 9),
(64, 1),
(65, 2),
(66, 3),
(67, 4),
(68, 5),
(69, 6),
(70, 7),
(71, 8),
(72, 9),
(73, 1),
(74, 2),
(75, 3),
(76, 4),
(77, 5),
(78, 6),
(79, 7),
(80, 8),
(81, 9),
(82, 1),
(83, 2),
(84, 3),
(85, 4),
(86, 5),
(87, 6),
(88, 7),
(89, 8),
(90, 9),
(91, 1),
(92, 2),
(93, 3),
(94, 4),
(95, 5),
(96, 6),
(97, 7),
(98, 8),
(99, 9),
(100, 1),
(467, 8),
(999999994, 4),
)
@pytest.mark.parametrize('n,expect', _test_digital_root)
def test_digital_root(n, expect):
assert expect == digital_root(n)
_test_maximum_digital_root_sum = (
(2, 2),
(3, 3),
(4, 4),
(5, 5),
(6, 6),
(7, 7),
(8, 8),
(9, 9),
(10, 7),
(11, 2),
(12, 8),
(13, 4),
(14, 9),
(15, 8),
(16, 10),
(17, 8),
(18, 11),
(19, 1),
(20, 9),
(21, 10),
(22, 4),
(23, 5),
(24, 11),
(25, 10),
(26, 8),
(27, 12),
(28, 11),
(29, 2),
(30, 11),
(31, 4),
(32, 12),
(33, 6),
(34, 10),
(35, 12),
(36, 13),
(37, 1),
(38, 3),
(39, 7),
(40, 13),
(41, 5),
(42, 13),
(43, 7),
(44, 8),
(45, 14),
(46, 7),
(47, 2),
(48, 14),
(49, 14),
(50, 12),
)
@pytest.mark.parametrize('n,expect', _test_maximum_digital_root_sum)
def test_maximum_digital_root_sum(n, expect):
assert expect == maximum_digital_root_sum(n)
_test_not_solve = (
(10**6, 45999954),
(10**6, 45999908),
(10**6, 17469850),
)
@pytest.mark.parametrize('n,expect', _test_not_solve)
def test_not_solve(n, expect):
assert expect != solve(n)