diff options
author | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2014-06-27 21:25:36 -0700 |
---|---|---|
committer | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2014-07-03 17:14:12 -0700 |
commit | aaa9f24d369798cfa268f465ec5a1c35257ed500 (patch) | |
tree | 9da9a0672d8167ad2a79cf57a71a8b815102b08e /src/caffe/util/math_functions.cpp | |
parent | 803a7a94f3033e4d389de57c18524c98ac744b3f (diff) | |
download | caffeonacl-aaa9f24d369798cfa268f465ec5a1c35257ed500.tar.gz caffeonacl-aaa9f24d369798cfa268f465ec5a1c35257ed500.tar.bz2 caffeonacl-aaa9f24d369798cfa268f465ec5a1c35257ed500.zip |
do all caffe_copy() as UVA mem copy, and drop caffe_gpu_copy()
Do all memory copies by `cudaMemcpy` in UVA mode so that the same
`caffe_copy()` interface works for all transfers.
`cudaMemcpy()` is used in lieu of BLAS copies because they do not
understand UVA.
Drop the now unnecessary `caffe_gpu_copy()` since location of the
pointers is now irrelevant to the interface.
Diffstat (limited to 'src/caffe/util/math_functions.cpp')
-rw-r--r-- | src/caffe/util/math_functions.cpp | 21 |
1 files changed, 13 insertions, 8 deletions
diff --git a/src/caffe/util/math_functions.cpp b/src/caffe/util/math_functions.cpp index 90df5124..b1b62edb 100644 --- a/src/caffe/util/math_functions.cpp +++ b/src/caffe/util/math_functions.cpp @@ -150,30 +150,35 @@ void caffe_add_scalar(const int N, const double alpha, double* Y) { } template <> -void caffe_copy<float>(const int N, const float* X, float* Y) { +void caffe_copy<int>(const int N, const int* X, int* Y) { if (X != Y) { - cblas_scopy(N, X, 1, Y, 1); + CUDA_CHECK(cudaMemcpy(Y, X, sizeof(int) * N, cudaMemcpyDefault)); } } template <> -void caffe_copy<double>(const int N, const double* X, double* Y) { +void caffe_copy<unsigned int>(const int N, const unsigned int* X, + unsigned int* Y) { if (X != Y) { - cblas_dcopy(N, X, 1, Y, 1); + CUDA_CHECK(cudaMemcpy(Y, X, sizeof(unsigned int) * N, cudaMemcpyDefault)); } } template <> -void caffe_gpu_copy<float>(const int N, const float* X, float* Y) { +void caffe_copy<float>(const int N, const float* X, float* Y) { if (X != Y) { - CUBLAS_CHECK(cublasScopy(Caffe::cublas_handle(), N, X, 1, Y, 1)); + CUDA_CHECK(cudaMemcpy(Y, X, sizeof(float) * N, cudaMemcpyDefault)); } } template <> -void caffe_gpu_copy<double>(const int N, const double* X, double* Y) { +void caffe_copy<double>(const int N, const double* X, double* Y) { + CUDA_CHECK(cudaMemcpy(Y, X, sizeof(double) * N, cudaMemcpyDefault)); +} + +void caffe_copy(const size_t N, const void* X, void* Y) { if (X != Y) { - CUBLAS_CHECK(cublasDcopy(Caffe::cublas_handle(), N, X, 1, Y, 1)); + CUDA_CHECK(cudaMemcpy(Y, X, N, cudaMemcpyDefault)); } } |