summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEdward Z. Yang <ezyang@fb.com>2017-10-03 14:17:01 -0700
committerSoumith Chintala <soumith@gmail.com>2017-10-05 15:27:49 -0400
commit2861638e8a779ccd334103026da7454f9db31fd8 (patch)
treef10e955e4db597c448a79097fb4c92564adaa8e2
parent539ae451d22df71810e0a7f54c9c3a688df4e3fc (diff)
downloadpytorch-2861638e8a779ccd334103026da7454f9db31fd8.tar.gz
pytorch-2861638e8a779ccd334103026da7454f9db31fd8.tar.bz2
pytorch-2861638e8a779ccd334103026da7454f9db31fd8.zip
Add torch.random.fork_rng, which forks the RNG temporarily.
There is a bit of nuance to this function. If one blindly charges in and initializes all GPUs, it is going to take a long time. 20sec for 8 GPUs on my dev machine. But to a user, it is non-obvious that fork_rng is going to hit all the GPUs by default (which it does by default for safety reasons.) So there is a nice warning when we notice we're hitting more than one GPU. There is a bit of extra generality which is going to be used by torch.jit in a subsequent commit.
-rw-r--r--torch/random.py71
1 files changed, 71 insertions, 0 deletions
diff --git a/torch/random.py b/torch/random.py
index 0f3baa9b77..34d38f5bf2 100644
--- a/torch/random.py
+++ b/torch/random.py
@@ -1,4 +1,6 @@
import torch
+import contextlib
+import warnings
from torch._C import default_generator
@@ -37,3 +39,72 @@ def initial_seed():
python `long`.
"""
return default_generator.initial_seed()
+
+
+_fork_rng_warned_already = False
+
+
+@contextlib.contextmanager
+def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices"):
+ """
+ Forks the RNG, so that when you return, the RNG is reset
+ to the state that it was previously in.
+
+ Arguments:
+ devices (iterable of CUDA IDs): CUDA devices for which to fork
+ the RNG. CPU RNG state is always forked. By default, fork_rng operates
+ on all devices, but will emit a warning if your machine has a lot
+ of devices, since this function will run very slowly in that case.
+ If you explicitly specify devices, this warning will be supressed
+ enabled (bool): if False, the RNG is not forked. This is a convenience
+ argument for easily disabling the context manager without having
+ to reindent your Python code.
+ """
+
+ import torch.cuda
+ global _fork_rng_warned_already
+
+ # Internal arguments:
+ # _caller: the function which called fork_rng, which the user used
+ # _devices_kw: the devices keyword of _caller
+
+ if not enabled:
+ yield
+ return
+
+ if devices is None:
+ num_devices = torch.cuda.device_count()
+ if num_devices > 1 and not _fork_rng_warned_already:
+ warnings.warn(
+ ("CUDA reports that you have {num_devices} available devices, and you "
+ "have used {caller} without explicitly specifying which devices are being used. "
+ "For safety, we initialize *every* CUDA device by default, which "
+ "can be quite slow if you have a lot of GPUs. If you know that you are only "
+ "making use of a few CUDA devices, set the environment variable CUDA_VISIBLE_DEVICES "
+ "or the '{devices_kw}' keyword argument of {caller} with the set of devices "
+ "you are actually using. For example, if you are using CPU only, "
+ "set CUDA_VISIBLE_DEVICES= or devices=[]; if you are using "
+ "GPU 0 only, set CUDA_VISIBLE_DEVICES=0 or devices=[0]. To initialize "
+ "all devices and suppress this warning, set the '{devices_kw}' keyword argument "
+ "to `range(torch.cuda.device_count())`."
+ ).format(num_devices=num_devices, caller=_caller, devices_kw=_devices_kw))
+ _fork_rng_warned_already = True
+ devices = list(range(num_devices))
+ else:
+ # Protect against user passing us a generator; we need to traverse this
+ # multiple times but a generator will be exhausted upon first traversal
+ devices = list(devices)
+
+ cpu_rng_state = torch.get_rng_state()
+ gpu_rng_states = []
+ for device in devices:
+ with torch.cuda.device(device):
+ gpu_rng_states.append(torch.cuda.get_rng_state())
+
+ try:
+ yield
+ finally:
+ torch.set_rng_state(cpu_rng_state)
+ for device, gpu_rng_state in zip(devices, gpu_rng_states):
+ with torch.cuda.device(device):
+ torch.cuda.set_rng_state(gpu_rng_state)