diff options
author | Kai Li <kaili_kloud@163.com> | 2014-07-10 07:50:31 +0800 |
---|---|---|
committer | Kai Li <kaili_kloud@163.com> | 2014-07-10 08:03:22 +0800 |
commit | 904c2ce69b728c988004cef6796fbdf07ecb4c1e (patch) | |
tree | aaaec9d94a70dfd06552f169ee0fc6952649c2dc /src | |
parent | ac0dd39252812aba1f67b9b0b0e18e62ea1742e6 (diff) | |
download | caffeonacl-904c2ce69b728c988004cef6796fbdf07ecb4c1e.tar.gz caffeonacl-904c2ce69b728c988004cef6796fbdf07ecb4c1e.tar.bz2 caffeonacl-904c2ce69b728c988004cef6796fbdf07ecb4c1e.zip |
Replace cudaMemcpy with caffe_gpu_memcpy in SyncedMemory per @longjon
Diffstat (limited to 'src')
-rw-r--r-- | src/caffe/syncedmem.cpp | 4 | ||||
-rw-r--r-- | src/caffe/test/test_syncedmem.cpp | 4 | ||||
-rw-r--r-- | src/caffe/util/math_functions.cpp | 8 |
3 files changed, 6 insertions, 10 deletions
diff --git a/src/caffe/syncedmem.cpp b/src/caffe/syncedmem.cpp index 3f9a3be9..77dfe7a4 100644 --- a/src/caffe/syncedmem.cpp +++ b/src/caffe/syncedmem.cpp @@ -33,7 +33,7 @@ inline void SyncedMemory::to_cpu() { CaffeMallocHost(&cpu_ptr_, size_); own_cpu_data_ = true; } - CUDA_CHECK(cudaMemcpy(cpu_ptr_, gpu_ptr_, size_, cudaMemcpyDefault)); + caffe_gpu_memcpy(size_, gpu_ptr_, cpu_ptr_); head_ = SYNCED; break; case HEAD_AT_CPU: @@ -53,7 +53,7 @@ inline void SyncedMemory::to_gpu() { if (gpu_ptr_ == NULL) { CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_)); } - CUDA_CHECK(cudaMemcpy(gpu_ptr_, cpu_ptr_, size_, cudaMemcpyDefault)); + caffe_gpu_memcpy(size_, cpu_ptr_, gpu_ptr_); head_ = SYNCED; break; case HEAD_AT_GPU: diff --git a/src/caffe/test/test_syncedmem.cpp b/src/caffe/test/test_syncedmem.cpp index 3aaeafc3..3a757088 100644 --- a/src/caffe/test/test_syncedmem.cpp +++ b/src/caffe/test/test_syncedmem.cpp @@ -58,7 +58,7 @@ TEST_F(SyncedMemoryTest, TestGPURead) { EXPECT_EQ(mem.head(), SyncedMemory::SYNCED); // check if values are the same char* recovered_value = new char[10]; - caffe_memcpy(10, gpu_data, recovered_value); + caffe_gpu_memcpy(10, gpu_data, recovered_value); for (int i = 0; i < mem.size(); ++i) { EXPECT_EQ((reinterpret_cast<char*>(recovered_value))[i], 1); } @@ -72,7 +72,7 @@ TEST_F(SyncedMemoryTest, TestGPURead) { gpu_data = mem.gpu_data(); EXPECT_EQ(mem.head(), SyncedMemory::SYNCED); // check if values are the same - caffe_memcpy(10, gpu_data, recovered_value); + caffe_gpu_memcpy(10, gpu_data, recovered_value); for (int i = 0; i < mem.size(); ++i) { EXPECT_EQ((reinterpret_cast<char*>(recovered_value))[i], 2); } diff --git a/src/caffe/util/math_functions.cpp b/src/caffe/util/math_functions.cpp index 9311a398..b989ca2a 100644 --- a/src/caffe/util/math_functions.cpp +++ b/src/caffe/util/math_functions.cpp @@ -166,13 +166,9 @@ template void caffe_copy<unsigned int>(const int N, const unsigned int* X, template void caffe_copy<float>(const int N, const float* X, float* Y); template void caffe_copy<double>(const int N, const double* X, double* Y); -void caffe_memcpy(const size_t N, const void* X, void* Y) { +void caffe_gpu_memcpy(const size_t N, const void* X, void* Y) { if (X != Y) { - if (Caffe::mode() == Caffe::GPU) { - CUDA_CHECK(cudaMemcpy(Y, X, N, cudaMemcpyDefault)); - } else { - memcpy(Y, X, N); - } + CUDA_CHECK(cudaMemcpy(Y, X, N, cudaMemcpyDefault)); } } |