summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--aten/src/THC/THCCachingAllocator.cpp6
-rw-r--r--aten/src/THC/THCCachingAllocator.h1
-rw-r--r--docs/source/cuda.rst4
-rw-r--r--docs/source/notes/cuda.rst12
-rw-r--r--torch/csrc/cuda/Module.cpp10
-rw-r--r--torch/cuda/__init__.py7
6 files changed, 39 insertions, 1 deletions
diff --git a/aten/src/THC/THCCachingAllocator.cpp b/aten/src/THC/THCCachingAllocator.cpp
index 11d1467201..68e8875f4e 100644
--- a/aten/src/THC/THCCachingAllocator.cpp
+++ b/aten/src/THC/THCCachingAllocator.cpp
@@ -495,3 +495,9 @@ THC_API std::mutex* THCCachingAllocator_getCudaFreeMutex()
{
return &caching_allocator.cuda_free_mutex;
}
+
+THC_API cudaError_t THCCachingAllocator_emptyCache(void)
+{
+ return caching_allocator.emptyCache();
+}
+
diff --git a/aten/src/THC/THCCachingAllocator.h b/aten/src/THC/THCCachingAllocator.h
index 741c2109d3..ffb2dcfcc2 100644
--- a/aten/src/THC/THCCachingAllocator.h
+++ b/aten/src/THC/THCCachingAllocator.h
@@ -11,6 +11,7 @@
THC_API THCDeviceAllocator* THCCachingAllocator_get(void);
THC_API void* THCCachingAllocator_getBaseAllocation(void *ptr, size_t *size);
THC_API void THCCachingAllocator_recordStream(void *ptr, THCStream* stream);
+THC_API cudaError_t THCCachingAllocator_emptyCache(void);
#if (__cplusplus >= 201103L) || (defined(_MSC_VER) && defined(__cplusplus))
THC_API std::mutex* THCCachingAllocator_getCudaFreeMutex();
diff --git a/docs/source/cuda.rst b/docs/source/cuda.rst
index b3441bfb5a..fbe87f8306 100644
--- a/docs/source/cuda.rst
+++ b/docs/source/cuda.rst
@@ -37,6 +37,10 @@ Streams and events
.. autoclass:: Event
:members:
+Memory management
+-----------------
+.. autofunction:: empty_cache
+
NVIDIA Tools Extension (NVTX)
-----------------------------
diff --git a/docs/source/notes/cuda.rst b/docs/source/notes/cuda.rst
index 0831f6424d..305d6d4f75 100644
--- a/docs/source/notes/cuda.rst
+++ b/docs/source/notes/cuda.rst
@@ -42,6 +42,16 @@ Below you can find a small example showcasing this::
d = torch.randn(2).cuda(2)
# d.get_device() == 2
+Memory management
+-----------------
+
+PyTorch use a caching memory allocator to speed up memory allocations. This
+allows fast memory deallocation without device synchronizations. However, the
+unused memory managed by the allocator will still show as if used in
+`nvidia-smi`. Calling :meth:`~torch.cuda.empty_cache` can release all unused
+cached memory from PyTorch so that those can be used by other GPU applications.
+
+
Best practices
--------------
@@ -50,7 +60,7 @@ Device-agnostic code
Due to the structure of PyTorch, you may need to explicitly write
device-agnostic (CPU or GPU) code; an example may be creating a new tensor as
-the initial hidden state of a recurrent neural network.
+the initial hidden state of a recurrent neural network.
The first step is to determine whether the GPU should be used or not. A common
pattern is to use Python's ``argparse`` module to read in user arguments, and
diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp
index 2353c4701c..857610db80 100644
--- a/torch/csrc/cuda/Module.cpp
+++ b/torch/csrc/cuda/Module.cpp
@@ -297,6 +297,15 @@ PyObject * THCPModule_cudaUnlockMutex(PyObject *module)
Py_RETURN_NONE;
}
+PyObject * THCPModule_emptyCache(PyObject *_unused)
+{
+ HANDLE_TH_ERRORS
+ auto device_allocator = THCState_getDeviceAllocator(state);
+ THCudaCheck(device_allocator->emptyCache(device_allocator->state));
+ END_HANDLE_TH_ERRORS
+ Py_RETURN_NONE;
+}
+
////////////////////////////////////////////////////////////////////////////////
// Cuda module initialization
////////////////////////////////////////////////////////////////////////////////
@@ -376,6 +385,7 @@ static struct PyMethodDef _THCPModule_methods[] = {
{"_cuda_getDriverVersion", (PyCFunction)THCPModule_getDriverVersion, METH_NOARGS, NULL},
{"_cuda_getRNGState", (PyCFunction)THCPModule_getRNGState, METH_NOARGS, NULL},
{"_cuda_setRNGState", (PyCFunction)THCPModule_setRNGState, METH_O, NULL},
+ {"_cuda_emptyCache", (PyCFunction) THCPModule_emptyCache, METH_NOARGS, NULL},
{"_cuda_manualSeed", (PyCFunction)THCPModule_manualSeed, METH_O, NULL},
{"_cuda_manualSeedAll", (PyCFunction)THCPModule_manualSeedAll, METH_O, NULL},
{"_cuda_seed", (PyCFunction)THCPModule_seed, METH_NOARGS, NULL},
diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py
index a5b67a5aa6..131e0611fe 100644
--- a/torch/cuda/__init__.py
+++ b/torch/cuda/__init__.py
@@ -267,6 +267,13 @@ def current_blas_handle():
return torch._C._cuda_getCurrentBlasHandle()
+def empty_cache():
+ """Releases all unoccupied cached memory currently held by the caching
+ allocator so that those can be used in other GPU application and visible in
+ `nvidia-smi`."""
+ return torch._C._cuda_emptyCache()
+
+
def _host_allocator():
_lazy_init()
return torch._C._cuda_cudaHostAllocator()