I kept looking and looking at this one thinking I was missing something. I started by running through all values for a then b. That was slow. So I then tried running through all values for b then a. This was also slow. I wanted a better way to predict the second number. The third number was always going to be easy to find because of the geometric series.
I finally tried running through all values for a then c. But instead of just looping through all primes for c, I used the fact that (a + 1)*(c + 1) had to be a square number. I could get the prime factorization of a + 1 then use that to predict the possible values of c. This was much much faster! Was able to get the solution down to 45 seconds on pypy3.
def solve(n):
_factor = {1: {}}
def factor(n):
orig_n = n
if orig_n in _factor:
return _factor[orig_n]
fact = {}
for p in primes:
exp = 0
while n % p == 0:
exp += 1
n //= p
if exp > 0:
fact[p] = exp
fact.update(factor(n))
break
_factor[orig_n] = fact
return fact
sieve = [True] * (n + 1)
primes = []
squares = {}
for p, is_prime in enumerate(sieve):
if not is_prime or p < 2:
continue
primes.append(p)
squares[(p + 1) ** 2] = p
_factor[p] = {p: 1}
num = p + p
while num <= n:
sieve[num] = False
num += p
total = 0
for i, a in enumerate(primes):
odds, evens = 1, 1
for p, exp in factor(a + 1).items():
div, mod = divmod(exp, 2)
evens *= p ** div
odds *= p ** mod
k, addend = evens, 1
if odds % 2:
addend = 2
if k % 2:
k -= 1
while True:
k += addend
c = odds * k ** 2 - 1
if c > n:
break
if sieve[c]:
b = squares.get((c + 1) * (a + 1))
if b:
total += a + b + c
return total
if __name__ == '__main__':
import sys
n = eval(sys.argv[1])
print(solve(n))
import pytest
from problem import solve
_test_solve = (
(100, 1035),
(1000, 75019),
(10000, 4225228),
(100000, 249551109),
(1000000, 17822459735),
)
@pytest.mark.parametrize('n,expect', _test_solve)
def test_solve(n, expect):
assert expect == solve(n)