summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWanchao Liang <wanchaol@users.noreply.github.com>2019-04-03 16:50:46 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-04-03 16:58:33 -0700
commit843e6234f5f87f281a487fd4f8434e07101ee3ed (patch)
treeb1a6499bbcfb619f00f6d90605046c2d0a422359
parent0512e4e32348f01dc2184031a7c4d4644a455ac3 (diff)
downloadpytorch-843e6234f5f87f281a487fd4f8434e07101ee3ed.tar.gz
pytorch-843e6234f5f87f281a487fd4f8434e07101ee3ed.tar.bz2
pytorch-843e6234f5f87f281a487fd4f8434e07101ee3ed.zip
Fix layernorm ad formula on weight and bias (#18233)
Summary: Fix the layernorm formula when weight and bias passed in. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18233 Differential Revision: D14760375 Pulled By: wanchaol fbshipit-source-id: d6bd3b137bc04c391aa5c24d021d1f811ba2a877
-rw-r--r--aten/src/ATen/native/Normalization.cpp2
-rw-r--r--test/test_jit.py13
-rw-r--r--torch/csrc/jit/symbolic_script.cpp51
3 files changed, 47 insertions, 19 deletions
diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp
index e4be45198c..b1d3b31cfd 100644
--- a/aten/src/ATen/native/Normalization.cpp
+++ b/aten/src/ATen/native/Normalization.cpp
@@ -384,7 +384,7 @@ Tensor instance_norm(
Tensor layer_norm(const Tensor& input, IntArrayRef normalized_shape,
const Tensor& weight /* optional */, const Tensor& bias /* optional */,
double eps, bool cudnn_enabled) {
-
+
int64_t normalized_ndim = normalized_shape.size();
AT_CHECK(normalized_ndim >= 1,
diff --git a/test/test_jit.py b/test/test_jit.py
index 70d01df507..79cfba33bd 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -12170,10 +12170,15 @@ nn_functional_tests = [
('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), ),
'', (True, 'aten::_batch_norm_impl_index')),
('instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),),
- ('layer_norm', (S, S, S, S), ([5],),),
- ('layer_norm', (S, S, S, S), ([5], (S,)), 'with_only_weight'),
- ('layer_norm', (S, S, S, S), ([5], None, (S,)), 'with_only_bias'),
- ('layer_norm', (S, S, S, S), ([5], (S,), (S,)), 'with_weight_and_bias'),
+ ('layer_norm', (S, S, S, S), ([5],), '',
+ (True, ['prim::Loop', 'aten::_batch_norm_impl_index'])),
+ ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),), 'with_only_weight',
+ (True, ['prim::Loop', 'aten::_batch_norm_impl_index'])),
+ ('layer_norm', (S, S, S, S), ([5], None, non_differentiable(torch.rand(S)),), 'with_only_bias',
+ (True, ['prim::Loop', 'aten::_batch_norm_impl_index'])),
+ ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),
+ non_differentiable(torch.rand(S))), 'with_weight_and_bias',
+ (True, ['prim::Loop', 'aten::_batch_norm_impl_index'])),
('group_norm', (S, S, S), (1, torch.rand(5),),),
('local_response_norm', (S, S, S), (2, ),),
('nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (torch.tensor([1, 0, 4]),), '', (True, 'aten::nll_loss_forward')),
diff --git a/torch/csrc/jit/symbolic_script.cpp b/torch/csrc/jit/symbolic_script.cpp
index c897d73657..3cfcd7c272 100644
--- a/torch/csrc/jit/symbolic_script.cpp
+++ b/torch/csrc/jit/symbolic_script.cpp
@@ -660,6 +660,7 @@ const std::vector<std::string> functions = {
return torch.adaptive_avg_pool3d(self, output_size), backward
+
def batch_norm(input : Tensor,
weight : Optional[Tensor],
bias : Optional[Tensor],
@@ -685,21 +686,27 @@ const std::vector<std::string> functions = {
return output, backward
def layer_norm(input : Tensor,
- normalied_shape : List[int],
+ normalized_shape : List[int],
weight : Optional[Tensor],
bias : Optional[Tensor],
eps : float,
cudnn_enable : bool):
+ input_ndim = input.dim()
+ normalized_ndim = len(normalized_shape)
+ n = 1
+ for i in range(input_ndim - normalized_ndim):
+ n *= input.size(i)
+
+ input_reshape = input.contiguous().view(1, n, -1)
+
bn_out, save1, save2, impl_idx = torch._batch_norm_impl_index(
- input, weight, bias, None, None, True,
+ input_reshape, None, None, None, None, True,
0.0, eps, cudnn_enable)
- has_weight = weight is not None
- has_bias = bias is not None
- bn_out = bn_out.view(input.sizes())
+ bn_out = bn_out.view(input.size())
if weight is not None and bias is not None:
- output = bias.addcmul(bn_out, weight)
+ output = bias.addcmul(bn_out, weight, value=1)
elif weight is not None:
output = bn_out.mul(weight)
elif bias is not None:
@@ -708,16 +715,32 @@ const std::vector<std::string> functions = {
output = bn_out
def backward(grad_output):
- if weight is not None:
- grad_output = grad_output * torch.t(weight)
- weight = grad_output * torch.t(bn_out)
+ if weight is not None and bias is not None:
+ grad_bn_out = grad_output * weight
+ grad_weight = (grad_output * bn_out)._grad_sum_to_size(weight.size())
+ grad_bias = grad_output._grad_sum_to_size(bias.size())
+ elif weight is not None:
+ grad_bn_out = grad_output * weight
+ grad_weight = (grad_output * bn_out)._grad_sum_to_size(weight.size())
+ grad_bias = None
+ elif bias is not None:
+ grad_bn_out = grad_output
+ grad_weight= None
+ grad_bias = grad_output._grad_sum_to_size(bias.size())
+ else:
+ grad_bn_out = grad_output
+ grad_weight= None
+ grad_bias = None
- grad_output = grad_output.reshape(input.sizes())
- dinput, dweight, dbias = torch._batch_norm_impl_index_backward(
- impl_idx, input, grad_output, weight, None, None,
- save1, save2, True, eps, [True, has_weight, has_bias])
- return dinput, None, dweight, dbias, None, None
+ grad_bn_out = grad_bn_out.contiguous().view(1, n, -1)
+
+ grad_input, _, _ = torch._batch_norm_impl_index_backward(
+ impl_idx, input_reshape, grad_bn_out, None, None, None,
+ save1, save2, True, eps, [True, False, False])
+
+ grad_input = grad_input.view(input.size())
+ return grad_input, None, grad_weight, grad_bias, None, None
return output, backward