I keep reading about the Möbius Function, but I’m not sure I’ve actually used it in a problem before.
This was the third iteration of my code. First I started with a very naive approach, brute force, finding the prime factors of each number. This allowed me to get a set of tests in place for my next iteration. In that I used a sieve to find all the squareful numbers. This was inefficient because whenever a number had more than one prime factor with an exponent > 1, it was sieved out once for each of these prime factors.
Instead, I needed to determine which of these squareful numbers would be counted multiple times. This is where the Möbius Function came in. I found this formula on MathWorld and wanted to try it out to see if it would be more efficient. And it was! I then found a way to optimize the Möbius Function and was then able to get the solution in about 28 seconds using Pypy3.
def memoize(fn):
_cache = {}
def wrap(n):
if n in _cache:
return _cache[n]
ret = fn(n)
_cache[n] = ret
return ret
return wrap
def solve(n):
top = int(n ** 0.5)
sieve = [False, False] + [True] * (top - 1)
primes = []
for p, isprime in enumerate(sieve):
if not isprime:
continue
primes.append(p)
num = p + p
while num <= top:
sieve[num] = False
num += p
@memoize
def mobius(n):
if n == 1:
return 1
stop = n ** 0.5
for p in primes:
if n % p == 0:
n //= p
if n % p == 0:
return 0
return -mobius(n)
if p > stop:
return -1
count = 0
for d in range(1, top + 1):
count += mobius(d) * (n // d**2)
return count
if __name__ == '__main__':
import sys
n = eval(sys.argv[1])
print(solve(n))
import pytest
from problem import solve
_test_solve = (
(12, 8),
(2**0, 1),
(2**1, 2),
(2**2, 3),
(2**3, 6),
(2**4, 11),
(2**5, 20),
(2**6, 39),
(2**7, 78),
(2**8, 157),
(2**9, 314),
(2**10, 624),
(2**11, 1245),
(2**12, 2491),
(2**13, 4982),
(2**14, 9962),
(2**15, 19920),
(2**16, 39844),
(2**17, 79688),
(2**18, 159360),
(2**19, 318725),
(2**20, 637461),
(2**21, 1274918),
(2**22, 2549834),
(2**23, 5099650),
(2**24, 10199301),
(2**25, 20398664),
)
@pytest.mark.parametrize('n,expect', _test_solve)
def test_solve(n, expect):
assert expect == solve(n)