diff options
-rw-r--r-- | aten/src/THC/THCCachingAllocator.cpp | 12 | ||||
-rw-r--r-- | aten/src/THC/THCCachingAllocator.h | 2 | ||||
-rw-r--r-- | docs/source/cuda.rst | 2 | ||||
-rw-r--r-- | docs/source/notes/cuda.rst | 8 | ||||
-rw-r--r-- | test/test_cuda.py | 33 | ||||
-rw-r--r-- | torch/csrc/cuda/Module.cpp | 22 | ||||
-rw-r--r-- | torch/cuda/__init__.py | 54 |
7 files changed, 121 insertions, 12 deletions
diff --git a/aten/src/THC/THCCachingAllocator.cpp b/aten/src/THC/THCCachingAllocator.cpp index 44ebac193e..6bf92e14a9 100644 --- a/aten/src/THC/THCCachingAllocator.cpp +++ b/aten/src/THC/THCCachingAllocator.cpp @@ -578,6 +578,12 @@ THC_API uint64_t THCCachingAllocator_maxMemoryAllocated(int device) { return caching_allocator.get_stats_for_device(device).max_amount_allocated; } +THC_API void THCCachingAllocator_resetMaxMemoryAllocated(int device) { + assertValidDevice(device); + DeviceStats& stats = caching_allocator.get_stats_for_device(device); + stats.max_amount_allocated = stats.amount_allocated; +} + THC_API uint64_t THCCachingAllocator_currentMemoryCached(int device) { assertValidDevice(device); @@ -589,6 +595,12 @@ THC_API uint64_t THCCachingAllocator_maxMemoryCached(int device) { return caching_allocator.get_stats_for_device(device).max_amount_cached; } +THC_API void THCCachingAllocator_resetMaxMemoryCached(int device) { + assertValidDevice(device); + DeviceStats& stats = caching_allocator.get_stats_for_device(device); + stats.max_amount_cached = stats.amount_cached; +} + // // In CUDA IPC, sender sends a tensor to receiver, THCCaching_CUDAIpcDevptr // is called by the receiving process to map the CUDA memory from the sending diff --git a/aten/src/THC/THCCachingAllocator.h b/aten/src/THC/THCCachingAllocator.h index 626694a8ad..491562ad81 100644 --- a/aten/src/THC/THCCachingAllocator.h +++ b/aten/src/THC/THCCachingAllocator.h @@ -21,8 +21,10 @@ THC_API void THCCachingAllocator_recordStream(void *ptr, at::cuda::CUDAStream st #endif THC_API uint64_t THCCachingAllocator_currentMemoryAllocated(int device); THC_API uint64_t THCCachingAllocator_maxMemoryAllocated(int device); +THC_API void THCCachingAllocator_resetMaxMemoryAllocated(int device); THC_API uint64_t THCCachingAllocator_currentMemoryCached(int device); THC_API uint64_t THCCachingAllocator_maxMemoryCached(int device); +THC_API void THCCachingAllocator_resetMaxMemoryCached(int device); #if (__cplusplus >= 201103L) || (defined(_MSC_VER) && defined(__cplusplus)) THC_API std::mutex* THCCachingAllocator_getCudaFreeMutex(); diff --git a/docs/source/cuda.rst b/docs/source/cuda.rst index 6da20ce68e..462967461c 100644 --- a/docs/source/cuda.rst +++ b/docs/source/cuda.rst @@ -46,8 +46,10 @@ Memory management .. autofunction:: empty_cache .. autofunction:: memory_allocated .. autofunction:: max_memory_allocated +.. autofunction:: reset_max_memory_allocated .. autofunction:: memory_cached .. autofunction:: max_memory_cached +.. autofunction:: reset_max_memory_cached NVIDIA Tools Extension (NVTX) ----------------------------- diff --git a/docs/source/notes/cuda.rst b/docs/source/notes/cuda.rst index 212f68e694..7cf2fe6ad3 100644 --- a/docs/source/notes/cuda.rst +++ b/docs/source/notes/cuda.rst @@ -74,9 +74,9 @@ You can force synchronous computation by setting environment variable operation is actually executed, so the stack trace does not show where it was requested.) -As an exception, several functions such as :meth:`~torch.Tensor.to` and -:meth:`~torch.Tensor.copy_` admit an explicit :attr:`non_blocking` argument, -which lets the caller bypass synchronization when it is unnecessary. +As an exception, several functions such as :meth:`~torch.Tensor.to` and +:meth:`~torch.Tensor.copy_` admit an explicit :attr:`non_blocking` argument, +which lets the caller bypass synchronization when it is unnecessary. Another exception is CUDA streams, explained below. CUDA streams @@ -118,7 +118,7 @@ unused memory managed by the allocator will still show as if used in :meth:`~torch.cuda.max_memory_allocated` to monitor memory occupied by tensors, and use :meth:`~torch.cuda.memory_cached` and :meth:`~torch.cuda.max_memory_cached` to monitor memory managed by the caching -allocator. Calling :meth:`~torch.cuda.empty_cache` can release all **unused** +allocator. Calling :meth:`~torch.cuda.empty_cache` releases all **unused** cached memory from PyTorch so that those can be used by other GPU applications. However, the occupied GPU memory by tensors will not be freed so it can not increase the amount of GPU memory available for PyTorch. diff --git a/test/test_cuda.py b/test/test_cuda.py index 26eab2ab37..5130d23ffe 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -667,7 +667,7 @@ class TestCuda(TestCase): # memory checks below to fail. return torch.cuda.FloatTensor(*size) - def assert_change(comp=1, empty_cache=False): + def assert_change(comp=1, empty_cache=False, reset_max_alloc=False, reset_max_cached=False): # comp > 0: increased # comp = 0: equal # comp < 0: decreased @@ -702,7 +702,26 @@ class TestCuda(TestCase): self.assertEqual(new_max_c, max_c_arr[0]) last_c_arr[0] = new_c + if reset_max_alloc: + torch.cuda.reset_max_memory_allocated(device) + self.assertEqual(torch.cuda.memory_allocated(device), last_m_arr[0]) + self.assertEqual(torch.cuda.max_memory_allocated(device), last_m_arr[0]) + max_m_arr[0] = last_m_arr[0] + self.assertEqual(torch.cuda.memory_cached(device), last_c_arr[0]) + self.assertEqual(torch.cuda.max_memory_cached(device), max_c_arr[0]) + + if reset_max_cached: + torch.cuda.reset_max_memory_cached(device) + self.assertEqual(torch.cuda.memory_allocated(device), last_m_arr[0]) + self.assertEqual(torch.cuda.max_memory_allocated(device), max_m_arr[0]) + self.assertEqual(torch.cuda.memory_cached(device), last_c_arr[0]) + self.assertEqual(torch.cuda.max_memory_cached(device), last_c_arr[0]) + max_c_arr[0] = last_c_arr[0] + assert_change(0) + assert_change(0, reset_max_alloc=True) + assert_change(0, empty_cache=True) + assert_change(0, reset_max_cached=True) assert_change(0) yield @@ -722,7 +741,7 @@ class TestCuda(TestCase): for i in range(5, int(N / 2) + 5): # large ones tensors2.append(alloc(i, i * 7, i * 9, i * 11)) - assert_change(1) + assert_change(1, reset_max_alloc=(i % 2 == 0), reset_max_cached=(i % 2 == 1)) yield tensors2.append(alloc(0, 0, 0)) @@ -742,7 +761,7 @@ class TestCuda(TestCase): assert_change(0) yield del permute - assert_change(0) + assert_change(0, reset_max_alloc=True) yield for i in range(int(N / 2)): @@ -757,17 +776,19 @@ class TestCuda(TestCase): yield del tensors2 - assert_change(-1) + assert_change(-1, reset_max_cached=True) assert_change(0) self.assertEqual(torch.cuda.memory_allocated(device), m1) yield True del tensors1 - assert_change(-1) + assert_change(-1, reset_max_alloc=True) self.assertEqual(torch.cuda.memory_allocated(device), m0) - # test empty_cache + # test empty_cache and reset_max_memory_* assert_change(0, empty_cache=True) + assert_change(0, reset_max_cached=True) + assert_change(0, reset_max_alloc=True) def test_memory_stats(self): torch.cuda.empty_cache() diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index bc7729819b..5b18d63798 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -269,6 +269,16 @@ PyObject * THCPModule_maxMemoryAllocated(PyObject *_unused, PyObject *arg) END_HANDLE_TH_ERRORS } +PyObject * THCPModule_resetMaxMemoryAllocated(PyObject *_unused, PyObject *arg) +{ + HANDLE_TH_ERRORS + THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to reset_max_memory_allocated"); + int device = (int) THPUtils_unpackLong(arg); + THCCachingAllocator_resetMaxMemoryAllocated(device); + END_HANDLE_TH_ERRORS + Py_RETURN_NONE; +} + PyObject * THCPModule_memoryCached(PyObject *_unused, PyObject *arg) { HANDLE_TH_ERRORS @@ -289,6 +299,16 @@ PyObject * THCPModule_maxMemoryCached(PyObject *_unused, PyObject *arg) END_HANDLE_TH_ERRORS } +PyObject * THCPModule_resetMaxMemoryCached(PyObject *_unused, PyObject *arg) +{ + HANDLE_TH_ERRORS + THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to reset_max_memory_cached"); + int device = (int) THPUtils_unpackLong(arg); + THCCachingAllocator_resetMaxMemoryCached(device); + END_HANDLE_TH_ERRORS + Py_RETURN_NONE; +} + //////////////////////////////////////////////////////////////////////////////// // Cuda module initialization //////////////////////////////////////////////////////////////////////////////// @@ -397,8 +417,10 @@ static struct PyMethodDef _THCPModule_methods[] = { {"_cuda_emptyCache", (PyCFunction) THCPModule_emptyCache, METH_NOARGS, nullptr}, {"_cuda_memoryAllocated", (PyCFunction) THCPModule_memoryAllocated, METH_O, nullptr}, {"_cuda_maxMemoryAllocated", (PyCFunction) THCPModule_maxMemoryAllocated, METH_O, nullptr}, + {"_cuda_resetMaxMemoryAllocated", (PyCFunction) THCPModule_resetMaxMemoryAllocated, METH_O, nullptr}, {"_cuda_memoryCached", (PyCFunction) THCPModule_memoryCached, METH_O, nullptr}, {"_cuda_maxMemoryCached", (PyCFunction) THCPModule_maxMemoryCached, METH_O, nullptr}, + {"_cuda_resetMaxMemoryCached", (PyCFunction) THCPModule_resetMaxMemoryCached, METH_O, nullptr}, {"_cuda_manualSeed", (PyCFunction)THCPModule_manualSeed, METH_O, nullptr}, {"_cuda_manualSeedAll", (PyCFunction)THCPModule_manualSeedAll, METH_O, nullptr}, {"_cuda_seed", (PyCFunction)THCPModule_seed, METH_NOARGS, nullptr}, diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index c6abfc0251..6534446983 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -375,7 +375,7 @@ def empty_cache(): def memory_allocated(device=None): - r"""Returns the current GPU memory usage by tensors in bytes for a given + r"""Returns the current GPU memory occupied by tensors in bytes for a given device. Arguments: @@ -394,9 +394,15 @@ def memory_allocated(device=None): def max_memory_allocated(device=None): - r"""Returns the maximum GPU memory usage by tensors in bytes for a given + r"""Returns the maximum GPU memory occupied by tensors in bytes for a given device. + By default, this returns the peak allocated memory since the beginning of + this program. :func:`~torch.cuda.reset_max_memory_allocated` can be used to + reset the starting point in tracking this metric. For example, these two + functions can measure the peak allocated memory usage of each iteration in a + training loop. + Arguments: device (torch.device or int, optional): selected device. Returns statistic for the current device, given by :meth:`~torch.cuda.current_device`, @@ -410,6 +416,25 @@ def max_memory_allocated(device=None): return torch._C._cuda_maxMemoryAllocated(device) +def reset_max_memory_allocated(device=None): + r"""Resets the starting point in tracking maximum GPU memory occupied by + tensors for a given device. + + See :func:`~torch.cuda.max_memory_allocated` for details. + + Arguments: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :meth:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + device = _get_device_index(device, optional=True) + return torch._C._cuda_resetMaxMemoryAllocated(device) + + def memory_cached(device=None): r"""Returns the current GPU memory managed by the caching allocator in bytes for a given device. @@ -431,6 +456,12 @@ def max_memory_cached(device=None): r"""Returns the maximum GPU memory managed by the caching allocator in bytes for a given device. + By default, this returns the peak cached memory since the beginning of this + program. :func:`~torch.cuda.reset_max_memory_cached` can be used to reset + the starting point in tracking this metric. For example, these two functions + can measure the peak cached memory amount of each iteration in a training + loop. + Arguments: device (torch.device or int, optional): selected device. Returns statistic for the current device, given by :meth:`~torch.cuda.current_device`, @@ -444,6 +475,25 @@ def max_memory_cached(device=None): return torch._C._cuda_maxMemoryCached(device) +def reset_max_memory_cached(device=None): + r"""Resets the starting point in tracking maximum GPU memory managed by the + caching allocator for a given device. + + See :func:`~torch.cuda.max_memory_cached` for details. + + Arguments: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :meth:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + device = _get_device_index(device, optional=True) + return torch._C._cuda_resetMaxMemoryCached(device) + + def _host_allocator(): _lazy_init() return torch._C._cuda_cudaHostAllocator() |