10.8. Functional Currying

  • functools.partial()

  • functools.partialmethod()

One of the most commonly used functions in the functools module is partial(), which allows you to create a new function with some of the arguments of an existing function already set. This can be useful in situations where you need to repeatedly call a function with the same arguments, but don't want to keep typing them out.

10.8.1. SetUp

>>> from functools import partial

10.8.2. Problem

>>> def add(a, b):
...     return a + b
>>>
>>> data = (1, 2, 3, 4, 5)
>>> result = map(lambda x: add(x,10), data)
>>>
>>> tuple(result)
(11, 12, 13, 14, 15)

10.8.3. Solution

>>> def add(a, b):
...     return a + b
>>>
>>> data = (1, 2, 3, 4, 5)
>>> add10 = partial(add, b=10)
>>> result = map(add10, data)
>>>
>>> tuple(result)
(11, 12, 13, 14, 15)

10.8.4. Partial

  • Create alias function and its arguments

  • Useful when you need to pass function with arguments to for example map or filter

>>> from functools import partial
>>>
>>>
>>> basetwo = partial(int, base=2)
>>> basetwo.__doc__ = 'Convert base 2 string to an int.'
>>> basetwo('10010')
18

10.8.5. Partialmethod

>>> from functools import partialmethod
>>>
>>>
>>> class Cell(object):
...     def __init__(self):
...         self._alive = False
...
...     @property
...     def alive(self):
...         return self._alive
...
...     def set_state(self, state):
...         self._alive = bool(state)
...
...     set_alive = partialmethod(set_state, True)
...     set_dead = partialmethod(set_state, False)
>>>
>>>
>>> c = Cell()
>>>
>>> c.alive
False
>>>
>>> c.set_alive()
>>> c.alive
True

10.8.6. Use Case - 1

  • We want to round to two decimal places

Problem:

>>> data = (1.1111, 2.2222, 3.3333, 4.4444)
>>> result = map(round, data)
>>>
>>> print(tuple(result))
(1, 2, 3, 4)

Function:

>>> def round2(x):
...     return round(x, ndigits=2)
>>>
>>> data = (1.1111, 2.2222, 3.3333, 4.4444)
>>> result = map(round2, data)
>>>
>>> print(tuple(result))
(1.11, 2.22, 3.33, 4.44)

Lambda:

>>> data = (1.1111, 2.2222, 3.3333, 4.4444)
>>> result = map(lambda x: round(x, ndigits=2), data)
>>>
>>> print(tuple(result))
(1.11, 2.22, 3.33, 4.44)

Partial:

>>> from functools import partial
>>>
>>> round2 = partial(round, ndigits=2)
>>> result = map(round2, data)
>>>
>>> print(tuple(result))
(1.11, 2.22, 3.33, 4.44)

10.8.7. Use Case - 2

>>> 
... from functools import partial
... import pandas as pd
...
... plot = partial(pd.DataFrame.plot, kind='line', xlabel='time', ylabel='value', title='value in time')
...
... plot(df.temperature)
... plot(df.humidity)
... plot(df.co2)
... plot(df.noise)