Problem 745

completed March 6, 2021

Comments

For this problem I decided on an implementation and instead of turning back when things got hard, I kept going at it until I was able to implement it. I knew that in theory my algorithm would work efficiently, it was just a matter of figuring out how to implement it.

My first implementation was brute force, running through all numbers and figuring out the maximum square that divides it. From there I wrote some tests then implemented a sieving algorithm that sieved for primes, then marked all multiples of its square using a second sieve. These were both incredibly inefficient because they meant looping approximately \(n\) times.

Instead of finding the maximal square for each number, I instead set out to find the count of numbers whose maximal square was \(d^2\). This method meant that I only needed to loop approximately \(\sqrt{n}\) times.

It was easy to find a count of integers with \(d^2\) as a divisor. However, it was much harder to find a count where \(d^2\) was the maximal square divisor. To do this, I found the easy count of \(d^2\) divisors, then went about subtracting the numbers that had \(d^2\) as a divisor, but it wasn’t the largest square divisor. And since some numbers would get removed twice (like 144 for instance), those needed to be added back. And then there’s ones that got removed and then added back twice (like 3600 for instance), those needed to be removed. This lead to the recursive solution below.

The last step was to add a 1 for all the squarefree numbers which I did by maintaining a count. Runs in about 5 seconds on Pypy3.

Code

modulo = 1000000007

def solve(n):

    top = int(n ** 0.5)
    sieve = [False, False] + [True] * (top - 1)
    primes = []
    for p, isprime in enumerate(sieve):
        if not isprime:
            continue
        primes.append(p)
        num = p + p
        while num <= top:
            sieve[num] = False
            num += p

    def doubles(p, psq, sign=1):
        count = 0
        for q in primes:
            if q == p:
                continue
            if q > p:
                break
            pqsq = psq * q**2
            if pqsq > n:
                break
            count += sign * (n // pqsq)
            count += doubles(q, pqsq, -sign)
        return count

    total = count = 0
    for d in range(2, int(n ** 0.5) + 1):
        d_count = 0

        dsq = d ** 2
        d_count += n // dsq
        for p in primes:
            dpsq = dsq * p**2
            if dpsq > n:
                break
            d_count -= n // dpsq
            d_count += doubles(p, dpsq)

        total += d_count * dsq
        count += d_count

    return (total + n - count) % modulo

if __name__ == '__main__':
    import sys
    n = eval(sys.argv[1])
    print(solve(n))

Tests

import pytest
from problem import solve

_test_solve = (
        (10**1, 24),
        (10**2, 767),
        (256,   3096),
        (10**3, 22606),
        (3600,  156752),
        (10**4, 722592),
        (10**5, 22910120),
        (10**6, 725086120),
        (10**7, 910324294),
        (10**8, 475275084),
        (10**9, 428763483),
        (10**10, 590978380),
)

@pytest.mark.parametrize('n,expect', _test_solve)
def test_solve(n, expect):
    assert expect == solve(n)