Problem 329

completed January 1, 2021

Comments

First problem of the year! So happy to have made it safely through 2020 and into a new year. These problems have really been a wonderful distraction from the stress the coronavirus pandemic has brought to my life. Just about every evening I’ll snuggle up in bed and grab a problem to work on. Diving deep into learning something new. It’s relaxing and I feel I’ve really learned a lot from it this year. I hope to continue working on them through 2021 and beyond.

As for this specific problem, the code could really use some DRY’ing up (don’t repeat yourself). But it works. I made one code decision that wasn’t as performant but really improved readability. I added a probability class to keep track of the number of found and total words. I considered going one step furture and implementing __add__ so I could just +=. I feel like that would probably be a lot cleaner looking, but probably more confusing and more error prone.

It’s a bit slow and could probably stand for some optimization, but all in all, it takes just under a minute to run.

Code

def memoize(fn):
    _cache = {}
    def wrap(*args):
        if args in _cache:
            return _cache[args]
        ret = fn(*args)
        _cache[args] = ret
        return ret
    return wrap

def gcd(a, b):
    while b:
        a, b = b, a % b
    return a

class probability(object):
    def __init__(self, found=0, total=0):
        self.found = found
        self.total = total

    def add(self, other):
        self.found += other.found
        self.total += other.total

    def ratio(self):
        div = gcd(self.found, self.total)
        return self.found // div, self.total // div

def solve(n):

    top = 500
    sieve = [False, False] + [True] * (top - 1)
    for p, is_prime in enumerate(sieve):
        if not is_prime:
            continue
        num = p + p
        while num <= top:
            sieve[num] = False
            num += p

    search_word = 0b111100111011010 >> (15 - n)

    @memoize
    def search(location, jumps, word):
        if jumps == n:
            if word == search_word:
                return probability(1, 1)
            return probability(0, 1)
        jumps += 1
        prob = probability()
        word <<= 1
        if sieve[location]:
            if location != 1:
                prob.add(search(location - 1, jumps, word + 1))
                prob.add(search(location - 1, jumps, word + 1))
                prob.add(search(location - 1, jumps, word))
            else:
                prob.add(search(location + 1, jumps, word + 1))
                prob.add(search(location + 1, jumps, word + 1))
                prob.add(search(location + 1, jumps, word))
            if location != top:
                prob.add(search(location + 1, jumps, word + 1))
                prob.add(search(location + 1, jumps, word + 1))
                prob.add(search(location + 1, jumps, word))
            else:
                prob.add(search(location - 1, jumps, word + 1))
                prob.add(search(location - 1, jumps, word + 1))
                prob.add(search(location - 1, jumps, word))
        else:
            if location != 1:
                prob.add(search(location - 1, jumps, word + 1))
                prob.add(search(location - 1, jumps, word))
                prob.add(search(location - 1, jumps, word))
            else:
                prob.add(search(location + 1, jumps, word + 1))
                prob.add(search(location + 1, jumps, word))
                prob.add(search(location + 1, jumps, word))
            if location != top:
                prob.add(search(location + 1, jumps, word + 1))
                prob.add(search(location + 1, jumps, word))
                prob.add(search(location + 1, jumps, word))
            else:
                prob.add(search(location - 1, jumps, word + 1))
                prob.add(search(location - 1, jumps, word))
                prob.add(search(location - 1, jumps, word))
        return prob

    prob = probability()
    for loc in range(1, top + 1):
        prob.add(search(loc, 0, 0))

    return prob.ratio()

if __name__ == '__main__':
    import sys
    n = eval(sys.argv[1])
    print(solve(n))

Tests

import pytest
from problem import solve

_test_solve = (
        (1, (119, 300)),
        (2, (173, 1125)),
        (3, (283, 4500)),
        (4, (449, 18000)),
        (5, (157, 10800)),
)

@pytest.mark.parametrize('n,expect', _test_solve)
def test_solve(n, expect):
    assert expect == solve(n)