summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--test/test_cuda.py15
-rw-r--r--torch/cuda/__init__.py39
-rw-r--r--torch/cuda/__init__.pyi1
3 files changed, 39 insertions, 16 deletions
diff --git a/test/test_cuda.py b/test/test_cuda.py
index ad52d9b950..88ae5b790d 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -1497,6 +1497,21 @@ class TestCuda(TestCase):
def test_cuda_synchronize(self):
torch.cuda.synchronize()
+ torch.cuda.synchronize('cuda')
+ torch.cuda.synchronize('cuda:0')
+ torch.cuda.synchronize(0)
+ torch.cuda.synchronize(torch.device('cuda:0'))
+
+ if TEST_MULTIGPU:
+ torch.cuda.synchronize('cuda:1')
+ torch.cuda.synchronize(1)
+ torch.cuda.synchronize(torch.device('cuda:1'))
+
+ with self.assertRaisesRegex(ValueError, "Expected a cuda device, but"):
+ torch.cuda.synchronize(torch.device("cpu"))
+
+ with self.assertRaisesRegex(ValueError, "Expected a cuda device, but"):
+ torch.cuda.synchronize("cpu")
@unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
def test_current_stream(self):
diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py
index c9e4076e16..94fa9b6b4f 100644
--- a/torch/cuda/__init__.py
+++ b/torch/cuda/__init__.py
@@ -271,7 +271,7 @@ def get_device_name(device=None):
Arguments:
device (torch.device or int, optional): device for which to return the
name. This function is a no-op if this argument is a negative
- integer. Uses the current device, given by :meth:`~torch.cuda.current_device`,
+ integer. It uses the current device, given by :func:`~torch.cuda.current_device`,
if :attr:`device` is ``None`` (default).
"""
return get_device_properties(device).name
@@ -283,8 +283,8 @@ def get_device_capability(device=None):
Arguments:
device (torch.device or int, optional): device for which to return the
device capability. This function is a no-op if this argument is
- a negative integer. Uses the current device, given by
- :meth:`~torch.cuda.current_device`, if :attr:`device` is ``None``
+ a negative integer. It uses the current device, given by
+ :func:`~torch.cuda.current_device`, if :attr:`device` is ``None``
(default).
Returns:
@@ -339,7 +339,7 @@ def stream(stream):
def device_count():
- """Returns the number of GPUs available."""
+ r"""Returns the number of GPUs available."""
if is_available():
return torch._C._cuda_getDeviceCount()
else:
@@ -352,10 +352,17 @@ def current_device():
return torch._C._cuda_getDevice()
-def synchronize():
- r"""Waits for all kernels in all streams on current device to complete."""
+def synchronize(device=None):
+ r"""Waits for all kernels in all streams on a CUDA device to complete.
+
+ Arguments:
+ device (torch.device or int, optional): device for which to synchronize.
+ It uses the current device, given by :func:`~torch.cuda.current_device`,
+ if :attr:`device` is ``None`` (default).
+ """
_lazy_init()
- return torch._C._cuda_synchronize()
+ with torch.cuda.device(device):
+ return torch._C._cuda_synchronize()
def ipc_collect():
@@ -377,7 +384,7 @@ def current_stream(device=None):
Arguments:
device (torch.device or int, optional): selected device. Returns
the currently selected :class:`Stream` for the current device, given
- by :meth:`~torch.cuda.current_device`, if :attr:`device` is ``None``
+ by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None``
(default).
"""
_lazy_init()
@@ -391,7 +398,7 @@ def default_stream(device=None):
Arguments:
device (torch.device or int, optional): selected device. Returns
the default :class:`Stream` for the current device, given by
- :meth:`~torch.cuda.current_device`, if :attr:`device` is ``None``
+ :func:`~torch.cuda.current_device`, if :attr:`device` is ``None``
(default).
"""
_lazy_init()
@@ -411,7 +418,7 @@ def empty_cache():
`nvidia-smi`.
.. note::
- :meth:`~torch.cuda.empty_cache` doesn't increase the amount of GPU
+ :func:`~torch.cuda.empty_cache` doesn't increase the amount of GPU
memory available for PyTorch. See :ref:`cuda-memory-management` for
more details about GPU memory management.
"""
@@ -425,7 +432,7 @@ def memory_allocated(device=None):
Arguments:
device (torch.device or int, optional): selected device. Returns
- statistic for the current device, given by :meth:`~torch.cuda.current_device`,
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
if :attr:`device` is ``None`` (default).
.. note::
@@ -450,7 +457,7 @@ def max_memory_allocated(device=None):
Arguments:
device (torch.device or int, optional): selected device. Returns
- statistic for the current device, given by :meth:`~torch.cuda.current_device`,
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
if :attr:`device` is ``None`` (default).
.. note::
@@ -469,7 +476,7 @@ def reset_max_memory_allocated(device=None):
Arguments:
device (torch.device or int, optional): selected device. Returns
- statistic for the current device, given by :meth:`~torch.cuda.current_device`,
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
if :attr:`device` is ``None`` (default).
.. note::
@@ -486,7 +493,7 @@ def memory_cached(device=None):
Arguments:
device (torch.device or int, optional): selected device. Returns
- statistic for the current device, given by :meth:`~torch.cuda.current_device`,
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
if :attr:`device` is ``None`` (default).
.. note::
@@ -509,7 +516,7 @@ def max_memory_cached(device=None):
Arguments:
device (torch.device or int, optional): selected device. Returns
- statistic for the current device, given by :meth:`~torch.cuda.current_device`,
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
if :attr:`device` is ``None`` (default).
.. note::
@@ -528,7 +535,7 @@ def reset_max_memory_cached(device=None):
Arguments:
device (torch.device or int, optional): selected device. Returns
- statistic for the current device, given by :meth:`~torch.cuda.current_device`,
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
if :attr:`device` is ``None`` (default).
.. note::
diff --git a/torch/cuda/__init__.pyi b/torch/cuda/__init__.pyi
index be85475e56..03da71119f 100644
--- a/torch/cuda/__init__.pyi
+++ b/torch/cuda/__init__.pyi
@@ -26,6 +26,7 @@ _device_t = Union[_device, int]
def check_error(res: int) -> None: ...
def device_count() -> int: ...
def empty_cache() -> None: ...
+def synchronize(device: _device_t) -> None: ...
def set_device(device: _device_t) -> None: ...
def get_device_capability(device: Optional[_device_t]=...) -> Tuple[int, int]: ...
def get_device_name(device: Optional[_device_t]=...) -> str: ...