Problem 743

completed March 5, 2021

Comments

I’ll start off by saying that this code runs in 4 minutes and 10 seconds with Pypy3 (v3.7.9) and I’m disappointed I didn’t get it down to under a minute. The slowest parts of the code are the three lines where I’m getting the modular inverse of a number. I tried this with two different algorithms; first using the Euclidean algorithm then ultimately setting on what I have below which uses Euler’s theorem. I’m sure the optimization lies with how to calculate each successive value that I add to count with each loop. Since all the values are either based on factorials or powers, they do not all need to be recalculated each time.

The important thing to notice is that in all examples given, k is a multiple of n. This is because for each working matrix, the sums of the columns repeat with a period of k. This makes sense because as the window slides from left to right, the values lost on the left must be replaced on the right. Therefore, the problem reduces to one of matrices with k columns.

Each column can have a sum of either 0, 1, or 2. The first step is to find all ways in which to partition the number k into k parts using only 0’s, 1’s, and 2’s.

From here, I found the number of permutations of each of these partitions. Then lastly, we must note that for the columns that sum to 1, there are 2 ways in which to situation the 1: \(\begin{bmatrix}1\\0\end{bmatrix}\) or \(\begin{bmatrix}0\\1\end{bmatrix}\). Therefore, we must multiple the number of permutations by 2 raised to the power of the number of columns summing to 1 in the entire matrix not just in the k-width window.

Code

modulo = 1000000007

def inv(n):
    return pow(n, modulo - 2, modulo)

def fact(n):
    f = 1
    for i in range(1, n + 1):
        f *= i
        f %= modulo
    return f

def solve(k, n):

    repeats = n // k
    count = 0
    zeros = k // 2
    ones = k % 2
    fk = fact(k)
    fact_ones = 1
    fact_zeros = fact(zeros)
    while ones <= k:
        fz = inv(fact_zeros)
        fo = inv(fact_ones)
        count += fk * fz**2 * fo * pow(2, ones * repeats, modulo)

        fact_zeros *= inv(zeros)
        fact_zeros %= modulo
        zeros -= 1

        ones += 2
        fact_ones *= ones * (ones - 1)
        fact_ones %= modulo

    return count % modulo

if __name__ == '__main__':
    import sys
    n = eval(sys.argv[1])
    m = eval(sys.argv[2])
    print(solve(n,m))

Tests

import pytest
from problem import solve

_test_solve = (
        (1, 1, 2),
        (2, 4, 18),
        (3, 9, 560),
        (4, 16, 68614),

        (10**1, 10**2, 641012066),
        (10**2, 10**4, 646548364),
)

@pytest.mark.parametrize('n,m,expect', _test_solve)
def test_solve(n, m, expect):
    assert expect == solve(n, m)