Code
import torch
from torch import tensor
import numpy as np
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
Fabrizio Damicelli
September 13, 2023
This post contains the code behind this video explanation:
Imagine a supervised learning scenario of a classification task with sequential data as features and a binary target.
Let’s simulate a toy dataset and take a look at it:
class CustomDataset(torch.utils.data.Dataset):
def __init__(self):
self.xs = [
list(range(11, 13)),
list(range(13, 16)),
list(range(16, 21)),
list(range(21, 24)),
list(range(22, 25)),
list(range(25, 30)),
]
self.ys = [0, 0, 0, 1, 1, 1]
assert len(self.xs) == len(self.ys)
def __len__(self):
return len(self.xs)
def __getitem__(self, idx):
return {
"x": self.xs[idx],
"y": self.ys[idx],
}
{'x': [11, 12], 'y': 0}
{'x': [13, 14, 15], 'y': 0}
{'x': [16, 17, 18, 19, 20], 'y': 0}
{'x': [21, 22, 23], 'y': 1}
{'x': [22, 23, 24], 'y': 1}
{'x': [25, 26, 27, 28, 29], 'y': 1}
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) Cell In[6], line 1 ----> 1 for batch in dloader: 2 print(batch) File /home/fabrizio/miniconda3/envs/olaf-training-py39/lib/python3.9/site-packages/torch/utils/data/dataloader.py:628, in _BaseDataLoaderIter.__next__(self) 625 if self._sampler_iter is None: 626 # TODO(https://github.com/pytorch/pytorch/issues/76750) 627 self._reset() # type: ignore[call-arg] --> 628 data = self._next_data() 629 self._num_yielded += 1 630 if self._dataset_kind == _DatasetKind.Iterable and \ 631 self._IterableDataset_len_called is not None and \ 632 self._num_yielded > self._IterableDataset_len_called: File /home/fabrizio/miniconda3/envs/olaf-training-py39/lib/python3.9/site-packages/torch/utils/data/dataloader.py:671, in _SingleProcessDataLoaderIter._next_data(self) 669 def _next_data(self): 670 index = self._next_index() # may raise StopIteration --> 671 data = self._dataset_fetcher.fetch(index) # may raise StopIteration 672 if self._pin_memory: 673 data = _utils.pin_memory.pin_memory(data, self._pin_memory_device) File /home/fabrizio/miniconda3/envs/olaf-training-py39/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py:61, in _MapDatasetFetcher.fetch(self, possibly_batched_index) 59 else: 60 data = self.dataset[possibly_batched_index] ---> 61 return self.collate_fn(data) File /home/fabrizio/miniconda3/envs/olaf-training-py39/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py:265, in default_collate(batch) 204 def default_collate(batch): 205 r""" 206 Function that takes in a batch of data and puts the elements within the batch 207 into a tensor with an additional outer dimension - batch size. The exact output type can be (...) 263 >>> default_collate(batch) # Handle `CustomType` automatically 264 """ --> 265 return collate(batch, collate_fn_map=default_collate_fn_map) File /home/fabrizio/miniconda3/envs/olaf-training-py39/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py:128, in collate(batch, collate_fn_map) 126 if isinstance(elem, collections.abc.Mapping): 127 try: --> 128 return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}) 129 except TypeError: 130 # The mapping type may not support `__init__(iterable)`. 131 return {key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem} File /home/fabrizio/miniconda3/envs/olaf-training-py39/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py:128, in <dictcomp>(.0) 126 if isinstance(elem, collections.abc.Mapping): 127 try: --> 128 return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}) 129 except TypeError: 130 # The mapping type may not support `__init__(iterable)`. 131 return {key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem} File /home/fabrizio/miniconda3/envs/olaf-training-py39/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py:139, in collate(batch, collate_fn_map) 137 elem_size = len(next(it)) 138 if not all(len(elem) == elem_size for elem in it): --> 139 raise RuntimeError('each element in list of batch should be of equal size') 140 transposed = list(zip(*batch)) # It may be accessed twice, so we use a list. 142 if isinstance(elem, tuple): RuntimeError: each element in list of batch should be of equal size
We can refactor our dataset and make it generate items with x
sequences that all have the same length (a parameter max_len
that we define beforehand).
class CustomDatasetFixLen(torch.utils.data.Dataset):
def __init__(self, max_len=10):
self.max_len = max_len
self.xs = [
list(range(11, 13)),
list(range(13, 16)),
list(range(16, 21)),
list(range(21, 24)),
list(range(22, 25)),
list(range(25, 30)),
]
self.ys = [0, 0, 0, 1, 1, 1]
assert len(self.xs) == len(self.ys)
def __len__(self):
return len(self.xs)
def __getitem__(self, idx):
x = self.xs[idx]
pad_len = self.max_len - len(x)
x = x + [0]*pad_len
return {
"x": np.array(x),
"y": self.ys[idx],
}
{'x': array([11, 12, 0, 0, 0, 0, 0, 0, 0, 0]), 'y': 0}
{'x': array([13, 14, 15, 0, 0, 0, 0, 0, 0, 0]), 'y': 0}
{'x': array([16, 17, 18, 19, 20, 0, 0, 0, 0, 0]), 'y': 0}
{'x': array([21, 22, 23, 0, 0, 0, 0, 0, 0, 0]), 'y': 1}
{'x': array([22, 23, 24, 0, 0, 0, 0, 0, 0, 0]), 'y': 1}
{'x': array([25, 26, 27, 28, 29, 0, 0, 0, 0, 0]), 'y': 1}
That works but is wasteful because we will be padding to max_len
= 10, even when we only need to pad to length 3 (for example, if the batch is formed by the first two items). That could limit the batch size we work with slowing down the training or even lead to unnecessary computations during the forward pass if we just pass our batches without masking. So, ideally, we would like to pad only as much as we need on each batch. In other words, we want to dynamically (per batch basis) adapt the padding.
Let’s implement our own collate function, i.e. the logic to put items together, that will allow us to the padding on a per batch basis (thus we call it dynamic_length_collate
)
{'x': tensor([[11., 12., 0.],
[13., 14., 15.]]), 'y': tensor([0, 0])}
{'x': tensor([[16., 17., 18., 19., 20.],
[21., 22., 23., 0., 0.]]), 'y': tensor([0, 1])}
{'x': tensor([[22., 23., 24., 0., 0.],
[25., 26., 27., 28., 29.]]), 'y': tensor([1, 1])}
That works!
For the sake of completeness, let’s use our dataloader with the custom collate function and actually feed the data into a (toy) neural network.
# A very toy example of a neural network
model = torch.nn.LSTM(input_size=1, hidden_size=2, batch_first=True)
for batch in dloader:
bs, seq_len = batch["x"].shape
pred = model(batch["x"].reshape(bs, seq_len, 1))
print(pred)
break
(tensor([[[ 9.4607e-04, 4.0929e-03],
[ 5.0468e-04, 5.7644e-03],
[-1.5826e-01, 1.9474e-02]],
[[ 2.6432e-04, 2.6775e-03],
[ 1.3929e-04, 3.5764e-03],
[ 7.2860e-05, 3.3866e-03]]], grad_fn=<TransposeBackward0>), (tensor([[[-1.5826e-01, 1.9474e-02],
[ 7.2860e-05, 3.3866e-03]]], grad_fn=<StackBackward0>), tensor([[[-0.2420, 0.0843],
[ 0.0071, 1.2391]]], grad_fn=<StackBackward0>)))
/Fin
Any bugs, questions, comments, suggestions? Ping me on twitter or drop me an e-mail (fabridamicelli at gmail).
Share this article on your favourite platform: