This problem finally taught me not to be afraid of modular arithmetic. With this I am now looking forward to other problems where I can dig in and finally start to really grok modular arithmetic. I already had a method saved for finding an inverse for modular division. What I needed in order to solve this problem was an inverse for modular exponentiation.
I did a bit of googling and came across the Tonelli-Shanks algorithm for finding the roots of \(r^2 \equiv n \pmod{p}\). The algorithm was pretty long and definitely complex, not something I would have been able to deduce on my own. I did however need to learn about quadratic residues in order to be able to implement it correctly. I feel okay about using an algorithm I found on Wikipedia because firstly, never would have figured it out myself, and second, since it’s been named after a couple humans, I assume the idea is to share it widely.
I also used a couple of lemmas from the webpage “Prime sieving for the polynomial
\(f(n)=2n^2-1\).” These showed that
If p | f(n) then also p | f(n+p)
and If p | f(n) then also p | f(-n)
. From
this knowledge and the algorithm above, I was able to create a sieve to find
the final answer.
I wish it ran a bit faster, but it completes in about 2 min 13 sec on Pypy3.
def modinv(n, p):
t, newt = 0, 1
r, newr = p, n
while newr:
q = r // newr
t, newt = newt, t - q * newt
r, newr = newr, r - q * newr
return t % p
def modroot(n, p):
# find Q and S such that p - 1 = Q*2**S
s = 0
q = 1
while p - 1 != q * (1 << s) or q % 2 == 0:
s += 1
q = (p - 1) // (1 << s)
# find z a quadratic non-residue
z = 2
while pow(z, ((p - 1) // 2), p) != p - 1:
z += 1
# loop for the answers
m = s % p
c = pow(z, q, p)
t = pow(n, q, p)
r = pow(n, (q + 1) // 2, p)
while t != 0 and t != 1:
i = 1
while pow(t, 1 << i, p) != 1 or m - i - 1 < 0:
i += 1
assert i < m, (i, m)
b = pow(c, 1 << (m - i - 1), p)
m = i
c = pow(b, 2, p)
t = (t * b ** 2) % p
r = (r * b) % p
return r, -r % p
def solve_for_n(p):
return modroot(modinv(2, p), p)
def isprime(n):
if n < 2:
return False
if n < 4:
return True
for p in range(3, int(n**0.5) + 1, 2):
if n % p == 0:
return False
return True
def solve(n):
ts = [False, False] + [True] * (n - 1)
for p in range(3, 2*n**2, 2):
m = p % 8
if m != 1 and m != 7:
continue
if not isprime(p):
continue
t1, t2 = solve_for_n(p)
if t1 > n and t2 > n:
break
for t in (t1, t2):
if t > n or t == 2:
continue
if 2 * t**2 - 1 == p:
continue
ts[t] = False
num = p + t
while num <= n:
ts[num] = False
num += p
num = p - t
if num < t:
num += p
while num <= n:
ts[num] = False
num += p
return sum(ts)
if __name__ == '__main__':
import sys
n = eval(sys.argv[1])
print(solve(n))
import pytest
from problem import solve, modroot, solve_for_n
_test_solve_for_n = (
(7, (2, 5)),
(23, (9, 14)),
)
@pytest.mark.parametrize('n,expect', _test_solve_for_n)
def test_solve_for_n(n, expect):
assert expect == solve_for_n(n)
_test_modroot = (
(5, 41, (28, 13)),
)
@pytest.mark.parametrize('n,p,expect', _test_modroot)
def test_modroot(n, p, expect):
assert expect == modroot(n, p)
_test_solve = (
(1, 0),
(2, 1),
(3, 2),
(4, 3),
(5, 3),
(6, 4),
(7, 5),
(8, 6),
(9, 6),
(10, 7),
(11, 8),
(12, 8),
(13, 9),
(14, 9),
(15, 10),
(16, 10),
(17, 11),
(18, 12),
(19, 12),
(20, 12),
(21, 13),
(22, 14),
(23, 14),
(24, 15),
(25, 16),
(26, 16),
(27, 16),
(28, 17),
(29, 17),
(30, 17),
(31, 17),
(32, 17),
(33, 17),
(34, 18),
(35, 18),
(36, 19),
(37, 19),
(38, 20),
(39, 21),
(40, 21),
(41, 22),
(42, 23),
(43, 24),
(44, 24),
(45, 25),
(46, 26),
(47, 26),
(48, 26),
(49, 27),
(50, 28),
(100, 45),
(1000, 303),
(10000, 2202),
(100000, 17185),
(1000000, 141444),
)
@pytest.mark.parametrize('n,expect', _test_solve)
def test_solve(n, expect):
assert expect == solve(n)