Tail Recursion Elimination in Python
Tail recursion is a special case of recursion where the recursive call is the last thing executed by the function. Some programming languages like Scheme and Scala optimize tail recursion by reusing the same stack frame for each recursive call. This optimization is called tail call elimination.
Long story short, Python does not optimize tail recursion. And probably never will, since Guido van Rossum, the creator of Python, has this and this argument against it.
The other day my colleagues and I were discussing about functional programming and recursion. One of them mentioned that Python does not optimize tail recursion. I thought it would be fun to try and implement a tail recursion elimination decorator in Python. So here it is (I later found this blog post by Kay Schlühr with some nice ideas, so I tweaked my code with some of them - Kay’s version is perfect for Python’s features at the time, I just added a few new things):
from functools import update_wrapper
class tail_recursion:
def __init__(self, func):
self.func = func
update_wrapper(self, func)
self.recurse_sentinel = object()
self.kwargs = None
def __call__(self, *args, **kwargs):
if self.kwargs is None:
self.args, self.kwargs = args, kwargs
try:
while (ret := self.func(*self.args, **self.kwargs)) is self.recurse_sentinel: pass
return ret
finally:
self.kwargs = None
else:
self.args, self.kwargs = args, kwargs
return self.recurse_sentinel
Here’s how it can be used:
@tail_recursion
def factorial(n, acc=1):
'''Calculate factorial of n'''
if n == 0:
return acc
else:
return factorial(n - 1, n * acc)
print(factorial(1000))
In case you are curious, the result is:
Explanation
In Python we can create decorators with functions or classes. In this case, it would be useful to store some things in the decorator itself, so I decided to go with a class. We start by registering the decorator instance as a wrapper over func
. If you want to know more about it, read the docs on update_wrapper
.
The key to eliminating the recursion stack is the recurse_sentinel
. We use this object’s identity as a sentinel to tell our loop that the function has been called recursively. We also store the arguments passed to the current function call. Since this is a tail recursion, we expect the arguments to keep changing until we reach a base case.
The original function call
When the function is first called, the value of self.kwargs
is None
. In every subsequent recursive call, even if no keyword argument is passed, the value will be a dict
. The if
block will be executed. We initialize the attributes that store the arguments with the original arguments and then we start calling the function in a loop.
The loop
I used the walrus operator to store the returned value in ret
. It is only used for the final answer. In all previous calls, ret
will be recurse_sentinel
.
If you are not familiar with *args
and
**kwargs
, take a look at this post by Real
Python .
So, what happens here? When we call self.func
, its code will be executed. But when it calls itself recursively, our decorator will be called first. Since this is not the original call, we will always end up in the else
block. It will simply store the arguments of this new recursive call and return the sentinel.
So we are back at the loop, but now self.args
and self.kwargs
are updated. Now we can perform the actual function call again. We don’t need to do anything inside the loop, just keep calling the function iteratively until we get to a base case.
The finally
The last thing we are missing is resetting self.kwargs
to None
. If we don’t do this, the next time the function is called, it will not work. That’s why we need to make sure it will be reset, even if an error occurs. Which is also related to one of the reasons Guido van Rossum thinks adding tail recursion elimination is a bad idea : the stack trace gets messy.
And that’s it for my decorator. It is not really useful, but it was fun afternoon project. Let me know if you have ideas to improve this decorator!