So, I don’t know what it is, but this problem was listed as 5% easiest as well and I still had to let it run for several hours just to get to the answer. That feels really frustrating to me.
I did enjoy working on this problem though. The math for it got simpler and simpler as I worked through it and simplified and simplified it down. It actually ended up being a pretty small and tight problem, which was really cool. Runtime still got slower and slower as it went, but not exponentially so, but still not ideal. I did some micro speed improvements, like adding a cache for the powers, but those will not improve upon the eternal slowing of the running, only just starts the race a bit faster.
MODULO = 1000000007
_exponents_cache = {}
def solve(max):
total, factors, fact_sub = 1, {}, {}
for i in range(2, max + 1):
for k, v in fact_sub.items():
factors[k] -= v
for k, v in prime_factors(i).items():
factors[k] = factors.get(k, 0) + (v * (i - 1))
fact_sub[k] = fact_sub.get(k, 0) + v
row_total = 1
for cache_key in factors.items():
cache = _exponents_cache.get(cache_key)
if cache:
prime_total = cache
else:
prime, exponents = cache_key
numerator = pow(prime, exponents + 1) - 1
denominator = prime - 1
prime_total = (numerator // denominator) % MODULO
_exponents_cache[cache_key] = prime_total
row_total *= prime_total
total += row_total
return total % MODULO
_primes = [2, 3, 5, 7]
def primes_iterator():
for p in _primes:
yield p
while True:
p += 2
for n in range(3, int(p ** 0.5) + 1, 2):
if p % n == 0:
break
else:
_primes.append(p)
yield p
def prime_factors(n):
factors = {}
for prime in primes_iterator():
exponent = 0
while n % prime == 0:
n //= prime
exponent += 1
if exponent > 0:
factors[prime] = exponent
if n == 1:
return factors
if __name__ == '__main__':
import sys
max = int(sys.argv[1])
print(solve(max))
from problem import *
import pytest
_solve_tests = (
(1, 1),
(2, 4),
(3, 17),
(4, 269),
(5, 5736),
(6, 590892),
(7, 55905463),
(8, 896504917),
(9, 677137056),
(10, 141740594713218418 % MODULO),
(11, 150275157),
(12, 121376207),
(13, 783506633),
(14, 614553155),
(15, 198581112),
(16, 655492273),
(17, 566075),
(18, 79414290),
(19, 637651480),
(20, 131742826),
(30, 221500462),
(40, 430088683),
(50, 217147306),
(60, 479364147),
(70, 357992988),
(80, 70449148),
(90, 123501558),
(100, 332792866),
(125, 128858272),
(150, 168263766),
(175, 154528397),
(200, 271664942),
(300, 694978846),
(400, 53488473),
(500, 899393748),
(600, 289794383),
(700, 246415106),
(800, 796725764),
(900, 618880649),
(1000, 361160563),
#(1500, 762946177),
#(2000, 939425731),
#(2500, 481019074),
#(3000, 665284696),
)
@pytest.mark.parametrize('n,expect', _solve_tests)
def test_solve(n, expect):
assert expect == solve(n)