diff options
author | Aapo Kyrola <akyrola@fb.com> | 2017-04-17 21:23:45 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2017-04-17 21:31:20 -0700 |
commit | 9ab077dc9d0bbe651348a498dd5472dc4d51f0af (patch) | |
tree | 53060eb0ae8e9e03d765cf80d63e1f37e08ef1c8 | |
parent | 391fd141150d250ee60027fad4acf8aecefbae51 (diff) | |
download | pytorch-9ab077dc9d0bbe651348a498dd5472dc4d51f0af.tar.gz pytorch-9ab077dc9d0bbe651348a498dd5472dc4d51f0af.tar.bz2 pytorch-9ab077dc9d0bbe651348a498dd5472dc4d51f0af.zip |
Revert D4871248: [caffe2][PR] fp16 support for FullyConnected op
Summary: This reverts commit 6a991c2c993dcf0b1e18aa3f2ffbe19e693dbadd
Differential Revision: D4871248
fbshipit-source-id: b6d812d09a00c83e363432e84742c503abfed65b
-rw-r--r-- | caffe2/contrib/nervana/nervana_fc_op_gpu.cc | 9 | ||||
-rw-r--r-- | caffe2/contrib/nervana/nervana_fc_op_gpu_test.cc | 2 | ||||
-rw-r--r-- | caffe2/contrib/nervana/nervana_math_gpu.cc | 16 | ||||
-rw-r--r-- | caffe2/operators/elementwise_op.cu | 58 | ||||
-rw-r--r-- | caffe2/operators/fully_connected_op.cc | 4 | ||||
-rw-r--r-- | caffe2/operators/fully_connected_op.h | 111 | ||||
-rw-r--r-- | caffe2/operators/fully_connected_op_gpu.cc | 57 | ||||
-rw-r--r-- | caffe2/operators/sparse_to_dense_op.h | 1 | ||||
-rw-r--r-- | caffe2/operators/square_root_divide_op.h | 2 | ||||
-rw-r--r-- | caffe2/operators/utility_ops.cc | 13 | ||||
-rw-r--r-- | caffe2/operators/utility_ops.h | 37 | ||||
-rw-r--r-- | caffe2/operators/utility_ops_gpu.cc | 28 | ||||
-rw-r--r-- | caffe2/utils/conversions.h | 182 | ||||
-rw-r--r-- | caffe2/utils/math-detail.h | 36 | ||||
-rw-r--r-- | caffe2/utils/math.h | 53 | ||||
-rw-r--r-- | caffe2/utils/math_cpu.cc | 119 | ||||
-rw-r--r-- | caffe2/utils/math_gpu.cu | 423 | ||||
-rw-r--r-- | caffe2/utils/math_gpu_test.cc | 57 |
18 files changed, 312 insertions, 896 deletions
diff --git a/caffe2/contrib/nervana/nervana_fc_op_gpu.cc b/caffe2/contrib/nervana/nervana_fc_op_gpu.cc index 8d33a7c20a..b2328500c7 100644 --- a/caffe2/contrib/nervana/nervana_fc_op_gpu.cc +++ b/caffe2/contrib/nervana/nervana_fc_op_gpu.cc @@ -5,11 +5,8 @@ namespace caffe2 { REGISTER_CUDA_OPERATOR_WITH_ENGINE( - FC, - NERVANA, - FullyConnectedOp<CUDAContext, NervanaEngine>); + FC, NERVANA, FullyConnectedOp<float, CUDAContext, NervanaEngine>); REGISTER_CUDA_OPERATOR_WITH_ENGINE( - FCGradient, - NERVANA, - FullyConnectedGradientOp<CUDAContext, NervanaEngine>); + FCGradient, NERVANA, + FullyConnectedGradientOp<float, CUDAContext, NervanaEngine>); } // namespace caffe2 diff --git a/caffe2/contrib/nervana/nervana_fc_op_gpu_test.cc b/caffe2/contrib/nervana/nervana_fc_op_gpu_test.cc index 3eb0fc3ace..a3ae3bb45a 100644 --- a/caffe2/contrib/nervana/nervana_fc_op_gpu_test.cc +++ b/caffe2/contrib/nervana/nervana_fc_op_gpu_test.cc @@ -49,7 +49,7 @@ TEST(NervanaFullyConnectedTest, Test) { AddConstInput(std::vector<int>{6, 10}, 1., "W", &ws); AddConstInput(std::vector<int>{6}, 0.1, "B", &ws); unique_ptr<OperatorBase> op( - new FullyConnectedOp<CUDAContext, NervanaEngine>(def, &ws)); + new FullyConnectedOp<float, CUDAContext, NervanaEngine>(def, &ws)); EXPECT_NE(nullptr, op.get()); EXPECT_TRUE(op->Run()); Blob* Yblob = ws.GetBlob("Y"); diff --git a/caffe2/contrib/nervana/nervana_math_gpu.cc b/caffe2/contrib/nervana/nervana_math_gpu.cc index 09c70e4343..f3010b9b95 100644 --- a/caffe2/contrib/nervana/nervana_math_gpu.cc +++ b/caffe2/contrib/nervana/nervana_math_gpu.cc @@ -11,18 +11,10 @@ namespace math { // limitation that the data has to be contiguous in memory. template <> void Gemm<float, CUDAContext, NervanaEngine>( - const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, - const int M, - const int N, - const int K, - const float alpha, - const float* A, - const float* B, - const float beta, - float* C, - CUDAContext* context, - TensorProto::DataType math_type) { + const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, + const int M, const int N, const int K, const float alpha, const float* A, + const float* B, const float beta, float* C, CUDAContext* context) { + // Note that cublas follows fortran order, so the order is different from // the cblas convention. int lda = (TransA == CblasNoTrans) ? K : M; diff --git a/caffe2/operators/elementwise_op.cu b/caffe2/operators/elementwise_op.cu index c75bd23867..016a555310 100644 --- a/caffe2/operators/elementwise_op.cu +++ b/caffe2/operators/elementwise_op.cu @@ -5,7 +5,6 @@ #include "caffe2/core/common_gpu.h" #include "caffe2/core/context_gpu.h" #include "caffe2/operators/elementwise_op.h" -#include "caffe2/utils/conversions.h" namespace caffe2 { @@ -63,6 +62,9 @@ REGISTER_CUDA_OPERATOR( \ name, BinaryElementwiseOp< \ input_type, CUDAContext, Cuda##name##Functor, output_type>) +#define CUDA_ADD(x, y) ((x) + (y)) +CUDA_FUNCTOR(Add, CUDA_ADD, NumericTypes, SameTypeAsInput); +#undef CUDA_ADD #define CUDA_SUB(x, y) ((x) - (y)) CUDA_FUNCTOR(Sub, CUDA_SUB, NumericTypes, SameTypeAsInput); #undef CUDA_SUB @@ -262,58 +264,4 @@ bool SumReduceLikeOp<CUDAContext>::DoRunWithType() { REGISTER_CUDA_OPERATOR(SumReduceLike, SumReduceLikeOp<CUDAContext>); -namespace { - -template <typename T, typename M> -__global__ void binary_add_kernel(const int N, const T* a, const T* b, T* r){ - CUDA_1D_KERNEL_LOOP(idx, N){ - r[idx] = convert::To<M, T>( - convert::To<T, M>(a[idx]) + convert::To<T, M>(b[idx])); -} -}; -} -; // namespace - -// Actual Add operator, because the above macros are read-only. -class CUDAAddOp final : public Operator<CUDAContext> { - public: - CUDAAddOp(const OperatorDef& operator_def, Workspace* ws) - : Operator<CUDAContext>(operator_def, ws){}; - ~CUDAAddOp() {} - - template <typename T, typename M> - bool DoRunWithType() { - auto& X0 = Input(0); - auto& X1 = Input(1); - auto* output = Output(0); - - output->ResizeLike(X0); - - binary_add_kernel<T, M><<< - CAFFE_GET_BLOCKS(X0.size()), - CAFFE_CUDA_NUM_THREADS, - 0, - context_.cuda_stream()>>>( - X0.size(), - X0.template data<T>(), - X1.template data<T>(), - output->template mutable_data<T>()); - return true; - } - - bool RunOnDevice() override { - if (Input(0).IsType<float>()) { - return DoRunWithType<float, float>(); - } else if (Input(0).IsType<float16>()) { - return DoRunWithType<float16, float>(); - } else { - return false; - } - } -}; - -namespace { -REGISTER_CUDA_OPERATOR(Add, CUDAAddOp); -} // namespace - } // namespace caffe2 diff --git a/caffe2/operators/fully_connected_op.cc b/caffe2/operators/fully_connected_op.cc index c00f199d5b..7a0e0bc5a5 100644 --- a/caffe2/operators/fully_connected_op.cc +++ b/caffe2/operators/fully_connected_op.cc @@ -3,8 +3,8 @@ namespace caffe2 { namespace { -REGISTER_CPU_OPERATOR(FC, FullyConnectedOp<CPUContext>); -REGISTER_CPU_OPERATOR(FCGradient, FullyConnectedGradientOp<CPUContext>); +REGISTER_CPU_OPERATOR(FC, FullyConnectedOp<float, CPUContext>); +REGISTER_CPU_OPERATOR(FCGradient, FullyConnectedGradientOp<float, CPUContext>); OPERATOR_SCHEMA(FC) .NumInputs(3) diff --git a/caffe2/operators/fully_connected_op.h b/caffe2/operators/fully_connected_op.h index 45adc0a4d2..6e24c91103 100644 --- a/caffe2/operators/fully_connected_op.h +++ b/caffe2/operators/fully_connected_op.h @@ -3,13 +3,12 @@ #include "caffe2/core/context.h" #include "caffe2/core/operator.h" -#include "caffe2/utils/conversions.h" #include "caffe2/utils/math.h" namespace caffe2 { // This is Caffe's InnerProductOp, with a name that fits its purpose better. -template <class Context, class Engine = DefaultEngine> +template <typename T, class Context, class Engine = DefaultEngine> class FullyConnectedOp final : public Operator<Context> { public: USE_OPERATOR_CONTEXT_FUNCTIONS; @@ -18,13 +17,7 @@ class FullyConnectedOp final : public Operator<Context> { axis_(OperatorBase::GetSingleArgument<int32_t>("axis", 1)) {} ~FullyConnectedOp() {} - template < - typename T_X, - typename T_W, - typename T_B, - typename T_Y, - typename MATH> - bool DoRunWithType() { + bool RunOnDevice() override { const auto& X = Input(0); const auto& W = Input(1); const auto& b = Input(2); @@ -70,53 +63,44 @@ class FullyConnectedOp final : public Operator<Context> { Y->Resize(Y_shape_cache_); CAFFE_ENFORCE(M * N == Y->size(), dimErrorString()); - // W * x - math::Gemm<T_X, Context, Engine>( + // X * W^T + math::Gemm<T, Context, Engine>( CblasNoTrans, CblasTrans, M, N, K, 1, - X.template data<T_X>(), - W.template data<T_W>(), + X.template data<T>(), + W.template data<T>(), 0, - Y->template mutable_data<T_Y>(), + Y->template mutable_data<T>(), &context_); // Add bias term if (bias_multiplier_.size() != M) { // If the helper bias multiplier is not M, reshape and fill it with one. bias_multiplier_.Resize(M); - math::Set<T_B, Context>( + math::Set<T, Context>( M, - convert::To<float, T_B>(1), - bias_multiplier_.template mutable_data<T_B>(), + static_cast<T>(1), + bias_multiplier_.template mutable_data<T>(), &context_); } - math::Gemm<T_B, Context, Engine>( + math::Gemm<T, Context, Engine>( CblasNoTrans, CblasNoTrans, M, N, 1, 1, - bias_multiplier_.template data<T_B>(), - b.template data<T_B>(), + bias_multiplier_.template data<T>(), + b.template data<T>(), 1, - Y->template mutable_data<T_Y>(), + Y->template mutable_data<T>(), &context_); return true; } - bool RunOnDevice() override { - return DoRunWithType< - float, // X - float, // W - float, // B - float, // Y - float>(); // Math - } - protected: size_t axis_{1}; // A local vector to cache the output shape so we don't need to recreate @@ -125,7 +109,7 @@ class FullyConnectedOp final : public Operator<Context> { Tensor<Context> bias_multiplier_; }; -template <class Context, class Engine = DefaultEngine> +template <typename T, class Context, class Engine = DefaultEngine> class FullyConnectedGradientOp : public Operator<Context> { public: USE_OPERATOR_CONTEXT_FUNCTIONS; @@ -134,16 +118,7 @@ class FullyConnectedGradientOp : public Operator<Context> { axis_(OperatorBase::GetSingleArgument<int32_t>("axis", 1)) {} ~FullyConnectedGradientOp() {} - template < - typename T_X, - typename T_W, - typename T_DY, - typename T_B, - typename T_DX, - typename T_DW, - typename T_DB, - typename MATH> - bool DoRunWithType() { + bool RunOnDevice() override { const auto& X = Input(0); const auto& W = Input(1); const auto& dY = Input(2); @@ -162,72 +137,60 @@ class FullyConnectedGradientOp : public Operator<Context> { db->Resize(N); // Compute dW - math::Gemm<T_DY, Context, Engine>( + math::Gemm<T, Context, Engine>( CblasTrans, CblasNoTrans, N, K, M, - convert::To<float, MATH>(1), - dY.template data<T_DY>(), - X.template data<T_X>(), - convert::To<float, MATH>(0), - dW->template mutable_data<T_DW>(), + 1, + dY.template data<T>(), + X.template data<T>(), + 0, + dW->template mutable_data<T>(), &context_); if (bias_multiplier_.size() != M) { // If the helper bias multiplier is not M, reshape and fill it // with one. bias_multiplier_.Resize(M); - math::Set<T_B, Context>( + math::Set<T, Context>( M, - convert::To<float, T_B>(1), - bias_multiplier_.template mutable_data<T_B>(), + static_cast<T>(1), + bias_multiplier_.template mutable_data<T>(), &context_); } // Compute dB - math::Gemv<T_DY, Context>( + math::Gemv<T, Context>( CblasTrans, M, N, - convert::To<float, MATH>(1), - dY.template data<T_DY>(), - bias_multiplier_.template data<T_B>(), - convert::To<float, MATH>(0), - db->template mutable_data<T_DB>(), + 1, + dY.template data<T>(), + bias_multiplier_.template data<T>(), + 0, + db->template mutable_data<T>(), &context_); // Compute dX if (OutputSize() == 3) { auto* dX = Output(2); dX->ResizeLike(X); - math::Gemm<T_DX, Context, Engine>( + math::Gemm<T, Context, Engine>( CblasNoTrans, CblasNoTrans, M, K, N, - convert::To<float, MATH>(1), - dY.template data<T_DY>(), - W.template data<T_W>(), - convert::To<float, MATH>(0), - dX->template mutable_data<T_DX>(), + 1, + dY.template data<T>(), + W.template data<T>(), + 0, + dX->template mutable_data<T>(), &context_); } return true; } - bool RunOnDevice() override { - return DoRunWithType< - float, // X - float, // W - float, // dY - float, // B - float, // dX - float, // dW - float, // dB - float>(); // Math - } - protected: size_t axis_{1}; Tensor<Context> bias_multiplier_; diff --git a/caffe2/operators/fully_connected_op_gpu.cc b/caffe2/operators/fully_connected_op_gpu.cc index 07431862f8..8ee67acd0b 100644 --- a/caffe2/operators/fully_connected_op_gpu.cc +++ b/caffe2/operators/fully_connected_op_gpu.cc @@ -2,60 +2,9 @@ #include "caffe2/operators/fully_connected_op.h" namespace caffe2 { - -template <> -bool FullyConnectedOp<CUDAContext>::RunOnDevice() { - if (Input(0).IsType<float>()) { - return DoRunWithType< - float, // X - float, // W - float, // B - float, // Y - float>(); // Math - } else if (Input(0).IsType<float16>()) { - return DoRunWithType< - float16, // X - float16, // W - float16, // B - float16, // Y - float>(); // Math - } else { - CAFFE_THROW("Unsupported type"); - } - return false; -} - -template <> -bool FullyConnectedGradientOp<CUDAContext>::RunOnDevice() { - if (Input(0).IsType<float>()) { - return DoRunWithType< - float, // X - float, // W - float, // dY - float, // B - float, // dX - float, // dW - float, // dB - float>(); // Math - } else if (Input(0).IsType<float16>()) { - return DoRunWithType< - float16, // X - float16, // W - float16, // dY - float16, // B - float16, // dX - float16, // dW - float16, // dB - float>(); // Math - } else { - CAFFE_THROW("Unsupported type"); - } - return false; -} - namespace { - -REGISTER_CUDA_OPERATOR(FC, FullyConnectedOp<CUDAContext>); -REGISTER_CUDA_OPERATOR(FCGradient, FullyConnectedGradientOp<CUDAContext>); +REGISTER_CUDA_OPERATOR(FC, FullyConnectedOp<float, CUDAContext>); +REGISTER_CUDA_OPERATOR(FCGradient, + FullyConnectedGradientOp<float, CUDAContext>); } // namespace } // namespace caffe2 diff --git a/caffe2/operators/sparse_to_dense_op.h b/caffe2/operators/sparse_to_dense_op.h index 439d96ce65..d48b61755f 100644 --- a/caffe2/operators/sparse_to_dense_op.h +++ b/caffe2/operators/sparse_to_dense_op.h @@ -50,6 +50,7 @@ class SparseToDenseOp final : public Operator<Context> { return DispatchHelper< TensorTypes2< float, + double, int32_t, int64_t, GenericTensorImplementation>, diff --git a/caffe2/operators/square_root_divide_op.h b/caffe2/operators/square_root_divide_op.h index df018bf0d6..644c2bd96f 100644 --- a/caffe2/operators/square_root_divide_op.h +++ b/caffe2/operators/square_root_divide_op.h @@ -17,7 +17,7 @@ class SquareRootDivideOp final : public Operator<Context> { : Operator<Context>(operator_def, ws) {} bool RunOnDevice() override { - return DispatchHelper<TensorTypes<float>>::call(this, Input(DATA)); + return DispatchHelper<TensorTypes<float, double>>::call(this, Input(DATA)); } private: diff --git a/caffe2/operators/utility_ops.cc b/caffe2/operators/utility_ops.cc index 99f0f20fa1..771da9c2e9 100644 --- a/caffe2/operators/utility_ops.cc +++ b/caffe2/operators/utility_ops.cc @@ -3,12 +3,6 @@ #include <cmath> namespace caffe2 { - -template <> -bool WeightedSumOp<CPUContext>::RunOnDevice() { - return DoRunWithType<float>(); -} - namespace { REGISTER_CPU_OPERATOR(WallClockTime, WallClockTimeOp<CPUContext>); @@ -18,9 +12,10 @@ REGISTER_CPU_OPERATOR(FlattenToVec, FlattenToVecOp<CPUContext>); REGISTER_CPU_OPERATOR(Alias, AliasOp<CPUContext>); REGISTER_CPU_OPERATOR(ResizeLike, ResizeLikeOp<CPUContext>); -REGISTER_CPU_OPERATOR(Sum, SumOp<CPUContext>); -REGISTER_CPU_OPERATOR(SumInt, SumOp<CPUContext>); -REGISTER_CPU_OPERATOR(WeightedSum, WeightedSumOp<CPUContext>); +REGISTER_CPU_OPERATOR(Sum, SumOp<float, CPUContext>); +REGISTER_CPU_OPERATOR(SumInt, SumOp<int, CPUContext>); + +REGISTER_CPU_OPERATOR(WeightedSum, WeightedSumOp<float, CPUContext>); REGISTER_CPU_OPERATOR( ScatterWeightedSum, ScatterWeightedSumOp<float, CPUContext>); diff --git a/caffe2/operators/utility_ops.h b/caffe2/operators/utility_ops.h index 85f08d2b99..95572317ef 100644 --- a/caffe2/operators/utility_ops.h +++ b/caffe2/operators/utility_ops.h @@ -250,14 +250,13 @@ class ResizeLikeOp : public Operator<Context> { } }; -template <class Context> +template <typename T, class Context> class SumOp : public Operator<Context> { public: USE_OPERATOR_CONTEXT_FUNCTIONS; USE_SIMPLE_CTOR_DTOR(SumOp); - template <typename T, typename M> - bool DoRunWithType() { + bool RunOnDevice() override { auto& input0 = Input(0); auto* output = Output(0); if (InputSize() == 1) { @@ -298,16 +297,6 @@ class SumOp : public Operator<Context> { } return true; } - - bool RunOnDevice() override { - if (Input(0).template IsType<float>()) { - return DoRunWithType<float, float>(); - } else if (Input(0).template IsType<int>()) { - return DoRunWithType<int, int>(); - } else { - return false; - } - } }; // WeightedSumOp computes the weighted sum of several tensors. The input should @@ -315,14 +304,13 @@ class SumOp : public Operator<Context> { // shape, and weight_i are size 1 tensors that specifies the weight of each // vector. Note that if one wants to do in-place computation, it could only be // done with X_0 also as the output, but not other X_i. -template <class Context> +template <typename T, class Context> class WeightedSumOp : public Operator<Context> { public: USE_OPERATOR_CONTEXT_FUNCTIONS; USE_SIMPLE_CTOR_DTOR(WeightedSumOp); - template <typename DstType> - bool DoRunWithType() { + bool RunOnDevice() override { DCHECK_EQ(InputSize() % 2, 0); auto& X0 = Input(0); auto& weight0 = Input(1); @@ -331,11 +319,11 @@ class WeightedSumOp : public Operator<Context> { int size = X0.size(); auto* output = Output(0); output->ResizeLike(X0); - math::Scale<DstType, Context>( + math::Scale<T, Context>( size, - weight0.template data<float>(), - X0.template data<DstType>(), - output->template mutable_data<DstType>(), + weight0.template data<T>(), + X0.template data<T>(), + output->template mutable_data<T>(), &context_); for (int i = 2; i < InputSize(); i += 2) { auto& X = Input(i); @@ -350,16 +338,15 @@ class WeightedSumOp : public Operator<Context> { auto& weight = Input(i + 1); DCHECK_EQ(X.size(), size); DCHECK_EQ(weight.size(), 1); - math::Axpy<DstType, Context>( + math::Axpy<T, Context>( size, - weight.template data<float>(), - X.template data<DstType>(), - output->template mutable_data<DstType>(), + weight.template data<T>(), + X.template data<T>(), + output->template mutable_data<T>(), &context_); } return true; } - bool RunOnDevice() override; }; /** diff --git a/caffe2/operators/utility_ops_gpu.cc b/caffe2/operators/utility_ops_gpu.cc index 0bb035328a..b3df226886 100644 --- a/caffe2/operators/utility_ops_gpu.cc +++ b/caffe2/operators/utility_ops_gpu.cc @@ -5,30 +5,6 @@ namespace caffe2 { template <> -bool WeightedSumOp<CUDAContext>::RunOnDevice() { - if (Input(0).IsType<float>()) { - return DoRunWithType<float>(); - } else if (Input(0).IsType<float16>()) { - return DoRunWithType<float16>(); - } else { - CAFFE_THROW("Unsupported inputs"); - } - return false; -} - -template <> -bool SumOp<CUDAContext>::RunOnDevice() { - if (Input(0).IsType<float>()) { - return DoRunWithType<float, float>(); - } else if (Input(0).IsType<float16>()) { - return DoRunWithType<float16, float16>(); - } else { - CAFFE_THROW("Unsupported inputs"); - } - return false; -} - -template <> class CopyOnDeviceLikeOp<CUDAContext, CUDAContext, CUDAContext> : public Operator<CUDAContext> { public: @@ -59,7 +35,9 @@ REGISTER_CUDA_OPERATOR(Squeeze, SqueezeOp<CUDAContext>); REGISTER_CUDA_OPERATOR(ExpandDims, ExpandDimsOp<CUDAContext>); REGISTER_CUDA_OPERATOR(Alias, AliasOp<CUDAContext>); REGISTER_CUDA_OPERATOR(ResizeLike, ResizeLikeOp<CUDAContext>); -REGISTER_CUDA_OPERATOR(WeightedSum, WeightedSumOp<CUDAContext>); +REGISTER_CUDA_OPERATOR(Sum, SumOp<float, CUDAContext>); + +REGISTER_CUDA_OPERATOR(WeightedSum, WeightedSumOp<float, CUDAContext>); REGISTER_CUDA_OPERATOR(Shape, ShapeOp<CUDAContext>); // From whatever the current context, ensure the output is TensorCPU REGISTER_CUDA_OPERATOR( diff --git a/caffe2/utils/conversions.h b/caffe2/utils/conversions.h deleted file mode 100644 index 0c6c3238f7..0000000000 --- a/caffe2/utils/conversions.h +++ /dev/null @@ -1,182 +0,0 @@ -#pragma once - -#include <caffe2/core/types.h> - -#ifdef __CUDA_ARCH__ -#include <cuda_fp16.h> -#endif - -#ifdef __CUDA_ARCH__ -#define CONVERSIONS_DECL __host__ __device__ inline -#else -#define CONVERSIONS_DECL inline -#endif - -namespace caffe2 { - -namespace convert { - -namespace { -inline float16 cpu_float2half_rn(float f) { - float16 ret; - - static_assert( - sizeof(unsigned int) == sizeof(float), - "Programming error sizeof(unsigned int) != sizeof(float)"); - - unsigned* xp = reinterpret_cast<unsigned int*>(&f); - unsigned x = *xp; - unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1; - unsigned sign, exponent, mantissa; - - // Get rid of +NaN/-NaN case first. - if (u > 0x7f800000) { - ret.x = 0x7fffU; - return ret; - } - - sign = ((x >> 16) & 0x8000); - - // Get rid of +Inf/-Inf, +0/-0. - if (u > 0x477fefff) { - ret.x = sign | 0x7c00U; - return ret; - } - if (u < 0x33000001) { - ret.x = (sign | 0x0000); - return ret; - } - - exponent = ((u >> 23) & 0xff); - mantissa = (u & 0x7fffff); - - if (exponent > 0x70) { - shift = 13; - exponent -= 0x70; - } else { - shift = 0x7e - exponent; - exponent = 0; - mantissa |= 0x800000; - } - lsb = (1 << shift); - lsb_s1 = (lsb >> 1); - lsb_m1 = (lsb - 1); - - // Round to nearest even. - remainder = (mantissa & lsb_m1); - mantissa >>= shift; - if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) { - ++mantissa; - if (!(mantissa & 0x3ff)) { - ++exponent; - mantissa = 0; - } - } - - ret.x = (sign | (exponent << 10) | mantissa); - - return ret; -} - -inline float cpu_half2float(float16 h) { - unsigned sign = ((h.x >> 15) & 1); - unsigned exponent = ((h.x >> 10) & 0x1f); - unsigned mantissa = ((h.x & 0x3ff) << 13); - - if (exponent == 0x1f) { /* NaN or Inf */ - mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0); - exponent = 0xff; - } else if (!exponent) { /* Denorm or Zero */ - if (mantissa) { - unsigned int msb; - exponent = 0x71; - do { - msb = (mantissa & 0x400000); - mantissa <<= 1; /* normalize */ - --exponent; - } while (!msb); - mantissa &= 0x7fffff; /* 1.mantissa is implicit */ - } - } else { - exponent += 0x70; - } - - int temp = ((sign << 31) | (exponent << 23) | mantissa); - - unsigned* rp = reinterpret_cast<unsigned*>(&temp); - return *rp; -} - -}; // anonymous -// general version: defer to static_cast -template <typename IN, typename OUT> -CONVERSIONS_DECL OUT To(const IN in) { - return static_cast<OUT>(in); -} - -#if __CUDA_ARCH__ -__device__ __inline__ __half inf_clip(__half h) { - int isi = __hisinf(h); - if (isi > 0) { - // Exponent all ones except LSB (0x1e), mantissa is all ones (0x3ff) - h.x = 0x7bffU; - } else if (isi < 0) { - // As above, negated - h.x = 0x7bffU ^ 0x8000; - } - return h; -} -#endif - -// explicit for fp16 -template <> -CONVERSIONS_DECL float16 To(const float in) { -#if __CUDA_ARCH__ - // hacky interface between C2 fp16 and CUDA - float16 ret; - __half r; - // r.x = __float2half_rn(in); - // ret.x = inf_clip(r).x; - ret.x = __float2half(in).x; - return ret; -#else - return cpu_float2half_rn(in); -#endif -} - -template <> -CONVERSIONS_DECL float To(const float16 in) { -#if __CUDA_ARCH__ - __half tmp; - tmp.x = in.x; - return __half2float(tmp); -#else - return cpu_half2float(in); -#endif -}; - -template <> -CONVERSIONS_DECL float To(const float in) { - return in; -} - -template <typename OUT, typename IN> -CONVERSIONS_DECL OUT Get(IN x) { - return static_cast<OUT>(x); -} - -template <> -CONVERSIONS_DECL float Get(float16 x) { - return To<float16, float>(x); -} - -template <> -CONVERSIONS_DECL float16 Get(float x) { - return To<float, float16>(x); -} - -}; // namespace convert - -}; // namespace caffe2 - -#undef CONVERSIONS_DECL diff --git a/caffe2/utils/math-detail.h b/caffe2/utils/math-detail.h index 07a1f997d6..35a880a6d4 100644 --- a/caffe2/utils/math-detail.h +++ b/caffe2/utils/math-detail.h @@ -11,12 +11,8 @@ namespace detail { template<typename T, class Context, int FixedSize> struct ScaleImpl { - inline void operator()( - const int N, - const float alpha, - const T* x, - T* y, - Context* context) { + inline void + operator()(const int N, const T alpha, const T* x, T* y, Context* context) { Scale(N, alpha, x, y, context); } }; @@ -26,7 +22,7 @@ template<typename T> struct ScaleImpl<T, CPUContext, 1> { inline void operator()( const int N, - const float alpha, + const T alpha, const T* x, T* y, CPUContext* context) { @@ -37,12 +33,8 @@ struct ScaleImpl<T, CPUContext, 1> { template<typename T, class Context, int FixedSize> struct AxpyImpl { - inline void operator()( - const int N, - const float alpha, - const T* x, - T* y, - Context* context) { + inline void + operator()(const int N, const T alpha, const T* x, T* y, Context* context) { Axpy(N, alpha, x, y, context); } }; @@ -52,7 +44,7 @@ template<typename T> struct AxpyImpl<T, CPUContext, 1> { inline void operator()( const int N, - const float alpha, + const T alpha, const T* x, T* y, CPUContext* context) { @@ -65,22 +57,14 @@ struct AxpyImpl<T, CPUContext, 1> { } // namespace detail template <typename T, class Context, int FixedSize> -inline void ScaleFixedSize( - const int N, - const float alpha, - const T* x, - T* y, - Context* context) { +inline void +ScaleFixedSize(const int N, const T alpha, const T* x, T* y, Context* context) { detail::ScaleImpl<T, Context, FixedSize>()(N, alpha, x, y, context); } template <typename T, class Context, int FixedSize> -inline void AxpyFixedSize( - const int N, - const float alpha, - const T* x, - T* y, - Context* context) { +inline void +AxpyFixedSize(const int N, const T alpha, const T* x, T* y, Context* context) { detail::AxpyImpl<T, Context, FixedSize>()(N, alpha, x, y, context); } diff --git a/caffe2/utils/math.h b/caffe2/utils/math.h index 105cb19733..a2472c0d33 100644 --- a/caffe2/utils/math.h +++ b/caffe2/utils/math.h @@ -141,20 +141,10 @@ void ColwiseMax(const int N, const int D, const T* x, T* y, // Decaf gemm provides a simpler interface to the gemm functions, with the // limitation that the data has to be contiguous in memory. -template <typename T, class Context, class Engine = DefaultEngine> -void Gemm( - const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, - const int M, - const int N, - const int K, - const float alpha, - const T* A, - const T* B, - const float beta, - T* C, - Context* context, - TensorProto::DataType math_type = TensorProto_DataType_FLOAT); +template <typename T, class Context, class Engine=DefaultEngine> +void Gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, + const int M, const int N, const int K, const T alpha, const T* A, + const T* B, const T beta, T* C, Context* context); // We also provide a gemm that has explicit lda, ldb and ldc specified. // In most cases you probably want to use the function above, though. @@ -179,18 +169,10 @@ void GemmEx( // to Trans, the output is: // CblasNoTrans: x is an N dim vector and y is an M dim vector. // CblasTrans: x is an M dim vector and y is an N dim vector. -template <typename T, class Context, class Engine = DefaultEngine> -void Gemv( - const CBLAS_TRANSPOSE TransA, - const int M, - const int N, - const float alpha, - const T* A, - const T* x, - const float beta, - T* y, - Context* context, - TensorProto::DataType math_type = TensorProto_DataType_FLOAT); +template <typename T, class Context, class Engine=DefaultEngine> +void Gemv(const CBLAS_TRANSPOSE TransA, const int M, const int N, + const T alpha, const T* A, const T* x, const T beta, + T* y, Context* context); template <typename T, class Context> void Set(const TIndex N, const T alpha, T* X, Context* context); @@ -236,31 +218,28 @@ void Select(const int N, const int D, const T* x, const int* idx, T* y, Context* context); template <typename T, class Context> -void Scale(const int N, const float alpha, const T* x, T* y, Context* context); +void Scale(const int N, const T alpha, const T* x, T* y, Context* context); // Different from the Scale function above, if alpha is passed in // as a pointer, we will assume that it lives on the Context device, // for example on GPU. template <typename T, class Context> -void Scale(const int N, const float* alpha, const T* x, T* y, Context* context); +void Scale(const int N, const T* alpha, const T* x, T* y, + Context* context); template <typename T, class Context> -void Axpy(const int N, const float alpha, const T* x, T* y, Context* context); +void Axpy(const int N, const T alpha, const T* x, T* y, Context* context); // Different from the Axpy function above, if alpha is passed in // as a pointer, we will assume that it lives on the Context device, // for example on GPU. template <typename T, class Context> -void Axpy(const int N, const float* alpha, const T* x, T* y, Context* context); +void Axpy(const int N, const T* alpha, const T* x, T* y, + Context* context); template <typename T, class Context> -void Axpby( - const int N, - const float alpha, - const T* x, - const T b, - T* y, - Context* context); +void Axpby(const int N, const T alpha, const T* x, const T b, T* y, + Context* context); template <typename T, class Context, int order> void Im2colNd( diff --git a/caffe2/utils/math_cpu.cc b/caffe2/utils/math_cpu.cc index e4340dfaf1..5cac0c8339 100644 --- a/caffe2/utils/math_cpu.cc +++ b/caffe2/utils/math_cpu.cc @@ -58,18 +58,9 @@ namespace math { // CblasTrans, respectively, for each of A and B. template <> void Gemm<float, CPUContext>( - const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, - const int M, - const int N, - const int K, - const float alpha, - const float* A, - const float* B, - const float beta, - float* C, - CPUContext* context, - TensorProto::DataType math_type) { + const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, + const int M, const int N, const int K, const float alpha, const float* A, + const float* B, const float beta, float* C, CPUContext* context) { auto C_mat = EigenMatrixMap<float>(C, N, M); if (beta == 0) { C_mat.setZero(); @@ -187,8 +178,7 @@ void Gemv<float, CPUContext>( const float* x, const float beta, float* y, - CPUContext* context, - TensorProto::DataType math_type) { + CPUContext* context) { EigenVectorMap<float> y_vec(y, TransA == CblasNoTrans ? M : N); if (beta == 0) { // In Caffe2 we often do a lazy initialization, which may contain NaNs in @@ -215,22 +205,19 @@ void Gemv<float, CPUContext>( } } -#define CAFFE2_SPECIALIZED_SCALE(T) \ - template <> \ - void Scale<T, CPUContext>( \ - const int n, const float alpha, const T* x, T* y, CPUContext* context) { \ - EigenVectorMap<T>(y, n) = ConstEigenVectorMap<T>(x, n) * alpha; \ - } \ - template <> \ - void Scale<T, CPUContext>( \ - const int n, \ - const float* alpha, \ - const T* x, \ - T* y, \ - CPUContext* context) { \ - EigenVectorMap<T>(y, n) = ConstEigenVectorMap<T>(x, n) * (*alpha); \ +#define CAFFE2_SPECIALIZED_SCALE(T) \ + template <> \ + void Scale<T, CPUContext>( \ + const int n, const T alpha, const T* x, T* y, CPUContext* context) { \ + EigenVectorMap<T>(y, n) = ConstEigenVectorMap<T>(x, n) * alpha; \ + } \ + template <> \ + void Scale<T, CPUContext>( \ + const int n, const T* alpha, const T* x, T* y, CPUContext* context) { \ + EigenVectorMap<T>(y, n) = ConstEigenVectorMap<T>(x, n) * (*alpha); \ } CAFFE2_SPECIALIZED_SCALE(float) +CAFFE2_SPECIALIZED_SCALE(double) #undef CAFFE2_SPECIALIZED_SCALE #define CAFFE2_SPECIALIZED_DOT(T) \ @@ -241,6 +228,7 @@ void Dot<T, CPUContext>( \ *y = ConstEigenVectorMap<T>(a, N).dot(ConstEigenVectorMap<T>(b, N)); \ } CAFFE2_SPECIALIZED_DOT(float) +CAFFE2_SPECIALIZED_DOT(double) #undef CAFFE2_SPECIALIZED_DOT #define CAFFE2_SPECIALIZED_AXPY(T) \ @@ -255,6 +243,7 @@ CAFFE2_SPECIALIZED_DOT(float) EigenVectorMap<T>(Y, N) += ConstEigenVectorMap<T>(x, N) * (*alpha); \ } CAFFE2_SPECIALIZED_AXPY(float) +CAFFE2_SPECIALIZED_AXPY(double) #undef CAFFE2_SPECIALIZED_AXPY #define CAFFE2_SPECIALIZED_AXPBY(T) \ @@ -265,24 +254,16 @@ void Axpby<T, CPUContext>(const int N, const T alpha, const T* x, \ y_vec = y_vec * beta + ConstEigenVectorMap<T>(x, N) * alpha; \ } CAFFE2_SPECIALIZED_AXPBY(float) +CAFFE2_SPECIALIZED_AXPBY(double) #undef CAFFE2_SPECIALIZED_AXPBY #else // CAFFE2_USE_EIGEN_FOR_BLAS template <> void Gemm<float, CPUContext>( - const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, - const int M, - const int N, - const int K, - const float alpha, - const float* A, - const float* B, - const float beta, - float* C, - CPUContext* context, - TensorProto::DataType math_type) { + const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, + const int M, const int N, const int K, const float alpha, const float* A, + const float* B, const float beta, float* C, CPUContext* context) { int lda = (TransA == CblasNoTrans) ? K : M; int ldb = (TransB == CblasNoTrans) ? N : K; cblas_sgemm(CblasRowMajor, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, @@ -311,39 +292,29 @@ void GemmEx<float, CPUContext>( template <> void Gemv<float, CPUContext>( - const CBLAS_TRANSPOSE TransA, - const int M, - const int N, - const float alpha, - const float* A, - const float* x, - const float beta, - float* y, - CPUContext* context, - TensorProto::DataType math_type) { + const CBLAS_TRANSPOSE TransA, const int M, const int N, const float alpha, + const float* A, const float* x, const float beta, float* y, + CPUContext* context) { cblas_sgemv(CblasRowMajor, TransA, M, N, alpha, A, N, x, 1, beta, y, 1); } -#define CAFFE2_SPECIALIZED_SCALE(T, prefix) \ - template <> \ - void Scale<T, CPUContext>( \ - const int n, const float alpha, const T* x, T* y, CPUContext* context) { \ - if (y != x) \ - cblas_##prefix##copy(n, x, 1, y, 1); \ - cblas_##prefix##scal(n, static_cast<float>(alpha), y, 1); \ - } \ - template <> \ - void Scale<T, CPUContext>( \ - const int n, \ - const float* alpha, \ - const T* x, \ - T* y, \ - CPUContext* context) { \ - if (y != x) \ - cblas_##prefix##copy(n, x, 1, y, 1); \ - cblas_##prefix##scal(n, static_cast<float>(*alpha), y, 1); \ +#define CAFFE2_SPECIALIZED_SCALE(T, prefix) \ + template <> \ + void Scale<T, CPUContext>( \ + const int n, const T alpha, const T* x, T* y, CPUContext* context) { \ + if (y != x) \ + cblas_##prefix##copy(n, x, 1, y, 1); \ + cblas_##prefix##scal(n, alpha, y, 1); \ + } \ + template <> \ + void Scale<T, CPUContext>( \ + const int n, const T* alpha, const T* x, T* y, CPUContext* context) { \ + if (y != x) \ + cblas_##prefix##copy(n, x, 1, y, 1); \ + cblas_##prefix##scal(n, *alpha, y, 1); \ } CAFFE2_SPECIALIZED_SCALE(float, s) +CAFFE2_SPECIALIZED_SCALE(double, d) #undef CAFFE2_SPECIALIZED_SCALE #define CAFFE2_SPECIALIZED_DOT(T, prefix) \ @@ -354,6 +325,7 @@ void Dot<T, CPUContext>( \ *y = cblas_##prefix##dot(N, a, 1, b, 1); \ } CAFFE2_SPECIALIZED_DOT(float, s) +CAFFE2_SPECIALIZED_DOT(double, d) #undef CAFFE2_SPECIALIZED_DOT #define CAFFE2_SPECIALIZED_AXPY(T, prefix) \ @@ -368,6 +340,7 @@ CAFFE2_SPECIALIZED_DOT(float, s) cblas_##prefix##axpy(N, *alpha, x, 1, y, 1); \ } CAFFE2_SPECIALIZED_AXPY(float, s) +CAFFE2_SPECIALIZED_AXPY(double, d) #undef CAFFE2_SPECIALIZED_AXPY // cblas_[sd]axpby is not a standard blas function, and if MKL is not present, @@ -389,6 +362,7 @@ void Axpby<T, CPUContext>(const int N, const T alpha, const T* x, \ } #endif // CAFFE2_USE_MKL CAFFE2_SPECIALIZED_AXPBY(float, s) +CAFFE2_SPECIALIZED_AXPBY(double, d) #undef CAFFE2_SPECIALIZED_AXPBY #endif // CAFFE2_USE_EIGEN_FOR_BLAS @@ -462,8 +436,11 @@ void Funcname<T, CPUContext>(const int N, const T* x, T* y, \ EigenVectorMap<T>(y, N) = ConstEigenVectorMap<T>(x, N).array().expr(); \ } DELEGATE_SIMPLE_UNARY_FUNCTION(float, Exp, exp) +DELEGATE_SIMPLE_UNARY_FUNCTION(double, Exp, exp) DELEGATE_SIMPLE_UNARY_FUNCTION(float, Log, log) +DELEGATE_SIMPLE_UNARY_FUNCTION(double, Log, log) DELEGATE_SIMPLE_UNARY_FUNCTION(float, Sqr, square) +DELEGATE_SIMPLE_UNARY_FUNCTION(double, Sqr, square) #undef DELEGATE_SIMPLE_UNARY_FUNCTION #define DELEGATE_POWX_FUNCTION(T) \ @@ -473,6 +450,7 @@ void Powx<T, CPUContext>( \ EigenVectorMap<T>(y, N) = ConstEigenVectorMap<T>(a, N).array().pow(b); \ } DELEGATE_POWX_FUNCTION(float) +DELEGATE_POWX_FUNCTION(double) #undef DELEGATE_POWX_FUNCTION #endif // CAFFE2_USE_MKL @@ -498,6 +476,7 @@ EIGEN_SIMPLE_BINARY_FUNCTION(int64_t, Funcname, expr) #define DEFINE_SIMPLE_BINARY_FUNCTION(Funcname, expr) \ EIGEN_SIMPLE_BINARY_FUNCTION(float, Funcname, expr) \ +EIGEN_SIMPLE_BINARY_FUNCTION(double, Funcname, expr) \ EIGEN_SIMPLE_BINARY_FUNCTION(int32_t, Funcname, expr) \ EIGEN_SIMPLE_BINARY_FUNCTION(int64_t, Funcname, expr) @@ -567,6 +546,7 @@ CAFFE2_SPECIALIZED_COLWISEMAX(float) DELEGATE_BROADCAST_BINARY_FUNCTION(int32_t, name, op) \ DELEGATE_BROADCAST_BINARY_FUNCTION(int64_t, name, op) \ DELEGATE_BROADCAST_BINARY_FUNCTION(float, name, op) \ + DELEGATE_BROADCAST_BINARY_FUNCTION(double, name, op) DEFINE_BROADCAST_BINARY_FUNCTION(Add, +) DEFINE_BROADCAST_BINARY_FUNCTION(Sub, -) @@ -622,6 +602,7 @@ CAFFE2_SPECIALIZED_SET(uint16_t); #define CAFFE2_DEFINE_BINARY_OP(name, op) \ CAFFE2_INSTANTIATE_BINARY_OP(name, op, float) \ + CAFFE2_INSTANTIATE_BINARY_OP(name, op, double) \ CAFFE2_INSTANTIATE_BINARY_OP(name, op, int32_t) \ CAFFE2_INSTANTIATE_BINARY_OP(name, op, int64_t) @@ -663,6 +644,7 @@ void Not<bool, CPUContext>( } CAFFE2_SPECIALIZED_CPU_ADD_STRIPED_BATCH(float); +CAFFE2_SPECIALIZED_CPU_ADD_STRIPED_BATCH(double); #undef CAFFE2_SPECIALIZED_CPU_ADD_STRIPED_BATCH template <> @@ -735,6 +717,7 @@ void RandGaussian<float, CPUContext>( } CAFFE2_SPECIALIZED_SUM(float); +CAFFE2_SPECIALIZED_SUM(double); CAFFE2_SPECIALIZED_SUM(int32_t); CAFFE2_SPECIALIZED_SUM(int64_t); diff --git a/caffe2/utils/math_gpu.cu b/caffe2/utils/math_gpu.cu index 46c5bc02f1..cfbe91b5d0 100644 --- a/caffe2/utils/math_gpu.cu +++ b/caffe2/utils/math_gpu.cu @@ -5,9 +5,8 @@ #include <thrust/system/cuda/detail/par.h> #include <thrust/version.h> -#include "caffe2/core/context_gpu.h" -#include "caffe2/utils/conversions.h" #include "caffe2/utils/math.h" +#include "caffe2/core/context_gpu.h" #if THRUST_VERSION >= 100800 #define THRUST_SUPPORTS_PER_THREAD @@ -33,30 +32,33 @@ void Funcname<T, CUDAContext>( \ } DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Exp, expf); +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Exp, exp); DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Log, logf); +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Log, log); __device__ float cuda_sqrf(const float x) { return x * x; } +__device__ double cuda_sqr(const double x) { return x * x; } DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sqr, cuda_sqrf); +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Sqr, cuda_sqr); #undef DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION -#define DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(T, Funcname, expr) \ - __global__ void _Kernel_##T##_##Funcname( \ - const int N, const T* a, const T* b, T* y) { \ - CUDA_1D_KERNEL_LOOP(i, N) { \ - float r = convert::To<T, float>(a[i]) expr convert::To<T, float>(b[i]); \ - y[i] = convert::To<float, T>(r); \ - } \ - } \ - template <> \ - void Funcname<T, CUDAContext>( \ - const int N, const T* a, const T* b, T* y, CUDAContext* context) { \ - _Kernel_##T##_##Funcname<<< \ - CAFFE_GET_BLOCKS(N), \ - CAFFE_CUDA_NUM_THREADS, \ - 0, \ - context->cuda_stream()>>>(N, a, b, y); \ +#define DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(T, Funcname, expr) \ + __global__ void _Kernel_##T##_##Funcname( \ + const int N, const T* a, const T* b, T* y) { \ + CUDA_1D_KERNEL_LOOP(i, N) { \ + y[i] = a[i] expr b[i]; \ + } \ + } \ + template <> \ + void Funcname<T, CUDAContext>( \ + const int N, const T* a, const T* b, T* y, CUDAContext* context) { \ + _Kernel_##T##_##Funcname<<< \ + CAFFE_GET_BLOCKS(N), \ + CAFFE_CUDA_NUM_THREADS, \ + 0, \ + context->cuda_stream()>>>(N, a, b, y); \ } DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Add, +); @@ -64,27 +66,13 @@ DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Sub, -); DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Mul, *); DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Div, /); -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float16, Add, +); -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float16, Sub, -); -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float16, Mul, *); -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float16, Div, /); - // Caffe2 gemm provides a simpler interface to the gemm functions, with the // limitation that the data has to be contiguous in memory. template <> void Gemm<float, CUDAContext>( - const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, - const int M, - const int N, - const int K, - const float alpha, - const float* A, - const float* B, - const float beta, - float* C, - CUDAContext* context, - TensorProto::DataType math_type) { + const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, + const int M, const int N, const int K, const float alpha, const float* A, + const float* B, const float beta, float* C, CUDAContext* context) { // Note that cublas follows fortran order, so the order is different from // the cblas convention. int lda = (TransA == CblasNoTrans) ? K : M; @@ -111,91 +99,11 @@ void Gemm<float, CUDAContext>( } template <> -void Gemm<float16, CUDAContext>( - const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, - const int M, - const int N, - const int K, - const float alpha, - const float16* A, - const float16* B, - const float beta, - float16* C, - CUDAContext* context, - TensorProto::DataType math_type) { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (TransA == CblasNoTrans) ? K : M; - int ldb = (TransB == CblasNoTrans) ? N : K; - cublasOperation_t cuTransA = - (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - if (math_type == TensorProto_DataType_FLOAT) { - CUBLAS_CHECK(cublasSgemmEx( - context->cublas_handle(), - cuTransB, - cuTransA, - N, - M, - K, - &alpha, - B, - CUDA_R_16F, - ldb, - A, - CUDA_R_16F, - lda, - &beta, - C, - CUDA_R_16F, - N)); - - } else if (math_type == TensorProto_DataType_FLOAT16) { - // convert alpha, beta from caffe2::float16 -> __half - __half alpha_fp16; - alpha_fp16.x = convert::To<float, float16>(alpha).x; - __half beta_fp16; - beta_fp16.x = convert::To<float, float16>(beta).x; - // call cublasHgemm - CUBLAS_CHECK(cublasHgemm( - context->cublas_handle(), - cuTransB, - cuTransA, - N, - M, - K, - &alpha_fp16, - (const __half*)B, - ldb, - (const __half*)A, - lda, - &beta_fp16, - (__half*)C, - N)); - } else { - // fail - CAFFE_THROW("Unsupported math type"); - } -} - -template <> void GemmEx<float, CUDAContext>( - const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, - const int M, - const int N, - const int K, - const float alpha, - const float* A, - const int lda, - const float* B, - const int ldb, - const float beta, - float* C, - const int ldc, - CUDAContext* context) { + const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, + const int M, const int N, const int K, const float alpha, const float* A, + const int lda, const float* B, const int ldb, const float beta, float* C, + const int ldc, CUDAContext* context) { // Note that cublas follows fortran order, so the order is different from // the cblas convention. cublasOperation_t cuTransA = @@ -221,19 +129,40 @@ void GemmEx<float, CUDAContext>( template <> void Gemv<float, CUDAContext>( + const CBLAS_TRANSPOSE TransA, const int M, const int N, const float alpha, + const float* A, const float* x, const float beta, float* y, + CUDAContext* context) { + cublasOperation_t cuTransA = + (TransA == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N; + CUBLAS_ENFORCE(cublasSgemv( + context->cublas_handle(), + cuTransA, + N, + M, + &alpha, + A, + N, + x, + 1, + &beta, + y, + 1)); +} + +template <> +void Gemv<double, CUDAContext>( const CBLAS_TRANSPOSE TransA, const int M, const int N, - const float alpha, - const float* A, - const float* x, - const float beta, - float* y, - CUDAContext* context, - TensorProto::DataType math_type) { + const double alpha, + const double* A, + const double* x, + const double beta, + double* y, + CUDAContext* context) { cublasOperation_t cuTransA = (TransA == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N; - CUBLAS_ENFORCE(cublasSgemv( + CUBLAS_ENFORCE(cublasDgemv( context->cublas_handle(), cuTransA, N, @@ -287,73 +216,6 @@ CAFFE2_SPECIALIZED_CUDA_ADD_STRIPED_BATCH(float); CAFFE2_SPECIALIZED_CUDA_ADD_STRIPED_BATCH(double); #undef CAFFE2_SPECIALIZED_CUDA_ADD_STRIPED_BATCH -template <> -void Gemv<float16, CUDAContext>( - const CBLAS_TRANSPOSE TransA, - const int M, - const int N, - const float alpha, - const float16* A, - const float16* x, - const float beta, - float16* y, - CUDAContext* context, - TensorProto::DataType math_type) { - cublasOperation_t cuTransA = - (TransA == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N; - - // sort out what we need to call cublasSgemmEx / cublasHgemm - int m = (cuTransA == CUBLAS_OP_N) ? N : M; - int k = (cuTransA == CUBLAS_OP_N) ? M : N; - int LDA = (cuTransA == CUBLAS_OP_N) ? m : k; - int LDC = m; - - if (math_type == TensorProto_DataType_FLOAT) { - CUBLAS_CHECK(cublasSgemmEx( - context->cublas_handle(), - cuTransA, - CUBLAS_OP_N, - m, - 1, - k, - &alpha, - A, - CUDA_R_16F, - LDA, - x, - CUDA_R_16F, - k, - &beta, - y, - CUDA_R_16F, - LDC)); - } else if (math_type == TensorProto_DataType_FLOAT16) { - __half alpha_fp16; - alpha_fp16.x = convert::To<float, float16>(alpha).x; - __half beta_fp16; - beta_fp16.x = convert::To<float, float16>(beta).x; - - CUBLAS_CHECK(cublasHgemm( - context->cublas_handle(), - cuTransA, - CUBLAS_OP_N, - m, - 1, - k, - &alpha_fp16, - (const __half*)A, - LDA, - (const __half*)x, - k, - &beta_fp16, - (__half*)y, - LDC)); - } else { - // fail - CAFFE_THROW("Unsupported math type"); - } -} - namespace { template <typename T> __global__ void SetKernel(const int N, const T alpha, T* Y) { @@ -376,7 +238,6 @@ CAFFE2_SPECIALIZED_CUDA_SET(double); CAFFE2_SPECIALIZED_CUDA_SET(bool); CAFFE2_SPECIALIZED_CUDA_SET(int8_t); CAFFE2_SPECIALIZED_CUDA_SET(int16_t); -CAFFE2_SPECIALIZED_CUDA_SET(float16); CAFFE2_SPECIALIZED_CUDA_SET(int); CAFFE2_SPECIALIZED_CUDA_SET(int64_t); CAFFE2_SPECIALIZED_CUDA_SET(char); @@ -386,11 +247,11 @@ CAFFE2_SPECIALIZED_CUDA_SET(uint16_t); namespace { template <typename T> -__global__ void -UniformShift(const int N, const float min, const float max, T* x) { - float scale = max - min; +__global__ void UniformShift(const int N, const T min, const T max, + T* x) { + T scale = max - min; CUDA_1D_KERNEL_LOOP(i, N) { - x[i] = convert::To<float, T>(convert::To<T, float>(x[i]) * scale + min); + x[i] = x[i] * scale + min; } } @@ -475,6 +336,7 @@ void RandGaussian<double, CUDAContext>( context->curand_generator(), r, even_n, mean, std)); } + template<> void Dot<float, CUDAContext>( const int n, const float* a, const float* b, float* y, @@ -484,28 +346,13 @@ void Dot<float, CUDAContext>( context->Copy<float, CPUContext, CUDAContext>(1, &result, y); } -template <> -void Dot<float16, CUDAContext>( - const int n, - const float16* a, - const float16* b, - float16* y, +template<> +void Dot<double, CUDAContext>( + const int n, const double* a, const double* b, double* y, CUDAContext* context) { - float16 result; - // execute with 32-bit math - CUBLAS_CHECK(cublasDotEx( - context->cublas_handle(), - n, - a, - CUDA_R_16F, - 1, - b, - CUDA_R_16F, - 1, - &result, - CUDA_R_16F, - CUDA_R_32F)); - context->Copy<float16, CPUContext, CUDAContext>(1, &result, y); + double result; + CUBLAS_ENFORCE(cublasDdot(context->cublas_handle(), n, a, 1, b, 1, y)); + context->Copy<double, CPUContext, CUDAContext>(1, &result, y); } // A previous version of caffe2 used Thrust but it turns out that thrust @@ -516,7 +363,7 @@ void Dot<float16, CUDAContext>( template <typename T> __global__ void SumKernel(const int N, const T* X, T* Y, bool square) { const int idx = threadIdx.x; - __shared__ float reduction_buffer[SUM_KERNEL_NTHREADS]; + __shared__ T reduction_buffer[SUM_KERNEL_NTHREADS]; reduction_buffer[idx] = 0; @@ -524,12 +371,11 @@ __global__ void SumKernel(const int N, const T* X, T* Y, bool square) { // N -> 128 if (!square) { for (int i = idx; i < N; i += SUM_KERNEL_NTHREADS) { - reduction_buffer[idx] += convert::To<T, float>(X[i]); + reduction_buffer[idx] += X[i]; } } else { for (int i = idx; i < N; i += SUM_KERNEL_NTHREADS) { - float Xi = convert::To<T, float>(X[i]); - reduction_buffer[idx] += Xi * Xi; + reduction_buffer[idx] += X[i] * X[i]; } } __syncthreads(); @@ -547,7 +393,7 @@ __global__ void SumKernel(const int N, const T* X, T* Y, bool square) { for (int i = 0; i < 32; ++i) { tmp += reduction_buffer[i]; } - *Y = convert::To<float, T>(tmp); + *Y = tmp; } } @@ -560,7 +406,7 @@ __global__ void SumKernel(const int N, const T* X, T* Y, bool square) { } CAFFE2_MATH_SUM_FUNC(float) -CAFFE2_MATH_SUM_FUNC(float16) +CAFFE2_MATH_SUM_FUNC(double) #undef CAFFE2_MATH_SUM_FUNC #define CAFFE2_MATH_SUMSQR_FUNC(T) \ @@ -592,33 +438,18 @@ void Select<float, CUDAContext>( 0, context->cuda_stream()>>>(N, D, x, idx, y); } -template <> -void Select<float16, CUDAContext>( - const int N, - const int D, - const float16* x, - const int* idx, - float16* y, - CUDAContext* context) { - SelectKernel<float16><<< - CAFFE_GET_BLOCKS(N), - CAFFE_CUDA_NUM_THREADS, - 0, - context->cuda_stream()>>>(N, D, x, idx, y); -} - namespace { template <typename T> -__global__ void ScaleKernel(const int n, const float alpha, const T* x, T* y) { +__global__ void ScaleKernel( + const int n, const T alpha, const T* x, T* y) { CUDA_1D_KERNEL_LOOP(i, n) { - // y[i] = convert::To<float,T>(convert::To<T, float>(x[i]) * alpha); - y[i] = convert::Get<T>(convert::Get<float>(x[i]) * alpha); + y[i] = x[i] * alpha; } } template <typename T> -__global__ void -ScaleKernelDeviceAlpha(const int n, const float* alpha, const T* x, T* y) { +__global__ void ScaleKernelDeviceAlpha( + const int n, const T* alpha, const T* x, T* y) { CUDA_1D_KERNEL_LOOP(i, n) { y[i] = x[i] * (*alpha); } @@ -630,20 +461,6 @@ __global__ void PowKernel(const int n, const T* x, const T exponent, T* y) { y[i] = powf(x[i], exponent); } } - -// fp16 specialization -template <> -__global__ void ScaleKernelDeviceAlpha( - const int n, - const float* alpha, - const float16* x, - float16* y) { - CUDA_1D_KERNEL_LOOP(i, n) { - y[i] = convert::To<float, float16>( - convert::To<float16, float>(x[i]) * (*alpha)); - } -} - } // namespace template <> @@ -672,17 +489,12 @@ void Scale<float, CUDAContext>( } template <> -void Scale<float16, CUDAContext>( - const int n, - const float alpha, - const float16* x, - float16* y, +void Scale<double, CUDAContext>( + const int n, const double alpha, const double *x, double* y, CUDAContext* context) { - ScaleKernel<float16><<< - CAFFE_GET_BLOCKS(n), - CAFFE_CUDA_NUM_THREADS, - 0, - context->cuda_stream()>>>(n, alpha, x, y); + ScaleKernel<double><<< + CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>( + n, alpha, x, y); } template <> @@ -695,17 +507,11 @@ void Scale<float, CUDAContext>( } template <> -void Scale<float16, CUDAContext>( - const int n, - const float* alpha, - const float16* x, - float16* y, +void Scale<double, CUDAContext>( + const int n, const double* alpha, const double *x, double* y, CUDAContext* context) { - ScaleKernelDeviceAlpha<float16><<< - CAFFE_GET_BLOCKS(n), - CAFFE_CUDA_NUM_THREADS, - 0, - context->cuda_stream()>>>(n, alpha, x, y); + ScaleKernelDeviceAlpha<double><<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS, + 0, context->cuda_stream()>>>(n, alpha, x, y); } template <> @@ -721,42 +527,18 @@ void Axpy<float, CUDAContext>( template <> void Axpy<double, CUDAContext>( const int N, - const float alpha, + const double alpha, const double* X, double* Y, CUDAContext* context) { - double alpha_d{alpha}; - CUBLAS_ENFORCE( - cublasDaxpy(context->cublas_handle(), N, &alpha_d, X, 1, Y, 1)); -} - -template <> -void Axpy<float16, CUDAContext>( - const int N, - const float alpha, - const float16* X, - float16* Y, - CUDAContext* context) { - CUBLAS_CHECK(cublasAxpyEx( - context->cublas_handle(), - N, - &alpha, - CUDA_R_16F, - X, - CUDA_R_16F, - 1, - Y, - CUDA_R_16F, - 1, - CUDA_R_32F)); + CUBLAS_ENFORCE(cublasDaxpy(context->cublas_handle(), N, &alpha, X, 1, Y, 1)); } namespace { template <typename T> -__global__ void AxpyKernel(const int n, const float* a, const T* x, T* y) { +__global__ void AxpyKernel(const int n, const T* a, const T* x, T* y) { CUDA_1D_KERNEL_LOOP(index, n) { - y[index] = convert::Get<T>( - convert::Get<float>(x[index]) * (*a) + convert::Get<float>(y[index])); + y[index] += x[index] * (*a); } } } // namespace @@ -770,19 +552,14 @@ void Axpy<float, CUDAContext>( } template <> -void Axpy<float16, CUDAContext>( - const int n, - const float* alpha, - const float16* X, - float16* Y, - CUDAContext* context) { - AxpyKernel<float16><<< - CAFFE_GET_BLOCKS(n), - CAFFE_CUDA_NUM_THREADS, - 0, - context->cuda_stream()>>>(n, alpha, X, Y); +void Axpy<double, CUDAContext>( + const int n, const double* alpha, const double* X, + double* Y, CUDAContext* context) { + AxpyKernel<double><<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS, + 0, context->cuda_stream()>>>(n, alpha, X, Y); } + namespace { template <typename T> __global__ void AxpbyKernel(const int n, const T a, const T* x, @@ -801,6 +578,14 @@ void Axpby<float, CUDAContext>( 0, context->cuda_stream()>>>(n, a, x, b, y); } +template <> +void Axpby<double, CUDAContext>( + const int n, const double a, const double* x, const double b, double* y, + CUDAContext* context) { + AxpbyKernel<double><<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS, + 0, context->cuda_stream()>>>(n, a, x, b, y); +} + namespace { template <typename T> diff --git a/caffe2/utils/math_gpu_test.cc b/caffe2/utils/math_gpu_test.cc index b1f930bec0..2ceeddd355 100644 --- a/caffe2/utils/math_gpu_test.cc +++ b/caffe2/utils/math_gpu_test.cc @@ -67,4 +67,61 @@ TEST(MathUtilGPUTest, testAddStripedBatch) { } } +#define TEST_GEMV_WITH_TYPE(field_name) \ + TEST(MathUtilGPUTest, testGemv_##field_name) { \ + if (!HasCudaGPU()) \ + return; \ + Workspace ws; \ + DeviceOption option; \ + option.set_device_type(CUDA); \ + CUDAContext context(option); \ + Blob* blobx = ws.CreateBlob("X"); \ + Blob* bloby = ws.CreateBlob("Y"); \ + Blob* blobz = ws.CreateBlob("Z"); \ + Blob* bloby_host = ws.CreateBlob("Y_host"); \ + \ + vector<int> shapex{64, 128}; \ + vector<int> shapey{64}; \ + vector<int> shapez{128}; \ + \ + auto* tensorx = blobx->GetMutable<Tensor<CUDAContext>>(); \ + tensorx->Resize(shapex); \ + math::Set<field_name, CUDAContext>( \ + 64 * 128, \ + (field_name)1.0, \ + tensorx->mutable_data<field_name>(), \ + &context); \ + \ + auto* tensory = bloby->GetMutable<Tensor<CUDAContext>>(); \ + tensory->Resize(shapey); \ + math::Set<field_name, CUDAContext>( \ + 64, (field_name)1.0, tensory->mutable_data<field_name>(), &context); \ + \ + auto* tensorz = blobz->GetMutable<Tensor<CUDAContext>>(); \ + tensorz->Resize(shapez); \ + \ + math::Gemv<field_name, CUDAContext>( \ + CblasTrans, \ + 64, \ + 128, \ + 1.0, \ + tensorx->template data<field_name>(), \ + tensory->mutable_data<field_name>(), \ + 0.0, \ + tensorz->template mutable_data<field_name>(), \ + &context); \ + context.FinishDeviceComputation(); \ + \ + auto* tensory_host = bloby_host->GetMutable<Tensor<CPUContext>>(); \ + tensory_host->CopyFrom<CUDAContext, CUDAContext>(*tensorz, &context); \ + context.FinishDeviceComputation(); \ + \ + for (int i = 0; i < 128; i++) { \ + EXPECT_EQ(tensory_host->data<field_name>()[i], 64.0); \ + } \ + } + +TEST_GEMV_WITH_TYPE(float); +TEST_GEMV_WITH_TYPE(double); + } // namespace caffe2 |