What is starmap?

python
itertools
functional-programming
python-standard-library
A quick explanation of this useful functionality from the Python itertools module.
Author

Fabrizio Damicelli

Published

September 18, 2021

Let’s look at a common pattern in Python code:

numbers = range(10)
list(numbers)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
def square(x):
    return x**2
results = []
for n in numbers:
    results.append(square(n))
results
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]

That’s fine. But Pythonistas often prefer list comprehensions like this:

results = [square(n) for n in numbers]

results
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]

Equivalently, we can do that with the built in map function:

list(map(square, numbers))
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]

Nice. But sometimes we want to do something less trivial:

results = [
    n**2 if n%2 == 1 else n
    for n in numbers
]
results
[0, 1, 2, 9, 4, 25, 6, 49, 8, 81]

Compact and nice, but the cognitive load starts growing. Arguably not what we want. We’d rather have a little function:

def square_if_odd(x):
    if x % 2 == 1:
        return x**2
    return x
results = [square_if_odd(n) for n in numbers]

results
[0, 1, 2, 9, 4, 25, 6, 49, 8, 81]

Equivalently:

results = list(map(square_if_odd, numbers))

results
[0, 1, 2, 9, 4, 25, 6, 49, 8, 81]

Reading a for loop triggers this voice inside our heads that kind of spells the operation. Sometimes that works great. But often times I find the map operation to reduce that cognitive load and to improve readability.

Let’s go one step further to see what I mean:

numbers2 = range(5, 15)
numbers3 = range(10, 20)
list(numbers2), list(numbers3)
([5, 6, 7, 8, 9, 10, 11, 12, 13, 14], [10, 11, 12, 13, 14, 15, 16, 17, 18, 19])

Say we want to do something combining inputs like this:

def add_and_exp(a,b,c):
    return (a + b) ** c
results = [add_and_exp(a, b, c) for a, b, c in zip(numbers, numbers2, numbers3)]
results
[9765625,
 1977326743,
 282429536481,
 34522712143931,
 3937376385699289,
 437893890380859375,
 48661191875666868481,
 5480386857784802185939,
 630880792396715529789561,
 74615470927590710561908487]
Warning

Beware that zip will stop when the shortest list of numbers is exhausted.”

This is the equivalent using map that we would like to have.

results = map(add_and_exp, zip(numbers, numbers2, numbers3))

Cleaner rigth? :)

Now, map is lazy so nothing will happen until we evaluate it:

list(results)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_71298/3044707654.py in <module>
----> 1 list(results)

TypeError: add_and_exp() missing 2 required positional arguments: 'b' and 'c'

Upps!

zip is giving us tuples that look like this (a, b, c), where a, b, c come from numbers, numbers2, numbers3, respectively.

list(zip(numbers, numbers2, numbers3))
[(0, 5, 10),
 (1, 6, 11),
 (2, 7, 12),
 (3, 8, 13),
 (4, 9, 14),
 (5, 10, 15),
 (6, 11, 16),
 (7, 12, 17),
 (8, 13, 18),
 (9, 14, 19)]

But our function takes 3 arguments. So we would like to unpack each tuple before passing it to add_and_exp. We could modify the function to handle that. But we don’t have to, because starmap does exactly that for us:

from itertools import starmap
results = starmap(add_and_exp, zip(numbers, numbers2, numbers3))
list(results)
[9765625,
 1977326743,
 282429536481,
 34522712143931,
 3937376385699289,
 437893890380859375,
 48661191875666868481,
 5480386857784802185939,
 630880792396715529789561,
 74615470927590710561908487]

Now go out and write some beatiful functional Python :)

Here’s the video version of this tutorial:

References: - Python itertools documentation

/Fin

Any bugs, questions, comments, suggestions? Ping me on twitter or drop me an e-mail (fabridamicelli at gmail).
Share this article on your favourite platform: