summaryrefslogtreecommitdiff
path: root/src/caffe
diff options
context:
space:
mode:
Diffstat (limited to 'src/caffe')
-rw-r--r--src/caffe/blob.cpp7
-rw-r--r--src/caffe/common.cpp49
-rw-r--r--src/caffe/layers/image_data_layer.cu1
-rw-r--r--src/caffe/layers/inner_product_layer.cu2
-rw-r--r--src/caffe/solver.cpp4
-rw-r--r--src/caffe/syncedmem.cpp20
-rw-r--r--src/caffe/test/test_accuracy_layer.cpp3
-rw-r--r--src/caffe/test/test_argmax_layer.cpp3
-rw-r--r--src/caffe/test/test_benchmark.cpp3
-rw-r--r--src/caffe/test/test_blob.cpp1
-rw-r--r--src/caffe/test/test_caffe_main.cpp6
-rw-r--r--src/caffe/test/test_common.cpp9
-rw-r--r--src/caffe/test/test_concat_layer.cpp3
-rw-r--r--src/caffe/test/test_convolution_layer.cpp3
-rw-r--r--src/caffe/test/test_data_layer.cpp3
-rw-r--r--src/caffe/test/test_dummy_data_layer.cpp2
-rw-r--r--src/caffe/test/test_eltwise_layer.cpp3
-rw-r--r--src/caffe/test/test_euclidean_loss_layer.cpp3
-rw-r--r--src/caffe/test/test_filler.cpp1
-rw-r--r--src/caffe/test/test_flatten_layer.cpp3
-rw-r--r--src/caffe/test/test_hdf5_output_layer.cpp3
-rw-r--r--src/caffe/test/test_hdf5data_layer.cpp3
-rw-r--r--src/caffe/test/test_hinge_loss_layer.cpp3
-rw-r--r--src/caffe/test/test_im2col_kernel.cu1
-rw-r--r--src/caffe/test/test_im2col_layer.cpp3
-rw-r--r--src/caffe/test/test_image_data_layer.cpp4
-rw-r--r--src/caffe/test/test_inner_product_layer.cpp15
-rw-r--r--src/caffe/test/test_lrn_layer.cpp3
-rw-r--r--src/caffe/test/test_maxpool_dropout_layers.cpp2
-rw-r--r--src/caffe/test/test_multinomial_logistic_loss_layer.cpp3
-rw-r--r--src/caffe/test/test_neuron_layer.cpp3
-rw-r--r--src/caffe/test/test_platform.cpp5
-rw-r--r--src/caffe/test/test_pooling_layer.cpp3
-rw-r--r--src/caffe/test/test_power_layer.cpp3
-rw-r--r--src/caffe/test/test_random_number_generator.cpp1
-rw-r--r--src/caffe/test/test_sigmoid_cross_entropy_loss_layer.cpp2
-rw-r--r--src/caffe/test/test_softmax_layer.cpp3
-rw-r--r--src/caffe/test/test_softmax_with_loss_layer.cpp3
-rw-r--r--src/caffe/test/test_split_layer.cpp3
-rw-r--r--src/caffe/test/test_stochastic_pooling.cpp3
-rw-r--r--src/caffe/test/test_syncedmem.cpp14
-rw-r--r--src/caffe/test/test_tanh_layer.cpp3
-rw-r--r--src/caffe/test/test_threshold_layer.cpp3
-rw-r--r--src/caffe/test/test_upgrade_proto.cpp1
-rw-r--r--src/caffe/test/test_util_blas.cpp8
-rw-r--r--src/caffe/util/benchmark.cpp21
-rw-r--r--src/caffe/util/math_functions.cpp135
-rw-r--r--src/caffe/util/math_functions.cu131
48 files changed, 276 insertions, 238 deletions
diff --git a/src/caffe/blob.cpp b/src/caffe/blob.cpp
index 8df46323..1051eaa1 100644
--- a/src/caffe/blob.cpp
+++ b/src/caffe/blob.cpp
@@ -1,8 +1,5 @@
// Copyright 2014 BVLC and contributors.
-#include <cuda_runtime.h>
-#include <cublas_v2.h>
-
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/syncedmem.hpp"
@@ -126,10 +123,14 @@ void Blob<Dtype>::Update() {
break;
case SyncedMemory::HEAD_AT_GPU:
case SyncedMemory::SYNCED:
+#ifndef CPU_ONLY
// perform computation on GPU
caffe_gpu_axpy<Dtype>(count_, Dtype(-1),
static_cast<const Dtype*>(diff_->gpu_data()),
static_cast<Dtype*>(data_->mutable_gpu_data()));
+#else
+ NO_GPU;
+#endif
break;
default:
LOG(FATAL) << "Syncedmem not initialized.";
diff --git a/src/caffe/common.cpp b/src/caffe/common.cpp
index 631c8afd..e8765eeb 100644
--- a/src/caffe/common.cpp
+++ b/src/caffe/common.cpp
@@ -10,8 +10,7 @@ namespace caffe {
shared_ptr<Caffe> Caffe::singleton_;
-
-// curand seeding
+// random seeding
int64_t cluster_seedgen(void) {
int64_t s, seed, pid;
pid = getpid();
@@ -20,6 +19,50 @@ int64_t cluster_seedgen(void) {
return seed;
}
+#ifdef CPU_ONLY // CPU-only Caffe.
+
+Caffe::Caffe()
+ : random_generator_(), mode_(Caffe::CPU), phase_(Caffe::TRAIN) { }
+
+Caffe::~Caffe() { }
+
+void Caffe::set_random_seed(const unsigned int seed) {
+ // RNG seed
+ Get().random_generator_.reset(new RNG(seed));
+}
+
+void Caffe::SetDevice(const int device_id) {
+ NO_GPU;
+}
+
+void Caffe::DeviceQuery() {
+ NO_GPU;
+}
+
+
+class Caffe::RNG::Generator {
+ public:
+ Generator() : rng_(new caffe::rng_t(cluster_seedgen())) {}
+ explicit Generator(unsigned int seed) : rng_(new caffe::rng_t(seed)) {}
+ caffe::rng_t* rng() { return rng_.get(); }
+ private:
+ shared_ptr<caffe::rng_t> rng_;
+};
+
+Caffe::RNG::RNG() : generator_(new Generator()) { }
+
+Caffe::RNG::RNG(unsigned int seed) : generator_(new Generator(seed)) { }
+
+Caffe::RNG& Caffe::RNG::operator=(const RNG& other) {
+ generator_.reset(other.generator_.get());
+ return *this;
+}
+
+void* Caffe::RNG::generator() {
+ return static_cast<void*>(generator_->rng());
+}
+
+#else // Normal GPU + CPU Caffe.
Caffe::Caffe()
: cublas_handle_(NULL), curand_generator_(NULL), random_generator_(),
@@ -201,4 +244,6 @@ const char* curandGetErrorString(curandStatus_t error) {
return "Unknown curand status";
}
+#endif // CPU_ONLY
+
} // namespace caffe
diff --git a/src/caffe/layers/image_data_layer.cu b/src/caffe/layers/image_data_layer.cu
index dd5bdbc2..40267091 100644
--- a/src/caffe/layers/image_data_layer.cu
+++ b/src/caffe/layers/image_data_layer.cu
@@ -1,6 +1,5 @@
// Copyright 2014 BVLC and contributors.
-#include <cuda_runtime.h>
#include <stdint.h>
#include <leveldb/db.h>
#include <pthread.h>
diff --git a/src/caffe/layers/inner_product_layer.cu b/src/caffe/layers/inner_product_layer.cu
index 453593cf..6ef75788 100644
--- a/src/caffe/layers/inner_product_layer.cu
+++ b/src/caffe/layers/inner_product_layer.cu
@@ -1,7 +1,5 @@
// Copyright 2014 BVLC and contributors.
-#include <cublas_v2.h>
-
#include <vector>
#include "caffe/blob.hpp"
diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp
index ca1d9252..edfa9c00 100644
--- a/src/caffe/solver.cpp
+++ b/src/caffe/solver.cpp
@@ -295,6 +295,7 @@ void SGDSolver<Dtype>::ComputeUpdateValue() {
}
break;
case Caffe::GPU:
+#ifndef CPU_ONLY
for (int param_id = 0; param_id < net_params.size(); ++param_id) {
// Compute the value to history, and then copy them to the blob's diff.
Dtype local_rate = rate * net_params_lr[param_id];
@@ -314,6 +315,9 @@ void SGDSolver<Dtype>::ComputeUpdateValue() {
history_[param_id]->gpu_data(),
net_params[param_id]->mutable_gpu_diff());
}
+#else
+ NO_GPU;
+#endif
break;
default:
LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
diff --git a/src/caffe/syncedmem.cpp b/src/caffe/syncedmem.cpp
index 77dfe7a4..844f639e 100644
--- a/src/caffe/syncedmem.cpp
+++ b/src/caffe/syncedmem.cpp
@@ -1,7 +1,5 @@
// Copyright 2014 BVLC and contributors.
-#include <cuda_runtime.h>
-
#include <cstring>
#include "caffe/common.hpp"
@@ -15,9 +13,11 @@ SyncedMemory::~SyncedMemory() {
CaffeFreeHost(cpu_ptr_);
}
+#ifndef CPU_ONLY
if (gpu_ptr_) {
CUDA_CHECK(cudaFree(gpu_ptr_));
}
+#endif // CPU_ONLY
}
inline void SyncedMemory::to_cpu() {
@@ -29,12 +29,16 @@ inline void SyncedMemory::to_cpu() {
own_cpu_data_ = true;
break;
case HEAD_AT_GPU:
+#ifndef CPU_ONLY
if (cpu_ptr_ == NULL) {
CaffeMallocHost(&cpu_ptr_, size_);
own_cpu_data_ = true;
}
caffe_gpu_memcpy(size_, gpu_ptr_, cpu_ptr_);
head_ = SYNCED;
+#else
+ NO_GPU;
+#endif
break;
case HEAD_AT_CPU:
case SYNCED:
@@ -43,6 +47,7 @@ inline void SyncedMemory::to_cpu() {
}
inline void SyncedMemory::to_gpu() {
+#ifndef CPU_ONLY
switch (head_) {
case UNINITIALIZED:
CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_));
@@ -60,6 +65,9 @@ inline void SyncedMemory::to_gpu() {
case SYNCED:
break;
}
+#else
+ NO_GPU;
+#endif
}
const void* SyncedMemory::cpu_data() {
@@ -78,8 +86,12 @@ void SyncedMemory::set_cpu_data(void* data) {
}
const void* SyncedMemory::gpu_data() {
+#ifndef CPU_ONLY
to_gpu();
return (const void*)gpu_ptr_;
+#else
+ NO_GPU;
+#endif
}
void* SyncedMemory::mutable_cpu_data() {
@@ -89,9 +101,13 @@ void* SyncedMemory::mutable_cpu_data() {
}
void* SyncedMemory::mutable_gpu_data() {
+#ifndef CPU_ONLY
to_gpu();
head_ = HEAD_AT_GPU;
return gpu_ptr_;
+#else
+ NO_GPU;
+#endif
}
diff --git a/src/caffe/test/test_accuracy_layer.cpp b/src/caffe/test/test_accuracy_layer.cpp
index 355a36b3..40b08748 100644
--- a/src/caffe/test/test_accuracy_layer.cpp
+++ b/src/caffe/test/test_accuracy_layer.cpp
@@ -5,7 +5,6 @@
#include <cfloat>
#include <vector>
-#include "cuda_runtime.h"
#include "gtest/gtest.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
@@ -16,8 +15,6 @@
namespace caffe {
-extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
-
template <typename Dtype>
class AccuracyLayerTest : public ::testing::Test {
protected:
diff --git a/src/caffe/test/test_argmax_layer.cpp b/src/caffe/test/test_argmax_layer.cpp
index 44a13b99..d3bdfbde 100644
--- a/src/caffe/test/test_argmax_layer.cpp
+++ b/src/caffe/test/test_argmax_layer.cpp
@@ -2,7 +2,6 @@
#include <vector>
-#include "cuda_runtime.h"
#include "gtest/gtest.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
@@ -14,8 +13,6 @@
namespace caffe {
-extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
-
template <typename Dtype>
class ArgMaxLayerTest : public ::testing::Test {
protected:
diff --git a/src/caffe/test/test_benchmark.cpp b/src/caffe/test/test_benchmark.cpp
index 82880088..b613a873 100644
--- a/src/caffe/test/test_benchmark.cpp
+++ b/src/caffe/test/test_benchmark.cpp
@@ -1,7 +1,6 @@
// Copyright 2014 BVLC and contributors.
#include <unistd.h> // for usleep
-#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include "caffe/common.hpp"
@@ -10,8 +9,6 @@
namespace caffe {
-extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
-
template <typename TypeParam>
class BenchmarkTest : public MultiDeviceTest<TypeParam> {};
diff --git a/src/caffe/test/test_blob.cpp b/src/caffe/test/test_blob.cpp
index a5240940..aec88f32 100644
--- a/src/caffe/test/test_blob.cpp
+++ b/src/caffe/test/test_blob.cpp
@@ -2,7 +2,6 @@
#include <cstring>
-#include "cuda_runtime.h"
#include "gtest/gtest.h"
#include "caffe/common.hpp"
#include "caffe/blob.hpp"
diff --git a/src/caffe/test/test_caffe_main.cpp b/src/caffe/test/test_caffe_main.cpp
index 07e6b8d5..bb5e6b46 100644
--- a/src/caffe/test/test_caffe_main.cpp
+++ b/src/caffe/test/test_caffe_main.cpp
@@ -6,14 +6,19 @@
#include "caffe/test/test_caffe_main.hpp"
namespace caffe {
+#ifndef CPU_ONLY
cudaDeviceProp CAFFE_TEST_CUDA_PROP;
+#endif
}
+#ifndef CPU_ONLY
using caffe::CAFFE_TEST_CUDA_PROP;
+#endif
int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
::google::InitGoogleLogging(argv[0]);
+#ifndef CPU_ONLY
// Before starting testing, let's first print out a few cuda defice info.
int device;
cudaGetDeviceCount(&device);
@@ -27,6 +32,7 @@ int main(int argc, char** argv) {
cudaGetDevice(&device);
cout << "Current device id: " << device << endl;
cudaGetDeviceProperties(&CAFFE_TEST_CUDA_PROP, device);
+#endif
// invoke the test.
return RUN_ALL_TESTS();
}
diff --git a/src/caffe/test/test_common.cpp b/src/caffe/test/test_common.cpp
index a452b612..a8e2eb15 100644
--- a/src/caffe/test/test_common.cpp
+++ b/src/caffe/test/test_common.cpp
@@ -2,7 +2,6 @@
#include <cstring>
-#include "cuda_runtime.h"
#include "gtest/gtest.h"
#include "caffe/common.hpp"
#include "caffe/syncedmem.hpp"
@@ -13,12 +12,16 @@ namespace caffe {
class CommonTest : public ::testing::Test {};
+#ifndef CPU_ONLY // GPU Caffe singleton test.
+
TEST_F(CommonTest, TestCublasHandlerGPU) {
int cuda_device_id;
CUDA_CHECK(cudaGetDevice(&cuda_device_id));
EXPECT_TRUE(Caffe::cublas_handle());
}
+#endif
+
TEST_F(CommonTest, TestBrewMode) {
Caffe::set_mode(Caffe::CPU);
EXPECT_EQ(Caffe::mode(), Caffe::CPU);
@@ -48,6 +51,8 @@ TEST_F(CommonTest, TestRandSeedCPU) {
}
}
+#ifndef CPU_ONLY // GPU Caffe singleton test.
+
TEST_F(CommonTest, TestRandSeedGPU) {
SyncedMemory data_a(10 * sizeof(unsigned int));
SyncedMemory data_b(10 * sizeof(unsigned int));
@@ -63,4 +68,6 @@ TEST_F(CommonTest, TestRandSeedGPU) {
}
}
+#endif
+
} // namespace caffe
diff --git a/src/caffe/test/test_concat_layer.cpp b/src/caffe/test/test_concat_layer.cpp
index ff208a90..0550bb2d 100644
--- a/src/caffe/test/test_concat_layer.cpp
+++ b/src/caffe/test/test_concat_layer.cpp
@@ -3,7 +3,6 @@
#include <cstring>
#include <vector>
-#include "cuda_runtime.h"
#include "gtest/gtest.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
@@ -15,8 +14,6 @@
namespace caffe {
-extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
-
template <typename TypeParam>
class ConcatLayerTest : public MultiDeviceTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;
diff --git a/src/caffe/test/test_convolution_layer.cpp b/src/caffe/test/test_convolution_layer.cpp
index 6f3e3146..d0522b13 100644
--- a/src/caffe/test/test_convolution_layer.cpp
+++ b/src/caffe/test/test_convolution_layer.cpp
@@ -3,7 +3,6 @@
#include <cstring>
#include <vector>
-#include "cuda_runtime.h"
#include "gtest/gtest.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
@@ -15,8 +14,6 @@
namespace caffe {
-extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
-
template <typename TypeParam>
class ConvolutionLayerTest : public MultiDeviceTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;
diff --git a/src/caffe/test/test_data_layer.cpp b/src/caffe/test/test_data_layer.cpp
index 8cd157dd..ed7337c5 100644
--- a/src/caffe/test/test_data_layer.cpp
+++ b/src/caffe/test/test_data_layer.cpp
@@ -3,7 +3,6 @@
#include <string>
#include <vector>
-#include "cuda_runtime.h"
#include "leveldb/db.h"
#include "gtest/gtest.h"
#include "caffe/blob.hpp"
@@ -18,8 +17,6 @@ using std::stringstream;
namespace caffe {
-extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
-
template <typename TypeParam>
class DataLayerTest : public MultiDeviceTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;
diff --git a/src/caffe/test/test_dummy_data_layer.cpp b/src/caffe/test/test_dummy_data_layer.cpp
index 3a83a797..5b5f2025 100644
--- a/src/caffe/test/test_dummy_data_layer.cpp
+++ b/src/caffe/test/test_dummy_data_layer.cpp
@@ -15,8 +15,6 @@ using std::stringstream;
namespace caffe {
-extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
-
template <typename Dtype>
class DummyDataLayerTest : public ::testing::Test {
protected:
diff --git a/src/caffe/test/test_eltwise_layer.cpp b/src/caffe/test/test_eltwise_layer.cpp
index 66490d2b..53bff96e 100644
--- a/src/caffe/test/test_eltwise_layer.cpp
+++ b/src/caffe/test/test_eltwise_layer.cpp
@@ -2,7 +2,6 @@
#include <vector>
-#include "cuda_runtime.h"
#include "gtest/gtest.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
@@ -14,8 +13,6 @@
namespace caffe {
-extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
-
template <typename TypeParam>
class EltwiseLayerTest : public MultiDeviceTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;
diff --git a/src/caffe/test/test_euclidean_loss_layer.cpp b/src/caffe/test/test_euclidean_loss_layer.cpp
index 8c796945..dd27670d 100644
--- a/src/caffe/test/test_euclidean_loss_layer.cpp
+++ b/src/caffe/test/test_euclidean_loss_layer.cpp
@@ -5,7 +5,6 @@
#include <cstring>
#include <vector>
-#include "cuda_runtime.h"
#include "gtest/gtest.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
@@ -17,8 +16,6 @@
namespace caffe {
-extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
-
template <typename TypeParam>
class EuclideanLossLayerTest : public MultiDeviceTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;
diff --git a/src/caffe/test/test_filler.cpp b/src/caffe/test/test_filler.cpp
index 93eda7e8..1b145f24 100644
--- a/src/caffe/test/test_filler.cpp
+++ b/src/caffe/test/test_filler.cpp
@@ -2,7 +2,6 @@
#include <cstring>
-#include "cuda_runtime.h"
#include "gtest/gtest.h"
#include "caffe/filler.hpp"
diff --git a/src/caffe/test/test_flatten_layer.cpp b/src/caffe/test/test_flatten_layer.cpp
index e6e777e6..bea099b1 100644
--- a/src/caffe/test/test_flatten_layer.cpp
+++ b/src/caffe/test/test_flatten_layer.cpp
@@ -3,7 +3,6 @@
#include <cstring>
#include <vector>
-#include "cuda_runtime.h"
#include "gtest/gtest.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
@@ -15,8 +14,6 @@
namespace caffe {
-extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
-
template <typename TypeParam>
class FlattenLayerTest : public MultiDeviceTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;
diff --git a/src/caffe/test/test_hdf5_output_layer.cpp b/src/caffe/test/test_hdf5_output_layer.cpp
index 6fd9a2fc..221d62aa 100644
--- a/src/caffe/test/test_hdf5_output_layer.cpp
+++ b/src/caffe/test/test_hdf5_output_layer.cpp
@@ -1,6 +1,5 @@
// Copyright 2014 BVLC and contributors.
-#include <cuda_runtime.h>
#include <string>
#include <vector>
@@ -17,8 +16,6 @@ using std::vector;
namespace caffe {
-extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
-
template<typename TypeParam>
class HDF5OutputLayerTest : public MultiDeviceTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;
diff --git a/src/caffe/test/test_hdf5data_layer.cpp b/src/caffe/test/test_hdf5data_layer.cpp
index 3eb24210..e606e209 100644
--- a/src/caffe/test/test_hdf5data_layer.cpp
+++ b/src/caffe/test/test_hdf5data_layer.cpp
@@ -3,7 +3,6 @@
#include <string>
#include <vector>
-#include "cuda_runtime.h"
#include "leveldb/db.h"
#include "gtest/gtest.h"
@@ -18,8 +17,6 @@ using std::string;
namespace caffe {
-extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
-
template <typename TypeParam>
class HDF5DataLayerTest : public MultiDeviceTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;
diff --git a/src/caffe/test/test_hinge_loss_layer.cpp b/src/caffe/test/test_hinge_loss_layer.cpp
index df6d8e25..84374e95 100644
--- a/src/caffe/test/test_hinge_loss_layer.cpp
+++ b/src/caffe/test/test_hinge_loss_layer.cpp
@@ -5,7 +5,6 @@
#include <cstring>
#include <vector>
-#include "cuda_runtime.h"
#include "gtest/gtest.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
@@ -17,8 +16,6 @@
namespace caffe {
-extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
-
template <typename TypeParam>
class HingeLossLayerTest : public MultiDeviceTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;
diff --git a/src/caffe/test/test_im2col_kernel.cu b/src/caffe/test/test_im2col_kernel.cu
index bd4404a0..5671968b 100644
--- a/src/caffe/test/test_im2col_kernel.cu
+++ b/src/caffe/test/test_im2col_kernel.cu
@@ -3,7 +3,6 @@
#include <cstring>
#include <vector>
-#include "cuda_runtime.h"
#include "gtest/gtest.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
diff --git a/src/caffe/test/test_im2col_layer.cpp b/src/caffe/test/test_im2col_layer.cpp
index 5be19174..a40f59df 100644
--- a/src/caffe/test/test_im2col_layer.cpp
+++ b/src/caffe/test/test_im2col_layer.cpp
@@ -3,7 +3,6 @@
#include <cstring>
#include <vector>
-#include "cuda_runtime.h"
#include "gtest/gtest.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
@@ -15,8 +14,6 @@
namespace caffe {
-extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
-
template <typename TypeParam>
class Im2colLayerTest : public MultiDeviceTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;
diff --git a/src/caffe/test/test_image_data_layer.cpp b/src/caffe/test/test_image_data_layer.cpp
index fbd4a1ca..ae9a2dbc 100644
--- a/src/caffe/test/test_image_data_layer.cpp
+++ b/src/caffe/test/test_image_data_layer.cpp
@@ -1,7 +1,5 @@
// Copyright 2014 BVLC and contributors.
-#include <cuda_runtime.h>
-
#include <iostream> // NOLINT(readability/streams)
#include <fstream> // NOLINT(readability/streams)
#include <map>
@@ -21,8 +19,6 @@ using std::string;
namespace caffe {
-extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
-
template <typename TypeParam>
class ImageDataLayerTest : public MultiDeviceTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;
diff --git a/src/caffe/test/test_inner_product_layer.cpp b/src/caffe/test/test_inner_product_layer.cpp
index ad4783f9..de194f2d 100644
--- a/src/caffe/test/test_inner_product_layer.cpp
+++ b/src/caffe/test/test_inner_product_layer.cpp
@@ -3,7 +3,6 @@
#include <cstring>
#include <vector>
-#include "cuda_runtime.h"
#include "gtest/gtest.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
@@ -15,7 +14,9 @@
namespace caffe {
+#ifndef CPU_ONLY
extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
+#endif
template <typename TypeParam>
class InnerProductLayerTest : public MultiDeviceTest<TypeParam> {
@@ -57,8 +58,12 @@ TYPED_TEST(InnerProductLayerTest, TestSetUp) {
TYPED_TEST(InnerProductLayerTest, TestForward) {
typedef typename TypeParam::Dtype Dtype;
+ bool IS_VALID_CUDA = false;
+#ifndef CPU_ONLY
+ IS_VALID_CUDA = CAFFE_TEST_CUDA_PROP.major >= 2;
+#endif
if (Caffe::mode() == Caffe::CPU ||
- sizeof(Dtype) == 4 || CAFFE_TEST_CUDA_PROP.major >= 2) {
+ sizeof(Dtype) == 4 || IS_VALID_CUDA) {
LayerParameter layer_param;
InnerProductParameter* inner_product_param =
layer_param.mutable_inner_product_param();
@@ -83,8 +88,12 @@ TYPED_TEST(InnerProductLayerTest, TestForward) {
TYPED_TEST(InnerProductLayerTest, TestGradient) {
typedef typename TypeParam::Dtype Dtype;
+ bool IS_VALID_CUDA = false;
+#ifndef CPU_ONLY
+ IS_VALID_CUDA = CAFFE_TEST_CUDA_PROP.major >= 2;
+#endif
if (Caffe::mode() == Caffe::CPU ||
- sizeof(Dtype) == 4 || CAFFE_TEST_CUDA_PROP.major >= 2) {
+ sizeof(Dtype) == 4 || IS_VALID_CUDA) {
LayerParameter layer_param;
InnerProductParameter* inner_product_param =
layer_param.mutable_inner_product_param();
diff --git a/src/caffe/test/test_lrn_layer.cpp b/src/caffe/test/test_lrn_layer.cpp
index a627c97f..5bd5533f 100644
--- a/src/caffe/test/test_lrn_layer.cpp
+++ b/src/caffe/test/test_lrn_layer.cpp
@@ -4,7 +4,6 @@
#include <cstring>
#include <vector>
-#include "cuda_runtime.h"
#include "gtest/gtest.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
@@ -19,8 +18,6 @@ using std::max;
namespace caffe {
-extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
-
template <typename TypeParam>
class LRNLayerTest : public MultiDeviceTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;
diff --git a/src/caffe/test/test_maxpool_dropout_layers.cpp b/src/caffe/test/test_maxpool_dropout_layers.cpp
index eef375ac..733bbf4a 100644
--- a/src/caffe/test/test_maxpool_dropout_layers.cpp
+++ b/src/caffe/test/test_maxpool_dropout_layers.cpp
@@ -3,8 +3,6 @@
#include <cstring>
#include <vector>
-#include "cuda_runtime.h"
-
#include "gtest/gtest.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
diff --git a/src/caffe/test/test_multinomial_logistic_loss_layer.cpp b/src/caffe/test/test_multinomial_logistic_loss_layer.cpp
index d73347e8..ec53d40e 100644
--- a/src/caffe/test/test_multinomial_logistic_loss_layer.cpp
+++ b/src/caffe/test/test_multinomial_logistic_loss_layer.cpp
@@ -5,7 +5,6 @@
#include <cstring>
#include <vector>
-#include "cuda_runtime.h"
#include "gtest/gtest.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
@@ -17,8 +16,6 @@
namespace caffe {
-extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
-
template <typename Dtype>
class MultinomialLogisticLossLayerTest : public ::testing::Test {
protected:
diff --git a/src/caffe/test/test_neuron_layer.cpp b/src/caffe/test/test_neuron_layer.cpp
index f4447184..246832d2 100644
--- a/src/caffe/test/test_neuron_layer.cpp
+++ b/src/caffe/test/test_neuron_layer.cpp
@@ -3,7 +3,6 @@
#include <cstring>
#include <vector>
-#include "cuda_runtime.h"
#include "gtest/gtest.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
@@ -15,8 +14,6 @@
namespace caffe {
-extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
-
template <typename TypeParam>
class NeuronLayerTest : public MultiDeviceTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;
diff --git a/src/caffe/test/test_platform.cpp b/src/caffe/test/test_platform.cpp
index 7cf8306e..e1c4dafc 100644
--- a/src/caffe/test/test_platform.cpp
+++ b/src/caffe/test/test_platform.cpp
@@ -1,9 +1,10 @@
// Copyright 2014 BVLC and contributors.
+#ifndef CPU_ONLY
+
#include <cstdlib>
#include <cstdio>
-#include "cuda_runtime.h"
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "caffe/test/test_caffe_main.hpp"
@@ -53,3 +54,5 @@ TEST_F(PlatformTest, TestInitialization) {
}
} // namespace caffe
+
+#endif // CPU_ONLY
diff --git a/src/caffe/test/test_pooling_layer.cpp b/src/caffe/test/test_pooling_layer.cpp
index b209d821..b9cec54a 100644
--- a/src/caffe/test/test_pooling_layer.cpp
+++ b/src/caffe/test/test_pooling_layer.cpp
@@ -3,7 +3,6 @@
#include <cstring>
#include <vector>
-#include "cuda_runtime.h"
#include "gtest/gtest.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
@@ -15,8 +14,6 @@
namespace caffe {
-extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
-
template <typename TypeParam>
class PoolingLayerTest : public MultiDeviceTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;
diff --git a/src/caffe/test/test_power_layer.cpp b/src/caffe/test/test_power_layer.cpp
index a1b716ad..c9992d57 100644
--- a/src/caffe/test/test_power_layer.cpp
+++ b/src/caffe/test/test_power_layer.cpp
@@ -3,7 +3,6 @@
#include <algorithm>
#include <vector>
-#include "cuda_runtime.h"
#include "gtest/gtest.h"
#include "caffe/blob.hpp"
@@ -18,8 +17,6 @@ using std::isnan;
namespace caffe {
-extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
-
template <typename TypeParam>
class PowerLayerTest : public MultiDeviceTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;
diff --git a/src/caffe/test/test_random_number_generator.cpp b/src/caffe/test/test_random_number_generator.cpp
index 3cd77da9..abe21c25 100644
--- a/src/caffe/test/test_random_number_generator.cpp
+++ b/src/caffe/test/test_random_number_generator.cpp
@@ -1,6 +1,5 @@
// Copyright 2014 BVLC and contributors.
-#include <cuda_runtime.h>
#include <cmath>
#include <cstring>
diff --git a/src/caffe/test/test_sigmoid_cross_entropy_loss_layer.cpp b/src/caffe/test/test_sigmoid_cross_entropy_loss_layer.cpp
index 76bbfb48..a5388db0 100644
--- a/src/caffe/test/test_sigmoid_cross_entropy_loss_layer.cpp
+++ b/src/caffe/test/test_sigmoid_cross_entropy_loss_layer.cpp
@@ -16,8 +16,6 @@
namespace caffe {
-extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
-
template <typename TypeParam>
class SigmoidCrossEntropyLossLayerTest : public MultiDeviceTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;
diff --git a/src/caffe/test/test_softmax_layer.cpp b/src/caffe/test/test_softmax_layer.cpp
index f0be279b..fa899f92 100644
--- a/src/caffe/test/test_softmax_layer.cpp
+++ b/src/caffe/test/test_softmax_layer.cpp
@@ -4,7 +4,6 @@
#include <cstring>
#include <vector>
-#include "cuda_runtime.h"
#include "gtest/gtest.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
@@ -16,8 +15,6 @@
namespace caffe {
-extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
-
template <typename TypeParam>
class SoftmaxLayerTest : public MultiDeviceTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;
diff --git a/src/caffe/test/test_softmax_with_loss_layer.cpp b/src/caffe/test/test_softmax_with_loss_layer.cpp
index efd6e33c..6f45c388 100644
--- a/src/caffe/test/test_softmax_with_loss_layer.cpp
+++ b/src/caffe/test/test_softmax_with_loss_layer.cpp
@@ -5,7 +5,6 @@
#include <cstring>
#include <vector>
-#include "cuda_runtime.h"
#include "gtest/gtest.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
@@ -17,8 +16,6 @@
namespace caffe {
-extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
-
template <typename TypeParam>
class SoftmaxWithLossLayerTest : public MultiDeviceTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;
diff --git a/src/caffe/test/test_split_layer.cpp b/src/caffe/test/test_split_layer.cpp
index 455fb59d..bbee6d28 100644
--- a/src/caffe/test/test_split_layer.cpp
+++ b/src/caffe/test/test_split_layer.cpp
@@ -4,7 +4,6 @@
#include <string>
#include <vector>
-#include "cuda_runtime.h"
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "caffe/blob.hpp"
@@ -18,8 +17,6 @@
namespace caffe {
-extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
-
template <typename TypeParam>
class SplitLayerTest : public MultiDeviceTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;
diff --git a/src/caffe/test/test_stochastic_pooling.cpp b/src/caffe/test/test_stochastic_pooling.cpp
index 7a931d22..66e9b2d7 100644
--- a/src/caffe/test/test_stochastic_pooling.cpp
+++ b/src/caffe/test/test_stochastic_pooling.cpp
@@ -4,7 +4,6 @@
#include <cstring>
#include <vector>
-#include "cuda_runtime.h"
#include "gtest/gtest.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
@@ -18,8 +17,6 @@ using std::min;
namespace caffe {
-extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
-
template <typename Dtype>
class StochasticPoolingLayerTest : public ::testing::Test {
protected:
diff --git a/src/caffe/test/test_syncedmem.cpp b/src/caffe/test/test_syncedmem.cpp
index 20bd8613..f07682a2 100644
--- a/src/caffe/test/test_syncedmem.cpp
+++ b/src/caffe/test/test_syncedmem.cpp
@@ -3,10 +3,10 @@
#include <cstring>
#include <vector>
-#include "cuda_runtime.h"
#include "gtest/gtest.h"
#include "caffe/common.hpp"
#include "caffe/syncedmem.hpp"
+#include "caffe/util/device_alternate.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/test/test_caffe_main.hpp"
@@ -24,6 +24,8 @@ TEST_F(SyncedMemoryTest, TestInitialization) {
delete p_mem;
}
+#ifndef CPU_ONLY // GPU test
+
TEST_F(SyncedMemoryTest, TestAllocationCPUGPU) {
SyncedMemory mem(10);
EXPECT_TRUE(mem.cpu_data());
@@ -32,18 +34,24 @@ TEST_F(SyncedMemoryTest, TestAllocationCPUGPU) {
EXPECT_TRUE(mem.mutable_gpu_data());
}
+#endif
+
TEST_F(SyncedMemoryTest, TestAllocationCPU) {
SyncedMemory mem(10);
EXPECT_TRUE(mem.cpu_data());
EXPECT_TRUE(mem.mutable_cpu_data());
}
+#ifndef CPU_ONLY // GPU test
+
TEST_F(SyncedMemoryTest, TestAllocationGPU) {
SyncedMemory mem(10);
EXPECT_TRUE(mem.gpu_data());
EXPECT_TRUE(mem.mutable_gpu_data());
}
+#endif
+
TEST_F(SyncedMemoryTest, TestCPUWrite) {
SyncedMemory mem(10);
void* cpu_data = mem.mutable_cpu_data();
@@ -61,6 +69,8 @@ TEST_F(SyncedMemoryTest, TestCPUWrite) {
}
}
+#ifndef CPU_ONLY // GPU test
+
TEST_F(SyncedMemoryTest, TestGPURead) {
SyncedMemory mem(10);
void* cpu_data = mem.mutable_cpu_data();
@@ -112,4 +122,6 @@ TEST_F(SyncedMemoryTest, TestGPUWrite) {
EXPECT_EQ(mem.head(), SyncedMemory::SYNCED);
}
+#endif
+
} // namespace caffe
diff --git a/src/caffe/test/test_tanh_layer.cpp b/src/caffe/test/test_tanh_layer.cpp
index 171eb4e4..7fc443f6 100644
--- a/src/caffe/test/test_tanh_layer.cpp
+++ b/src/caffe/test/test_tanh_layer.cpp
@@ -5,7 +5,6 @@
#include <cstring>
#include <vector>
-#include "cuda_runtime.h"
#include "gtest/gtest.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
@@ -17,8 +16,6 @@
namespace caffe {
-extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
-
template <typename TypeParam>
class TanHLayerTest : public MultiDeviceTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;
diff --git a/src/caffe/test/test_threshold_layer.cpp b/src/caffe/test/test_threshold_layer.cpp
index 46519ff2..7006dd11 100644
--- a/src/caffe/test/test_threshold_layer.cpp
+++ b/src/caffe/test/test_threshold_layer.cpp
@@ -2,7 +2,6 @@
#include <vector>
-#include "cuda_runtime.h"
#include "gtest/gtest.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
@@ -14,8 +13,6 @@
namespace caffe {
-extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
-
template <typename TypeParam>
class ThresholdLayerTest : public MultiDeviceTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;
diff --git a/src/caffe/test/test_upgrade_proto.cpp b/src/caffe/test/test_upgrade_proto.cpp
index 9203f558..72ee8389 100644
--- a/src/caffe/test/test_upgrade_proto.cpp
+++ b/src/caffe/test/test_upgrade_proto.cpp
@@ -4,7 +4,6 @@
#include <string>
#include <vector>
-#include "cuda_runtime.h"
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "caffe/blob.hpp"
diff --git a/src/caffe/test/test_util_blas.cpp b/src/caffe/test/test_util_blas.cpp
index 725d24e2..d74f9f01 100644
--- a/src/caffe/test/test_util_blas.cpp
+++ b/src/caffe/test/test_util_blas.cpp
@@ -1,12 +1,12 @@
// Copyright 2014 BVLC and contributors.
-#include <cstring>
+#ifndef CPU_ONLY // CPU-GPU test
-#include "cuda_runtime.h"
-#include "cublas_v2.h"
+#include <cstring>
#include "gtest/gtest.h"
#include "caffe/blob.hpp"
+#include "caffe/util/device_alternate.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/test/test_caffe_main.hpp"
@@ -131,3 +131,5 @@ TYPED_TEST(GemmTest, TestGemvCPUGPU) {
}
} // namespace caffe
+
+#endif // CPU_ONLY
diff --git a/src/caffe/util/benchmark.cpp b/src/caffe/util/benchmark.cpp
index 0bd85218..009b118a 100644
--- a/src/caffe/util/benchmark.cpp
+++ b/src/caffe/util/benchmark.cpp
@@ -1,7 +1,6 @@
// Copyright 2014 BVLC and contributors.
#include <boost/date_time/posix_time/posix_time.hpp>
-#include <cuda_runtime.h>
#include "caffe/common.hpp"
#include "caffe/util/benchmark.hpp"
@@ -17,15 +16,23 @@ Timer::Timer()
Timer::~Timer() {
if (Caffe::mode() == Caffe::GPU) {
+#ifndef CPU_ONLY
CUDA_CHECK(cudaEventDestroy(start_gpu_));
CUDA_CHECK(cudaEventDestroy(stop_gpu_));
+#else
+ NO_GPU;
+#endif
}
}
void Timer::Start() {
if (!running()) {
if (Caffe::mode() == Caffe::GPU) {
+#ifndef CPU_ONLY
CUDA_CHECK(cudaEventRecord(start_gpu_, 0));
+#else
+ NO_GPU;
+#endif
} else {
start_cpu_ = boost::posix_time::microsec_clock::local_time();
}
@@ -37,8 +44,12 @@ void Timer::Start() {
void Timer::Stop() {
if (running()) {
if (Caffe::mode() == Caffe::GPU) {
+#ifndef CPU_ONLY
CUDA_CHECK(cudaEventRecord(stop_gpu_, 0));
CUDA_CHECK(cudaEventSynchronize(stop_gpu_));
+#else
+ NO_GPU;
+#endif
} else {
stop_cpu_ = boost::posix_time::microsec_clock::local_time();
}
@@ -55,8 +66,12 @@ float Timer::MilliSeconds() {
Stop();
}
if (Caffe::mode() == Caffe::GPU) {
+#ifndef CPU_ONLY
CUDA_CHECK(cudaEventElapsedTime(&elapsed_milliseconds_, start_gpu_,
stop_gpu_));
+#else
+ NO_GPU;
+#endif
} else {
elapsed_milliseconds_ = (stop_cpu_ - start_cpu_).total_milliseconds();
}
@@ -70,8 +85,12 @@ float Timer::Seconds() {
void Timer::Init() {
if (!initted()) {
if (Caffe::mode() == Caffe::GPU) {
+#ifndef CPU_ONLY
CUDA_CHECK(cudaEventCreate(&start_gpu_));
CUDA_CHECK(cudaEventCreate(&stop_gpu_));
+#else
+ NO_GPU;
+#endif
}
initted_ = true;
}
diff --git a/src/caffe/util/math_functions.cpp b/src/caffe/util/math_functions.cpp
index b989ca2a..36d8877d 100644
--- a/src/caffe/util/math_functions.cpp
+++ b/src/caffe/util/math_functions.cpp
@@ -2,7 +2,6 @@
#include <boost/math/special_functions/next.hpp>
#include <boost/random.hpp>
-#include <cublas_v2.h>
#include <limits>
@@ -35,38 +34,6 @@ void caffe_cpu_gemm<double>(const CBLAS_TRANSPOSE TransA,
}
template <>
-void caffe_gpu_gemm<float>(const CBLAS_TRANSPOSE TransA,
- const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
- const float alpha, const float* A, const float* B, const float beta,
- float* C) {
- // Note that cublas follows fortran order.
- int lda = (TransA == CblasNoTrans) ? K : M;
- int ldb = (TransB == CblasNoTrans) ? N : K;
- cublasOperation_t cuTransA =
- (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
- cublasOperation_t cuTransB =
- (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
- CUBLAS_CHECK(cublasSgemm(Caffe::cublas_handle(), cuTransB, cuTransA,
- N, M, K, &alpha, B, ldb, A, lda, &beta, C, N));
-}
-
-template <>
-void caffe_gpu_gemm<double>(const CBLAS_TRANSPOSE TransA,
- const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
- const double alpha, const double* A, const double* B, const double beta,
- double* C) {
- // Note that cublas follows fortran order.
- int lda = (TransA == CblasNoTrans) ? K : M;
- int ldb = (TransB == CblasNoTrans) ? N : K;
- cublasOperation_t cuTransA =
- (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
- cublasOperation_t cuTransB =
- (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
- CUBLAS_CHECK(cublasDgemm(Caffe::cublas_handle(), cuTransB, cuTransA,
- N, M, K, &alpha, B, ldb, A, lda, &beta, C, N));
-}
-
-template <>
void caffe_cpu_gemv<float>(const CBLAS_TRANSPOSE TransA, const int M,
const int N, const float alpha, const float* A, const float* x,
const float beta, float* y) {
@@ -81,26 +48,6 @@ void caffe_cpu_gemv<double>(const CBLAS_TRANSPOSE TransA, const int M,
}
template <>
-void caffe_gpu_gemv<float>(const CBLAS_TRANSPOSE TransA, const int M,
- const int N, const float alpha, const float* A, const float* x,
- const float beta, float* y) {
- cublasOperation_t cuTransA =
- (TransA == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
- CUBLAS_CHECK(cublasSgemv(Caffe::cublas_handle(), cuTransA, N, M, &alpha,
- A, N, x, 1, &beta, y, 1));
-}
-
-template <>
-void caffe_gpu_gemv<double>(const CBLAS_TRANSPOSE TransA, const int M,
- const int N, const double alpha, const double* A, const double* x,
- const double beta, double* y) {
- cublasOperation_t cuTransA =
- (TransA == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
- CUBLAS_CHECK(cublasDgemv(Caffe::cublas_handle(), cuTransA, N, M, &alpha,
- A, N, x, 1, &beta, y, 1));
-}
-
-template <>
void caffe_axpy<float>(const int N, const float alpha, const float* X,
float* Y) { cblas_saxpy(N, alpha, X, 1, Y, 1); }
@@ -108,18 +55,6 @@ template <>
void caffe_axpy<double>(const int N, const double alpha, const double* X,
double* Y) { cblas_daxpy(N, alpha, X, 1, Y, 1); }
-template <>
-void caffe_gpu_axpy<float>(const int N, const float alpha, const float* X,
- float* Y) {
- CUBLAS_CHECK(cublasSaxpy(Caffe::cublas_handle(), N, &alpha, X, 1, Y, 1));
-}
-
-template <>
-void caffe_gpu_axpy<double>(const int N, const double alpha, const double* X,
- double* Y) {
- CUBLAS_CHECK(cublasDaxpy(Caffe::cublas_handle(), N, &alpha, X, 1, Y, 1));
-}
-
template <typename Dtype>
void caffe_set(const int N, const Dtype alpha, Dtype* Y) {
if (alpha == 0) {
@@ -153,7 +88,11 @@ template <typename Dtype>
void caffe_copy(const int N, const Dtype* X, Dtype* Y) {
if (X != Y) {
if (Caffe::mode() == Caffe::GPU) {
+#ifndef CPU_ONLY
CUDA_CHECK(cudaMemcpy(Y, X, sizeof(Dtype) * N, cudaMemcpyDefault));
+#else
+ NO_GPU;
+#endif
} else {
memcpy(Y, X, sizeof(Dtype) * N);
}
@@ -166,12 +105,6 @@ template void caffe_copy<unsigned int>(const int N, const unsigned int* X,
template void caffe_copy<float>(const int N, const float* X, float* Y);
template void caffe_copy<double>(const int N, const double* X, double* Y);
-void caffe_gpu_memcpy(const size_t N, const void* X, void* Y) {
- if (X != Y) {
- CUDA_CHECK(cudaMemcpy(Y, X, N, cudaMemcpyDefault));
- }
-}
-
template <>
void caffe_scal<float>(const int N, const float alpha, float *X) {
cblas_sscal(N, alpha, X, 1);
@@ -183,30 +116,6 @@ void caffe_scal<double>(const int N, const double alpha, double *X) {
}
template <>
-void caffe_gpu_scal<float>(const int N, const float alpha, float *X) {
- CUBLAS_CHECK(cublasSscal(Caffe::cublas_handle(), N, &alpha, X, 1));
-}
-
-template <>
-void caffe_gpu_scal<double>(const int N, const double alpha, double *X) {
- CUBLAS_CHECK(cublasDscal(Caffe::cublas_handle(), N, &alpha, X, 1));
-}
-
-template <>
-void caffe_gpu_axpby<float>(const int N, const float alpha, const float* X,
- const float beta, float* Y) {
- caffe_gpu_scal<float>(N, beta, Y);
- caffe_gpu_axpy<float>(N, alpha, X, Y);
-}
-
-template <>
-void caffe_gpu_axpby<double>(const int N, const double alpha, const double* X,
- const double beta, double* Y) {
- caffe_gpu_scal<double>(N, beta, Y);
- caffe_gpu_axpy<double>(N, alpha, X, Y);
-}
-
-template <>
void caffe_cpu_axpby<float>(const int N, const float alpha, const float* X,
const float beta, float* Y) {
cblas_saxpby(N, alpha, X, 1, beta, Y, 1);
@@ -408,18 +317,6 @@ double caffe_cpu_dot<double>(const int n, const double* x, const double* y) {
}
template <>
-void caffe_gpu_dot<float>(const int n, const float* x, const float* y,
- float* out) {
- CUBLAS_CHECK(cublasSdot(Caffe::cublas_handle(), n, x, 1, y, 1, out));
-}
-
-template <>
-void caffe_gpu_dot<double>(const int n, const double* x, const double* y,
- double * out) {
- CUBLAS_CHECK(cublasDdot(Caffe::cublas_handle(), n, x, 1, y, 1, out));
-}
-
-template <>
int caffe_cpu_hamming_distance<float>(const int n, const float* x,
const float* y) {
int dist = 0;
@@ -451,16 +348,6 @@ double caffe_cpu_asum<double>(const int n, const double* x) {
return cblas_dasum(n, x, 1);
}
-template <>
-void caffe_gpu_asum<float>(const int n, const float* x, float* y) {
- CUBLAS_CHECK(cublasSasum(Caffe::cublas_handle(), n, x, 1, y));
-}
-
-template <>
-void caffe_gpu_asum<double>(const int n, const double* x, double* y) {
- CUBLAS_CHECK(cublasDasum(Caffe::cublas_handle(), n, x, 1, y));
-}
-
INSTANTIATE_CAFFE_CPU_UNARY_FUNC(sign);
INSTANTIATE_CAFFE_CPU_UNARY_FUNC(sgnbit);
INSTANTIATE_CAFFE_CPU_UNARY_FUNC(fabs);
@@ -479,18 +366,4 @@ void caffe_cpu_scale<double>(const int n, const double alpha, const double *x,
cblas_dscal(n, alpha, y, 1);
}
-template <>
-void caffe_gpu_scale<float>(const int n, const float alpha, const float *x,
- float* y) {
- CUBLAS_CHECK(cublasScopy(Caffe::cublas_handle(), n, x, 1, y, 1));
- CUBLAS_CHECK(cublasSscal(Caffe::cublas_handle(), n, &alpha, y, 1));
-}
-
-template <>
-void caffe_gpu_scale<double>(const int n, const double alpha, const double *x,
- double* y) {
- CUBLAS_CHECK(cublasDcopy(Caffe::cublas_handle(), n, x, 1, y, 1));
- CUBLAS_CHECK(cublasDscal(Caffe::cublas_handle(), n, &alpha, y, 1));
-}
-
} // namespace caffe
diff --git a/src/caffe/util/math_functions.cu b/src/caffe/util/math_functions.cu
index 849e53b9..1e934931 100644
--- a/src/caffe/util/math_functions.cu
+++ b/src/caffe/util/math_functions.cu
@@ -4,6 +4,7 @@
#include <thrust/device_vector.h>
#include <thrust/functional.h> // thrust::plus
#include <thrust/reduce.h>
+
#include <cmath>
#include <cstdlib>
#include <cstring>
@@ -13,6 +14,136 @@
namespace caffe {
+template <>
+void caffe_gpu_gemm<float>(const CBLAS_TRANSPOSE TransA,
+ const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
+ const float alpha, const float* A, const float* B, const float beta,
+ float* C) {
+ // Note that cublas follows fortran order.
+ int lda = (TransA == CblasNoTrans) ? K : M;
+ int ldb = (TransB == CblasNoTrans) ? N : K;
+ cublasOperation_t cuTransA =
+ (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
+ cublasOperation_t cuTransB =
+ (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
+ CUBLAS_CHECK(cublasSgemm(Caffe::cublas_handle(), cuTransB, cuTransA,
+ N, M, K, &alpha, B, ldb, A, lda, &beta, C, N));
+}
+
+template <>
+void caffe_gpu_gemm<double>(const CBLAS_TRANSPOSE TransA,
+ const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
+ const double alpha, const double* A, const double* B, const double beta,
+ double* C) {
+ // Note that cublas follows fortran order.
+ int lda = (TransA == CblasNoTrans) ? K : M;
+ int ldb = (TransB == CblasNoTrans) ? N : K;
+ cublasOperation_t cuTransA =
+ (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
+ cublasOperation_t cuTransB =
+ (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
+ CUBLAS_CHECK(cublasDgemm(Caffe::cublas_handle(), cuTransB, cuTransA,
+ N, M, K, &alpha, B, ldb, A, lda, &beta, C, N));
+}
+
+template <>
+void caffe_gpu_gemv<float>(const CBLAS_TRANSPOSE TransA, const int M,
+ const int N, const float alpha, const float* A, const float* x,
+ const float beta, float* y) {
+ cublasOperation_t cuTransA =
+ (TransA == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
+ CUBLAS_CHECK(cublasSgemv(Caffe::cublas_handle(), cuTransA, N, M, &alpha,
+ A, N, x, 1, &beta, y, 1));
+}
+
+template <>
+void caffe_gpu_gemv<double>(const CBLAS_TRANSPOSE TransA, const int M,
+ const int N, const double alpha, const double* A, const double* x,
+ const double beta, double* y) {
+ cublasOperation_t cuTransA =
+ (TransA == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
+ CUBLAS_CHECK(cublasDgemv(Caffe::cublas_handle(), cuTransA, N, M, &alpha,
+ A, N, x, 1, &beta, y, 1));
+}
+
+template <>
+void caffe_gpu_axpy<float>(const int N, const float alpha, const float* X,
+ float* Y) {
+ CUBLAS_CHECK(cublasSaxpy(Caffe::cublas_handle(), N, &alpha, X, 1, Y, 1));
+}
+
+template <>
+void caffe_gpu_axpy<double>(const int N, const double alpha, const double* X,
+ double* Y) {
+ CUBLAS_CHECK(cublasDaxpy(Caffe::cublas_handle(), N, &alpha, X, 1, Y, 1));
+}
+
+void caffe_gpu_memcpy(const size_t N, const void* X, void* Y) {
+ if (X != Y) {
+ CUDA_CHECK(cudaMemcpy(Y, X, N, cudaMemcpyDefault));
+ }
+}
+
+template <>
+void caffe_gpu_scal<float>(const int N, const float alpha, float *X) {
+ CUBLAS_CHECK(cublasSscal(Caffe::cublas_handle(), N, &alpha, X, 1));
+}
+
+template <>
+void caffe_gpu_scal<double>(const int N, const double alpha, double *X) {
+ CUBLAS_CHECK(cublasDscal(Caffe::cublas_handle(), N, &alpha, X, 1));
+}
+
+template <>
+void caffe_gpu_axpby<float>(const int N, const float alpha, const float* X,
+ const float beta, float* Y) {
+ caffe_gpu_scal<float>(N, beta, Y);
+ caffe_gpu_axpy<float>(N, alpha, X, Y);
+}
+
+template <>
+void caffe_gpu_axpby<double>(const int N, const double alpha, const double* X,
+ const double beta, double* Y) {
+ caffe_gpu_scal<double>(N, beta, Y);
+ caffe_gpu_axpy<double>(N, alpha, X, Y);
+}
+
+template <>
+void caffe_gpu_dot<float>(const int n, const float* x, const float* y,
+ float* out) {
+ CUBLAS_CHECK(cublasSdot(Caffe::cublas_handle(), n, x, 1, y, 1, out));
+}
+
+template <>
+void caffe_gpu_dot<double>(const int n, const double* x, const double* y,
+ double * out) {
+ CUBLAS_CHECK(cublasDdot(Caffe::cublas_handle(), n, x, 1, y, 1, out));
+}
+
+template <>
+void caffe_gpu_asum<float>(const int n, const float* x, float* y) {
+ CUBLAS_CHECK(cublasSasum(Caffe::cublas_handle(), n, x, 1, y));
+}
+
+template <>
+void caffe_gpu_asum<double>(const int n, const double* x, double* y) {
+ CUBLAS_CHECK(cublasDasum(Caffe::cublas_handle(), n, x, 1, y));
+}
+
+template <>
+void caffe_gpu_scale<float>(const int n, const float alpha, const float *x,
+ float* y) {
+ CUBLAS_CHECK(cublasScopy(Caffe::cublas_handle(), n, x, 1, y, 1));
+ CUBLAS_CHECK(cublasSscal(Caffe::cublas_handle(), n, &alpha, y, 1));
+}
+
+template <>
+void caffe_gpu_scale<double>(const int n, const double alpha, const double *x,
+ double* y) {
+ CUBLAS_CHECK(cublasDcopy(Caffe::cublas_handle(), n, x, 1, y, 1));
+ CUBLAS_CHECK(cublasDscal(Caffe::cublas_handle(), n, &alpha, y, 1));
+}
+
template <typename Dtype>
__global__ void set_kernel(const int n, const Dtype alpha, Dtype* y) {
CUDA_KERNEL_LOOP(index, n) {