summaryrefslogtreecommitdiff
path: root/src/caffe/common.cpp
diff options
context:
space:
mode:
authorYangqing Jia <jiayq84@gmail.com>2013-10-31 16:52:22 -0700
committerYangqing Jia <jiayq84@gmail.com>2013-10-31 16:52:22 -0700
commit82b912be849a7bec9bfee92a8e5d81182f4130f2 (patch)
treecbe089cc26dda1ca8d752f3216063382ec25241a /src/caffe/common.cpp
parent25a865cd8ba0995c89907990fedaa357282b9a64 (diff)
downloadcaffeonacl-82b912be849a7bec9bfee92a8e5d81182f4130f2.tar.gz
caffeonacl-82b912be849a7bec9bfee92a8e5d81182f4130f2.tar.bz2
caffeonacl-82b912be849a7bec9bfee92a8e5d81182f4130f2.zip
solver restructuring: now all prototxt are specified in the solver protocol buffer
Diffstat (limited to 'src/caffe/common.cpp')
-rw-r--r--src/caffe/common.cpp18
1 files changed, 18 insertions, 0 deletions
diff --git a/src/caffe/common.cpp b/src/caffe/common.cpp
index aecdc6e1..a70a2808 100644
--- a/src/caffe/common.cpp
+++ b/src/caffe/common.cpp
@@ -74,6 +74,24 @@ void Caffe::set_random_seed(const unsigned int seed) {
VSL_CHECK(vslNewStream(&(Get().vsl_stream_), VSL_BRNG_MT19937, seed));
}
+void Caffe::SetDevice(const int device_id) {
+ int current_device;
+ CUDA_CHECK(cudaGetDevice(&current_device));
+ if (current_device == device_id) {
+ return;
+ }
+ if (Get().cublas_handle_) CUBLAS_CHECK(cublasDestroy(Get().cublas_handle_));
+ if (Get().curand_generator_) {
+ CURAND_CHECK(curandDestroyGenerator(Get().curand_generator_));
+ }
+ CUDA_CHECK(cudaSetDevice(device_id));
+ CUBLAS_CHECK(cublasCreate(&Get().cublas_handle_));
+ CURAND_CHECK(curandCreateGenerator(&Get().curand_generator_,
+ CURAND_RNG_PSEUDO_DEFAULT));
+ CURAND_CHECK(curandSetPseudoRandomGeneratorSeed(Get().curand_generator_,
+ time(NULL)));
+}
+
void Caffe::DeviceQuery() {
cudaDeviceProp prop;
int device;