I figured a while has passed since the last time I completed one of these, so I might as well check to see if there are any relatively easy problems that I can do. This one was a difficulty 5% so I figured I could finish it in a day or so, which is exactly what happened.
My algorithm is to sort the points into buckets based on their $y$ value. For each point, this then reduced the number of possible other points to check it against. I am then able to only check points whose $x$ values are within a distance less than the smallest distance found.
This method increased initialization time because I had to create a large list to store all of the available points. But once it got to the searching portion, the time required is $O(n \log n)$. With Python v3.10.8, I was able to get the answer in roughly 16 seconds. Not my best showing, but good enough.
modulo = 50515093
def memoize(fn):
_cache = {}
def wrap(n):
if n in _cache:
return _cache[n]
_cache[n] = ret = fn(n)
return ret
return wrap
@memoize
def soint(n):
if n == 0:
return 290797
return (soint(n - 1) ** 2) % modulo
def point(n):
return soint(2 * n), soint(2 * n + 1)
def dist(ax, ay, bx, by):
return (ax - bx) ** 2 + (ay - by) ** 2
def solve(n):
points_x = [None] * modulo
for k in range(n):
p = px, py = point(k)
if not points_x[px]:
points_x[px] = []
points_x[px].append(py)
max_dist = shortest = float('inf')
for p1x in range(modulo):
if not points_x[p1x]:
continue
for p1y in points_x[p1x]:
for p2x in range(max(0, p1x - max_dist), min(modulo, p1x + max_dist)):
if not points_x[p2x]:
continue
for p2y in points_x[p2x]:
if p1x == p2x and p1y == p2y:
continue
d = dist(p1x, p1y, p2x, p2y)
if d < shortest:
shortest = d
max_dist = int(shortest ** 0.5) + 1
return round(shortest ** 0.5, 9)
if __name__ == '__main__':
import sys
n = eval(sys.argv[1])
print(solve(n))
import pytest
from problem import solve
_test_solve = (
(2, 19823088.577278618),
(3, 13336599.368763387),
(4, 12403093.445842694),
(5, 10899267.179168515),
(6, 10899267.179168515),
(7, 10899267.179168515),
(8, 10899267.179168515),
(9, 9262015.547769556),
(10, 9262015.547769556),
(11, 7599576.500714826),
(12, 7599576.500714826),
(13, 6684832.935854493),
(14, 546446.466846479),
)
@pytest.mark.parametrize('n,expect', _test_solve)
def test_solve(n, expect):
assert expect == solve(n)