import torch
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Lambda
Why do we need transform?
Data comes in many different formats. On the other hand, PyTorch can only do machine learning with one data type, the tensor. Transforms can convert any data to a tensor. In this post, we will look at how to transform images. I will assume that you are familiar with PyTorch Datasets. If you are not, I recommend reading this post before you continue.
How do PyTorch transforms work?
All built-in datasets from the torchvision module take the parameters transform
and target_tranform
. They take in a function that transforms input data into a tensor, following predefined steps. To avoid having to write these functions ourselves, the torchvision.transforms
module come with an image-to-tensor transform, called ToTensor
out of the box.
Let’s see an example through the FashionMNIST dataset.
def our_own_transformation(target):
"""
Transformes target label to a one-hot tensor
example:
>>> our_own_transformation(3)
>>> torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0])
"""
= torch.zeros(10, dtype=torch.float)
zeros_list = torch.tensor(target)
one_hot_index = zeros_list.scatter_(0, one_hot_index, value=1)
one_hot_tensor return one_hot_tensor
= FashionMNIST(
ds_train ="data",
root=True,
train=True,
download=ToTensor(),
transform=Lambda(our_own_transformation)
target_transform
)
= FashionMNIST(
ds_test ="data",
root=True,
train=True,
download=ToTensor(),
transform=Lambda(our_own_transformation)
target_transform )
In this code, we specified that we want to convert our training data to a tensor using the ToTensor
method, and the target label to a tensor using our_own_transformation
.
Further reading
There are many more things we can do with transforms. We can rotate images, shift images, or we can chain transformations together to create a preprocessing pipeline. Since those usecases are too advanced for us at the moment, I will not cover them in this post. However, if you are curious or already more experienced, I recommend that you check out the example section on the Pytorch Website!