For this problem I decided on an implementation and instead of turning back when things got hard, I kept going at it until I was able to implement it. I knew that in theory my algorithm would work efficiently, it was just a matter of figuring out how to implement it.
My first implementation was brute force, running through all numbers and figuring out the maximum square that divides it. From there I wrote some tests then implemented a sieving algorithm that sieved for primes, then marked all multiples of its square using a second sieve. These were both incredibly inefficient because they meant looping approximately \(n\) times.
Instead of finding the maximal square for each number, I instead set out to find the count of numbers whose maximal square was \(d^2\). This method meant that I only needed to loop approximately \(\sqrt{n}\) times.
It was easy to find a count of integers with \(d^2\) as a divisor. However, it was much harder to find a count where \(d^2\) was the maximal square divisor. To do this, I found the easy count of \(d^2\) divisors, then went about subtracting the numbers that had \(d^2\) as a divisor, but it wasn’t the largest square divisor. And since some numbers would get removed twice (like 144 for instance), those needed to be added back. And then there’s ones that got removed and then added back twice (like 3600 for instance), those needed to be removed. This lead to the recursive solution below.
The last step was to add a 1 for all the squarefree numbers which I did by maintaining a count. Runs in about 5 seconds on Pypy3.
modulo = 1000000007
def solve(n):
top = int(n ** 0.5)
sieve = [False, False] + [True] * (top - 1)
primes = []
for p, isprime in enumerate(sieve):
if not isprime:
continue
primes.append(p)
num = p + p
while num <= top:
sieve[num] = False
num += p
def doubles(p, psq, sign=1):
count = 0
for q in primes:
if q == p:
continue
if q > p:
break
pqsq = psq * q**2
if pqsq > n:
break
count += sign * (n // pqsq)
count += doubles(q, pqsq, -sign)
return count
total = count = 0
for d in range(2, int(n ** 0.5) + 1):
d_count = 0
dsq = d ** 2
d_count += n // dsq
for p in primes:
dpsq = dsq * p**2
if dpsq > n:
break
d_count -= n // dpsq
d_count += doubles(p, dpsq)
total += d_count * dsq
count += d_count
return (total + n - count) % modulo
if __name__ == '__main__':
import sys
n = eval(sys.argv[1])
print(solve(n))
import pytest
from problem import solve
_test_solve = (
(10**1, 24),
(10**2, 767),
(256, 3096),
(10**3, 22606),
(3600, 156752),
(10**4, 722592),
(10**5, 22910120),
(10**6, 725086120),
(10**7, 910324294),
(10**8, 475275084),
(10**9, 428763483),
(10**10, 590978380),
)
@pytest.mark.parametrize('n,expect', _test_solve)
def test_solve(n, expect):
assert expect == solve(n)