Generating batch data for PyTorch | by Sam | Nov, 2020
Deep Learning in Practice
Creating custom data loaders for PyTorch — MADE EASY!
I was in the middle of creating a custom PyTorch training module that overcomplicated things, especially when it came to generating batches for training and ensuring that those batches weren’t repeated during the training epoch. “This is a solved problem” I thought to myself as I furiously coded away in the depths of the lab.
There’s reasons why you don’t want to just increment indices as you select items from your dataset. 1) This doesn’t scale out to multiple workers. 2) You need to randomize your sequences to maximize training performance.
This is where Torch’s data utilities (
torch.utils.data ) come in handy. You should never create a batch generator from scratch.
You can take two approaches. 1) Move all the preprocessing before you create a dataset, and just use the dataset to generate items or 2) Perform all the preprocessing (scaling, shifting, reshaping, etc) in the initialization step of your dataset. If you’re only using Torch, method #2 makes sense. I am using multiple backends, so I’m rolling method #1.
- Create a custom dataset class. You overwrite the
- Create an iterator that uses
- Use this iterator in your training loop.
See the attached code for an example of how this is used.
In this article, we reviewed the best method for feeding data to a PyTorch training loop. This opens up a number of interested data access patterns that facilitate easier and faster training such as:
- Using multiple processes to read data
- More standardized data preprocessing
- Using dataloaders across multiple machines for distributed training
Thanks for reading!
If you liked this, you might like: