Problem 518

completed October 28, 2020

Comments

I kept looking and looking at this one thinking I was missing something. I started by running through all values for a then b. That was slow. So I then tried running through all values for b then a. This was also slow. I wanted a better way to predict the second number. The third number was always going to be easy to find because of the geometric series.

I finally tried running through all values for a then c. But instead of just looping through all primes for c, I used the fact that (a + 1)*(c + 1) had to be a square number. I could get the prime factorization of a + 1 then use that to predict the possible values of c. This was much much faster! Was able to get the solution down to 45 seconds on pypy3.

Code

def solve(n):

    _factor = {1: {}}
    def factor(n):
        orig_n = n
        if orig_n in _factor:
            return _factor[orig_n]
        fact = {}
        for p in primes:
            exp = 0
            while n % p == 0:
                exp += 1
                n //= p
            if exp > 0:
                fact[p] = exp
                fact.update(factor(n))
                break
        _factor[orig_n] = fact
        return fact

    sieve = [True] * (n + 1)
    primes = []
    squares = {}
    for p, is_prime in enumerate(sieve):
        if not is_prime or p < 2:
            continue
        primes.append(p)
        squares[(p + 1) ** 2] = p
        _factor[p] = {p: 1}
        num = p + p
        while num <= n:
            sieve[num] = False
            num += p

    total = 0
    for i, a in enumerate(primes):
        odds, evens = 1, 1
        for p, exp in factor(a + 1).items():
            div, mod = divmod(exp, 2)
            evens *= p ** div
            odds *= p ** mod
        k, addend = evens, 1
        if odds % 2:
            addend = 2
            if k % 2:
                k -= 1
        while True:
            k += addend
            c = odds * k ** 2 - 1
            if c > n:
                break
            if sieve[c]:
                b = squares.get((c + 1) * (a + 1))
                if b:
                    total += a + b + c

    return total

if __name__ == '__main__':
    import sys
    n = eval(sys.argv[1])
    print(solve(n))

Tests

import pytest
from problem import solve

_test_solve = (
        (100, 1035),
        (1000, 75019),
        (10000, 4225228),
        (100000, 249551109),
        (1000000, 17822459735),
)

@pytest.mark.parametrize('n,expect', _test_solve)
def test_solve(n, expect):
    assert expect == solve(n)