Problem 128

completed January 27, 2021

Comments

I enjoyed this problem. I knew there would be an easy way to solve it, as my final solution shows. However, in order to get there, I wanted to implement a brute force solution first (code for this is not included). Once that brute force solution was in place I was able to see that only the first and last hexes from each row were ever able to have 3 primes. This greatly reduced the amount of searching needed. From there, it was just a matter of figuring out the formulas to calculate the differences between neighbors.

What I liked most about this problem was that while working on my brute force solution, I was forced to find a way to represent a hexagonal tiling in memory. To do this I followed the “cube coordinates” method from https://www.redblobgames.com/grids/hexagons/#coordinates-cube. What this method does is align three axes on the grid and thus assign three different numbers to indicate the location on the grid itself. From that website, there are tons of ways that I could have done this, but I chose this way because it seemed the best fit for working in rings.

Code

def memoize(fn):
    _cache = {}
    def wrap(n):
        if n in _cache:
            return _cache[n]
        ret = fn(n)
        _cache[n] = ret
        return ret
    return wrap

@memoize
def isprime(n):
    if n < 2:
        return False
    if n == 2:
        return True
    if n % 2 == 0:
        return False
    for p in range(3, int(n**0.5) + 1, 2):
        if n % p == 0:
            return False
    return True

def solve(n):
    if n == 1:
        return 1
    if n == 2:
        return 2
    if n == 3:
        return 8
    count = 3
    r = 2
    while True:
        if all((isprime(6*r-1), isprime(12*r-7), isprime(6*(r+1)-1))):
            count += 1
            if count == n:
                return 3 * r * (r + 1) + 1

        if all((isprime(6*(r+1)+1), isprime(6*(r+1)-1), isprime(12*r+17))):
            count += 1
            if count == n:
                return 3 * r * (r + 1) + 2

        r += 1

if __name__ == '__main__':
    import sys
    n = eval(sys.argv[1])
    print(solve(n))

Tests

import pytest
from problem import solve

_test_solve = (
        (1, 1),
        (2, 2),
        (3, 8),
        (4, 19),
        (5, 20),
        (6, 37),
        (7, 61),
        (8, 128),
        (9, 217),
        (10, 271),
        (20, 5677),
        (30, 26227),
        (40, 75367),
        (50, 169220),
        (100, 1790270),
)

@pytest.mark.parametrize('n,expect', _test_solve)
def test_solve(n, expect):
    assert expect == solve(n)