
Tail Recursion in Python: Implementing a Tail Recursion Elimination Decorator
Tail recursion is a special form of recursion where the recursive call is the last operation of the function. Languages like Scheme and Scala optimize this by reusing the same stack frame for each recursive call, known as tail call elimination. Unfortunately, Python does not optimize tail recursion, and likely never will, as explained by Guido van Rossum, Python’s creator (here and here ).
Curious about the limitation, I decided to implement a Python decorator to simulate tail recursion elimination. Here’s the result (I later found this blog post by Kay Schlühr with some nice ideas, so I tweaked my code with some of them — Kay’s version is perfect for Python’s features at the time, I just added a few new things):
from functools import update_wrapper
class tail_recursion:
def __init__(self, func):
self.func = func
update_wrapper(self, func)
self.recurse_sentinel = object()
self.kwargs = None
def __call__(self, *args, **kwargs):
if self.kwargs is None:
self.args, self.kwargs = args, kwargs
try:
while (ret := self.func(*self.args, **self.kwargs)) is self.recurse_sentinel: pass
return ret
finally:
self.kwargs = None
else:
self.args, self.kwargs = args, kwargs
return self.recurse_sentinel
Usage Example:
@tail_recursion
def factorial(n, acc=1):
'''Calculate factorial of n'''
if n == 0:
return acc
else:
return factorial(n - 1, n * acc)
print(factorial(1000))
In case you are curious, the result is:
How It Works
In Python we can create decorators with functions or classes. In this case, it would be useful to store some things in the decorator itself, so I decided to go with a class. We start by registering the decorator instance as a wrapper over func
. If you want to know more about it, read the docs on update_wrapper
.
The key to eliminating the recursion stack is the recurse_sentinel
. We use this object’s identity as a sentinel to tell our loop that the function has been called recursively. We also store the arguments passed to the current function call. Since this is a tail recursion, we expect the arguments to keep changing until we reach a base case.
The original function call
When the function is first called, the value of self.kwargs
is None
. In every subsequent recursive call, even if no keyword argument is passed, the value will be a dict
. The if
block will be executed. We initialize the attributes that store the arguments with the original arguments and then we start calling the function in a loop.
The loop
I used the walrus operator to store the returned value in ret
. It is only used for the final answer. In all previous calls, ret
will be recurse_sentinel
.
If you are not familiar with *args
and
**kwargs
, take a look at this post by Real
Python .
So, what happens here? When we call self.func
, its code will be executed. But when it calls itself recursively, our decorator will be called first. Since this is not the original call, we will always end up in the else
block. It will simply store the arguments of this new recursive call and return the sentinel.
So we are back at the loop, but now self.args
and self.kwargs
are updated. Now we can perform the actual function call again. We don’t need to do anything inside the loop, just keep calling the function iteratively until we get to a base case.
The finally
The last thing we are missing is resetting self.kwargs
to None
. If we don’t do this, the next time the function is called, it will not work. That’s why we need to make sure it will be reset, even if an error occurs. Which is also related to one of the reasons Guido van Rossum thinks adding tail recursion elimination is a bad idea : the stack trace gets messy.
And that’s it for my decorator. It is not really useful, but it was fun afternoon project. Let me know if you have ideas to improve this decorator!
That was a bit of a flex…
Yeah, sorry. I tried to write a more readable version:
from functools import update_wrapper
class tail_recursion:
def __init__(self, function):
self.function = function
update_wrapper(self, function)
self.recurse_sentinel = object()
self._reset_wrapper()
def __call__(self, *args, **kwargs):
self.args, self.kwargs = args, kwargs
if self.original_call:
# Force following calls to go to the else block
self.original_call = False
try:
return self._apply_recursion_iteratively()
finally:
self._reset_wrapper()
else:
return self.recurse_sentinel
def _apply_recursion_iteratively(self):
retval = self.recurse_sentinel
while self._continue_recursion(retval):
retval = self.function(*self.args, **self.kwargs)
return retval
def _continue_recursion(self, returned_value):
return returned_value is self.recurse_sentinel
def _reset_wrapper(self):
self.original_call = True
self.args = None
self.kwargs = None
As you can see, it’s still not great. The problem is that we are breaking the way Python processes the stack. So a lot is going on behind the scenes. Can you refactor my code to make it easier to read? Let me know in the comments!