This problem was incredibly similar to Problem 250 which I finished yesterday evening. I used most of the same code and the same approach here as I did there. I was unable to remove the inner most loop, however, which was possible in Problem 250. But it wasn’t an issue here because of the much smaller set size.
Now after reading through the forums, I feel like I could have done this a bit more efficiently. This solution takes about 3.5 minutes on Pypy3 and it seems that others were able to get theirs to about 30 seconds.
modulo = 10**16
def solve(n):
top = (n + 2) * n // 2
sieve = [False, False] + [True] * (top - 1)
superset = []
for p, isprime in enumerate(sieve):
if not isprime:
continue
if p <= n:
superset.append(p)
num = p + p
while num <= top:
sieve[num] = False
num += p
superset_sum = sum(superset)
searches = [[0] * superset_sum for _ in range(len(superset))]
searches[0][0] = 1
count = 0
for index in range(len(superset)):
if index == 0:
search = searched = [1] + [0] * (superset_sum - 1)
else:
search = [0] * superset_sum
start = superset[index-1]
for i, d in enumerate(searched):
key = i + start
if key < superset_sum:
search[key] = (search[key] + d) % modulo
searched = [searched[i] + search[i]
for i in range(superset_sum)]
for total, repeat in enumerate(search):
if not repeat:
continue
for i in range(index, len(superset)):
new_total = total + superset[i]
if sieve[new_total]:
count = (count + repeat) % modulo
return count % modulo
if __name__ == '__main__':
import sys
n = eval(sys.argv[1])
print(solve(n))
import pytest
from problem import solve
_test_solve = (
(2, 1),
(3, 3),
(4, 3),
(5, 5),
(6, 5),
(7, 7),
(8, 7),
(9, 7),
(10, 7),
(11, 12),
(12, 12),
(13, 20),
(14, 20),
(15, 20),
(16, 20),
(17, 35),
(18, 35),
(19, 65),
(20, 65),
(21, 65),
(22, 65),
(23, 122),
(24, 122),
(25, 122),
(26, 122),
(27, 122),
(28, 122),
(29, 237),
(30, 237),
(31, 448),
(100, 5253640),
(200, 9070770715773),
(300, 9053434012630419),
(400, 8338447649846224),
(500, 5012491561571871),
(600, 6450938646305359),
(700, 1089141014850720),
(800, 2743807595555872),
(900, 5301872277445669),
(1000, 5725053962252706),
)
@pytest.mark.parametrize('n,expect', _test_solve)
def test_solve(n, expect):
assert expect == solve(n)