It’s unfortunate to have 4 nested for loops. I’m sure I could have been a bit
smarter in piecing together all of the possible triangles. But since this only
needs to run through 10**8 possibilities, I figure this was okay. The thing
that sped it up to a manageable speed was caching the values of num_below
.
Once I did that, the program was able to run in about 35 sec with pypy3.
_num_below = {}
def num_below(a, b):
key = (a, b)
if key in _num_below:
return _num_below[key]
total = 0
for i in range(a):
div, mod = divmod(b * (a - i), a)
total += div
if mod == 0:
total -= 1
_num_below[key] = total
return total
def solve(n):
count = 0
for a in range(1, n + 1):
for b in range(1, n + 1):
for c in range(1, n + 1):
for d in range(1, n + 1):
total = 1 + num_below(a, b) + num_below(b, c) + \
num_below(c, d) + num_below(d, a)
if int(total ** 0.5) ** 2 == total:
count += 1
return count
if __name__ == '__main__':
import sys
n = int(sys.argv[1])
print(solve(n))
import pytest
from problem import solve
_test_solve = (
(1, 1),
(2, 5),
(3, 13),
(4, 42),
(5, 88),
(6, 156),
(7, 220),
(8, 376),
(9, 566),
(10, 862),
)
@pytest.mark.parametrize('n,expect', _test_solve)
def test_solve(n, expect):
assert expect == solve(n)