This problem was pretty fun! I was very grateful for the problem’s name. I knew it was a hint, so once I got a working brute force algorithm, I looked it up. Low and behold, it was all about the Chinese Remainder Theorem. I feel like I still struggle to really grok problems that have modulo arithmetic, like this one. But, one article lead me to another and another and eventually I found myself implementing the extended Euclidean algorithm for the greatest common denominator, combining that with Bézout’s identity, and getting myself the answer.
It feels like there’s a lot more code here that needed. Most of it went toward calculating Euler’s totient: generating primes, a memoization wrapper, and the actual totient function.
Runs in about 5 seconds with Pypy3.
_stop = int(1005000 ** 0.5)
_sieve = [False, False] + [True] * _stop
_primes = []
for p, is_pr in enumerate(_sieve):
if not is_pr:
continue
_primes.append(p)
n = p + p
while n <= _stop:
_sieve[n] = False
n += p
def cache(fn):
_cache = {}
def wrap(*args):
if args in _cache:
return _cache[args]
ret = fn(*args)
_cache[args] = ret
return ret
return wrap
@cache
def totient(n):
if n == 1:
return 1
top = n ** 0.5
for p in _primes:
if p > top:
break
if n % p == 0:
n //= p
e = 1
while n % p == 0:
n //= p
e += 1
return p ** e * (1 - 1/p) * totient(n)
return n - 1
def g(a, n, b, m):
r, s, t = gcd(n, m)
if a%r != b%r:
return 0
return (a*t*m + b*s*n) % (n * m) // r
def gcd(a, b):
old_r, r = a, b
old_s, s = 1, 0
old_t, t = 0, 1
while r:
quotient = old_r // r
old_r, r = r, old_r - quotient * r
old_s, s = s, old_s - quotient * s
old_t, t = t, old_t - quotient * t
return old_r, old_s, old_t
def f(n, m):
return g(int(totient(n)), n, int(totient(m)), m)
def solve(bottom=1000000, top=1005000):
total= 0
for n in range(bottom, top):
for m in range(n + 1, top):
total += f(n, m)
return int(total)
if __name__ == '__main__':
import sys
n = eval(sys.argv[1])
m = eval(sys.argv[2])
print(solve(n, m))
import pytest
from problem import g, f
_test_g = (
((2, 4, 4, 6), 10),
((3, 4, 4, 6), 0),
((1, 3, 3, 4), 7),
((1, 5, 5, 9), 41),
)
@pytest.mark.parametrize('n,expect', _test_g)
def test_g(n, expect):
assert expect == g(*n)
_test_f = (
((1, 2), 1),
((2, 3), 5),
((2, 5), 9),
((4, 6), 2),
((5, 10), 4),
((10, 15), 0),
)
@pytest.mark.parametrize('n,expect', _test_f)
def test_f(n, expect):
assert expect == f(*n)