Problem 706

completed December 12, 2020

Comments

Slowly but surely it seems I can eventually get to the answer on these problems if I stare at them long enough.

I reduced the problem by only working with the numbers 0, 1, and 2. Since we’re dealing with modulo three in the end anyway, there is no point in using any other numbers. This gives the added bonus of making the cache more powerful. Since there are fewer options for what can be put into it, there are more chances that the cache will be available.

Ultimately, I ended up hitting Python’s recursion limit which is 1000. My implementation needed to be able to have a recursion depth of n where n in this case is 100,000. I used the sys.setrecursionlimit to increase it but consistently got seg faults. I looked up what was going wrong here and found that there’s a hard limit of 65,000 and some change for my operating system. So, what I did to get around this was “prime” the cache by finding the answer for n=60,000. This meant that I never needed to hit the recursion limit.

Code

import math

modulo = 1000000007

_first_digits = ((1, 3), (2, 3), (3, 3))
_digits = ((1, 3), (2, 3), (3, 4))

def solve(n):
    if n % 3 == 2:
        return (3 * 10 ** (n-1)) % modulo

    def cache(fn):
        _cache = {}
        def wrap(*args, **kwargs):
            if args in _cache:
                return _cache[args]
            ret = fn(*args, **kwargs)
            _cache[args] = ret
            return ret
        return wrap

    @cache
    def search(parts, threes, left, digits=_first_digits):
        left -= 1
        total = 0

        for digit, count in digits:
            next_threes = (threes + parts[3 - digit]) % 3

            if left == 0:
                total += count * (next_threes == 0)

            else:
                next_parts = ((parts[3-digit]+1)%3, parts[1-digit], parts[2-digit])
                total += count * search(
                        next_parts, next_threes, left, digits=_digits)

        return total % modulo

    parts = (1, 0, 0)
    search(parts, 0, 60000)
    return search(parts, 0, n)

if __name__ == '__main__':
    import sys
    n = eval(sys.argv[1])
    sys.setrecursionlimit(max(1000, n*2))
    print(solve(n))

Tests

import pytest
from problem import solve

_test_solve = (
        (1, 6),
        (4, 3534),
        (7, 3023178),
        (10, 987680024),
        (13, 621045014),
        (16, 383955427),
        (19, 47568507),
        (22, 340128871),
        (25, 124780296),
        (28, 119157965),
        (31, 513633884),
        (34, 445101666),
        (37, 128623947),
        (40, 567174137),
        (43, 821069620),
        (46, 498287604),
        (49, 242640031),
        (100, 131387729),
)

@pytest.mark.parametrize('n,expect', _test_solve)
def test_solve(n, expect):
    assert expect == solve(n)