For some reason I kept being pretty confused by this one. I think it was the backwardness of looking for the lowest number m such that m! is divisible by n. I kept getting those two numbers swapped around in my head.
The problem itself was pretty easy. I quickly figured out that given prime factorization for n = p1**e1 * p2**e2 * p3**e3 …, I just needed to find the maximum value of s(p**e) between all the primes. I finally settled on a solution that uses recursion which allows me to memoize answers.
While the solution takes maybe 30 minutes to calculate, it does get the correct answer and I’m fine with that.
Okay, I just ran the code again using pypy3 and it takes about 4 min to run. That’s so much better, I think I’m going to switch to using pypy for all my future problems now!
from sympy import sieve, isprime
_prime_fact = {}
def num_of_primes(n, p):
key = (n, p)
if key in _prime_fact:
exp = _prime_fact[key]
else:
exp = 0
while n % p == 0:
exp += 1
n //= p
_prime_fact[key] = exp
return exp
_s = {}
def s(n):
if n in _s:
return _s[n]
orig_n = n
if isprime(n):
_s[orig_n] = n
return n
max_num = 1
for p in sieve:
if n % p != 0:
continue
exp = num_of_primes(n, p)
n //= p ** exp
mults = num = 0
while mults < exp:
num += p
mults += num_of_primes(num, p)
if num > max_num:
max_num = num
if n == 1:
_s[orig_n] = max_num
return max_num
else:
a = max(max_num, s(n))
_s[orig_n] = a
return a
def S(n):
total = 0
for k in range(2, n + 1):
total += s(k)
return total
if __name__ == '__main__':
import sys
n = int(sys.argv[1])
print(S(n))
import pytest
from problem import s, S
_test_s = (
(2, 2),
(3, 3),
(4, 4),
(5, 5),
(6, 3),
(7, 7),
(8, 4),
(9, 6),
(10, 5),
(11, 11),
(12, 4),
(13, 13),
(14, 7),
(15, 5),
(16, 6),
(17, 17),
(18, 6),
(19, 19),
(20, 5),
(21, 7),
(22, 11),
(23, 23),
(24, 4),
(25, 10),
(26, 13),
(27, 9),
(28, 7),
(29, 29),
(30, 5),
(31, 31),
(32, 8),
(33, 11),
(34, 17),
(35, 7),
(36, 6),
(37, 37),
(38, 19),
(39, 13),
(40, 5),
(41, 41),
(42, 7),
(43, 43),
(44, 11),
(45, 6),
(46, 23),
(47, 47),
(48, 6),
(49, 14),
(50, 10),
(51, 17),
(52, 13),
(53, 53),
(54, 9),
(55, 11),
(56, 7),
(57, 19),
(58, 29),
(59, 59),
(60, 5),
)
@pytest.mark.parametrize('n,expect', _test_s)
def test_s(n, expect):
assert expect == s(n)
_test_S = (
(2, 2),
(3, 5),
(4, 9),
(5, 14),
(6, 17),
(7, 24),
(8, 28),
(9, 34),
(10, 39),
(10 ** 2, 2012),
(10 ** 3, 136817),
(10 ** 4, 10125843),
)
@pytest.mark.parametrize('n,expect', _test_S)
def test_S(n, expect):
assert expect == S(n)