Problem 158

completed January 31, 2021

Comments

I was only able to get this as fast as running in 2 min 30 sec. A bit frustratingly slow, but not too slow that it feels unreasonable.

My solution was another recursive one. It was the memoization that tripped me up because simply adding my @memoize decorator wasn’t enough. Firstly, caching values for every value of n sent to P(n) didn’t work. The values had to be cached on a per n basis. That still produced a solution that was impossibly slow. I then started playing around with and seeing which args from the P(n) signature were required and which were not for memoization. Ends up that not all 4 of them are needed and really all that’s required is just the used argument. Also turns out that only when the first and only lexicographically after characters have been placed is memoization required. All of these things cut down on the execution time significantly.

The thing I’m most proud of in this code is how I implemented used. I needed a way to track which letters had been used and which hadn’t. For this I used a bitmask which was both super easy to calculate and very easy to use as a dictionary key. I’m so pleased I’d like to see if I can find more bitmask uses as I continue working on these problems.

Code

def memoize(fn):
    _cache = {}
    def wrap(i, a, l, u):
        if not a:
            return fn(i, a, l, u)
        if u in _cache:
            return _cache[u]
        ret = fn(i, a, l, u)
        _cache[u] = ret
        return ret
    return wrap

mask = {i: 1 << i for i in range(26)}

def solve(n):

    def P(k):
        @memoize
        def p(index, afters, last, used):
            count = 0
            if index + 1 == k:
                for l in range(last if afters else n):
                    m = mask[l]
                    if used & m == m:
                        continue
                    count += l > last or afters
            else:
                for l in range(last if afters else n):
                    m = mask[l]
                    if used & m == m:
                        continue
                    count += p(index + 1, l > last or afters, l, used ^ m)
            return count
        return p(0, 0, n, 0)

    top = float('-inf')
    for k in range(3*n//4 - 1, n + 1):
        pk = P(k)
        if pk >= top:
            top = pk
        else:
            return top
    return top

if __name__ == '__main__':
    import sys
    n = eval(sys.argv[1])
    print(solve(n))

Tests

import pytest
from problem import solve

_test_solve = (
        (1, 0),
        (2, 1),
        (3, 4),
        (4, 16),
        (5, 55),
        (6, 165),
        (7, 546),
        (8, 1596),
        (9, 4788),
        (10, 14400),
        (11, 40755),
        (12, 122265),
        (13, 358930),
)

@pytest.mark.parametrize('n,expect', _test_solve)
def test_solve(n, expect):
    assert expect == solve(n)