summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/caffe/syncedmem.hpp15
-rw-r--r--src/caffe/syncedmem.cpp8
2 files changed, 14 insertions, 9 deletions
diff --git a/include/caffe/syncedmem.hpp b/include/caffe/syncedmem.hpp
index 62aadef4..3d92a0ea 100644
--- a/include/caffe/syncedmem.hpp
+++ b/include/caffe/syncedmem.hpp
@@ -13,20 +13,22 @@ namespace caffe {
// The improvement in performance seems negligible in the single GPU case,
// but might be more significant for parallel training. Most importantly,
// it improved stability for large models on many GPUs.
-inline void CaffeMallocHost(void** ptr, size_t size) {
+inline void CaffeMallocHost(void** ptr, size_t size, bool* use_cuda) {
#ifndef CPU_ONLY
if (Caffe::mode() == Caffe::GPU) {
CUDA_CHECK(cudaMallocHost(ptr, size));
+ *use_cuda = true;
return;
}
#endif
*ptr = malloc(size);
+ *use_cuda = false;
CHECK(*ptr) << "host allocation of size " << size << " failed";
}
-inline void CaffeFreeHost(void* ptr) {
+inline void CaffeFreeHost(void* ptr, bool use_cuda) {
#ifndef CPU_ONLY
- if (Caffe::mode() == Caffe::GPU) {
+ if (use_cuda) {
CUDA_CHECK(cudaFreeHost(ptr));
return;
}
@@ -45,10 +47,12 @@ class SyncedMemory {
public:
SyncedMemory()
: cpu_ptr_(NULL), gpu_ptr_(NULL), size_(0), head_(UNINITIALIZED),
- own_cpu_data_(false), own_gpu_data_(false), gpu_device_(-1) {}
+ own_cpu_data_(false), cpu_malloc_use_cuda_(false), own_gpu_data_(false),
+ gpu_device_(-1) {}
explicit SyncedMemory(size_t size)
: cpu_ptr_(NULL), gpu_ptr_(NULL), size_(size), head_(UNINITIALIZED),
- own_cpu_data_(false), own_gpu_data_(false), gpu_device_(-1) {}
+ own_cpu_data_(false), cpu_malloc_use_cuda_(false), own_gpu_data_(false),
+ gpu_device_(-1) {}
~SyncedMemory();
const void* cpu_data();
void set_cpu_data(void* data);
@@ -72,6 +76,7 @@ class SyncedMemory {
size_t size_;
SyncedHead head_;
bool own_cpu_data_;
+ bool cpu_malloc_use_cuda_;
bool own_gpu_data_;
int gpu_device_;
diff --git a/src/caffe/syncedmem.cpp b/src/caffe/syncedmem.cpp
index a667a867..632bf1f1 100644
--- a/src/caffe/syncedmem.cpp
+++ b/src/caffe/syncedmem.cpp
@@ -8,7 +8,7 @@ namespace caffe {
SyncedMemory::~SyncedMemory() {
if (cpu_ptr_ && own_cpu_data_) {
- CaffeFreeHost(cpu_ptr_);
+ CaffeFreeHost(cpu_ptr_, cpu_malloc_use_cuda_);
}
#ifndef CPU_ONLY
@@ -27,7 +27,7 @@ SyncedMemory::~SyncedMemory() {
inline void SyncedMemory::to_cpu() {
switch (head_) {
case UNINITIALIZED:
- CaffeMallocHost(&cpu_ptr_, size_);
+ CaffeMallocHost(&cpu_ptr_, size_, &cpu_malloc_use_cuda_);
caffe_memset(size_, 0, cpu_ptr_);
head_ = HEAD_AT_CPU;
own_cpu_data_ = true;
@@ -35,7 +35,7 @@ inline void SyncedMemory::to_cpu() {
case HEAD_AT_GPU:
#ifndef CPU_ONLY
if (cpu_ptr_ == NULL) {
- CaffeMallocHost(&cpu_ptr_, size_);
+ CaffeMallocHost(&cpu_ptr_, size_, &cpu_malloc_use_cuda_);
own_cpu_data_ = true;
}
caffe_gpu_memcpy(size_, gpu_ptr_, cpu_ptr_);
@@ -86,7 +86,7 @@ const void* SyncedMemory::cpu_data() {
void SyncedMemory::set_cpu_data(void* data) {
CHECK(data);
if (own_cpu_data_) {
- CaffeFreeHost(cpu_ptr_);
+ CaffeFreeHost(cpu_ptr_, cpu_malloc_use_cuda_);
}
cpu_ptr_ = data;
head_ = HEAD_AT_CPU;