I’m sure there’s got to be a way to either reduce this big sum or actually calculate the answer in some way. But I don’t know how to do it. So, I used binary search instead. Requires the decimal modulo because the numbers are too precise for floating point arithmetic. Runs in about 9 seconds on Pypy3.
After reading through the forum posts once I finished the problem, I realized
that I can vastly improve the performance of this code. The slowest part is
taking the large powers of r**(k-1)
. This can be avoided by storing the
r-power in a variable, and just multiplying this by r
for each loop. Takes
the time down to about 700 ms.
r_k = 1
for k in range(1, n + 1):
total += (900 - 3*k) * r_k
r_k *= r
import decimal
decimal.getcontext().prec = 15
def s(n, r):
if r == 0:
return n
total = 0
for k in range(1, n + 1):
total += (900 - 3*k) * r**(k-1)
return total
def solve(n, a):
top, bottom = 10, -10
last = 0
while True:
r = decimal.Decimal((top + bottom) / 2)
num = s(n, r)
if num == last:
return round(r, 12)
if num > a:
bottom = r
else:
top = r
last = num
if __name__ == '__main__':
import sys
n = eval(sys.argv[1])
m = eval(sys.argv[2])
print(solve(n, m))