I was only able to get this as fast as running in 2 min 30 sec. A bit frustratingly slow, but not too slow that it feels unreasonable.
My solution was another recursive one. It was the memoization that tripped me
up because simply adding my @memoize
decorator wasn’t enough. Firstly,
caching values for every value of n sent to P(n) didn’t work. The values had
to be cached on a per n basis. That still produced a solution that was
impossibly slow. I then started playing around with and seeing which args from
the P(n) signature were required and which were not for memoization. Ends up
that not all 4 of them are needed and really all that’s required is just the
used
argument. Also turns out that only when the first and only
lexicographically after characters have been placed is memoization required.
All of these things cut down on the execution time significantly.
The thing I’m most proud of in this code is how I implemented used
. I needed
a way to track which letters had been used and which hadn’t. For this I used a
bitmask which was both super easy to calculate and very easy to use as a
dictionary key. I’m so pleased I’d like to see if I can find more bitmask uses
as I continue working on these problems.
def memoize(fn):
_cache = {}
def wrap(i, a, l, u):
if not a:
return fn(i, a, l, u)
if u in _cache:
return _cache[u]
ret = fn(i, a, l, u)
_cache[u] = ret
return ret
return wrap
mask = {i: 1 << i for i in range(26)}
def solve(n):
def P(k):
@memoize
def p(index, afters, last, used):
count = 0
if index + 1 == k:
for l in range(last if afters else n):
m = mask[l]
if used & m == m:
continue
count += l > last or afters
else:
for l in range(last if afters else n):
m = mask[l]
if used & m == m:
continue
count += p(index + 1, l > last or afters, l, used ^ m)
return count
return p(0, 0, n, 0)
top = float('-inf')
for k in range(3*n//4 - 1, n + 1):
pk = P(k)
if pk >= top:
top = pk
else:
return top
return top
if __name__ == '__main__':
import sys
n = eval(sys.argv[1])
print(solve(n))
import pytest
from problem import solve
_test_solve = (
(1, 0),
(2, 1),
(3, 4),
(4, 16),
(5, 55),
(6, 165),
(7, 546),
(8, 1596),
(9, 4788),
(10, 14400),
(11, 40755),
(12, 122265),
(13, 358930),
)
@pytest.mark.parametrize('n,expect', _test_solve)
def test_solve(n, expect):
assert expect == solve(n)