From 847c56aeb5857fc4d3f5df88b9e8f937939bb8cc Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Mon, 18 Dec 2017 02:39:49 -0500 Subject: Add reduce arg to BCELoss (#3532) * Add reduce arg to BCELoss * Fix test precision --- aten/src/ATen/nn.yaml | 2 +- aten/src/THCUNN/BCECriterion.cu | 33 ++++++++++++++- aten/src/THCUNN/generic/BCECriterion.cu | 40 ++++++++++++++++--- aten/src/THCUNN/generic/THCUNN.h | 7 +++- aten/src/THNN/generic/BCECriterion.c | 71 +++++++++++++++++++++++++++------ aten/src/THNN/generic/THNN.h | 7 +++- test/common_nn.py | 4 ++ test/test_nn.py | 27 +++++++++++++ torch/nn/functional.py | 8 +++- torch/nn/modules/loss.py | 13 +++++- 10 files changed, 185 insertions(+), 27 deletions(-) diff --git a/aten/src/ATen/nn.yaml b/aten/src/ATen/nn.yaml index 3df33ee9be..08d87b1fbd 100644 --- a/aten/src/ATen/nn.yaml +++ b/aten/src/ATen/nn.yaml @@ -1,6 +1,6 @@ # Loss functions -- name: binary_cross_entropy(Tensor self, Tensor target, Tensor weight={}, bool size_average=true) +- name: binary_cross_entropy(Tensor self, Tensor target, Tensor weight={}, bool size_average=true, bool reduce=true) cname: BCECriterion - name: kl_div(Tensor self, Tensor target, bool size_average=true, bool reduce=true) diff --git a/aten/src/THCUNN/BCECriterion.cu b/aten/src/THCUNN/BCECriterion.cu index 04218dcc4a..ccb40008c3 100644 --- a/aten/src/THCUNN/BCECriterion.cu +++ b/aten/src/THCUNN/BCECriterion.cu @@ -32,6 +32,22 @@ struct bce_functor } }; +template +struct bce_updateOutput_no_reduce_functor +{ + __forceinline__ __host__ __device__ + void operator()( + const Dtype *input, + const Dtype *target, + Dtype *output) + { + assert(*input >= 0. && *input <= 1.); + *output = ScalarConvert::to( + -(*target * THCNumerics::log(*input + eps()) + + (Acctype(1) - *target) * THCNumerics::log(Acctype(1) - *input + eps()))); + } +}; + template struct bce_functor_weights { @@ -43,7 +59,22 @@ struct bce_functor_weights Dtype t = thrust::get<1>(x); Dtype w = thrust::get<2>(x); assert(input >= 0. && input <= 1.); - return - w * (t * THCNumerics::log(input + eps()) + (Acctype(1) - t) * THCNumerics::log(Acctype(1) - input + eps())); + return - w * (t * THCNumerics::log(input + eps()) + + (Acctype(1) - t) * THCNumerics::log(Acctype(1) - input + eps())); + } +}; + +template +struct bce_updateGradInput_no_reduce_functor +{ + __forceinline__ __host__ __device__ + void operator()( + const Dtype *x, + const Dtype *t, + Dtype *gradInput) + { + *gradInput = ScalarConvert::to( + - (*t - *x) / ((Acctype(1) - *x + eps()) * (*x + eps()))); } }; diff --git a/aten/src/THCUNN/generic/BCECriterion.cu b/aten/src/THCUNN/generic/BCECriterion.cu index 4d9988b757..e98f1b05db 100644 --- a/aten/src/THCUNN/generic/BCECriterion.cu +++ b/aten/src/THCUNN/generic/BCECriterion.cu @@ -2,19 +2,32 @@ #define THC_GENERIC_FILE "generic/BCECriterion.cu" #else +#include "THCApply.cuh" + void THNN_(BCECriterion_updateOutput)( THCState *state, THCTensor *input, THCTensor *target, THCTensor *output, bool sizeAverage, - THCTensor *weights) + THCTensor *weights, + bool reduce) { THCUNN_check_nElement(state, input, target); THCUNN_check_nElement(state, input, weights); - THCTensor_(resize1d)(state, output, 1); THCUNN_assertSameGPU(state, 3, input, target, weights); + if (!reduce) { + THCTensor_(resizeAs)(state, output, input); + THC_pointwiseApply3(state, input, target, output, + bce_updateOutput_no_reduce_functor()); + if (weights) { + THCTensor_(cmul)(state, output, output, weights); + } + return; + } + + THCTensor_(resize1d)(state, output, 1); ptrdiff_t size = THCTensor_(nElement)(state, input); input = THCTensor_(newContiguous)(state, input); @@ -58,22 +71,37 @@ void THNN_(BCECriterion_updateGradInput)( THCState *state, THCTensor *input, THCTensor *target, + THCTensor *gradOutput, THCTensor *gradInput, bool sizeAverage, - THCTensor *weights) + THCTensor *weights, + bool reduce) { THCUNN_check_nElement(state, input, target); THCUNN_check_nElement(state, input, weights); THCUNN_assertSameGPU(state, 4, input, target, gradInput, weights); + THCTensor_(resizeAs)(state, gradInput, input); + + if (!reduce) { + THCUNN_check_nElement(state, gradOutput, input); + THC_pointwiseApply3(state, input, target, gradInput, + bce_updateGradInput_no_reduce_functor()); + THCTensor_(cmul)(state, gradInput, gradInput, gradOutput); + if (weights) { + THCTensor_(cmul)(state, gradInput, gradInput, weights); + } + return; + } + + THCUNN_check_dim_size(state, gradOutput, 1, 0, 1); + ptrdiff_t size = THCTensor_(nElement)(state, input); - real norm = ScalarConvert::to(sizeAverage ? accreal(1)/size : accreal(1)); + real norm = ScalarConvert::to((sizeAverage ? accreal(1)/size : accreal(1)) * THCTensor_(get1d)(state, gradOutput, 0)); input = THCTensor_(newContiguous)(state, input); target = THCTensor_(newContiguous)(state, target); - THCTensor_(resizeAs)(state, gradInput, input); - thrust::device_ptr input_data(THCTensor_(data)(state, input)); thrust::device_ptr target_data(THCTensor_(data)(state, target)); thrust::device_ptr gradInput_data(THCTensor_(data)(state, gradInput)); diff --git a/aten/src/THCUNN/generic/THCUNN.h b/aten/src/THCUNN/generic/THCUNN.h index e86583bde4..7f5002f484 100644 --- a/aten/src/THCUNN/generic/THCUNN.h +++ b/aten/src/THCUNN/generic/THCUNN.h @@ -66,15 +66,18 @@ TH_API void THNN_(BCECriterion_updateOutput)( THCTensor *target, THCTensor *output, bool sizeAverage, - THCTensor *weights); // [OPTIONAL] + THCTensor *weights, // [OPTIONAL] + bool reduce); TH_API void THNN_(BCECriterion_updateGradInput)( THCState *state, THCTensor *input, THCTensor *target, + THCTensor *gradOutput, THCTensor *gradInput, bool sizeAverage, - THCTensor *weights); // [OPTIONAL] + THCTensor *weights, // [OPTIONAL] + bool reduce); TH_API void THNN_(ClassNLLCriterion_updateOutput)( THCState *state, diff --git a/aten/src/THNN/generic/BCECriterion.c b/aten/src/THNN/generic/BCECriterion.c index b668370336..1f69c315e8 100644 --- a/aten/src/THNN/generic/BCECriterion.c +++ b/aten/src/THNN/generic/BCECriterion.c @@ -4,16 +4,38 @@ #define EPS 1e-12 -void THNN_(BCECriterion_updateOutput)(THNNState *state, THTensor *input, - THTensor *target, THTensor *output, - bool sizeAverage, THTensor *weights) +void THNN_(BCECriterion_updateOutput)( + THNNState *state, + THTensor *input, + THTensor *target, + THTensor *output, + bool sizeAverage, + THTensor *weights, + bool reduce) { THNN_CHECK_NELEMENT(input, target); THNN_CHECK_NELEMENT(input, weights); + + if (!reduce) { + THTensor_(resizeAs)(output, input); + TH_TENSOR_APPLY3(real, input, real, target, real, output, + real x = *input_data; + real y = *target_data; + THAssertMsg(x >= 0. && x <= 1., + "input value should be between 0~1, but got %f", + (double) x); + *output_data = -(log(x + EPS) * y + log(1. - x + EPS) * (1. - y)); + ); + if (weights) { + THTensor_(cmul)(output, output, weights); + } + return; + } + THTensor_(resize1d)(output, 1); real sum = 0; - if(weights) + if (weights) { TH_TENSOR_APPLY3(real, input, real, target, real, weights, real x = *input_data; real y = *target_data; @@ -22,8 +44,8 @@ void THNN_(BCECriterion_updateOutput)(THNNState *state, THTensor *input, "input value should be between 0~1, but got %f", (double) x); sum -= (log(x + EPS) * y + log(1. - x + EPS) * (1. - y)) * w; - ) - else + ); + } else { TH_TENSOR_APPLY2(real, input, real, target, real x = *input_data; real y = *target_data; @@ -32,6 +54,7 @@ void THNN_(BCECriterion_updateOutput)(THNNState *state, THTensor *input, (double) x); sum -= log(x + EPS) * y + log(1. - x + EPS) * (1. - y); ); + } if (sizeAverage) @@ -40,21 +63,45 @@ void THNN_(BCECriterion_updateOutput)(THNNState *state, THTensor *input, THTensor_(set1d)(output, 0, sum); } -void THNN_(BCECriterion_updateGradInput)(THNNState *state, THTensor *input, - THTensor *target, THTensor *gradInput, - bool sizeAverage, THTensor *weights) +void THNN_(BCECriterion_updateGradInput)( + THNNState *state, + THTensor *input, + THTensor *target, + THTensor *gradOutput, + THTensor *gradInput, + bool sizeAverage, + THTensor *weights, + bool reduce) { THNN_CHECK_NELEMENT(input, target); THNN_CHECK_NELEMENT(input, weights); + THTensor_(resizeAs)(gradInput, input); - real norm = (sizeAverage ? 1./((real)THTensor_(nElement)(input)) : 1.); + if (!reduce) { + THNN_CHECK_NELEMENT(gradOutput, input); + TH_TENSOR_APPLY3(real, gradInput, real, input, real, target, + real x = *input_data; + real y = *target_data; + *gradInput_data = -(y - x) / ((1. - x + EPS) * (x + EPS)); + ); - THTensor_(resizeAs)(gradInput, input); + if (weights) { + TH_TENSOR_APPLY3(real, gradInput, real, weights, real, gradOutput, + *gradInput_data = *gradInput_data * *weights_data * *gradOutput_data; + ); + } else { + THTensor_(cmul)(gradInput, gradInput, gradOutput); + } + return; + } + + THNN_CHECK_DIM_SIZE(gradOutput, 1, 0, 1); + real norm = (sizeAverage ? 1./((real)THTensor_(nElement)(input)) : 1.); TH_TENSOR_APPLY3(real, gradInput, real, input, real, target, real x = *input_data; real y = *target_data; - *gradInput_data = - norm * (y - x) / ((1. - x + EPS) * (x + EPS)); + *gradInput_data = - norm * (y - x) / ((1. - x + EPS) * (x + EPS)) * THTensor_fastGet1d(gradOutput, 0); ); if(weights) diff --git a/aten/src/THNN/generic/THNN.h b/aten/src/THNN/generic/THNN.h index 32fe59ae3b..93436dfec1 100644 --- a/aten/src/THNN/generic/THNN.h +++ b/aten/src/THNN/generic/THNN.h @@ -34,14 +34,17 @@ TH_API void THNN_(BCECriterion_updateOutput)( THTensor *target, THTensor *output, bool sizeAverage, - THTensor *weights); // [OPTIONAL] + THTensor *weights, // [OPTIONAL] + bool reduce); TH_API void THNN_(BCECriterion_updateGradInput)( THNNState *state, THTensor *input, THTensor *target, + THTensor *gradOutput, THTensor *gradInput, bool sizeAverage, - THTensor *weights); // [OPTIONAL] + THTensor *weights, // [OPTIONAL] + bool reduce); TH_API void THNN_(ClassNLLCriterion_updateOutput)( THNNState *state, // library's state diff --git a/test/common_nn.py b/test/common_nn.py index c0a2e1b13f..3b926ad7c0 100644 --- a/test/common_nn.py +++ b/test/common_nn.py @@ -400,6 +400,8 @@ criterion_tests = [ module_name='BCELoss', input_fn=lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2), target_fn=lambda: torch.randn(15, 10).gt(0).double(), + reference_fn=lambda i, t, m: -(t * i.log() + (1 - t) * (1 - i).log()).sum() / + (i.numel() if get_size_average(m) else 1), check_gradgrad=False, ), dict( @@ -407,6 +409,8 @@ criterion_tests = [ constructor_args_fn=lambda: (torch.rand(10),), input_fn=lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2), target_fn=lambda: torch.randn(15, 10).gt(0).double(), + reference_fn=lambda i, t, m: -((t * i.log() + (1 - t) * (1 - i).log()) * get_weight(m)).sum() / + (i.numel() if get_size_average(m) else 1), desc='weights', check_gradgrad=False, ), diff --git a/test/test_nn.py b/test/test_nn.py index ef65b18457..df948204da 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -3952,6 +3952,31 @@ new_criterion_tests = [ ] +def bceloss_no_reduce_test(): + t = torch.randn(15, 10).gt(0).double() + return dict( + fullname='BCELoss_no_reduce', + constructor=wrap_functional( + lambda i: F.binary_cross_entropy(i, Variable(t.type_as(i.data)), reduce=False)), + input_fn=lambda: torch.rand(15, 10).clamp_(2e-2, 1 - 2e-2), + reference_fn=lambda i, m: -(t * i.log() + (1 - t) * (1 - i).log()), + check_gradgrad=False, + pickle=False) + + +def bceloss_weights_no_reduce_test(): + t = torch.randn(15, 10).gt(0).double() + weights = torch.rand(10) + return dict( + fullname='BCELoss_weights_no_reduce', + constructor=wrap_functional( + lambda i: F.binary_cross_entropy(i, Variable(t.type_as(i.data)), + weight=weights.type_as(i.data), reduce=False)), + input_fn=lambda: torch.rand(15, 10).clamp_(2e-2, 1 - 2e-2), + reference_fn=lambda i, m: -(t * i.log() + (1 - t) * (1 - i).log()) * weights, + check_gradgrad=False, + + def poissonnllloss_no_reduce_test(): t = Variable(torch.randn(10, 10)) return dict( @@ -4175,6 +4200,8 @@ def smoothl1loss_no_reduce_test(): new_module_tests = [ + bceloss_no_reduce_test(), + bceloss_weights_no_reduce_test(), poissonnllloss_no_reduce_test(), kldivloss_no_reduce_test(), l1loss_no_reduce_test(), diff --git a/torch/nn/functional.py b/torch/nn/functional.py index a9b5a241e9..8ba929243e 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -1212,7 +1212,7 @@ def cross_entropy(input, target, weight=None, size_average=True, ignore_index=-1 return nll_loss(log_softmax(input, 1), target, weight, size_average, ignore_index, reduce) -def binary_cross_entropy(input, target, weight=None, size_average=True): +def binary_cross_entropy(input, target, weight=None, size_average=True, reduce=True): r"""Function that measures the Binary Cross Entropy between the target and the output. @@ -1227,6 +1227,10 @@ def binary_cross_entropy(input, target, weight=None, size_average=True): over observations for each minibatch. However, if the field sizeAverage is set to False, the losses are instead summed for each minibatch. Default: ``True`` + reduce (bool, optional): By default, the losses are averaged or summed over + observations for each minibatch depending on size_average. When reduce + is False, returns a loss per batch element instead and ignores + size_average. Default: True Examples:: @@ -1248,7 +1252,7 @@ def binary_cross_entropy(input, target, weight=None, size_average=True): if torch.is_tensor(weight): weight = Variable(weight) - return torch._C._nn.binary_cross_entropy(input, target, weight, size_average) + return torch._C._nn.binary_cross_entropy(input, target, weight, size_average, reduce) def binary_cross_entropy_with_logits(input, target, weight=None, size_average=True): diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py index b70a5ce08f..4b3620bf71 100644 --- a/torch/nn/modules/loss.py +++ b/torch/nn/modules/loss.py @@ -366,11 +366,17 @@ class BCELoss(_WeightedLoss): over observations for each minibatch. However, if the field size_average is set to ``False``, the losses are instead summed for each minibatch. Default: ``True`` + reduce (bool, optional): By default, the losses are averaged or summed over + observations for each minibatch depending on size_average. When reduce + is False, returns a loss per batch element instead and ignores + size_average. Default: True Shape: - Input: :math:`(N, *)` where `*` means, any number of additional dimensions - Target: :math:`(N, *)`, same shape as the input + - Output: scalar. If `reduce` is False, then `(N, *)`, same shape as + input. Examples:: @@ -381,10 +387,15 @@ class BCELoss(_WeightedLoss): >>> output = loss(m(input), target) >>> output.backward() """ + def __init__(self, weight=None, size_average=True, reduce=True): + super(BCELoss, self).__init__(weight, size_average) + self.reduce = reduce + def forward(self, input, target): _assert_no_grad(target) return F.binary_cross_entropy(input, target, weight=self.weight, - size_average=self.size_average) + size_average=self.size_average, + reduce=self.reduce) class BCEWithLogitsLoss(Module): -- cgit v1.2.3