This problem took me quite a while to complete. I got stuck for a while, but after sleeping on it, and trying different approaches, I finally got the correct answer.
The first solution I entered was incorrect, yet all my tests were passing. This was puzzling and quite frustrating. What I ended up doing was rolling back to an earlier commit and starting over. The problem with my initial implementation was that I had removed all values if \(x^x \equiv 0 \pmod{250}\) and used combinatorics to calculate the total ways to choose these values. I’m not exactly sure how this went wrong, but it did. Including all values in the superset lead me to the correct answer.
Runs in about 8.4 seconds on Pypy3.
modulo = 10**16
small = 250
def solve(n):
superset = sorted(tuple(pow(i, i, small) for i in range(1, n + 1)))
superset_count = [0] * small
superset_index = [0] * small
for i in range(small):
cnt = superset.count(i)
superset_count[i] = cnt
if cnt:
superset_index[i] = superset.index(i)
count = 0
for index in range(n):
if index == 0:
search = searched = [1] + [0] * (small - 1)
else:
search = [0] * 250
start = superset[index-1]
for i, d in enumerate(searched):
key = (i + start) % small
search[key] = (search[key] + d) % modulo
searched = [searched[i] + search[i] for i in range(small)]
for total, repeat in enumerate(search):
if not repeat:
continue
addem = (small - total) % small
cnt = superset_count[addem]
if not cnt:
continue
ind = superset_index[addem]
start = max(index, superset_index[addem])
if start - ind < cnt:
count += repeat * (cnt - start + ind)
return count % modulo
if __name__ == '__main__':
import sys
n = eval(sys.argv[1])
print(solve(n))
import pytest
from problem import solve
_test_solve = (
(1, 0),
(2, 0),
(3, 0),
(4, 0),
(5, 0),
(6, 0),
(7, 0),
(8, 1),
(9, 2),
(10, 5),
(11, 9),
(12, 17),
(13, 31),
(14, 51),
(15, 103),
(16, 259),
(17, 523),
(18, 1099),
(19, 2119),
(20, 4239),
(100, 2919523543220223),
(200, 229002094182399),
(300, 2163528616771583),
(400, 1921290045947903),
(500, 1161003119476735),
(1000, 9471080450686975),
)
@pytest.mark.parametrize('n,expect', _test_solve)
def test_solve(n, expect):
assert expect == solve(n)
_test_wrong_solve = (
(250250, 8147698406785023),
(250250, 8147698406785022),
(250250, 8147698406785024),
)
@pytest.mark.parametrize('n,expect', _test_wrong_solve)
def test_wrong_solve(n, expect):
assert expect != solve(n)