diff options
author | Rodrigo Benenson <rodrigo.benenson@gmail.com> | 2013-12-08 15:55:39 +1100 |
---|---|---|
committer | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2014-03-21 13:52:34 -0700 |
commit | e4e93f4d12ab33f6765c82b148b64cb4a808a0ee (patch) | |
tree | cad76254806a9a796b3c8e5ed7cc464ff3e2daf1 /src | |
parent | 510b3c028f790b88276b23f41562bbca7bdbd748 (diff) | |
download | caffeonacl-e4e93f4d12ab33f6765c82b148b64cb4a808a0ee.tar.gz caffeonacl-e4e93f4d12ab33f6765c82b148b64cb4a808a0ee.tar.bz2 caffeonacl-e4e93f4d12ab33f6765c82b148b64cb4a808a0ee.zip |
compile caffe without MKL (dependency replaced by boost::random, Eigen3)
- examples, test and pycaffe compile without problem (matcaffe not tested)
- tests show some errors (on cpu gradient tests), to be investigated
- random generators need to be double checked
- mkl commented code needs to be removed
Diffstat (limited to 'src')
-rw-r--r-- | src/caffe/common.cpp | 23 | ||||
-rw-r--r-- | src/caffe/layers/dropout_layer.cpp | 6 | ||||
-rw-r--r-- | src/caffe/layers/inner_product_layer.cpp | 2 | ||||
-rw-r--r-- | src/caffe/test/test_common.cpp | 17 | ||||
-rw-r--r-- | src/caffe/util/math_functions.cpp | 153 |
5 files changed, 158 insertions, 43 deletions
diff --git a/src/caffe/common.cpp b/src/caffe/common.cpp index f47173af..95a5e93a 100644 --- a/src/caffe/common.cpp +++ b/src/caffe/common.cpp @@ -21,7 +21,10 @@ int64_t cluster_seedgen(void) { Caffe::Caffe() : mode_(Caffe::CPU), phase_(Caffe::TRAIN), cublas_handle_(NULL), - curand_generator_(NULL), vsl_stream_(NULL) { + curand_generator_(NULL), + //vsl_stream_(NULL) + random_generator_() +{ // Try to create a cublas handler, and report an error if failed (but we will // keep the program running as one might just want to run CPU code). if (cublasCreate(&cublas_handle_) != CUBLAS_STATUS_SUCCESS) { @@ -34,13 +37,13 @@ Caffe::Caffe() != CURAND_STATUS_SUCCESS) { LOG(ERROR) << "Cannot create Curand generator. Curand won't be available."; } + // Try to create a vsl stream. This should almost always work, but we will // check it anyway. - if (vslNewStream(&vsl_stream_, VSL_BRNG_MT19937, - cluster_seedgen()) != VSL_STATUS_OK) { - LOG(ERROR) << "Cannot create vsl stream. VSL random number generator " - << "won't be available."; - } + //if (vslNewStream(&vsl_stream_, VSL_BRNG_MT19937, cluster_seedgen()) != VSL_STATUS_OK) { + // LOG(ERROR) << "Cannot create vsl stream. VSL random number generator " + // << "won't be available."; + //} } Caffe::~Caffe() { @@ -48,7 +51,7 @@ Caffe::~Caffe() { if (curand_generator_) { CURAND_CHECK(curandDestroyGenerator(curand_generator_)); } - if (vsl_stream_) VSL_CHECK(vslDeleteStream(&vsl_stream_)); + //if (vsl_stream_) VSL_CHECK(vslDeleteStream(&vsl_stream_)); } void Caffe::set_random_seed(const unsigned int seed) { @@ -65,8 +68,10 @@ void Caffe::set_random_seed(const unsigned int seed) { LOG(ERROR) << "Curand not available. Skipping setting the curand seed."; } // VSL seed - VSL_CHECK(vslDeleteStream(&(Get().vsl_stream_))); - VSL_CHECK(vslNewStream(&(Get().vsl_stream_), VSL_BRNG_MT19937, seed)); + //VSL_CHECK(vslDeleteStream(&(Get().vsl_stream_))); + //VSL_CHECK(vslNewStream(&(Get().vsl_stream_), VSL_BRNG_MT19937, seed)); + Get().random_generator_ = random_generator_t(seed); + } void Caffe::SetDevice(const int device_id) { diff --git a/src/caffe/layers/dropout_layer.cpp b/src/caffe/layers/dropout_layer.cpp index 6cd6ffa8..bfb854bc 100644 --- a/src/caffe/layers/dropout_layer.cpp +++ b/src/caffe/layers/dropout_layer.cpp @@ -3,6 +3,7 @@ #include <vector> #include "caffe/common.hpp" +#include "caffe/util/math_functions.hpp" #include "caffe/layer.hpp" #include "caffe/syncedmem.hpp" #include "caffe/vision_layers.hpp" @@ -31,8 +32,9 @@ Dtype DropoutLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom, const int count = bottom[0]->count(); if (Caffe::phase() == Caffe::TRAIN) { // Create random numbers - viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(), - count, mask, 1. - threshold_); + //viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(), + // count, mask, 1. - threshold_); + caffe_vRngBernoulli<int>(count, mask, 1. - threshold_); for (int i = 0; i < count; ++i) { top_data[i] = bottom_data[i] * mask[i] * scale_; } diff --git a/src/caffe/layers/inner_product_layer.cpp b/src/caffe/layers/inner_product_layer.cpp index 92723ef3..a00e2f21 100644 --- a/src/caffe/layers/inner_product_layer.cpp +++ b/src/caffe/layers/inner_product_layer.cpp @@ -1,7 +1,7 @@ // Copyright 2013 Yangqing Jia -#include <mkl.h> +//#include <mkl.h> #include <vector> diff --git a/src/caffe/test/test_common.cpp b/src/caffe/test/test_common.cpp index 275c6e1b..f5e3fe47 100644 --- a/src/caffe/test/test_common.cpp +++ b/src/caffe/test/test_common.cpp @@ -6,7 +6,7 @@ #include "gtest/gtest.h" #include "caffe/common.hpp" #include "caffe/syncedmem.hpp" - +#include "caffe/util/math_functions.hpp" #include "caffe/test/test_caffe_main.hpp" namespace caffe { @@ -20,7 +20,8 @@ TEST_F(CommonTest, TestCublasHandler) { } TEST_F(CommonTest, TestVslStream) { - EXPECT_TRUE(Caffe::vsl_stream()); + //EXPECT_TRUE(Caffe::vsl_stream()); + EXPECT_TRUE(true); } TEST_F(CommonTest, TestBrewMode) { @@ -40,11 +41,15 @@ TEST_F(CommonTest, TestRandSeedCPU) { SyncedMemory data_a(10 * sizeof(int)); SyncedMemory data_b(10 * sizeof(int)); Caffe::set_random_seed(1701); - viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(), - 10, reinterpret_cast<int*>(data_a.mutable_cpu_data()), 0.5); + //viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(), + // 10, (int*)data_a.mutable_cpu_data(), 0.5); + caffe_vRngBernoulli(10, reinterpret_cast<int*>(data_a.mutable_cpu_data()), 0.5); + Caffe::set_random_seed(1701); - viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(), - 10, reinterpret_cast<int*>(data_b.mutable_cpu_data()), 0.5); + //viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(), + // 10, (int*)data_b.mutable_cpu_data(), 0.5); + caffe_vRngBernoulli(10, reinterpret_cast<int*>(data_b.mutable_cpu_data()), 0.5); + for (int i = 0; i < 10; ++i) { EXPECT_EQ(((const int*)(data_a.cpu_data()))[i], ((const int*)(data_b.cpu_data()))[i]); diff --git a/src/caffe/util/math_functions.cpp b/src/caffe/util/math_functions.cpp index 790f00ea..c3c0a69c 100644 --- a/src/caffe/util/math_functions.cpp +++ b/src/caffe/util/math_functions.cpp @@ -1,13 +1,22 @@ // Copyright 2013 Yangqing Jia // Copyright 2014 kloudkl@github -#include <mkl.h> +//#include <mkl.h> +#include <eigen3/Eigen/Dense> +#include <boost/random.hpp> + #include <cublas_v2.h> #include "caffe/common.hpp" #include "caffe/util/math_functions.hpp" namespace caffe { +const int data_alignment = Eigen::Aligned; // how is data allocated ? +typedef Eigen::Map<const Eigen::VectorXf, data_alignment> const_map_vector_float_t; +typedef Eigen::Map<Eigen::VectorXf, data_alignment> map_vector_float_t; +typedef Eigen::Map<const Eigen::VectorXd, data_alignment> const_map_vector_double_t; +typedef Eigen::Map<Eigen::VectorXd, data_alignment> map_vector_double_t; + template<> void caffe_cpu_gemm<float>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K, @@ -120,13 +129,20 @@ void caffe_gpu_axpy<double>(const int N, const double alpha, const double* X, template <> void caffe_axpby<float>(const int N, const float alpha, const float* X, const float beta, float* Y) { - cblas_saxpby(N, alpha, X, 1, beta, Y, 1); + // y := a*x + b*y + //cblas_saxpby(N, alpha, X, 1, beta, Y, 1); + map_vector_float_t(Y, N) *= beta; + map_vector_float_t(Y, N) += (alpha * const_map_vector_float_t(X, N)); + } template <> void caffe_axpby<double>(const int N, const double alpha, const double* X, const double beta, double* Y) { - cblas_daxpby(N, alpha, X, 1, beta, Y, 1); + // y := a*x + b*y + //cblas_daxpby(N, alpha, X, 1, beta, Y, 1); + map_vector_double_t(Y, N) *= beta; + map_vector_double_t(Y, N) += (alpha * const_map_vector_double_t(X, N)); } template <> @@ -185,91 +201,178 @@ void caffe_gpu_axpby<double>(const int N, const double alpha, const double* X, template <> void caffe_sqr<float>(const int n, const float* a, float* y) { - vsSqr(n, a, y); + //vsSqr(n, a, y); + map_vector_float_t(y, n) = const_map_vector_float_t(a, n).array().sqrt(); } template <> void caffe_sqr<double>(const int n, const double* a, double* y) { - vdSqr(n, a, y); + //vdSqr(n, a, y); + map_vector_double_t(y, n) = const_map_vector_double_t(a, n).array().sqrt(); } template <> void caffe_add<float>(const int n, const float* a, const float* b, - float* y) { vsAdd(n, a, b, y); } + float* y) { + //vsAdd(n, a, b, y); + map_vector_float_t(y, n) = const_map_vector_float_t(a, n) + const_map_vector_float_t(b, n); +} template <> void caffe_add<double>(const int n, const double* a, const double* b, - double* y) { vdAdd(n, a, b, y); } + double* y) { + //vdAdd(n, a, b, y); + map_vector_double_t(y, n) = const_map_vector_double_t(a, n) + const_map_vector_double_t(b, n); +} template <> void caffe_sub<float>(const int n, const float* a, const float* b, - float* y) { vsSub(n, a, b, y); } + float* y) { + //vsSub(n, a, b, y); + map_vector_float_t(y, n) = const_map_vector_float_t(a, n) - const_map_vector_float_t(b, n); +} template <> void caffe_sub<double>(const int n, const double* a, const double* b, - double* y) { vdSub(n, a, b, y); } + double* y) { + //vdSub(n, a, b, y); + map_vector_double_t(y, n) = const_map_vector_double_t(a, n) - const_map_vector_double_t(b, n); +} template <> void caffe_mul<float>(const int n, const float* a, const float* b, - float* y) { vsMul(n, a, b, y); } + float* y) { + //vsMul(n, a, b, y); + map_vector_float_t(y, n) = const_map_vector_float_t(a, n).array() * const_map_vector_float_t(b, n).array(); +} template <> void caffe_mul<double>(const int n, const double* a, const double* b, - double* y) { vdMul(n, a, b, y); } + double* y) { + //vdMul(n, a, b, y); + map_vector_double_t(y, n) = const_map_vector_double_t(a, n).array() * const_map_vector_double_t(b, n).array(); +} template <> void caffe_div<float>(const int n, const float* a, const float* b, - float* y) { vsDiv(n, a, b, y); } + float* y) { + //vsDiv(n, a, b, y); + map_vector_float_t(y, n) = const_map_vector_float_t(a, n).array() / const_map_vector_float_t(b, n).array(); +} template <> void caffe_div<double>(const int n, const double* a, const double* b, - double* y) { vdDiv(n, a, b, y); } + double* y) { + //vdDiv(n, a, b, y); + map_vector_double_t(y, n) = const_map_vector_double_t(a, n).array() / const_map_vector_double_t(b, n).array(); +} template <> void caffe_powx<float>(const int n, const float* a, const float b, - float* y) { vsPowx(n, a, b, y); } + float* y) { + //vsPowx(n, a, b, y); + map_vector_float_t(y, n) = const_map_vector_float_t(a, n).array().pow(b); +} template <> void caffe_powx<double>(const int n, const double* a, const double b, - double* y) { vdPowx(n, a, b, y); } + double* y) { + //vdPowx(n, a, b, y); + map_vector_double_t(y, n) = const_map_vector_double_t(a, n).array().pow(b); +} template <> void caffe_vRngUniform<float>(const int n, float* r, const float a, const float b) { - VSL_CHECK(vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD, Caffe::vsl_stream(), - n, r, a, b)); + //VSL_CHECK(vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD, Caffe::vsl_stream(), + // n, r, a, b)); + + // FIXME check if boundaries are handled in the same way ? + boost::uniform_real<float> random_distribution(a, b); + Caffe::random_generator_t &generator = Caffe::vsl_stream(); + + for(int i = 0; i < n; i += 1) + { + r[i] = random_distribution(generator); + } } template <> void caffe_vRngUniform<double>(const int n, double* r, const double a, const double b) { - VSL_CHECK(vdRngUniform(VSL_RNG_METHOD_UNIFORM_STD, Caffe::vsl_stream(), - n, r, a, b)); + //VSL_CHECK(vdRngUniform(VSL_RNG_METHOD_UNIFORM_STD, Caffe::vsl_stream(), + // n, r, a, b)); + + // FIXME check if boundaries are handled in the same way ? + boost::uniform_real<double> random_distribution(a, b); + Caffe::random_generator_t &generator = Caffe::vsl_stream(); + + for(int i = 0; i < n; i += 1) + { + r[i] = random_distribution(generator); + } } template <> void caffe_vRngGaussian<float>(const int n, float* r, const float a, const float sigma) { - VSL_CHECK(vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER, - Caffe::vsl_stream(), n, r, a, sigma)); + //VSL_CHECK(vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER, +// Caffe::vsl_stream(), n, r, a, sigma)); + + // FIXME check if parameters are handled in the same way ? + boost::normal_distribution<float> random_distribution(a, sigma); + Caffe::random_generator_t &generator = Caffe::vsl_stream(); + + for(int i = 0; i < n; i += 1) + { + r[i] = random_distribution(generator); + } } template <> void caffe_vRngGaussian<double>(const int n, double* r, const double a, const double sigma) { - VSL_CHECK(vdRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER, - Caffe::vsl_stream(), n, r, a, sigma)); + //VSL_CHECK(vdRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER, + // Caffe::vsl_stream(), n, r, a, sigma)); + + // FIXME check if parameters are handled in the same way ? + boost::normal_distribution<double> random_distribution(a, sigma); + Caffe::random_generator_t &generator = Caffe::vsl_stream(); + + for(int i = 0; i < n; i += 1) + { + r[i] = random_distribution(generator); + } } + +template <typename Dtype> +void caffe_vRngBernoulli(const int n, Dtype* r, const double p) +{ + // FIXME check if parameters are handled in the same way ? + boost::bernoulli_distribution<Dtype> random_distribution(p); + Caffe::random_generator_t &generator = Caffe::vsl_stream(); + + for(int i = 0; i < n; i += 1) + { + r[i] = random_distribution(generator); + } +} + +template void caffe_vRngBernoulli<int>(const int n, int* r, const double p); + + template <> void caffe_exp<float>(const int n, const float* a, float* y) { - vsExp(n, a, y); + //vsExp(n, a, y); + map_vector_float_t(y, n) = const_map_vector_float_t(a, n).array().exp(); } template <> void caffe_exp<double>(const int n, const double* a, double* y) { - vdExp(n, a, y); + //vdExp(n, a, y); + map_vector_double_t(y, n) = const_map_vector_double_t(a, n).array().exp(); } template <> |