import matplotlib.pyplot as plt
import numpy as np
16 Plotting
16.1 Matplotlib
16.1.1 Anatomy of a Figure
The most common plotting library in python is matplotlib. matplotlib is very comprehensive and brings many idiomatic choices that sometimes can make it a bit cumbersome to wrap our heads around, so don’t worry if you struggle with it a bit at first.
We will look at the most basic forms of plotting to try to understand the basic building blocks. To go beyond that you will for sure need to look up more things, but the good news is that matplotlib is a decently old and very mature project, which facilitates searching for information on the internet and its official documentation is pretty good too. To discover functionalities and for searching information, it is useful to take a look at the components of a plot:
Here’s the code behind this figure in the matplotlib documentation.
Take a look at the examples gallery to reach for inspiration and guidance to plot your data.
Also, even free Large Language Models (LLMs) such as (bing.com’s or brave’s built-in copilots) are pretty good at generating plotting instructions, especially when given the precise terminology regarding the plot – give them a try if you feel stuck making a figure.
16.1.2 Install matplotlib
uv add matplotlib
16.1.3 Basic Plotting
The official documentation has a nice quick intro to matplotlib concepts, let’s take a look at that here.
You can find some nice cheatsheets here.
One quick way to plot is calling functions from the pyplot
module.
These are some of the most of the commonly used functions:
= np.linspace(0, 10)
x = x * 2 + 1
y = x * 3 + 1 z
plt.plot(x, y)
plt.scatter(x, z)
=15); plt.hist(x, bins
That is fine and can take us far, especially for rapid exploration. But it has some limitations. There is a more powerful and extensive API way to plot:
API means Application Programming Interface. More details in the glossary.
= plt.subplots() # Create a figure containing a single Axes
fig, ax # Plot on that Axes ax.plot(x, y)
plt.polar(x, y)
= plt.subplots(ncols=2, figsize=(8, 3.5)) # Create a figure containing 2 Axes
fig, (ax1, ax2)
ax1.plot(x, y)
ax2.scatter(x, y) fig.tight_layout()
Notice that fig also contains the axes and that axes are a numpy array!
= plt.subplots(ncols=2);
fig, axes 0])
fig.delaxes(axes[1])
fig.delaxes(axes[type(axes)
numpy.ndarray
<Figure size 640x480 with 0 Axes>
This numpy array grid generalizes and indeed how this objects compose nicely with the python language in general:
= plt.subplots(ncols=2, nrows=2)
fig, axes = axes
((ax1, ax2), (ax3, ax4)) ="red")
ax1.plot(x, y, color="red", marker=".")
ax2.scatter(x, y, color
ax3.plot(x, y)=".")
ax4.scatter(x, y, marker
for i, ax in enumerate(axes.flatten(), start=1):
f"This is plot number {i}")
ax.set_title("x-variable")
ax.set_xlabel("y-variable")
ax.set_ylabel("What a Figure!") # Yeah, set_suptitle would be nicer
fig.suptitle( fig.tight_layout()
= np.outer(x, x)
data
def plot_images(data):
= plt.subplots(ncols=2, figsize=(8, 5))
fig, (ax1, ax2) = ax1.imshow(data)
img1 = ax2.imshow(data, origin="lower", cmap="Oranges")
img2
ax2.set_xticks([])
ax2.set_yticks([])=.5)
plt.colorbar(img1, shrink=.5)
plt.colorbar(img2, shrink
fig.tight_layout()
plot_images(data)
16.1.4 Saving Figures
We can save the figure in different formats (the file extensions serves as guide):
"fig.png")
fig.savefig("fig.svg") # Vector graphics are supported out of the box! fig.savefig(
If you want a more in-depth introduction to the matplotlib library, take a look at this talk.
16.2 Seaborn
seaborn is a data visualization library built on top of matplotlib.
It implements a high-level interface for plotting statistical graphics. Seaborn integrates very well with pandas
dataframes as input data and abstracts away some of the common data pre-processing steps.
Install seaborn:
uv add seaborn
import seaborn as sns
For the sake of the example we can repeat the figure from above, but we will execute this one line of code that sets up some defaults for us:
="dark", font_scale=1.4) # this has global effects sns.set_theme(style
import seaborn as sns
= sns.load_dataset("fmri")
df df.head()
subject | timepoint | event | region | signal | |
---|---|---|---|---|---|
0 | s13 | 18 | stim | parietal | -0.017552 |
1 | s5 | 14 | stim | parietal | -0.080883 |
2 | s12 | 18 | stim | parietal | -0.081033 |
3 | s11 | 18 | stim | parietal | -0.046134 |
4 | s10 | 18 | stim | parietal | -0.037970 |
="timepoint", y="signal", marker=".", hue="region") sns.lineplot(df, x
plot_images(data)
There are many great plot examples in the example gallery.
Here are a few interesting ones – click on figure to open documentation website with code.
Multiple Regression
Time Series
Heat Scatter of Brain Networks Correlations
Annotated Heatmap
Small multiple time series
Scatterplot with categorical variables
16.3 Exercises
Use numpy to create 4 arrays of numbers drawn from 4 different distributions: uniform, normal, lognormal and exponential. Each array should have of 100000 samples. Use matplotlib to plot a grid of 2x2 subplots, with one histogram on each subplot. Each histogram should have a different color. Hint: Look at numpy’s submodule
random
.Pick one interesting example of the Seaborn gallery and reproduce it on your computer. Change 1 or 2 parameters of the plot, for example, some color or order of variables, remove/add a variable. If you have own data with a similar shape, plot those!