summaryrefslogtreecommitdiff
path: root/include/caffe/common.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'include/caffe/common.hpp')
-rw-r--r--include/caffe/common.hpp17
1 files changed, 13 insertions, 4 deletions
diff --git a/include/caffe/common.hpp b/include/caffe/common.hpp
index bd4e39f1..0ee49f36 100644
--- a/include/caffe/common.hpp
+++ b/include/caffe/common.hpp
@@ -4,12 +4,10 @@
#define CAFFE_COMMON_HPP_
#include <boost/shared_ptr.hpp>
-#include <cublas_v2.h>
-#include <cuda.h>
-#include <curand.h>
-#include <driver_types.h> // cuda driver types
#include <glog/logging.h>
+#include "caffe/util/device_alternate.hpp"
+
// Disable the copy and assignment operator for a class.
#define DISABLE_COPY_AND_ASSIGN(classname) \
private:\
@@ -25,6 +23,8 @@ private:\
// is executed we will see a fatal log.
#define NOT_IMPLEMENTED LOG(FATAL) << "Not Implemented Yet"
+#ifndef CPU_ONLY
+
// CUDA: various checks for different function calls.
#define CUDA_CHECK(condition) \
/* Code block avoids redefinition of cudaError_t error */ \
@@ -56,6 +56,8 @@ private:\
// CUDA: check for error after kernel execution and exit loudly if there is one.
#define CUDA_POST_KERNEL_CHECK CUDA_CHECK(cudaPeekAtLastError())
+#endif // CPU_ONLY
+
namespace caffe {
// We will use the boost shared_ptr instead of the new C++11 one mainly
@@ -99,10 +101,12 @@ class Caffe {
}
return *(Get().random_generator_);
}
+#ifndef CPU_ONLY
inline static cublasHandle_t cublas_handle() { return Get().cublas_handle_; }
inline static curandGenerator_t curand_generator() {
return Get().curand_generator_;
}
+#endif
// Returns the mode: running on CPU or GPU.
inline static Brew mode() { return Get().mode_; }
@@ -125,8 +129,10 @@ class Caffe {
static void DeviceQuery();
protected:
+#ifndef CPU_ONLY
cublasHandle_t cublas_handle_;
curandGenerator_t curand_generator_;
+#endif
shared_ptr<RNG> random_generator_;
Brew mode_;
@@ -140,6 +146,8 @@ class Caffe {
DISABLE_COPY_AND_ASSIGN(Caffe);
};
+#ifndef CPU_ONLY
+
// NVIDIA_CUDA-5.5_Samples/common/inc/helper_cuda.h
const char* cublasGetErrorString(cublasStatus_t error);
const char* curandGetErrorString(curandStatus_t error);
@@ -158,6 +166,7 @@ inline int CAFFE_GET_BLOCKS(const int N) {
return (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
}
+#endif // CPU_ONLY
} // namespace caffe