Skip to content

Commit 39ab47d

Browse files
committed
Fix Transforms tutorial: unify target_transform and ToTensor prose
1 parent 3406de7 commit 39ab47d

1 file changed

Lines changed: 13 additions & 10 deletions

File tree

beginner_source/basics/transforms_tutorial.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,35 +30,38 @@
3030
from torchvision import datasets
3131
from torchvision.transforms import ToTensor, Lambda
3232

33+
target_transform = Lambda(
34+
lambda y: torch.zeros(10, dtype=torch.float32).scatter_(
35+
dim=0, index=torch.tensor(y), value=1
36+
)
37+
)
38+
3339
ds = datasets.FashionMNIST(
3440
root="data",
3541
train=True,
3642
download=True,
3743
transform=ToTensor(),
38-
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
44+
target_transform=target_transform,
3945
)
4046

4147
#################################################
4248
# ToTensor()
4349
# -------------------------------
4450
#
4551
# `ToTensor <https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.ToTensor>`_
46-
# converts a PIL image or NumPy ``ndarray`` into a ``FloatTensor``. and scales
52+
# converts a PIL image or NumPy ``ndarray`` into a ``FloatTensor`` and scales
4753
# the image's pixel intensity values in the range [0., 1.]
4854
#
4955

5056
##############################################
5157
# Lambda Transforms
5258
# -------------------------------
5359
#
54-
# Lambda transforms apply any user-defined lambda function. Here, we define a function
55-
# to turn the integer into a one-hot encoded tensor.
56-
# It first creates a zero tensor of size 10 (the number of labels in our dataset) and calls
57-
# `scatter_ <https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html>`_ which assigns a
58-
# ``value=1`` on the index as given by the label ``y``.
59-
60-
target_transform = Lambda(lambda y: torch.zeros(
61-
10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
60+
# Lambda transforms apply any user-defined callable. Above, ``target_transform`` wraps a
61+
# small lambda that turns each label ``y`` into a one-hot encoded tensor: it allocates a
62+
# length-10 zero vector and uses
63+
# `scatter_ <https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html>`_ to write
64+
# ``value=1`` at index ``y``.
6265

6366
######################################################################
6467
# --------------

0 commit comments

Comments
 (0)