I’ll start off by saying that this code runs in 4 minutes and 10 seconds with
Pypy3 (v3.7.9) and I’m disappointed I didn’t get it down to under a minute.
The slowest parts of the code are the three lines where I’m getting the modular
inverse of a number. I tried this with two different algorithms; first using
the Euclidean algorithm then ultimately setting on what I have below which uses
Euler’s theorem. I’m sure the optimization lies with how to calculate each
successive value that I add to count
with each loop. Since all the values
are either based on factorials or powers, they do not all need to be
recalculated each time.
The important thing to notice is that in all examples given, k
is a multiple
of n
. This is because for each working matrix, the sums of the columns
repeat with a period of k
. This makes sense because as the window slides
from left to right, the values lost on the left must be replaced on the right.
Therefore, the problem reduces to one of matrices with k
columns.
Each column can have a sum of either 0, 1, or 2. The first step is to find all
ways in which to partition the number k
into k
parts using only 0’s, 1’s,
and 2’s.
From here, I found the number of permutations of each of these partitions. Then lastly, we must note that for the columns that sum to 1, there are 2 ways in which to situation the 1: \(\begin{bmatrix}1\\0\end{bmatrix}\) or \(\begin{bmatrix}0\\1\end{bmatrix}\). Therefore, we must multiple the number of permutations by 2 raised to the power of the number of columns summing to 1 in the entire matrix not just in the k-width window.
modulo = 1000000007
def inv(n):
return pow(n, modulo - 2, modulo)
def fact(n):
f = 1
for i in range(1, n + 1):
f *= i
f %= modulo
return f
def solve(k, n):
repeats = n // k
count = 0
zeros = k // 2
ones = k % 2
fk = fact(k)
fact_ones = 1
fact_zeros = fact(zeros)
while ones <= k:
fz = inv(fact_zeros)
fo = inv(fact_ones)
count += fk * fz**2 * fo * pow(2, ones * repeats, modulo)
fact_zeros *= inv(zeros)
fact_zeros %= modulo
zeros -= 1
ones += 2
fact_ones *= ones * (ones - 1)
fact_ones %= modulo
return count % modulo
if __name__ == '__main__':
import sys
n = eval(sys.argv[1])
m = eval(sys.argv[2])
print(solve(n,m))
import pytest
from problem import solve
_test_solve = (
(1, 1, 2),
(2, 4, 18),
(3, 9, 560),
(4, 16, 68614),
(10**1, 10**2, 641012066),
(10**2, 10**4, 646548364),
)
@pytest.mark.parametrize('n,m,expect', _test_solve)
def test_solve(n, m, expect):
assert expect == solve(n, m)