I’ve started including tests for any wrong answers that I get as well. It’s more complete and also allows me to keep track and check against my mistakes.
This problem was just a brute force solution. No fancy equations here, though that would have been really really nice.
I think if I had to do this again, I wouldn’t use recursion. I’d just make a
big old nested for loop. I checked the forums and there was at least one
person who did it that way. I think I like to shy away from that method
because it’s much much harder to write tests for. What I really needed here to
get this down to a more reasonable amount of time was to have access to my
reference
variable that I made up in the check
method down in the search
method. This way I could have restricted the range of numbers that would be
used when calling the recursive search. This would have meant I’d not be
wasting as much time with grids that could never have produced a viable
solution.
Runs in just over 6 minutes with Pypy3.
def solve(n):
nsqr = n ** 2
nsqrn = nsqr - n
grid = [0] * nsqr
def check(index):
if index < n - 1:
return True
g = grid[:index+1]
# reference sum
reference = sum(g[:n])
# row
if (index + 1) % n == 0:
seqsum = sum(g[index-n+1:index+1])
if seqsum > reference:
return None
if seqsum < reference:
return False
# column and diagonal comparison
if index == nsqrn - n + 1:
if sum(g[:nsqrn-n+1:n]) != sum(g[n-1:nsqrn+1:n-1]):
return False
# first diagonal
if index == nsqrn and index > 1:
seqsum = sum(g[n-1:nsqrn+1:n-1])
if seqsum > reference:
return None
if seqsum < reference:
return False
# column
if index >= nsqrn:
seqsum = sum(g[index%n:index+1:n])
if seqsum > reference:
return None
if seqsum < reference:
return False
# last diagonal
if index == nsqr - 1:
seqsum = sum(g[::n+1])
if seqsum > reference:
return None
if seqsum < reference:
return False
return True
def search(index):
count = 0
end = 5 if index == 1 else 10
for n in range(end):
grid[index] = n
result = check(index)
if index == nsqr - 1 and result:
return 1
elif result:
count += search(index + 1)
elif result is None:
return count
if index == 1:
count *= 2
return count
return search(0)
if __name__ == '__main__':
import sys
n = eval(sys.argv[1])
print(solve(n))
import pytest
from problem import solve
_test_solve = (
(2, 10),
(3, 170),
)
@pytest.mark.parametrize('n,expect', _test_solve)
def test_solve(n, expect):
assert expect == solve(n)
_test_not_solve = (
(4, 1000),
(4, 10000),
(4, 2549976),
(4, 12003319),
(4, 12003309),
(4, 12003320),
)
@pytest.mark.parametrize('n,expect', _test_not_solve)
def test_not_solve(n, expect):
assert expect != solve(n)