I enjoyed this problem. I knew there would be an easy way to solve it, as my final solution shows. However, in order to get there, I wanted to implement a brute force solution first (code for this is not included). Once that brute force solution was in place I was able to see that only the first and last hexes from each row were ever able to have 3 primes. This greatly reduced the amount of searching needed. From there, it was just a matter of figuring out the formulas to calculate the differences between neighbors.
What I liked most about this problem was that while working on my brute force solution, I was forced to find a way to represent a hexagonal tiling in memory. To do this I followed the “cube coordinates” method from https://www.redblobgames.com/grids/hexagons/#coordinates-cube. What this method does is align three axes on the grid and thus assign three different numbers to indicate the location on the grid itself. From that website, there are tons of ways that I could have done this, but I chose this way because it seemed the best fit for working in rings.
def memoize(fn):
_cache = {}
def wrap(n):
if n in _cache:
return _cache[n]
ret = fn(n)
_cache[n] = ret
return ret
return wrap
@memoize
def isprime(n):
if n < 2:
return False
if n == 2:
return True
if n % 2 == 0:
return False
for p in range(3, int(n**0.5) + 1, 2):
if n % p == 0:
return False
return True
def solve(n):
if n == 1:
return 1
if n == 2:
return 2
if n == 3:
return 8
count = 3
r = 2
while True:
if all((isprime(6*r-1), isprime(12*r-7), isprime(6*(r+1)-1))):
count += 1
if count == n:
return 3 * r * (r + 1) + 1
if all((isprime(6*(r+1)+1), isprime(6*(r+1)-1), isprime(12*r+17))):
count += 1
if count == n:
return 3 * r * (r + 1) + 2
r += 1
if __name__ == '__main__':
import sys
n = eval(sys.argv[1])
print(solve(n))
import pytest
from problem import solve
_test_solve = (
(1, 1),
(2, 2),
(3, 8),
(4, 19),
(5, 20),
(6, 37),
(7, 61),
(8, 128),
(9, 217),
(10, 271),
(20, 5677),
(30, 26227),
(40, 75367),
(50, 169220),
(100, 1790270),
)
@pytest.mark.parametrize('n,expect', _test_solve)
def test_solve(n, expect):
assert expect == solve(n)