Problem 407

completed October 24, 2020

Comments

I never really grokked this one, but I did end up getting the right answer. I was hoping to just be able to do some algebra on the equation and get something more simplified, but I never did. Instead, I printed out the solutions of M(n) from 1 to 100 and took a close look at them. The first thing that stood out was that any number with only a single prime number in its prime factorization, the answer was 1. The next thing I noticed was that n and the answer both shared a divisor. I started finding more and more patterns based on this and tried out more and more things. I knew I had it when I’d eliminated an entire for loop.

However, it took me an extra day to get the answer because at the very beginning of working on this problem, I’d made the mistake of thinking that M(1) = 1. Well it doesn’t, M(1) = 0. So I was off by one.

I also learned from this problem that the sympy sieve for generating primes is way way slower than the primes iterator that I wrote.

Code

_primes = [2, 3, 5, 7]
def primes_iterator(top):
    for p in _primes:
        yield p

    while p <= top:
        p += 2
        for n in range(3, int(p ** 0.5) + 1, 2):
            if p % n == 0:
                break
        else:
            _primes.append(p)
            yield p

_prime_factorization = {1: {}}
def prime_factorization(n):
    orig_n = n
    if orig_n in _prime_factorization:
        return _prime_factorization[orig_n]

    fact = {}
    for p in primes_iterator(n):
        exp = 0
        while n % p == 0:
            exp += 1
            n //= p
        if exp > 0:
            fact[p] = exp
            fact.update(_prime_factorization[n])
            _prime_factorization[orig_n] = fact
            return fact

def solve(top):
    total = 0
    calculated = {1: 0}
    for p in primes_iterator(top):
        num = p
        exp = 1
        while num <= top:
            calculated[num] = 1
            calculated[num * 2] = num + 1
            _prime_factorization[num] = {p: exp}
            num *= p
            exp += 1

    for n in range(1, top + 1):
        if n in calculated:
            total += calculated[n]
        else:
            p, exp = max(prime_factorization(n).items())
            prime_power = p ** exp
            for k in range(n - prime_power, prime_power, -prime_power):
                k_2 = (k ** 2) % n
                if k_2 == k:
                    total += k
                    break
                elif (k_2 + 2 * k + 1) % n == k + 1:
                    total += k + 1
                    break
    return total

if __name__ == '__main__':
    import sys
    n = int(sys.argv[1])
    print(solve(n))

Tests

import pytest
from problem import solve

_test_solve = (
        (1, 0),
        (2, 1),
        (3, 2),
        (4, 3),
        (5, 4),
        (6, 8),
        (7, 9),
        (8, 10),
        (9, 11),
        (10, 17),
        (11, 18),
        (12, 27),
        (13, 28),
        (14, 36),
        (15, 46),
        (16, 47),
        (17, 48),
        (18, 58),
        (19, 59),
        (20, 75),
        (21, 90),
        (22, 102),
        (24, 119),
        (26, 134),
        (28, 156),
        (30, 182),
        (33, 206),
        (34, 224),
        (35, 245),
        (36, 273),
        (38, 294),
        (39, 321),
        (40, 346),
        (42, 383),
        (44, 417),
        (45, 453),
        (46, 477),
        (48, 511),
        (50, 538),
        (51, 572),
        (52, 612),
        (54, 641),
        (55, 686),
        (56, 735),
        (57, 774),
        (58, 804),
        (60, 850),
        (62, 883),
        (63, 919),
        (65, 960),
        (66, 1015),
        (68, 1068),
        (69, 1114),
        (70, 1170),
        (72, 1235),
        (74, 1274),
        (75, 1325),
        (76, 1382),
        (77, 1438),
        (78, 1504),
        (80, 1570),
        (82, 1613),
        (84, 1678),
        (85, 1729),
        (86, 1773),
        (87, 1831),
        (88, 1887),
        (90, 1969),
        (91, 2047),
        (92, 2116),
        (93, 2179),
        (94, 2227),
        (95, 2303),
        (96, 2367),
        (98, 2418),
        (99, 2473),
        (100, 2549),
        (200, 10800),
        (300, 25579),
        (400, 46670),
        (500, 74840),
        (600, 109104),
        (700, 149868),
        (800, 198742),
        (900, 252296),
        (1000, 314034),
        (10000, 34981569),
        (100000, 3717852515),
)

@pytest.mark.parametrize('n,expect', _test_solve)
def test_solve(n, expect):
    assert expect == solve(n)