Problem 193

completed February 27, 2021

Comments

I keep reading about the Möbius Function, but I’m not sure I’ve actually used it in a problem before.

This was the third iteration of my code. First I started with a very naive approach, brute force, finding the prime factors of each number. This allowed me to get a set of tests in place for my next iteration. In that I used a sieve to find all the squareful numbers. This was inefficient because whenever a number had more than one prime factor with an exponent > 1, it was sieved out once for each of these prime factors.

Instead, I needed to determine which of these squareful numbers would be counted multiple times. This is where the Möbius Function came in. I found this formula on MathWorld and wanted to try it out to see if it would be more efficient. And it was! I then found a way to optimize the Möbius Function and was then able to get the solution in about 28 seconds using Pypy3.

Code

def memoize(fn):
    _cache = {}
    def wrap(n):
        if n in _cache:
            return _cache[n]
        ret = fn(n)
        _cache[n] = ret
        return ret
    return wrap

def solve(n):
    top = int(n ** 0.5)
    sieve = [False, False] + [True] * (top - 1)
    primes = []
    for p, isprime in enumerate(sieve):
        if not isprime:
            continue
        primes.append(p)
        num = p + p
        while num <= top:
            sieve[num] = False
            num += p

    @memoize
    def mobius(n):
        if n == 1:
            return 1
        stop = n ** 0.5
        for p in primes:
            if n % p == 0:
                n //= p
                if n % p == 0:
                    return 0
                return -mobius(n)
            if p > stop:
                return -1

    count = 0
    for d in range(1, top + 1):
        count += mobius(d) * (n // d**2)
    return count

if __name__ == '__main__':
    import sys
    n = eval(sys.argv[1])
    print(solve(n))

Tests

import pytest
from problem import solve

_test_solve = (
        (12, 8),
        (2**0, 1),
        (2**1, 2),
        (2**2, 3),
        (2**3, 6),
        (2**4, 11),
        (2**5, 20),
        (2**6, 39),
        (2**7, 78),
        (2**8, 157),
        (2**9, 314),
        (2**10, 624),
        (2**11, 1245),
        (2**12, 2491),
        (2**13, 4982),
        (2**14, 9962),
        (2**15, 19920),
        (2**16, 39844),
        (2**17, 79688),
        (2**18, 159360),
        (2**19, 318725),
        (2**20, 637461),
        (2**21, 1274918),
        (2**22, 2549834),
        (2**23, 5099650),
        (2**24, 10199301),
        (2**25, 20398664),
)

@pytest.mark.parametrize('n,expect', _test_solve)
def test_solve(n, expect):
    assert expect == solve(n)