summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPriya Goyal <prigoyal@fb.com>2019-04-23 11:41:44 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-04-23 11:49:49 -0700
commit0d0acba3bd4034cfb1e01aa74c55930b53deb8d4 (patch)
treed2465a979f355f673494a1eb13e1fbc9cbf76424
parente9c8f372c49431907ac525a2abbfe212e549f61e (diff)
downloadpytorch-0d0acba3bd4034cfb1e01aa74c55930b53deb8d4.tar.gz
pytorch-0d0acba3bd4034cfb1e01aa74c55930b53deb8d4.tar.bz2
pytorch-0d0acba3bd4034cfb1e01aa74c55930b53deb8d4.zip
Allow extracting element-wise loss in softmax (#19579)
Summary: Often times, we want to experiment with loss per element (image etc.). This changeset allows getting per element loss as well. This output is optional. Pull Request resolved: https://github.com/pytorch/pytorch/pull/19579 Reviewed By: jerryzh168 Differential Revision: D15035797 Pulled By: prigoyal fbshipit-source-id: 562dea514f49c1f2f1cbbc083a1938dc019a75c4
-rw-r--r--caffe2/operators/softmax_ops.cu4
-rw-r--r--caffe2/operators/softmax_with_loss_op.cc2
2 files changed, 4 insertions, 2 deletions
diff --git a/caffe2/operators/softmax_ops.cu b/caffe2/operators/softmax_ops.cu
index df943aee03..2835d78ed3 100644
--- a/caffe2/operators/softmax_ops.cu
+++ b/caffe2/operators/softmax_ops.cu
@@ -394,7 +394,9 @@ bool SoftmaxWithLossOp<float, CUDAContext>::RunOnDevice() {
math::Scale<float, float, CUDAContext>(
1, scale_ / total_weight, avg_loss_data, avg_loss_data, &context_);
}
-
+ if (OutputSize() > 2) {
+ OutputTensorAlias(2, losses_);
+ }
return true;
}
diff --git a/caffe2/operators/softmax_with_loss_op.cc b/caffe2/operators/softmax_with_loss_op.cc
index f61560c85b..405a50eb8a 100644
--- a/caffe2/operators/softmax_with_loss_op.cc
+++ b/caffe2/operators/softmax_with_loss_op.cc
@@ -14,7 +14,7 @@ REGISTER_CPU_OPERATOR(
// Input: X (logits), T (labels); Output: P (probs), Y
OPERATOR_SCHEMA(SoftmaxWithLoss)
.NumInputs(2, 3)
- .NumOutputs(2)
+ .NumOutputs({2, 3})
.TensorInferenceFunction([](const OperatorDef& def,
const vector<TensorShape>& in) {
ArgumentHelper helper(def);