Problem 479

completed December 27, 2020

Comments

It was fun to watch this problem become simpler and simpler as I dug into it. It was really nice how a crazy cubic equation could all of a sudden reduce to $(1-k^2)$.

The most challenging part was making it performant. The bottleneck was in calculating an already large number raised to the millionth power. I knew that adding the third argument to the pow method – the modulus – would speed things up. However, the equation included a division, which make straight-forward modular arithmetic not an option. I dug through the internet and came across the idea of finding the modular inverse of a number. Basically, instead of dividing by a number, multiply by its modular inverse. The modular inverse being the number that when multiplied by the original yields 1 mod whatever your modulo is. I used the extended Euclidean algorithm to find the inverse. After solving the problem and looking at the forums, I find there’s also an algorithm using Euler’s theorem, which may or may not be faster.

Runs in less than a second in Pypy3.

Code

modulo = 1000000007

def inv(n, mod):
    t, newt = 0, 1
    r, newr = mod, n
    while newr:
        q = r // newr
        t, newt = newt, t - q * newt
        r, newr = newr, r - q * newr
    return t

def solve(n):
    total = 1 - n
    for k in range(2, n + 1):
        k_2 = k**2
        big_pow = pow(k_2 - 1, n + 1, modulo)
        total += (1 + big_pow) * inv(k_2, modulo)
    return total % modulo


if __name__ == '__main__':
    import sys
    n = eval(sys.argv[1])
    print(solve(n))

Tests

import pytest
from problem import solve, inv

_test_inv = (
        (5, 7, 3),
        (3, 7, 5),
        (15, 26, 7),
        (7, 26, 15),
)

@pytest.mark.parametrize('n,mod,expect', _test_inv)
def test_inv(n, mod, expect):
    assert expect == inv(n, mod)

_test_solve = (
        (4, 51160),
        (10, 397585319),
        (100, 113956045),
        (1000, 373423977),
        (10000, 866508622),
)

@pytest.mark.parametrize('n,expect', _test_solve)
def test_solve(n, expect):
    assert expect == solve(n)