I never really grokked this one, but I did end up getting the right answer. I was hoping to just be able to do some algebra on the equation and get something more simplified, but I never did. Instead, I printed out the solutions of M(n) from 1 to 100 and took a close look at them. The first thing that stood out was that any number with only a single prime number in its prime factorization, the answer was 1. The next thing I noticed was that n and the answer both shared a divisor. I started finding more and more patterns based on this and tried out more and more things. I knew I had it when I’d eliminated an entire for loop.
However, it took me an extra day to get the answer because at the very beginning of working on this problem, I’d made the mistake of thinking that M(1) = 1. Well it doesn’t, M(1) = 0. So I was off by one.
I also learned from this problem that the sympy sieve for generating primes is way way slower than the primes iterator that I wrote.
_primes = [2, 3, 5, 7]
def primes_iterator(top):
for p in _primes:
yield p
while p <= top:
p += 2
for n in range(3, int(p ** 0.5) + 1, 2):
if p % n == 0:
break
else:
_primes.append(p)
yield p
_prime_factorization = {1: {}}
def prime_factorization(n):
orig_n = n
if orig_n in _prime_factorization:
return _prime_factorization[orig_n]
fact = {}
for p in primes_iterator(n):
exp = 0
while n % p == 0:
exp += 1
n //= p
if exp > 0:
fact[p] = exp
fact.update(_prime_factorization[n])
_prime_factorization[orig_n] = fact
return fact
def solve(top):
total = 0
calculated = {1: 0}
for p in primes_iterator(top):
num = p
exp = 1
while num <= top:
calculated[num] = 1
calculated[num * 2] = num + 1
_prime_factorization[num] = {p: exp}
num *= p
exp += 1
for n in range(1, top + 1):
if n in calculated:
total += calculated[n]
else:
p, exp = max(prime_factorization(n).items())
prime_power = p ** exp
for k in range(n - prime_power, prime_power, -prime_power):
k_2 = (k ** 2) % n
if k_2 == k:
total += k
break
elif (k_2 + 2 * k + 1) % n == k + 1:
total += k + 1
break
return total
if __name__ == '__main__':
import sys
n = int(sys.argv[1])
print(solve(n))
import pytest
from problem import solve
_test_solve = (
(1, 0),
(2, 1),
(3, 2),
(4, 3),
(5, 4),
(6, 8),
(7, 9),
(8, 10),
(9, 11),
(10, 17),
(11, 18),
(12, 27),
(13, 28),
(14, 36),
(15, 46),
(16, 47),
(17, 48),
(18, 58),
(19, 59),
(20, 75),
(21, 90),
(22, 102),
(24, 119),
(26, 134),
(28, 156),
(30, 182),
(33, 206),
(34, 224),
(35, 245),
(36, 273),
(38, 294),
(39, 321),
(40, 346),
(42, 383),
(44, 417),
(45, 453),
(46, 477),
(48, 511),
(50, 538),
(51, 572),
(52, 612),
(54, 641),
(55, 686),
(56, 735),
(57, 774),
(58, 804),
(60, 850),
(62, 883),
(63, 919),
(65, 960),
(66, 1015),
(68, 1068),
(69, 1114),
(70, 1170),
(72, 1235),
(74, 1274),
(75, 1325),
(76, 1382),
(77, 1438),
(78, 1504),
(80, 1570),
(82, 1613),
(84, 1678),
(85, 1729),
(86, 1773),
(87, 1831),
(88, 1887),
(90, 1969),
(91, 2047),
(92, 2116),
(93, 2179),
(94, 2227),
(95, 2303),
(96, 2367),
(98, 2418),
(99, 2473),
(100, 2549),
(200, 10800),
(300, 25579),
(400, 46670),
(500, 74840),
(600, 109104),
(700, 149868),
(800, 198742),
(900, 252296),
(1000, 314034),
(10000, 34981569),
(100000, 3717852515),
)
@pytest.mark.parametrize('n,expect', _test_solve)
def test_solve(n, expect):
assert expect == solve(n)