summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRichard Zou <zou3519@users.noreply.github.com>2017-12-18 02:39:49 -0500
committerSoumith Chintala <soumith@gmail.com>2017-12-18 02:39:49 -0500
commit847c56aeb5857fc4d3f5df88b9e8f937939bb8cc (patch)
tree31bb9849fef7f8b6294bf4425cbe835ebf0262e6
parentb86dc0c8ba3d4a6feb16863c04b931f43efd1ad8 (diff)
downloadpytorch-847c56aeb5857fc4d3f5df88b9e8f937939bb8cc.tar.gz
pytorch-847c56aeb5857fc4d3f5df88b9e8f937939bb8cc.tar.bz2
pytorch-847c56aeb5857fc4d3f5df88b9e8f937939bb8cc.zip
Add reduce arg to BCELoss (#3532)
* Add reduce arg to BCELoss * Fix test precision
-rw-r--r--aten/src/ATen/nn.yaml2
-rw-r--r--aten/src/THCUNN/BCECriterion.cu33
-rw-r--r--aten/src/THCUNN/generic/BCECriterion.cu40
-rw-r--r--aten/src/THCUNN/generic/THCUNN.h7
-rw-r--r--aten/src/THNN/generic/BCECriterion.c71
-rw-r--r--aten/src/THNN/generic/THNN.h7
-rw-r--r--test/common_nn.py4
-rw-r--r--test/test_nn.py27
-rw-r--r--torch/nn/functional.py8
-rw-r--r--torch/nn/modules/loss.py13
10 files changed, 185 insertions, 27 deletions
diff --git a/aten/src/ATen/nn.yaml b/aten/src/ATen/nn.yaml
index 3df33ee9be..08d87b1fbd 100644
--- a/aten/src/ATen/nn.yaml
+++ b/aten/src/ATen/nn.yaml
@@ -1,6 +1,6 @@
# Loss functions
-- name: binary_cross_entropy(Tensor self, Tensor target, Tensor weight={}, bool size_average=true)
+- name: binary_cross_entropy(Tensor self, Tensor target, Tensor weight={}, bool size_average=true, bool reduce=true)
cname: BCECriterion
- name: kl_div(Tensor self, Tensor target, bool size_average=true, bool reduce=true)
diff --git a/aten/src/THCUNN/BCECriterion.cu b/aten/src/THCUNN/BCECriterion.cu
index 04218dcc4a..ccb40008c3 100644
--- a/aten/src/THCUNN/BCECriterion.cu
+++ b/aten/src/THCUNN/BCECriterion.cu
@@ -33,6 +33,22 @@ struct bce_functor
};
template <typename Dtype, typename Acctype>
+struct bce_updateOutput_no_reduce_functor
+{
+ __forceinline__ __host__ __device__
+ void operator()(
+ const Dtype *input,
+ const Dtype *target,
+ Dtype *output)
+ {
+ assert(*input >= 0. && *input <= 1.);
+ *output = ScalarConvert<Acctype, Dtype>::to(
+ -(*target * THCNumerics<Acctype>::log(*input + eps<Acctype>()) +
+ (Acctype(1) - *target) * THCNumerics<Acctype>::log(Acctype(1) - *input + eps<Acctype>())));
+ }
+};
+
+template <typename Dtype, typename Acctype>
struct bce_functor_weights
{
template <class Tuple>
@@ -43,7 +59,22 @@ struct bce_functor_weights
Dtype t = thrust::get<1>(x);
Dtype w = thrust::get<2>(x);
assert(input >= 0. && input <= 1.);
- return - w * (t * THCNumerics<Acctype>::log(input + eps<Acctype>()) + (Acctype(1) - t) * THCNumerics<Acctype>::log(Acctype(1) - input + eps<Acctype>()));
+ return - w * (t * THCNumerics<Acctype>::log(input + eps<Acctype>()) +
+ (Acctype(1) - t) * THCNumerics<Acctype>::log(Acctype(1) - input + eps<Acctype>()));
+ }
+};
+
+template <typename Dtype, typename Acctype>
+struct bce_updateGradInput_no_reduce_functor
+{
+ __forceinline__ __host__ __device__
+ void operator()(
+ const Dtype *x,
+ const Dtype *t,
+ Dtype *gradInput)
+ {
+ *gradInput = ScalarConvert<Acctype,Dtype>::to(
+ - (*t - *x) / ((Acctype(1) - *x + eps<Acctype>()) * (*x + eps<Acctype>())));
}
};
diff --git a/aten/src/THCUNN/generic/BCECriterion.cu b/aten/src/THCUNN/generic/BCECriterion.cu
index 4d9988b757..e98f1b05db 100644
--- a/aten/src/THCUNN/generic/BCECriterion.cu
+++ b/aten/src/THCUNN/generic/BCECriterion.cu
@@ -2,19 +2,32 @@
#define THC_GENERIC_FILE "generic/BCECriterion.cu"
#else
+#include "THCApply.cuh"
+
void THNN_(BCECriterion_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *target,
THCTensor *output,
bool sizeAverage,
- THCTensor *weights)
+ THCTensor *weights,
+ bool reduce)
{
THCUNN_check_nElement(state, input, target);
THCUNN_check_nElement(state, input, weights);
- THCTensor_(resize1d)(state, output, 1);
THCUNN_assertSameGPU(state, 3, input, target, weights);
+ if (!reduce) {
+ THCTensor_(resizeAs)(state, output, input);
+ THC_pointwiseApply3(state, input, target, output,
+ bce_updateOutput_no_reduce_functor<real, accreal>());
+ if (weights) {
+ THCTensor_(cmul)(state, output, output, weights);
+ }
+ return;
+ }
+
+ THCTensor_(resize1d)(state, output, 1);
ptrdiff_t size = THCTensor_(nElement)(state, input);
input = THCTensor_(newContiguous)(state, input);
@@ -58,22 +71,37 @@ void THNN_(BCECriterion_updateGradInput)(
THCState *state,
THCTensor *input,
THCTensor *target,
+ THCTensor *gradOutput,
THCTensor *gradInput,
bool sizeAverage,
- THCTensor *weights)
+ THCTensor *weights,
+ bool reduce)
{
THCUNN_check_nElement(state, input, target);
THCUNN_check_nElement(state, input, weights);
THCUNN_assertSameGPU(state, 4, input, target, gradInput, weights);
+ THCTensor_(resizeAs)(state, gradInput, input);
+
+ if (!reduce) {
+ THCUNN_check_nElement(state, gradOutput, input);
+ THC_pointwiseApply3(state, input, target, gradInput,
+ bce_updateGradInput_no_reduce_functor<real, accreal>());
+ THCTensor_(cmul)(state, gradInput, gradInput, gradOutput);
+ if (weights) {
+ THCTensor_(cmul)(state, gradInput, gradInput, weights);
+ }
+ return;
+ }
+
+ THCUNN_check_dim_size(state, gradOutput, 1, 0, 1);
+
ptrdiff_t size = THCTensor_(nElement)(state, input);
- real norm = ScalarConvert<accreal, real>::to(sizeAverage ? accreal(1)/size : accreal(1));
+ real norm = ScalarConvert<accreal, real>::to((sizeAverage ? accreal(1)/size : accreal(1)) * THCTensor_(get1d)(state, gradOutput, 0));
input = THCTensor_(newContiguous)(state, input);
target = THCTensor_(newContiguous)(state, target);
- THCTensor_(resizeAs)(state, gradInput, input);
-
thrust::device_ptr<real> input_data(THCTensor_(data)(state, input));
thrust::device_ptr<real> target_data(THCTensor_(data)(state, target));
thrust::device_ptr<real> gradInput_data(THCTensor_(data)(state, gradInput));
diff --git a/aten/src/THCUNN/generic/THCUNN.h b/aten/src/THCUNN/generic/THCUNN.h
index e86583bde4..7f5002f484 100644
--- a/aten/src/THCUNN/generic/THCUNN.h
+++ b/aten/src/THCUNN/generic/THCUNN.h
@@ -66,15 +66,18 @@ TH_API void THNN_(BCECriterion_updateOutput)(
THCTensor *target,
THCTensor *output,
bool sizeAverage,
- THCTensor *weights); // [OPTIONAL]
+ THCTensor *weights, // [OPTIONAL]
+ bool reduce);
TH_API void THNN_(BCECriterion_updateGradInput)(
THCState *state,
THCTensor *input,
THCTensor *target,
+ THCTensor *gradOutput,
THCTensor *gradInput,
bool sizeAverage,
- THCTensor *weights); // [OPTIONAL]
+ THCTensor *weights, // [OPTIONAL]
+ bool reduce);
TH_API void THNN_(ClassNLLCriterion_updateOutput)(
THCState *state,
diff --git a/aten/src/THNN/generic/BCECriterion.c b/aten/src/THNN/generic/BCECriterion.c
index b668370336..1f69c315e8 100644
--- a/aten/src/THNN/generic/BCECriterion.c
+++ b/aten/src/THNN/generic/BCECriterion.c
@@ -4,16 +4,38 @@
#define EPS 1e-12
-void THNN_(BCECriterion_updateOutput)(THNNState *state, THTensor *input,
- THTensor *target, THTensor *output,
- bool sizeAverage, THTensor *weights)
+void THNN_(BCECriterion_updateOutput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *target,
+ THTensor *output,
+ bool sizeAverage,
+ THTensor *weights,
+ bool reduce)
{
THNN_CHECK_NELEMENT(input, target);
THNN_CHECK_NELEMENT(input, weights);
+
+ if (!reduce) {
+ THTensor_(resizeAs)(output, input);
+ TH_TENSOR_APPLY3(real, input, real, target, real, output,
+ real x = *input_data;
+ real y = *target_data;
+ THAssertMsg(x >= 0. && x <= 1.,
+ "input value should be between 0~1, but got %f",
+ (double) x);
+ *output_data = -(log(x + EPS) * y + log(1. - x + EPS) * (1. - y));
+ );
+ if (weights) {
+ THTensor_(cmul)(output, output, weights);
+ }
+ return;
+ }
+
THTensor_(resize1d)(output, 1);
real sum = 0;
- if(weights)
+ if (weights) {
TH_TENSOR_APPLY3(real, input, real, target, real, weights,
real x = *input_data;
real y = *target_data;
@@ -22,8 +44,8 @@ void THNN_(BCECriterion_updateOutput)(THNNState *state, THTensor *input,
"input value should be between 0~1, but got %f",
(double) x);
sum -= (log(x + EPS) * y + log(1. - x + EPS) * (1. - y)) * w;
- )
- else
+ );
+ } else {
TH_TENSOR_APPLY2(real, input, real, target,
real x = *input_data;
real y = *target_data;
@@ -32,6 +54,7 @@ void THNN_(BCECriterion_updateOutput)(THNNState *state, THTensor *input,
(double) x);
sum -= log(x + EPS) * y + log(1. - x + EPS) * (1. - y);
);
+ }
if (sizeAverage)
@@ -40,21 +63,45 @@ void THNN_(BCECriterion_updateOutput)(THNNState *state, THTensor *input,
THTensor_(set1d)(output, 0, sum);
}
-void THNN_(BCECriterion_updateGradInput)(THNNState *state, THTensor *input,
- THTensor *target, THTensor *gradInput,
- bool sizeAverage, THTensor *weights)
+void THNN_(BCECriterion_updateGradInput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *target,
+ THTensor *gradOutput,
+ THTensor *gradInput,
+ bool sizeAverage,
+ THTensor *weights,
+ bool reduce)
{
THNN_CHECK_NELEMENT(input, target);
THNN_CHECK_NELEMENT(input, weights);
+ THTensor_(resizeAs)(gradInput, input);
- real norm = (sizeAverage ? 1./((real)THTensor_(nElement)(input)) : 1.);
+ if (!reduce) {
+ THNN_CHECK_NELEMENT(gradOutput, input);
+ TH_TENSOR_APPLY3(real, gradInput, real, input, real, target,
+ real x = *input_data;
+ real y = *target_data;
+ *gradInput_data = -(y - x) / ((1. - x + EPS) * (x + EPS));
+ );
- THTensor_(resizeAs)(gradInput, input);
+ if (weights) {
+ TH_TENSOR_APPLY3(real, gradInput, real, weights, real, gradOutput,
+ *gradInput_data = *gradInput_data * *weights_data * *gradOutput_data;
+ );
+ } else {
+ THTensor_(cmul)(gradInput, gradInput, gradOutput);
+ }
+ return;
+ }
+
+ THNN_CHECK_DIM_SIZE(gradOutput, 1, 0, 1);
+ real norm = (sizeAverage ? 1./((real)THTensor_(nElement)(input)) : 1.);
TH_TENSOR_APPLY3(real, gradInput, real, input, real, target,
real x = *input_data;
real y = *target_data;
- *gradInput_data = - norm * (y - x) / ((1. - x + EPS) * (x + EPS));
+ *gradInput_data = - norm * (y - x) / ((1. - x + EPS) * (x + EPS)) * THTensor_fastGet1d(gradOutput, 0);
);
if(weights)
diff --git a/aten/src/THNN/generic/THNN.h b/aten/src/THNN/generic/THNN.h
index 32fe59ae3b..93436dfec1 100644
--- a/aten/src/THNN/generic/THNN.h
+++ b/aten/src/THNN/generic/THNN.h
@@ -34,14 +34,17 @@ TH_API void THNN_(BCECriterion_updateOutput)(
THTensor *target,
THTensor *output,
bool sizeAverage,
- THTensor *weights); // [OPTIONAL]
+ THTensor *weights, // [OPTIONAL]
+ bool reduce);
TH_API void THNN_(BCECriterion_updateGradInput)(
THNNState *state,
THTensor *input,
THTensor *target,
+ THTensor *gradOutput,
THTensor *gradInput,
bool sizeAverage,
- THTensor *weights); // [OPTIONAL]
+ THTensor *weights, // [OPTIONAL]
+ bool reduce);
TH_API void THNN_(ClassNLLCriterion_updateOutput)(
THNNState *state, // library's state
diff --git a/test/common_nn.py b/test/common_nn.py
index c0a2e1b13f..3b926ad7c0 100644
--- a/test/common_nn.py
+++ b/test/common_nn.py
@@ -400,6 +400,8 @@ criterion_tests = [
module_name='BCELoss',
input_fn=lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2),
target_fn=lambda: torch.randn(15, 10).gt(0).double(),
+ reference_fn=lambda i, t, m: -(t * i.log() + (1 - t) * (1 - i).log()).sum() /
+ (i.numel() if get_size_average(m) else 1),
check_gradgrad=False,
),
dict(
@@ -407,6 +409,8 @@ criterion_tests = [
constructor_args_fn=lambda: (torch.rand(10),),
input_fn=lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2),
target_fn=lambda: torch.randn(15, 10).gt(0).double(),
+ reference_fn=lambda i, t, m: -((t * i.log() + (1 - t) * (1 - i).log()) * get_weight(m)).sum() /
+ (i.numel() if get_size_average(m) else 1),
desc='weights',
check_gradgrad=False,
),
diff --git a/test/test_nn.py b/test/test_nn.py
index ef65b18457..df948204da 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -3952,6 +3952,31 @@ new_criterion_tests = [
]
+def bceloss_no_reduce_test():
+ t = torch.randn(15, 10).gt(0).double()
+ return dict(
+ fullname='BCELoss_no_reduce',
+ constructor=wrap_functional(
+ lambda i: F.binary_cross_entropy(i, Variable(t.type_as(i.data)), reduce=False)),
+ input_fn=lambda: torch.rand(15, 10).clamp_(2e-2, 1 - 2e-2),
+ reference_fn=lambda i, m: -(t * i.log() + (1 - t) * (1 - i).log()),
+ check_gradgrad=False,
+ pickle=False)
+
+
+def bceloss_weights_no_reduce_test():
+ t = torch.randn(15, 10).gt(0).double()
+ weights = torch.rand(10)
+ return dict(
+ fullname='BCELoss_weights_no_reduce',
+ constructor=wrap_functional(
+ lambda i: F.binary_cross_entropy(i, Variable(t.type_as(i.data)),
+ weight=weights.type_as(i.data), reduce=False)),
+ input_fn=lambda: torch.rand(15, 10).clamp_(2e-2, 1 - 2e-2),
+ reference_fn=lambda i, m: -(t * i.log() + (1 - t) * (1 - i).log()) * weights,
+ check_gradgrad=False,
+
+
def poissonnllloss_no_reduce_test():
t = Variable(torch.randn(10, 10))
return dict(
@@ -4175,6 +4200,8 @@ def smoothl1loss_no_reduce_test():
new_module_tests = [
+ bceloss_no_reduce_test(),
+ bceloss_weights_no_reduce_test(),
poissonnllloss_no_reduce_test(),
kldivloss_no_reduce_test(),
l1loss_no_reduce_test(),
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index a9b5a241e9..8ba929243e 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -1212,7 +1212,7 @@ def cross_entropy(input, target, weight=None, size_average=True, ignore_index=-1
return nll_loss(log_softmax(input, 1), target, weight, size_average, ignore_index, reduce)
-def binary_cross_entropy(input, target, weight=None, size_average=True):
+def binary_cross_entropy(input, target, weight=None, size_average=True, reduce=True):
r"""Function that measures the Binary Cross Entropy
between the target and the output.
@@ -1227,6 +1227,10 @@ def binary_cross_entropy(input, target, weight=None, size_average=True):
over observations for each minibatch. However, if the field
sizeAverage is set to False, the losses are instead summed
for each minibatch. Default: ``True``
+ reduce (bool, optional): By default, the losses are averaged or summed over
+ observations for each minibatch depending on size_average. When reduce
+ is False, returns a loss per batch element instead and ignores
+ size_average. Default: True
Examples::
@@ -1248,7 +1252,7 @@ def binary_cross_entropy(input, target, weight=None, size_average=True):
if torch.is_tensor(weight):
weight = Variable(weight)
- return torch._C._nn.binary_cross_entropy(input, target, weight, size_average)
+ return torch._C._nn.binary_cross_entropy(input, target, weight, size_average, reduce)
def binary_cross_entropy_with_logits(input, target, weight=None, size_average=True):
diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py
index b70a5ce08f..4b3620bf71 100644
--- a/torch/nn/modules/loss.py
+++ b/torch/nn/modules/loss.py
@@ -366,11 +366,17 @@ class BCELoss(_WeightedLoss):
over observations for each minibatch. However, if the field
size_average is set to ``False``, the losses are instead summed for
each minibatch. Default: ``True``
+ reduce (bool, optional): By default, the losses are averaged or summed over
+ observations for each minibatch depending on size_average. When reduce
+ is False, returns a loss per batch element instead and ignores
+ size_average. Default: True
Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Target: :math:`(N, *)`, same shape as the input
+ - Output: scalar. If `reduce` is False, then `(N, *)`, same shape as
+ input.
Examples::
@@ -381,10 +387,15 @@ class BCELoss(_WeightedLoss):
>>> output = loss(m(input), target)
>>> output.backward()
"""
+ def __init__(self, weight=None, size_average=True, reduce=True):
+ super(BCELoss, self).__init__(weight, size_average)
+ self.reduce = reduce
+
def forward(self, input, target):
_assert_no_grad(target)
return F.binary_cross_entropy(input, target, weight=self.weight,
- size_average=self.size_average)
+ size_average=self.size_average,
+ reduce=self.reduce)
class BCEWithLogitsLoss(Module):