summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorElias Ellison <eellison@fb.com>2019-02-08 11:34:40 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-02-08 12:32:11 -0800
commitcd2dca3cafb49fbf64a4714571717f2390625e72 (patch)
tree9dba9c98aa61081c1bd6a5cf834bd2332edcd22b /test
parent5ada54e0bc797421bbea5b4ba36e93c1924e4d47 (diff)
downloadpytorch-cd2dca3cafb49fbf64a4714571717f2390625e72.tar.gz
pytorch-cd2dca3cafb49fbf64a4714571717f2390625e72.tar.bz2
pytorch-cd2dca3cafb49fbf64a4714571717f2390625e72.zip
Allow sequential modules in module list (#16882)
Summary: Fix for https://github.com/pytorch/pytorch/issues/16845 Pull Request resolved: https://github.com/pytorch/pytorch/pull/16882 Differential Revision: D14007746 Pulled By: eellison fbshipit-source-id: d7918275cc1de6a67320619c3203463f66783343
Diffstat (limited to 'test')
-rw-r--r--test/test_jit.py59
1 files changed, 59 insertions, 0 deletions
diff --git a/test/test_jit.py b/test/test_jit.py
index d37246c7b9..2caf025cd7 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -5750,6 +5750,65 @@ a")
i = torch.Tensor(2)
hs(i)
+ def test_script_sequential_in_mod_list(self):
+ class Sub(torch.jit.ScriptModule):
+ def __init__(self):
+ super(Sub, self).__init__(False)
+ self.weight = nn.Parameter(torch.randn(2))
+
+ @torch.jit.script_method
+ def forward(self, thing):
+ return self.weight + thing
+
+ class M(torch.jit.ScriptModule):
+ __constants__ = ['mods']
+
+ def __init__(self):
+ super(M, self).__init__(False)
+ self.mods = nn.ModuleList([Sub(), nn.Sequential(Sub(), nn.Sequential(Sub(), Sub()), Sub())])
+
+ @torch.jit.script_method
+ def forward(self, v):
+ for mod in self.mods:
+ v = mod(v)
+ return v
+
+ m = M()
+ graph = str(m.graph)
+ print(graph)
+ return
+ self.assertTrue(graph.count("aten::add") == 4)
+ self.assertTrue("python" not in graph)
+
+ def test_script_nested_mod_list(self):
+ class Sub(torch.jit.ScriptModule):
+ def __init__(self):
+ super(Sub, self).__init__(False)
+ self.weight = nn.Parameter(torch.randn(2))
+
+ @torch.jit.script_method
+ def forward(self, thing):
+ return self.weight + thing
+
+ class M(torch.jit.ScriptModule):
+ __constants__ = ['mods']
+
+ def __init__(self):
+ super(M, self).__init__(False)
+ self.mods = nn.ModuleList([nn.ModuleList([Sub()]), nn.Sequential(Sub()), nn.ModuleList([Sub(), Sub()])])
+
+ @torch.jit.script_method
+ def forward(self, v):
+ for mod in self.mods:
+ for m in mod:
+ v = m(v)
+ return v
+
+ m = M()
+ graph = str(m.graph)
+ self.assertTrue(graph.count("aten::add") == 4)
+ self.assertTrue("python" not in graph)
+
def test_constant_as_attr(self):
class M(torch.jit.ScriptModule):
__constants__ = ['dim']