I never got this one to run in a performant manner. This solution takes just under an hour to run. The part that made it so slow was finding all the “prime relatives” of each prime number. This required me to check each prime against each other prime. I see in the forum thread that other people did this instead by first finding all the “relatives” then returning only those that are prime. Since I already have my sieve, it would be very easy to check if a number is prime. I think I’ll try that out and come back here to report my findings!
Update: It is in fact faster! And this makes sense because even for the largest primes you are checking against <50 other numbers. Whereas in my previous version you were checking against good lord knows how many other primes (and mostly coming up empty too!).
def cache(fn):
_cache = {}
_seen = {}
_twos_rel = {}
def wrap(*args):
if _twos_rel.get(args[0]):
return True
if args in _cache:
return _cache[args]
if args in _seen:
_cache[args] = False
return False
_seen[args] = True
ret = fn(*args)
_cache[args] = ret
if ret:
_twos_rel[args[0]] = True
return ret
return wrap
def is_connected(a, b):
while a:
b, bmod = divmod(b, 10)
a, amod = divmod(a, 10)
if amod != bmod:
return a == b
return not b // 10
def solve(n):
sieve = [False, False] + [True] * (n - 1)
primes = []
for p, is_prime in enumerate(sieve):
if not is_prime:
continue
primes.append(p)
num = p + p
while num <= n:
sieve[num] = False
num += p
relatives_of = {}
for p in primes:
relatives_of[p] = []
for q in primes:
if q == p:
break
if is_connected(q, p):
relatives_of[p].append(q)
relatives_of[q].append(p)
@cache
def is_twos_relative(num, top):
if num == 2:
return True
for p in relatives_of[num]:
if p >= top:
break
if is_twos_relative(p, top):
return True
return False
total = 0
for p in primes:
if not is_twos_relative(p, p):
total += p
return total
if __name__ == '__main__':
import sys
sys.setrecursionlimit(60000)
n = eval(sys.argv[1])
print(solve(n))
def relatives_below(p):
below = []
digits = int(math.log(p, 10)) + 1
num = p % (10**(digits-1))
if num and int(math.log(num, 10)) + 1 == digits - 1:
below.append(p % (10**(digits-1)))
for digit in range(digits):
num = (p // (10**(digit+1)) * (10**(digit+1))) + (p % (10**digit))
if digit == digits - 1 and digits > 1:
num += 10**digit
while num < p:
below.append(num)
num += 10**digit
return below
import pytest
from problem import solve, is_connected
_test_is_connected = (
(2, 5, True),
(7, 17, True),
(7, 217, False),
(1, 101, False),
)
@pytest.mark.parametrize('a,b,expect', _test_is_connected)
def test_is_connected(a, b, expect):
assert expect is is_connected(a, b)
_test_solve = (
(10**1, 0),
(10**2, 11),
(10**3, 431),
(10**4, 78728),
(10**5, 5134953),
#(10**6, 532658671),
)
@pytest.mark.parametrize('n,expect', _test_solve)
def test_solve(n, expect):
assert expect == solve(n)