I started by implementing a pretty brute force method of calculating. But of course since it would mean going through 10**16 iterations, this would never have been feasible. But, it allowed me to write some tests and do some analysis.
I printed out the values for f(n, k) and noticed an easy pattern.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
1 1 2 2 3 3 4 4 5 5 6 6 7 7 8 8 9 9
1 1 1 2 2 2 3 3 3 4 4 4 5 5 5
1 1 1 1 2 2 2 2 3 3 3
1 1 1 1 1 2
The solution was just to put this pattern into code, and since I’d already written some good tests, that work was pretty easy.
modulo = 1000000007
def solve(num):
total = tri = addr = 0
while True:
addr += 1
tri += addr
if tri > num:
break
div, mod = divmod(num - tri + 1, addr)
total += div * (div + 1) // 2 * addr + mod * (div + 1)
return total % modulo
if __name__ == '__main__':
import sys
n = eval(sys.argv[1])
print(solve(n))
import pytest
from problem import solve
_test_solve = (
(1, 1),
(2, 3),
(3, 7),
(4, 12),
(5, 19),
(6, 28),
(7, 39),
(8, 51),
(9, 66),
(10**1, 83),
(10**2, 12656),
(10**3, 1817711),
(10**4, 238998218),
(10**5, 651877630),
)
@pytest.mark.parametrize('n,expect', _test_solve)
def test_solve(n, expect):
assert expect == solve(n)