summaryrefslogtreecommitdiff
path: root/caffe2
diff options
context:
space:
mode:
Diffstat (limited to 'caffe2')
-rw-r--r--caffe2/operators/softmax_op.cc72
-rw-r--r--caffe2/operators/softmax_op.h1
-rw-r--r--caffe2/operators/softmax_shared.cc55
-rw-r--r--caffe2/operators/softmax_shared.h21
-rw-r--r--caffe2/operators/softmax_utils.cc38
-rw-r--r--caffe2/operators/softmax_utils.h23
-rw-r--r--caffe2/operators/softmax_with_loss_op.cc91
-rw-r--r--caffe2/operators/spatial_softmax_with_loss_op.cc3
8 files changed, 138 insertions, 166 deletions
diff --git a/caffe2/operators/softmax_op.cc b/caffe2/operators/softmax_op.cc
index 2a021ab730..e2418cb429 100644
--- a/caffe2/operators/softmax_op.cc
+++ b/caffe2/operators/softmax_op.cc
@@ -1,49 +1,29 @@
#include "caffe2/operators/softmax_op.h"
-#include "caffe2/operators/softmax_shared.h"
+
+#include "caffe2/operators/softmax_utils.h"
namespace caffe2 {
// Implementation for the CPU context.
template <>
bool SoftmaxOp<float, CPUContext>::RunOnDevice() {
- auto& X = Input(0);
-
- const auto canonical_axis = X.canonical_axis_index(axis_);
+ const auto& X = Input(0);
+ const int canonical_axis = X.canonical_axis_index(axis_);
const int N = X.size_to_dim(canonical_axis);
const int D = X.size_from_dim(canonical_axis);
auto* Y = Output(0, X.sizes(), at::dtype<float>());
- float* Ydata = Y->template mutable_data<float>();
- // First, get scales
+ const float* X_data = X.data<float>();
+ float* Y_data = Y->mutable_data<float>();
+ if (N == 0) {
+ return true;
+ }
if (!scale_.defined()) {
scale_ = caffe2::empty({N}, at::dtype<float>().device(CPU));
} else if (scale_.numel() != N) {
scale_.Resize(N);
}
-
- if (!rowmax_.defined()) {
- rowmax_ = caffe2::empty({N}, at::dtype<float>().device(CPU));
- } else if (rowmax_.numel() != N) {
- rowmax_.Resize(N);
- }
-
- if (!sum_multiplier_.defined()) {
- sum_multiplier_ = caffe2::empty({D}, at::dtype<float>().device(CPU));
- math::Set<float, CPUContext>(D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
- } else if (sum_multiplier_.numel() != D) {
- sum_multiplier_.Resize(D);
- math::Set<float, CPUContext>(D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
- }
-
- SoftmaxCPU(
- context_,
- N,
- D,
- X.data<float>(),
- Ydata,
- scale_.mutable_data<float>(),
- sum_multiplier_.data<float>(),
- false,
- rowmax_.mutable_data<float>());
+ softmax_utils::SoftmaxCPU<float>(
+ N, D, false, X_data, Y_data, scale_.mutable_data<float>(), &context_);
return true;
}
@@ -65,10 +45,12 @@ bool SoftmaxGradientOp<float, CPUContext>::RunOnDevice() {
if (!sum_multiplier_.defined()) {
sum_multiplier_ = caffe2::empty({D}, at::dtype<float>().device(CPU));
- math::Set<float, CPUContext>(D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
+ math::Set<float, CPUContext>(
+ D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
} else if (sum_multiplier_.numel() != D) {
sum_multiplier_.Resize(D);
- math::Set<float, CPUContext>(D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
+ math::Set<float, CPUContext>(
+ D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
}
auto* dX = Output(0, Y.sizes(), at::dtype<float>());
@@ -81,12 +63,21 @@ bool SoftmaxGradientOp<float, CPUContext>::RunOnDevice() {
context_.CopySameDevice<float>(Y.numel(), dYdata, dXdata);
float* scaledata = scale_.mutable_data<float>();
for (int i = 0; i < N; ++i) {
- math::Dot<float, CPUContext>(D, Ydata + i * D, dYdata + i * D,
- scaledata + i, &context_);
+ math::Dot<float, CPUContext>(
+ D, Ydata + i * D, dYdata + i * D, scaledata + i, &context_);
}
- math::Gemm<float, CPUContext>(CblasNoTrans, CblasNoTrans, N, D, 1, -1,
- scaledata, sum_multiplier_.data<float>(), 1,
- dXdata, &context_);
+ math::Gemm<float, CPUContext>(
+ CblasNoTrans,
+ CblasNoTrans,
+ N,
+ D,
+ 1,
+ -1,
+ scaledata,
+ sum_multiplier_.data<float>(),
+ 1,
+ dXdata,
+ &context_);
math::Mul<float, CPUContext>(Y.numel(), dXdata, Ydata, dXdata, &context_);
return true;
}
@@ -184,7 +175,8 @@ class GetSoftmaxGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
return SingleGradientDef(
- def_.type() + "Gradient", "",
+ def_.type() + "Gradient",
+ "",
vector<string>{O(0), GO(0)},
vector<string>{GI(0)});
}
@@ -192,4 +184,4 @@ class GetSoftmaxGradient : public GradientMakerBase {
REGISTER_GRADIENT(Softmax, GetSoftmaxGradient);
REGISTER_GRADIENT(SoftmaxFp16, GetSoftmaxGradient);
-} // namespace caffe2
+} // namespace caffe2
diff --git a/caffe2/operators/softmax_op.h b/caffe2/operators/softmax_op.h
index cd081a18fb..d75a8ec65d 100644
--- a/caffe2/operators/softmax_op.h
+++ b/caffe2/operators/softmax_op.h
@@ -16,6 +16,7 @@ class SoftmaxOp final : public Operator<Context> {
: Operator<Context>(std::forward<Args>(args)...),
axis_(this->template GetSingleArgument<int>("axis", 1)) {}
USE_OPERATOR_CONTEXT_FUNCTIONS;
+
bool RunOnDevice() override;
protected:
diff --git a/caffe2/operators/softmax_shared.cc b/caffe2/operators/softmax_shared.cc
deleted file mode 100644
index c1b3761879..0000000000
--- a/caffe2/operators/softmax_shared.cc
+++ /dev/null
@@ -1,55 +0,0 @@
-#include "caffe2/core/context.h"
-#include "caffe2/core/operator.h"
-#include "caffe2/utils/math.h"
-
-namespace caffe2 {
-
-void SoftmaxCPU(
- CPUContext& context,
- const int N,
- const int D,
- const float* Xdata,
- float* Ydata,
- float* scale,
- const float* sum_multiplier,
- bool logarithmic,
- float* rowmax) {
- math::RowwiseMax<float, CPUContext>(N, D, Xdata, rowmax, &context);
- // Put the intermediate result X - max(X) into Y
- context.template CopyFromCPU<float>(N * D, Xdata, Ydata);
- // Subtract the max (for numerical reasons)
- math::Gemm<float, CPUContext>(
- CblasNoTrans,
- CblasNoTrans,
- N,
- D,
- 1,
- -1,
- rowmax,
- sum_multiplier,
- 1,
- Ydata,
- &context);
- // Exponentiation
- math::Exp<float, CPUContext>(N * D, Ydata, Ydata, &context);
- math::Gemv<float, CPUContext>(
- CblasNoTrans, N, D, 1, Ydata, sum_multiplier, 0, scale, &context);
- // Do division
- // TODO(Yangqing): maybe implement it more beautifully?
- if (!logarithmic) {
- for (int i = 0; i < N; ++i) {
- for (int j = 0; j < D; ++j) {
- Ydata[i * D + j] /= scale[i];
- }
- }
- } else {
- for (int i = 0; i < N; ++i) {
- for (int j = 0; j < D; ++j) {
- Ydata[i * D + j] =
- Xdata[i * D + j] - rowmax[i] - log(fmaxf(scale[i], 1e-20f));
- }
- }
- }
-}
-
-} // namespace caffe2
diff --git a/caffe2/operators/softmax_shared.h b/caffe2/operators/softmax_shared.h
deleted file mode 100644
index 60c2bd0ab5..0000000000
--- a/caffe2/operators/softmax_shared.h
+++ /dev/null
@@ -1,21 +0,0 @@
-#ifndef CAFFE2_OPERATORS_SOFTMAX_SHARED_H_
-#define CAFFE2_OPERATORS_SOFTMAX_SHARED_H_
-
-#include "caffe2/core/context.h"
-#include "caffe2/core/operator.h"
-
-namespace caffe2 {
-
-void SoftmaxCPU(
- CPUContext& context,
- const int N,
- const int D,
- const float* Xdata,
- float* Ydata,
- float* scale,
- const float* sum_multiplier,
- bool logarithmic,
- float* rowmax);
-} // namespace caffe2
-
-#endif // #define CAFFE2_OPERATORS_SOFTMAX_SHARED_H_
diff --git a/caffe2/operators/softmax_utils.cc b/caffe2/operators/softmax_utils.cc
new file mode 100644
index 0000000000..98288e6e4b
--- /dev/null
+++ b/caffe2/operators/softmax_utils.cc
@@ -0,0 +1,38 @@
+#include "caffe2/operators/softmax_utils.h"
+
+#include "caffe2/core/context.h"
+#include "caffe2/utils/eigen_utils.h"
+#include "caffe2/utils/math.h"
+
+namespace caffe2 {
+namespace softmax_utils {
+
+#define CAFFE2_SPECIALIZED_SOFTMAX_CPU(T) \
+ template <> \
+ void SoftmaxCPU<T>( \
+ const int N, \
+ const int D, \
+ const bool logarithmic, \
+ const T* X, \
+ T* Y, \
+ T* scratch, \
+ CPUContext* context) { \
+ ConstEigenArrayMap<T> X_arr(X, D, N); \
+ EigenArrayMap<T> Y_arr(Y, D, N); \
+ EigenVectorArrayMap<T> scratch_arr(scratch, N); \
+ scratch_arr = X_arr.colwise().maxCoeff().transpose(); \
+ Y_arr = X_arr.rowwise() - scratch_arr.transpose(); \
+ math::Exp<T, CPUContext>(N * D, Y, Y, context); \
+ if (logarithmic) { \
+ scratch_arr += Y_arr.colwise().sum().log().transpose(); \
+ Y_arr = X_arr.rowwise() - scratch_arr.transpose(); \
+ } else { \
+ scratch_arr = Y_arr.colwise().sum().inverse().transpose(); \
+ Y_arr = Y_arr.rowwise() * scratch_arr.transpose(); \
+ } \
+ }
+CAFFE2_SPECIALIZED_SOFTMAX_CPU(float)
+#undef CAFFE2_SPECIALIZED_SOFTMAX_CPU
+
+} // namespace softmax_utils
+} // namespace caffe2
diff --git a/caffe2/operators/softmax_utils.h b/caffe2/operators/softmax_utils.h
new file mode 100644
index 0000000000..5b2d7cb323
--- /dev/null
+++ b/caffe2/operators/softmax_utils.h
@@ -0,0 +1,23 @@
+#ifndef CAFFE2_OPERATORS_SOFTMAX_UTILS_H_
+#define CAFFE2_OPERATORS_SOFTMAX_UTILS_H_
+
+#include "caffe2/core/context.h"
+#include "caffe2/core/operator.h"
+
+namespace caffe2 {
+namespace softmax_utils {
+
+template <typename T>
+void SoftmaxCPU(
+ int N,
+ int D,
+ bool logarithmic,
+ const T* X,
+ T* Y,
+ T* scratch,
+ CPUContext* context);
+
+} // namespace softmax_utils
+} // namespace caffe2
+
+#endif // CAFFE2_OPERATORS_SOFTMAX_UTILS_H_
diff --git a/caffe2/operators/softmax_with_loss_op.cc b/caffe2/operators/softmax_with_loss_op.cc
index 36a77408d4..f61560c85b 100644
--- a/caffe2/operators/softmax_with_loss_op.cc
+++ b/caffe2/operators/softmax_with_loss_op.cc
@@ -1,5 +1,8 @@
-#include "softmax_with_loss_op.h"
-#include "softmax_shared.h"
+#include "caffe2/operators/softmax_with_loss_op.h"
+
+#include <vector>
+
+#include "caffe2/operators/softmax_utils.h"
namespace caffe2 {
@@ -12,28 +15,28 @@ REGISTER_CPU_OPERATOR(
OPERATOR_SCHEMA(SoftmaxWithLoss)
.NumInputs(2, 3)
.NumOutputs(2)
- .TensorInferenceFunction(
- [](const OperatorDef& def, const vector<TensorShape>& in) {
- ArgumentHelper helper(def);
- auto axis = helper.GetSingleArgument<int32_t>("axis", 1);
-
- vector<TensorShape> out(2);
-
- auto logits = in[0]; // Tensor with Shape [batch_size, num_classes]
- auto labels = in[1]; // Tensor with shape [batch_size, ]
- const auto canonical_axis =
- canonical_axis_index_(axis, logits.dims().size());
- const int batch_size =
- size_to_dim_(canonical_axis, GetDimsVector(logits));
- const int num_classes =
- size_from_dim_(canonical_axis, GetDimsVector(logits));
-
- out[0].set_data_type(logits.data_type());
- out[0].add_dims(batch_size);
- out[0].add_dims(num_classes);
-
- return out;
- })
+ .TensorInferenceFunction([](const OperatorDef& def,
+ const vector<TensorShape>& in) {
+ ArgumentHelper helper(def);
+ auto axis = helper.GetSingleArgument<int32_t>("axis", 1);
+
+ vector<TensorShape> out(2);
+
+ auto logits = in[0]; // Tensor with Shape [batch_size, num_classes]
+ auto labels = in[1]; // Tensor with shape [batch_size, ]
+ const auto canonical_axis =
+ canonical_axis_index_(axis, logits.dims().size());
+ const int batch_size =
+ size_to_dim_(canonical_axis, GetDimsVector(logits));
+ const int num_classes =
+ size_from_dim_(canonical_axis, GetDimsVector(logits));
+
+ out[0].set_data_type(logits.data_type());
+ out[0].add_dims(batch_size);
+ out[0].add_dims(num_classes);
+
+ return out;
+ })
.SetDoc(R"DOC(
Combined Softmax and Cross-Entropy loss operator. The operator first computes the softmax normalized values for each layer in the batch of the given input, then computes cross-entropy loss. This operator is numerically more stable than separate `Softmax` and `CrossEntropy` ops. The inputs are a 2-D tensor `logits` of size (batch_size x input_feature_dimensions), which represents the unscaled log probabilities, and a 1-dimensional integer `labels` tensor for ground truth. An optional third input blob (`weight_tensor`) can be used to weight the samples for the loss, which is useful if the training set is unbalanced. This operator outputs a `softmax` tensor which contains the probability for each label for each example (same shape is `logits` input), and a scalar `loss` value, which is the averaged cross-entropy loss between the softmax probabilities and the ground truth values. Use parameter `label_prob`=1 to enable inputting labels as a probability distribution.
@@ -132,10 +135,18 @@ avgloss: 10.667433
</details>
)DOC")
- .Arg("label_prob","*(type: int; default: 0)* Setting to 1 enables inputting labels as probability distribution.")
- .Arg("axis","*(type: int; default: 1)* Axis of the inputs when coerced to 2D.")
- .Arg("scale","*(type: float)* Average loss output scaling factor (must be >= 0).")
- .Arg("order","*(type: string; default: 'NCHW')* Order of blob dimensions (only 'NCHW' is supported currently).")
+ .Arg(
+ "label_prob",
+ "*(type: int; default: 0)* Setting to 1 enables inputting labels as probability distribution.")
+ .Arg(
+ "axis",
+ "*(type: int; default: 1)* Axis of the inputs when coerced to 2D.")
+ .Arg(
+ "scale",
+ "*(type: float)* Average loss output scaling factor (must be >= 0).")
+ .Arg(
+ "order",
+ "*(type: string; default: 'NCHW')* Order of blob dimensions (only 'NCHW' is supported currently).")
.Input(0, "logits", "*(type: Tensor`<float>`)* Input tensor.")
.Input(1, "labels", "*(type: Tensor`<float>`)* Ground truth label tensor.")
.Input(
@@ -178,36 +189,20 @@ bool SoftmaxWithLossOp<float, CPUContext>::RunOnDevice() {
}
}
- if (!sum_multiplier_.defined()) {
- sum_multiplier_ = caffe2::empty({D}, at::dtype<float>().device(CPU));
- math::Set<float, CPUContext>(D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
- } else if (sum_multiplier_.numel() != D) {
- sum_multiplier_.Resize(D);
- math::Set<float, CPUContext>(D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
- }
-
if (!losses_.defined()) {
losses_ = caffe2::empty({N}, at::dtype<float>().device(CPU));
} else if (losses_.numel() != N) {
losses_.Resize(N);
}
- if (!rowmax_.defined()) {
- rowmax_ = caffe2::empty({N}, at::dtype<float>().device(CPU));
- } else if (rowmax_.numel() != N) {
- rowmax_.Resize(N);
- }
-
- SoftmaxCPU(
- context_,
+ softmax_utils::SoftmaxCPU<float>(
N,
D,
+ !label_prob_mode_,
X.data<float>(),
Pdata,
losses_.mutable_data<float>(),
- sum_multiplier_.data<float>(),
- !label_prob_mode_,
- rowmax_.mutable_data<float>());
+ &context_);
// Then compute cross entropy
float loss_sum = 0.0;
@@ -382,5 +377,5 @@ class GetSoftmaxWithLossGradient : public GradientMakerBase {
};
REGISTER_GRADIENT(SoftmaxWithLoss, GetSoftmaxWithLossGradient);
-}
+} // namespace
} // namespace caffe2
diff --git a/caffe2/operators/spatial_softmax_with_loss_op.cc b/caffe2/operators/spatial_softmax_with_loss_op.cc
index 09464b0e05..d345fe175b 100644
--- a/caffe2/operators/spatial_softmax_with_loss_op.cc
+++ b/caffe2/operators/spatial_softmax_with_loss_op.cc
@@ -1,5 +1,4 @@
-#include "spatial_softmax_with_loss_op.h"
-#include "softmax_shared.h"
+#include "caffe2/operators/spatial_softmax_with_loss_op.h"
namespace caffe2 {