10.5. Functional Memoization

10.5.1. SetUp

>>> from pprint import pprint

10.5.2. Problem

  • Calling the same function with the same parameter

  • Requires computation every time the function is called

>>> def factorial(n):
...     return 1 if n==0 else n*factorial(n-1)
>>> factorial(5)
120
>>> factorial(5)
120
>>> factorial(5)
120

10.5.3. Solution

  • Memoization - remembering function results for given parameter

  • Store information in external dict

  • key: function parameter

  • value: function call result for given parameter

  • Dicts are very fast

>>> RESULTS = {}
>>>
>>> def factorial(n):
...     if n not in RESULTS:
...          RESULTS[n] = 1 if n==0 else n*factorial(n-1)
...     return RESULTS[n]
>>> factorial(5)
120
>>>
>>> factorial(5)
120
>>>
>>> factorial(5)
120

Result:

>>> RESULTS  
{0: 1,
 1: 1,
 2: 2,
 3: 6,
 4: 24,
 5: 120}

The RESULTS was calculated only first time. For each consecutive function call, results are taken from RESULTS not calculated all over again. This speeds up execution and also leaves trace for debugging:

10.5.4. Function Based Decorator

>>> def cache(func):
...     def wrapper(n):
...         if n not in wrapper.cache:
...             wrapper.cache[n] = func(n)
...         return wrapper.cache[n]
...     wrapper.cache = {}
...     return wrapper
>>>
>>> @cache
... def factorial(n):
...     return 1 if n==0 else n*factorial(n-1)

Usage:

>>> factorial(5)
120
>>>
>>> factorial(5)
120
>>>
>>> factorial(5)
120

Result:

>>> factorial.cache  
{0: 1,
 1: 1,
 2: 2,
 3: 6,
 4: 24,
 5: 120}

10.5.5. Class Based Decorator

>>> class Cache(dict):
...     def __init__(self, func):
...         self.func = func
...
...     def __call__(self, n):
...         return self[n]
...
...     def __missing__(self, n):
...         self[n] = self.func(n)
...         return self[n]
>>>
>>>
>>> @Cache
... def factorial(n):
...     return 1 if n == 0 else n * factorial(n-1)

Usage:

>>> factorial(5)
120
>>>
>>> factorial(5)
120
>>>
>>> factorial(5)
120

Result:

>>> factorial  
{0: 1,
 1: 1,
 2: 2,
 3: 6,
 4: 24,
 5: 120}

10.5.6. Functools Cache

  • Cache with unlimited size

>>> from functools import cache
>>>
>>> @cache
... def factorial(n):
...     return 1 if n==0 else n*factorial(n-1)

Usage:

>>> factorial(5)
120
>>>
>>> factorial(5)
120
>>>
>>> factorial(5)
120

Result:

>>> factorial.cache_info()
CacheInfo(hits=2, misses=6, maxsize=None, currsize=6)
>>>
>>> factorial.cache_parameters()
{'maxsize': None, 'typed': False}
>>>
>>> factorial.cache_clear()

10.5.7. Functools LRU Cache

  • Least Recently Used

  • Cache with limited size

  • from functools import lru_cache

  • @lru_cache(maxsize=None)

>>> from functools import lru_cache
>>>
>>> @lru_cache
... def factorial(n):
...     return 1 if n==0 else n*factorial(n-1)

Usage:

>>> factorial(5)
120
>>>
>>> factorial(5)
120
>>>
>>> factorial(5)
120

Result:

>>> factorial.cache_info()
CacheInfo(hits=2, misses=6, maxsize=128, currsize=6)
>>>
>>> factorial.cache_parameters()
{'maxsize': 128, 'typed': False}
>>>
>>> factorial.cache_clear()

10.5.8. Performance

  • Date: 2025-01-14

  • Python: 3.13.1

  • IPython: 8.31.0

  • System: macOS 15.2

  • Computer: MacBook M3 Max

  • CPU: 16 cores (12 performance and 4 efficiency) / 3nm

  • RAM: 128 GB RAM LPDDR5

No cache:

>>> def factorial(n):
...     return 1 if n==0 else n*factorial(n-1)
>>> # doctest: +SKIP
... %%timeit -n 1000 -r 1000
... factorial(50)
...
1.94 μs ± 74.5 ns per loop (mean ± std. dev. of 1000 runs, 1,000 loops each)
1.94 μs ± 74.5 ns per loop (mean ± std. dev. of 1000 runs, 1,000 loops each)
1.94 μs ± 75.9 ns per loop (mean ± std. dev. of 1000 runs, 1,000 loops each)
1.94 μs ± 85.7 ns per loop (mean ± std. dev. of 1000 runs, 1,000 loops each)
1.94 μs ± 90.2 ns per loop (mean ± std. dev. of 1000 runs, 1,000 loops each)

Dict cache:

