diff options
author | Wanchao Liang <wanchaol@users.noreply.github.com> | 2019-04-03 16:50:46 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-04-03 16:58:33 -0700 |
commit | 843e6234f5f87f281a487fd4f8434e07101ee3ed (patch) | |
tree | b1a6499bbcfb619f00f6d90605046c2d0a422359 | |
parent | 0512e4e32348f01dc2184031a7c4d4644a455ac3 (diff) | |
download | pytorch-843e6234f5f87f281a487fd4f8434e07101ee3ed.tar.gz pytorch-843e6234f5f87f281a487fd4f8434e07101ee3ed.tar.bz2 pytorch-843e6234f5f87f281a487fd4f8434e07101ee3ed.zip |
Fix layernorm ad formula on weight and bias (#18233)
Summary:
Fix the layernorm formula when weight and bias passed in.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18233
Differential Revision: D14760375
Pulled By: wanchaol
fbshipit-source-id: d6bd3b137bc04c391aa5c24d021d1f811ba2a877
-rw-r--r-- | aten/src/ATen/native/Normalization.cpp | 2 | ||||
-rw-r--r-- | test/test_jit.py | 13 | ||||
-rw-r--r-- | torch/csrc/jit/symbolic_script.cpp | 51 |
3 files changed, 47 insertions, 19 deletions
diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index e4be45198c..b1d3b31cfd 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -384,7 +384,7 @@ Tensor instance_norm( Tensor layer_norm(const Tensor& input, IntArrayRef normalized_shape, const Tensor& weight /* optional */, const Tensor& bias /* optional */, double eps, bool cudnn_enabled) { - + int64_t normalized_ndim = normalized_shape.size(); AT_CHECK(normalized_ndim >= 1, diff --git a/test/test_jit.py b/test/test_jit.py index 70d01df507..79cfba33bd 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -12170,10 +12170,15 @@ nn_functional_tests = [ ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), ), '', (True, 'aten::_batch_norm_impl_index')), ('instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),), - ('layer_norm', (S, S, S, S), ([5],),), - ('layer_norm', (S, S, S, S), ([5], (S,)), 'with_only_weight'), - ('layer_norm', (S, S, S, S), ([5], None, (S,)), 'with_only_bias'), - ('layer_norm', (S, S, S, S), ([5], (S,), (S,)), 'with_weight_and_bias'), + ('layer_norm', (S, S, S, S), ([5],), '', + (True, ['prim::Loop', 'aten::_batch_norm_impl_index'])), + ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),), 'with_only_weight', + (True, ['prim::Loop', 'aten::_batch_norm_impl_index'])), + ('layer_norm', (S, S, S, S), ([5], None, non_differentiable(torch.rand(S)),), 'with_only_bias', + (True, ['prim::Loop', 'aten::_batch_norm_impl_index'])), + ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)), + non_differentiable(torch.rand(S))), 'with_weight_and_bias', + (True, ['prim::Loop', 'aten::_batch_norm_impl_index'])), ('group_norm', (S, S, S), (1, torch.rand(5),),), ('local_response_norm', (S, S, S), (2, ),), ('nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (torch.tensor([1, 0, 4]),), '', (True, 'aten::nll_loss_forward')), diff --git a/torch/csrc/jit/symbolic_script.cpp b/torch/csrc/jit/symbolic_script.cpp index c897d73657..3cfcd7c272 100644 --- a/torch/csrc/jit/symbolic_script.cpp +++ b/torch/csrc/jit/symbolic_script.cpp @@ -660,6 +660,7 @@ const std::vector<std::string> functions = { return torch.adaptive_avg_pool3d(self, output_size), backward + def batch_norm(input : Tensor, weight : Optional[Tensor], bias : Optional[Tensor], @@ -685,21 +686,27 @@ const std::vector<std::string> functions = { return output, backward def layer_norm(input : Tensor, - normalied_shape : List[int], + normalized_shape : List[int], weight : Optional[Tensor], bias : Optional[Tensor], eps : float, cudnn_enable : bool): + input_ndim = input.dim() + normalized_ndim = len(normalized_shape) + n = 1 + for i in range(input_ndim - normalized_ndim): + n *= input.size(i) + + input_reshape = input.contiguous().view(1, n, -1) + bn_out, save1, save2, impl_idx = torch._batch_norm_impl_index( - input, weight, bias, None, None, True, + input_reshape, None, None, None, None, True, 0.0, eps, cudnn_enable) - has_weight = weight is not None - has_bias = bias is not None - bn_out = bn_out.view(input.sizes()) + bn_out = bn_out.view(input.size()) if weight is not None and bias is not None: - output = bias.addcmul(bn_out, weight) + output = bias.addcmul(bn_out, weight, value=1) elif weight is not None: output = bn_out.mul(weight) elif bias is not None: @@ -708,16 +715,32 @@ const std::vector<std::string> functions = { output = bn_out def backward(grad_output): - if weight is not None: - grad_output = grad_output * torch.t(weight) - weight = grad_output * torch.t(bn_out) + if weight is not None and bias is not None: + grad_bn_out = grad_output * weight + grad_weight = (grad_output * bn_out)._grad_sum_to_size(weight.size()) + grad_bias = grad_output._grad_sum_to_size(bias.size()) + elif weight is not None: + grad_bn_out = grad_output * weight + grad_weight = (grad_output * bn_out)._grad_sum_to_size(weight.size()) + grad_bias = None + elif bias is not None: + grad_bn_out = grad_output + grad_weight= None + grad_bias = grad_output._grad_sum_to_size(bias.size()) + else: + grad_bn_out = grad_output + grad_weight= None + grad_bias = None - grad_output = grad_output.reshape(input.sizes()) - dinput, dweight, dbias = torch._batch_norm_impl_index_backward( - impl_idx, input, grad_output, weight, None, None, - save1, save2, True, eps, [True, has_weight, has_bias]) - return dinput, None, dweight, dbias, None, None + grad_bn_out = grad_bn_out.contiguous().view(1, n, -1) + + grad_input, _, _ = torch._batch_norm_impl_index_backward( + impl_idx, input_reshape, grad_bn_out, None, None, None, + save1, save2, True, eps, [True, False, False]) + + grad_input = grad_input.view(input.size()) + return grad_input, None, grad_weight, grad_bias, None, None return output, backward |