|
30 | 30 | from torchvision import datasets |
31 | 31 | from torchvision.transforms import ToTensor, Lambda |
32 | 32 |
|
| 33 | +target_transform = Lambda( |
| 34 | + lambda y: torch.zeros(10, dtype=torch.float32).scatter_( |
| 35 | + dim=0, index=torch.tensor(y), value=1 |
| 36 | + ) |
| 37 | +) |
| 38 | + |
33 | 39 | ds = datasets.FashionMNIST( |
34 | 40 | root="data", |
35 | 41 | train=True, |
36 | 42 | download=True, |
37 | 43 | transform=ToTensor(), |
38 | | - target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1)) |
| 44 | + target_transform=target_transform, |
39 | 45 | ) |
40 | 46 |
|
41 | 47 | ################################################# |
42 | 48 | # ToTensor() |
43 | 49 | # ------------------------------- |
44 | 50 | # |
45 | 51 | # `ToTensor <https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.ToTensor>`_ |
46 | | -# converts a PIL image or NumPy ``ndarray`` into a ``FloatTensor``. and scales |
| 52 | +# converts a PIL image or NumPy ``ndarray`` into a ``FloatTensor`` and scales |
47 | 53 | # the image's pixel intensity values in the range [0., 1.] |
48 | 54 | # |
49 | 55 |
|
50 | 56 | ############################################## |
51 | 57 | # Lambda Transforms |
52 | 58 | # ------------------------------- |
53 | 59 | # |
54 | | -# Lambda transforms apply any user-defined lambda function. Here, we define a function |
55 | | -# to turn the integer into a one-hot encoded tensor. |
56 | | -# It first creates a zero tensor of size 10 (the number of labels in our dataset) and calls |
57 | | -# `scatter_ <https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html>`_ which assigns a |
58 | | -# ``value=1`` on the index as given by the label ``y``. |
59 | | - |
60 | | -target_transform = Lambda(lambda y: torch.zeros( |
61 | | - 10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1)) |
| 60 | +# Lambda transforms apply any user-defined callable. Above, ``target_transform`` wraps a |
| 61 | +# small lambda that turns each label ``y`` into a one-hot encoded tensor: it allocates a |
| 62 | +# length-10 zero vector and uses |
| 63 | +# `scatter_ <https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html>`_ to write |
| 64 | +# ``value=1`` at index ``y``. |
62 | 65 |
|
63 | 66 | ###################################################################### |
64 | 67 | # -------------- |
|
0 commit comments