diff options
author | Kai Li <kaili_kloud@163.com> | 2014-01-12 13:55:26 +0800 |
---|---|---|
committer | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2014-03-21 13:52:34 -0700 |
commit | 788f070d063e3f3e5fc8eb0faa53411e966898f6 (patch) | |
tree | 6ab300d209d51f0905f529360eceb4bad847285b /src/caffe/util/math_functions.cpp | |
parent | 38457e1c1f0d5bb9765896c3d5a43eaf19534ec9 (diff) | |
download | caffeonacl-788f070d063e3f3e5fc8eb0faa53411e966898f6.tar.gz caffeonacl-788f070d063e3f3e5fc8eb0faa53411e966898f6.tar.bz2 caffeonacl-788f070d063e3f3e5fc8eb0faa53411e966898f6.zip |
Fix math funcs, add tests, change Eigen Map to unaligned for lrn_layer
[shelhamer: removed math function tests, since they were merged via
other branches]
Diffstat (limited to 'src/caffe/util/math_functions.cpp')
-rw-r--r-- | src/caffe/util/math_functions.cpp | 322 |
1 files changed, 200 insertions, 122 deletions
diff --git a/src/caffe/util/math_functions.cpp b/src/caffe/util/math_functions.cpp index 850a408f..46c82dbd 100644 --- a/src/caffe/util/math_functions.cpp +++ b/src/caffe/util/math_functions.cpp @@ -13,11 +13,22 @@ namespace caffe { -const int data_alignment = Eigen::Aligned; // how is data allocated ? -typedef Eigen::Map<const Eigen::VectorXf, data_alignment> const_map_vector_float_t; -typedef Eigen::Map<Eigen::VectorXf, data_alignment> map_vector_float_t; -typedef Eigen::Map<const Eigen::VectorXd, data_alignment> const_map_vector_double_t; -typedef Eigen::Map<Eigen::VectorXd, data_alignment> map_vector_double_t; +// Operations on aligned memory are faster than on unaligned memory. +// But unfortunately, the pointers passed in are not always aligned. +// Therefore, the memory-aligned Eigen::Map objects that wrap them +// cannot be assigned to. This happens in lrn_layer and makes +// test_lrn_layer crash with segmentation fault. +// TODO: Use aligned Eigen::Map when the pointer to be wrapped is aligned. + +// Though the default map option is unaligned, making it explicit is no harm. +//const int data_alignment = Eigen::Aligned; // how is data allocated ? +const int data_alignment = Eigen::Unaligned; +typedef Eigen::Array<float, 1, Eigen::Dynamic> float_array_t; +typedef Eigen::Map<const float_array_t, data_alignment> const_map_vector_float_t; +typedef Eigen::Map<float_array_t, data_alignment> map_vector_float_t; +typedef Eigen::Array<double, 1, Eigen::Dynamic> double_array_t; +typedef Eigen::Map<const double_array_t, data_alignment> const_map_vector_double_t; +typedef Eigen::Map<double_array_t, data_alignment> map_vector_double_t; template<> void caffe_cpu_gemm<float>(const CBLAS_TRANSPOSE TransA, @@ -129,25 +140,6 @@ void caffe_gpu_axpy<double>(const int N, const double alpha, const double* X, } template <> -void caffe_axpby<float>(const int N, const float alpha, const float* X, - const float beta, float* Y) { - // y := a*x + b*y - //cblas_saxpby(N, alpha, X, 1, beta, Y, 1); - map_vector_float_t(Y, N) *= beta; - map_vector_float_t(Y, N) += (alpha * const_map_vector_float_t(X, N)); - -} - -template <> -void caffe_axpby<double>(const int N, const double alpha, const double* X, - const double beta, double* Y) { - // y := a*x + b*y - //cblas_daxpby(N, alpha, X, 1, beta, Y, 1); - map_vector_double_t(Y, N) *= beta; - map_vector_double_t(Y, N) += (alpha * const_map_vector_double_t(X, N)); -} - -template <> void caffe_copy<float>(const int N, const float* X, float* Y) { cblas_scopy(N, X, 1, Y, 1); } @@ -202,189 +194,275 @@ void caffe_gpu_axpby<double>(const int N, const double alpha, const double* X, } template <> -void caffe_sqr<float>(const int n, const float* a, float* y) { - //vsSqr(n, a, y); - map_vector_float_t(y, n) = const_map_vector_float_t(a, n).array().sqrt(); +void caffe_axpby<float>(const int N, const float alpha, const float* X, + const float beta, float* Y) { + // y := a*x + b*y + //cblas_saxpby(N, alpha, X, 1, beta, Y, 1); + CHECK_GE(N, 0); + CHECK(X); + CHECK(Y); + map_vector_float_t y_map(Y, N); + // Eigen produces optimized code using lasy evaluation + // http://eigen.tuxfamily.org/dox/TopicLazyEvaluation.html + y_map = const_map_vector_float_t(X, N) * alpha + y_map * beta; } template <> -void caffe_sqr<double>(const int n, const double* a, double* y) { - //vdSqr(n, a, y); - map_vector_double_t(y, n) = const_map_vector_double_t(a, n).array().sqrt(); +void caffe_axpby<double>(const int N, const double alpha, const double* X, + const double beta, double* Y) { + // y := a*x + b*y + //cblas_daxpby(N, alpha, X, 1, beta, Y, 1); + CHECK_GE(N, 0); + CHECK(X); + CHECK(Y); + map_vector_double_t y_map(Y, N); + y_map = const_map_vector_double_t(X, N) * alpha + y_map * beta; } template <> void caffe_add<float>(const int n, const float* a, const float* b, float* y) { - //vsAdd(n, a, b, y); - map_vector_float_t(y, n) = const_map_vector_float_t(a, n) + const_map_vector_float_t(b, n); + //vsAdd(n, a, b, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(b); + CHECK(y); + map_vector_float_t(y, n) = const_map_vector_float_t(a, n) + + const_map_vector_float_t(b, n); } template <> void caffe_add<double>(const int n, const double* a, const double* b, double* y) { - //vdAdd(n, a, b, y); - map_vector_double_t(y, n) = const_map_vector_double_t(a, n) + const_map_vector_double_t(b, n); + //vdAdd(n, a, b, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(b); + CHECK(y); + map_vector_double_t(y, n) = const_map_vector_double_t(a, n) + + const_map_vector_double_t(b, n); } template <> void caffe_sub<float>(const int n, const float* a, const float* b, float* y) { - //vsSub(n, a, b, y); - map_vector_float_t(y, n) = const_map_vector_float_t(a, n) - const_map_vector_float_t(b, n); + //vsSub(n, a, b, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(b); + CHECK(y); + map_vector_float_t(y, n) = const_map_vector_float_t(a, n) - + const_map_vector_float_t(b, n); } template <> void caffe_sub<double>(const int n, const double* a, const double* b, double* y) { - //vdSub(n, a, b, y); - map_vector_double_t(y, n) = const_map_vector_double_t(a, n) - const_map_vector_double_t(b, n); + //vdSub(n, a, b, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(b); + CHECK(y); + map_vector_double_t(y, n) = const_map_vector_double_t(a, n) - + const_map_vector_double_t(b, n); } template <> void caffe_mul<float>(const int n, const float* a, const float* b, float* y) { - //vsMul(n, a, b, y); - map_vector_float_t(y, n) = const_map_vector_float_t(a, n).array() * const_map_vector_float_t(b, n).array(); + //vsMul(n, a, b, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(b); + CHECK(y); + map_vector_float_t(y, n) = const_map_vector_float_t(a, n) * + const_map_vector_float_t(b, n); } template <> void caffe_mul<double>(const int n, const double* a, const double* b, double* y) { - //vdMul(n, a, b, y); - map_vector_double_t(y, n) = const_map_vector_double_t(a, n).array() * const_map_vector_double_t(b, n).array(); + //vdMul(n, a, b, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(b); + CHECK(y); + map_vector_double_t(y, n) = const_map_vector_double_t(a, n) * + const_map_vector_double_t(b, n); } template <> void caffe_div<float>(const int n, const float* a, const float* b, float* y) { - //vsDiv(n, a, b, y); - map_vector_float_t(y, n) = const_map_vector_float_t(a, n).array() / const_map_vector_float_t(b, n).array(); + //vsDiv(n, a, b, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(b); + CHECK(y); + map_vector_float_t(y, n) = const_map_vector_float_t(a, n) / + const_map_vector_float_t(b, n); } template <> void caffe_div<double>(const int n, const double* a, const double* b, double* y) { - //vdDiv(n, a, b, y); - map_vector_double_t(y, n) = const_map_vector_double_t(a, n).array() / const_map_vector_double_t(b, n).array(); + //vdDiv(n, a, b, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(b); + CHECK(y); + map_vector_double_t(y, n) = const_map_vector_double_t(a, n) / + const_map_vector_double_t(b, n); } template <> void caffe_powx<float>(const int n, const float* a, const float b, float* y) { - //vsPowx(n, a, b, y); - map_vector_float_t(y, n) = const_map_vector_float_t(a, n).array().pow(b); + //vsPowx(n, a, b, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(y); + map_vector_float_t(y, n) = const_map_vector_float_t(a, n).pow(b); } template <> void caffe_powx<double>(const int n, const double* a, const double b, double* y) { - //vdPowx(n, a, b, y); - map_vector_double_t(y, n) = const_map_vector_double_t(a, n).array().pow(b); + //vdPowx(n, a, b, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(y); + map_vector_double_t(y, n) = const_map_vector_double_t(a, n).pow(b); +} + +template <> +void caffe_sqr<float>(const int n, const float* a, float* y) { + // http://software.intel.com/sites/products/documentation/hpc/mkl/mklman/GUID-F003F826-81BF-42EC-AE51-2EF624893133.htm + // v?Sqr Performs element by element squaring of the vector. + //vsSqr(n, a, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(y); + caffe_powx<float>(n, a, 2, y); + // TODO: which is faster? +// map_vector_float_t(y, n) = const_map_vector_float_t(a, n) * +// const_map_vector_float_t(a, n); +} + +template <> +void caffe_sqr<double>(const int n, const double* a, double* y) { + //vdSqr(n, a, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(y); + caffe_powx<double>(n, a, 2, y); +} + +template <> +void caffe_exp<float>(const int n, const float* a, float* y) { + //vsExp(n, a, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(y); + map_vector_float_t(y, n) = const_map_vector_float_t(a, n).exp(); +} + +template <> +void caffe_exp<double>(const int n, const double* a, double* y) { + //vdExp(n, a, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(y); + map_vector_double_t(y, n) = const_map_vector_double_t(a, n).exp(); } template <typename Dtype> Dtype caffe_nextafter(const Dtype b) { - return boost::math::nextafter<Dtype, Dtype>(b, std::numeric_limits<Dtype>::max()); + return boost::math::nextafter<Dtype, Dtype>( + b, std::numeric_limits<Dtype>::max()); } -template <> -void caffe_vRngUniform<float>(const int n, float* r, - const float a, const float b) { +template +float caffe_nextafter(const float b); + +template +double caffe_nextafter(const double b); + +template <typename Dtype> +void caffe_vRngUniform(const int n, Dtype* r, + const Dtype a, const Dtype b) { + CHECK_GE(n, 0); + CHECK(r); + CHECK_LE(a, b); //VSL_CHECK(vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD, Caffe::vsl_stream(), // n, r, a, b)); // FIXME check if boundaries are handled in the same way ? - boost::random::uniform_real_distribution<float> random_distribution( - a, caffe_nextafter<float>(b)); + // Fixed by caffe_nextafter + boost::random::uniform_real_distribution<Dtype> random_distribution( + a, caffe_nextafter<Dtype>(b)); Caffe::random_generator_t &generator = Caffe::vsl_stream(); - for(int i = 0; i < n; i += 1) - { - r[i] = random_distribution(generator); + for(int i = 0; i < n; i += 1) { + r[i] = random_distribution(generator); } } -template <> +template +void caffe_vRngUniform<float>(const int n, float* r, + const float a, const float b); +template void caffe_vRngUniform<double>(const int n, double* r, - const double a, const double b) { - //VSL_CHECK(vdRngUniform(VSL_RNG_METHOD_UNIFORM_STD, Caffe::vsl_stream(), - // n, r, a, b)); - - // FIXME check if boundaries are handled in the same way ? - boost::random::uniform_real_distribution<double> random_distribution( - a, caffe_nextafter<double>(b)); - Caffe::random_generator_t &generator = Caffe::vsl_stream(); + const double a, const double b); - for(int i = 0; i < n; i += 1) - { - r[i] = random_distribution(generator); - } -} - -template <> -void caffe_vRngGaussian<float>(const int n, float* r, const float a, - const float sigma) { - DCHECK(sigma > 0); +template <typename Dtype> +void caffe_vRngGaussian(const int n, Dtype* r, const Dtype a, + const Dtype sigma) { + CHECK_GE(n, 0); + CHECK(r); + CHECK_GT(sigma, 0); //VSL_CHECK(vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER, // Caffe::vsl_stream(), n, r, a, sigma)); // FIXME check if parameters are handled in the same way ? - boost::normal_distribution<float> random_distribution(a, sigma); - Caffe::random_generator_t &generator = Caffe::vsl_stream(); + // http://www.boost.org/doc/libs/1_55_0/doc/html/boost/random/normal_distribution.html + // http://software.intel.com/sites/products/documentation/hpc/mkl/mklman/GUID-63196F25-5013-4038-8BCD-2613C4EF3DE4.htm + // The above two documents show that the probability density functions are different. + // But the unit tests still pass. Maybe their codes are the same or + // the tests are irrelevant to the random numbers. + boost::normal_distribution<Dtype> random_distribution(a, sigma); + Caffe::random_generator_t &generator = Caffe::vsl_stream(); - for(int i = 0; i < n; i += 1) - { - r[i] = random_distribution(generator); - } + for(int i = 0; i < n; i += 1) { + r[i] = random_distribution(generator); + } } +template +void caffe_vRngGaussian<float>(const int n, float* r, const float a, + const float sigma); -template <> +template void caffe_vRngGaussian<double>(const int n, double* r, const double a, - const double sigma) { - DCHECK(sigma > 0); - //VSL_CHECK(vdRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER, - // Caffe::vsl_stream(), n, r, a, sigma)); - - // FIXME check if parameters are handled in the same way ? - boost::normal_distribution<double> random_distribution(a, sigma); - Caffe::random_generator_t &generator = Caffe::vsl_stream(); - - for(int i = 0; i < n; i += 1) - { - r[i] = random_distribution(generator); - } -} - + const double sigma); template <typename Dtype> -void caffe_vRngBernoulli(const int n, Dtype* r, const double p) -{ +void caffe_vRngBernoulli(const int n, Dtype* r, const double p) { + CHECK_GE(n, 0); + CHECK(r); + CHECK_GE(p, 0); + CHECK_LE(p, 1); // FIXME check if parameters are handled in the same way ? - boost::bernoulli_distribution<Dtype> random_distribution(p); - Caffe::random_generator_t &generator = Caffe::vsl_stream(); - - for(int i = 0; i < n; i += 1) - { - r[i] = random_distribution(generator); - } -} - -template void caffe_vRngBernoulli<int>(const int n, int* r, const double p); - + boost::bernoulli_distribution<Dtype> random_distribution(p); + Caffe::random_generator_t &generator = Caffe::vsl_stream(); -template <> -void caffe_exp<float>(const int n, const float* a, float* y) { - //vsExp(n, a, y); - map_vector_float_t(y, n) = const_map_vector_float_t(a, n).array().exp(); + for(int i = 0; i < n; i += 1) { + r[i] = random_distribution(generator); + } } -template <> -void caffe_exp<double>(const int n, const double* a, double* y) { - //vdExp(n, a, y); - map_vector_double_t(y, n) = const_map_vector_double_t(a, n).array().exp(); -} +template +void caffe_vRngBernoulli<int>(const int n, int* r, const double p); template <> float caffe_cpu_dot<float>(const int n, const float* x, const float* y) { |