PyTorch DataLoader: Understand and implement a custom collate function

pytorch
python
data
neural-nets
Understand DataLoader’s inner workings and bring your data pipeline to the next level.
Author

Fabrizio Damicelli

Published

September 13, 2023

This post contains the code behind this video explanation:

Code
import torch
from torch import tensor
import numpy as np
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence

Imagine a supervised learning scenario of a classification task with sequential data as features and a binary target.

Let’s simulate a toy dataset and take a look at it:

Code
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self):
        self.xs = [
            list(range(11, 13)),
            list(range(13, 16)),
            list(range(16, 21)),
            list(range(21, 24)),
            list(range(22, 25)),
            list(range(25, 30)),
        ]
        self.ys = [0, 0, 0, 1, 1, 1]
        assert len(self.xs) == len(self.ys)
    def __len__(self): 
        return len(self.xs)
    def __getitem__(self, idx):
        return {
            "x": self.xs[idx],
            "y": self.ys[idx],
        }
dset = CustomDataset()

for item in dset:
    print(item)
{'x': [11, 12], 'y': 0}
{'x': [13, 14, 15], 'y': 0}
{'x': [16, 17, 18, 19, 20], 'y': 0}
{'x': [21, 22, 23], 'y': 1}
{'x': [22, 23, 24], 'y': 1}
{'x': [25, 26, 27, 28, 29], 'y': 1}
dloader = DataLoader(dset, batch_size=2, shuffle=False)
for batch in dloader:
    print(batch)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[6], line 1
----> 1 for batch in dloader:
      2     print(batch)

File /home/fabrizio/miniconda3/envs/olaf-training-py39/lib/python3.9/site-packages/torch/utils/data/dataloader.py:628, in _BaseDataLoaderIter.__next__(self)
    625 if self._sampler_iter is None:
    626     # TODO(https://github.com/pytorch/pytorch/issues/76750)
    627     self._reset()  # type: ignore[call-arg]
--> 628 data = self._next_data()
    629 self._num_yielded += 1
    630 if self._dataset_kind == _DatasetKind.Iterable and \
    631         self._IterableDataset_len_called is not None and \
    632         self._num_yielded > self._IterableDataset_len_called:

File /home/fabrizio/miniconda3/envs/olaf-training-py39/lib/python3.9/site-packages/torch/utils/data/dataloader.py:671, in _SingleProcessDataLoaderIter._next_data(self)
    669 def _next_data(self):
    670     index = self._next_index()  # may raise StopIteration
--> 671     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    672     if self._pin_memory:
    673         data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

File /home/fabrizio/miniconda3/envs/olaf-training-py39/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py:61, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
     59 else:
     60     data = self.dataset[possibly_batched_index]
---> 61 return self.collate_fn(data)

File /home/fabrizio/miniconda3/envs/olaf-training-py39/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py:265, in default_collate(batch)
    204 def default_collate(batch):
    205     r"""
    206         Function that takes in a batch of data and puts the elements within the batch
    207         into a tensor with an additional outer dimension - batch size. The exact output type can be
   (...)
    263             >>> default_collate(batch)  # Handle `CustomType` automatically
    264     """
--> 265     return collate(batch, collate_fn_map=default_collate_fn_map)

File /home/fabrizio/miniconda3/envs/olaf-training-py39/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py:128, in collate(batch, collate_fn_map)
    126 if isinstance(elem, collections.abc.Mapping):
    127     try:
--> 128         return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem})
    129     except TypeError:
    130         # The mapping type may not support `__init__(iterable)`.
    131         return {key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}

File /home/fabrizio/miniconda3/envs/olaf-training-py39/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py:128, in <dictcomp>(.0)
    126 if isinstance(elem, collections.abc.Mapping):
    127     try:
--> 128         return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem})
    129     except TypeError:
    130         # The mapping type may not support `__init__(iterable)`.
    131         return {key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}

File /home/fabrizio/miniconda3/envs/olaf-training-py39/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py:139, in collate(batch, collate_fn_map)
    137 elem_size = len(next(it))
    138 if not all(len(elem) == elem_size for elem in it):
--> 139     raise RuntimeError('each element in list of batch should be of equal size')
    140 transposed = list(zip(*batch))  # It may be accessed twice, so we use a list.
    142 if isinstance(elem, tuple):

RuntimeError: each element in list of batch should be of equal size

A first solution attempt

