Problem 216

completed February 7, 2021

Comments

This problem finally taught me not to be afraid of modular arithmetic. With this I am now looking forward to other problems where I can dig in and finally start to really grok modular arithmetic. I already had a method saved for finding an inverse for modular division. What I needed in order to solve this problem was an inverse for modular exponentiation.

I did a bit of googling and came across the Tonelli-Shanks algorithm for finding the roots of \(r^2 \equiv n \pmod{p}\). The algorithm was pretty long and definitely complex, not something I would have been able to deduce on my own. I did however need to learn about quadratic residues in order to be able to implement it correctly. I feel okay about using an algorithm I found on Wikipedia because firstly, never would have figured it out myself, and second, since it’s been named after a couple humans, I assume the idea is to share it widely.

I also used a couple of lemmas from the webpage “Prime sieving for the polynomial \(f(n)=2n^2-1\).” These showed that If p | f(n) then also p | f(n+p) and If p | f(n) then also p | f(-n). From this knowledge and the algorithm above, I was able to create a sieve to find the final answer.

I wish it ran a bit faster, but it completes in about 2 min 13 sec on Pypy3.

Code

def modinv(n, p):
    t, newt = 0, 1
    r, newr = p, n
    while newr:
        q = r // newr
        t, newt = newt, t - q * newt
        r, newr = newr, r - q * newr
    return t % p

def modroot(n, p):
    # find Q and S such that p - 1 = Q*2**S
    s = 0
    q = 1
    while p - 1 != q * (1 << s) or q % 2 == 0:
        s += 1
        q = (p - 1) // (1 << s)

    # find z a quadratic non-residue
    z = 2
    while pow(z, ((p - 1) // 2), p) != p - 1:
        z += 1

    # loop for the answers
    m = s % p
    c = pow(z, q, p)
    t = pow(n, q, p)
    r = pow(n, (q + 1) // 2, p)

    while t != 0 and t != 1:
        i = 1
        while pow(t, 1 << i, p) != 1 or m - i - 1 < 0:
            i += 1
            assert i < m, (i, m)
        b = pow(c, 1 << (m - i - 1), p)
        m = i
        c = pow(b, 2, p)
        t = (t * b ** 2) % p
        r = (r * b) % p

    return r, -r % p

def solve_for_n(p):
    return modroot(modinv(2, p), p)

def isprime(n):
    if n < 2:
        return False
    if n < 4:
        return True
    for p in range(3, int(n**0.5) + 1, 2):
        if n % p == 0:
            return False
    return True

def solve(n):
    ts = [False, False] + [True] * (n - 1)
    for p in range(3, 2*n**2, 2):
        m = p % 8
        if m != 1 and m != 7:
            continue

        if not isprime(p):
            continue

        t1, t2 = solve_for_n(p)
        if t1 > n and t2 > n:
            break

        for t in (t1, t2):
            if t > n or t == 2:
                continue
            if 2 * t**2 - 1 == p:
                continue

            ts[t] = False

            num = p + t
            while num <= n:
                ts[num] = False
                num += p

            num = p - t
            if num < t:
                num += p
            while num <= n:
                ts[num] = False
                num += p

    return sum(ts)

if __name__ == '__main__':
    import sys
    n = eval(sys.argv[1])
    print(solve(n))

Tests

import pytest
from problem import solve, modroot, solve_for_n

_test_solve_for_n = (
        (7, (2, 5)),
        (23, (9, 14)),
)

@pytest.mark.parametrize('n,expect', _test_solve_for_n)
def test_solve_for_n(n, expect):
    assert expect == solve_for_n(n)

_test_modroot = (
        (5, 41, (28, 13)),
)

@pytest.mark.parametrize('n,p,expect', _test_modroot)
def test_modroot(n, p, expect):
    assert expect == modroot(n, p)

_test_solve = (
        (1, 0),
        (2, 1),
        (3, 2),
        (4, 3),
        (5, 3),
        (6, 4),
        (7, 5),
        (8, 6),
        (9, 6),
        (10, 7),
        (11, 8),
        (12, 8),
        (13, 9),
        (14, 9),
        (15, 10),
        (16, 10),
        (17, 11),
        (18, 12),
        (19, 12),
        (20, 12),
        (21, 13),
        (22, 14),
        (23, 14),
        (24, 15),
        (25, 16),
        (26, 16),
        (27, 16),
        (28, 17),
        (29, 17),
        (30, 17),
        (31, 17),
        (32, 17),
        (33, 17),
        (34, 18),
        (35, 18),
        (36, 19),
        (37, 19),
        (38, 20),
        (39, 21),
        (40, 21),
        (41, 22),
        (42, 23),
        (43, 24),
        (44, 24),
        (45, 25),
        (46, 26),
        (47, 26),
        (48, 26),
        (49, 27),
        (50, 28),
        (100, 45),
        (1000, 303),
        (10000, 2202),
        (100000, 17185),
        (1000000, 141444),
)

@pytest.mark.parametrize('n,expect', _test_solve)
def test_solve(n, expect):
    assert expect == solve(n)