Problem 250

completed March 29, 2021

Comments

This problem took me quite a while to complete. I got stuck for a while, but after sleeping on it, and trying different approaches, I finally got the correct answer.

The first solution I entered was incorrect, yet all my tests were passing. This was puzzling and quite frustrating. What I ended up doing was rolling back to an earlier commit and starting over. The problem with my initial implementation was that I had removed all values if \(x^x \equiv 0 \pmod{250}\) and used combinatorics to calculate the total ways to choose these values. I’m not exactly sure how this went wrong, but it did. Including all values in the superset lead me to the correct answer.

Runs in about 8.4 seconds on Pypy3.

Code

modulo = 10**16
small = 250

def solve(n):

    superset = sorted(tuple(pow(i, i, small) for i in range(1, n + 1)))

    superset_count = [0] * small
    superset_index = [0] * small
    for i in range(small):
        cnt = superset.count(i)
        superset_count[i] = cnt
        if cnt:
            superset_index[i] = superset.index(i)

    count = 0

    for index in range(n):

        if index == 0:
            search = searched = [1] + [0] * (small - 1)
        else:
            search = [0] * 250
            start = superset[index-1]
            for i, d in enumerate(searched):
                key = (i + start) % small
                search[key] = (search[key] + d) % modulo
            searched = [searched[i] + search[i] for i in range(small)]

        for total, repeat in enumerate(search):
            if not repeat:
                continue
            addem = (small - total) % small
            cnt = superset_count[addem]
            if not cnt:
                continue
            ind = superset_index[addem]
            start = max(index, superset_index[addem])
            if start - ind < cnt:
                count += repeat * (cnt - start + ind)

    return count % modulo

if __name__ == '__main__':
    import sys
    n = eval(sys.argv[1])
    print(solve(n))

Tests

import pytest
from problem import solve

_test_solve = (
        (1, 0),
        (2, 0),
        (3, 0),
        (4, 0),
        (5, 0),
        (6, 0),
        (7, 0),
        (8, 1),
        (9, 2),
        (10, 5),
        (11, 9),
        (12, 17),
        (13, 31),
        (14, 51),
        (15, 103),
        (16, 259),
        (17, 523),
        (18, 1099),
        (19, 2119),
        (20, 4239),
        (100, 2919523543220223),
        (200, 229002094182399),
        (300, 2163528616771583),
        (400, 1921290045947903),
        (500, 1161003119476735),
        (1000, 9471080450686975),
)

@pytest.mark.parametrize('n,expect', _test_solve)
def test_solve(n, expect):
    assert expect == solve(n)

_test_wrong_solve = (
        (250250, 8147698406785023),
        (250250, 8147698406785022),
        (250250, 8147698406785024),
)

@pytest.mark.parametrize('n,expect', _test_wrong_solve)
def test_wrong_solve(n, expect):
    assert expect != solve(n)