We can refactor our dataset and make it generate items with x sequences that all have the same length (a parameter max_len that we define beforehand).

class CustomDatasetFixLen(torch.utils.data.Dataset):
    def __init__(self, max_len=10):
        self.max_len = max_len
        self.xs = [
            list(range(11, 13)),
            list(range(13, 16)),
            list(range(16, 21)),
            list(range(21, 24)),
            list(range(22, 25)),
            list(range(25, 30)),
        ]
        self.ys = [0, 0, 0, 1, 1, 1]
        assert len(self.xs) == len(self.ys)
    def __len__(self): 
        return len(self.xs)
    def __getitem__(self, idx):
        x = self.xs[idx]
        pad_len = self.max_len - len(x)
        x = x + [0]*pad_len
        return {
            "x": np.array(x),
            "y": self.ys[idx],
        }
dset = CustomDatasetFixLen(max_len=10)
for item in dset:
    print(item)
{'x': array([11, 12,  0,  0,  0,  0,  0,  0,  0,  0]), 'y': 0}
{'x': array([13, 14, 15,  0,  0,  0,  0,  0,  0,  0]), 'y': 0}
{'x': array([16, 17, 18, 19, 20,  0,  0,  0,  0,  0]), 'y': 0}
{'x': array([21, 22, 23,  0,  0,  0,  0,  0,  0,  0]), 'y': 1}
{'x': array([22, 23, 24,  0,  0,  0,  0,  0,  0,  0]), 'y': 1}
{'x': array([25, 26, 27, 28, 29,  0,  0,  0,  0,  0]), 'y': 1}

That works but is wasteful because we will be padding to max_len = 10, even when we only need to pad to length 3 (for example, if the batch is formed by the first two items). That could limit the batch size we work with slowing down the training or even lead to unnecessary computations during the forward pass if we just pass our batches without masking. So, ideally, we would like to pad only as much as we need on each batch. In other words, we want to dynamically (per batch basis) adapt the padding.

There must be a better way

Let’s implement our own collate function, i.e. the logic to put items together, that will allow us to the padding on a per batch basis (thus we call it dynamic_length_collate)

def dynamic_length_collate(batch):
    max_len = max(len(item["x"]) for item in batch)
    batch_x = []
    for item in batch:
        pad_len = max_len - len(item["x"])
        batch_x.append(item["x"] + [0]*pad_len)
    return {
        "x": tensor(batch_x).type(torch.float),
        "y": tensor([item["y"] for item in batch])
    }
dset = CustomDataset()  # Use our original dataset, without fix max_len
dloader = DataLoader(dset, batch_size=2, shuffle=False,
                     collate_fn=dynamic_length_collate)
for batch in dloader:
    print(batch)
{'x': tensor([[11., 12.,  0.],
        [13., 14., 15.]]), 'y': tensor([0, 0])}
{'x': tensor([[16., 17., 18., 19., 20.],
        [21., 22., 23.,  0.,  0.]]), 'y': tensor([0, 1])}
{'x': tensor([[22., 23., 24.,  0.,  0.],
        [25., 26., 27., 28., 29.]]), 'y': tensor([1, 1])}

That works!

For the sake of completeness, let’s use our dataloader with the custom collate function and actually feed the data into a (toy) neural network.

# A very toy example of a neural network
model = torch.nn.LSTM(input_size=1, hidden_size=2, batch_first=True)

for batch in dloader:
    bs, seq_len = batch["x"].shape
    pred = model(batch["x"].reshape(bs, seq_len, 1))
    print(pred)
    break
(tensor([[[ 9.4607e-04,  4.0929e-03],
         [ 5.0468e-04,  5.7644e-03],
         [-1.5826e-01,  1.9474e-02]],

        [[ 2.6432e-04,  2.6775e-03],
         [ 1.3929e-04,  3.5764e-03],
         [ 7.2860e-05,  3.3866e-03]]], grad_fn=<TransposeBackward0>), (tensor([[[-1.5826e-01,  1.9474e-02],
         [ 7.2860e-05,  3.3866e-03]]], grad_fn=<StackBackward0>), tensor([[[-0.2420,  0.0843],
         [ 0.0071,  1.2391]]], grad_fn=<StackBackward0>)))

/Fin

Any bugs, questions, comments, suggestions? Ping me on twitter or drop me an e-mail (fabridamicelli at gmail).
Share this article on your favourite platform: