Adding tail call optimization to Python
July 2, 2023
Tail call optimization is a great feature. It’s obviously easier to solve recursive problems recursively rather than iteratively. It’s a shame it’s not available in Python (or many other languages), so let’s add it.
TL;DR
The following function will result in RecursionError
when its argument >= 1000
def my_fn(target_iters, iteration=0):
if iteration >= target_iters:
return iteration
return my_fn(target_iters, iteration+1)
>>> my_fn(1000)
...
RecursionError: maximum recursion depth exceeded in comparison
With some extra code, we can add tail call optimization
def my_fn2(my_fn2, target_iters, iteration=0):
if iteration >= target_iters:
return iteration
return my_fn2(target_iters, iteration+1)
optimized_fn = tailcall_optimize(my_fn2)
>>> optimized_fn(10000)
10000
Arguably a very contrived example, I admit. But not too bad - the only thing that changed is the signature of the function, everything else is the same.
How it works
Pretty straightforward actually! The key is - as you may already expect - in the my_fn2
passed in as a function argument.
We’re making use of the fact that only the leaf returns any data - that means that all other functions leading to the leaf don’t need to maintain any state
Let’s look at some pseudocode
args = init_args
while true:
return_value = my_fn(lambda new_args: update_args(new_args), *args)
if return_value is not None:
return return_value
Here it also becomes obvious why only the leaf can return any data. If other nodes would return data while recursing we’d require a strategy to merge them back together in order to emulate the behavior we’d experience with a regular recursive function.
The Python code I used is the following
def tailcall_optimize(fn):
def recursion_wrapper(*args, **kwargs):
is_called = True
result = None
val = (args, kwargs)
def fn_wrapper(*args, **kwargs):
nonlocal val, is_called
val = (args, kwargs)
is_called = True
while is_called:
is_called = False
result = fn(fn_wrapper, *val[0], **val[1])
return result
return recursion_wrapper
Instead of relying on the fact that the non-leaf node returns None
, I’m tracking if the fn_wrapper passed in is being called or not - but either works!
This certainly requires a few more checks to bring it into production, but conceptually this is all we need. This could also be easily ported to other languages.
This was all good fun, but let’s address the elephant is the room: How to make this work with a broader set of recursive functions.
Naturally I looked into that as well…