diff options
author | Priya Goyal <prigoyal@fb.com> | 2019-04-23 11:41:44 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-04-23 11:49:49 -0700 |
commit | 0d0acba3bd4034cfb1e01aa74c55930b53deb8d4 (patch) | |
tree | d2465a979f355f673494a1eb13e1fbc9cbf76424 | |
parent | e9c8f372c49431907ac525a2abbfe212e549f61e (diff) | |
download | pytorch-0d0acba3bd4034cfb1e01aa74c55930b53deb8d4.tar.gz pytorch-0d0acba3bd4034cfb1e01aa74c55930b53deb8d4.tar.bz2 pytorch-0d0acba3bd4034cfb1e01aa74c55930b53deb8d4.zip |
Allow extracting element-wise loss in softmax (#19579)
Summary:
Often times, we want to experiment with loss per element (image etc.). This changeset allows getting per element loss as well. This output is optional.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19579
Reviewed By: jerryzh168
Differential Revision: D15035797
Pulled By: prigoyal
fbshipit-source-id: 562dea514f49c1f2f1cbbc083a1938dc019a75c4
-rw-r--r-- | caffe2/operators/softmax_ops.cu | 4 | ||||
-rw-r--r-- | caffe2/operators/softmax_with_loss_op.cc | 2 |
2 files changed, 4 insertions, 2 deletions
diff --git a/caffe2/operators/softmax_ops.cu b/caffe2/operators/softmax_ops.cu index df943aee03..2835d78ed3 100644 --- a/caffe2/operators/softmax_ops.cu +++ b/caffe2/operators/softmax_ops.cu @@ -394,7 +394,9 @@ bool SoftmaxWithLossOp<float, CUDAContext>::RunOnDevice() { math::Scale<float, float, CUDAContext>( 1, scale_ / total_weight, avg_loss_data, avg_loss_data, &context_); } - + if (OutputSize() > 2) { + OutputTensorAlias(2, losses_); + } return true; } diff --git a/caffe2/operators/softmax_with_loss_op.cc b/caffe2/operators/softmax_with_loss_op.cc index f61560c85b..405a50eb8a 100644 --- a/caffe2/operators/softmax_with_loss_op.cc +++ b/caffe2/operators/softmax_with_loss_op.cc @@ -14,7 +14,7 @@ REGISTER_CPU_OPERATOR( // Input: X (logits), T (labels); Output: P (probs), Y OPERATOR_SCHEMA(SoftmaxWithLoss) .NumInputs(2, 3) - .NumOutputs(2) + .NumOutputs({2, 3}) .TensorInferenceFunction([](const OperatorDef& def, const vector<TensorShape>& in) { ArgumentHelper helper(def); |