summaryrefslogtreecommitdiff
path: root/src/caffe/util/math_functions.cpp
diff options
context:
space:
mode:
authorKai Li <kaili_kloud@163.com>2014-01-12 13:55:26 +0800
committerEvan Shelhamer <shelhamer@imaginarynumber.net>2014-03-21 13:52:34 -0700
commit788f070d063e3f3e5fc8eb0faa53411e966898f6 (patch)
tree6ab300d209d51f0905f529360eceb4bad847285b /src/caffe/util/math_functions.cpp
parent38457e1c1f0d5bb9765896c3d5a43eaf19534ec9 (diff)
downloadcaffeonacl-788f070d063e3f3e5fc8eb0faa53411e966898f6.tar.gz
caffeonacl-788f070d063e3f3e5fc8eb0faa53411e966898f6.tar.bz2
caffeonacl-788f070d063e3f3e5fc8eb0faa53411e966898f6.zip
Fix math funcs, add tests, change Eigen Map to unaligned for lrn_layer
[shelhamer: removed math function tests, since they were merged via other branches]
Diffstat (limited to 'src/caffe/util/math_functions.cpp')
-rw-r--r--src/caffe/util/math_functions.cpp322
1 files changed, 200 insertions, 122 deletions
diff --git a/src/caffe/util/math_functions.cpp b/src/caffe/util/math_functions.cpp
index 850a408f..46c82dbd 100644
--- a/src/caffe/util/math_functions.cpp
+++ b/src/caffe/util/math_functions.cpp
@@ -13,11 +13,22 @@
namespace caffe {
-const int data_alignment = Eigen::Aligned; // how is data allocated ?
-typedef Eigen::Map<const Eigen::VectorXf, data_alignment> const_map_vector_float_t;
-typedef Eigen::Map<Eigen::VectorXf, data_alignment> map_vector_float_t;
-typedef Eigen::Map<const Eigen::VectorXd, data_alignment> const_map_vector_double_t;
-typedef Eigen::Map<Eigen::VectorXd, data_alignment> map_vector_double_t;
+// Operations on aligned memory are faster than on unaligned memory.
+// But unfortunately, the pointers passed in are not always aligned.
+// Therefore, the memory-aligned Eigen::Map objects that wrap them
+// cannot be assigned to. This happens in lrn_layer and makes
+// test_lrn_layer crash with segmentation fault.
+// TODO: Use aligned Eigen::Map when the pointer to be wrapped is aligned.
+
+// Though the default map option is unaligned, making it explicit is no harm.
+//const int data_alignment = Eigen::Aligned; // how is data allocated ?
+const int data_alignment = Eigen::Unaligned;
+typedef Eigen::Array<float, 1, Eigen::Dynamic> float_array_t;
+typedef Eigen::Map<const float_array_t, data_alignment> const_map_vector_float_t;
+typedef Eigen::Map<float_array_t, data_alignment> map_vector_float_t;
+typedef Eigen::Array<double, 1, Eigen::Dynamic> double_array_t;
+typedef Eigen::Map<const double_array_t, data_alignment> const_map_vector_double_t;
+typedef Eigen::Map<double_array_t, data_alignment> map_vector_double_t;
template<>
void caffe_cpu_gemm<float>(const CBLAS_TRANSPOSE TransA,
@@ -129,25 +140,6 @@ void caffe_gpu_axpy<double>(const int N, const double alpha, const double* X,
}
template <>
-void caffe_axpby<float>(const int N, const float alpha, const float* X,
- const float beta, float* Y) {
- // y := a*x + b*y
- //cblas_saxpby(N, alpha, X, 1, beta, Y, 1);
- map_vector_float_t(Y, N) *= beta;
- map_vector_float_t(Y, N) += (alpha * const_map_vector_float_t(X, N));
-
-}
-
-template <>
-void caffe_axpby<double>(const int N, const double alpha, const double* X,
- const double beta, double* Y) {
- // y := a*x + b*y
- //cblas_daxpby(N, alpha, X, 1, beta, Y, 1);
- map_vector_double_t(Y, N) *= beta;
- map_vector_double_t(Y, N) += (alpha * const_map_vector_double_t(X, N));
-}
-
-template <>
void caffe_copy<float>(const int N, const float* X, float* Y) {
cblas_scopy(N, X, 1, Y, 1);
}
@@ -202,189 +194,275 @@ void caffe_gpu_axpby<double>(const int N, const double alpha, const double* X,
}
template <>
-void caffe_sqr<float>(const int n, const float* a, float* y) {
- //vsSqr(n, a, y);
- map_vector_float_t(y, n) = const_map_vector_float_t(a, n).array().sqrt();
+void caffe_axpby<float>(const int N, const float alpha, const float* X,
+ const float beta, float* Y) {
+ // y := a*x + b*y
+ //cblas_saxpby(N, alpha, X, 1, beta, Y, 1);
+ CHECK_GE(N, 0);
+ CHECK(X);
+ CHECK(Y);
+ map_vector_float_t y_map(Y, N);
+ // Eigen produces optimized code using lasy evaluation
+ // http://eigen.tuxfamily.org/dox/TopicLazyEvaluation.html
+ y_map = const_map_vector_float_t(X, N) * alpha + y_map * beta;
}
template <>
-void caffe_sqr<double>(const int n, const double* a, double* y) {
- //vdSqr(n, a, y);
- map_vector_double_t(y, n) = const_map_vector_double_t(a, n).array().sqrt();
+void caffe_axpby<double>(const int N, const double alpha, const double* X,
+ const double beta, double* Y) {
+ // y := a*x + b*y
+ //cblas_daxpby(N, alpha, X, 1, beta, Y, 1);
+ CHECK_GE(N, 0);
+ CHECK(X);
+ CHECK(Y);
+ map_vector_double_t y_map(Y, N);
+ y_map = const_map_vector_double_t(X, N) * alpha + y_map * beta;
}
template <>
void caffe_add<float>(const int n, const float* a, const float* b,
float* y) {
- //vsAdd(n, a, b, y);
- map_vector_float_t(y, n) = const_map_vector_float_t(a, n) + const_map_vector_float_t(b, n);
+ //vsAdd(n, a, b, y);
+ CHECK_GE(n, 0);
+ CHECK(a);
+ CHECK(b);
+ CHECK(y);
+ map_vector_float_t(y, n) = const_map_vector_float_t(a, n) +
+ const_map_vector_float_t(b, n);
}
template <>
void caffe_add<double>(const int n, const double* a, const double* b,
double* y) {
- //vdAdd(n, a, b, y);
- map_vector_double_t(y, n) = const_map_vector_double_t(a, n) + const_map_vector_double_t(b, n);
+ //vdAdd(n, a, b, y);
+ CHECK_GE(n, 0);
+ CHECK(a);
+ CHECK(b);
+ CHECK(y);
+ map_vector_double_t(y, n) = const_map_vector_double_t(a, n) +
+ const_map_vector_double_t(b, n);
}
template <>
void caffe_sub<float>(const int n, const float* a, const float* b,
float* y) {
- //vsSub(n, a, b, y);
- map_vector_float_t(y, n) = const_map_vector_float_t(a, n) - const_map_vector_float_t(b, n);
+ //vsSub(n, a, b, y);
+ CHECK_GE(n, 0);
+ CHECK(a);
+ CHECK(b);
+ CHECK(y);
+ map_vector_float_t(y, n) = const_map_vector_float_t(a, n) -
+ const_map_vector_float_t(b, n);
}
template <>
void caffe_sub<double>(const int n, const double* a, const double* b,
double* y) {
- //vdSub(n, a, b, y);
- map_vector_double_t(y, n) = const_map_vector_double_t(a, n) - const_map_vector_double_t(b, n);
+ //vdSub(n, a, b, y);
+ CHECK_GE(n, 0);
+ CHECK(a);
+ CHECK(b);
+ CHECK(y);
+ map_vector_double_t(y, n) = const_map_vector_double_t(a, n) -
+ const_map_vector_double_t(b, n);
}
template <>
void caffe_mul<float>(const int n, const float* a, const float* b,
float* y) {
- //vsMul(n, a, b, y);
- map_vector_float_t(y, n) = const_map_vector_float_t(a, n).array() * const_map_vector_float_t(b, n).array();
+ //vsMul(n, a, b, y);
+ CHECK_GE(n, 0);
+ CHECK(a);
+ CHECK(b);
+ CHECK(y);
+ map_vector_float_t(y, n) = const_map_vector_float_t(a, n) *
+ const_map_vector_float_t(b, n);
}
template <>
void caffe_mul<double>(const int n, const double* a, const double* b,
double* y) {
- //vdMul(n, a, b, y);
- map_vector_double_t(y, n) = const_map_vector_double_t(a, n).array() * const_map_vector_double_t(b, n).array();
+ //vdMul(n, a, b, y);
+ CHECK_GE(n, 0);
+ CHECK(a);
+ CHECK(b);
+ CHECK(y);
+ map_vector_double_t(y, n) = const_map_vector_double_t(a, n) *
+ const_map_vector_double_t(b, n);
}
template <>
void caffe_div<float>(const int n, const float* a, const float* b,
float* y) {
- //vsDiv(n, a, b, y);
- map_vector_float_t(y, n) = const_map_vector_float_t(a, n).array() / const_map_vector_float_t(b, n).array();
+ //vsDiv(n, a, b, y);
+ CHECK_GE(n, 0);
+ CHECK(a);
+ CHECK(b);
+ CHECK(y);
+ map_vector_float_t(y, n) = const_map_vector_float_t(a, n) /
+ const_map_vector_float_t(b, n);
}
template <>
void caffe_div<double>(const int n, const double* a, const double* b,
double* y) {
- //vdDiv(n, a, b, y);
- map_vector_double_t(y, n) = const_map_vector_double_t(a, n).array() / const_map_vector_double_t(b, n).array();
+ //vdDiv(n, a, b, y);
+ CHECK_GE(n, 0);
+ CHECK(a);
+ CHECK(b);
+ CHECK(y);
+ map_vector_double_t(y, n) = const_map_vector_double_t(a, n) /
+ const_map_vector_double_t(b, n);
}
template <>
void caffe_powx<float>(const int n, const float* a, const float b,
float* y) {
- //vsPowx(n, a, b, y);
- map_vector_float_t(y, n) = const_map_vector_float_t(a, n).array().pow(b);
+ //vsPowx(n, a, b, y);
+ CHECK_GE(n, 0);
+ CHECK(a);
+ CHECK(y);
+ map_vector_float_t(y, n) = const_map_vector_float_t(a, n).pow(b);
}
template <>
void caffe_powx<double>(const int n, const double* a, const double b,
double* y) {
- //vdPowx(n, a, b, y);
- map_vector_double_t(y, n) = const_map_vector_double_t(a, n).array().pow(b);
+ //vdPowx(n, a, b, y);
+ CHECK_GE(n, 0);
+ CHECK(a);
+ CHECK(y);
+ map_vector_double_t(y, n) = const_map_vector_double_t(a, n).pow(b);
+}
+
+template <>
+void caffe_sqr<float>(const int n, const float* a, float* y) {
+ // http://software.intel.com/sites/products/documentation/hpc/mkl/mklman/GUID-F003F826-81BF-42EC-AE51-2EF624893133.htm
+ // v?Sqr Performs element by element squaring of the vector.
+ //vsSqr(n, a, y);
+ CHECK_GE(n, 0);
+ CHECK(a);
+ CHECK(y);
+ caffe_powx<float>(n, a, 2, y);
+ // TODO: which is faster?
+// map_vector_float_t(y, n) = const_map_vector_float_t(a, n) *
+// const_map_vector_float_t(a, n);
+}
+
+template <>
+void caffe_sqr<double>(const int n, const double* a, double* y) {
+ //vdSqr(n, a, y);
+ CHECK_GE(n, 0);
+ CHECK(a);
+ CHECK(y);
+ caffe_powx<double>(n, a, 2, y);
+}
+
+template <>
+void caffe_exp<float>(const int n, const float* a, float* y) {
+ //vsExp(n, a, y);
+ CHECK_GE(n, 0);
+ CHECK(a);
+ CHECK(y);
+ map_vector_float_t(y, n) = const_map_vector_float_t(a, n).exp();
+}
+
+template <>
+void caffe_exp<double>(const int n, const double* a, double* y) {
+ //vdExp(n, a, y);
+ CHECK_GE(n, 0);
+ CHECK(a);
+ CHECK(y);
+ map_vector_double_t(y, n) = const_map_vector_double_t(a, n).exp();
}
template <typename Dtype>
Dtype caffe_nextafter(const Dtype b) {
- return boost::math::nextafter<Dtype, Dtype>(b, std::numeric_limits<Dtype>::max());
+ return boost::math::nextafter<Dtype, Dtype>(
+ b, std::numeric_limits<Dtype>::max());
}
-template <>
-void caffe_vRngUniform<float>(const int n, float* r,
- const float a, const float b) {
+template
+float caffe_nextafter(const float b);
+
+template
+double caffe_nextafter(const double b);
+
+template <typename Dtype>
+void caffe_vRngUniform(const int n, Dtype* r,
+ const Dtype a, const Dtype b) {
+ CHECK_GE(n, 0);
+ CHECK(r);
+ CHECK_LE(a, b);
//VSL_CHECK(vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD, Caffe::vsl_stream(),
// n, r, a, b));
// FIXME check if boundaries are handled in the same way ?
- boost::random::uniform_real_distribution<float> random_distribution(
- a, caffe_nextafter<float>(b));
+ // Fixed by caffe_nextafter
+ boost::random::uniform_real_distribution<Dtype> random_distribution(
+ a, caffe_nextafter<Dtype>(b));
Caffe::random_generator_t &generator = Caffe::vsl_stream();
- for(int i = 0; i < n; i += 1)
- {
- r[i] = random_distribution(generator);
+ for(int i = 0; i < n; i += 1) {
+ r[i] = random_distribution(generator);
}
}
-template <>
+template
+void caffe_vRngUniform<float>(const int n, float* r,
+ const float a, const float b);
+template
void caffe_vRngUniform<double>(const int n, double* r,
- const double a, const double b) {
- //VSL_CHECK(vdRngUniform(VSL_RNG_METHOD_UNIFORM_STD, Caffe::vsl_stream(),
- // n, r, a, b));
-
- // FIXME check if boundaries are handled in the same way ?
- boost::random::uniform_real_distribution<double> random_distribution(
- a, caffe_nextafter<double>(b));
- Caffe::random_generator_t &generator = Caffe::vsl_stream();
+ const double a, const double b);
- for(int i = 0; i < n; i += 1)
- {
- r[i] = random_distribution(generator);
- }
-}
-
-template <>
-void caffe_vRngGaussian<float>(const int n, float* r, const float a,
- const float sigma) {
- DCHECK(sigma > 0);
+template <typename Dtype>
+void caffe_vRngGaussian(const int n, Dtype* r, const Dtype a,
+ const Dtype sigma) {
+ CHECK_GE(n, 0);
+ CHECK(r);
+ CHECK_GT(sigma, 0);
//VSL_CHECK(vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER,
// Caffe::vsl_stream(), n, r, a, sigma));
// FIXME check if parameters are handled in the same way ?
- boost::normal_distribution<float> random_distribution(a, sigma);
- Caffe::random_generator_t &generator = Caffe::vsl_stream();
+ // http://www.boost.org/doc/libs/1_55_0/doc/html/boost/random/normal_distribution.html
+ // http://software.intel.com/sites/products/documentation/hpc/mkl/mklman/GUID-63196F25-5013-4038-8BCD-2613C4EF3DE4.htm
+ // The above two documents show that the probability density functions are different.
+ // But the unit tests still pass. Maybe their codes are the same or
+ // the tests are irrelevant to the random numbers.
+ boost::normal_distribution<Dtype> random_distribution(a, sigma);
+ Caffe::random_generator_t &generator = Caffe::vsl_stream();
- for(int i = 0; i < n; i += 1)
- {
- r[i] = random_distribution(generator);
- }
+ for(int i = 0; i < n; i += 1) {
+ r[i] = random_distribution(generator);
+ }
}
+template
+void caffe_vRngGaussian<float>(const int n, float* r, const float a,
+ const float sigma);
-template <>
+template
void caffe_vRngGaussian<double>(const int n, double* r, const double a,
- const double sigma) {
- DCHECK(sigma > 0);
- //VSL_CHECK(vdRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER,
- // Caffe::vsl_stream(), n, r, a, sigma));
-
- // FIXME check if parameters are handled in the same way ?
- boost::normal_distribution<double> random_distribution(a, sigma);
- Caffe::random_generator_t &generator = Caffe::vsl_stream();
-
- for(int i = 0; i < n; i += 1)
- {
- r[i] = random_distribution(generator);
- }
-}
-
+ const double sigma);
template <typename Dtype>
-void caffe_vRngBernoulli(const int n, Dtype* r, const double p)
-{
+void caffe_vRngBernoulli(const int n, Dtype* r, const double p) {
+ CHECK_GE(n, 0);
+ CHECK(r);
+ CHECK_GE(p, 0);
+ CHECK_LE(p, 1);
// FIXME check if parameters are handled in the same way ?
- boost::bernoulli_distribution<Dtype> random_distribution(p);
- Caffe::random_generator_t &generator = Caffe::vsl_stream();
-
- for(int i = 0; i < n; i += 1)
- {
- r[i] = random_distribution(generator);
- }
-}
-
-template void caffe_vRngBernoulli<int>(const int n, int* r, const double p);
-
+ boost::bernoulli_distribution<Dtype> random_distribution(p);
+ Caffe::random_generator_t &generator = Caffe::vsl_stream();
-template <>
-void caffe_exp<float>(const int n, const float* a, float* y) {
- //vsExp(n, a, y);
- map_vector_float_t(y, n) = const_map_vector_float_t(a, n).array().exp();
+ for(int i = 0; i < n; i += 1) {
+ r[i] = random_distribution(generator);
+ }
}
-template <>
-void caffe_exp<double>(const int n, const double* a, double* y) {
- //vdExp(n, a, y);
- map_vector_double_t(y, n) = const_map_vector_double_t(a, n).array().exp();
-}
+template
+void caffe_vRngBernoulli<int>(const int n, int* r, const double p);
template <>
float caffe_cpu_dot<float>(const int n, const float* x, const float* y) {