Slowly but surely it seems I can eventually get to the answer on these problems if I stare at them long enough.
I reduced the problem by only working with the numbers 0, 1, and 2. Since we’re dealing with modulo three in the end anyway, there is no point in using any other numbers. This gives the added bonus of making the cache more powerful. Since there are fewer options for what can be put into it, there are more chances that the cache will be available.
Ultimately, I ended up hitting Python’s recursion limit which is 1000. My
implementation needed to be able to have a recursion depth of n where n in this
case is 100,000. I used the sys.setrecursionlimit
to increase it but
consistently got seg faults. I looked up what was going wrong here and found
that there’s a hard limit of 65,000 and some change for my operating system.
So, what I did to get around this was “prime” the cache by finding the answer
for n=60,000. This meant that I never needed to hit the recursion limit.
import math
modulo = 1000000007
_first_digits = ((1, 3), (2, 3), (3, 3))
_digits = ((1, 3), (2, 3), (3, 4))
def solve(n):
if n % 3 == 2:
return (3 * 10 ** (n-1)) % modulo
def cache(fn):
_cache = {}
def wrap(*args, **kwargs):
if args in _cache:
return _cache[args]
ret = fn(*args, **kwargs)
_cache[args] = ret
return ret
return wrap
@cache
def search(parts, threes, left, digits=_first_digits):
left -= 1
total = 0
for digit, count in digits:
next_threes = (threes + parts[3 - digit]) % 3
if left == 0:
total += count * (next_threes == 0)
else:
next_parts = ((parts[3-digit]+1)%3, parts[1-digit], parts[2-digit])
total += count * search(
next_parts, next_threes, left, digits=_digits)
return total % modulo
parts = (1, 0, 0)
search(parts, 0, 60000)
return search(parts, 0, n)
if __name__ == '__main__':
import sys
n = eval(sys.argv[1])
sys.setrecursionlimit(max(1000, n*2))
print(solve(n))
import pytest
from problem import solve
_test_solve = (
(1, 6),
(4, 3534),
(7, 3023178),
(10, 987680024),
(13, 621045014),
(16, 383955427),
(19, 47568507),
(22, 340128871),
(25, 124780296),
(28, 119157965),
(31, 513633884),
(34, 445101666),
(37, 128623947),
(40, 567174137),
(43, 821069620),
(46, 498287604),
(49, 242640031),
(100, 131387729),
)
@pytest.mark.parametrize('n,expect', _test_solve)
def test_solve(n, expect):
assert expect == solve(n)