Code
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from torch import tensor
="talk") sns.set_context(context
Fabrizio Damicelli
August 4, 2022
seaborn makes our life easy when it comes to slicing and plotting data in Python.
That awesome buffet of well balanced aesthetic and practical functionalities of its ergonomic API comes with a few caveats to consider though.
Here’s one of them when trying to plot data including PyTorch tensors.
Let’s create 2 simple numpy
arrays simulating the values of two variables, x
and y
.
x = np.arange(20) + np.random.normal(scale=1.7, size=20).round(2)
y = np.arange(20) + np.random.normal(scale=1.7, size=20).round(2)
x[:5], y[:5]
(array([-1.41, 3.97, 3.2 , 4.63, 2.99]),
array([-0.58, 1.07, -0.13, 0.79, 3.84]))
We can plot them using seaborn scatterplot:
We can achieve the same using lists:
([3.93, 1.6400000000000001, 2.91, 2.75, 0.3900000000000001],
[-1.08, 0.91, 4.970000000000001, 3.64, 8.129999999999999])
Again, we can plot them using seaborn
scatterplot:
Observe that the lists are made up of python float
s while the numpy
arrays contain numpy.float64
:
So far so good.
Now what happens if the individual elements are pytorch zero-dimensional tensors (i.e. scalars) like these:
([tensor(3.9300, dtype=torch.float64),
tensor(1.6400, dtype=torch.float64),
tensor(2.9100, dtype=torch.float64),
tensor(2.7500, dtype=torch.float64),
tensor(0.3900, dtype=torch.float64)],
[tensor(-1.0800, dtype=torch.float64),
tensor(0.9100, dtype=torch.float64),
tensor(4.9700, dtype=torch.float64),
tensor(3.6400, dtype=torch.float64),
tensor(8.1300, dtype=torch.float64)])
At first glance it looks like it should be all kind of the same. Indeed some comparisons still work the way we (I?) expect. For example, the element-wise equality:
array([ True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True])
array([ True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True])
array([ True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True])
This happens thanks to the fact that numpy under the hood first casts the objects and then compares them.
When comparing the python list directly to the tensor we have this:
That is not quite what I expect (i.e. element-wise comparison), but it is still aligned with our believe that these array-like structures (list, array, tensor) are made of equivalent scalar elements.
So let’s plot the tensors, like we did with the lists and the arrays:
Ups, that doesn’t look good – the y-axis is flipped!
After going a bit down the rabbit hole of seaborn
and pandas
error traces, we see that under the hood seaborn
infers the data type of the values and –surprise!– pytorch tensors seem to be interpreted as categorical.
That can be more explicitely seen if we try to plot the data with the pointplot
function instead:
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) /tmp/ipykernel_418681/3817459137.py in <module> ----> 1 sns.pointplot(x=x_pt, y=y_pt); ~/miniconda3/envs/myenv39/lib/python3.9/site-packages/seaborn/_decorators.py in inner_f(*args, **kwargs) 44 ) 45 kwargs.update({k: arg for k, arg in zip(sig.parameters, args)}) ---> 46 return f(**kwargs) 47 return inner_f 48 ~/miniconda3/envs/myenv39/lib/python3.9/site-packages/seaborn/categorical.py in pointplot(x, y, hue, data, order, hue_order, estimator, ci, n_boot, units, seed, markers, linestyles, dodge, join, scale, orient, color, palette, errwidth, capsize, ax, **kwargs) 3373 ): 3374 -> 3375 plotter = _PointPlotter(x, y, hue, data, order, hue_order, 3376 estimator, ci, n_boot, units, seed, 3377 markers, linestyles, dodge, join, scale, ~/miniconda3/envs/myenv39/lib/python3.9/site-packages/seaborn/categorical.py in __init__(self, x, y, hue, data, order, hue_order, estimator, ci, n_boot, units, seed, markers, linestyles, dodge, join, scale, orient, color, palette, errwidth, capsize) 1653 orient, color, palette, errwidth=None, capsize=None): 1654 """Initialize the plotter.""" -> 1655 self.establish_variables(x, y, hue, data, orient, 1656 order, hue_order, units) 1657 self.establish_colors(color, palette, 1) ~/miniconda3/envs/myenv39/lib/python3.9/site-packages/seaborn/categorical.py in establish_variables(self, x, y, hue, data, orient, order, hue_order, units) 154 155 # Figure out the plotting orientation --> 156 orient = infer_orient( 157 x, y, orient, require_numeric=self.require_numeric 158 ) ~/miniconda3/envs/myenv39/lib/python3.9/site-packages/seaborn/_core.py in infer_orient(x, y, orient, require_numeric) 1350 elif require_numeric and "numeric" not in (x_type, y_type): 1351 err = "Neither the `x` nor `y` variable appears to be numeric." -> 1352 raise TypeError(err) 1353 1354 else: TypeError: Neither the `x` nor `y` variable appears to be numeric.
To go a bit deeper understanding this behaviour you can read the section “Categorical plots will always be categorical” of this article by seaborn’s creator Michaels Waskom himself.
A couple of options to fix our plots.
We can flip the y-axis:
We can cast the data, for example to a numpy array:
Or to a tensor:
Maybe that changes in in the future.
PS: Consider starring the seaborn project on github.
/Fin
Any bugs, questions, comments, suggestions? Ping me on twitter or drop me an e-mail (fabridamicelli at gmail).
Share this article on your favourite platform: