Problem 571

completed May 6, 2021

Comments

I enjoy problems like this one that are easy enough to grasp what needs to be done on the first read, but are still challenging enough to require a good bit of thought. I guess when I look for which problem I’m going to work on next, I look for ones that I can immediately have an idea of how to solve, at least by brute force. This was definitely one of those problems.

My first shot, brute force attempt just looked through all numbers starting with 1 and checked to see if it was $n$-super-pandigital. From here I realized that I need at least $n$ base-$n$ digits if the number is going to be $n$-pandigital.

From here I soon realized that I was going to have to be much more efficient in my search. Since we know what all the base-$n$ pandigital numbers are going to look like, I wrote an iterator that yields all these pandigital numbers. From here it was straightforward. One small optimization to check for super-pandigital-ness starting with with the largest base and working back toward 2 was needed. Then I was able to get the answer.

Runs in 1 minute 52 seconds with Pypy3.

Code

import itertools

def is_pandigital(n, b):
    digits = {}
    while n:
        n, d = divmod(n, b)
        digits[d] = True
        if len(digits) == b:
            return True
    return False

def is_super_pandigital(n, k):
    for b in range(k, 1, -1):
        if not is_pandigital(n, b):
            return False
    return True

def solve(b, c=10):
    total = count = 0
    powers = tuple(b**i for i in range(b)[::-1])
    for digits in itertools.permutations(range(b)):
        if digits[0] == 0:
            continue
        n = sum(d*p for d, p in zip(digits, powers))
        if is_super_pandigital(n, b - 1):
            total += n
            count += 1
            if count == c:
                break
    else:
        raise AssertionError('not enough super-pandigitals found')
    return total

if __name__ == '__main__':
    import sys
    n = eval(sys.argv[1])
    c = eval(sys.argv[2])
    print(solve(n, c))

Tests

import pytest
from problem import solve, is_pandigital, is_super_pandigital

_test_is_pandigital = (
        (978, 2, True),
        (978, 3, True),
        (978, 4, True),
        (978, 5, True),
        (978, 6, False),
        (1093265784, 2, True),
        (1093265784, 3, True),
        (1093265784, 4, True),
        (1093265784, 5, True),
        (1093265784, 6, True),
        (1093265784, 7, True),
        (1093265784, 8, True),
        (1093265784, 9, True),
        (1093265784, 10, True),
        (1093265784, 11, False),
)

@pytest.mark.parametrize('n,b,expect', _test_is_pandigital)
def test_is_pandigital(n, b, expect):
    assert expect == is_pandigital(n, b)

_test_is_super_pandigital = (
        (978, 2, True),
        (978, 3, True),
        (978, 4, True),
        (978, 5, True),
        (978, 6, False),
)

@pytest.mark.parametrize('n,k,expect', _test_is_super_pandigital)
def test_is_super_pandigital(n, k, expect):
    assert expect == is_super_pandigital(n, k)

_test_solve = (
        (4, 1443),
        (5, 12928),
        (6, 105265),
        (7, 2170722),
        (8, 28525441),
        (9, 693716760),
        (10, 20319792309),
)

@pytest.mark.parametrize('n,expect', _test_solve)
def test_solve(n, expect):
    assert expect == solve(n)