Problem 675

completed December 29, 2020

Comments

I’ve done so many of these now it seems like they are all running together. There wasn’t anything real standout about this one. As usual it came down to me actually getting out pen and paper and doing the actual math. Once that happened I was able to incrementally get to a working solution. This doesn’t run super fast, takes about 37 seconds to get to the answer.

Code

modulo = 1000000087

def memoize(fn):
    _cache = {}
    def wrap(*args):
        if args in _cache:
            return _cache[args]
        ret = fn(*args)
        _cache[args] = ret
        return ret
    return wrap

_primes = [2, 3]
def primes():
    for p in _primes:
        yield p
    while True:
        p += 2
        if is_prime(p):
            _primes.append(p)
            yield p

def is_prime(n):
    top = int(n ** 0.5) + 1
    for p in range(3, top, 2):
        if n % p == 0:
            return False
    return True

@memoize
def factors(n):
    top = n ** 0.5
    for p in primes():
        if p > top:
            break
        if n % p == 0:
            n //= p
            e = 1
            while n % p == 0:
                n //= p
                e += 1
            f = {p: e}
            f.update(factors(n))
            return f
    if n == 1:
        return {}
    return {n: 1}

def inv(n, mod):
    t, newt = 0, 1
    r, newr = mod, n
    while newr:
        q = r // newr
        t, newt = newt, t - q * newt
        r, newr = newr, r - q * newr
    return t

def solve(n):
    total, prod, facts = 0, 1, {}
    for i in range(2, n + 1):
        for p, e in factors(i).items():
            oldvar = facts.get(p, 1)
            newvar = oldvar + 2 * e
            facts[p] = newvar
            prod *= newvar * inv(oldvar, modulo)
            prod %= modulo
        total += prod
    return total % modulo

if __name__ == '__main__':
    import sys
    n = eval(sys.argv[1])
    print(solve(n))

Tests

import pytest
from problem import solve

_test_solve = (
        (10**1, 4821),
        (10**2, 930751395),
        (10**3, 822391759),
        (10**4, 979435692),
)

@pytest.mark.parametrize('n,expect', _test_solve)
def test_solve(n, expect):
    assert expect == solve(n)