diff options
author | Elias Ellison <eellison@fb.com> | 2019-02-08 11:34:40 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-02-08 12:32:11 -0800 |
commit | cd2dca3cafb49fbf64a4714571717f2390625e72 (patch) | |
tree | 9dba9c98aa61081c1bd6a5cf834bd2332edcd22b /test | |
parent | 5ada54e0bc797421bbea5b4ba36e93c1924e4d47 (diff) | |
download | pytorch-cd2dca3cafb49fbf64a4714571717f2390625e72.tar.gz pytorch-cd2dca3cafb49fbf64a4714571717f2390625e72.tar.bz2 pytorch-cd2dca3cafb49fbf64a4714571717f2390625e72.zip |
Allow sequential modules in module list (#16882)
Summary:
Fix for https://github.com/pytorch/pytorch/issues/16845
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16882
Differential Revision: D14007746
Pulled By: eellison
fbshipit-source-id: d7918275cc1de6a67320619c3203463f66783343
Diffstat (limited to 'test')
-rw-r--r-- | test/test_jit.py | 59 |
1 files changed, 59 insertions, 0 deletions
diff --git a/test/test_jit.py b/test/test_jit.py index d37246c7b9..2caf025cd7 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -5750,6 +5750,65 @@ a") i = torch.Tensor(2) hs(i) + def test_script_sequential_in_mod_list(self): + class Sub(torch.jit.ScriptModule): + def __init__(self): + super(Sub, self).__init__(False) + self.weight = nn.Parameter(torch.randn(2)) + + @torch.jit.script_method + def forward(self, thing): + return self.weight + thing + + class M(torch.jit.ScriptModule): + __constants__ = ['mods'] + + def __init__(self): + super(M, self).__init__(False) + self.mods = nn.ModuleList([Sub(), nn.Sequential(Sub(), nn.Sequential(Sub(), Sub()), Sub())]) + + @torch.jit.script_method + def forward(self, v): + for mod in self.mods: + v = mod(v) + return v + + m = M() + graph = str(m.graph) + print(graph) + return + self.assertTrue(graph.count("aten::add") == 4) + self.assertTrue("python" not in graph) + + def test_script_nested_mod_list(self): + class Sub(torch.jit.ScriptModule): + def __init__(self): + super(Sub, self).__init__(False) + self.weight = nn.Parameter(torch.randn(2)) + + @torch.jit.script_method + def forward(self, thing): + return self.weight + thing + + class M(torch.jit.ScriptModule): + __constants__ = ['mods'] + + def __init__(self): + super(M, self).__init__(False) + self.mods = nn.ModuleList([nn.ModuleList([Sub()]), nn.Sequential(Sub()), nn.ModuleList([Sub(), Sub()])]) + + @torch.jit.script_method + def forward(self, v): + for mod in self.mods: + for m in mod: + v = m(v) + return v + + m = M() + graph = str(m.graph) + self.assertTrue(graph.count("aten::add") == 4) + self.assertTrue("python" not in graph) + def test_constant_as_attr(self): class M(torch.jit.ScriptModule): __constants__ = ['dim'] |