Problem 166

completed February 1, 2021

Comments

I’ve started including tests for any wrong answers that I get as well. It’s more complete and also allows me to keep track and check against my mistakes.

This problem was just a brute force solution. No fancy equations here, though that would have been really really nice.

I think if I had to do this again, I wouldn’t use recursion. I’d just make a big old nested for loop. I checked the forums and there was at least one person who did it that way. I think I like to shy away from that method because it’s much much harder to write tests for. What I really needed here to get this down to a more reasonable amount of time was to have access to my reference variable that I made up in the check method down in the search method. This way I could have restricted the range of numbers that would be used when calling the recursive search. This would have meant I’d not be wasting as much time with grids that could never have produced a viable solution.

Runs in just over 6 minutes with Pypy3.

Code

def solve(n):

    nsqr = n ** 2
    nsqrn = nsqr - n
    grid = [0] * nsqr

    def check(index):
        if index < n - 1:
            return True
        g = grid[:index+1]

        # reference sum
        reference = sum(g[:n])

        # row
        if (index + 1) % n == 0:
            seqsum = sum(g[index-n+1:index+1])
            if seqsum > reference:
                return None
            if seqsum < reference:
                return False

        # column and diagonal comparison
        if index == nsqrn - n + 1:
            if sum(g[:nsqrn-n+1:n]) != sum(g[n-1:nsqrn+1:n-1]):
                return False

        # first diagonal
        if index == nsqrn and index > 1:
            seqsum = sum(g[n-1:nsqrn+1:n-1])
            if seqsum > reference:
                return None
            if seqsum < reference:
                return False

        # column
        if index >= nsqrn:
            seqsum = sum(g[index%n:index+1:n])
            if seqsum > reference:
                return None
            if seqsum < reference:
                return False

        # last diagonal
        if index == nsqr - 1:
            seqsum = sum(g[::n+1])
            if seqsum > reference:
                return None
            if seqsum < reference:
                return False

        return True

    def search(index):
        count = 0
        end = 5 if index == 1 else 10
        for n in range(end):
            grid[index] = n
            result = check(index)
            if index == nsqr - 1 and result:
                return 1
            elif result:
                count += search(index + 1)
            elif result is None:
                return count
        if index == 1:
            count *= 2
        return count

    return search(0)

if __name__ == '__main__':
    import sys
    n = eval(sys.argv[1])
    print(solve(n))

Tests

import pytest
from problem import solve

_test_solve = (
        (2, 10),
        (3, 170),
)

@pytest.mark.parametrize('n,expect', _test_solve)
def test_solve(n, expect):
    assert expect == solve(n)

_test_not_solve = (
        (4, 1000),
        (4, 10000),
        (4, 2549976),
        (4, 12003319),
        (4, 12003309),
        (4, 12003320),
)

@pytest.mark.parametrize('n,expect', _test_not_solve)
def test_not_solve(n, expect):
    assert expect != solve(n)