summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorRodrigo Benenson <rodrigo.benenson@gmail.com>2013-12-08 15:55:39 +1100
committerEvan Shelhamer <shelhamer@imaginarynumber.net>2014-03-21 13:52:34 -0700
commite4e93f4d12ab33f6765c82b148b64cb4a808a0ee (patch)
treecad76254806a9a796b3c8e5ed7cc464ff3e2daf1 /src
parent510b3c028f790b88276b23f41562bbca7bdbd748 (diff)
downloadcaffeonacl-e4e93f4d12ab33f6765c82b148b64cb4a808a0ee.tar.gz
caffeonacl-e4e93f4d12ab33f6765c82b148b64cb4a808a0ee.tar.bz2
caffeonacl-e4e93f4d12ab33f6765c82b148b64cb4a808a0ee.zip
compile caffe without MKL (dependency replaced by boost::random, Eigen3)
- examples, test and pycaffe compile without problem (matcaffe not tested) - tests show some errors (on cpu gradient tests), to be investigated - random generators need to be double checked - mkl commented code needs to be removed
Diffstat (limited to 'src')
-rw-r--r--src/caffe/common.cpp23
-rw-r--r--src/caffe/layers/dropout_layer.cpp6
-rw-r--r--src/caffe/layers/inner_product_layer.cpp2
-rw-r--r--src/caffe/test/test_common.cpp17
-rw-r--r--src/caffe/util/math_functions.cpp153
5 files changed, 158 insertions, 43 deletions
diff --git a/src/caffe/common.cpp b/src/caffe/common.cpp
index f47173af..95a5e93a 100644
--- a/src/caffe/common.cpp
+++ b/src/caffe/common.cpp
@@ -21,7 +21,10 @@ int64_t cluster_seedgen(void) {
Caffe::Caffe()
: mode_(Caffe::CPU), phase_(Caffe::TRAIN), cublas_handle_(NULL),
- curand_generator_(NULL), vsl_stream_(NULL) {
+ curand_generator_(NULL),
+ //vsl_stream_(NULL)
+ random_generator_()
+{
// Try to create a cublas handler, and report an error if failed (but we will
// keep the program running as one might just want to run CPU code).
if (cublasCreate(&cublas_handle_) != CUBLAS_STATUS_SUCCESS) {
@@ -34,13 +37,13 @@ Caffe::Caffe()
!= CURAND_STATUS_SUCCESS) {
LOG(ERROR) << "Cannot create Curand generator. Curand won't be available.";
}
+
// Try to create a vsl stream. This should almost always work, but we will
// check it anyway.
- if (vslNewStream(&vsl_stream_, VSL_BRNG_MT19937,
- cluster_seedgen()) != VSL_STATUS_OK) {
- LOG(ERROR) << "Cannot create vsl stream. VSL random number generator "
- << "won't be available.";
- }
+ //if (vslNewStream(&vsl_stream_, VSL_BRNG_MT19937, cluster_seedgen()) != VSL_STATUS_OK) {
+ // LOG(ERROR) << "Cannot create vsl stream. VSL random number generator "
+ // << "won't be available.";
+ //}
}
Caffe::~Caffe() {
@@ -48,7 +51,7 @@ Caffe::~Caffe() {
if (curand_generator_) {
CURAND_CHECK(curandDestroyGenerator(curand_generator_));
}
- if (vsl_stream_) VSL_CHECK(vslDeleteStream(&vsl_stream_));
+ //if (vsl_stream_) VSL_CHECK(vslDeleteStream(&vsl_stream_));
}
void Caffe::set_random_seed(const unsigned int seed) {
@@ -65,8 +68,10 @@ void Caffe::set_random_seed(const unsigned int seed) {
LOG(ERROR) << "Curand not available. Skipping setting the curand seed.";
}
// VSL seed
- VSL_CHECK(vslDeleteStream(&(Get().vsl_stream_)));
- VSL_CHECK(vslNewStream(&(Get().vsl_stream_), VSL_BRNG_MT19937, seed));
+ //VSL_CHECK(vslDeleteStream(&(Get().vsl_stream_)));
+ //VSL_CHECK(vslNewStream(&(Get().vsl_stream_), VSL_BRNG_MT19937, seed));
+ Get().random_generator_ = random_generator_t(seed);
+
}
void Caffe::SetDevice(const int device_id) {
diff --git a/src/caffe/layers/dropout_layer.cpp b/src/caffe/layers/dropout_layer.cpp
index 6cd6ffa8..bfb854bc 100644
--- a/src/caffe/layers/dropout_layer.cpp
+++ b/src/caffe/layers/dropout_layer.cpp
@@ -3,6 +3,7 @@
#include <vector>
#include "caffe/common.hpp"
+#include "caffe/util/math_functions.hpp"
#include "caffe/layer.hpp"
#include "caffe/syncedmem.hpp"
#include "caffe/vision_layers.hpp"
@@ -31,8 +32,9 @@ Dtype DropoutLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const int count = bottom[0]->count();
if (Caffe::phase() == Caffe::TRAIN) {
// Create random numbers
- viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(),
- count, mask, 1. - threshold_);
+ //viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(),
+ // count, mask, 1. - threshold_);
+ caffe_vRngBernoulli<int>(count, mask, 1. - threshold_);
for (int i = 0; i < count; ++i) {
top_data[i] = bottom_data[i] * mask[i] * scale_;
}
diff --git a/src/caffe/layers/inner_product_layer.cpp b/src/caffe/layers/inner_product_layer.cpp
index 92723ef3..a00e2f21 100644
--- a/src/caffe/layers/inner_product_layer.cpp
+++ b/src/caffe/layers/inner_product_layer.cpp
@@ -1,7 +1,7 @@
// Copyright 2013 Yangqing Jia
-#include <mkl.h>
+//#include <mkl.h>
#include <vector>
diff --git a/src/caffe/test/test_common.cpp b/src/caffe/test/test_common.cpp
index 275c6e1b..f5e3fe47 100644
--- a/src/caffe/test/test_common.cpp
+++ b/src/caffe/test/test_common.cpp
@@ -6,7 +6,7 @@
#include "gtest/gtest.h"
#include "caffe/common.hpp"
#include "caffe/syncedmem.hpp"
-
+#include "caffe/util/math_functions.hpp"
#include "caffe/test/test_caffe_main.hpp"
namespace caffe {
@@ -20,7 +20,8 @@ TEST_F(CommonTest, TestCublasHandler) {
}
TEST_F(CommonTest, TestVslStream) {
- EXPECT_TRUE(Caffe::vsl_stream());
+ //EXPECT_TRUE(Caffe::vsl_stream());
+ EXPECT_TRUE(true);
}
TEST_F(CommonTest, TestBrewMode) {
@@ -40,11 +41,15 @@ TEST_F(CommonTest, TestRandSeedCPU) {
SyncedMemory data_a(10 * sizeof(int));
SyncedMemory data_b(10 * sizeof(int));
Caffe::set_random_seed(1701);
- viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(),
- 10, reinterpret_cast<int*>(data_a.mutable_cpu_data()), 0.5);
+ //viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(),
+ // 10, (int*)data_a.mutable_cpu_data(), 0.5);
+ caffe_vRngBernoulli(10, reinterpret_cast<int*>(data_a.mutable_cpu_data()), 0.5);
+
Caffe::set_random_seed(1701);
- viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(),
- 10, reinterpret_cast<int*>(data_b.mutable_cpu_data()), 0.5);
+ //viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(),
+ // 10, (int*)data_b.mutable_cpu_data(), 0.5);
+ caffe_vRngBernoulli(10, reinterpret_cast<int*>(data_b.mutable_cpu_data()), 0.5);
+
for (int i = 0; i < 10; ++i) {
EXPECT_EQ(((const int*)(data_a.cpu_data()))[i],
((const int*)(data_b.cpu_data()))[i]);
diff --git a/src/caffe/util/math_functions.cpp b/src/caffe/util/math_functions.cpp
index 790f00ea..c3c0a69c 100644
--- a/src/caffe/util/math_functions.cpp
+++ b/src/caffe/util/math_functions.cpp
@@ -1,13 +1,22 @@
// Copyright 2013 Yangqing Jia
// Copyright 2014 kloudkl@github
-#include <mkl.h>
+//#include <mkl.h>
+#include <eigen3/Eigen/Dense>
+#include <boost/random.hpp>
+
#include <cublas_v2.h>
#include "caffe/common.hpp"
#include "caffe/util/math_functions.hpp"
namespace caffe {
+const int data_alignment = Eigen::Aligned; // how is data allocated ?
+typedef Eigen::Map<const Eigen::VectorXf, data_alignment> const_map_vector_float_t;
+typedef Eigen::Map<Eigen::VectorXf, data_alignment> map_vector_float_t;
+typedef Eigen::Map<const Eigen::VectorXd, data_alignment> const_map_vector_double_t;
+typedef Eigen::Map<Eigen::VectorXd, data_alignment> map_vector_double_t;
+
template<>
void caffe_cpu_gemm<float>(const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
@@ -120,13 +129,20 @@ void caffe_gpu_axpy<double>(const int N, const double alpha, const double* X,
template <>
void caffe_axpby<float>(const int N, const float alpha, const float* X,
const float beta, float* Y) {
- cblas_saxpby(N, alpha, X, 1, beta, Y, 1);
+ // y := a*x + b*y
+ //cblas_saxpby(N, alpha, X, 1, beta, Y, 1);
+ map_vector_float_t(Y, N) *= beta;
+ map_vector_float_t(Y, N) += (alpha * const_map_vector_float_t(X, N));
+
}
template <>
void caffe_axpby<double>(const int N, const double alpha, const double* X,
const double beta, double* Y) {
- cblas_daxpby(N, alpha, X, 1, beta, Y, 1);
+ // y := a*x + b*y
+ //cblas_daxpby(N, alpha, X, 1, beta, Y, 1);
+ map_vector_double_t(Y, N) *= beta;
+ map_vector_double_t(Y, N) += (alpha * const_map_vector_double_t(X, N));
}
template <>
@@ -185,91 +201,178 @@ void caffe_gpu_axpby<double>(const int N, const double alpha, const double* X,
template <>
void caffe_sqr<float>(const int n, const float* a, float* y) {
- vsSqr(n, a, y);
+ //vsSqr(n, a, y);
+ map_vector_float_t(y, n) = const_map_vector_float_t(a, n).array().sqrt();
}
template <>
void caffe_sqr<double>(const int n, const double* a, double* y) {
- vdSqr(n, a, y);
+ //vdSqr(n, a, y);
+ map_vector_double_t(y, n) = const_map_vector_double_t(a, n).array().sqrt();
}
template <>
void caffe_add<float>(const int n, const float* a, const float* b,
- float* y) { vsAdd(n, a, b, y); }
+ float* y) {
+ //vsAdd(n, a, b, y);
+ map_vector_float_t(y, n) = const_map_vector_float_t(a, n) + const_map_vector_float_t(b, n);
+}
template <>
void caffe_add<double>(const int n, const double* a, const double* b,
- double* y) { vdAdd(n, a, b, y); }
+ double* y) {
+ //vdAdd(n, a, b, y);
+ map_vector_double_t(y, n) = const_map_vector_double_t(a, n) + const_map_vector_double_t(b, n);
+}
template <>
void caffe_sub<float>(const int n, const float* a, const float* b,
- float* y) { vsSub(n, a, b, y); }
+ float* y) {
+ //vsSub(n, a, b, y);
+ map_vector_float_t(y, n) = const_map_vector_float_t(a, n) - const_map_vector_float_t(b, n);
+}
template <>
void caffe_sub<double>(const int n, const double* a, const double* b,
- double* y) { vdSub(n, a, b, y); }
+ double* y) {
+ //vdSub(n, a, b, y);
+ map_vector_double_t(y, n) = const_map_vector_double_t(a, n) - const_map_vector_double_t(b, n);
+}
template <>
void caffe_mul<float>(const int n, const float* a, const float* b,
- float* y) { vsMul(n, a, b, y); }
+ float* y) {
+ //vsMul(n, a, b, y);
+ map_vector_float_t(y, n) = const_map_vector_float_t(a, n).array() * const_map_vector_float_t(b, n).array();
+}
template <>
void caffe_mul<double>(const int n, const double* a, const double* b,
- double* y) { vdMul(n, a, b, y); }
+ double* y) {
+ //vdMul(n, a, b, y);
+ map_vector_double_t(y, n) = const_map_vector_double_t(a, n).array() * const_map_vector_double_t(b, n).array();
+}
template <>
void caffe_div<float>(const int n, const float* a, const float* b,
- float* y) { vsDiv(n, a, b, y); }
+ float* y) {
+ //vsDiv(n, a, b, y);
+ map_vector_float_t(y, n) = const_map_vector_float_t(a, n).array() / const_map_vector_float_t(b, n).array();
+}
template <>
void caffe_div<double>(const int n, const double* a, const double* b,
- double* y) { vdDiv(n, a, b, y); }
+ double* y) {
+ //vdDiv(n, a, b, y);
+ map_vector_double_t(y, n) = const_map_vector_double_t(a, n).array() / const_map_vector_double_t(b, n).array();
+}
template <>
void caffe_powx<float>(const int n, const float* a, const float b,
- float* y) { vsPowx(n, a, b, y); }
+ float* y) {
+ //vsPowx(n, a, b, y);
+ map_vector_float_t(y, n) = const_map_vector_float_t(a, n).array().pow(b);
+}
template <>
void caffe_powx<double>(const int n, const double* a, const double b,
- double* y) { vdPowx(n, a, b, y); }
+ double* y) {
+ //vdPowx(n, a, b, y);
+ map_vector_double_t(y, n) = const_map_vector_double_t(a, n).array().pow(b);
+}
template <>
void caffe_vRngUniform<float>(const int n, float* r,
const float a, const float b) {
- VSL_CHECK(vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD, Caffe::vsl_stream(),
- n, r, a, b));
+ //VSL_CHECK(vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD, Caffe::vsl_stream(),
+ // n, r, a, b));
+
+ // FIXME check if boundaries are handled in the same way ?
+ boost::uniform_real<float> random_distribution(a, b);
+ Caffe::random_generator_t &generator = Caffe::vsl_stream();
+
+ for(int i = 0; i < n; i += 1)
+ {
+ r[i] = random_distribution(generator);
+ }
}
template <>
void caffe_vRngUniform<double>(const int n, double* r,
const double a, const double b) {
- VSL_CHECK(vdRngUniform(VSL_RNG_METHOD_UNIFORM_STD, Caffe::vsl_stream(),
- n, r, a, b));
+ //VSL_CHECK(vdRngUniform(VSL_RNG_METHOD_UNIFORM_STD, Caffe::vsl_stream(),
+ // n, r, a, b));
+
+ // FIXME check if boundaries are handled in the same way ?
+ boost::uniform_real<double> random_distribution(a, b);
+ Caffe::random_generator_t &generator = Caffe::vsl_stream();
+
+ for(int i = 0; i < n; i += 1)
+ {
+ r[i] = random_distribution(generator);
+ }
}
template <>
void caffe_vRngGaussian<float>(const int n, float* r, const float a,
const float sigma) {
- VSL_CHECK(vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER,
- Caffe::vsl_stream(), n, r, a, sigma));
+ //VSL_CHECK(vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER,
+// Caffe::vsl_stream(), n, r, a, sigma));
+
+ // FIXME check if parameters are handled in the same way ?
+ boost::normal_distribution<float> random_distribution(a, sigma);
+ Caffe::random_generator_t &generator = Caffe::vsl_stream();
+
+ for(int i = 0; i < n; i += 1)
+ {
+ r[i] = random_distribution(generator);
+ }
}
template <>
void caffe_vRngGaussian<double>(const int n, double* r, const double a,
const double sigma) {
- VSL_CHECK(vdRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER,
- Caffe::vsl_stream(), n, r, a, sigma));
+ //VSL_CHECK(vdRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER,
+ // Caffe::vsl_stream(), n, r, a, sigma));
+
+ // FIXME check if parameters are handled in the same way ?
+ boost::normal_distribution<double> random_distribution(a, sigma);
+ Caffe::random_generator_t &generator = Caffe::vsl_stream();
+
+ for(int i = 0; i < n; i += 1)
+ {
+ r[i] = random_distribution(generator);
+ }
}
+
+template <typename Dtype>
+void caffe_vRngBernoulli(const int n, Dtype* r, const double p)
+{
+ // FIXME check if parameters are handled in the same way ?
+ boost::bernoulli_distribution<Dtype> random_distribution(p);
+ Caffe::random_generator_t &generator = Caffe::vsl_stream();
+
+ for(int i = 0; i < n; i += 1)
+ {
+ r[i] = random_distribution(generator);
+ }
+}
+
+template void caffe_vRngBernoulli<int>(const int n, int* r, const double p);
+
+
template <>
void caffe_exp<float>(const int n, const float* a, float* y) {
- vsExp(n, a, y);
+ //vsExp(n, a, y);
+ map_vector_float_t(y, n) = const_map_vector_float_t(a, n).array().exp();
}
template <>
void caffe_exp<double>(const int n, const double* a, double* y) {
- vdExp(n, a, y);
+ //vdExp(n, a, y);
+ map_vector_double_t(y, n) = const_map_vector_double_t(a, n).array().exp();
}
template <>