summaryrefslogtreecommitdiff
path: root/docs/source/notes
diff options
context:
space:
mode:
authorKai Arulkumaran <Kaixhin@users.noreply.github.com>2017-10-25 21:38:17 +0100
committerAdam Paszke <adam.paszke@gmail.com>2017-10-25 22:38:17 +0200
commita7c5be1d454cd686c399a307f6127cf2e1b8d4f6 (patch)
treeafe3e189ee7799ec017fca1a1ce2f96693cdb9de /docs/source/notes
parent837f933cac6c55366752fe14a793d0010eef0d89 (diff)
downloadpytorch-a7c5be1d454cd686c399a307f6127cf2e1b8d4f6.tar.gz
pytorch-a7c5be1d454cd686c399a307f6127cf2e1b8d4f6.tar.bz2
pytorch-a7c5be1d454cd686c399a307f6127cf2e1b8d4f6.zip
Document CUDA best practices (#3227)
Diffstat (limited to 'docs/source/notes')
-rw-r--r--docs/source/notes/cuda.rst76
1 files changed, 76 insertions, 0 deletions
diff --git a/docs/source/notes/cuda.rst b/docs/source/notes/cuda.rst
index 33d440a819..41e27aff49 100644
--- a/docs/source/notes/cuda.rst
+++ b/docs/source/notes/cuda.rst
@@ -44,6 +44,82 @@ Below you can find a small example showcasing this::
Best practices
--------------
+Device-agnostic code
+^^^^^^^^^^^^^^^^^^^^
+
+Due to the structure of PyTorch, you may need to explicitly write
+device-agnostic (CPU or GPU) code; an example may be creating a new tensor as
+the initial hidden state of a recurrent neural network.
+
+The first step is to determine whether the GPU should be used or not. A common
+pattern is to use Python's `argparse` module to read in user arguments, and
+have a flag that can be used to disable CUDA, in combination with
+`torch.cuda.is_available()`. In the following, `args.cuda` results in a flag
+that can be used to cast tensors and modules to CUDA if desired::
+
+ import argparse
+ import torch
+
+ parser = argparse.ArgumentParser(description='PyTorch Example')
+ parser.add_argument('--disable-cuda', action='store_true',
+ help='Disable CUDA')
+ args = parser.parse_args()
+ args.cuda = not args.disable_cuda and torch.cuda.is_available()
+
+If modules or tensors need to be sent to the GPU, `args.cuda` can be used as
+follows::
+
+ x = torch.Tensor(8, 42)
+ net = Network()
+ if args.cuda:
+ x = x.cuda()
+ net.cuda()
+
+When creating tensors, an alternative to the if statement is to have a default
+datatype defined, and cast all tensors using that. An example when using a
+dataloader would be as follows::
+
+ dtype = torch.cuda.FloatTensor
+ for i, x in enumerate(train_loader):
+ x = Variable(x.type(dtype))
+
+When working with multiple GPUs on a system, you can use the
+`CUDA_VISIBLE_DEVICES` environment flag to manage which GPUs are available to
+PyTorch. To manually control which GPU a tensor is created on, the best practice
+is to use the `torch.cuda.device()` context manager::
+
+ print("Outside device is 0") # On device 0 (default in most scenarios)
+ with torch.cuda.device(1):
+ print("Inside device is 1") # On device 1
+ print("Outside device is still 0") # On device 0
+
+If you have a tensor and would like to create a new tensor of the same type on
+the same device, then you can use the `.new()` function, which acts the same as
+a normal tensor constructor. Whilst the previously mentioned methods depend on
+the current GPU context, `new()` preserves the device of the original tensor.
+
+This is the recommended practice when creating modules in which new
+tensors/variables need to be created internally during the forward pass::
+
+ x_cpu = torch.FloatTensor(1)
+ x_gpu = torch.cuda.FloatTensor(1)
+ x_cpu_long = torch.LongTensor(1)
+
+ y_cpu = x_cpu.new(8, 10, 10).fill_(0.3)
+ y_gpu = x_gpu.new(x_gpu.size()).fill_(-5)
+ y_cpu_long = x_cpu_long.new([[1, 2, 3]])
+
+If you want to create a tensor of the same type and size of another tensor, and
+fill it with either ones or zeros, `torch.ones_like()` or `torch.zeros_like()`
+are provided as more convenient functions (which also preserve device)::
+
+ x_cpu = torch.FloatTensor(1)
+ x_gpu = torch.cuda.FloatTensor(1)
+
+ y_cpu = torch.ones_like(x_cpu)
+ y_gpu = torch.zeros_like(x_gpu)
+
+
Use pinned memory buffers
^^^^^^^^^^^^^^^^^^^^^^^^^