Skip to content

Commit c717024

Browse files
committed
Update CIFAR10 tutorial device selection for CUDA/MPS/CPU
1 parent 3406de7 commit c717024

1 file changed

Lines changed: 12 additions & 8 deletions

File tree

beginner_source/blitz/cifar10_tutorial.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -299,28 +299,32 @@ def forward(self, x):
299299
# Just like how you transfer a Tensor onto the GPU, you transfer the neural
300300
# net onto the GPU.
301301
#
302-
# Let's first define our device as the first visible cuda device if we have
303-
# CUDA available:
302+
# Let's first select a device. Prefer CUDA when available, otherwise use MPS
303+
# (Apple Silicon), and fall back to CPU.
304+
if torch.cuda.is_available():
305+
device = torch.device("cuda:0")
306+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
307+
device = torch.device("mps")
308+
else:
309+
device = torch.device("cpu")
304310

305-
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
306-
307-
# Assuming that we are on a CUDA machine, this should print a CUDA device:
311+
# This prints the selected device, e.g. "cuda:0", "mps", or "cpu".
308312

309313
print(device)
310314

311315
########################################################################
312-
# The rest of this section assumes that ``device`` is a CUDA device.
316+
# The rest of this section assumes that ``device`` is an accelerator device.
313317
#
314318
# Then these methods will recursively go over all modules and convert their
315-
# parameters and buffers to CUDA tensors:
319+
# parameters and buffers to tensors on ``device``:
316320
#
317321
# .. code:: python
318322
#
319323
# net.to(device)
320324
#
321325
#
322326
# Remember that you will have to send the inputs and targets at every step
323-
# to the GPU too:
327+
# to ``device`` too:
324328
#
325329
# .. code:: python
326330
#

0 commit comments

Comments
 (0)