Problem 816

completed February 15, 2023

Comments

I figured a while has passed since the last time I completed one of these, so I might as well check to see if there are any relatively easy problems that I can do. This one was a difficulty 5% so I figured I could finish it in a day or so, which is exactly what happened.

My algorithm is to sort the points into buckets based on their $y$ value. For each point, this then reduced the number of possible other points to check it against. I am then able to only check points whose $x$ values are within a distance less than the smallest distance found.

This method increased initialization time because I had to create a large list to store all of the available points. But once it got to the searching portion, the time required is $O(n \log n)$. With Python v3.10.8, I was able to get the answer in roughly 16 seconds. Not my best showing, but good enough.

Code

modulo = 50515093

def memoize(fn):
    _cache = {}
    def wrap(n):
        if n in _cache:
            return _cache[n]
        _cache[n] = ret = fn(n)
        return ret
    return wrap

@memoize
def soint(n):
    if n == 0:
        return 290797
    return (soint(n - 1) ** 2) % modulo

def point(n):
    return soint(2 * n), soint(2 * n + 1)

def dist(ax, ay, bx, by):
    return (ax - bx) ** 2 + (ay - by) ** 2

def solve(n):
    points_x = [None] * modulo
    for k in range(n):
        p = px, py = point(k)
        if not points_x[px]:
            points_x[px] = []
        points_x[px].append(py)

    max_dist = shortest = float('inf')
    for p1x in range(modulo):
        if not points_x[p1x]:
            continue
        for p1y in points_x[p1x]:
            for p2x in range(max(0, p1x - max_dist), min(modulo, p1x + max_dist)):
                if not points_x[p2x]:
                    continue
                for p2y in points_x[p2x]:
                    if p1x == p2x and p1y == p2y:
                        continue
                    d = dist(p1x, p1y, p2x, p2y)
                    if d < shortest:
                        shortest = d
                        max_dist = int(shortest ** 0.5) + 1
    return round(shortest ** 0.5, 9)

if __name__ == '__main__':
    import sys
    n = eval(sys.argv[1])
    print(solve(n))

Tests

import pytest
from problem import solve

_test_solve = (
        (2, 19823088.577278618),
        (3, 13336599.368763387),
        (4, 12403093.445842694),
        (5, 10899267.179168515),
        (6, 10899267.179168515),
        (7, 10899267.179168515),
        (8, 10899267.179168515),
        (9, 9262015.547769556),
        (10, 9262015.547769556),
        (11, 7599576.500714826),
        (12, 7599576.500714826),
        (13, 6684832.935854493),
        (14, 546446.466846479),
)

@pytest.mark.parametrize('n,expect', _test_solve)
def test_solve(n, expect):
    assert expect == solve(n)