summaryrefslogtreecommitdiff
path: root/caffe2
diff options
context:
space:
mode:
authorAapo Kyrola <akyrola@fb.com>2017-04-17 21:23:45 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2017-04-17 21:31:20 -0700
commit9ab077dc9d0bbe651348a498dd5472dc4d51f0af (patch)
tree53060eb0ae8e9e03d765cf80d63e1f37e08ef1c8 /caffe2
parent391fd141150d250ee60027fad4acf8aecefbae51 (diff)
downloadpytorch-9ab077dc9d0bbe651348a498dd5472dc4d51f0af.tar.gz
pytorch-9ab077dc9d0bbe651348a498dd5472dc4d51f0af.tar.bz2
pytorch-9ab077dc9d0bbe651348a498dd5472dc4d51f0af.zip
Revert D4871248: [caffe2][PR] fp16 support for FullyConnected op
Summary: This reverts commit 6a991c2c993dcf0b1e18aa3f2ffbe19e693dbadd Differential Revision: D4871248 fbshipit-source-id: b6d812d09a00c83e363432e84742c503abfed65b
Diffstat (limited to 'caffe2')
-rw-r--r--caffe2/contrib/nervana/nervana_fc_op_gpu.cc9
-rw-r--r--caffe2/contrib/nervana/nervana_fc_op_gpu_test.cc2
-rw-r--r--caffe2/contrib/nervana/nervana_math_gpu.cc16
-rw-r--r--caffe2/operators/elementwise_op.cu58
-rw-r--r--caffe2/operators/fully_connected_op.cc4
-rw-r--r--caffe2/operators/fully_connected_op.h111
-rw-r--r--caffe2/operators/fully_connected_op_gpu.cc57
-rw-r--r--caffe2/operators/sparse_to_dense_op.h1
-rw-r--r--caffe2/operators/square_root_divide_op.h2
-rw-r--r--caffe2/operators/utility_ops.cc13
-rw-r--r--caffe2/operators/utility_ops.h37
-rw-r--r--caffe2/operators/utility_ops_gpu.cc28
-rw-r--r--caffe2/utils/conversions.h182
-rw-r--r--caffe2/utils/math-detail.h36
-rw-r--r--caffe2/utils/math.h53
-rw-r--r--caffe2/utils/math_cpu.cc119
-rw-r--r--caffe2/utils/math_gpu.cu423
-rw-r--r--caffe2/utils/math_gpu_test.cc57
18 files changed, 312 insertions, 896 deletions
diff --git a/caffe2/contrib/nervana/nervana_fc_op_gpu.cc b/caffe2/contrib/nervana/nervana_fc_op_gpu.cc
index 8d33a7c20a..b2328500c7 100644
--- a/caffe2/contrib/nervana/nervana_fc_op_gpu.cc
+++ b/caffe2/contrib/nervana/nervana_fc_op_gpu.cc
@@ -5,11 +5,8 @@
namespace caffe2 {
REGISTER_CUDA_OPERATOR_WITH_ENGINE(
- FC,
- NERVANA,
- FullyConnectedOp<CUDAContext, NervanaEngine>);
+ FC, NERVANA, FullyConnectedOp<float, CUDAContext, NervanaEngine>);
REGISTER_CUDA_OPERATOR_WITH_ENGINE(
- FCGradient,
- NERVANA,
- FullyConnectedGradientOp<CUDAContext, NervanaEngine>);
+ FCGradient, NERVANA,
+ FullyConnectedGradientOp<float, CUDAContext, NervanaEngine>);
} // namespace caffe2
diff --git a/caffe2/contrib/nervana/nervana_fc_op_gpu_test.cc b/caffe2/contrib/nervana/nervana_fc_op_gpu_test.cc
index 3eb0fc3ace..a3ae3bb45a 100644
--- a/caffe2/contrib/nervana/nervana_fc_op_gpu_test.cc
+++ b/caffe2/contrib/nervana/nervana_fc_op_gpu_test.cc
@@ -49,7 +49,7 @@ TEST(NervanaFullyConnectedTest, Test) {
AddConstInput(std::vector<int>{6, 10}, 1., "W", &ws);
AddConstInput(std::vector<int>{6}, 0.1, "B", &ws);
unique_ptr<OperatorBase> op(
- new FullyConnectedOp<CUDAContext, NervanaEngine>(def, &ws));
+ new FullyConnectedOp<float, CUDAContext, NervanaEngine>(def, &ws));
EXPECT_NE(nullptr, op.get());
EXPECT_TRUE(op->Run());
Blob* Yblob = ws.GetBlob("Y");
diff --git a/caffe2/contrib/nervana/nervana_math_gpu.cc b/caffe2/contrib/nervana/nervana_math_gpu.cc
index 09c70e4343..f3010b9b95 100644
--- a/caffe2/contrib/nervana/nervana_math_gpu.cc
+++ b/caffe2/contrib/nervana/nervana_math_gpu.cc
@@ -11,18 +11,10 @@ namespace math {
// limitation that the data has to be contiguous in memory.
template <>
void Gemm<float, CUDAContext, NervanaEngine>(
- const CBLAS_TRANSPOSE TransA,
- const CBLAS_TRANSPOSE TransB,
- const int M,
- const int N,
- const int K,
- const float alpha,
- const float* A,
- const float* B,
- const float beta,
- float* C,
- CUDAContext* context,
- TensorProto::DataType math_type) {
+ const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
+ const int M, const int N, const int K, const float alpha, const float* A,
+ const float* B, const float beta, float* C, CUDAContext* context) {
+
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (TransA == CblasNoTrans) ? K : M;
diff --git a/caffe2/operators/elementwise_op.cu b/caffe2/operators/elementwise_op.cu
index c75bd23867..016a555310 100644
--- a/caffe2/operators/elementwise_op.cu
+++ b/caffe2/operators/elementwise_op.cu
@@ -5,7 +5,6 @@
#include "caffe2/core/common_gpu.h"
#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/elementwise_op.h"
-#include "caffe2/utils/conversions.h"
namespace caffe2 {
@@ -63,6 +62,9 @@ REGISTER_CUDA_OPERATOR( \
name, BinaryElementwiseOp< \
input_type, CUDAContext, Cuda##name##Functor, output_type>)
+#define CUDA_ADD(x, y) ((x) + (y))
+CUDA_FUNCTOR(Add, CUDA_ADD, NumericTypes, SameTypeAsInput);
+#undef CUDA_ADD
#define CUDA_SUB(x, y) ((x) - (y))
CUDA_FUNCTOR(Sub, CUDA_SUB, NumericTypes, SameTypeAsInput);
#undef CUDA_SUB
@@ -262,58 +264,4 @@ bool SumReduceLikeOp<CUDAContext>::DoRunWithType() {
REGISTER_CUDA_OPERATOR(SumReduceLike, SumReduceLikeOp<CUDAContext>);
-namespace {
-
-template <typename T, typename M>
-__global__ void binary_add_kernel(const int N, const T* a, const T* b, T* r){
- CUDA_1D_KERNEL_LOOP(idx, N){
- r[idx] = convert::To<M, T>(
- convert::To<T, M>(a[idx]) + convert::To<T, M>(b[idx]));
-}
-};
-}
-; // namespace
-
-// Actual Add operator, because the above macros are read-only.
-class CUDAAddOp final : public Operator<CUDAContext> {
- public:
- CUDAAddOp(const OperatorDef& operator_def, Workspace* ws)
- : Operator<CUDAContext>(operator_def, ws){};
- ~CUDAAddOp() {}
-
- template <typename T, typename M>
- bool DoRunWithType() {
- auto& X0 = Input(0);
- auto& X1 = Input(1);
- auto* output = Output(0);
-
- output->ResizeLike(X0);
-
- binary_add_kernel<T, M><<<
- CAFFE_GET_BLOCKS(X0.size()),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context_.cuda_stream()>>>(
- X0.size(),
- X0.template data<T>(),
- X1.template data<T>(),
- output->template mutable_data<T>());
- return true;
- }
-
- bool RunOnDevice() override {
- if (Input(0).IsType<float>()) {
- return DoRunWithType<float, float>();
- } else if (Input(0).IsType<float16>()) {
- return DoRunWithType<float16, float>();
- } else {
- return false;
- }
- }
-};
-
-namespace {
-REGISTER_CUDA_OPERATOR(Add, CUDAAddOp);
-} // namespace
-
} // namespace caffe2
diff --git a/caffe2/operators/fully_connected_op.cc b/caffe2/operators/fully_connected_op.cc
index c00f199d5b..7a0e0bc5a5 100644
--- a/caffe2/operators/fully_connected_op.cc
+++ b/caffe2/operators/fully_connected_op.cc
@@ -3,8 +3,8 @@
namespace caffe2 {
namespace {
-REGISTER_CPU_OPERATOR(FC, FullyConnectedOp<CPUContext>);
-REGISTER_CPU_OPERATOR(FCGradient, FullyConnectedGradientOp<CPUContext>);
+REGISTER_CPU_OPERATOR(FC, FullyConnectedOp<float, CPUContext>);
+REGISTER_CPU_OPERATOR(FCGradient, FullyConnectedGradientOp<float, CPUContext>);
OPERATOR_SCHEMA(FC)
.NumInputs(3)
diff --git a/caffe2/operators/fully_connected_op.h b/caffe2/operators/fully_connected_op.h
index 45adc0a4d2..6e24c91103 100644
--- a/caffe2/operators/fully_connected_op.h
+++ b/caffe2/operators/fully_connected_op.h
@@ -3,13 +3,12 @@
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
-#include "caffe2/utils/conversions.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
// This is Caffe's InnerProductOp, with a name that fits its purpose better.
-template <class Context, class Engine = DefaultEngine>
+template <typename T, class Context, class Engine = DefaultEngine>
class FullyConnectedOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
@@ -18,13 +17,7 @@ class FullyConnectedOp final : public Operator<Context> {
axis_(OperatorBase::GetSingleArgument<int32_t>("axis", 1)) {}
~FullyConnectedOp() {}
- template <
- typename T_X,
- typename T_W,
- typename T_B,
- typename T_Y,
- typename MATH>
- bool DoRunWithType() {
+ bool RunOnDevice() override {
const auto& X = Input(0);
const auto& W = Input(1);
const auto& b = Input(2);
@@ -70,53 +63,44 @@ class FullyConnectedOp final : public Operator<Context> {
Y->Resize(Y_shape_cache_);
CAFFE_ENFORCE(M * N == Y->size(), dimErrorString());
- // W * x
- math::Gemm<T_X, Context, Engine>(
+ // X * W^T
+ math::Gemm<T, Context, Engine>(
CblasNoTrans,
CblasTrans,
M,
N,
K,
1,
- X.template data<T_X>(),
- W.template data<T_W>(),
+ X.template data<T>(),
+ W.template data<T>(),
0,
- Y->template mutable_data<T_Y>(),
+ Y->template mutable_data<T>(),
&context_);
// Add bias term
if (bias_multiplier_.size() != M) {
// If the helper bias multiplier is not M, reshape and fill it with one.
bias_multiplier_.Resize(M);
- math::Set<T_B, Context>(
+ math::Set<T, Context>(
M,
- convert::To<float, T_B>(1),
- bias_multiplier_.template mutable_data<T_B>(),
+ static_cast<T>(1),
+ bias_multiplier_.template mutable_data<T>(),
&context_);
}
- math::Gemm<T_B, Context, Engine>(
+ math::Gemm<T, Context, Engine>(
CblasNoTrans,
CblasNoTrans,
M,
N,
1,
1,
- bias_multiplier_.template data<T_B>(),
- b.template data<T_B>(),
+ bias_multiplier_.template data<T>(),
+ b.template data<T>(),
1,
- Y->template mutable_data<T_Y>(),
+ Y->template mutable_data<T>(),
&context_);
return true;
}
- bool RunOnDevice() override {
- return DoRunWithType<
- float, // X
- float, // W
- float, // B
- float, // Y
- float>(); // Math
- }
-
protected:
size_t axis_{1};
// A local vector to cache the output shape so we don't need to recreate
@@ -125,7 +109,7 @@ class FullyConnectedOp final : public Operator<Context> {
Tensor<Context> bias_multiplier_;
};
-template <class Context, class Engine = DefaultEngine>
+template <typename T, class Context, class Engine = DefaultEngine>
class FullyConnectedGradientOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
@@ -134,16 +118,7 @@ class FullyConnectedGradientOp : public Operator<Context> {
axis_(OperatorBase::GetSingleArgument<int32_t>("axis", 1)) {}
~FullyConnectedGradientOp() {}
- template <
- typename T_X,
- typename T_W,
- typename T_DY,
- typename T_B,
- typename T_DX,
- typename T_DW,
- typename T_DB,
- typename MATH>
- bool DoRunWithType() {
+ bool RunOnDevice() override {
const auto& X = Input(0);
const auto& W = Input(1);
const auto& dY = Input(2);
@@ -162,72 +137,60 @@ class FullyConnectedGradientOp : public Operator<Context> {
db->Resize(N);
// Compute dW
- math::Gemm<T_DY, Context, Engine>(
+ math::Gemm<T, Context, Engine>(
CblasTrans,
CblasNoTrans,
N,
K,
M,
- convert::To<float, MATH>(1),
- dY.template data<T_DY>(),
- X.template data<T_X>(),
- convert::To<float, MATH>(0),
- dW->template mutable_data<T_DW>(),
+ 1,
+ dY.template data<T>(),
+ X.template data<T>(),
+ 0,
+ dW->template mutable_data<T>(),
&context_);
if (bias_multiplier_.size() != M) {
// If the helper bias multiplier is not M, reshape and fill it
// with one.
bias_multiplier_.Resize(M);
- math::Set<T_B, Context>(
+ math::Set<T, Context>(
M,
- convert::To<float, T_B>(1),
- bias_multiplier_.template mutable_data<T_B>(),
+ static_cast<T>(1),
+ bias_multiplier_.template mutable_data<T>(),
&context_);
}
// Compute dB
- math::Gemv<T_DY, Context>(
+ math::Gemv<T, Context>(
CblasTrans,
M,
N,
- convert::To<float, MATH>(1),
- dY.template data<T_DY>(),
- bias_multiplier_.template data<T_B>(),
- convert::To<float, MATH>(0),
- db->template mutable_data<T_DB>(),
+ 1,
+ dY.template data<T>(),
+ bias_multiplier_.template data<T>(),
+ 0,
+ db->template mutable_data<T>(),
&context_);
// Compute dX
if (OutputSize() == 3) {
auto* dX = Output(2);
dX->ResizeLike(X);
- math::Gemm<T_DX, Context, Engine>(
+ math::Gemm<T, Context, Engine>(
CblasNoTrans,
CblasNoTrans,
M,
K,
N,
- convert::To<float, MATH>(1),
- dY.template data<T_DY>(),
- W.template data<T_W>(),
- convert::To<float, MATH>(0),
- dX->template mutable_data<T_DX>(),
+ 1,
+ dY.template data<T>(),
+ W.template data<T>(),
+ 0,
+ dX->template mutable_data<T>(),
&context_);
}
return true;
}
- bool RunOnDevice() override {
- return DoRunWithType<
- float, // X
- float, // W
- float, // dY
- float, // B
- float, // dX
- float, // dW
- float, // dB
- float>(); // Math
- }
-
protected:
size_t axis_{1};
Tensor<Context> bias_multiplier_;
diff --git a/caffe2/operators/fully_connected_op_gpu.cc b/caffe2/operators/fully_connected_op_gpu.cc
index 07431862f8..8ee67acd0b 100644
--- a/caffe2/operators/fully_connected_op_gpu.cc
+++ b/caffe2/operators/fully_connected_op_gpu.cc
@@ -2,60 +2,9 @@
#include "caffe2/operators/fully_connected_op.h"
namespace caffe2 {
-
-template <>
-bool FullyConnectedOp<CUDAContext>::RunOnDevice() {
- if (Input(0).IsType<float>()) {
- return DoRunWithType<
- float, // X
- float, // W
- float, // B
- float, // Y
- float>(); // Math
- } else if (Input(0).IsType<float16>()) {
- return DoRunWithType<
- float16, // X
- float16, // W
- float16, // B
- float16, // Y
- float>(); // Math
- } else {
- CAFFE_THROW("Unsupported type");
- }
- return false;
-}
-
-template <>
-bool FullyConnectedGradientOp<CUDAContext>::RunOnDevice() {
- if (Input(0).IsType<float>()) {
- return DoRunWithType<
- float, // X
- float, // W
- float, // dY
- float, // B
- float, // dX
- float, // dW
- float, // dB
- float>(); // Math
- } else if (Input(0).IsType<float16>()) {
- return DoRunWithType<
- float16, // X
- float16, // W
- float16, // dY
- float16, // B
- float16, // dX
- float16, // dW
- float16, // dB
- float>(); // Math
- } else {
- CAFFE_THROW("Unsupported type");
- }
- return false;
-}
-
namespace {
-
-REGISTER_CUDA_OPERATOR(FC, FullyConnectedOp<CUDAContext>);
-REGISTER_CUDA_OPERATOR(FCGradient, FullyConnectedGradientOp<CUDAContext>);
+REGISTER_CUDA_OPERATOR(FC, FullyConnectedOp<float, CUDAContext>);
+REGISTER_CUDA_OPERATOR(FCGradient,
+ FullyConnectedGradientOp<float, CUDAContext>);
} // namespace
} // namespace caffe2
diff --git a/caffe2/operators/sparse_to_dense_op.h b/caffe2/operators/sparse_to_dense_op.h
index 439d96ce65..d48b61755f 100644
--- a/caffe2/operators/sparse_to_dense_op.h
+++ b/caffe2/operators/sparse_to_dense_op.h
@@ -50,6 +50,7 @@ class SparseToDenseOp final : public Operator<Context> {
return DispatchHelper<
TensorTypes2<
float,
+ double,
int32_t,
int64_t,
GenericTensorImplementation>,
diff --git a/caffe2/operators/square_root_divide_op.h b/caffe2/operators/square_root_divide_op.h
index df018bf0d6..644c2bd96f 100644
--- a/caffe2/operators/square_root_divide_op.h
+++ b/caffe2/operators/square_root_divide_op.h
@@ -17,7 +17,7 @@ class SquareRootDivideOp final : public Operator<Context> {
: Operator<Context>(operator_def, ws) {}
bool RunOnDevice() override {
- return DispatchHelper<TensorTypes<float>>::call(this, Input(DATA));
+ return DispatchHelper<TensorTypes<float, double>>::call(this, Input(DATA));
}
private:
diff --git a/caffe2/operators/utility_ops.cc b/caffe2/operators/utility_ops.cc
index 99f0f20fa1..771da9c2e9 100644
--- a/caffe2/operators/utility_ops.cc
+++ b/caffe2/operators/utility_ops.cc
@@ -3,12 +3,6 @@
#include <cmath>
namespace caffe2 {
-
-template <>
-bool WeightedSumOp<CPUContext>::RunOnDevice() {
- return DoRunWithType<float>();
-}
-
namespace {
REGISTER_CPU_OPERATOR(WallClockTime, WallClockTimeOp<CPUContext>);
@@ -18,9 +12,10 @@ REGISTER_CPU_OPERATOR(FlattenToVec, FlattenToVecOp<CPUContext>);
REGISTER_CPU_OPERATOR(Alias, AliasOp<CPUContext>);
REGISTER_CPU_OPERATOR(ResizeLike, ResizeLikeOp<CPUContext>);
-REGISTER_CPU_OPERATOR(Sum, SumOp<CPUContext>);
-REGISTER_CPU_OPERATOR(SumInt, SumOp<CPUContext>);
-REGISTER_CPU_OPERATOR(WeightedSum, WeightedSumOp<CPUContext>);
+REGISTER_CPU_OPERATOR(Sum, SumOp<float, CPUContext>);
+REGISTER_CPU_OPERATOR(SumInt, SumOp<int, CPUContext>);
+
+REGISTER_CPU_OPERATOR(WeightedSum, WeightedSumOp<float, CPUContext>);
REGISTER_CPU_OPERATOR(
ScatterWeightedSum,
ScatterWeightedSumOp<float, CPUContext>);
diff --git a/caffe2/operators/utility_ops.h b/caffe2/operators/utility_ops.h
index 85f08d2b99..95572317ef 100644
--- a/caffe2/operators/utility_ops.h
+++ b/caffe2/operators/utility_ops.h
@@ -250,14 +250,13 @@ class ResizeLikeOp : public Operator<Context> {
}
};
-template <class Context>
+template <typename T, class Context>
class SumOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
USE_SIMPLE_CTOR_DTOR(SumOp);
- template <typename T, typename M>
- bool DoRunWithType() {
+ bool RunOnDevice() override {
auto& input0 = Input(0);
auto* output = Output(0);
if (InputSize() == 1) {
@@ -298,16 +297,6 @@ class SumOp : public Operator<Context> {
}
return true;
}
-
- bool RunOnDevice() override {
- if (Input(0).template IsType<float>()) {
- return DoRunWithType<float, float>();
- } else if (Input(0).template IsType<int>()) {
- return DoRunWithType<int, int>();
- } else {
- return false;
- }
- }
};
// WeightedSumOp computes the weighted sum of several tensors. The input should
@@ -315,14 +304,13 @@ class SumOp : public Operator<Context> {
// shape, and weight_i are size 1 tensors that specifies the weight of each
// vector. Note that if one wants to do in-place computation, it could only be
// done with X_0 also as the output, but not other X_i.
-template <class Context>
+template <typename T, class Context>
class WeightedSumOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
USE_SIMPLE_CTOR_DTOR(WeightedSumOp);
- template <typename DstType>
- bool DoRunWithType() {
+ bool RunOnDevice() override {
DCHECK_EQ(InputSize() % 2, 0);
auto& X0 = Input(0);
auto& weight0 = Input(1);
@@ -331,11 +319,11 @@ class WeightedSumOp : public Operator<Context> {
int size = X0.size();
auto* output = Output(0);
output->ResizeLike(X0);
- math::Scale<DstType, Context>(
+ math::Scale<T, Context>(
size,
- weight0.template data<float>(),
- X0.template data<DstType>(),
- output->template mutable_data<DstType>(),
+ weight0.template data<T>(),
+ X0.template data<T>(),
+ output->template mutable_data<T>(),
&context_);
for (int i = 2; i < InputSize(); i += 2) {
auto& X = Input(i);
@@ -350,16 +338,15 @@ class WeightedSumOp : public Operator<Context> {
auto& weight = Input(i + 1);
DCHECK_EQ(X.size(), size);
DCHECK_EQ(weight.size(), 1);
- math::Axpy<DstType, Context>(
+ math::Axpy<T, Context>(
size,
- weight.template data<float>(),
- X.template data<DstType>(),
- output->template mutable_data<DstType>(),
+ weight.template data<T>(),
+ X.template data<T>(),
+ output->template mutable_data<T>(),
&context_);
}
return true;
}
- bool RunOnDevice() override;
};
/**
diff --git a/caffe2/operators/utility_ops_gpu.cc b/caffe2/operators/utility_ops_gpu.cc
index 0bb035328a..b3df226886 100644
--- a/caffe2/operators/utility_ops_gpu.cc
+++ b/caffe2/operators/utility_ops_gpu.cc
@@ -5,30 +5,6 @@
namespace caffe2 {
template <>
-bool WeightedSumOp<CUDAContext>::RunOnDevice() {
- if (Input(0).IsType<float>()) {
- return DoRunWithType<float>();
- } else if (Input(0).IsType<float16>()) {
- return DoRunWithType<float16>();
- } else {
- CAFFE_THROW("Unsupported inputs");
- }
- return false;
-}
-
-template <>
-bool SumOp<CUDAContext>::RunOnDevice() {
- if (Input(0).IsType<float>()) {
- return DoRunWithType<float, float>();
- } else if (Input(0).IsType<float16>()) {
- return DoRunWithType<float16, float16>();
- } else {
- CAFFE_THROW("Unsupported inputs");
- }
- return false;
-}
-
-template <>
class CopyOnDeviceLikeOp<CUDAContext, CUDAContext, CUDAContext>
: public Operator<CUDAContext> {
public:
@@ -59,7 +35,9 @@ REGISTER_CUDA_OPERATOR(Squeeze, SqueezeOp<CUDAContext>);
REGISTER_CUDA_OPERATOR(ExpandDims, ExpandDimsOp<CUDAContext>);
REGISTER_CUDA_OPERATOR(Alias, AliasOp<CUDAContext>);
REGISTER_CUDA_OPERATOR(ResizeLike, ResizeLikeOp<CUDAContext>);
-REGISTER_CUDA_OPERATOR(WeightedSum, WeightedSumOp<CUDAContext>);
+REGISTER_CUDA_OPERATOR(Sum, SumOp<float, CUDAContext>);
+
+REGISTER_CUDA_OPERATOR(WeightedSum, WeightedSumOp<float, CUDAContext>);
REGISTER_CUDA_OPERATOR(Shape, ShapeOp<CUDAContext>);
// From whatever the current context, ensure the output is TensorCPU
REGISTER_CUDA_OPERATOR(
diff --git a/caffe2/utils/conversions.h b/caffe2/utils/conversions.h
deleted file mode 100644
index 0c6c3238f7..0000000000
--- a/caffe2/utils/conversions.h
+++ /dev/null
@@ -1,182 +0,0 @@
-#pragma once
-
-#include <caffe2/core/types.h>
-
-#ifdef __CUDA_ARCH__
-#include <cuda_fp16.h>
-#endif
-
-#ifdef __CUDA_ARCH__
-#define CONVERSIONS_DECL __host__ __device__ inline
-#else
-#define CONVERSIONS_DECL inline
-#endif
-
-namespace caffe2 {
-
-namespace convert {
-
-namespace {
-inline float16 cpu_float2half_rn(float f) {
- float16 ret;
-
- static_assert(
- sizeof(unsigned int) == sizeof(float),
- "Programming error sizeof(unsigned int) != sizeof(float)");
-
- unsigned* xp = reinterpret_cast<unsigned int*>(&f);
- unsigned x = *xp;
- unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1;
- unsigned sign, exponent, mantissa;
-
- // Get rid of +NaN/-NaN case first.
- if (u > 0x7f800000) {
- ret.x = 0x7fffU;
- return ret;
- }
-
- sign = ((x >> 16) & 0x8000);
-
- // Get rid of +Inf/-Inf, +0/-0.
- if (u > 0x477fefff) {
- ret.x = sign | 0x7c00U;
- return ret;
- }
- if (u < 0x33000001) {
- ret.x = (sign | 0x0000);
- return ret;
- }
-
- exponent = ((u >> 23) & 0xff);
- mantissa = (u & 0x7fffff);
-
- if (exponent > 0x70) {
- shift = 13;
- exponent -= 0x70;
- } else {
- shift = 0x7e - exponent;
- exponent = 0;
- mantissa |= 0x800000;
- }
- lsb = (1 << shift);
- lsb_s1 = (lsb >> 1);
- lsb_m1 = (lsb - 1);
-
- // Round to nearest even.
- remainder = (mantissa & lsb_m1);
- mantissa >>= shift;
- if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) {
- ++mantissa;
- if (!(mantissa & 0x3ff)) {
- ++exponent;
- mantissa = 0;
- }
- }
-
- ret.x = (sign | (exponent << 10) | mantissa);
-
- return ret;
-}
-
-inline float cpu_half2float(float16 h) {
- unsigned sign = ((h.x >> 15) & 1);
- unsigned exponent = ((h.x >> 10) & 0x1f);
- unsigned mantissa = ((h.x & 0x3ff) << 13);
-
- if (exponent == 0x1f) { /* NaN or Inf */
- mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0);
- exponent = 0xff;
- } else if (!exponent) { /* Denorm or Zero */
- if (mantissa) {
- unsigned int msb;
- exponent = 0x71;
- do {
- msb = (mantissa & 0x400000);
- mantissa <<= 1; /* normalize */
- --exponent;
- } while (!msb);
- mantissa &= 0x7fffff; /* 1.mantissa is implicit */
- }
- } else {
- exponent += 0x70;
- }
-
- int temp = ((sign << 31) | (exponent << 23) | mantissa);
-
- unsigned* rp = reinterpret_cast<unsigned*>(&temp);
- return *rp;
-}
-
-}; // anonymous
-// general version: defer to static_cast
-template <typename IN, typename OUT>
-CONVERSIONS_DECL OUT To(const IN in) {
- return static_cast<OUT>(in);
-}
-
-#if __CUDA_ARCH__
-__device__ __inline__ __half inf_clip(__half h) {
- int isi = __hisinf(h);
- if (isi > 0) {
- // Exponent all ones except LSB (0x1e), mantissa is all ones (0x3ff)
- h.x = 0x7bffU;
- } else if (isi < 0) {
- // As above, negated
- h.x = 0x7bffU ^ 0x8000;
- }
- return h;
-}
-#endif
-
-// explicit for fp16
-template <>
-CONVERSIONS_DECL float16 To(const float in) {
-#if __CUDA_ARCH__
- // hacky interface between C2 fp16 and CUDA
- float16 ret;
- __half r;
- // r.x = __float2half_rn(in);
- // ret.x = inf_clip(r).x;
- ret.x = __float2half(in).x;
- return ret;
-#else
- return cpu_float2half_rn(in);
-#endif
-}
-
-template <>
-CONVERSIONS_DECL float To(const float16 in) {
-#if __CUDA_ARCH__
- __half tmp;
- tmp.x = in.x;
- return __half2float(tmp);
-#else
- return cpu_half2float(in);
-#endif
-};
-
-template <>
-CONVERSIONS_DECL float To(const float in) {
- return in;
-}
-
-template <typename OUT, typename IN>
-CONVERSIONS_DECL OUT Get(IN x) {
- return static_cast<OUT>(x);
-}
-
-template <>
-CONVERSIONS_DECL float Get(float16 x) {
- return To<float16, float>(x);
-}
-
-template <>
-CONVERSIONS_DECL float16 Get(float x) {
- return To<float, float16>(x);
-}
-
-}; // namespace convert
-
-}; // namespace caffe2
-
-#undef CONVERSIONS_DECL
diff --git a/caffe2/utils/math-detail.h b/caffe2/utils/math-detail.h
index 07a1f997d6..35a880a6d4 100644
--- a/caffe2/utils/math-detail.h
+++ b/caffe2/utils/math-detail.h
@@ -11,12 +11,8 @@ namespace detail {
template<typename T, class Context, int FixedSize>
struct ScaleImpl {
- inline void operator()(
- const int N,
- const float alpha,
- const T* x,
- T* y,
- Context* context) {
+ inline void
+ operator()(const int N, const T alpha, const T* x, T* y, Context* context) {
Scale(N, alpha, x, y, context);
}
};
@@ -26,7 +22,7 @@ template<typename T>
struct ScaleImpl<T, CPUContext, 1> {
inline void operator()(
const int N,
- const float alpha,
+ const T alpha,
const T* x,
T* y,
CPUContext* context) {
@@ -37,12 +33,8 @@ struct ScaleImpl<T, CPUContext, 1> {
template<typename T, class Context, int FixedSize>
struct AxpyImpl {
- inline void operator()(
- const int N,
- const float alpha,
- const T* x,
- T* y,
- Context* context) {
+ inline void
+ operator()(const int N, const T alpha, const T* x, T* y, Context* context) {
Axpy(N, alpha, x, y, context);
}
};
@@ -52,7 +44,7 @@ template<typename T>
struct AxpyImpl<T, CPUContext, 1> {
inline void operator()(
const int N,
- const float alpha,
+ const T alpha,
const T* x,
T* y,
CPUContext* context) {
@@ -65,22 +57,14 @@ struct AxpyImpl<T, CPUContext, 1> {
} // namespace detail
template <typename T, class Context, int FixedSize>
-inline void ScaleFixedSize(
- const int N,
- const float alpha,
- const T* x,
- T* y,
- Context* context) {
+inline void
+ScaleFixedSize(const int N, const T alpha, const T* x, T* y, Context* context) {
detail::ScaleImpl<T, Context, FixedSize>()(N, alpha, x, y, context);
}
template <typename T, class Context, int FixedSize>
-inline void AxpyFixedSize(
- const int N,
- const float alpha,
- const T* x,
- T* y,
- Context* context) {
+inline void
+AxpyFixedSize(const int N, const T alpha, const T* x, T* y, Context* context) {
detail::AxpyImpl<T, Context, FixedSize>()(N, alpha, x, y, context);
}
diff --git a/caffe2/utils/math.h b/caffe2/utils/math.h
index 105cb19733..a2472c0d33 100644
--- a/caffe2/utils/math.h
+++ b/caffe2/utils/math.h
@@ -141,20 +141,10 @@ void ColwiseMax(const int N, const int D, const T* x, T* y,
// Decaf gemm provides a simpler interface to the gemm functions, with the
// limitation that the data has to be contiguous in memory.
-template <typename T, class Context, class Engine = DefaultEngine>
-void Gemm(
- const CBLAS_TRANSPOSE TransA,
- const CBLAS_TRANSPOSE TransB,
- const int M,
- const int N,
- const int K,
- const float alpha,
- const T* A,
- const T* B,
- const float beta,
- T* C,
- Context* context,
- TensorProto::DataType math_type = TensorProto_DataType_FLOAT);
+template <typename T, class Context, class Engine=DefaultEngine>
+void Gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
+ const int M, const int N, const int K, const T alpha, const T* A,
+ const T* B, const T beta, T* C, Context* context);
// We also provide a gemm that has explicit lda, ldb and ldc specified.
// In most cases you probably want to use the function above, though.
@@ -179,18 +169,10 @@ void GemmEx(
// to Trans, the output is:
// CblasNoTrans: x is an N dim vector and y is an M dim vector.
// CblasTrans: x is an M dim vector and y is an N dim vector.
-template <typename T, class Context, class Engine = DefaultEngine>
-void Gemv(
- const CBLAS_TRANSPOSE TransA,
- const int M,
- const int N,
- const float alpha,
- const T* A,
- const T* x,
- const float beta,
- T* y,
- Context* context,
- TensorProto::DataType math_type = TensorProto_DataType_FLOAT);
+template <typename T, class Context, class Engine=DefaultEngine>
+void Gemv(const CBLAS_TRANSPOSE TransA, const int M, const int N,
+ const T alpha, const T* A, const T* x, const T beta,
+ T* y, Context* context);
template <typename T, class Context>
void Set(const TIndex N, const T alpha, T* X, Context* context);
@@ -236,31 +218,28 @@ void Select(const int N, const int D, const T* x, const int* idx, T* y,
Context* context);
template <typename T, class Context>
-void Scale(const int N, const float alpha, const T* x, T* y, Context* context);
+void Scale(const int N, const T alpha, const T* x, T* y, Context* context);
// Different from the Scale function above, if alpha is passed in
// as a pointer, we will assume that it lives on the Context device,
// for example on GPU.
template <typename T, class Context>
-void Scale(const int N, const float* alpha, const T* x, T* y, Context* context);
+void Scale(const int N, const T* alpha, const T* x, T* y,
+ Context* context);
template <typename T, class Context>
-void Axpy(const int N, const float alpha, const T* x, T* y, Context* context);
+void Axpy(const int N, const T alpha, const T* x, T* y, Context* context);
// Different from the Axpy function above, if alpha is passed in
// as a pointer, we will assume that it lives on the Context device,
// for example on GPU.
template <typename T, class Context>
-void Axpy(const int N, const float* alpha, const T* x, T* y, Context* context);
+void Axpy(const int N, const T* alpha, const T* x, T* y,
+ Context* context);
template <typename T, class Context>
-void Axpby(
- const int N,
- const float alpha,
- const T* x,
- const T b,
- T* y,
- Context* context);
+void Axpby(const int N, const T alpha, const T* x, const T b, T* y,
+ Context* context);
template <typename T, class Context, int order>
void Im2colNd(
diff --git a/caffe2/utils/math_cpu.cc b/caffe2/utils/math_cpu.cc
index e4340dfaf1..5cac0c8339 100644
--- a/caffe2/utils/math_cpu.cc
+++ b/caffe2/utils/math_cpu.cc
@@ -58,18 +58,9 @@ namespace math {
// CblasTrans, respectively, for each of A and B.
template <>
void Gemm<float, CPUContext>(
- const CBLAS_TRANSPOSE TransA,
- const CBLAS_TRANSPOSE TransB,
- const int M,
- const int N,
- const int K,
- const float alpha,
- const float* A,
- const float* B,
- const float beta,
- float* C,
- CPUContext* context,
- TensorProto::DataType math_type) {
+ const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
+ const int M, const int N, const int K, const float alpha, const float* A,
+ const float* B, const float beta, float* C, CPUContext* context) {
auto C_mat = EigenMatrixMap<float>(C, N, M);
if (beta == 0) {
C_mat.setZero();
@@ -187,8 +178,7 @@ void Gemv<float, CPUContext>(
const float* x,
const float beta,
float* y,
- CPUContext* context,
- TensorProto::DataType math_type) {
+ CPUContext* context) {
EigenVectorMap<float> y_vec(y, TransA == CblasNoTrans ? M : N);
if (beta == 0) {
// In Caffe2 we often do a lazy initialization, which may contain NaNs in
@@ -215,22 +205,19 @@ void Gemv<float, CPUContext>(
}
}
-#define CAFFE2_SPECIALIZED_SCALE(T) \
- template <> \
- void Scale<T, CPUContext>( \
- const int n, const float alpha, const T* x, T* y, CPUContext* context) { \
- EigenVectorMap<T>(y, n) = ConstEigenVectorMap<T>(x, n) * alpha; \
- } \
- template <> \
- void Scale<T, CPUContext>( \
- const int n, \
- const float* alpha, \
- const T* x, \
- T* y, \
- CPUContext* context) { \
- EigenVectorMap<T>(y, n) = ConstEigenVectorMap<T>(x, n) * (*alpha); \
+#define CAFFE2_SPECIALIZED_SCALE(T) \
+ template <> \
+ void Scale<T, CPUContext>( \
+ const int n, const T alpha, const T* x, T* y, CPUContext* context) { \
+ EigenVectorMap<T>(y, n) = ConstEigenVectorMap<T>(x, n) * alpha; \
+ } \
+ template <> \
+ void Scale<T, CPUContext>( \
+ const int n, const T* alpha, const T* x, T* y, CPUContext* context) { \
+ EigenVectorMap<T>(y, n) = ConstEigenVectorMap<T>(x, n) * (*alpha); \
}
CAFFE2_SPECIALIZED_SCALE(float)
+CAFFE2_SPECIALIZED_SCALE(double)
#undef CAFFE2_SPECIALIZED_SCALE
#define CAFFE2_SPECIALIZED_DOT(T) \
@@ -241,6 +228,7 @@ void Dot<T, CPUContext>( \
*y = ConstEigenVectorMap<T>(a, N).dot(ConstEigenVectorMap<T>(b, N)); \
}
CAFFE2_SPECIALIZED_DOT(float)
+CAFFE2_SPECIALIZED_DOT(double)
#undef CAFFE2_SPECIALIZED_DOT
#define CAFFE2_SPECIALIZED_AXPY(T) \
@@ -255,6 +243,7 @@ CAFFE2_SPECIALIZED_DOT(float)
EigenVectorMap<T>(Y, N) += ConstEigenVectorMap<T>(x, N) * (*alpha); \
}
CAFFE2_SPECIALIZED_AXPY(float)
+CAFFE2_SPECIALIZED_AXPY(double)
#undef CAFFE2_SPECIALIZED_AXPY
#define CAFFE2_SPECIALIZED_AXPBY(T) \
@@ -265,24 +254,16 @@ void Axpby<T, CPUContext>(const int N, const T alpha, const T* x, \
y_vec = y_vec * beta + ConstEigenVectorMap<T>(x, N) * alpha; \
}
CAFFE2_SPECIALIZED_AXPBY(float)
+CAFFE2_SPECIALIZED_AXPBY(double)
#undef CAFFE2_SPECIALIZED_AXPBY
#else // CAFFE2_USE_EIGEN_FOR_BLAS
template <>
void Gemm<float, CPUContext>(
- const CBLAS_TRANSPOSE TransA,
- const CBLAS_TRANSPOSE TransB,
- const int M,
- const int N,
- const int K,
- const float alpha,
- const float* A,
- const float* B,
- const float beta,
- float* C,
- CPUContext* context,
- TensorProto::DataType math_type) {
+ const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
+ const int M, const int N, const int K, const float alpha, const float* A,
+ const float* B, const float beta, float* C, CPUContext* context) {
int lda = (TransA == CblasNoTrans) ? K : M;
int ldb = (TransB == CblasNoTrans) ? N : K;
cblas_sgemm(CblasRowMajor, TransA, TransB, M, N, K, alpha, A, lda, B, ldb,
@@ -311,39 +292,29 @@ void GemmEx<float, CPUContext>(
template <>
void Gemv<float, CPUContext>(
- const CBLAS_TRANSPOSE TransA,
- const int M,
- const int N,
- const float alpha,
- const float* A,
- const float* x,
- const float beta,
- float* y,
- CPUContext* context,
- TensorProto::DataType math_type) {
+ const CBLAS_TRANSPOSE TransA, const int M, const int N, const float alpha,
+ const float* A, const float* x, const float beta, float* y,
+ CPUContext* context) {
cblas_sgemv(CblasRowMajor, TransA, M, N, alpha, A, N, x, 1, beta, y, 1);
}
-#define CAFFE2_SPECIALIZED_SCALE(T, prefix) \
- template <> \
- void Scale<T, CPUContext>( \
- const int n, const float alpha, const T* x, T* y, CPUContext* context) { \
- if (y != x) \
- cblas_##prefix##copy(n, x, 1, y, 1); \
- cblas_##prefix##scal(n, static_cast<float>(alpha), y, 1); \
- } \
- template <> \
- void Scale<T, CPUContext>( \
- const int n, \
- const float* alpha, \
- const T* x, \
- T* y, \
- CPUContext* context) { \
- if (y != x) \
- cblas_##prefix##copy(n, x, 1, y, 1); \
- cblas_##prefix##scal(n, static_cast<float>(*alpha), y, 1); \
+#define CAFFE2_SPECIALIZED_SCALE(T, prefix) \
+ template <> \
+ void Scale<T, CPUContext>( \
+ const int n, const T alpha, const T* x, T* y, CPUContext* context) { \
+ if (y != x) \
+ cblas_##prefix##copy(n, x, 1, y, 1); \
+ cblas_##prefix##scal(n, alpha, y, 1); \
+ } \
+ template <> \
+ void Scale<T, CPUContext>( \
+ const int n, const T* alpha, const T* x, T* y, CPUContext* context) { \
+ if (y != x) \
+ cblas_##prefix##copy(n, x, 1, y, 1); \
+ cblas_##prefix##scal(n, *alpha, y, 1); \
}
CAFFE2_SPECIALIZED_SCALE(float, s)
+CAFFE2_SPECIALIZED_SCALE(double, d)
#undef CAFFE2_SPECIALIZED_SCALE
#define CAFFE2_SPECIALIZED_DOT(T, prefix) \
@@ -354,6 +325,7 @@ void Dot<T, CPUContext>( \
*y = cblas_##prefix##dot(N, a, 1, b, 1); \
}
CAFFE2_SPECIALIZED_DOT(float, s)
+CAFFE2_SPECIALIZED_DOT(double, d)
#undef CAFFE2_SPECIALIZED_DOT
#define CAFFE2_SPECIALIZED_AXPY(T, prefix) \
@@ -368,6 +340,7 @@ CAFFE2_SPECIALIZED_DOT(float, s)
cblas_##prefix##axpy(N, *alpha, x, 1, y, 1); \
}
CAFFE2_SPECIALIZED_AXPY(float, s)
+CAFFE2_SPECIALIZED_AXPY(double, d)
#undef CAFFE2_SPECIALIZED_AXPY
// cblas_[sd]axpby is not a standard blas function, and if MKL is not present,
@@ -389,6 +362,7 @@ void Axpby<T, CPUContext>(const int N, const T alpha, const T* x, \
}
#endif // CAFFE2_USE_MKL
CAFFE2_SPECIALIZED_AXPBY(float, s)
+CAFFE2_SPECIALIZED_AXPBY(double, d)
#undef CAFFE2_SPECIALIZED_AXPBY
#endif // CAFFE2_USE_EIGEN_FOR_BLAS
@@ -462,8 +436,11 @@ void Funcname<T, CPUContext>(const int N, const T* x, T* y, \
EigenVectorMap<T>(y, N) = ConstEigenVectorMap<T>(x, N).array().expr(); \
}
DELEGATE_SIMPLE_UNARY_FUNCTION(float, Exp, exp)
+DELEGATE_SIMPLE_UNARY_FUNCTION(double, Exp, exp)
DELEGATE_SIMPLE_UNARY_FUNCTION(float, Log, log)
+DELEGATE_SIMPLE_UNARY_FUNCTION(double, Log, log)
DELEGATE_SIMPLE_UNARY_FUNCTION(float, Sqr, square)
+DELEGATE_SIMPLE_UNARY_FUNCTION(double, Sqr, square)
#undef DELEGATE_SIMPLE_UNARY_FUNCTION
#define DELEGATE_POWX_FUNCTION(T) \
@@ -473,6 +450,7 @@ void Powx<T, CPUContext>( \
EigenVectorMap<T>(y, N) = ConstEigenVectorMap<T>(a, N).array().pow(b); \
}
DELEGATE_POWX_FUNCTION(float)
+DELEGATE_POWX_FUNCTION(double)
#undef DELEGATE_POWX_FUNCTION
#endif // CAFFE2_USE_MKL
@@ -498,6 +476,7 @@ EIGEN_SIMPLE_BINARY_FUNCTION(int64_t, Funcname, expr)
#define DEFINE_SIMPLE_BINARY_FUNCTION(Funcname, expr) \
EIGEN_SIMPLE_BINARY_FUNCTION(float, Funcname, expr) \
+EIGEN_SIMPLE_BINARY_FUNCTION(double, Funcname, expr) \
EIGEN_SIMPLE_BINARY_FUNCTION(int32_t, Funcname, expr) \
EIGEN_SIMPLE_BINARY_FUNCTION(int64_t, Funcname, expr)
@@ -567,6 +546,7 @@ CAFFE2_SPECIALIZED_COLWISEMAX(float)
DELEGATE_BROADCAST_BINARY_FUNCTION(int32_t, name, op) \
DELEGATE_BROADCAST_BINARY_FUNCTION(int64_t, name, op) \
DELEGATE_BROADCAST_BINARY_FUNCTION(float, name, op) \
+ DELEGATE_BROADCAST_BINARY_FUNCTION(double, name, op)
DEFINE_BROADCAST_BINARY_FUNCTION(Add, +)
DEFINE_BROADCAST_BINARY_FUNCTION(Sub, -)
@@ -622,6 +602,7 @@ CAFFE2_SPECIALIZED_SET(uint16_t);
#define CAFFE2_DEFINE_BINARY_OP(name, op) \
CAFFE2_INSTANTIATE_BINARY_OP(name, op, float) \
+ CAFFE2_INSTANTIATE_BINARY_OP(name, op, double) \
CAFFE2_INSTANTIATE_BINARY_OP(name, op, int32_t) \
CAFFE2_INSTANTIATE_BINARY_OP(name, op, int64_t)
@@ -663,6 +644,7 @@ void Not<bool, CPUContext>(
}
CAFFE2_SPECIALIZED_CPU_ADD_STRIPED_BATCH(float);
+CAFFE2_SPECIALIZED_CPU_ADD_STRIPED_BATCH(double);
#undef CAFFE2_SPECIALIZED_CPU_ADD_STRIPED_BATCH
template <>
@@ -735,6 +717,7 @@ void RandGaussian<float, CPUContext>(
}
CAFFE2_SPECIALIZED_SUM(float);
+CAFFE2_SPECIALIZED_SUM(double);
CAFFE2_SPECIALIZED_SUM(int32_t);
CAFFE2_SPECIALIZED_SUM(int64_t);
diff --git a/caffe2/utils/math_gpu.cu b/caffe2/utils/math_gpu.cu
index 46c5bc02f1..cfbe91b5d0 100644
--- a/caffe2/utils/math_gpu.cu
+++ b/caffe2/utils/math_gpu.cu
@@ -5,9 +5,8 @@
#include <thrust/system/cuda/detail/par.h>
#include <thrust/version.h>
-#include "caffe2/core/context_gpu.h"
-#include "caffe2/utils/conversions.h"
#include "caffe2/utils/math.h"
+#include "caffe2/core/context_gpu.h"
#if THRUST_VERSION >= 100800
#define THRUST_SUPPORTS_PER_THREAD
@@ -33,30 +32,33 @@ void Funcname<T, CUDAContext>( \
}
DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Exp, expf);
+DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Exp, exp);
DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Log, logf);
+DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Log, log);
__device__ float cuda_sqrf(const float x) { return x * x; }
+__device__ double cuda_sqr(const double x) { return x * x; }
DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sqr, cuda_sqrf);
+DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Sqr, cuda_sqr);
#undef DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION
-#define DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(T, Funcname, expr) \
- __global__ void _Kernel_##T##_##Funcname( \
- const int N, const T* a, const T* b, T* y) { \
- CUDA_1D_KERNEL_LOOP(i, N) { \
- float r = convert::To<T, float>(a[i]) expr convert::To<T, float>(b[i]); \
- y[i] = convert::To<float, T>(r); \
- } \
- } \
- template <> \
- void Funcname<T, CUDAContext>( \
- const int N, const T* a, const T* b, T* y, CUDAContext* context) { \
- _Kernel_##T##_##Funcname<<< \
- CAFFE_GET_BLOCKS(N), \
- CAFFE_CUDA_NUM_THREADS, \
- 0, \
- context->cuda_stream()>>>(N, a, b, y); \
+#define DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(T, Funcname, expr) \
+ __global__ void _Kernel_##T##_##Funcname( \
+ const int N, const T* a, const T* b, T* y) { \
+ CUDA_1D_KERNEL_LOOP(i, N) { \
+ y[i] = a[i] expr b[i]; \
+ } \
+ } \
+ template <> \
+ void Funcname<T, CUDAContext>( \
+ const int N, const T* a, const T* b, T* y, CUDAContext* context) { \
+ _Kernel_##T##_##Funcname<<< \
+ CAFFE_GET_BLOCKS(N), \
+ CAFFE_CUDA_NUM_THREADS, \
+ 0, \
+ context->cuda_stream()>>>(N, a, b, y); \
}
DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Add, +);
@@ -64,27 +66,13 @@ DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Sub, -);
DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Mul, *);
DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Div, /);
-DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float16, Add, +);
-DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float16, Sub, -);
-DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float16, Mul, *);
-DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float16, Div, /);
-
// Caffe2 gemm provides a simpler interface to the gemm functions, with the
// limitation that the data has to be contiguous in memory.
template <>
void Gemm<float, CUDAContext>(
- const CBLAS_TRANSPOSE TransA,
- const CBLAS_TRANSPOSE TransB,
- const int M,
- const int N,
- const int K,
- const float alpha,
- const float* A,
- const float* B,
- const float beta,
- float* C,
- CUDAContext* context,
- TensorProto::DataType math_type) {
+ const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
+ const int M, const int N, const int K, const float alpha, const float* A,
+ const float* B, const float beta, float* C, CUDAContext* context) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (TransA == CblasNoTrans) ? K : M;
@@ -111,91 +99,11 @@ void Gemm<float, CUDAContext>(
}
template <>
-void Gemm<float16, CUDAContext>(
- const CBLAS_TRANSPOSE TransA,
- const CBLAS_TRANSPOSE TransB,
- const int M,
- const int N,
- const int K,
- const float alpha,
- const float16* A,
- const float16* B,
- const float beta,
- float16* C,
- CUDAContext* context,
- TensorProto::DataType math_type) {
- // Note that cublas follows fortran order, so the order is different from
- // the cblas convention.
- int lda = (TransA == CblasNoTrans) ? K : M;
- int ldb = (TransB == CblasNoTrans) ? N : K;
- cublasOperation_t cuTransA =
- (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
- cublasOperation_t cuTransB =
- (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
- if (math_type == TensorProto_DataType_FLOAT) {
- CUBLAS_CHECK(cublasSgemmEx(
- context->cublas_handle(),
- cuTransB,
- cuTransA,
- N,
- M,
- K,
- &alpha,
- B,
- CUDA_R_16F,
- ldb,
- A,
- CUDA_R_16F,
- lda,
- &beta,
- C,
- CUDA_R_16F,
- N));
-
- } else if (math_type == TensorProto_DataType_FLOAT16) {
- // convert alpha, beta from caffe2::float16 -> __half
- __half alpha_fp16;
- alpha_fp16.x = convert::To<float, float16>(alpha).x;
- __half beta_fp16;
- beta_fp16.x = convert::To<float, float16>(beta).x;
- // call cublasHgemm
- CUBLAS_CHECK(cublasHgemm(
- context->cublas_handle(),
- cuTransB,
- cuTransA,
- N,
- M,
- K,
- &alpha_fp16,
- (const __half*)B,
- ldb,
- (const __half*)A,
- lda,
- &beta_fp16,
- (__half*)C,
- N));
- } else {
- // fail
- CAFFE_THROW("Unsupported math type");
- }
-}
-
-template <>
void GemmEx<float, CUDAContext>(
- const CBLAS_TRANSPOSE TransA,
- const CBLAS_TRANSPOSE TransB,
- const int M,
- const int N,
- const int K,
- const float alpha,
- const float* A,
- const int lda,
- const float* B,
- const int ldb,
- const float beta,
- float* C,
- const int ldc,
- CUDAContext* context) {
+ const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
+ const int M, const int N, const int K, const float alpha, const float* A,
+ const int lda, const float* B, const int ldb, const float beta, float* C,
+ const int ldc, CUDAContext* context) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t cuTransA =
@@ -221,19 +129,40 @@ void GemmEx<float, CUDAContext>(
template <>
void Gemv<float, CUDAContext>(
+ const CBLAS_TRANSPOSE TransA, const int M, const int N, const float alpha,
+ const float* A, const float* x, const float beta, float* y,
+ CUDAContext* context) {
+ cublasOperation_t cuTransA =
+ (TransA == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
+ CUBLAS_ENFORCE(cublasSgemv(
+ context->cublas_handle(),
+ cuTransA,
+ N,
+ M,
+ &alpha,
+ A,
+ N,
+ x,
+ 1,
+ &beta,
+ y,
+ 1));
+}
+
+template <>
+void Gemv<double, CUDAContext>(
const CBLAS_TRANSPOSE TransA,
const int M,
const int N,
- const float alpha,
- const float* A,
- const float* x,
- const float beta,
- float* y,
- CUDAContext* context,
- TensorProto::DataType math_type) {
+ const double alpha,
+ const double* A,
+ const double* x,
+ const double beta,
+ double* y,
+ CUDAContext* context) {
cublasOperation_t cuTransA =
(TransA == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
- CUBLAS_ENFORCE(cublasSgemv(
+ CUBLAS_ENFORCE(cublasDgemv(
context->cublas_handle(),
cuTransA,
N,
@@ -287,73 +216,6 @@ CAFFE2_SPECIALIZED_CUDA_ADD_STRIPED_BATCH(float);
CAFFE2_SPECIALIZED_CUDA_ADD_STRIPED_BATCH(double);
#undef CAFFE2_SPECIALIZED_CUDA_ADD_STRIPED_BATCH
-template <>
-void Gemv<float16, CUDAContext>(
- const CBLAS_TRANSPOSE TransA,
- const int M,
- const int N,
- const float alpha,
- const float16* A,
- const float16* x,
- const float beta,
- float16* y,
- CUDAContext* context,
- TensorProto::DataType math_type) {
- cublasOperation_t cuTransA =
- (TransA == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
-
- // sort out what we need to call cublasSgemmEx / cublasHgemm
- int m = (cuTransA == CUBLAS_OP_N) ? N : M;
- int k = (cuTransA == CUBLAS_OP_N) ? M : N;
- int LDA = (cuTransA == CUBLAS_OP_N) ? m : k;
- int LDC = m;
-
- if (math_type == TensorProto_DataType_FLOAT) {
- CUBLAS_CHECK(cublasSgemmEx(
- context->cublas_handle(),
- cuTransA,
- CUBLAS_OP_N,
- m,
- 1,
- k,
- &alpha,
- A,
- CUDA_R_16F,
- LDA,
- x,
- CUDA_R_16F,
- k,
- &beta,
- y,
- CUDA_R_16F,
- LDC));
- } else if (math_type == TensorProto_DataType_FLOAT16) {
- __half alpha_fp16;
- alpha_fp16.x = convert::To<float, float16>(alpha).x;
- __half beta_fp16;
- beta_fp16.x = convert::To<float, float16>(beta).x;
-
- CUBLAS_CHECK(cublasHgemm(
- context->cublas_handle(),
- cuTransA,
- CUBLAS_OP_N,
- m,
- 1,
- k,
- &alpha_fp16,
- (const __half*)A,
- LDA,
- (const __half*)x,
- k,
- &beta_fp16,
- (__half*)y,
- LDC));
- } else {
- // fail
- CAFFE_THROW("Unsupported math type");
- }
-}
-
namespace {
template <typename T>
__global__ void SetKernel(const int N, const T alpha, T* Y) {
@@ -376,7 +238,6 @@ CAFFE2_SPECIALIZED_CUDA_SET(double);
CAFFE2_SPECIALIZED_CUDA_SET(bool);
CAFFE2_SPECIALIZED_CUDA_SET(int8_t);
CAFFE2_SPECIALIZED_CUDA_SET(int16_t);
-CAFFE2_SPECIALIZED_CUDA_SET(float16);
CAFFE2_SPECIALIZED_CUDA_SET(int);
CAFFE2_SPECIALIZED_CUDA_SET(int64_t);
CAFFE2_SPECIALIZED_CUDA_SET(char);
@@ -386,11 +247,11 @@ CAFFE2_SPECIALIZED_CUDA_SET(uint16_t);
namespace {
template <typename T>
-__global__ void
-UniformShift(const int N, const float min, const float max, T* x) {
- float scale = max - min;
+__global__ void UniformShift(const int N, const T min, const T max,
+ T* x) {
+ T scale = max - min;
CUDA_1D_KERNEL_LOOP(i, N) {
- x[i] = convert::To<float, T>(convert::To<T, float>(x[i]) * scale + min);
+ x[i] = x[i] * scale + min;
}
}
@@ -475,6 +336,7 @@ void RandGaussian<double, CUDAContext>(
context->curand_generator(), r, even_n, mean, std));
}
+
template<>
void Dot<float, CUDAContext>(
const int n, const float* a, const float* b, float* y,
@@ -484,28 +346,13 @@ void Dot<float, CUDAContext>(
context->Copy<float, CPUContext, CUDAContext>(1, &result, y);
}
-template <>
-void Dot<float16, CUDAContext>(
- const int n,
- const float16* a,
- const float16* b,
- float16* y,
+template<>
+void Dot<double, CUDAContext>(
+ const int n, const double* a, const double* b, double* y,
CUDAContext* context) {
- float16 result;
- // execute with 32-bit math
- CUBLAS_CHECK(cublasDotEx(
- context->cublas_handle(),
- n,
- a,
- CUDA_R_16F,
- 1,
- b,
- CUDA_R_16F,
- 1,
- &result,
- CUDA_R_16F,
- CUDA_R_32F));
- context->Copy<float16, CPUContext, CUDAContext>(1, &result, y);
+ double result;
+ CUBLAS_ENFORCE(cublasDdot(context->cublas_handle(), n, a, 1, b, 1, y));
+ context->Copy<double, CPUContext, CUDAContext>(1, &result, y);
}
// A previous version of caffe2 used Thrust but it turns out that thrust
@@ -516,7 +363,7 @@ void Dot<float16, CUDAContext>(
template <typename T>
__global__ void SumKernel(const int N, const T* X, T* Y, bool square) {
const int idx = threadIdx.x;
- __shared__ float reduction_buffer[SUM_KERNEL_NTHREADS];
+ __shared__ T reduction_buffer[SUM_KERNEL_NTHREADS];
reduction_buffer[idx] = 0;
@@ -524,12 +371,11 @@ __global__ void SumKernel(const int N, const T* X, T* Y, bool square) {
// N -> 128
if (!square) {
for (int i = idx; i < N; i += SUM_KERNEL_NTHREADS) {
- reduction_buffer[idx] += convert::To<T, float>(X[i]);
+ reduction_buffer[idx] += X[i];
}
} else {
for (int i = idx; i < N; i += SUM_KERNEL_NTHREADS) {
- float Xi = convert::To<T, float>(X[i]);
- reduction_buffer[idx] += Xi * Xi;
+ reduction_buffer[idx] += X[i] * X[i];
}
}
__syncthreads();
@@ -547,7 +393,7 @@ __global__ void SumKernel(const int N, const T* X, T* Y, bool square) {
for (int i = 0; i < 32; ++i) {
tmp += reduction_buffer[i];
}
- *Y = convert::To<float, T>(tmp);
+ *Y = tmp;
}
}
@@ -560,7 +406,7 @@ __global__ void SumKernel(const int N, const T* X, T* Y, bool square) {
}
CAFFE2_MATH_SUM_FUNC(float)
-CAFFE2_MATH_SUM_FUNC(float16)
+CAFFE2_MATH_SUM_FUNC(double)
#undef CAFFE2_MATH_SUM_FUNC
#define CAFFE2_MATH_SUMSQR_FUNC(T) \
@@ -592,33 +438,18 @@ void Select<float, CUDAContext>(
0, context->cuda_stream()>>>(N, D, x, idx, y);
}
-template <>
-void Select<float16, CUDAContext>(
- const int N,
- const int D,
- const float16* x,
- const int* idx,
- float16* y,
- CUDAContext* context) {
- SelectKernel<float16><<<
- CAFFE_GET_BLOCKS(N),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context->cuda_stream()>>>(N, D, x, idx, y);
-}
-
namespace {
template <typename T>
-__global__ void ScaleKernel(const int n, const float alpha, const T* x, T* y) {
+__global__ void ScaleKernel(
+ const int n, const T alpha, const T* x, T* y) {
CUDA_1D_KERNEL_LOOP(i, n) {
- // y[i] = convert::To<float,T>(convert::To<T, float>(x[i]) * alpha);
- y[i] = convert::Get<T>(convert::Get<float>(x[i]) * alpha);
+ y[i] = x[i] * alpha;
}
}
template <typename T>
-__global__ void
-ScaleKernelDeviceAlpha(const int n, const float* alpha, const T* x, T* y) {
+__global__ void ScaleKernelDeviceAlpha(
+ const int n, const T* alpha, const T* x, T* y) {
CUDA_1D_KERNEL_LOOP(i, n) {
y[i] = x[i] * (*alpha);
}
@@ -630,20 +461,6 @@ __global__ void PowKernel(const int n, const T* x, const T exponent, T* y) {
y[i] = powf(x[i], exponent);
}
}
-
-// fp16 specialization
-template <>
-__global__ void ScaleKernelDeviceAlpha(
- const int n,
- const float* alpha,
- const float16* x,
- float16* y) {
- CUDA_1D_KERNEL_LOOP(i, n) {
- y[i] = convert::To<float, float16>(
- convert::To<float16, float>(x[i]) * (*alpha));
- }
-}
-
} // namespace
template <>
@@ -672,17 +489,12 @@ void Scale<float, CUDAContext>(
}
template <>
-void Scale<float16, CUDAContext>(
- const int n,
- const float alpha,
- const float16* x,
- float16* y,
+void Scale<double, CUDAContext>(
+ const int n, const double alpha, const double *x, double* y,
CUDAContext* context) {
- ScaleKernel<float16><<<
- CAFFE_GET_BLOCKS(n),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context->cuda_stream()>>>(n, alpha, x, y);
+ ScaleKernel<double><<<
+ CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+ n, alpha, x, y);
}
template <>
@@ -695,17 +507,11 @@ void Scale<float, CUDAContext>(
}
template <>
-void Scale<float16, CUDAContext>(
- const int n,
- const float* alpha,
- const float16* x,
- float16* y,
+void Scale<double, CUDAContext>(
+ const int n, const double* alpha, const double *x, double* y,
CUDAContext* context) {
- ScaleKernelDeviceAlpha<float16><<<
- CAFFE_GET_BLOCKS(n),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context->cuda_stream()>>>(n, alpha, x, y);
+ ScaleKernelDeviceAlpha<double><<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS,
+ 0, context->cuda_stream()>>>(n, alpha, x, y);
}
template <>
@@ -721,42 +527,18 @@ void Axpy<float, CUDAContext>(
template <>
void Axpy<double, CUDAContext>(
const int N,
- const float alpha,
+ const double alpha,
const double* X,
double* Y,
CUDAContext* context) {
- double alpha_d{alpha};
- CUBLAS_ENFORCE(
- cublasDaxpy(context->cublas_handle(), N, &alpha_d, X, 1, Y, 1));
-}
-
-template <>
-void Axpy<float16, CUDAContext>(
- const int N,
- const float alpha,
- const float16* X,
- float16* Y,
- CUDAContext* context) {
- CUBLAS_CHECK(cublasAxpyEx(
- context->cublas_handle(),
- N,
- &alpha,
- CUDA_R_16F,
- X,
- CUDA_R_16F,
- 1,
- Y,
- CUDA_R_16F,
- 1,
- CUDA_R_32F));
+ CUBLAS_ENFORCE(cublasDaxpy(context->cublas_handle(), N, &alpha, X, 1, Y, 1));
}
namespace {
template <typename T>
-__global__ void AxpyKernel(const int n, const float* a, const T* x, T* y) {
+__global__ void AxpyKernel(const int n, const T* a, const T* x, T* y) {
CUDA_1D_KERNEL_LOOP(index, n) {
- y[index] = convert::Get<T>(
- convert::Get<float>(x[index]) * (*a) + convert::Get<float>(y[index]));
+ y[index] += x[index] * (*a);
}
}
} // namespace
@@ -770,19 +552,14 @@ void Axpy<float, CUDAContext>(
}
template <>
-void Axpy<float16, CUDAContext>(
- const int n,
- const float* alpha,
- const float16* X,
- float16* Y,
- CUDAContext* context) {
- AxpyKernel<float16><<<
- CAFFE_GET_BLOCKS(n),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context->cuda_stream()>>>(n, alpha, X, Y);
+void Axpy<double, CUDAContext>(
+ const int n, const double* alpha, const double* X,
+ double* Y, CUDAContext* context) {
+ AxpyKernel<double><<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS,
+ 0, context->cuda_stream()>>>(n, alpha, X, Y);
}
+
namespace {
template <typename T>
__global__ void AxpbyKernel(const int n, const T a, const T* x,
@@ -801,6 +578,14 @@ void Axpby<float, CUDAContext>(
0, context->cuda_stream()>>>(n, a, x, b, y);
}
+template <>
+void Axpby<double, CUDAContext>(
+ const int n, const double a, const double* x, const double b, double* y,
+ CUDAContext* context) {
+ AxpbyKernel<double><<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS,
+ 0, context->cuda_stream()>>>(n, a, x, b, y);
+}
+
namespace {
template <typename T>
diff --git a/caffe2/utils/math_gpu_test.cc b/caffe2/utils/math_gpu_test.cc
index b1f930bec0..2ceeddd355 100644
--- a/caffe2/utils/math_gpu_test.cc
+++ b/caffe2/utils/math_gpu_test.cc
@@ -67,4 +67,61 @@ TEST(MathUtilGPUTest, testAddStripedBatch) {
}
}
+#define TEST_GEMV_WITH_TYPE(field_name) \
+ TEST(MathUtilGPUTest, testGemv_##field_name) { \
+ if (!HasCudaGPU()) \
+ return; \
+ Workspace ws; \
+ DeviceOption option; \
+ option.set_device_type(CUDA); \
+ CUDAContext context(option); \
+ Blob* blobx = ws.CreateBlob("X"); \
+ Blob* bloby = ws.CreateBlob("Y"); \
+ Blob* blobz = ws.CreateBlob("Z"); \
+ Blob* bloby_host = ws.CreateBlob("Y_host"); \
+ \
+ vector<int> shapex{64, 128}; \
+ vector<int> shapey{64}; \
+ vector<int> shapez{128}; \
+ \
+ auto* tensorx = blobx->GetMutable<Tensor<CUDAContext>>(); \
+ tensorx->Resize(shapex); \
+ math::Set<field_name, CUDAContext>( \
+ 64 * 128, \
+ (field_name)1.0, \
+ tensorx->mutable_data<field_name>(), \
+ &context); \
+ \
+ auto* tensory = bloby->GetMutable<Tensor<CUDAContext>>(); \
+ tensory->Resize(shapey); \
+ math::Set<field_name, CUDAContext>( \
+ 64, (field_name)1.0, tensory->mutable_data<field_name>(), &context); \
+ \
+ auto* tensorz = blobz->GetMutable<Tensor<CUDAContext>>(); \
+ tensorz->Resize(shapez); \
+ \
+ math::Gemv<field_name, CUDAContext>( \
+ CblasTrans, \
+ 64, \
+ 128, \
+ 1.0, \
+ tensorx->template data<field_name>(), \
+ tensory->mutable_data<field_name>(), \
+ 0.0, \
+ tensorz->template mutable_data<field_name>(), \
+ &context); \
+ context.FinishDeviceComputation(); \
+ \
+ auto* tensory_host = bloby_host->GetMutable<Tensor<CPUContext>>(); \
+ tensory_host->CopyFrom<CUDAContext, CUDAContext>(*tensorz, &context); \
+ context.FinishDeviceComputation(); \
+ \
+ for (int i = 0; i < 128; i++) { \
+ EXPECT_EQ(tensory_host->data<field_name>()[i], 64.0); \
+ } \
+ }
+
+TEST_GEMV_WITH_TYPE(float);
+TEST_GEMV_WITH_TYPE(double);
+
} // namespace caffe2