summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorAdam Paszke <adam.paszke@gmail.com>2018-05-16 20:03:04 +0200
committerGitHub <noreply@github.com>2018-05-16 20:03:04 +0200
commitb45f2ff1ae5b8a7fb04f96a9d6de86df39c1dcb2 (patch)
tree8c27506722d9fcb9055f51f9abf90cc1d5b6d63a /test
parent28b0b16f9bd8678b1f31bb4fb6de1b31f6e87f3d (diff)
downloadpytorch-b45f2ff1ae5b8a7fb04f96a9d6de86df39c1dcb2.tar.gz
pytorch-b45f2ff1ae5b8a7fb04f96a9d6de86df39c1dcb2.tar.bz2
pytorch-b45f2ff1ae5b8a7fb04f96a9d6de86df39c1dcb2.zip
Remove CompiledFunction + clean up JIT tests (#7421)
Diffstat (limited to 'test')
-rw-r--r--test/expect/TestJit.test_alexnet.expect93
-rw-r--r--test/expect/TestJit.test_backward.expect23
-rw-r--r--test/expect/TestJit.test_backward_opaque.expect10
-rw-r--r--test/expect/TestJit.test_c_function.expect6
-rw-r--r--test/expect/TestJit.test_concat_fusion.expect2
-rw-r--r--test/expect/TestJit.test_function_as_argument.expect26
-rw-r--r--test/expect/TestJit.test_fuse_last_device.expect14
-rw-r--r--test/expect/TestJit.test_fusion_distribute-raw.expect7
-rw-r--r--test/expect/TestJit.test_fusion_distribute.expect2
-rw-r--r--test/expect/TestJit.test_index_inner.expect5
-rw-r--r--test/expect/TestJit.test_index_trace.expect8
-rw-r--r--test/expect/TestJit.test_input_pruning.expect12
-rw-r--r--test/expect/TestJit.test_lstm.expect42
-rw-r--r--test/expect/TestJit.test_lstm_fusion_concat.expect41
-rw-r--r--test/expect/TestJit.test_lstm_fusion_cpu.expect (renamed from test/expect/TestJit.test_lstm_fusion.expect)16
-rw-r--r--test/expect/TestJit.test_lstm_fusion_cuda.expect40
-rw-r--r--test/expect/TestJit.test_matmul_native.expect5
-rw-r--r--test/expect/TestJit.test_output_pruning.expect10
-rw-r--r--test/expect/TestJit.test_repeated_input.expect6
-rw-r--r--test/expect/TestJit.test_repeated_output.expect9
-rw-r--r--test/expect/TestJit.test_saved_output.expect8
-rw-r--r--test/expect/TestJit.test_shape_analysis_broadcast.expect8
-rw-r--r--test/test_jit.py1153
23 files changed, 440 insertions, 1106 deletions
diff --git a/test/expect/TestJit.test_alexnet.expect b/test/expect/TestJit.test_alexnet.expect
index 62f45eff89..9a927efd24 100644
--- a/test/expect/TestJit.test_alexnet.expect
+++ b/test/expect/TestJit.test_alexnet.expect
@@ -1,46 +1,49 @@
-graph(%1 : Double(10, 3, 224, 224)
- %2 : Double(64, 3, 11, 11)
- %3 : Double(64)
- %4 : Double(192, 64, 5, 5)
- %5 : Double(192)
- %6 : Double(384, 192, 3, 3)
- %7 : Double(384)
- %8 : Double(256, 384, 3, 3)
- %9 : Double(256)
- %10 : Double(256, 256, 3, 3)
- %11 : Double(256)
- %12 : Double(4096, 9216)
- %13 : Double(4096)
- %14 : Double(4096, 4096)
- %15 : Double(4096)
- %16 : Double(1000, 4096)
- %17 : Double(1000)) {
- %19 : Double(10, 64, 55, 55), %20 : Handle = CppOp[ConvForward](%1, %2, %3), uses = [[%21.i0], []];
- %21 : Double(10, 64, 55, 55) = threshold[threshold={0}, value={0}, inplace=1](%19), uses = [%22.i0];
- %23 : Double(10, 64, 27, 27), %24 : Long(10, 64, 27, 27) = max_pool2d[kernel_size=[3, 3], stride=[2, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=0](%21), uses = [[%25.i0], []];
- %26 : Double(10, 192, 27, 27), %27 : Handle = CppOp[ConvForward](%23, %4, %5), uses = [[%28.i0], []];
- %28 : Double(10, 192, 27, 27) = threshold[threshold={0}, value={0}, inplace=1](%26), uses = [%29.i0];
- %30 : Double(10, 192, 13, 13), %31 : Long(10, 192, 13, 13) = max_pool2d[kernel_size=[3, 3], stride=[2, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=0](%28), uses = [[%32.i0], []];
- %33 : Double(10, 384, 13, 13), %34 : Handle = CppOp[ConvForward](%30, %6, %7), uses = [[%35.i0], []];
- %35 : Double(10, 384, 13, 13) = threshold[threshold={0}, value={0}, inplace=1](%33), uses = [%36.i0];
- %37 : Double(10, 256, 13, 13), %38 : Handle = CppOp[ConvForward](%35, %8, %9), uses = [[%39.i0], []];
- %39 : Double(10, 256, 13, 13) = threshold[threshold={0}, value={0}, inplace=1](%37), uses = [%40.i0];
- %41 : Double(10, 256, 13, 13), %42 : Handle = CppOp[ConvForward](%39, %10, %11), uses = [[%43.i0], []];
- %43 : Double(10, 256, 13, 13) = threshold[threshold={0}, value={0}, inplace=1](%41), uses = [%44.i0];
- %45 : Double(10, 256, 6, 6), %46 : Long(10, 256, 6, 6) = max_pool2d[kernel_size=[3, 3], stride=[2, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=0](%43), uses = [[%47.i0], []];
- %47 : Double(10, 9216) = view[size=[10, 9216]](%45), uses = [%48.i0];
- %49 : Double(10, 9216), %50 : Handle = ^Dropout(0.5, True, False)(%47), uses = [[%53.i1], []];
- %51 : Double(9216!, 4096!) = t(%12), uses = [%53.i2];
- %52 : Double(10!, 4096) = expand[size=[10, 4096]](%13), uses = [%53.i0];
- %53 : Double(10, 4096) = addmm[beta={1}, alpha={1}](%52, %49, %51), uses = [%54.i0];
- %54 : Double(10, 4096) = threshold[threshold={0}, value={0}, inplace=1](%53), uses = [%55.i0];
- %56 : Double(10, 4096), %57 : Handle = ^Dropout(0.5, True, False)(%54), uses = [[%60.i1], []];
- %58 : Double(4096!, 4096!) = t(%14), uses = [%60.i2];
- %59 : Double(10!, 4096) = expand[size=[10, 4096]](%15), uses = [%60.i0];
- %60 : Double(10, 4096) = addmm[beta={1}, alpha={1}](%59, %56, %58), uses = [%61.i0];
- %61 : Double(10, 4096) = threshold[threshold={0}, value={0}, inplace=1](%60), uses = [%64.i1];
- %62 : Double(4096!, 1000!) = t(%16), uses = [%64.i2];
- %63 : Double(10!, 1000) = expand[size=[10, 1000]](%17), uses = [%64.i0];
- %64 : Double(10, 1000) = addmm[beta={1}, alpha={1}](%63, %61, %62), uses = [%0.i0];
- return (%64);
+graph(%0 : Double(1, 3, 224, 224)
+ %1 : Double(64, 3, 11, 11)
+ %2 : Double(64)
+ %3 : Double(192, 64, 5, 5)
+ %4 : Double(192)
+ %5 : Double(384, 192, 3, 3)
+ %6 : Double(384)
+ %7 : Double(256, 384, 3, 3)
+ %8 : Double(256)
+ %9 : Double(256, 256, 3, 3)
+ %10 : Double(256)
+ %11 : Double(4096, 9216)
+ %12 : Double(4096)
+ %13 : Double(4096, 4096)
+ %14 : Double(4096)
+ %15 : Double(1000, 4096)
+ %16 : Double(1000)) {
+ %17 : Double(1, 64, 55, 55) = aten::_convolution[stride=[4, 4], padding=[2, 2], dilation=[1, 1], transposed=0, output_padding=[0, 0], groups=1, benchmark=0, deterministic=0, cudnn_enabled=1](%0, %1, %2), scope: AlexNet/Sequential[features]/Conv2d[0]
+ %18 : Double(1, 64, 55, 55) = aten::threshold[threshold={0}, value={0}](%17), scope: AlexNet/Sequential[features]/ReLU[1]
+ %19 : Double(1, 64, 27, 27), %20 : Long(1, 64, 27, 27) = aten::max_pool2d[kernel_size=[3, 3], stride=[2, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=0](%18), scope: AlexNet/Sequential[features]/MaxPool2d[2]
+ %21 : Double(1, 192, 27, 27) = aten::_convolution[stride=[1, 1], padding=[2, 2], dilation=[1, 1], transposed=0, output_padding=[0, 0], groups=1, benchmark=0, deterministic=0, cudnn_enabled=1](%19, %3, %4), scope: AlexNet/Sequential[features]/Conv2d[3]
+ %22 : Double(1, 192, 27, 27) = aten::threshold[threshold={0}, value={0}](%21), scope: AlexNet/Sequential[features]/ReLU[4]
+ %23 : Double(1, 192, 13, 13), %24 : Long(1, 192, 13, 13) = aten::max_pool2d[kernel_size=[3, 3], stride=[2, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=0](%22), scope: AlexNet/Sequential[features]/MaxPool2d[5]
+ %25 : Double(1, 384, 13, 13) = aten::_convolution[stride=[1, 1], padding=[1, 1], dilation=[1, 1], transposed=0, output_padding=[0, 0], groups=1, benchmark=0, deterministic=0, cudnn_enabled=1](%23, %5, %6), scope: AlexNet/Sequential[features]/Conv2d[6]
+ %26 : Double(1, 384, 13, 13) = aten::threshold[threshold={0}, value={0}](%25), scope: AlexNet/Sequential[features]/ReLU[7]
+ %27 : Double(1, 256, 13, 13) = aten::_convolution[stride=[1, 1], padding=[1, 1], dilation=[1, 1], transposed=0, output_padding=[0, 0], groups=1, benchmark=0, deterministic=0, cudnn_enabled=1](%26, %7, %8), scope: AlexNet/Sequential[features]/Conv2d[8]
+ %28 : Double(1, 256, 13, 13) = aten::threshold[threshold={0}, value={0}](%27), scope: AlexNet/Sequential[features]/ReLU[9]
+ %29 : Double(1, 256, 13, 13) = aten::_convolution[stride=[1, 1], padding=[1, 1], dilation=[1, 1], transposed=0, output_padding=[0, 0], groups=1, benchmark=0, deterministic=0, cudnn_enabled=1](%28, %9, %10), scope: AlexNet/Sequential[features]/Conv2d[10]
+ %30 : Double(1, 256, 13, 13) = aten::threshold[threshold={0}, value={0}](%29), scope: AlexNet/Sequential[features]/ReLU[11]
+ %31 : Double(1, 256, 6, 6), %32 : Long(1, 256, 6, 6) = aten::max_pool2d[kernel_size=[3, 3], stride=[2, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=0](%30), scope: AlexNet/Sequential[features]/MaxPool2d[12]
+ %33 : Long() = aten::size[dim=0](%31), scope: AlexNet
+ %34 : Long() = prim::Constant[value={9216}](), scope: AlexNet
+ %35 : Dynamic = aten::stack[dim=0](%33, %34), scope: AlexNet
+ %36 : Double(1, 9216) = aten::view(%31, %35), scope: AlexNet
+ %37 : Double(1, 9216), %38 : Handle = ^Dropout(0.5, True, False)(%36), scope: AlexNet/Sequential[classifier]/Dropout[0]
+ %39 : Double(9216!, 4096!) = aten::t(%11), scope: AlexNet/Sequential[classifier]/Linear[1]
+ %40 : Double(1, 4096) = aten::expand[size=[1, 4096], implicit=1](%12), scope: AlexNet/Sequential[classifier]/Linear[1]
+ %41 : Double(1, 4096) = aten::addmm[beta={1}, alpha={1}](%40, %37, %39), scope: AlexNet/Sequential[classifier]/Linear[1]
+ %42 : Double(1, 4096) = aten::threshold[threshold={0}, value={0}](%41), scope: AlexNet/Sequential[classifier]/ReLU[2]
+ %43 : Double(1, 4096), %44 : Handle = ^Dropout(0.5, True, False)(%42), scope: AlexNet/Sequential[classifier]/Dropout[3]
+ %45 : Double(4096!, 4096!) = aten::t(%13), scope: AlexNet/Sequential[classifier]/Linear[4]
+ %46 : Double(1, 4096) = aten::expand[size=[1, 4096], implicit=1](%14), scope: AlexNet/Sequential[classifier]/Linear[4]
+ %47 : Double(1, 4096) = aten::addmm[beta={1}, alpha={1}](%46, %43, %45), scope: AlexNet/Sequential[classifier]/Linear[4]
+ %48 : Double(1, 4096) = aten::threshold[threshold={0}, value={0}](%47), scope: AlexNet/Sequential[classifier]/ReLU[5]
+ %49 : Double(4096!, 1000!) = aten::t(%15), scope: AlexNet/Sequential[classifier]/Linear[6]
+ %50 : Double(1, 1000) = aten::expand[size=[1, 1000], implicit=1](%16), scope: AlexNet/Sequential[classifier]/Linear[6]
+ %51 : Double(1, 1000) = aten::addmm[beta={1}, alpha={1}](%50, %48, %49), scope: AlexNet/Sequential[classifier]/Linear[6]
+ return (%51);
}
diff --git a/test/expect/TestJit.test_backward.expect b/test/expect/TestJit.test_backward.expect
deleted file mode 100644
index d87b0448f6..0000000000
--- a/test/expect/TestJit.test_backward.expect
+++ /dev/null
@@ -1,23 +0,0 @@
-graph(%0 : Double(2, 2)
- %1 : Double(2, 2)
- -------- stage 1 --------
- %4 : Double(2, 2)
- -------- stage 2 --------
- %8 : Double(2, 2!)
- %9 : Double(2, 2)) {
- %2 : Double(2, 2) = mul[other={2}](%1)
- %3 : Double(2, 2) = mul(%2, %0)
- ---------------- stage 1 ----------------
- %5 : Double(2, 2) = mul(%4, %0)
- %6 : Double(2, 2) = mul(%4, %2)
- %7 : Double(2, 2) = mul[other={2}](%5)
- ---------------- stage 2 ----------------
- %10 : Double(2, 2) = mul(%8, %2)
- %11 : Double(2, 2) = mul(%8, %4)
- %12 : Double(2, 2) = mul[other={2}](%9)
- %13 : Double(2, 2) = mul[other={2}](%11)
- %14 : Double(2, 2) = mul(%12, %0)
- %15 : Double(2, 2) = mul(%12, %4)
- %16 : Double(2, 2) = CppOp[N5torch8autograd3AddE](%10, %14)
- return (%3, %6, %7, %16, %15, %13);
-}
diff --git a/test/expect/TestJit.test_backward_opaque.expect b/test/expect/TestJit.test_backward_opaque.expect
deleted file mode 100644
index 779de3a335..0000000000
--- a/test/expect/TestJit.test_backward_opaque.expect
+++ /dev/null
@@ -1,10 +0,0 @@
-graph(%0 : Double(3, 3)
- %1 : Double(3, 3)
- -------- stage 1 --------
- %3 : Double(3, 3)) {
- %2 : Double(3, 3) = cross[dim=-1](%0, %1)
- ---------------- stage 1 ----------------
- %4 : Double(3, 3) = cross[dim=-1](%1, %3)
- %5 : Double(3, 3) = cross[dim=-1](%3, %0)
- return (%2, %4, %5);
-}
diff --git a/test/expect/TestJit.test_c_function.expect b/test/expect/TestJit.test_c_function.expect
deleted file mode 100644
index aafbe7b221..0000000000
--- a/test/expect/TestJit.test_c_function.expect
+++ /dev/null
@@ -1,6 +0,0 @@
-graph(%0 : Double(1, 3, 10, 10)
- %1 : Double(8, 3, 3, 3)
- %2 : Double(8)) {
- %3 : Double(1, 8, 8, 8) = aten::_convolution[stride=[1, 1], padding=[0, 0], dilation=[1, 1], transposed=0, output_padding=[0, 0], groups=1, benchmark=0, deterministic=0, cudnn_enabled=1](%0, %1, %2)
- return (%3);
-}
diff --git a/test/expect/TestJit.test_concat_fusion.expect b/test/expect/TestJit.test_concat_fusion.expect
index 435872b62e..c1b45b1727 100644
--- a/test/expect/TestJit.test_concat_fusion.expect
+++ b/test/expect/TestJit.test_concat_fusion.expect
@@ -1,6 +1,6 @@
graph(%0 : Float(3, 20)
%1 : Float(3, 20)) {
- %2 : Float(6, 20) = prim::FusionGroup_0(%0, %1)
+ %2 : Float(6, 20) = prim::FusionGroup_0[device=0](%0, %1)
return (%2);
}
with prim::FusionGroup_0 = graph(%3 : Float(3, 20)
diff --git a/test/expect/TestJit.test_function_as_argument.expect b/test/expect/TestJit.test_function_as_argument.expect
deleted file mode 100644
index 2570b327ff..0000000000
--- a/test/expect/TestJit.test_function_as_argument.expect
+++ /dev/null
@@ -1,26 +0,0 @@
-graph(%1 : Double(3, 10)
- %2 : Double(3, 20)
- %3 : Double(3, 20)
- %4 : Double(80, 10)
- %5 : Double(80, 20)
- %6 : Double(80)
- %7 : Double(80)) {
- %8 : Double(10!, 80!) = Transpose[perm=[1, 0]](%4), uses = [%9.i0];
- %9 : UNKNOWN_TYPE = Transpose(%8), uses = [%10.i1];
- %10 : Double(3, 80) = FC(%1, %9, %6), uses = [%14.i0];
- %11 : Double(20!, 80!) = Transpose[perm=[1, 0]](%5), uses = [%12.i0];
- %12 : UNKNOWN_TYPE = Transpose(%11), uses = [%13.i1];
- %13 : Double(3, 80) = FC(%2, %12, %7), uses = [%14.i1];
- %14 : Double(3, 80) = Add(%10, %13), uses = [%15.i0];
- %16 : Double(3!, 20), %17 : Double(3!, 20), %18 : Double(3!, 20), %19 : Double(3!, 20) = Split[split=[20, 20, 20, 20], axis=1](%14), uses = [[%20.i0], [%21.i0], [%22.i0], [%23.i0]];
- %20 : Double(3, 20) = Sigmoid(%16), uses = [%25.i0];
- %21 : Double(3, 20) = Sigmoid(%17), uses = [%24.i0];
- %22 : Double(3, 20) = Tanh(%18), uses = [%25.i1];
- %23 : Double(3, 20) = Sigmoid(%19), uses = [%28.i0];
- %24 : Double(3, 20) = Mul(%21, %3), uses = [%26.i0];
- %25 : Double(3, 20) = Mul(%20, %22), uses = [%26.i1];
- %26 : Double(3, 20) = Add(%24, %25), uses = [%27.i0, %0.i1];
- %27 : Double(3, 20) = Tanh(%26), uses = [%28.i1];
- %28 : Double(3, 20) = Mul(%23, %27), uses = [%0.i0];
- return (%28, %26);
-}
diff --git a/test/expect/TestJit.test_fuse_last_device.expect b/test/expect/TestJit.test_fuse_last_device.expect
new file mode 100644
index 0000000000..276fadc61f
--- /dev/null
+++ b/test/expect/TestJit.test_fuse_last_device.expect
@@ -0,0 +1,14 @@
+graph(%0 : Float(1)
+ %1 : Float(1)) {
+ %2 : Float(1) = prim::FusionGroup_0[device=1](%0, %1)
+ return (%2);
+}
+with prim::FusionGroup_0 = graph(%6 : Float(1)
+ %9 : Float(1)) {
+ %10 : Float(1) = aten::add[alpha={1}](%6, %9)
+ %8 : Float(1) = aten::mul(%6, %10)
+ %5 : Float(1) = aten::add[other={1}, alpha={1}](%8)
+ %3 : Float(1) = aten::tanh(%5)
+ %1 : Float(1) = aten::sigmoid(%3)
+ return (%1);
+}
diff --git a/test/expect/TestJit.test_fusion_distribute-raw.expect b/test/expect/TestJit.test_fusion_distribute-raw.expect
deleted file mode 100644
index 290ca22052..0000000000
--- a/test/expect/TestJit.test_fusion_distribute-raw.expect
+++ /dev/null
@@ -1,7 +0,0 @@
-graph(%0 : Float(4, 4)
- %1 : Float(4, 4)) {
- %2 : Float(4, 4) = aten::add[alpha={1}](%0, %1)
- %3 : Float(4!, 2), %4 : Float(4!, 2) = aten::chunk[chunks=2, dim=1](%2)
- %5 : Float(4, 2) = aten::mul(%3, %4)
- return (%5);
-}
diff --git a/test/expect/TestJit.test_fusion_distribute.expect b/test/expect/TestJit.test_fusion_distribute.expect
index 0b873fb724..4465074e55 100644
--- a/test/expect/TestJit.test_fusion_distribute.expect
+++ b/test/expect/TestJit.test_fusion_distribute.expect
@@ -2,7 +2,7 @@ graph(%0 : Float(4, 4)
%1 : Float(4, 4)) {
%2 : Float(4!, 2), %3 : Float(4!, 2) = aten::chunk[chunks=2, dim=1](%0)
%4 : Float(4!, 2), %5 : Float(4!, 2) = aten::chunk[chunks=2, dim=1](%1)
- %6 : Float(4, 2) = prim::FusionGroup_0(%2, %4, %3, %5)
+ %6 : Float(4, 2) = prim::FusionGroup_0[device=0](%2, %4, %3, %5)
return (%6);
}
with prim::FusionGroup_0 = graph(%3 : Float(4!, 2)
diff --git a/test/expect/TestJit.test_index_inner.expect b/test/expect/TestJit.test_index_inner.expect
deleted file mode 100644
index 02e3fa321f..0000000000
--- a/test/expect/TestJit.test_index_inner.expect
+++ /dev/null
@@ -1,5 +0,0 @@
-graph(%0 : Double(4, 4)) {
- %1 : Double(4) = select[dim=0, index=0](%0)
- %2 : Double(4) = add[other={0}, alpha={1}](%1)
- return (%2);
-}
diff --git a/test/expect/TestJit.test_index_trace.expect b/test/expect/TestJit.test_index_trace.expect
deleted file mode 100644
index 45423ec05e..0000000000
--- a/test/expect/TestJit.test_index_trace.expect
+++ /dev/null
@@ -1,8 +0,0 @@
-graph(%0 : Double(4, 4)
- -------- stage 1 --------
- %1 : Double(4!)) {
- %2 : Double(4) = aten::select[dim=0, index=0](%0)
- ---------------- stage 1 ----------------
- %3 : Double(4, 4) = CppOp[AsStridedBackward](%1)
- return (%2, %3);
-}
diff --git a/test/expect/TestJit.test_input_pruning.expect b/test/expect/TestJit.test_input_pruning.expect
deleted file mode 100644
index 5464d480ea..0000000000
--- a/test/expect/TestJit.test_input_pruning.expect
+++ /dev/null
@@ -1,12 +0,0 @@
-graph(%0 : Double(5, 5)
- %1 : Double(5, 5)
- -------- stage 1 --------
- %4 : Double(5, 5)
- %5 : Double(5, 5)) {
- %2 : Double(5, 5) = aten::mul(%0, %1)
- %3 : Double(5, 5) = aten::add[alpha={1}](%0, %1)
- ---------------- stage 1 ----------------
- %6 : Double(5, 5) = aten::mul(%4, %1)
- %7 : Double(5, 5) = aten::add[alpha={1}](%5, %6)
- return (%2, %3, %7);
-}
diff --git a/test/expect/TestJit.test_lstm.expect b/test/expect/TestJit.test_lstm.expect
deleted file mode 100644
index 3bbdb7c2b7..0000000000
--- a/test/expect/TestJit.test_lstm.expect
+++ /dev/null
@@ -1,42 +0,0 @@
-graph(%1 : Double(3, 10)
- %2 : Double(3, 20)
- %3 : Double(3, 20)
- %4 : Double(80, 10)
- %5 : Double(80, 20)
- %6 : Double(80)
- %7 : Double(80)) {
- %8 : Double(10!, 80!) = Transpose[perm=[1, 0]](%4), uses = [%9.i0];
- %9 : UNKNOWN_TYPE = Transpose(%8), uses = [%10.i1];
- %10 : Double(3, 80) = FC(%1, %9, %6), uses = [%32.i0];
- %11 : Double(20!, 80!) = Transpose[perm=[1, 0]](%5), uses = [%12.i0];
- %12 : UNKNOWN_TYPE = Transpose(%11), uses = [%13.i1];
- %13 : Double(3, 80) = FC(%2, %12, %7), uses = [%33.i0];
- %36 : Double(3!, 20), %39 : Double(3!, 20), %42 : Double(3!, 20), %45 : Double(3!, 20) = Split[split=[20, 20, 20, 20], axis=1](%13), uses = [[%29.i8], [%29.i6], [%29.i4], [%29.i2]];
- %35 : Double(3!, 20), %38 : Double(3!, 20), %41 : Double(3!, 20), %44 : Double(3!, 20) = Split[split=[20, 20, 20, 20], axis=1](%10), uses = [[%29.i7], [%29.i5], [%29.i3], [%29.i1]];
- %30 : Double(3, 20), %31 : Double(3, 20) = fusion_group_0(%3, %44, %45, %41, %42, %38, %39, %35, %36), uses = [[%0.i0], [%0.i1]];
- return (%30, %31);
-}
-with fusion_group_0 = graph(%13 : Double(3, 20)
- %23 : Double(3!, 20)
- %24 : Double(3!, 20)
- %26 : Double(3!, 20)
- %27 : Double(3!, 20)
- %29 : Double(3!, 20)
- %30 : Double(3!, 20)
- %32 : Double(3!, 20)
- %33 : Double(3!, 20)) {
- %34 : Double(3, 20) = Add(%32, %33), uses = [%22.i0];
- %31 : Double(3, 20) = Add(%29, %30), uses = [%20.i0];
- %28 : Double(3, 20) = Add(%26, %27), uses = [%18.i0];
- %25 : Double(3, 20) = Add(%23, %24), uses = [%16.i0];
- %22 : Double(3, 20) = Sigmoid(%34), uses = [%11.i0];
- %20 : Double(3, 20) = Sigmoid(%31), uses = [%14.i0];
- %18 : Double(3, 20) = Tanh(%28), uses = [%11.i1];
- %16 : Double(3, 20) = Sigmoid(%25), uses = [%3.i0];
- %14 : Double(3, 20) = Mul(%20, %13), uses = [%8.i0];
- %11 : Double(3, 20) = Mul(%22, %18), uses = [%8.i1];
- %8 : Double(3, 20) = Add(%14, %11), uses = [%5.i0, %0.i1];
- %5 : Double(3, 20) = Tanh(%8), uses = [%3.i1];
- %3 : Double(3, 20) = Mul(%16, %5), uses = [%0.i0];
- return (%3, %8);
-}
diff --git a/test/expect/TestJit.test_lstm_fusion_concat.expect b/test/expect/TestJit.test_lstm_fusion_concat.expect
new file mode 100644
index 0000000000..65bee38a8c
--- /dev/null
+++ b/test/expect/TestJit.test_lstm_fusion_concat.expect
@@ -0,0 +1,41 @@
+graph(%0 : Float(3, 10)
+ %1 : Float(3, 20)
+ %2 : Float(3, 20)
+ %3 : Float(80, 10)
+ %4 : Float(80, 20)
+ %5 : Float(80)
+ %6 : Float(80)) {
+ %7 : Float(10!, 80!) = aten::t(%3)
+ %8 : Float(3, 80) = aten::addmm[beta={1}, alpha={1}](%5, %0, %7)
+ %9 : Float(20!, 80!) = aten::t(%4)
+ %10 : Float(3, 80) = aten::addmm[beta={1}, alpha={1}](%6, %1, %9)
+ %11 : Float(3!, 20), %12 : Float(3!, 20), %13 : Float(3!, 20), %14 : Float(3!, 20) = aten::chunk[chunks=4, dim=1](%8)
+ %15 : Float(3!, 20), %16 : Float(3!, 20), %17 : Float(3!, 20), %18 : Float(3!, 20) = aten::chunk[chunks=4, dim=1](%10)
+ %19 : Float(6, 20) = prim::FusionGroup_0[device=0](%2, %14, %18, %13, %17, %12, %16, %11, %15)
+ return (%19);
+}
+with prim::FusionGroup_0 = graph(%14 : Float(3, 20)
+ %24 : Float(3!, 20)
+ %25 : Float(3!, 20)
+ %27 : Float(3!, 20)
+ %28 : Float(3!, 20)
+ %30 : Float(3!, 20)
+ %31 : Float(3!, 20)
+ %33 : Float(3!, 20)
+ %34 : Float(3!, 20)) {
+ %35 : Float(3, 20) = aten::add[alpha={1}](%33, %34)
+ %32 : Float(3, 20) = aten::add[alpha={1}](%30, %31)
+ %29 : Float(3, 20) = aten::add[alpha={1}](%27, %28)
+ %26 : Float(3, 20) = aten::add[alpha={1}](%24, %25)
+ %23 : Float(3, 20) = aten::sigmoid(%35)
+ %21 : Float(3, 20) = aten::sigmoid(%32)
+ %19 : Float(3, 20) = aten::tanh(%29)
+ %17 : Float(3, 20) = aten::sigmoid(%26)
+ %15 : Float(3, 20) = aten::mul(%21, %14)
+ %12 : Float(3, 20) = aten::mul(%23, %19)
+ %9 : Float(3, 20) = aten::add[alpha={1}](%15, %12)
+ %6 : Float(3, 20) = aten::tanh(%9)
+ %5 : Float(3, 20) = aten::mul(%17, %6)
+ %2 : Float(6, 20) = aten::cat[dim=0](%5, %9)
+ return (%2);
+}
diff --git a/test/expect/TestJit.test_lstm_fusion.expect b/test/expect/TestJit.test_lstm_fusion_cpu.expect
index 10a537fe98..500996c9bc 100644
--- a/test/expect/TestJit.test_lstm_fusion.expect
+++ b/test/expect/TestJit.test_lstm_fusion_cpu.expect
@@ -6,15 +6,13 @@ graph(%0 : Float(3, 10)
%5 : Float(80)
%6 : Float(80)) {
%7 : Float(10!, 80!) = aten::t(%3)
- %8 : Float(3!, 80) = aten::expand[size=[3, 80], implicit=1](%5)
- %9 : Float(3, 80) = aten::addmm[beta={1}, alpha={1}](%8, %0, %7)
- %10 : Float(20!, 80!) = aten::t(%4)
- %11 : Float(3!, 80) = aten::expand[size=[3, 80], implicit=1](%6)
- %12 : Float(3, 80) = aten::addmm[beta={1}, alpha={1}](%11, %1, %10)
- %13 : Float(3!, 20), %14 : Float(3!, 20), %15 : Float(3!, 20), %16 : Float(3!, 20) = aten::chunk[chunks=4, dim=1](%9)
- %17 : Float(3!, 20), %18 : Float(3!, 20), %19 : Float(3!, 20), %20 : Float(3!, 20) = aten::chunk[chunks=4, dim=1](%12)
- %21 : Float(3, 20), %22 : Float(3, 20) = prim::FusionGroup_0(%2, %16, %20, %15, %19, %14, %18, %13, %17)
- return (%21, %22);
+ %8 : Float(3, 80) = aten::addmm[beta={1}, alpha={1}](%5, %0, %7)
+ %9 : Float(20!, 80!) = aten::t(%4)
+ %10 : Float(3, 80) = aten::addmm[beta={1}, alpha={1}](%6, %1, %9)
+ %11 : Float(3!, 20), %12 : Float(3!, 20), %13 : Float(3!, 20), %14 : Float(3!, 20) = aten::chunk[chunks=4, dim=1](%8)
+ %15 : Float(3!, 20), %16 : Float(3!, 20), %17 : Float(3!, 20), %18 : Float(3!, 20) = aten::chunk[chunks=4, dim=1](%10)
+ %19 : Float(3, 20), %20 : Float(3, 20) = prim::FusionGroup_0[device=-1](%2, %14, %18, %13, %17, %12, %16, %11, %15)
+ return (%19, %20);
}
with prim::FusionGroup_0 = graph(%12 : Float(3, 20)
%22 : Float(3!, 20)
diff --git a/test/expect/TestJit.test_lstm_fusion_cuda.expect b/test/expect/TestJit.test_lstm_fusion_cuda.expect
new file mode 100644
index 0000000000..015fcc1a00
--- /dev/null
+++ b/test/expect/TestJit.test_lstm_fusion_cuda.expect
@@ -0,0 +1,40 @@
+graph(%0 : Float(3, 10)
+ %1 : Float(3, 20)
+ %2 : Float(3, 20)
+ %3 : Float(80, 10)
+ %4 : Float(80, 20)
+ %5 : Float(80)
+ %6 : Float(80)) {
+ %7 : Float(10!, 80!) = aten::t(%3)
+ %8 : Float(3, 80) = aten::addmm[beta={1}, alpha={1}](%5, %0, %7)
+ %9 : Float(20!, 80!) = aten::t(%4)
+ %10 : Float(3, 80) = aten::addmm[beta={1}, alpha={1}](%6, %1, %9)
+ %11 : Float(3!, 20), %12 : Float(3!, 20), %13 : Float(3!, 20), %14 : Float(3!, 20) = aten::chunk[chunks=4, dim=1](%8)
+ %15 : Float(3!, 20), %16 : Float(3!, 20), %17 : Float(3!, 20), %18 : Float(3!, 20) = aten::chunk[chunks=4, dim=1](%10)
+ %19 : Float(3, 20), %20 : Float(3, 20) = prim::FusionGroup_0[device=0](%2, %14, %18, %13, %17, %12, %16, %11, %15)
+ return (%19, %20);
+}
+with prim::FusionGroup_0 = graph(%12 : Float(3, 20)
+ %22 : Float(3!, 20)
+ %23 : Float(3!, 20)
+ %25 : Float(3!, 20)
+ %26 : Float(3!, 20)
+ %28 : Float(3!, 20)
+ %29 : Float(3!, 20)
+ %31 : Float(3!, 20)
+ %32 : Float(3!, 20)) {
+ %33 : Float(3, 20) = aten::add[alpha={1}](%31, %32)
+ %30 : Float(3, 20) = aten::add[alpha={1}](%28, %29)
+ %27 : Float(3, 20) = aten::add[alpha={1}](%25, %26)
+ %24 : Float(3, 20) = aten::add[alpha={1}](%22, %23)
+ %21 : Float(3, 20) = aten::sigmoid(%33)
+ %19 : Float(3, 20) = aten::sigmoid(%30)
+ %17 : Float(3, 20) = aten::tanh(%27)
+ %15 : Float(3, 20) = aten::sigmoid(%24)
+ %13 : Float(3, 20) = aten::mul(%19, %12)
+ %10 : Float(3, 20) = aten::mul(%21, %17)
+ %7 : Float(3, 20) = aten::add[alpha={1}](%13, %10)
+ %4 : Float(3, 20) = aten::tanh(%7)
+ %2 : Float(3, 20) = aten::mul(%15, %4)
+ return (%2, %7);
+}
diff --git a/test/expect/TestJit.test_matmul_native.expect b/test/expect/TestJit.test_matmul_native.expect
deleted file mode 100644
index 84f234edbb..0000000000
--- a/test/expect/TestJit.test_matmul_native.expect
+++ /dev/null
@@ -1,5 +0,0 @@
-graph(%0 : Double(1, 1)
- %1 : Double(1, 1)) {
- %2 : Double(1, 1) = aten::matmul(%0, %1)
- return (%2);
-}
diff --git a/test/expect/TestJit.test_output_pruning.expect b/test/expect/TestJit.test_output_pruning.expect
deleted file mode 100644
index a89e0b4678..0000000000
--- a/test/expect/TestJit.test_output_pruning.expect
+++ /dev/null
@@ -1,10 +0,0 @@
-graph(%0 : Double(5, 5)
- %1 : Double(5, 5)
- -------- stage 1 --------
- %4 : Double(5, 5)) {
- %2 : Double(5, 5) = aten::mul(%0, %1)
- %3 : Double(5, 5) = aten::add[alpha={1}](%1, %1)
- ---------------- stage 1 ----------------
- %5 : Double(5, 5) = aten::mul(%4, %1)
- return (%2, %3, %5);
-}
diff --git a/test/expect/TestJit.test_repeated_input.expect b/test/expect/TestJit.test_repeated_input.expect
index 8db8e85d9f..57e57066ef 100644
--- a/test/expect/TestJit.test_repeated_input.expect
+++ b/test/expect/TestJit.test_repeated_input.expect
@@ -1,7 +1,5 @@
graph(%0 : Double(2, 2)
- %1 : Double(2, 2)
- -------- stage 1 --------
- %3 : Double(2, 2!)) {
+ %1 : Double(2, 2)) {
%2 : Double(2, 2) = aten::add[alpha={1}](%0, %1)
- return (%2, %3, %3);
+ return (%2);
}
diff --git a/test/expect/TestJit.test_repeated_output.expect b/test/expect/TestJit.test_repeated_output.expect
index 1f8c156f15..b3baff631e 100644
--- a/test/expect/TestJit.test_repeated_output.expect
+++ b/test/expect/TestJit.test_repeated_output.expect
@@ -1,10 +1,5 @@
graph(%0 : Double(2, 2)
- %1 : Double(2, 2)
- -------- stage 1 --------
- %3 : Double(2, 2)
- %4 : Double(2, 2)) {
+ %1 : Double(2, 2)) {
%2 : Double(2, 2) = aten::add[alpha={1}](%0, %1)
- ---------------- stage 1 ----------------
- %5 : Double(2, 2) = aten::add[alpha={1}](%3, %4)
- return (%2, %2, %5, %5);
+ return (%2, %2);
}
diff --git a/test/expect/TestJit.test_saved_output.expect b/test/expect/TestJit.test_saved_output.expect
deleted file mode 100644
index 77145061e1..0000000000
--- a/test/expect/TestJit.test_saved_output.expect
+++ /dev/null
@@ -1,8 +0,0 @@
-graph(%0 : Double(4, 4)
- -------- stage 1 --------
- %2 : Double(4, 4!)) {
- %1 : Double(4, 4) = aten::sigmoid(%0)
- ---------------- stage 1 ----------------
- %3 : Double(4, 4) = aten::_sigmoid_backward(%2, %1)
- return (%1, %3);
-}
diff --git a/test/expect/TestJit.test_shape_analysis_broadcast.expect b/test/expect/TestJit.test_shape_analysis_broadcast.expect
index 1f42efa52e..bbe5b74164 100644
--- a/test/expect/TestJit.test_shape_analysis_broadcast.expect
+++ b/test/expect/TestJit.test_shape_analysis_broadcast.expect
@@ -1,7 +1,7 @@
graph(%a : Double(3, 1, 5)
%b : Double(4, 1, 8, 5)) {
- %3 : Double(4!, 3!, 8!, 5) = aten::expand[size=[4, 3, 8, 5], implicit=0](%a)
- %4 : Double(4!, 3!, 8, 5) = aten::expand[size=[4, 3, 8, 5], implicit=0](%b)
- %2 : Double(4, 3, 8, 5) = aten::add[alpha={1}](%3, %4)
- return (%2);
+ %2 : Double(4!, 3!, 8!, 5) = aten::expand[size=[4, 3, 8, 5], implicit=0](%a)
+ %3 : Double(4!, 3!, 8, 5) = aten::expand[size=[4, 3, 8, 5], implicit=0](%b)
+ %4 : Double(4, 3, 8, 5) = aten::add[alpha={1}](%2, %3)
+ return (%4);
}
diff --git a/test/test_jit.py b/test/test_jit.py
index 0428cd6d00..a48a0e8c50 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -7,6 +7,7 @@ from itertools import product, chain
import torch.jit.frontend
from torch.autograd import Variable, Function
from torch.autograd.function import traceable
+from torch.testing import assert_allclose
from common import TestCase, run_tests, IS_WINDOWS
from textwrap import dedent
import os
@@ -18,6 +19,7 @@ import textwrap
import numpy as np
import tempfile
import shutil
+import warnings
from torch.jit.frontend import NotSupportedError
@@ -45,6 +47,10 @@ PY35 = sys.version_info >= (3, 5)
WINDOWS = sys.platform == 'win32'
+def LSTMCellF(input, hx, cx, *params):
+ return LSTMCell(input, (hx, cx), *params)
+
+
def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
hx, cx = hidden
gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)
@@ -61,97 +67,91 @@ def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
def LSTMCellC(*args, **kwargs):
- hy, cy = LSTMCell(*args, **kwargs)
+ hy, cy = LSTMCellF(*args, **kwargs)
return torch.cat((hy, cy))
+def get_lstm_inputs(device):
+ input = torch.randn(3, 10, dtype=torch.float, device=device)
+ hx = torch.randn(3, 20, dtype=torch.float, device=device)
+ cx = torch.randn(3, 20, dtype=torch.float, device=device)
+ module = nn.LSTMCell(10, 20).to(torch.float, device) # Just to allocate weights with correct sizes
+ return (input, hx, cx) + tuple(p.requires_grad_(False) for p in module.parameters())
+
+
class TestJit(TestCase):
- maxDiff = None
+ def assertExpectedONNXGraph(self, trace, *args, **kwargs):
+ torch.onnx._optimize_trace(trace, aten=False)
+ self.assertExpectedGraph(trace, *args, **kwargs)
- @contextmanager
- def assertCompiled(self, compiled_fn):
- self.assertIsInstance(compiled_fn, torch._C.CompiledFunction)
- hits, misses = compiled_fn.hits, compiled_fn.misses
- yield
- self.assertLess(hits, compiled_fn.hits)
- self.assertEqual(misses, compiled_fn.misses)
-
- def assertExpectedTrace(self, trace, *args, **kwargs):
- torch._C._jit_pass_lint(trace.graph())
- torch._C._jit_pass_dce(trace.graph())
- torch._C._jit_pass_lint(trace.graph())
- trace.set_graph(torch._C._jit_pass_canonicalize(trace.graph()))
- torch._C._jit_pass_lint(trace.graph())
- self.assertExpected(str(trace), *args, **kwargs)
+ def assertExpectedGraph(self, trace, *args, **kwargs):
+ if isinstance(trace, torch._C.Graph):
+ graph = trace
+ else:
+ graph = trace.graph()
+
+ torch._C._jit_pass_lint(graph)
+ torch._C._jit_pass_dce(graph)
+ torch._C._jit_pass_lint(graph)
+ graph = torch._C._jit_pass_canonicalize(graph)
+ torch._C._jit_pass_lint(graph)
+ self.assertExpected(str(graph), *args, **kwargs)
def assertExportImport(self, trace, inputs):
initializers = []
- graph1 = trace.graph()
- ge1 = torch._C.GraphExecutor(graph1, False)
- out1 = ge1(*inputs)
+ def run(graph):
+ return torch._C.GraphExecutor(graph, False)(*inputs)
+
proto, _ = trace.graph().export(initializers, onnx_opset_version=0,
defer_weight_export=False, export_raw_ir=True)
+ self.assertFalse(initializers)
- graph2, initializers = torch._C._jit_import_graph(proto)
- ge2 = torch._C.GraphExecutor(graph2, False)
- out2 = ge2(*inputs)
+ imported_graph, initializers = torch._C._jit_import_graph(proto)
+ self.assertFalse(initializers)
- self.assertEqual(out1, out2)
+ self.assertEqual(run(trace.graph()), run(imported_graph))
- def test_simple(self):
- x = Variable(torch.Tensor([0.4]), requires_grad=True)
- y = Variable(torch.Tensor([0.7]), requires_grad=True)
+ def run_pass(self, name, trace):
+ if isinstance(trace, torch._C.Graph):
+ graph = trace
+ set_graph = False
+ else:
+ set_graph = True
+ graph = trace.graph()
- def f(x, y):
- return torch.sigmoid(torch.tanh(x * (x + y)))
+ torch._C._jit_pass_lint(graph)
+ result = getattr(torch._C, '_jit_pass_' + name)(graph)
+ if result is not None:
+ graph = result
+ torch._C._jit_pass_lint(graph)
- trace, z = torch.jit.get_trace_graph(f, (x, y), nderivs=0)
- self.assertExpectedTrace(trace)
- self.assertExportImport(trace, (x, y))
+ if set_graph:
+ trace.set_graph(graph)
+ return graph
- # matmul is currently implemented as a native function, which
- # exercises different codepaths in the JIT. The following two
- # tests ensure that (1) matmul indeed traces into an atomic,
- # native operation, and (2) the JIT knows how to run it
+ def test_simple(self):
+ x = torch.tensor([0.4], requires_grad=True)
+ y = torch.tensor([0.7], requires_grad=True)
- def test_matmul_native(self):
- x = Variable(torch.Tensor([[0.4]]), requires_grad=True)
- y = Variable(torch.Tensor([[0.7]]), requires_grad=True)
+ def f(x, y):
+ return torch.sigmoid(torch.tanh(x * (x + y)))
- trace, z = torch.jit.get_trace_graph(lambda x, y: x.matmul(y), (x, y), nderivs=0)
- torch._C._jit_pass_lint(trace.graph())
- torch._C._jit_pass_dce(trace.graph())
- self.assertExpectedTrace(trace)
+ trace, z = torch.jit.get_trace_graph(f, (x, y))
+ self.assertExpectedGraph(trace)
self.assertExportImport(trace, (x, y))
- def test_matmul_native_run(self):
- x = Variable(torch.Tensor([[0.4]]), requires_grad=True)
- y = Variable(torch.Tensor([[0.7]]), requires_grad=True)
-
- @torch.jit.compile(nderivs=0)
- def fn(x, y):
- return x.matmul(y)
-
- z = fn(x, y)
- with self.assertCompiled(fn):
- z2 = fn(x, y)
- self.assertEqual(z, z2)
-
# index-2 is not implemented in interpreter
@unittest.expectedFailure
def test_index(self):
- x = Variable(torch.Tensor([0.4]), requires_grad=True)
- y = Variable(torch.LongTensor([0]), requires_grad=True)
+ x = torch.tensor([0.4], requires_grad=True)
+ y = torch.tensor([0], dtype=torch.int64, requires_grad=True)
@torch.jit.compile(nderivs=0)
def fn(x, y):
return x[y]
- z = fn(x, y)
- with self.assertCompiled(fn):
- z2 = fn(x, y)
- self.assertEqual(z, z2)
+ fn(x, y) # Fails
# Backwards tracing was broken for indexing by a constant,
# because it's internally implemented using as_strided,
@@ -159,26 +159,22 @@ class TestJit(TestCase):
# currently supported.) It currently works because
# slice() is now not marked as traceable.
def test_index_constant(self):
- x = Variable(torch.Tensor([0.4]), requires_grad=True)
+ x = torch.tensor([0.4], requires_grad=True)
- @torch.jit.compile(nderivs=1)
def fn(x):
return x[0]
- z = fn(x)
- z.backward()
- grad = x.grad.clone()
- x.grad.zero_()
- with self.assertCompiled(fn):
- z2 = fn(x)
- z2.backward()
- grad2 = x.grad.clone()
- self.assertEqual(z, z2)
- self.assertEqual(grad, grad2)
+ def run(f):
+ y = f(x)
+ grad = torch.autograd.grad(y, x)[0].clone()
+ return y, grad
+
+ traced_fn = torch.jit.trace(torch.ones(1))(fn)
+ self.assertEqual(run(fn), run(traced_fn))
def test_scopes(self):
- x = Variable(torch.Tensor([0.4]), requires_grad=True)
- y = Variable(torch.Tensor([0.7]), requires_grad=True)
+ x = torch.tensor([0.4], requires_grad=True)
+ y = torch.tensor([0.7], requires_grad=True)
def f(x, y):
out = x + y
@@ -190,7 +186,7 @@ class TestJit(TestCase):
return out
trace, z = torch.jit.get_trace_graph(f, (x, y), nderivs=0)
- self.assertExpectedTrace(trace)
+ self.assertExpectedGraph(trace)
self.assertExportImport(trace, (x, y))
def test_scopes_intermediate_node(self):
@@ -200,12 +196,10 @@ class TestJit(TestCase):
return F.log_softmax(x, dim=0)
net = Net()
- t = Variable(torch.ones(2), requires_grad=True)
- trace, _ = torch.jit.get_trace_graph(net, (t, ))
- self.assertExportImport(trace, (t, ))
- torch.onnx._optimize_trace(trace, False)
-
- self.assertExpectedTrace(trace)
+ t = torch.ones(2, requires_grad=True)
+ trace, _ = torch.jit.get_trace_graph(net, (t,))
+ self.assertExportImport(trace, (t,))
+ self.assertExpectedONNXGraph(trace)
def test_scopes_identity_node(self):
@@ -225,93 +219,55 @@ class TestJit(TestCase):
model = Net()
- t = Variable(torch.ones(1, 3, 227, 227), requires_grad=True)
+ t = torch.ones(1, 3, 227, 227, requires_grad=True)
with torch.onnx.set_training(model, False):
- trace, _ = torch.jit.get_trace_graph(model, (t, ))
-
- self.assertExportImport(trace, (t, ) + tuple(model.parameters()))
+ trace, _ = torch.jit.get_trace_graph(model, (t,))
- torch.onnx._optimize_trace(trace, False)
-
- self.assertExpectedTrace(trace)
+ self.assertExportImport(trace, (t,) + tuple(model.parameters()))
+ self.assertExpectedONNXGraph(trace)
+ # TODO: Fuser doesn't work at all when inputs require grad. Fix that
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
- def test_lstm_fusion(self):
- input = Variable(torch.randn(3, 10).float().cuda())
- hx = Variable(torch.randn(3, 20).float().cuda())
- cx = Variable(torch.randn(3, 20).float().cuda())
- module = nn.LSTMCell(10, 20).float().cuda() # Just to allocate weights with correct sizes
-
- trace, _ = torch.jit.get_trace_graph(LSTMCell, (input, (hx, cx)) + tuple(module.parameters()))
- torch._C._jit_pass_lint(trace.graph())
- torch._C._jit_pass_dce(trace.graph())
- torch._C._jit_pass_lint(trace.graph())
- self.assertExportImport(trace, (input, hx, cx) + tuple(module.parameters()))
- torch._C._jit_pass_fuse(trace.graph())
- self.assertExpectedTrace(trace)
-
- def run_lstm_fusion(self, use_cuda):
- def to_type(x):
- x = x.float()
- if use_cuda:
- x = x.cuda()
- return x
-
- def rand_v(a, b):
- return Variable(to_type(torch.randn(a, b)))
-
- input = rand_v(3, 10)
- hx = rand_v(3, 20)
- cx = rand_v(3, 20)
- module = to_type(nn.LSTMCell(10, 20)) # Just to allocate weights with correct sizes
-
- CompiledLSTMCell = torch.jit.compile(nderivs=0)(LSTMCell)
-
- z = CompiledLSTMCell(input, (hx, cx), *module.parameters())
- with self.assertCompiled(CompiledLSTMCell):
- z2 = CompiledLSTMCell(input, (hx, cx), *module.parameters())
- self.assertEqual(z, z2)
-
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
- def test_run_lstm_fusion_cuda(self):
- self.run_lstm_fusion(True)
+ def test_lstm_fusion_cuda(self):
+ inputs = get_lstm_inputs('cuda')
+ ge = self.checkTrace(LSTMCellF, inputs)
+ self.assertExpectedGraph(ge.graph_for(*inputs))
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- def test_run_lstm_fusion_cpu(self):
- self.run_lstm_fusion(False)
+ def test_lstm_fusion_cpu(self):
+ inputs = get_lstm_inputs('cpu')
+ try:
+ ge = self.checkTrace(LSTMCellF, inputs)
+ self.assertExpectedGraph(ge.graph_for(*inputs))
+ except RuntimeError as e:
+ if 'Failed to compile' in e.args[0]:
+ warnings.warn('CPU fuser test has failed! This is not a hard failure, '
+ 'because the kernels sometimes trigger bugs in compilers '
+ '(most notably GCC 7.2).')
+ raise unittest.SkipTest('Failed to compile')
+ else:
+ raise
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
- def test_run_lstm_fusion_concat(self):
- input = Variable(torch.randn(3, 10).float().cuda())
- hx = Variable(torch.randn(3, 20).float().cuda())
- cx = Variable(torch.randn(3, 20).float().cuda())
- module = nn.LSTMCell(10, 20).float().cuda() # Just to allocate weights with correct sizes
-
- CompiledLSTMCell = torch.jit.compile(nderivs=0)(LSTMCellC)
-
- z = CompiledLSTMCell(input, (hx, cx), *module.parameters())
- with self.assertCompiled(CompiledLSTMCell):
- z2 = CompiledLSTMCell(input, (hx, cx), *module.parameters())
- self.assertEqual(z, z2)
+ def test_lstm_fusion_concat(self):
+ inputs = get_lstm_inputs('cuda')
+ ge = self.checkTrace(LSTMCellC, inputs)
+ self.assertExpectedGraph(ge.graph_for(*inputs))
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
def test_concat_fusion(self):
- hx = Variable(torch.randn(3, 20).float().cuda())
- cx = Variable(torch.randn(3, 20).float().cuda())
+ hx = torch.randn(3, 20, dtype=torch.float, device='cuda')
+ cx = torch.randn(3, 20, dtype=torch.float, device='cuda')
- def Foo(hx, cx):
+ def foo(hx, cx):
return torch.cat((hx + cx, hx * cx))
- trace, _ = torch.jit.get_trace_graph(Foo, (hx, cx))
- self.assertExportImport(trace, (hx, cx))
- torch._C._jit_pass_lint(trace.graph())
- torch._C._jit_pass_fuse(trace.graph())
- self.assertExpectedTrace(trace)
+ ge = self.checkTrace(foo, (hx, cx))
+ self.assertExpectedGraph(ge.graph_for(hx, cx))
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@@ -319,16 +275,15 @@ class TestJit(TestCase):
def f(x, y):
z1, z2 = (x + y).chunk(2, dim=1)
return z1 * z2
- x = Variable(torch.randn(4, 4).float().cuda())
- y = Variable(torch.randn(4, 4).float().cuda())
- trace, _ = torch.jit.get_trace_graph(f, (x, y), nderivs=0)
- torch._C._jit_pass_lint(trace.graph())
- torch._C._jit_pass_dce(trace.graph())
- self.assertExpectedTrace(trace, 'raw')
- self.assertExportImport(trace, (x, y))
- torch._C._jit_pass_fuse(trace.graph())
- self.assertExpectedTrace(trace)
+ x = torch.randn(4, 4, dtype=torch.float, device='cuda')
+ y = torch.randn(4, 4, dtype=torch.float, device='cuda')
+
+ ge = self.checkTrace(f, (x, y))
+ self.assertExpectedGraph(ge.graph_for(x, y))
+
+ # TODO: adapt this test to check that GraphExecutor treats them differently
+ @unittest.skip("Need to be adjusted to Graph Executor")
def test_arg_configurations(self):
"""Different arg configurations should trigger different traces"""
x = Variable(torch.FloatTensor(4, 4).uniform_())
@@ -373,97 +328,43 @@ class TestJit(TestCase):
self.assertEqual(fn.hits, 0)
def test_cse(self):
- x = Variable(torch.Tensor([0.4, 0.3]), requires_grad=True)
- y = Variable(torch.Tensor([0.7, 0.5]), requires_grad=True)
-
- trace, inputs = torch._C._tracer_enter((x, y), 0)
+ x = torch.tensor([0.4, 0.3], requires_grad=True)
+ y = torch.tensor([0.7, 0.5], requires_grad=True)
def fn(x, y):
w = (x + y) * (x + y) * (x + y)
t = torch.tanh(w) + torch.tanh(w)
z = (x + y) * (x + y) * (x + y) + t
return z
- z = fn(*inputs)
- torch._C._tracer_exit((z,))
- torch._C._jit_pass_lint(trace.graph())
- torch._C._jit_pass_cse(trace.graph())
- self.assertExpectedTrace(trace)
+ trace, _ = torch.jit.get_trace_graph(fn, (x, y), nderivs=0)
+ self.run_pass('cse', trace)
+ self.assertExpectedGraph(trace)
self.assertExportImport(trace, (x, y))
- def test_compile_run_twice(self):
- x = Variable(torch.Tensor([0.4]), requires_grad=True)
- y = Variable(torch.Tensor([0.7]), requires_grad=True)
-
- @torch.jit.compile(nderivs=0, optimize=False)
- def doit(x, y):
- return torch.sigmoid(torch.tanh(x * (x + y)))
-
- z = doit(x, y)
- with self.assertCompiled(doit):
- z2 = doit(x, y)
- self.assertEqual(z, torch.sigmoid(torch.tanh(x * (x + y))))
- self.assertEqual(z, z2)
-
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
- def test_compile_addc(self):
- x = Variable(torch.Tensor([0.4]), requires_grad=True).float().cuda()
- y = Variable(torch.Tensor([0.7]), requires_grad=True).float().cuda()
+ def test_shape_analysis_broadcast(self):
+ def broadcast(a, b):
+ return a + b
- @torch.jit.compile(nderivs=0)
- def doit(x, y):
- return torch.sigmoid(torch.tanh(x * (x + y) + 1))
+ x = torch.randn(3, 1, 5, requires_grad=True)
+ y = torch.randn(4, 1, 8, 5, requires_grad=True)
- z = doit(x, y)
- with self.assertCompiled(doit):
- z2 = doit(x, y)
- self.assertEqual(z, torch.sigmoid(torch.tanh(x * (x + y) + 1)))
- self.assertEqual(z, z2)
+ graph = torch.jit._script_graph(broadcast)
+ torch._C._jit_pass_shape_analysis(graph, (x, y), False)
+ self.assertExpectedGraph(graph)
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
- def test_compile_fuse_last_device(self):
- max_device = torch.cuda.device_count() - 1
- x = Variable(torch.Tensor([0.4]), requires_grad=True).float().cuda(max_device)
- y = Variable(torch.Tensor([0.7]), requires_grad=True).float().cuda(max_device)
+ def test_fuse_last_device(self):
+ device = 'cuda:' + str(torch.cuda.device_count() - 1)
+ x = torch.tensor([0.4], dtype=torch.float, device=device)
+ y = torch.tensor([0.7], dtype=torch.float, device=device)
- @torch.jit.compile(nderivs=0)
def doit(x, y):
return torch.sigmoid(torch.tanh(x * (x + y) + 1))
- z = doit(x, y)
- with self.assertCompiled(doit):
- z2 = doit(x, y)
- self.assertEqual(z, torch.sigmoid(torch.tanh(x * (x + y) + 1)))
- self.assertEqual(z, z2)
-
- def test_traced_function(self):
- x = Variable(torch.Tensor([0.4]), requires_grad=True)
- y = Variable(torch.Tensor([0.7]), requires_grad=True)
-
- @torch.jit.compile(nderivs=0)
- def doit(x, y):
- return torch.sigmoid(torch.tanh(x * (x + y)))
-
- z = doit(x, y)
- with self.assertCompiled(doit):
- z2 = doit(x, y)
- self.assertEqual(z, torch.sigmoid(torch.tanh(x * (x + y))))
- self.assertEqual(z, z2)
-
- def test_disabled_traced_function(self):
- x = Variable(torch.Tensor([0.4]), requires_grad=True)
- y = Variable(torch.Tensor([0.7]), requires_grad=True)
-
- @torch.jit.compile(enabled=False)
- def doit(x, y):
- return torch.sigmoid(torch.tanh(x * (x + y)))
-
- z = doit(x, y)
- z2 = doit(x, y)
- self.assertEqual(z, torch.sigmoid(torch.tanh(x * (x + y))))
- self.assertEqual(z, z2)
+ ge = self.checkTrace(doit, (x, y))
+ self.assertExpectedGraph(ge.graph_for(x, y))
def test_assign_traces(self):
"""Check that output Variables are assigned traces before they are saved."""
@@ -480,61 +381,17 @@ class TestJit(TestCase):
a, = ctx.saved_tensors
return a * grad_a
- x = Variable(torch.randn(10, 10), requires_grad=True)
+ x = torch.randn(10, 10, requires_grad=True)
trace, out = torch.jit.get_trace_graph(MyFn.apply, x, nderivs=1)
out.sum().backward()
- torch._C._jit_pass_dce(trace.graph())
- self.assertExpectedTrace(trace)
-
- def test_legacy_traced_module(self):
- input = Variable(torch.randn(3, 10))
- hx = Variable(torch.randn(3, 20))
- cx = Variable(torch.randn(3, 20))
-
- @torch.jit.compile(nderivs=0)
- class MyLSTMCell(nn.LSTMCell):
- pass
-
- lstm = MyLSTMCell(10, 20)
-
- out = lstm(input, (hx, cx))
- with self.assertCompiled(lstm):
- out2 = lstm(input, (hx, cx))
- self.assertEqual(out, out2)
-
- def test_autograd_closure(self):
- x = Variable(torch.Tensor([0.4]), requires_grad=True)
- y = Variable(torch.Tensor([0.7]), requires_grad=True)
-
- trace, inputs = torch._C._tracer_enter((x, y), 1)
-
- def fn(x, y):
- z = torch.sigmoid(x * (x + y))
- w = torch.abs(x * x * x + y) + Variable(torch.ones(1))
- return z, w
- z, w = fn(*inputs)
-
- torch._C._tracer_exit((z, w))
- torch._C._jit_pass_lint(trace.graph())
-
- (z * w).backward()
- torch._C._jit_pass_dce(trace.graph())
- torch._C._jit_pass_lint(trace.graph())
-
- x_grad = x.grad.data.clone()
- x.grad.data.zero_()
-
- function = torch._C._jit_createInterpreterFactory(trace)
- torch._C._jit_pass_lint(trace.graph())
- z2, w2 = function()(x, y)
- (z2 * w2).backward()
- self.assertEqual(z, z2)
- self.assertEqual(w, w2)
- self.assertEqual(x.grad.data, x_grad)
+ self.run_pass('dce', trace)
+ self.assertExpectedGraph(trace)
+ # TODO: update verify to work with GraphExecutors
+ @unittest.skip("verify needs to be updated to work with GraphExecutors")
def test_verify(self):
- x = Variable(torch.Tensor([0.4]), requires_grad=True)
- y = Variable(torch.Tensor([0.7]), requires_grad=True)
+ x = torch.tensor([0.4], requires_grad=True)
+ y = torch.tensor([0.7], requires_grad=True)
@torch.jit.compile
def f(x, y):
@@ -545,37 +402,14 @@ class TestJit(TestCase):
torch.jit.verify(f, (x, y), loss_fn=lambda z, w: z * w, devices=[])
def test_constant(self):
- x = Variable(torch.randn(2, 2), requires_grad=True)
-
- trace, (tx,) = torch._C._tracer_enter((x,), 0)
-
- y = Variable(torch.diag(torch.Tensor([2, 2])))
- z = tx.matmul(y)
-
- torch._C._tracer_exit((z,))
- function = torch._C._jit_createInterpreterFactory(trace)
-
- z2 = function()(x)
- self.assertEqual(z, z2)
-
- y.data.fill_(1000) # make sure the data has been cloned
+ x = torch.randn(2, 2, requires_grad=True)
- x2 = Variable(torch.ones(2, 2) * 2, requires_grad=True)
- z3 = function()(x2)
- self.assertEqual(z3.data, torch.ones(2, 2) * 4)
-
- def test_c_function(self):
- x = Variable(torch.randn(1, 3, 10, 10))
- m = nn.Conv2d(3, 8, 3, 1)
+ def f(x):
+ return x.matmul(torch.diag(torch.tensor([2., 2.])))
- trace, inputs = torch._C._tracer_enter((x,) + tuple(m.parameters()), 0)
- y = m(inputs[0])
- torch._C._tracer_exit((y,))
- self.assertExpectedTrace(trace)
- self.assertExportImport(trace, inputs)
+ self.checkTrace(f, (x,), (torch.ones(2, 2, requires_grad=True),))
def test_legacy_fail(self):
-
class MyLegacyFn(Function):
def forward(self, x):
return x
@@ -583,24 +417,22 @@ class TestJit(TestCase):
def backward(self, grad_output):
return grad_output
- x = Variable(torch.Tensor([0]), requires_grad=True)
- trace, inputs = torch._C._tracer_enter((x,), 0)
- self.assertRaisesRegex(RuntimeError, "MyLegacyFn", lambda: MyLegacyFn()(*inputs))
- torch._C._tracer_exit(inputs)
+ x = torch.tensor([0.], requires_grad=True)
+ with self.assertRaisesRegex(RuntimeError, "MyLegacyFn"):
+ torch.jit.get_trace_graph(lambda x: MyLegacyFn()(x), (x,), nderivs=0)
def test_inplace_transplant(self):
- x = Variable(torch.Tensor([0]), requires_grad=True)
- trace, inputs = torch._C._tracer_enter((x,), 0)
+ x = torch.tensor([0.], requires_grad=True)
def fn(x):
y = x.clone()
y.add_(2)
y.add_(3)
return y
- y = fn(*inputs)
- torch._C._tracer_exit((y,))
- self.assertExpectedTrace(trace)
- self.assertExportImport(trace, inputs)
+
+ trace, _ = torch.jit.get_trace_graph(fn, (x,), nderivs=0)
+ self.assertExpectedGraph(trace)
+ self.assertExportImport(trace, (x,))
def test_inplace_flags(self):
class InplaceFn(Function):
@@ -622,8 +454,7 @@ class TestJit(TestCase):
def backward(ctx, go):
return go
- x = Variable(torch.Tensor([0]), requires_grad=True)
- trace, inputs = torch._C._tracer_enter((x,), 0)
+ x = torch.tensor([0], requires_grad=True)
def fn(x):
y = RegularFn.apply(x)
@@ -631,9 +462,9 @@ class TestJit(TestCase):
y = InplaceFn.apply(y)
y = RegularFn.apply(y)
return y
- y = fn(*inputs)
- torch._C._tracer_exit((y,))
- torch._C._jit_pass_dce(trace.graph())
+
+ trace, _ = torch.jit.get_trace_graph(fn, (x,), nderivs=0)
+ self.run_pass('dce', trace)
ops = [n for n in trace.graph().nodes()]
for op in ops:
self.assertTrue(op.hasAttribute('inplace'))
@@ -653,125 +484,13 @@ class TestJit(TestCase):
def backward(self, grad):
return grad
- @torch.jit.compile(nderivs=0)
def fn(x):
return MyInplaceFn.apply(x)
- x = Variable(torch.randn(5, 5))
- fn(x) # trace
- with self.assertRaisesRegex(RuntimeError, 'inplace MyInplaceFn'):
- fn(x)
-
- def test_backward(self):
- a = Variable(torch.randn(2, 2), requires_grad=True)
- b = Variable(torch.randn(2, 2), requires_grad=True)
-
- x = a
- y = a * b
-
- trace, inputs = torch._C._tracer_enter((x, y), 2)
- def fn(x, y):
- return y * 2 * x
- z = fn(*inputs)
- torch._C._tracer_exit((z,))
- torch._C._jit_pass_lint(trace.graph())
-
- # Run first backward
- grad, = torch.autograd.grad(z, x, Variable(torch.ones(2, 2), requires_grad=True), create_graph=True)
- torch._C._jit_pass_lint(trace.graph())
-
- # Run second backward
- grad.sum().backward(create_graph=True)
- torch._C._jit_pass_lint(trace.graph())
-
- # Run dead code elimination to remove unused trace nodes
- torch._C._jit_pass_dce(trace.graph())
- # This is nondeterministic, see:
- # https://github.com/ezyang/pytorch/issues/227
- # self.assertExpectedTrace(trace)
- self.skipTest("output is nondeterministic on Travis/Python 3.5")
-
- def test_backward_opaque(self):
- x = Variable(torch.randn(3, 3), requires_grad=True)
- y = Variable(torch.randn(3, 3), requires_grad=True)
-
- trace, inputs = torch._C._tracer_enter((x, y), 2)
-
- def fn(x, y):
- return x.cross(y)
- z = fn(*inputs)
- torch._C._tracer_exit((z,))
- torch._C._jit_pass_lint(trace.graph())
-
- # Run first backward
- grad, = torch.autograd.grad(z, x, Variable(torch.ones(3, 3), requires_grad=True), create_graph=True)
- torch._C._jit_pass_lint(trace.graph())
-
- # Run dead code elimination to remove unused trace nodes
- torch._C._jit_pass_dce(trace.graph())
- # This is nondeterministic, see:
- # https://github.com/ezyang/pytorch/issues/227
- # self.assertExpectedTrace(trace)
- self.skipTest("output is nondeterministic on Travis/Python 3.5")
-
- def test_backward_closure(self):
- """Check that autograd closures handle multiple stages correctly."""
- x = Variable(torch.randn(1), requires_grad=True)
-
- @torch.jit.compile(nderivs=2)
- def fn(x):
- return x * x
-
- # Generate trace
- grad_x, = torch.autograd.grad(fn(x), (x,), create_graph=True)
- self.assertFalse(fn.has_trace_for(x))
- grad_x.backward()
- self.assertTrue(fn.has_trace_for(x))
-
- x_grad = x.grad.data.clone()
- x.grad.data.zero_()
-
- # Run the trace
- with self.assertCompiled(fn):
- output = fn(x)
- grad_x, = torch.autograd.grad(output, (x,), create_graph=True)
- grad_x.backward()
-
- self.assertEqual(x.grad.data, x_grad)
-
- def test_trace_expire(self):
- x = Variable(torch.randn(2, 2), requires_grad=True)
- y = Variable(torch.randn(2, 2), requires_grad=True)
-
- def record_trace(num_backwards):
- trace, inputs = torch._C._tracer_enter((x, y), num_backwards)
-
- def fn(x, y):
- return y * 2 * x
- z = fn(*inputs)
- torch._C._tracer_exit((z,))
- return z, trace
-
- def check(expired, complete):
- self.assertEqual(trace.is_expired, expired)
- self.assertEqual(trace.is_complete, complete)
-
- z, trace = record_trace(0)
- check(False, True)
- del z
- check(False, True)
-
- z, trace = record_trace(1)
- check(False, False)
- del z
- check(True, False)
-
- z, trace = record_trace(1)
- check(False, False)
- z.sum().backward()
- check(False, True)
- del z
- check(False, True)
+ x = torch.randn(5, 5)
+ ge = torch._C.GraphExecutor(fn, (x,))
+ with self.assertRaisesRegex(RuntimeError, 'inplace MyInplaceFn'):
+ ge(x)
def do_trace_size(self, requires_grad):
def fn(x):
@@ -787,7 +506,7 @@ class TestJit(TestCase):
# Check that the trace looks ok
trace, _ = torch.jit.get_trace_graph(fn, (x,))
- self.assertExpectedTrace(trace)
+ self.assertExpectedGraph(trace)
def test_trace_size(self):
self.do_trace_size(False)
@@ -797,62 +516,31 @@ class TestJit(TestCase):
def test_trace_size_with_grad(self):
self.do_trace_size(True)
- def test_multiuse_fn(self):
- x = Variable(torch.randn(2, 2), requires_grad=True)
- w = Variable(torch.randn(2, 2), requires_grad=True)
-
- @torch.jit.compile
- def cell(x, w):
- return x * w + 2
-
- out = cell(cell(cell(x, w), w), w)
- self.assertFalse(cell.has_trace_for(x, w))
-
- out.sum().backward()
- self.assertTrue(cell.has_trace_for(x, w))
-
- torch.jit.verify(cell, (x, w), devices=[])
-
+ # TODO: implement
+ @unittest.expectedFailure
def test_output_unflatten(self):
"""Check that outputs of traced functions retain the original structure and nesting"""
- x = Variable(torch.randn(2, 2), requires_grad=True)
-
def fn(x):
return (x * 2, (x ** 2, x + 4, (x + 2,), ), x * 4)
- expected_out = fn(x)
- fn = torch.jit.compile(fn)
-
- def recursive_sum(obj):
- if isinstance(obj, Variable):
- return obj.sum()
- else:
- return sum(recursive_sum(o) for o in obj)
-
- recursive_sum(fn(x)).backward()
- self.assertTrue(fn.has_trace_for(x))
- with self.assertCompiled(fn):
- self.assertEqual(fn(x), expected_out)
+ self.checkTrace(fn, (torch.randn(2, 2),))
+ # TODO: implement
+ @unittest.expectedFailure
def test_input_flatten(self):
"""Check that inputs to traced functions are flattened"""
- def make_var():
- return Variable(torch.randn(1), requires_grad=True)
- x = (make_var(), (make_var(), make_var()))
def fn(x, t):
y, z = t
return x * y * z
- expected_out = fn(*x)
- fn = torch.jit.compile(fn)
- fn(*x).backward()
- self.assertTrue(fn.has_trace_for(*x))
- with self.assertCompiled(fn):
- self.assertEqual(fn(*x), expected_out)
+ inputs = (torch.randn(1), (torch.randn(1), torch.randn(1)))
+ self.checkTrace(fn, inputs)
+ # TODO: adapt to a GraphExecutor test
+ @unittest.skip("Need to instrument GraphExecutors a bit more")
def test_flags(self):
- x = Variable(torch.randn(2, 2))
+ x, y = torch.randn(2, 2)
y = Variable(torch.randn(2, 2))
@torch.jit.compile
@@ -876,53 +564,15 @@ class TestJit(TestCase):
self.assertEqual(grad_v, expected_grad)
self.assertEqual(fn.has_trace_for(x, y), rx or ry)
- def test_no_grad_fallback(self):
- """Check that Traceable falls back to num_backwards=0 if in no-backprop mode"""
- x = Variable(torch.randn(2, 2))
- y = Variable(torch.randn(2, 2), requires_grad=True)
-
- @torch.jit.compile
- def fn(x, y):
- return x * x + x * y
-
- out = fn(x, y)
- self.assertFalse(fn.has_trace_for(x, y))
- with torch.no_grad():
- out = fn(x, y)
- self.assertTrue(fn.has_trace_for(x, y))
- with self.assertCompiled(fn):
- out2 = fn(x, y)
- self.assertEqual(out, out2)
-
- def test_backward_flag_checks(self):
- x = Variable(torch.randn(1), requires_grad=True)
-
- @torch.jit.compile(nderivs=2)
- def fn(x):
- return x * x
-
- grad_x, = torch.autograd.grad(fn(x), (x,), create_graph=True)
- self.assertFalse(fn.has_trace_for(x))
- grad_x.backward()
- self.assertTrue(fn.has_trace_for(x))
-
- with self.assertRaisesRegex(RuntimeError, 'was compiled with'):
- fn(x).backward(Variable(torch.ones(1), requires_grad=True))
- with self.assertRaisesRegex(RuntimeError, 'was compiled with'):
- grad_x, = torch.autograd.grad(fn(x), (x,), create_graph=True)
- grad_x.backward(Variable(torch.ones(1), requires_grad=True))
-
- # TODO: Test executing this
-
def test_python_ir(self):
- x = Variable(torch.Tensor([0.4]), requires_grad=True)
- y = Variable(torch.Tensor([0.7]), requires_grad=True)
+ x = torch.tensor([0.4], requires_grad=True)
+ y = torch.tensor([0.7], requires_grad=True)
def doit(x, y):
return torch.sigmoid(torch.tanh(x * (x + y)))
- traced, _ = torch.jit.get_trace_graph(doit, (x, y))
- g = torch._C._jit_get_graph(traced)
+ trace, _ = torch.jit.get_trace_graph(doit, (x, y))
+ g = trace.graph()
g2 = torch._C.Graph()
g_to_g2 = {}
for node in g.inputs():
@@ -937,9 +587,9 @@ class TestJit(TestCase):
g2.registerOutput(g_to_g2[node])
t_node = g2.create("prim::TensorTest").t_("a", torch.ones([2, 2]))
- assert(t_node.attributeNames() == ["a"])
+ self.assertEqual(t_node.attributeNames(), ["a"])
g2.appendNode(t_node)
- assert(torch.equal(torch.ones([2, 2]), t_node.t("a")))
+ self.assertTrue(torch.equal(torch.ones(2, 2), t_node.t("a")))
self.assertExpected(str(g2))
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@@ -950,245 +600,40 @@ class TestJit(TestCase):
self.assertExpected(torch._C._jit_run_cpp_tests())
def test_batchnorm(self):
- x = Variable(torch.randn(2, 2, 2, 2).fill_(1.0), requires_grad=True)
+ x = torch.ones(2, 2, 2, 2)
trace, _ = torch.jit.get_trace_graph(nn.BatchNorm2d(2), x)
- self.assertExpectedTrace(trace)
+ self.assertExpectedGraph(trace)
def test_dropout(self):
- x = Variable(torch.randn(2, 2).fill_(1.0), requires_grad=True)
+ x = torch.ones(2, 2)
trace, _ = torch.jit.get_trace_graph(nn.Dropout(0.6), x)
- self.assertExpectedTrace(trace)
-
- def test_batchnorm_run_twice(self):
- @torch.jit.compile(nderivs=0)
- class MyBatchNorm2d(nn.BatchNorm2d):
- pass
-
- bn = MyBatchNorm2d(1)
- x = Variable(torch.randn(5, 1, 2, 1))
- z = bn(x)
- with self.assertCompiled(bn):
- z2 = bn(x)
- self.assertEqual(z, z2)
-
- def test_non_decorator_use_fails(self):
- MyLSTM = torch.jit.compile(nn.LSTM)
- self.assertRaisesRegex(TypeError, "class decorator", lambda: MyLSTM(2, 2))
+ self.assertExpectedGraph(trace)
def test_conv(self):
- x = Variable(torch.randn(20, 16, 50, 40).fill_(1.0), requires_grad=True)
+ x = torch.ones(20, 16, 50, 40)
trace, _ = torch.jit.get_trace_graph(nn.Conv2d(16, 13, 3, bias=False), x)
- self.assertExpectedTrace(trace)
-
- def test_reuse_function(self):
- @torch.jit.compile(nderivs=0)
- def clinear(*args):
- return F.linear(*args)
-
- def cast(x):
- return x
-
- input = Variable(cast(torch.randn(1, 1)))
- weights = Variable(cast(torch.randn(1, 1)))
- bias = Variable(cast(torch.randn(1, 1)))
-
- # linear AKA addmm without bias is of particular interest
- # because we allocate a zero-filled new variable when we execute,
- # and then *fill* it with the result
-
- r1_ = clinear(input, weights)
- with self.assertCompiled(clinear):
- r1 = clinear(r1_, weights)
- r2 = F.linear(F.linear(input, weights), weights)
-
- self.assertEqual(r1, r2)
-
- def test_unused_input(self):
- @torch.jit.compile(nderivs=1)
- def fn(a, b, c):
- return a + b
-
- a, b, c = [Variable(torch.randn(2, 2), requires_grad=True) for _ in range(3)]
- fn(a, b, c).sum().backward()
- with self.assertCompiled(fn):
- fn(a, b, c).sum().backward()
+ self.assertExpectedGraph(trace)
def test_repeated_input(self):
- @torch.jit.compile(nderivs=1)
def fn(a, b):
return a + b
- a, b = [Variable(torch.randn(2, 2), requires_grad=True) for _ in range(2)]
- fn(a, a).sum().backward()
- with self.assertCompiled(fn):
- fn(a, a).sum().backward()
- with self.assertCompiled(fn):
- fn(a, b).sum().backward()
- self.assertExpected(str(fn.graph_for(a, a)))
+ ge = self.checkTrace(fn, [torch.randn(2, 2)] * 2)
+ self.assertExpectedGraph(ge.graph)
def test_repeated_output(self):
- @torch.jit.compile(nderivs=1)
def fn(a, b):
z = a + b
return z, z
- a, b = [Variable(torch.randn(2, 2), requires_grad=True) for _ in range(2)]
- sum(fn(a, b)).sum().backward()
- with self.assertCompiled(fn):
- sum(fn(a, b)).sum().backward()
- self.assertExpected(str(fn.graph_for(a, b)))
-
- def test_re_enter(self):
- @torch.jit.compile(nderivs=1)
- def fn(a, b):
- return a + b
-
- @torch.jit.compile(nderivs=1)
- def fn2(a, b, c):
- return fn(a, b) + c
-
- a, b, c = [Variable(torch.randn(2, 2), requires_grad=True) for _ in range(3)]
-
- fn(a, b).sum().backward()
- with self.assertCompiled(fn):
- fn(a, b).sum().backward()
-
- fn2(a, b, c).sum().backward()
- with self.assertCompiled(fn2):
- fn2(a, b, c).sum().backward()
-
- def test_mini_wlm(self):
- """Exercise null-edge pruning in the tracer."""
-
- @torch.jit.compile
- class MyModel(nn.Module):
- def __init__(self):
- super(MyModel, self).__init__()
- self.encoder = nn.Embedding(2, 2)
-
- def forward(self, input, hidden):
- emb = self.encoder(input)
- hidden = hidden.clone() # simulate some RNN operation
- return emb, hidden
-
- model = MyModel()
-
- x = Variable(torch.LongTensor([[0, 1], [1, 0]]))
- y = Variable(torch.FloatTensor([0]))
-
- z, _ = model(x, y)
- z.sum().backward()
- self.assertTrue(model.has_trace_for(x, y))
-
- with self.assertCompiled(model):
- z, _ = model(x, y)
- z.sum().backward()
-
- def test_module_cast(self):
- """Compiled modules can be casted to other data types"""
- @torch.jit.compile(nderivs=0)
- class Adder(nn.Module):
- def __init__(self):
- super(Adder, self).__init__()
- self.y = nn.Parameter(torch.randn(2, 2))
-
- def forward(self, x):
- return x + self.y
-
- x = Variable(torch.randn(2, 2).float())
- # Wrap it in a sequential to make sure it works for submodules
- a = nn.Sequential(Adder()).float()
-
- def check_type(caster):
- caster(a)
- a(caster(x))
- with self.assertCompiled(a[0]):
- a(caster(x))
-
- check_type(lambda x: x)
- check_type(lambda x: x.double())
- if torch.cuda.is_available():
- check_type(lambda x: x.float().cuda())
- check_type(lambda x: x.double().cuda())
- self.assertEqual(a[0].hits, 4 if torch.cuda.is_available() else 2)
-
- # Tracer fails when it receives the same grad variable as multiple input to
- # traced region. The problem is that it's not immediately obvious how to
- # assign multiple inputs to this Variable. It might be possible to solve
- # this using the view mechanism, but this requires some thought.
- # In general, it should be supported, because the user has no control
- # over this (and it's quite common, e.g. the sum call below will pass the same
- # grad variable as both inputs to grad of fn).
- @unittest.skip("Broken - repeated grads trigger an assertion failure.")
- def test_repeated_grad(self):
- @torch.jit.compile
- def fn(x):
- return x * x, x + x
-
- x = Variable(torch.randn(5, 5), requires_grad=True)
- # This shouldn't raise!
- sum(fn(x)).sum().backward()
-
- def test_input_pruning(self):
- """Check that stage 1 will return only one value"""
- # One of the inputs doesn't require grad, so it should be pruned
- @torch.jit.compile
- def fn(x, y):
- return x * y, x + y
-
- x = Variable(torch.randn(5, 5), requires_grad=True)
- y = Variable(torch.randn(5, 5))
-
- out = fn(x, y)
- (out[0] * out[1]).sum().backward()
- with self.assertCompiled(fn):
- fn(x, y)
- self.assertExpected(str(fn.graph_for(x, y)))
-
- def test_output_pruning(self):
- """Check that stage 1 will take one value as an argument"""
- # One of the outputs doesn't require grad, so it should be pruned
- @torch.jit.compile
- def fn(x, y):
- return x * y, y + y
-
- x = Variable(torch.randn(5, 5), requires_grad=True)
- y = Variable(torch.randn(5, 5))
-
- out = fn(x, y)
- (out[0] * out[1]).sum().backward()
- with self.assertCompiled(fn):
- fn(x, y)
- self.assertExpected(str(fn.graph_for(x, y)))
+ ge = self.checkTrace(fn, [torch.randn(2, 2) for _ in range(2)])
+ self.assertExpectedGraph(ge.graph)
@skipIfNoTorchVision
def test_alexnet(self):
- return
- x = Variable(torch.randn(10, 3, 224, 224).fill_(1.0), requires_grad=True)
+ x = torch.ones(1, 3, 224, 224)
trace, _ = torch.jit.get_trace_graph(torchvision.models.AlexNet(), x)
- self.assertExpectedTrace(trace)
- self.assertExportImport(trace, (x, ))
- # NB: Purposely NOT testing protobuf export here
-
- def test_debug_info(self):
- """Check that debug info doesn't crash and has some reasonable info"""
-
- @torch.jit.compile(nderivs=1)
- def fn(x, y):
- return x * y + x + y
-
- x = Variable(torch.randn(5, 5), requires_grad=True)
- y = Variable(torch.randn(5, 5), requires_grad=True)
-
- out = fn(x, y)
-
- out.sum().backward()
-
- for _ in range(0, 100):
- out = fn(x, y)
- info_str = fn.jit_debug_info()
- self.assertTrue("hits: 100" in info_str)
- self.assertTrue("stage 1" in info_str)
+ self.assertExpectedGraph(trace)
# Inplace copies don't work with tracer yet.
# This is actually somewhat important to support correctly
@@ -1197,7 +642,7 @@ class TestJit(TestCase):
# viewed portion.
@unittest.expectedFailure
def test_inplace_copy(self):
- x = Variable(torch.randn(4, 4), requires_grad=True)
+ x = torch.randn(4, 4, requires_grad=True)
def f(x):
out = Variable(torch.zeros(x.size()))
@@ -1205,28 +650,9 @@ class TestJit(TestCase):
return out
trace, z = torch.jit.get_trace_graph(f, (x, ), nderivs=0)
- torch._C._jit_pass_lint(trace.graph())
- torch._C._jit_pass_dce(trace.graph())
- self.assertExpectedTrace(trace)
- self.assertExportImport(trace, (x, ))
-
- def test_index_trace(self):
- x = Variable(torch.randn(4, 4), requires_grad=True)
- trace, z = torch.jit.get_trace_graph(lambda x: x[0], (x, ), nderivs=1)
- z.sum().backward()
- torch._C._jit_pass_lint(trace.graph())
- torch._C._jit_pass_dce(trace.graph())
- self.assertExpectedTrace(trace)
-
- def test_saved_output(self):
- x = Variable(torch.randn(4, 4), requires_grad=True)
-
- @torch.jit.compile(nderivs=1)
- def fn(x):
- return x.sigmoid()
-
- fn(x).sum().backward()
- self.assertExpected(str(fn.graph_for(x)))
+ self.run_pass('dce', trace)
+ self.assertExpectedGraph(trace)
+ self.assertExportImport(trace, (x,))
def test_shared_param(self):
@@ -1239,18 +665,18 @@ class TestJit(TestCase):
return x * self.a + self.b
m = MyModule()
- trace, _ = torch.jit.get_trace_graph(m, (Variable(torch.randn(2, 2)),), nderivs=0)
+ trace, _ = torch.jit.get_trace_graph(m, (torch.randn(2, 2),), nderivs=0)
self.assertEqual(len(list(trace.graph().inputs())), 2)
- self.assertExpected(str(trace))
+ self.assertExpectedGraph(trace)
def test_nested_inplace(self):
- x = Variable(torch.randn(2, 2))
+ x = torch.randn(2, 2)
trace, _ = torch.jit.get_trace_graph(lambda x: F.threshold(x, 0, 0, inplace=True), (x,), nderivs=0)
- self.assertExpectedTrace(trace)
- self.assertExportImport(trace, (x, ))
+ self.assertExpectedGraph(trace)
+ self.assertExportImport(trace, (x,))
- def checkGraphExecutor(self, func, reference_tensors, input_tensors=None,
- optimize=True, drop=None, allow_unused=False):
+ def checkTrace(self, func, reference_tensors, input_tensors=None,
+ optimize=True, drop=None, allow_unused=False):
def allSum(vs):
# drop allows us to remove some values from ever being used
# to test unused outputs
@@ -1262,11 +688,10 @@ class TestJit(TestCase):
if input_tensors is None:
input_tensors = reference_tensors
- nograd_inputs = [Variable(t) for t in reference_tensors]
- recording_inputs = [Variable(t, requires_grad=True)
- for t in reference_tensors]
+ nograd_inputs = reference_tensors
+ recording_inputs = [t.clone().requires_grad_() for t in reference_tensors]
- ge = torch._C.GraphExecutor(func, [Variable(t) for t in input_tensors], optimize)
+ ge = torch.jit.trace(*input_tensors, optimize=optimize)(func)
# test no gradients case
@@ -1307,7 +732,12 @@ class TestJit(TestCase):
self.assertEqual(outputs, outputs_ge)
self.assertEqual(grads, grads_ge)
- self.assertEqual(grads2, grads2_ge)
+ for g2, g2_ge in zip(grads2, grads2_ge):
+ if g2 is None and g2_ge is None:
+ continue
+ self.assertTrue(torch.allclose(g2, g2_ge, atol=5e-4, rtol=1e-4))
+
+ return ge
def run_ge_tests(self, optimize, use_cuda):
def rand(*args):
@@ -1315,27 +745,27 @@ class TestJit(TestCase):
if use_cuda:
t = t.cuda()
return t
- self.checkGraphExecutor(lambda a, b: a * b + b,
- [rand(1), rand(1)], [rand(2, 3), rand(2, 3)],
- optimize=optimize)
+ self.checkTrace(lambda a, b: a * b + b,
+ [rand(1), rand(1)], [rand(2, 3), rand(2, 3)],
+ optimize=optimize)
# trivial identity
- self.checkGraphExecutor(lambda a, b: (
+ self.checkTrace(lambda a, b: (
b, a), [rand(1), rand(1)], optimize=optimize)
def foo(a):
t = a * a
return t * t, 4 * t
- self.checkGraphExecutor(foo, [rand(1)], optimize=optimize)
+ self.checkTrace(foo, [rand(1)], optimize=optimize)
# unused input
- self.checkGraphExecutor(
+ self.checkTrace(
lambda a, b: a * a, [rand(1), rand(1)], optimize=optimize,
allow_unused=True)
# test outputs that do not get used in grad
- self.checkGraphExecutor(foo, [rand(1)], drop=1, optimize=optimize)
+ self.checkTrace(foo, [rand(1)], drop=1, optimize=optimize)
# test autograd fallback
- self.checkGraphExecutor(lambda a, b: a * b /
- (a - 2 * b) + b, [rand(1), rand(1)],
- optimize=optimize)
+ self.checkTrace(lambda a, b: a * b /
+ (a - 2 * b) + b, [rand(1), rand(1)],
+ optimize=optimize)
def test_ge_unoptimized(self):
self.run_ge_tests(False, False)
@@ -1373,10 +803,12 @@ class TestJit(TestCase):
self.assertEqual(g2result, g2result2)
def test_trace_annotation(self):
- @torch.jit.trace(Variable(torch.rand(1)))
+ @torch.jit.trace(torch.rand(1))
def foo(a):
return a + a + a
- s = Variable(torch.rand(2))
+
+ x = torch.randn(5, 5)
+ self.assertEqual(foo(x), x + x + x)
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "calls .cuda()")
@@ -1439,86 +871,6 @@ class TestJit(TestCase):
self.assertEqual(out, out_state)
self.assertNotEqual(out, out_ones)
- def test_shape_prop_mismatch_output(self):
- with self.assertRaises(RuntimeError):
- cu = torch.jit.CompilationUnit('''
- def test_shape_prop_mismatch_output(a):
- b = slice(a, dim=0, end=-2, start=2, step=1)
- b = topk(a, dim=0, k=2, largest=True, sorted=True)
- return b
- ''')
- inputs = [torch.zeros(10)]
- outputs = [torch.zeros(2), torch.from_numpy(np.array([1, 5])).long()]
-
- real_outs = cu.test_shape_prop_mismatch_output(*inputs)
- self.assertEqual(real_outs, outputs)
-
- def test_view_shape_prop(self):
- cu = torch.jit.CompilationUnit('''
- def test_view_shape_prop(a):
- return view(a, size=[-1])
- ''')
- inputs = [torch.zeros(10, 10)]
- outputs = torch.zeros(100)
-
- real_outs = cu.test_view_shape_prop(*inputs)
- self.assertEqual(real_outs, outputs)
-
- def test_integral_shape_inference(self):
- cu = torch.jit.CompilationUnit('''
- def test_integral_shape_inference(a):
- return a / a
- ''')
- inputs = [torch.ones(10, 10).type(torch.LongTensor)]
- outputs = torch.ones(10, 10)
-
- self.assertEqual(cu.test_integral_shape_inference(*inputs), outputs)
-
- def test_shape_analysis_broadcast(self):
- def broadcast(a, b):
- return a + b
-
- x = torch.randn(3, 1, 5, requires_grad=True)
- y = torch.randn(4, 1, 8, 5, requires_grad=True)
-
- graph = torch.jit._script_graph(broadcast)
- torch._C._jit_pass_shape_analysis(graph, (x, y), False)
- self.assertExpected(str(graph))
-
- def test_fuser_multiple_blocks(self):
- cu = torch.jit.CompilationUnit('''
- def test_fuser_multiple_blocks(this, that, theother, meme):
- i = 0
- while i < 20:
- this = cat([this, meme], dim=0)
- that = cat([that, meme], dim=0)
- theother = cat([theother, meme], dim=0)
- i = i + 1
- return this, that, theother
- ''')
-
- inputs = [torch.ones(0, 10, 10)] * 3
- inputs += [torch.ones(1, 10, 10)]
- outputs = [torch.ones(20, 10, 10)] * 3
-
- self.assertEqual(cu.test_fuser_multiple_blocks(*inputs), outputs)
-
- def test_dropout_script(self):
-
- eg = torch.zeros(1, 2, 3, requires_grad=True)
-
- @torch.jit.trace(eg)
- def foo(x):
- x = torch.neg(x)
- return F.dropout(x)
-
- class MyDrop(nn.Module):
- def forward(self, x):
- return foo(x)
-
- f = io.BytesIO()
- torch.onnx.export(MyDrop(), (eg,), f, verbose=False)
-
def test_python_function(self):
class MyFn(Function):
@staticmethod
@@ -1752,6 +1104,61 @@ class TestScript(TestCase):
# NOTE: cannot optimize yet because broadcasts are not inserted before the fuser runs
self.checkScript(script, [alpha, beta, x, y], optimize=False, outputs=outputs)
+ def test_view_shape_prop(self):
+ cu = torch.jit.CompilationUnit('''
+ def test_view_shape_prop(a):
+ return view(a, size=[-1])
+ ''')
+ inputs = [torch.zeros(10, 10)]
+ outputs = torch.zeros(100)
+
+ real_outs = cu.test_view_shape_prop(*inputs)
+ self.assertEqual(real_outs, outputs)
+
+ def test_integral_shape_inference(self):
+ cu = torch.jit.CompilationUnit('''
+ def test_integral_shape_inference(a):
+ return a / a
+ ''')
+ inputs = [torch.ones(10, 10).type(torch.LongTensor)]
+ outputs = torch.ones(10, 10)
+
+ self.assertEqual(cu.test_integral_shape_inference(*inputs), outputs)
+
+ def test_fuser_multiple_blocks(self):
+ cu = torch.jit.CompilationUnit('''
+ def test_fuser_multiple_blocks(this, that, theother, meme):
+ i = 0
+ while i < 20:
+ this = cat([this, meme], dim=0)
+ that = cat([that, meme], dim=0)
+ theother = cat([theother, meme], dim=0)
+ i = i + 1
+ return this, that, theother
+ ''')
+
+ inputs = [torch.ones(0, 10, 10)] * 3
+ inputs += [torch.ones(1, 10, 10)]
+ outputs = [torch.ones(20, 10, 10)] * 3
+
+ self.assertEqual(cu.test_fuser_multiple_blocks(*inputs), outputs)
+
+ def test_dropout_script(self):
+
+ eg = torch.zeros(1, 2, 3, requires_grad=True)
+
+ @torch.jit.trace(eg)
+ def foo(x):
+ x = torch.neg(x)
+ return F.dropout(x)
+
+ class MyDrop(nn.Module):
+ def forward(self, x):
+ return foo(x)
+
+ f = io.BytesIO()
+ torch.onnx.export(MyDrop(), (eg,), f, verbose=False)
+
@unittest.skip("RuntimeError: VariableType::ID() not implemented")
def test_cast(self):
script = '''