>>> CACHE = {}
>>>
>>> def factorial(n):
...     if n not in CACHE:
...         CACHE[n] = 1 if n==0 else n*factorial(n-1)
...     return CACHE[n]
>>> # doctest: +SKIP
... %%timeit -n 1000 -r 1000
... factorial(50)
...
35.7 ns ± 7.61 ns per loop (mean ± std. dev. of 1000 runs, 1,000 loops each)
36.0 ns ± 9.51 ns per loop (mean ± std. dev. of 1000 runs, 1,000 loops each)
36.2 ns ± 8.31 ns per loop (mean ± std. dev. of 1000 runs, 1,000 loops each)
36.3 ns ± 8 ns per loop (mean ± std. dev. of 1000 runs, 1,000 loops each)
37.2 ns ± 11.8 ns per loop (mean ± std. dev. of 1000 runs, 1,000 loops each)

Function based decorator:

>>> def cache(func):
...     def wrapper(n):
...         if n not in wrapper.cache:
...             wrapper.cache[n] = func(n)
...         return wrapper.cache[n]
...     wrapper.cache = {}
...     return wrapper
>>>
>>> @cache
... def factorial(n):
...     return 1 if n==0 else n*factorial(n-1)
>>> # doctest: +SKIP
... %%timeit -n 1000 -r 1000
... factorial(50)
...
71.2 ns ± 12.4 ns per loop (mean ± std. dev. of 1000 runs, 1,000 loops each)
71.4 ns ± 14.9 ns per loop (mean ± std. dev. of 1000 runs, 1,000 loops each)
71.8 ns ± 15.6 ns per loop (mean ± std. dev. of 1000 runs, 1,000 loops each)
71.8 ns ± 16.8 ns per loop (mean ± std. dev. of 1000 runs, 1,000 loops each)
72.2 ns ± 15.9 ns per loop (mean ± std. dev. of 1000 runs, 1,000 loops each)

Class based decorator:

>>> class Cache(dict):
...     def __init__(self, func):
...         self.func = func
...
...     def __call__(self, n):
...         return self[n]
...
...     def __missing__(self, n):
...         self[n] = self.func(n)
...         return self[n]
>>>
>>>
>>> @Cache
... def factorial(n):
...     return 1 if n == 0 else n * factorial(n-1)
>>> # doctest: +SKIP
... %%timeit -n 1000 -r 1000
... factorial(50)
...
56.1 ns ± 8.7 ns per loop (mean ± std. dev. of 1000 runs, 1,000 loops each)
56.2 ns ± 8.7 ns per loop (mean ± std. dev. of 1000 runs, 1,000 loops each)
56.5 ns ± 14 ns per loop (mean ± std. dev. of 1000 runs, 1,000 loops each)
57.6 ns ± 12 ns per loop (mean ± std. dev. of 1000 runs, 1,000 loops each)
57.9 ns ± 12 ns per loop (mean ± std. dev. of 1000 runs, 1,000 loops each)

Functools Cache:

>>> from functools import cache
>>>
>>> @cache
... def factorial(n):
...     return 1 if n==0 else n*factorial(n-1)
>>> # doctest: +SKIP
... %%timeit -n 1000 -r 1000
... factorial(50)
...
22.4 ns ± 2.92 ns per loop (mean ± std. dev. of 1000 runs, 1,000 loops each)
22.7 ns ± 3.25 ns per loop (mean ± std. dev. of 1000 runs, 1,000 loops each)
22.8 ns ± 4.28 ns per loop (mean ± std. dev. of 1000 runs, 1,000 loops each)
22.8 ns ± 5.24 ns per loop (mean ± std. dev. of 1000 runs, 1,000 loops each)
22.9 ns ± 2.88 ns per loop (mean ± std. dev. of 1000 runs, 1,000 loops each)

Functools LRU Cache:

>>> from functools import lru_cache
>>>
>>> @lru_cache
... def factorial(n):
...     return 1 if n==0 else n*factorial(n-1)
>>> # doctest: +SKIP
... %%timeit -n 1000 -r 1000
... factorial(50)
...
26.9 ns ± 8.77 ns per loop (mean ± std. dev. of 1000 runs, 1,000 loops each)
27.2 ns ± 10.7 ns per loop (mean ± std. dev. of 1000 runs, 1,000 loops each)
27.8 ns ± 8.75 ns per loop (mean ± std. dev. of 1000 runs, 1,000 loops each)
28.3 ns ± 8.3 ns per loop (mean ± std. dev. of 1000 runs, 1,000 loops each)
28.6 ns ± 10.4 ns per loop (mean ± std. dev. of 1000 runs, 1,000 loops each)

10.5.9. Use Case - 1

>>> CACHE = {}
>>>
>>> def add(a,b):
...     if (a,b) not in CACHE:
...         CACHE[a,b] = a + b
...     return CACHE[a,b]

Args:

>>> add(1,2)
3
>>>
>>> add(3,2)
5
>>>
>>> add(3,5)
8
>>>
>>> pprint(CACHE)
{(1, 2): 3, (3, 2): 5, (3, 5): 8}

Kwargs:

>>> add(a=1, b=2)
3
>>>
>>> add(b=2, a=1)
3
>>>
>>> add(b=2, a=10)
12
>>>
>>> pprint(CACHE)
{(1, 2): 3, (3, 2): 5, (3, 5): 8, (10, 2): 12}