Problem 196

completed March 13, 2021

Comments

Another late night working on a problem. Fridays I always tell myself that I’ll go to bed on time, only to get deep into one of these problems! A part of me is almost sad that we’re nearing the tail end of this pandemic because it means I’ll have less time for math.

This problem wasn’t too complex. Just search through the rows starting two before and ending two after. Find primes and recurse through neighbor primes. If the neighbor group is 3 or more, add to the sums.

Other than figuring out the proper data structure for the triangle, the toughest part of this problem was determining the primality of these large numbers. I ended up using a sieve which only looks through the small range that are required.

Runs in about 24 seconds with Pypy3. I think I could get it faster figuring out a more suitable data structure for storing primes and the triangle. But, 24 seconds came in under the recommended 1 min, so I just let it be. Also, it’s 2:30am and I’m tired.

Code

def memoize(fn):
    _cache = {}
    def wrap(n):
        if n in _cache:
            return _cache[n]
        ret = fn(n)
        _cache[n] = ret
        return ret
    return wrap

@memoize
def triangle(n):
    return n * (n + 1) // 2

def solve(n):

    max_neighbor = triangle(n + 2)
    min_neighbor = triangle(n - 3) + 1

    top = int(triangle(n + 3) ** 0.5)
    sieve = [False, False] + [True] * (top - 1)
    big_composits = {}
    for p, isprime in enumerate(sieve):
        if not isprime:
            continue
        num = p + p
        while num <= top:
            sieve[num] = False
            num += p
        num = min_neighbor // p * p
        while num <= p:
            num += p
        while num <= max_neighbor:
            big_composits[num] = True
            num += p

    def isprime(k):
        return not big_composits.get(k)

    def row_for(k):
        row = n - 2
        while k > triangle(row):
            row += 1
        return row

    def neighbors_of(n):
        r = row_for(n)
        r_index = n - triangle(r - 1)

        # three above
        if n > triangle(r - 1) + 1:
            yield triangle(r - 2) + r_index - 1
        if n < triangle(r):
            yield triangle(r - 2) + r_index
        if n < triangle(r) - 1:
            yield triangle(r - 2) + r_index + 1

        # two next to
        if n > triangle(r - 1) + 1:
            yield n - 1
        if n < triangle(r):
            yield n + 1

        # three below
        if n > triangle(r - 1) + 1:
            yield triangle(r) + r_index - 1
        yield triangle(r) + r_index
        yield triangle(r) + r_index + 1

    def prime_neighbors(n):
        neighs = [n]
        visited[n] = True
        for k in neighbors_of(n):
            if not isprime(k) or k > max_neighbor or k < min_neighbor:
                continue
            if not visited.get(k):
                neighs.extend(prime_neighbors(k))
        return neighs

    start = triangle(n - 2) - n + 3
    row_start = triangle(n) - n + 1
    end = triangle(n + 2)
    row_end = triangle(n)

    visited = {}
    total = 0

    for k in range(start, end + 1):
        if not isprime(k) or visited.get(k):
            continue
        neighbors = prime_neighbors(k)
        if len(neighbors) > 2:
            for neighbor in neighbors:
                if neighbor >= row_start and neighbor <= row_end:
                    total += neighbor

    return total

if __name__ == '__main__':
    import sys
    n = eval(sys.argv[1])
    m = eval(sys.argv[2])
    print(solve(n)+solve(m))

Tests

import pytest
from problem import solve

_test_solve = (
        (1, 0),
        (2, 5),
        (3, 5),
        (4, 7),
        (5, 24),
        (6, 36),
        (7, 23),
        (8, 60),
        (9, 37),
        (10, 47),
        (11, 120),
        (12, 144),
        (13, 172),
        (14, 301),
        (15, 0),
        (16, 0),
        (17, 300),
        (18, 330),
        (19, 533),
        (20, 780),
        (10**2, 9938),
        (10**3, 3500211),
        (10**4, 950007619),
        (10**5, 549999566882),
)

@pytest.mark.parametrize('n,expect', _test_solve)
def test_solve(n, expect):
    assert expect == solve(n)