diff options
-rw-r--r-- | test/test_cuda.py | 15 | ||||
-rw-r--r-- | torch/cuda/__init__.py | 39 | ||||
-rw-r--r-- | torch/cuda/__init__.pyi | 1 |
3 files changed, 39 insertions, 16 deletions
diff --git a/test/test_cuda.py b/test/test_cuda.py index ad52d9b950..88ae5b790d 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -1497,6 +1497,21 @@ class TestCuda(TestCase): def test_cuda_synchronize(self): torch.cuda.synchronize() + torch.cuda.synchronize('cuda') + torch.cuda.synchronize('cuda:0') + torch.cuda.synchronize(0) + torch.cuda.synchronize(torch.device('cuda:0')) + + if TEST_MULTIGPU: + torch.cuda.synchronize('cuda:1') + torch.cuda.synchronize(1) + torch.cuda.synchronize(torch.device('cuda:1')) + + with self.assertRaisesRegex(ValueError, "Expected a cuda device, but"): + torch.cuda.synchronize(torch.device("cpu")) + + with self.assertRaisesRegex(ValueError, "Expected a cuda device, but"): + torch.cuda.synchronize("cpu") @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") def test_current_stream(self): diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index c9e4076e16..94fa9b6b4f 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -271,7 +271,7 @@ def get_device_name(device=None): Arguments: device (torch.device or int, optional): device for which to return the name. This function is a no-op if this argument is a negative - integer. Uses the current device, given by :meth:`~torch.cuda.current_device`, + integer. It uses the current device, given by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` (default). """ return get_device_properties(device).name @@ -283,8 +283,8 @@ def get_device_capability(device=None): Arguments: device (torch.device or int, optional): device for which to return the device capability. This function is a no-op if this argument is - a negative integer. Uses the current device, given by - :meth:`~torch.cuda.current_device`, if :attr:`device` is ``None`` + a negative integer. It uses the current device, given by + :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` (default). Returns: @@ -339,7 +339,7 @@ def stream(stream): def device_count(): - """Returns the number of GPUs available.""" + r"""Returns the number of GPUs available.""" if is_available(): return torch._C._cuda_getDeviceCount() else: @@ -352,10 +352,17 @@ def current_device(): return torch._C._cuda_getDevice() -def synchronize(): - r"""Waits for all kernels in all streams on current device to complete.""" +def synchronize(device=None): + r"""Waits for all kernels in all streams on a CUDA device to complete. + + Arguments: + device (torch.device or int, optional): device for which to synchronize. + It uses the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + """ _lazy_init() - return torch._C._cuda_synchronize() + with torch.cuda.device(device): + return torch._C._cuda_synchronize() def ipc_collect(): @@ -377,7 +384,7 @@ def current_stream(device=None): Arguments: device (torch.device or int, optional): selected device. Returns the currently selected :class:`Stream` for the current device, given - by :meth:`~torch.cuda.current_device`, if :attr:`device` is ``None`` + by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` (default). """ _lazy_init() @@ -391,7 +398,7 @@ def default_stream(device=None): Arguments: device (torch.device or int, optional): selected device. Returns the default :class:`Stream` for the current device, given by - :meth:`~torch.cuda.current_device`, if :attr:`device` is ``None`` + :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` (default). """ _lazy_init() @@ -411,7 +418,7 @@ def empty_cache(): `nvidia-smi`. .. note:: - :meth:`~torch.cuda.empty_cache` doesn't increase the amount of GPU + :func:`~torch.cuda.empty_cache` doesn't increase the amount of GPU memory available for PyTorch. See :ref:`cuda-memory-management` for more details about GPU memory management. """ @@ -425,7 +432,7 @@ def memory_allocated(device=None): Arguments: device (torch.device or int, optional): selected device. Returns - statistic for the current device, given by :meth:`~torch.cuda.current_device`, + statistic for the current device, given by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` (default). .. note:: @@ -450,7 +457,7 @@ def max_memory_allocated(device=None): Arguments: device (torch.device or int, optional): selected device. Returns - statistic for the current device, given by :meth:`~torch.cuda.current_device`, + statistic for the current device, given by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` (default). .. note:: @@ -469,7 +476,7 @@ def reset_max_memory_allocated(device=None): Arguments: device (torch.device or int, optional): selected device. Returns - statistic for the current device, given by :meth:`~torch.cuda.current_device`, + statistic for the current device, given by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` (default). .. note:: @@ -486,7 +493,7 @@ def memory_cached(device=None): Arguments: device (torch.device or int, optional): selected device. Returns - statistic for the current device, given by :meth:`~torch.cuda.current_device`, + statistic for the current device, given by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` (default). .. note:: @@ -509,7 +516,7 @@ def max_memory_cached(device=None): Arguments: device (torch.device or int, optional): selected device. Returns - statistic for the current device, given by :meth:`~torch.cuda.current_device`, + statistic for the current device, given by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` (default). .. note:: @@ -528,7 +535,7 @@ def reset_max_memory_cached(device=None): Arguments: device (torch.device or int, optional): selected device. Returns - statistic for the current device, given by :meth:`~torch.cuda.current_device`, + statistic for the current device, given by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` (default). .. note:: diff --git a/torch/cuda/__init__.pyi b/torch/cuda/__init__.pyi index be85475e56..03da71119f 100644 --- a/torch/cuda/__init__.pyi +++ b/torch/cuda/__init__.pyi @@ -26,6 +26,7 @@ _device_t = Union[_device, int] def check_error(res: int) -> None: ... def device_count() -> int: ... def empty_cache() -> None: ... +def synchronize(device: _device_t) -> None: ... def set_device(device: _device_t) -> None: ... def get_device_capability(device: Optional[_device_t]=...) -> Tuple[int, int]: ... def get_device_name(device: Optional[_device_t]=...) -> str: ... |