diff options
author | Adam Paszke <adam.paszke@gmail.com> | 2018-05-16 20:03:04 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-05-16 20:03:04 +0200 |
commit | b45f2ff1ae5b8a7fb04f96a9d6de86df39c1dcb2 (patch) | |
tree | 8c27506722d9fcb9055f51f9abf90cc1d5b6d63a /test | |
parent | 28b0b16f9bd8678b1f31bb4fb6de1b31f6e87f3d (diff) | |
download | pytorch-b45f2ff1ae5b8a7fb04f96a9d6de86df39c1dcb2.tar.gz pytorch-b45f2ff1ae5b8a7fb04f96a9d6de86df39c1dcb2.tar.bz2 pytorch-b45f2ff1ae5b8a7fb04f96a9d6de86df39c1dcb2.zip |
Remove CompiledFunction + clean up JIT tests (#7421)
Diffstat (limited to 'test')
23 files changed, 440 insertions, 1106 deletions
diff --git a/test/expect/TestJit.test_alexnet.expect b/test/expect/TestJit.test_alexnet.expect index 62f45eff89..9a927efd24 100644 --- a/test/expect/TestJit.test_alexnet.expect +++ b/test/expect/TestJit.test_alexnet.expect @@ -1,46 +1,49 @@ -graph(%1 : Double(10, 3, 224, 224) - %2 : Double(64, 3, 11, 11) - %3 : Double(64) - %4 : Double(192, 64, 5, 5) - %5 : Double(192) - %6 : Double(384, 192, 3, 3) - %7 : Double(384) - %8 : Double(256, 384, 3, 3) - %9 : Double(256) - %10 : Double(256, 256, 3, 3) - %11 : Double(256) - %12 : Double(4096, 9216) - %13 : Double(4096) - %14 : Double(4096, 4096) - %15 : Double(4096) - %16 : Double(1000, 4096) - %17 : Double(1000)) { - %19 : Double(10, 64, 55, 55), %20 : Handle = CppOp[ConvForward](%1, %2, %3), uses = [[%21.i0], []]; - %21 : Double(10, 64, 55, 55) = threshold[threshold={0}, value={0}, inplace=1](%19), uses = [%22.i0]; - %23 : Double(10, 64, 27, 27), %24 : Long(10, 64, 27, 27) = max_pool2d[kernel_size=[3, 3], stride=[2, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=0](%21), uses = [[%25.i0], []]; - %26 : Double(10, 192, 27, 27), %27 : Handle = CppOp[ConvForward](%23, %4, %5), uses = [[%28.i0], []]; - %28 : Double(10, 192, 27, 27) = threshold[threshold={0}, value={0}, inplace=1](%26), uses = [%29.i0]; - %30 : Double(10, 192, 13, 13), %31 : Long(10, 192, 13, 13) = max_pool2d[kernel_size=[3, 3], stride=[2, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=0](%28), uses = [[%32.i0], []]; - %33 : Double(10, 384, 13, 13), %34 : Handle = CppOp[ConvForward](%30, %6, %7), uses = [[%35.i0], []]; - %35 : Double(10, 384, 13, 13) = threshold[threshold={0}, value={0}, inplace=1](%33), uses = [%36.i0]; - %37 : Double(10, 256, 13, 13), %38 : Handle = CppOp[ConvForward](%35, %8, %9), uses = [[%39.i0], []]; - %39 : Double(10, 256, 13, 13) = threshold[threshold={0}, value={0}, inplace=1](%37), uses = [%40.i0]; - %41 : Double(10, 256, 13, 13), %42 : Handle = CppOp[ConvForward](%39, %10, %11), uses = [[%43.i0], []]; - %43 : Double(10, 256, 13, 13) = threshold[threshold={0}, value={0}, inplace=1](%41), uses = [%44.i0]; - %45 : Double(10, 256, 6, 6), %46 : Long(10, 256, 6, 6) = max_pool2d[kernel_size=[3, 3], stride=[2, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=0](%43), uses = [[%47.i0], []]; - %47 : Double(10, 9216) = view[size=[10, 9216]](%45), uses = [%48.i0]; - %49 : Double(10, 9216), %50 : Handle = ^Dropout(0.5, True, False)(%47), uses = [[%53.i1], []]; - %51 : Double(9216!, 4096!) = t(%12), uses = [%53.i2]; - %52 : Double(10!, 4096) = expand[size=[10, 4096]](%13), uses = [%53.i0]; - %53 : Double(10, 4096) = addmm[beta={1}, alpha={1}](%52, %49, %51), uses = [%54.i0]; - %54 : Double(10, 4096) = threshold[threshold={0}, value={0}, inplace=1](%53), uses = [%55.i0]; - %56 : Double(10, 4096), %57 : Handle = ^Dropout(0.5, True, False)(%54), uses = [[%60.i1], []]; - %58 : Double(4096!, 4096!) = t(%14), uses = [%60.i2]; - %59 : Double(10!, 4096) = expand[size=[10, 4096]](%15), uses = [%60.i0]; - %60 : Double(10, 4096) = addmm[beta={1}, alpha={1}](%59, %56, %58), uses = [%61.i0]; - %61 : Double(10, 4096) = threshold[threshold={0}, value={0}, inplace=1](%60), uses = [%64.i1]; - %62 : Double(4096!, 1000!) = t(%16), uses = [%64.i2]; - %63 : Double(10!, 1000) = expand[size=[10, 1000]](%17), uses = [%64.i0]; - %64 : Double(10, 1000) = addmm[beta={1}, alpha={1}](%63, %61, %62), uses = [%0.i0]; - return (%64); +graph(%0 : Double(1, 3, 224, 224) + %1 : Double(64, 3, 11, 11) + %2 : Double(64) + %3 : Double(192, 64, 5, 5) + %4 : Double(192) + %5 : Double(384, 192, 3, 3) + %6 : Double(384) + %7 : Double(256, 384, 3, 3) + %8 : Double(256) + %9 : Double(256, 256, 3, 3) + %10 : Double(256) + %11 : Double(4096, 9216) + %12 : Double(4096) + %13 : Double(4096, 4096) + %14 : Double(4096) + %15 : Double(1000, 4096) + %16 : Double(1000)) { + %17 : Double(1, 64, 55, 55) = aten::_convolution[stride=[4, 4], padding=[2, 2], dilation=[1, 1], transposed=0, output_padding=[0, 0], groups=1, benchmark=0, deterministic=0, cudnn_enabled=1](%0, %1, %2), scope: AlexNet/Sequential[features]/Conv2d[0] + %18 : Double(1, 64, 55, 55) = aten::threshold[threshold={0}, value={0}](%17), scope: AlexNet/Sequential[features]/ReLU[1] + %19 : Double(1, 64, 27, 27), %20 : Long(1, 64, 27, 27) = aten::max_pool2d[kernel_size=[3, 3], stride=[2, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=0](%18), scope: AlexNet/Sequential[features]/MaxPool2d[2] + %21 : Double(1, 192, 27, 27) = aten::_convolution[stride=[1, 1], padding=[2, 2], dilation=[1, 1], transposed=0, output_padding=[0, 0], groups=1, benchmark=0, deterministic=0, cudnn_enabled=1](%19, %3, %4), scope: AlexNet/Sequential[features]/Conv2d[3] + %22 : Double(1, 192, 27, 27) = aten::threshold[threshold={0}, value={0}](%21), scope: AlexNet/Sequential[features]/ReLU[4] + %23 : Double(1, 192, 13, 13), %24 : Long(1, 192, 13, 13) = aten::max_pool2d[kernel_size=[3, 3], stride=[2, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=0](%22), scope: AlexNet/Sequential[features]/MaxPool2d[5] + %25 : Double(1, 384, 13, 13) = aten::_convolution[stride=[1, 1], padding=[1, 1], dilation=[1, 1], transposed=0, output_padding=[0, 0], groups=1, benchmark=0, deterministic=0, cudnn_enabled=1](%23, %5, %6), scope: AlexNet/Sequential[features]/Conv2d[6] + %26 : Double(1, 384, 13, 13) = aten::threshold[threshold={0}, value={0}](%25), scope: AlexNet/Sequential[features]/ReLU[7] + %27 : Double(1, 256, 13, 13) = aten::_convolution[stride=[1, 1], padding=[1, 1], dilation=[1, 1], transposed=0, output_padding=[0, 0], groups=1, benchmark=0, deterministic=0, cudnn_enabled=1](%26, %7, %8), scope: AlexNet/Sequential[features]/Conv2d[8] + %28 : Double(1, 256, 13, 13) = aten::threshold[threshold={0}, value={0}](%27), scope: AlexNet/Sequential[features]/ReLU[9] + %29 : Double(1, 256, 13, 13) = aten::_convolution[stride=[1, 1], padding=[1, 1], dilation=[1, 1], transposed=0, output_padding=[0, 0], groups=1, benchmark=0, deterministic=0, cudnn_enabled=1](%28, %9, %10), scope: AlexNet/Sequential[features]/Conv2d[10] + %30 : Double(1, 256, 13, 13) = aten::threshold[threshold={0}, value={0}](%29), scope: AlexNet/Sequential[features]/ReLU[11] + %31 : Double(1, 256, 6, 6), %32 : Long(1, 256, 6, 6) = aten::max_pool2d[kernel_size=[3, 3], stride=[2, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=0](%30), scope: AlexNet/Sequential[features]/MaxPool2d[12] + %33 : Long() = aten::size[dim=0](%31), scope: AlexNet + %34 : Long() = prim::Constant[value={9216}](), scope: AlexNet + %35 : Dynamic = aten::stack[dim=0](%33, %34), scope: AlexNet + %36 : Double(1, 9216) = aten::view(%31, %35), scope: AlexNet + %37 : Double(1, 9216), %38 : Handle = ^Dropout(0.5, True, False)(%36), scope: AlexNet/Sequential[classifier]/Dropout[0] + %39 : Double(9216!, 4096!) = aten::t(%11), scope: AlexNet/Sequential[classifier]/Linear[1] + %40 : Double(1, 4096) = aten::expand[size=[1, 4096], implicit=1](%12), scope: AlexNet/Sequential[classifier]/Linear[1] + %41 : Double(1, 4096) = aten::addmm[beta={1}, alpha={1}](%40, %37, %39), scope: AlexNet/Sequential[classifier]/Linear[1] + %42 : Double(1, 4096) = aten::threshold[threshold={0}, value={0}](%41), scope: AlexNet/Sequential[classifier]/ReLU[2] + %43 : Double(1, 4096), %44 : Handle = ^Dropout(0.5, True, False)(%42), scope: AlexNet/Sequential[classifier]/Dropout[3] + %45 : Double(4096!, 4096!) = aten::t(%13), scope: AlexNet/Sequential[classifier]/Linear[4] + %46 : Double(1, 4096) = aten::expand[size=[1, 4096], implicit=1](%14), scope: AlexNet/Sequential[classifier]/Linear[4] + %47 : Double(1, 4096) = aten::addmm[beta={1}, alpha={1}](%46, %43, %45), scope: AlexNet/Sequential[classifier]/Linear[4] + %48 : Double(1, 4096) = aten::threshold[threshold={0}, value={0}](%47), scope: AlexNet/Sequential[classifier]/ReLU[5] + %49 : Double(4096!, 1000!) = aten::t(%15), scope: AlexNet/Sequential[classifier]/Linear[6] + %50 : Double(1, 1000) = aten::expand[size=[1, 1000], implicit=1](%16), scope: AlexNet/Sequential[classifier]/Linear[6] + %51 : Double(1, 1000) = aten::addmm[beta={1}, alpha={1}](%50, %48, %49), scope: AlexNet/Sequential[classifier]/Linear[6] + return (%51); } diff --git a/test/expect/TestJit.test_backward.expect b/test/expect/TestJit.test_backward.expect deleted file mode 100644 index d87b0448f6..0000000000 --- a/test/expect/TestJit.test_backward.expect +++ /dev/null @@ -1,23 +0,0 @@ -graph(%0 : Double(2, 2) - %1 : Double(2, 2) - -------- stage 1 -------- - %4 : Double(2, 2) - -------- stage 2 -------- - %8 : Double(2, 2!) - %9 : Double(2, 2)) { - %2 : Double(2, 2) = mul[other={2}](%1) - %3 : Double(2, 2) = mul(%2, %0) - ---------------- stage 1 ---------------- - %5 : Double(2, 2) = mul(%4, %0) - %6 : Double(2, 2) = mul(%4, %2) - %7 : Double(2, 2) = mul[other={2}](%5) - ---------------- stage 2 ---------------- - %10 : Double(2, 2) = mul(%8, %2) - %11 : Double(2, 2) = mul(%8, %4) - %12 : Double(2, 2) = mul[other={2}](%9) - %13 : Double(2, 2) = mul[other={2}](%11) - %14 : Double(2, 2) = mul(%12, %0) - %15 : Double(2, 2) = mul(%12, %4) - %16 : Double(2, 2) = CppOp[N5torch8autograd3AddE](%10, %14) - return (%3, %6, %7, %16, %15, %13); -} diff --git a/test/expect/TestJit.test_backward_opaque.expect b/test/expect/TestJit.test_backward_opaque.expect deleted file mode 100644 index 779de3a335..0000000000 --- a/test/expect/TestJit.test_backward_opaque.expect +++ /dev/null @@ -1,10 +0,0 @@ -graph(%0 : Double(3, 3) - %1 : Double(3, 3) - -------- stage 1 -------- - %3 : Double(3, 3)) { - %2 : Double(3, 3) = cross[dim=-1](%0, %1) - ---------------- stage 1 ---------------- - %4 : Double(3, 3) = cross[dim=-1](%1, %3) - %5 : Double(3, 3) = cross[dim=-1](%3, %0) - return (%2, %4, %5); -} diff --git a/test/expect/TestJit.test_c_function.expect b/test/expect/TestJit.test_c_function.expect deleted file mode 100644 index aafbe7b221..0000000000 --- a/test/expect/TestJit.test_c_function.expect +++ /dev/null @@ -1,6 +0,0 @@ -graph(%0 : Double(1, 3, 10, 10) - %1 : Double(8, 3, 3, 3) - %2 : Double(8)) { - %3 : Double(1, 8, 8, 8) = aten::_convolution[stride=[1, 1], padding=[0, 0], dilation=[1, 1], transposed=0, output_padding=[0, 0], groups=1, benchmark=0, deterministic=0, cudnn_enabled=1](%0, %1, %2) - return (%3); -} diff --git a/test/expect/TestJit.test_concat_fusion.expect b/test/expect/TestJit.test_concat_fusion.expect index 435872b62e..c1b45b1727 100644 --- a/test/expect/TestJit.test_concat_fusion.expect +++ b/test/expect/TestJit.test_concat_fusion.expect @@ -1,6 +1,6 @@ graph(%0 : Float(3, 20) %1 : Float(3, 20)) { - %2 : Float(6, 20) = prim::FusionGroup_0(%0, %1) + %2 : Float(6, 20) = prim::FusionGroup_0[device=0](%0, %1) return (%2); } with prim::FusionGroup_0 = graph(%3 : Float(3, 20) diff --git a/test/expect/TestJit.test_function_as_argument.expect b/test/expect/TestJit.test_function_as_argument.expect deleted file mode 100644 index 2570b327ff..0000000000 --- a/test/expect/TestJit.test_function_as_argument.expect +++ /dev/null @@ -1,26 +0,0 @@ -graph(%1 : Double(3, 10) - %2 : Double(3, 20) - %3 : Double(3, 20) - %4 : Double(80, 10) - %5 : Double(80, 20) - %6 : Double(80) - %7 : Double(80)) { - %8 : Double(10!, 80!) = Transpose[perm=[1, 0]](%4), uses = [%9.i0]; - %9 : UNKNOWN_TYPE = Transpose(%8), uses = [%10.i1]; - %10 : Double(3, 80) = FC(%1, %9, %6), uses = [%14.i0]; - %11 : Double(20!, 80!) = Transpose[perm=[1, 0]](%5), uses = [%12.i0]; - %12 : UNKNOWN_TYPE = Transpose(%11), uses = [%13.i1]; - %13 : Double(3, 80) = FC(%2, %12, %7), uses = [%14.i1]; - %14 : Double(3, 80) = Add(%10, %13), uses = [%15.i0]; - %16 : Double(3!, 20), %17 : Double(3!, 20), %18 : Double(3!, 20), %19 : Double(3!, 20) = Split[split=[20, 20, 20, 20], axis=1](%14), uses = [[%20.i0], [%21.i0], [%22.i0], [%23.i0]]; - %20 : Double(3, 20) = Sigmoid(%16), uses = [%25.i0]; - %21 : Double(3, 20) = Sigmoid(%17), uses = [%24.i0]; - %22 : Double(3, 20) = Tanh(%18), uses = [%25.i1]; - %23 : Double(3, 20) = Sigmoid(%19), uses = [%28.i0]; - %24 : Double(3, 20) = Mul(%21, %3), uses = [%26.i0]; - %25 : Double(3, 20) = Mul(%20, %22), uses = [%26.i1]; - %26 : Double(3, 20) = Add(%24, %25), uses = [%27.i0, %0.i1]; - %27 : Double(3, 20) = Tanh(%26), uses = [%28.i1]; - %28 : Double(3, 20) = Mul(%23, %27), uses = [%0.i0]; - return (%28, %26); -} diff --git a/test/expect/TestJit.test_fuse_last_device.expect b/test/expect/TestJit.test_fuse_last_device.expect new file mode 100644 index 0000000000..276fadc61f --- /dev/null +++ b/test/expect/TestJit.test_fuse_last_device.expect @@ -0,0 +1,14 @@ +graph(%0 : Float(1) + %1 : Float(1)) { + %2 : Float(1) = prim::FusionGroup_0[device=1](%0, %1) + return (%2); +} +with prim::FusionGroup_0 = graph(%6 : Float(1) + %9 : Float(1)) { + %10 : Float(1) = aten::add[alpha={1}](%6, %9) + %8 : Float(1) = aten::mul(%6, %10) + %5 : Float(1) = aten::add[other={1}, alpha={1}](%8) + %3 : Float(1) = aten::tanh(%5) + %1 : Float(1) = aten::sigmoid(%3) + return (%1); +} diff --git a/test/expect/TestJit.test_fusion_distribute-raw.expect b/test/expect/TestJit.test_fusion_distribute-raw.expect deleted file mode 100644 index 290ca22052..0000000000 --- a/test/expect/TestJit.test_fusion_distribute-raw.expect +++ /dev/null @@ -1,7 +0,0 @@ -graph(%0 : Float(4, 4) - %1 : Float(4, 4)) { - %2 : Float(4, 4) = aten::add[alpha={1}](%0, %1) - %3 : Float(4!, 2), %4 : Float(4!, 2) = aten::chunk[chunks=2, dim=1](%2) - %5 : Float(4, 2) = aten::mul(%3, %4) - return (%5); -} diff --git a/test/expect/TestJit.test_fusion_distribute.expect b/test/expect/TestJit.test_fusion_distribute.expect index 0b873fb724..4465074e55 100644 --- a/test/expect/TestJit.test_fusion_distribute.expect +++ b/test/expect/TestJit.test_fusion_distribute.expect @@ -2,7 +2,7 @@ graph(%0 : Float(4, 4) %1 : Float(4, 4)) { %2 : Float(4!, 2), %3 : Float(4!, 2) = aten::chunk[chunks=2, dim=1](%0) %4 : Float(4!, 2), %5 : Float(4!, 2) = aten::chunk[chunks=2, dim=1](%1) - %6 : Float(4, 2) = prim::FusionGroup_0(%2, %4, %3, %5) + %6 : Float(4, 2) = prim::FusionGroup_0[device=0](%2, %4, %3, %5) return (%6); } with prim::FusionGroup_0 = graph(%3 : Float(4!, 2) diff --git a/test/expect/TestJit.test_index_inner.expect b/test/expect/TestJit.test_index_inner.expect deleted file mode 100644 index 02e3fa321f..0000000000 --- a/test/expect/TestJit.test_index_inner.expect +++ /dev/null @@ -1,5 +0,0 @@ -graph(%0 : Double(4, 4)) { - %1 : Double(4) = select[dim=0, index=0](%0) - %2 : Double(4) = add[other={0}, alpha={1}](%1) - return (%2); -} diff --git a/test/expect/TestJit.test_index_trace.expect b/test/expect/TestJit.test_index_trace.expect deleted file mode 100644 index 45423ec05e..0000000000 --- a/test/expect/TestJit.test_index_trace.expect +++ /dev/null @@ -1,8 +0,0 @@ -graph(%0 : Double(4, 4) - -------- stage 1 -------- - %1 : Double(4!)) { - %2 : Double(4) = aten::select[dim=0, index=0](%0) - ---------------- stage 1 ---------------- - %3 : Double(4, 4) = CppOp[AsStridedBackward](%1) - return (%2, %3); -} diff --git a/test/expect/TestJit.test_input_pruning.expect b/test/expect/TestJit.test_input_pruning.expect deleted file mode 100644 index 5464d480ea..0000000000 --- a/test/expect/TestJit.test_input_pruning.expect +++ /dev/null @@ -1,12 +0,0 @@ -graph(%0 : Double(5, 5) - %1 : Double(5, 5) - -------- stage 1 -------- - %4 : Double(5, 5) - %5 : Double(5, 5)) { - %2 : Double(5, 5) = aten::mul(%0, %1) - %3 : Double(5, 5) = aten::add[alpha={1}](%0, %1) - ---------------- stage 1 ---------------- - %6 : Double(5, 5) = aten::mul(%4, %1) - %7 : Double(5, 5) = aten::add[alpha={1}](%5, %6) - return (%2, %3, %7); -} diff --git a/test/expect/TestJit.test_lstm.expect b/test/expect/TestJit.test_lstm.expect deleted file mode 100644 index 3bbdb7c2b7..0000000000 --- a/test/expect/TestJit.test_lstm.expect +++ /dev/null @@ -1,42 +0,0 @@ -graph(%1 : Double(3, 10) - %2 : Double(3, 20) - %3 : Double(3, 20) - %4 : Double(80, 10) - %5 : Double(80, 20) - %6 : Double(80) - %7 : Double(80)) { - %8 : Double(10!, 80!) = Transpose[perm=[1, 0]](%4), uses = [%9.i0]; - %9 : UNKNOWN_TYPE = Transpose(%8), uses = [%10.i1]; - %10 : Double(3, 80) = FC(%1, %9, %6), uses = [%32.i0]; - %11 : Double(20!, 80!) = Transpose[perm=[1, 0]](%5), uses = [%12.i0]; - %12 : UNKNOWN_TYPE = Transpose(%11), uses = [%13.i1]; - %13 : Double(3, 80) = FC(%2, %12, %7), uses = [%33.i0]; - %36 : Double(3!, 20), %39 : Double(3!, 20), %42 : Double(3!, 20), %45 : Double(3!, 20) = Split[split=[20, 20, 20, 20], axis=1](%13), uses = [[%29.i8], [%29.i6], [%29.i4], [%29.i2]]; - %35 : Double(3!, 20), %38 : Double(3!, 20), %41 : Double(3!, 20), %44 : Double(3!, 20) = Split[split=[20, 20, 20, 20], axis=1](%10), uses = [[%29.i7], [%29.i5], [%29.i3], [%29.i1]]; - %30 : Double(3, 20), %31 : Double(3, 20) = fusion_group_0(%3, %44, %45, %41, %42, %38, %39, %35, %36), uses = [[%0.i0], [%0.i1]]; - return (%30, %31); -} -with fusion_group_0 = graph(%13 : Double(3, 20) - %23 : Double(3!, 20) - %24 : Double(3!, 20) - %26 : Double(3!, 20) - %27 : Double(3!, 20) - %29 : Double(3!, 20) - %30 : Double(3!, 20) - %32 : Double(3!, 20) - %33 : Double(3!, 20)) { - %34 : Double(3, 20) = Add(%32, %33), uses = [%22.i0]; - %31 : Double(3, 20) = Add(%29, %30), uses = [%20.i0]; - %28 : Double(3, 20) = Add(%26, %27), uses = [%18.i0]; - %25 : Double(3, 20) = Add(%23, %24), uses = [%16.i0]; - %22 : Double(3, 20) = Sigmoid(%34), uses = [%11.i0]; - %20 : Double(3, 20) = Sigmoid(%31), uses = [%14.i0]; - %18 : Double(3, 20) = Tanh(%28), uses = [%11.i1]; - %16 : Double(3, 20) = Sigmoid(%25), uses = [%3.i0]; - %14 : Double(3, 20) = Mul(%20, %13), uses = [%8.i0]; - %11 : Double(3, 20) = Mul(%22, %18), uses = [%8.i1]; - %8 : Double(3, 20) = Add(%14, %11), uses = [%5.i0, %0.i1]; - %5 : Double(3, 20) = Tanh(%8), uses = [%3.i1]; - %3 : Double(3, 20) = Mul(%16, %5), uses = [%0.i0]; - return (%3, %8); -} diff --git a/test/expect/TestJit.test_lstm_fusion_concat.expect b/test/expect/TestJit.test_lstm_fusion_concat.expect new file mode 100644 index 0000000000..65bee38a8c --- /dev/null +++ b/test/expect/TestJit.test_lstm_fusion_concat.expect @@ -0,0 +1,41 @@ +graph(%0 : Float(3, 10) + %1 : Float(3, 20) + %2 : Float(3, 20) + %3 : Float(80, 10) + %4 : Float(80, 20) + %5 : Float(80) + %6 : Float(80)) { + %7 : Float(10!, 80!) = aten::t(%3) + %8 : Float(3, 80) = aten::addmm[beta={1}, alpha={1}](%5, %0, %7) + %9 : Float(20!, 80!) = aten::t(%4) + %10 : Float(3, 80) = aten::addmm[beta={1}, alpha={1}](%6, %1, %9) + %11 : Float(3!, 20), %12 : Float(3!, 20), %13 : Float(3!, 20), %14 : Float(3!, 20) = aten::chunk[chunks=4, dim=1](%8) + %15 : Float(3!, 20), %16 : Float(3!, 20), %17 : Float(3!, 20), %18 : Float(3!, 20) = aten::chunk[chunks=4, dim=1](%10) + %19 : Float(6, 20) = prim::FusionGroup_0[device=0](%2, %14, %18, %13, %17, %12, %16, %11, %15) + return (%19); +} +with prim::FusionGroup_0 = graph(%14 : Float(3, 20) + %24 : Float(3!, 20) + %25 : Float(3!, 20) + %27 : Float(3!, 20) + %28 : Float(3!, 20) + %30 : Float(3!, 20) + %31 : Float(3!, 20) + %33 : Float(3!, 20) + %34 : Float(3!, 20)) { + %35 : Float(3, 20) = aten::add[alpha={1}](%33, %34) + %32 : Float(3, 20) = aten::add[alpha={1}](%30, %31) + %29 : Float(3, 20) = aten::add[alpha={1}](%27, %28) + %26 : Float(3, 20) = aten::add[alpha={1}](%24, %25) + %23 : Float(3, 20) = aten::sigmoid(%35) + %21 : Float(3, 20) = aten::sigmoid(%32) + %19 : Float(3, 20) = aten::tanh(%29) + %17 : Float(3, 20) = aten::sigmoid(%26) + %15 : Float(3, 20) = aten::mul(%21, %14) + %12 : Float(3, 20) = aten::mul(%23, %19) + %9 : Float(3, 20) = aten::add[alpha={1}](%15, %12) + %6 : Float(3, 20) = aten::tanh(%9) + %5 : Float(3, 20) = aten::mul(%17, %6) + %2 : Float(6, 20) = aten::cat[dim=0](%5, %9) + return (%2); +} diff --git a/test/expect/TestJit.test_lstm_fusion.expect b/test/expect/TestJit.test_lstm_fusion_cpu.expect index 10a537fe98..500996c9bc 100644 --- a/test/expect/TestJit.test_lstm_fusion.expect +++ b/test/expect/TestJit.test_lstm_fusion_cpu.expect @@ -6,15 +6,13 @@ graph(%0 : Float(3, 10) %5 : Float(80) %6 : Float(80)) { %7 : Float(10!, 80!) = aten::t(%3) - %8 : Float(3!, 80) = aten::expand[size=[3, 80], implicit=1](%5) - %9 : Float(3, 80) = aten::addmm[beta={1}, alpha={1}](%8, %0, %7) - %10 : Float(20!, 80!) = aten::t(%4) - %11 : Float(3!, 80) = aten::expand[size=[3, 80], implicit=1](%6) - %12 : Float(3, 80) = aten::addmm[beta={1}, alpha={1}](%11, %1, %10) - %13 : Float(3!, 20), %14 : Float(3!, 20), %15 : Float(3!, 20), %16 : Float(3!, 20) = aten::chunk[chunks=4, dim=1](%9) - %17 : Float(3!, 20), %18 : Float(3!, 20), %19 : Float(3!, 20), %20 : Float(3!, 20) = aten::chunk[chunks=4, dim=1](%12) - %21 : Float(3, 20), %22 : Float(3, 20) = prim::FusionGroup_0(%2, %16, %20, %15, %19, %14, %18, %13, %17) - return (%21, %22); + %8 : Float(3, 80) = aten::addmm[beta={1}, alpha={1}](%5, %0, %7) + %9 : Float(20!, 80!) = aten::t(%4) + %10 : Float(3, 80) = aten::addmm[beta={1}, alpha={1}](%6, %1, %9) + %11 : Float(3!, 20), %12 : Float(3!, 20), %13 : Float(3!, 20), %14 : Float(3!, 20) = aten::chunk[chunks=4, dim=1](%8) + %15 : Float(3!, 20), %16 : Float(3!, 20), %17 : Float(3!, 20), %18 : Float(3!, 20) = aten::chunk[chunks=4, dim=1](%10) + %19 : Float(3, 20), %20 : Float(3, 20) = prim::FusionGroup_0[device=-1](%2, %14, %18, %13, %17, %12, %16, %11, %15) + return (%19, %20); } with prim::FusionGroup_0 = graph(%12 : Float(3, 20) %22 : Float(3!, 20) diff --git a/test/expect/TestJit.test_lstm_fusion_cuda.expect b/test/expect/TestJit.test_lstm_fusion_cuda.expect new file mode 100644 index 0000000000..015fcc1a00 --- /dev/null +++ b/test/expect/TestJit.test_lstm_fusion_cuda.expect @@ -0,0 +1,40 @@ +graph(%0 : Float(3, 10) + %1 : Float(3, 20) + %2 : Float(3, 20) + %3 : Float(80, 10) + %4 : Float(80, 20) + %5 : Float(80) + %6 : Float(80)) { + %7 : Float(10!, 80!) = aten::t(%3) + %8 : Float(3, 80) = aten::addmm[beta={1}, alpha={1}](%5, %0, %7) + %9 : Float(20!, 80!) = aten::t(%4) + %10 : Float(3, 80) = aten::addmm[beta={1}, alpha={1}](%6, %1, %9) + %11 : Float(3!, 20), %12 : Float(3!, 20), %13 : Float(3!, 20), %14 : Float(3!, 20) = aten::chunk[chunks=4, dim=1](%8) + %15 : Float(3!, 20), %16 : Float(3!, 20), %17 : Float(3!, 20), %18 : Float(3!, 20) = aten::chunk[chunks=4, dim=1](%10) + %19 : Float(3, 20), %20 : Float(3, 20) = prim::FusionGroup_0[device=0](%2, %14, %18, %13, %17, %12, %16, %11, %15) + return (%19, %20); +} +with prim::FusionGroup_0 = graph(%12 : Float(3, 20) + %22 : Float(3!, 20) + %23 : Float(3!, 20) + %25 : Float(3!, 20) + %26 : Float(3!, 20) + %28 : Float(3!, 20) + %29 : Float(3!, 20) + %31 : Float(3!, 20) + %32 : Float(3!, 20)) { + %33 : Float(3, 20) = aten::add[alpha={1}](%31, %32) + %30 : Float(3, 20) = aten::add[alpha={1}](%28, %29) + %27 : Float(3, 20) = aten::add[alpha={1}](%25, %26) + %24 : Float(3, 20) = aten::add[alpha={1}](%22, %23) + %21 : Float(3, 20) = aten::sigmoid(%33) + %19 : Float(3, 20) = aten::sigmoid(%30) + %17 : Float(3, 20) = aten::tanh(%27) + %15 : Float(3, 20) = aten::sigmoid(%24) + %13 : Float(3, 20) = aten::mul(%19, %12) + %10 : Float(3, 20) = aten::mul(%21, %17) + %7 : Float(3, 20) = aten::add[alpha={1}](%13, %10) + %4 : Float(3, 20) = aten::tanh(%7) + %2 : Float(3, 20) = aten::mul(%15, %4) + return (%2, %7); +} diff --git a/test/expect/TestJit.test_matmul_native.expect b/test/expect/TestJit.test_matmul_native.expect deleted file mode 100644 index 84f234edbb..0000000000 --- a/test/expect/TestJit.test_matmul_native.expect +++ /dev/null @@ -1,5 +0,0 @@ -graph(%0 : Double(1, 1) - %1 : Double(1, 1)) { - %2 : Double(1, 1) = aten::matmul(%0, %1) - return (%2); -} diff --git a/test/expect/TestJit.test_output_pruning.expect b/test/expect/TestJit.test_output_pruning.expect deleted file mode 100644 index a89e0b4678..0000000000 --- a/test/expect/TestJit.test_output_pruning.expect +++ /dev/null @@ -1,10 +0,0 @@ -graph(%0 : Double(5, 5) - %1 : Double(5, 5) - -------- stage 1 -------- - %4 : Double(5, 5)) { - %2 : Double(5, 5) = aten::mul(%0, %1) - %3 : Double(5, 5) = aten::add[alpha={1}](%1, %1) - ---------------- stage 1 ---------------- - %5 : Double(5, 5) = aten::mul(%4, %1) - return (%2, %3, %5); -} diff --git a/test/expect/TestJit.test_repeated_input.expect b/test/expect/TestJit.test_repeated_input.expect index 8db8e85d9f..57e57066ef 100644 --- a/test/expect/TestJit.test_repeated_input.expect +++ b/test/expect/TestJit.test_repeated_input.expect @@ -1,7 +1,5 @@ graph(%0 : Double(2, 2) - %1 : Double(2, 2) - -------- stage 1 -------- - %3 : Double(2, 2!)) { + %1 : Double(2, 2)) { %2 : Double(2, 2) = aten::add[alpha={1}](%0, %1) - return (%2, %3, %3); + return (%2); } diff --git a/test/expect/TestJit.test_repeated_output.expect b/test/expect/TestJit.test_repeated_output.expect index 1f8c156f15..b3baff631e 100644 --- a/test/expect/TestJit.test_repeated_output.expect +++ b/test/expect/TestJit.test_repeated_output.expect @@ -1,10 +1,5 @@ graph(%0 : Double(2, 2) - %1 : Double(2, 2) - -------- stage 1 -------- - %3 : Double(2, 2) - %4 : Double(2, 2)) { + %1 : Double(2, 2)) { %2 : Double(2, 2) = aten::add[alpha={1}](%0, %1) - ---------------- stage 1 ---------------- - %5 : Double(2, 2) = aten::add[alpha={1}](%3, %4) - return (%2, %2, %5, %5); + return (%2, %2); } diff --git a/test/expect/TestJit.test_saved_output.expect b/test/expect/TestJit.test_saved_output.expect deleted file mode 100644 index 77145061e1..0000000000 --- a/test/expect/TestJit.test_saved_output.expect +++ /dev/null @@ -1,8 +0,0 @@ -graph(%0 : Double(4, 4) - -------- stage 1 -------- - %2 : Double(4, 4!)) { - %1 : Double(4, 4) = aten::sigmoid(%0) - ---------------- stage 1 ---------------- - %3 : Double(4, 4) = aten::_sigmoid_backward(%2, %1) - return (%1, %3); -} diff --git a/test/expect/TestJit.test_shape_analysis_broadcast.expect b/test/expect/TestJit.test_shape_analysis_broadcast.expect index 1f42efa52e..bbe5b74164 100644 --- a/test/expect/TestJit.test_shape_analysis_broadcast.expect +++ b/test/expect/TestJit.test_shape_analysis_broadcast.expect @@ -1,7 +1,7 @@ graph(%a : Double(3, 1, 5) %b : Double(4, 1, 8, 5)) { - %3 : Double(4!, 3!, 8!, 5) = aten::expand[size=[4, 3, 8, 5], implicit=0](%a) - %4 : Double(4!, 3!, 8, 5) = aten::expand[size=[4, 3, 8, 5], implicit=0](%b) - %2 : Double(4, 3, 8, 5) = aten::add[alpha={1}](%3, %4) - return (%2); + %2 : Double(4!, 3!, 8!, 5) = aten::expand[size=[4, 3, 8, 5], implicit=0](%a) + %3 : Double(4!, 3!, 8, 5) = aten::expand[size=[4, 3, 8, 5], implicit=0](%b) + %4 : Double(4, 3, 8, 5) = aten::add[alpha={1}](%2, %3) + return (%4); } diff --git a/test/test_jit.py b/test/test_jit.py index 0428cd6d00..a48a0e8c50 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -7,6 +7,7 @@ from itertools import product, chain import torch.jit.frontend from torch.autograd import Variable, Function from torch.autograd.function import traceable +from torch.testing import assert_allclose from common import TestCase, run_tests, IS_WINDOWS from textwrap import dedent import os @@ -18,6 +19,7 @@ import textwrap import numpy as np import tempfile import shutil +import warnings from torch.jit.frontend import NotSupportedError @@ -45,6 +47,10 @@ PY35 = sys.version_info >= (3, 5) WINDOWS = sys.platform == 'win32' +def LSTMCellF(input, hx, cx, *params): + return LSTMCell(input, (hx, cx), *params) + + def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): hx, cx = hidden gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh) @@ -61,97 +67,91 @@ def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): def LSTMCellC(*args, **kwargs): - hy, cy = LSTMCell(*args, **kwargs) + hy, cy = LSTMCellF(*args, **kwargs) return torch.cat((hy, cy)) +def get_lstm_inputs(device): + input = torch.randn(3, 10, dtype=torch.float, device=device) + hx = torch.randn(3, 20, dtype=torch.float, device=device) + cx = torch.randn(3, 20, dtype=torch.float, device=device) + module = nn.LSTMCell(10, 20).to(torch.float, device) # Just to allocate weights with correct sizes + return (input, hx, cx) + tuple(p.requires_grad_(False) for p in module.parameters()) + + class TestJit(TestCase): - maxDiff = None + def assertExpectedONNXGraph(self, trace, *args, **kwargs): + torch.onnx._optimize_trace(trace, aten=False) + self.assertExpectedGraph(trace, *args, **kwargs) - @contextmanager - def assertCompiled(self, compiled_fn): - self.assertIsInstance(compiled_fn, torch._C.CompiledFunction) - hits, misses = compiled_fn.hits, compiled_fn.misses - yield - self.assertLess(hits, compiled_fn.hits) - self.assertEqual(misses, compiled_fn.misses) - - def assertExpectedTrace(self, trace, *args, **kwargs): - torch._C._jit_pass_lint(trace.graph()) - torch._C._jit_pass_dce(trace.graph()) - torch._C._jit_pass_lint(trace.graph()) - trace.set_graph(torch._C._jit_pass_canonicalize(trace.graph())) - torch._C._jit_pass_lint(trace.graph()) - self.assertExpected(str(trace), *args, **kwargs) + def assertExpectedGraph(self, trace, *args, **kwargs): + if isinstance(trace, torch._C.Graph): + graph = trace + else: + graph = trace.graph() + + torch._C._jit_pass_lint(graph) + torch._C._jit_pass_dce(graph) + torch._C._jit_pass_lint(graph) + graph = torch._C._jit_pass_canonicalize(graph) + torch._C._jit_pass_lint(graph) + self.assertExpected(str(graph), *args, **kwargs) def assertExportImport(self, trace, inputs): initializers = [] - graph1 = trace.graph() - ge1 = torch._C.GraphExecutor(graph1, False) - out1 = ge1(*inputs) + def run(graph): + return torch._C.GraphExecutor(graph, False)(*inputs) + proto, _ = trace.graph().export(initializers, onnx_opset_version=0, defer_weight_export=False, export_raw_ir=True) + self.assertFalse(initializers) - graph2, initializers = torch._C._jit_import_graph(proto) - ge2 = torch._C.GraphExecutor(graph2, False) - out2 = ge2(*inputs) + imported_graph, initializers = torch._C._jit_import_graph(proto) + self.assertFalse(initializers) - self.assertEqual(out1, out2) + self.assertEqual(run(trace.graph()), run(imported_graph)) - def test_simple(self): - x = Variable(torch.Tensor([0.4]), requires_grad=True) - y = Variable(torch.Tensor([0.7]), requires_grad=True) + def run_pass(self, name, trace): + if isinstance(trace, torch._C.Graph): + graph = trace + set_graph = False + else: + set_graph = True + graph = trace.graph() - def f(x, y): - return torch.sigmoid(torch.tanh(x * (x + y))) + torch._C._jit_pass_lint(graph) + result = getattr(torch._C, '_jit_pass_' + name)(graph) + if result is not None: + graph = result + torch._C._jit_pass_lint(graph) - trace, z = torch.jit.get_trace_graph(f, (x, y), nderivs=0) - self.assertExpectedTrace(trace) - self.assertExportImport(trace, (x, y)) + if set_graph: + trace.set_graph(graph) + return graph - # matmul is currently implemented as a native function, which - # exercises different codepaths in the JIT. The following two - # tests ensure that (1) matmul indeed traces into an atomic, - # native operation, and (2) the JIT knows how to run it + def test_simple(self): + x = torch.tensor([0.4], requires_grad=True) + y = torch.tensor([0.7], requires_grad=True) - def test_matmul_native(self): - x = Variable(torch.Tensor([[0.4]]), requires_grad=True) - y = Variable(torch.Tensor([[0.7]]), requires_grad=True) + def f(x, y): + return torch.sigmoid(torch.tanh(x * (x + y))) - trace, z = torch.jit.get_trace_graph(lambda x, y: x.matmul(y), (x, y), nderivs=0) - torch._C._jit_pass_lint(trace.graph()) - torch._C._jit_pass_dce(trace.graph()) - self.assertExpectedTrace(trace) + trace, z = torch.jit.get_trace_graph(f, (x, y)) + self.assertExpectedGraph(trace) self.assertExportImport(trace, (x, y)) - def test_matmul_native_run(self): - x = Variable(torch.Tensor([[0.4]]), requires_grad=True) - y = Variable(torch.Tensor([[0.7]]), requires_grad=True) - - @torch.jit.compile(nderivs=0) - def fn(x, y): - return x.matmul(y) - - z = fn(x, y) - with self.assertCompiled(fn): - z2 = fn(x, y) - self.assertEqual(z, z2) - # index-2 is not implemented in interpreter @unittest.expectedFailure def test_index(self): - x = Variable(torch.Tensor([0.4]), requires_grad=True) - y = Variable(torch.LongTensor([0]), requires_grad=True) + x = torch.tensor([0.4], requires_grad=True) + y = torch.tensor([0], dtype=torch.int64, requires_grad=True) @torch.jit.compile(nderivs=0) def fn(x, y): return x[y] - z = fn(x, y) - with self.assertCompiled(fn): - z2 = fn(x, y) - self.assertEqual(z, z2) + fn(x, y) # Fails # Backwards tracing was broken for indexing by a constant, # because it's internally implemented using as_strided, @@ -159,26 +159,22 @@ class TestJit(TestCase): # currently supported.) It currently works because # slice() is now not marked as traceable. def test_index_constant(self): - x = Variable(torch.Tensor([0.4]), requires_grad=True) + x = torch.tensor([0.4], requires_grad=True) - @torch.jit.compile(nderivs=1) def fn(x): return x[0] - z = fn(x) - z.backward() - grad = x.grad.clone() - x.grad.zero_() - with self.assertCompiled(fn): - z2 = fn(x) - z2.backward() - grad2 = x.grad.clone() - self.assertEqual(z, z2) - self.assertEqual(grad, grad2) + def run(f): + y = f(x) + grad = torch.autograd.grad(y, x)[0].clone() + return y, grad + + traced_fn = torch.jit.trace(torch.ones(1))(fn) + self.assertEqual(run(fn), run(traced_fn)) def test_scopes(self): - x = Variable(torch.Tensor([0.4]), requires_grad=True) - y = Variable(torch.Tensor([0.7]), requires_grad=True) + x = torch.tensor([0.4], requires_grad=True) + y = torch.tensor([0.7], requires_grad=True) def f(x, y): out = x + y @@ -190,7 +186,7 @@ class TestJit(TestCase): return out trace, z = torch.jit.get_trace_graph(f, (x, y), nderivs=0) - self.assertExpectedTrace(trace) + self.assertExpectedGraph(trace) self.assertExportImport(trace, (x, y)) def test_scopes_intermediate_node(self): @@ -200,12 +196,10 @@ class TestJit(TestCase): return F.log_softmax(x, dim=0) net = Net() - t = Variable(torch.ones(2), requires_grad=True) - trace, _ = torch.jit.get_trace_graph(net, (t, )) - self.assertExportImport(trace, (t, )) - torch.onnx._optimize_trace(trace, False) - - self.assertExpectedTrace(trace) + t = torch.ones(2, requires_grad=True) + trace, _ = torch.jit.get_trace_graph(net, (t,)) + self.assertExportImport(trace, (t,)) + self.assertExpectedONNXGraph(trace) def test_scopes_identity_node(self): @@ -225,93 +219,55 @@ class TestJit(TestCase): model = Net() - t = Variable(torch.ones(1, 3, 227, 227), requires_grad=True) + t = torch.ones(1, 3, 227, 227, requires_grad=True) with torch.onnx.set_training(model, False): - trace, _ = torch.jit.get_trace_graph(model, (t, )) - - self.assertExportImport(trace, (t, ) + tuple(model.parameters())) + trace, _ = torch.jit.get_trace_graph(model, (t,)) - torch.onnx._optimize_trace(trace, False) - - self.assertExpectedTrace(trace) + self.assertExportImport(trace, (t,) + tuple(model.parameters())) + self.assertExpectedONNXGraph(trace) + # TODO: Fuser doesn't work at all when inputs require grad. Fix that @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_lstm_fusion(self): - input = Variable(torch.randn(3, 10).float().cuda()) - hx = Variable(torch.randn(3, 20).float().cuda()) - cx = Variable(torch.randn(3, 20).float().cuda()) - module = nn.LSTMCell(10, 20).float().cuda() # Just to allocate weights with correct sizes - - trace, _ = torch.jit.get_trace_graph(LSTMCell, (input, (hx, cx)) + tuple(module.parameters())) - torch._C._jit_pass_lint(trace.graph()) - torch._C._jit_pass_dce(trace.graph()) - torch._C._jit_pass_lint(trace.graph()) - self.assertExportImport(trace, (input, hx, cx) + tuple(module.parameters())) - torch._C._jit_pass_fuse(trace.graph()) - self.assertExpectedTrace(trace) - - def run_lstm_fusion(self, use_cuda): - def to_type(x): - x = x.float() - if use_cuda: - x = x.cuda() - return x - - def rand_v(a, b): - return Variable(to_type(torch.randn(a, b))) - - input = rand_v(3, 10) - hx = rand_v(3, 20) - cx = rand_v(3, 20) - module = to_type(nn.LSTMCell(10, 20)) # Just to allocate weights with correct sizes - - CompiledLSTMCell = torch.jit.compile(nderivs=0)(LSTMCell) - - z = CompiledLSTMCell(input, (hx, cx), *module.parameters()) - with self.assertCompiled(CompiledLSTMCell): - z2 = CompiledLSTMCell(input, (hx, cx), *module.parameters()) - self.assertEqual(z, z2) - - @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_run_lstm_fusion_cuda(self): - self.run_lstm_fusion(True) + def test_lstm_fusion_cuda(self): + inputs = get_lstm_inputs('cuda') + ge = self.checkTrace(LSTMCellF, inputs) + self.assertExpectedGraph(ge.graph_for(*inputs)) @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") - def test_run_lstm_fusion_cpu(self): - self.run_lstm_fusion(False) + def test_lstm_fusion_cpu(self): + inputs = get_lstm_inputs('cpu') + try: + ge = self.checkTrace(LSTMCellF, inputs) + self.assertExpectedGraph(ge.graph_for(*inputs)) + except RuntimeError as e: + if 'Failed to compile' in e.args[0]: + warnings.warn('CPU fuser test has failed! This is not a hard failure, ' + 'because the kernels sometimes trigger bugs in compilers ' + '(most notably GCC 7.2).') + raise unittest.SkipTest('Failed to compile') + else: + raise @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_run_lstm_fusion_concat(self): - input = Variable(torch.randn(3, 10).float().cuda()) - hx = Variable(torch.randn(3, 20).float().cuda()) - cx = Variable(torch.randn(3, 20).float().cuda()) - module = nn.LSTMCell(10, 20).float().cuda() # Just to allocate weights with correct sizes - - CompiledLSTMCell = torch.jit.compile(nderivs=0)(LSTMCellC) - - z = CompiledLSTMCell(input, (hx, cx), *module.parameters()) - with self.assertCompiled(CompiledLSTMCell): - z2 = CompiledLSTMCell(input, (hx, cx), *module.parameters()) - self.assertEqual(z, z2) + def test_lstm_fusion_concat(self): + inputs = get_lstm_inputs('cuda') + ge = self.checkTrace(LSTMCellC, inputs) + self.assertExpectedGraph(ge.graph_for(*inputs)) @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_concat_fusion(self): - hx = Variable(torch.randn(3, 20).float().cuda()) - cx = Variable(torch.randn(3, 20).float().cuda()) + hx = torch.randn(3, 20, dtype=torch.float, device='cuda') + cx = torch.randn(3, 20, dtype=torch.float, device='cuda') - def Foo(hx, cx): + def foo(hx, cx): return torch.cat((hx + cx, hx * cx)) - trace, _ = torch.jit.get_trace_graph(Foo, (hx, cx)) - self.assertExportImport(trace, (hx, cx)) - torch._C._jit_pass_lint(trace.graph()) - torch._C._jit_pass_fuse(trace.graph()) - self.assertExpectedTrace(trace) + ge = self.checkTrace(foo, (hx, cx)) + self.assertExpectedGraph(ge.graph_for(hx, cx)) @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @@ -319,16 +275,15 @@ class TestJit(TestCase): def f(x, y): z1, z2 = (x + y).chunk(2, dim=1) return z1 * z2 - x = Variable(torch.randn(4, 4).float().cuda()) - y = Variable(torch.randn(4, 4).float().cuda()) - trace, _ = torch.jit.get_trace_graph(f, (x, y), nderivs=0) - torch._C._jit_pass_lint(trace.graph()) - torch._C._jit_pass_dce(trace.graph()) - self.assertExpectedTrace(trace, 'raw') - self.assertExportImport(trace, (x, y)) - torch._C._jit_pass_fuse(trace.graph()) - self.assertExpectedTrace(trace) + x = torch.randn(4, 4, dtype=torch.float, device='cuda') + y = torch.randn(4, 4, dtype=torch.float, device='cuda') + + ge = self.checkTrace(f, (x, y)) + self.assertExpectedGraph(ge.graph_for(x, y)) + + # TODO: adapt this test to check that GraphExecutor treats them differently + @unittest.skip("Need to be adjusted to Graph Executor") def test_arg_configurations(self): """Different arg configurations should trigger different traces""" x = Variable(torch.FloatTensor(4, 4).uniform_()) @@ -373,97 +328,43 @@ class TestJit(TestCase): self.assertEqual(fn.hits, 0) def test_cse(self): - x = Variable(torch.Tensor([0.4, 0.3]), requires_grad=True) - y = Variable(torch.Tensor([0.7, 0.5]), requires_grad=True) - - trace, inputs = torch._C._tracer_enter((x, y), 0) + x = torch.tensor([0.4, 0.3], requires_grad=True) + y = torch.tensor([0.7, 0.5], requires_grad=True) def fn(x, y): w = (x + y) * (x + y) * (x + y) t = torch.tanh(w) + torch.tanh(w) z = (x + y) * (x + y) * (x + y) + t return z - z = fn(*inputs) - torch._C._tracer_exit((z,)) - torch._C._jit_pass_lint(trace.graph()) - torch._C._jit_pass_cse(trace.graph()) - self.assertExpectedTrace(trace) + trace, _ = torch.jit.get_trace_graph(fn, (x, y), nderivs=0) + self.run_pass('cse', trace) + self.assertExpectedGraph(trace) self.assertExportImport(trace, (x, y)) - def test_compile_run_twice(self): - x = Variable(torch.Tensor([0.4]), requires_grad=True) - y = Variable(torch.Tensor([0.7]), requires_grad=True) - - @torch.jit.compile(nderivs=0, optimize=False) - def doit(x, y): - return torch.sigmoid(torch.tanh(x * (x + y))) - - z = doit(x, y) - with self.assertCompiled(doit): - z2 = doit(x, y) - self.assertEqual(z, torch.sigmoid(torch.tanh(x * (x + y)))) - self.assertEqual(z, z2) - - @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_compile_addc(self): - x = Variable(torch.Tensor([0.4]), requires_grad=True).float().cuda() - y = Variable(torch.Tensor([0.7]), requires_grad=True).float().cuda() + def test_shape_analysis_broadcast(self): + def broadcast(a, b): + return a + b - @torch.jit.compile(nderivs=0) - def doit(x, y): - return torch.sigmoid(torch.tanh(x * (x + y) + 1)) + x = torch.randn(3, 1, 5, requires_grad=True) + y = torch.randn(4, 1, 8, 5, requires_grad=True) - z = doit(x, y) - with self.assertCompiled(doit): - z2 = doit(x, y) - self.assertEqual(z, torch.sigmoid(torch.tanh(x * (x + y) + 1))) - self.assertEqual(z, z2) + graph = torch.jit._script_graph(broadcast) + torch._C._jit_pass_shape_analysis(graph, (x, y), False) + self.assertExpectedGraph(graph) @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device") - def test_compile_fuse_last_device(self): - max_device = torch.cuda.device_count() - 1 - x = Variable(torch.Tensor([0.4]), requires_grad=True).float().cuda(max_device) - y = Variable(torch.Tensor([0.7]), requires_grad=True).float().cuda(max_device) + def test_fuse_last_device(self): + device = 'cuda:' + str(torch.cuda.device_count() - 1) + x = torch.tensor([0.4], dtype=torch.float, device=device) + y = torch.tensor([0.7], dtype=torch.float, device=device) - @torch.jit.compile(nderivs=0) def doit(x, y): return torch.sigmoid(torch.tanh(x * (x + y) + 1)) - z = doit(x, y) - with self.assertCompiled(doit): - z2 = doit(x, y) - self.assertEqual(z, torch.sigmoid(torch.tanh(x * (x + y) + 1))) - self.assertEqual(z, z2) - - def test_traced_function(self): - x = Variable(torch.Tensor([0.4]), requires_grad=True) - y = Variable(torch.Tensor([0.7]), requires_grad=True) - - @torch.jit.compile(nderivs=0) - def doit(x, y): - return torch.sigmoid(torch.tanh(x * (x + y))) - - z = doit(x, y) - with self.assertCompiled(doit): - z2 = doit(x, y) - self.assertEqual(z, torch.sigmoid(torch.tanh(x * (x + y)))) - self.assertEqual(z, z2) - - def test_disabled_traced_function(self): - x = Variable(torch.Tensor([0.4]), requires_grad=True) - y = Variable(torch.Tensor([0.7]), requires_grad=True) - - @torch.jit.compile(enabled=False) - def doit(x, y): - return torch.sigmoid(torch.tanh(x * (x + y))) - - z = doit(x, y) - z2 = doit(x, y) - self.assertEqual(z, torch.sigmoid(torch.tanh(x * (x + y)))) - self.assertEqual(z, z2) + ge = self.checkTrace(doit, (x, y)) + self.assertExpectedGraph(ge.graph_for(x, y)) def test_assign_traces(self): """Check that output Variables are assigned traces before they are saved.""" @@ -480,61 +381,17 @@ class TestJit(TestCase): a, = ctx.saved_tensors return a * grad_a - x = Variable(torch.randn(10, 10), requires_grad=True) + x = torch.randn(10, 10, requires_grad=True) trace, out = torch.jit.get_trace_graph(MyFn.apply, x, nderivs=1) out.sum().backward() - torch._C._jit_pass_dce(trace.graph()) - self.assertExpectedTrace(trace) - - def test_legacy_traced_module(self): - input = Variable(torch.randn(3, 10)) - hx = Variable(torch.randn(3, 20)) - cx = Variable(torch.randn(3, 20)) - - @torch.jit.compile(nderivs=0) - class MyLSTMCell(nn.LSTMCell): - pass - - lstm = MyLSTMCell(10, 20) - - out = lstm(input, (hx, cx)) - with self.assertCompiled(lstm): - out2 = lstm(input, (hx, cx)) - self.assertEqual(out, out2) - - def test_autograd_closure(self): - x = Variable(torch.Tensor([0.4]), requires_grad=True) - y = Variable(torch.Tensor([0.7]), requires_grad=True) - - trace, inputs = torch._C._tracer_enter((x, y), 1) - - def fn(x, y): - z = torch.sigmoid(x * (x + y)) - w = torch.abs(x * x * x + y) + Variable(torch.ones(1)) - return z, w - z, w = fn(*inputs) - - torch._C._tracer_exit((z, w)) - torch._C._jit_pass_lint(trace.graph()) - - (z * w).backward() - torch._C._jit_pass_dce(trace.graph()) - torch._C._jit_pass_lint(trace.graph()) - - x_grad = x.grad.data.clone() - x.grad.data.zero_() - - function = torch._C._jit_createInterpreterFactory(trace) - torch._C._jit_pass_lint(trace.graph()) - z2, w2 = function()(x, y) - (z2 * w2).backward() - self.assertEqual(z, z2) - self.assertEqual(w, w2) - self.assertEqual(x.grad.data, x_grad) + self.run_pass('dce', trace) + self.assertExpectedGraph(trace) + # TODO: update verify to work with GraphExecutors + @unittest.skip("verify needs to be updated to work with GraphExecutors") def test_verify(self): - x = Variable(torch.Tensor([0.4]), requires_grad=True) - y = Variable(torch.Tensor([0.7]), requires_grad=True) + x = torch.tensor([0.4], requires_grad=True) + y = torch.tensor([0.7], requires_grad=True) @torch.jit.compile def f(x, y): @@ -545,37 +402,14 @@ class TestJit(TestCase): torch.jit.verify(f, (x, y), loss_fn=lambda z, w: z * w, devices=[]) def test_constant(self): - x = Variable(torch.randn(2, 2), requires_grad=True) - - trace, (tx,) = torch._C._tracer_enter((x,), 0) - - y = Variable(torch.diag(torch.Tensor([2, 2]))) - z = tx.matmul(y) - - torch._C._tracer_exit((z,)) - function = torch._C._jit_createInterpreterFactory(trace) - - z2 = function()(x) - self.assertEqual(z, z2) - - y.data.fill_(1000) # make sure the data has been cloned + x = torch.randn(2, 2, requires_grad=True) - x2 = Variable(torch.ones(2, 2) * 2, requires_grad=True) - z3 = function()(x2) - self.assertEqual(z3.data, torch.ones(2, 2) * 4) - - def test_c_function(self): - x = Variable(torch.randn(1, 3, 10, 10)) - m = nn.Conv2d(3, 8, 3, 1) + def f(x): + return x.matmul(torch.diag(torch.tensor([2., 2.]))) - trace, inputs = torch._C._tracer_enter((x,) + tuple(m.parameters()), 0) - y = m(inputs[0]) - torch._C._tracer_exit((y,)) - self.assertExpectedTrace(trace) - self.assertExportImport(trace, inputs) + self.checkTrace(f, (x,), (torch.ones(2, 2, requires_grad=True),)) def test_legacy_fail(self): - class MyLegacyFn(Function): def forward(self, x): return x @@ -583,24 +417,22 @@ class TestJit(TestCase): def backward(self, grad_output): return grad_output - x = Variable(torch.Tensor([0]), requires_grad=True) - trace, inputs = torch._C._tracer_enter((x,), 0) - self.assertRaisesRegex(RuntimeError, "MyLegacyFn", lambda: MyLegacyFn()(*inputs)) - torch._C._tracer_exit(inputs) + x = torch.tensor([0.], requires_grad=True) + with self.assertRaisesRegex(RuntimeError, "MyLegacyFn"): + torch.jit.get_trace_graph(lambda x: MyLegacyFn()(x), (x,), nderivs=0) def test_inplace_transplant(self): - x = Variable(torch.Tensor([0]), requires_grad=True) - trace, inputs = torch._C._tracer_enter((x,), 0) + x = torch.tensor([0.], requires_grad=True) def fn(x): y = x.clone() y.add_(2) y.add_(3) return y - y = fn(*inputs) - torch._C._tracer_exit((y,)) - self.assertExpectedTrace(trace) - self.assertExportImport(trace, inputs) + + trace, _ = torch.jit.get_trace_graph(fn, (x,), nderivs=0) + self.assertExpectedGraph(trace) + self.assertExportImport(trace, (x,)) def test_inplace_flags(self): class InplaceFn(Function): @@ -622,8 +454,7 @@ class TestJit(TestCase): def backward(ctx, go): return go - x = Variable(torch.Tensor([0]), requires_grad=True) - trace, inputs = torch._C._tracer_enter((x,), 0) + x = torch.tensor([0], requires_grad=True) def fn(x): y = RegularFn.apply(x) @@ -631,9 +462,9 @@ class TestJit(TestCase): y = InplaceFn.apply(y) y = RegularFn.apply(y) return y - y = fn(*inputs) - torch._C._tracer_exit((y,)) - torch._C._jit_pass_dce(trace.graph()) + + trace, _ = torch.jit.get_trace_graph(fn, (x,), nderivs=0) + self.run_pass('dce', trace) ops = [n for n in trace.graph().nodes()] for op in ops: self.assertTrue(op.hasAttribute('inplace')) @@ -653,125 +484,13 @@ class TestJit(TestCase): def backward(self, grad): return grad - @torch.jit.compile(nderivs=0) def fn(x): return MyInplaceFn.apply(x) - x = Variable(torch.randn(5, 5)) - fn(x) # trace - with self.assertRaisesRegex(RuntimeError, 'inplace MyInplaceFn'): - fn(x) - - def test_backward(self): - a = Variable(torch.randn(2, 2), requires_grad=True) - b = Variable(torch.randn(2, 2), requires_grad=True) - - x = a - y = a * b - - trace, inputs = torch._C._tracer_enter((x, y), 2) - def fn(x, y): - return y * 2 * x - z = fn(*inputs) - torch._C._tracer_exit((z,)) - torch._C._jit_pass_lint(trace.graph()) - - # Run first backward - grad, = torch.autograd.grad(z, x, Variable(torch.ones(2, 2), requires_grad=True), create_graph=True) - torch._C._jit_pass_lint(trace.graph()) - - # Run second backward - grad.sum().backward(create_graph=True) - torch._C._jit_pass_lint(trace.graph()) - - # Run dead code elimination to remove unused trace nodes - torch._C._jit_pass_dce(trace.graph()) - # This is nondeterministic, see: - # https://github.com/ezyang/pytorch/issues/227 - # self.assertExpectedTrace(trace) - self.skipTest("output is nondeterministic on Travis/Python 3.5") - - def test_backward_opaque(self): - x = Variable(torch.randn(3, 3), requires_grad=True) - y = Variable(torch.randn(3, 3), requires_grad=True) - - trace, inputs = torch._C._tracer_enter((x, y), 2) - - def fn(x, y): - return x.cross(y) - z = fn(*inputs) - torch._C._tracer_exit((z,)) - torch._C._jit_pass_lint(trace.graph()) - - # Run first backward - grad, = torch.autograd.grad(z, x, Variable(torch.ones(3, 3), requires_grad=True), create_graph=True) - torch._C._jit_pass_lint(trace.graph()) - - # Run dead code elimination to remove unused trace nodes - torch._C._jit_pass_dce(trace.graph()) - # This is nondeterministic, see: - # https://github.com/ezyang/pytorch/issues/227 - # self.assertExpectedTrace(trace) - self.skipTest("output is nondeterministic on Travis/Python 3.5") - - def test_backward_closure(self): - """Check that autograd closures handle multiple stages correctly.""" - x = Variable(torch.randn(1), requires_grad=True) - - @torch.jit.compile(nderivs=2) - def fn(x): - return x * x - - # Generate trace - grad_x, = torch.autograd.grad(fn(x), (x,), create_graph=True) - self.assertFalse(fn.has_trace_for(x)) - grad_x.backward() - self.assertTrue(fn.has_trace_for(x)) - - x_grad = x.grad.data.clone() - x.grad.data.zero_() - - # Run the trace - with self.assertCompiled(fn): - output = fn(x) - grad_x, = torch.autograd.grad(output, (x,), create_graph=True) - grad_x.backward() - - self.assertEqual(x.grad.data, x_grad) - - def test_trace_expire(self): - x = Variable(torch.randn(2, 2), requires_grad=True) - y = Variable(torch.randn(2, 2), requires_grad=True) - - def record_trace(num_backwards): - trace, inputs = torch._C._tracer_enter((x, y), num_backwards) - - def fn(x, y): - return y * 2 * x - z = fn(*inputs) - torch._C._tracer_exit((z,)) - return z, trace - - def check(expired, complete): - self.assertEqual(trace.is_expired, expired) - self.assertEqual(trace.is_complete, complete) - - z, trace = record_trace(0) - check(False, True) - del z - check(False, True) - - z, trace = record_trace(1) - check(False, False) - del z - check(True, False) - - z, trace = record_trace(1) - check(False, False) - z.sum().backward() - check(False, True) - del z - check(False, True) + x = torch.randn(5, 5) + ge = torch._C.GraphExecutor(fn, (x,)) + with self.assertRaisesRegex(RuntimeError, 'inplace MyInplaceFn'): + ge(x) def do_trace_size(self, requires_grad): def fn(x): @@ -787,7 +506,7 @@ class TestJit(TestCase): # Check that the trace looks ok trace, _ = torch.jit.get_trace_graph(fn, (x,)) - self.assertExpectedTrace(trace) + self.assertExpectedGraph(trace) def test_trace_size(self): self.do_trace_size(False) @@ -797,62 +516,31 @@ class TestJit(TestCase): def test_trace_size_with_grad(self): self.do_trace_size(True) - def test_multiuse_fn(self): - x = Variable(torch.randn(2, 2), requires_grad=True) - w = Variable(torch.randn(2, 2), requires_grad=True) - - @torch.jit.compile - def cell(x, w): - return x * w + 2 - - out = cell(cell(cell(x, w), w), w) - self.assertFalse(cell.has_trace_for(x, w)) - - out.sum().backward() - self.assertTrue(cell.has_trace_for(x, w)) - - torch.jit.verify(cell, (x, w), devices=[]) - + # TODO: implement + @unittest.expectedFailure def test_output_unflatten(self): """Check that outputs of traced functions retain the original structure and nesting""" - x = Variable(torch.randn(2, 2), requires_grad=True) - def fn(x): return (x * 2, (x ** 2, x + 4, (x + 2,), ), x * 4) - expected_out = fn(x) - fn = torch.jit.compile(fn) - - def recursive_sum(obj): - if isinstance(obj, Variable): - return obj.sum() - else: - return sum(recursive_sum(o) for o in obj) - - recursive_sum(fn(x)).backward() - self.assertTrue(fn.has_trace_for(x)) - with self.assertCompiled(fn): - self.assertEqual(fn(x), expected_out) + self.checkTrace(fn, (torch.randn(2, 2),)) + # TODO: implement + @unittest.expectedFailure def test_input_flatten(self): """Check that inputs to traced functions are flattened""" - def make_var(): - return Variable(torch.randn(1), requires_grad=True) - x = (make_var(), (make_var(), make_var())) def fn(x, t): y, z = t return x * y * z - expected_out = fn(*x) - fn = torch.jit.compile(fn) - fn(*x).backward() - self.assertTrue(fn.has_trace_for(*x)) - with self.assertCompiled(fn): - self.assertEqual(fn(*x), expected_out) + inputs = (torch.randn(1), (torch.randn(1), torch.randn(1))) + self.checkTrace(fn, inputs) + # TODO: adapt to a GraphExecutor test + @unittest.skip("Need to instrument GraphExecutors a bit more") def test_flags(self): - x = Variable(torch.randn(2, 2)) + x, y = torch.randn(2, 2) y = Variable(torch.randn(2, 2)) @torch.jit.compile @@ -876,53 +564,15 @@ class TestJit(TestCase): self.assertEqual(grad_v, expected_grad) self.assertEqual(fn.has_trace_for(x, y), rx or ry) - def test_no_grad_fallback(self): - """Check that Traceable falls back to num_backwards=0 if in no-backprop mode""" - x = Variable(torch.randn(2, 2)) - y = Variable(torch.randn(2, 2), requires_grad=True) - - @torch.jit.compile - def fn(x, y): - return x * x + x * y - - out = fn(x, y) - self.assertFalse(fn.has_trace_for(x, y)) - with torch.no_grad(): - out = fn(x, y) - self.assertTrue(fn.has_trace_for(x, y)) - with self.assertCompiled(fn): - out2 = fn(x, y) - self.assertEqual(out, out2) - - def test_backward_flag_checks(self): - x = Variable(torch.randn(1), requires_grad=True) - - @torch.jit.compile(nderivs=2) - def fn(x): - return x * x - - grad_x, = torch.autograd.grad(fn(x), (x,), create_graph=True) - self.assertFalse(fn.has_trace_for(x)) - grad_x.backward() - self.assertTrue(fn.has_trace_for(x)) - - with self.assertRaisesRegex(RuntimeError, 'was compiled with'): - fn(x).backward(Variable(torch.ones(1), requires_grad=True)) - with self.assertRaisesRegex(RuntimeError, 'was compiled with'): - grad_x, = torch.autograd.grad(fn(x), (x,), create_graph=True) - grad_x.backward(Variable(torch.ones(1), requires_grad=True)) - - # TODO: Test executing this - def test_python_ir(self): - x = Variable(torch.Tensor([0.4]), requires_grad=True) - y = Variable(torch.Tensor([0.7]), requires_grad=True) + x = torch.tensor([0.4], requires_grad=True) + y = torch.tensor([0.7], requires_grad=True) def doit(x, y): return torch.sigmoid(torch.tanh(x * (x + y))) - traced, _ = torch.jit.get_trace_graph(doit, (x, y)) - g = torch._C._jit_get_graph(traced) + trace, _ = torch.jit.get_trace_graph(doit, (x, y)) + g = trace.graph() g2 = torch._C.Graph() g_to_g2 = {} for node in g.inputs(): @@ -937,9 +587,9 @@ class TestJit(TestCase): g2.registerOutput(g_to_g2[node]) t_node = g2.create("prim::TensorTest").t_("a", torch.ones([2, 2])) - assert(t_node.attributeNames() == ["a"]) + self.assertEqual(t_node.attributeNames(), ["a"]) g2.appendNode(t_node) - assert(torch.equal(torch.ones([2, 2]), t_node.t("a"))) + self.assertTrue(torch.equal(torch.ones(2, 2), t_node.t("a"))) self.assertExpected(str(g2)) @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") @@ -950,245 +600,40 @@ class TestJit(TestCase): self.assertExpected(torch._C._jit_run_cpp_tests()) def test_batchnorm(self): - x = Variable(torch.randn(2, 2, 2, 2).fill_(1.0), requires_grad=True) + x = torch.ones(2, 2, 2, 2) trace, _ = torch.jit.get_trace_graph(nn.BatchNorm2d(2), x) - self.assertExpectedTrace(trace) + self.assertExpectedGraph(trace) def test_dropout(self): - x = Variable(torch.randn(2, 2).fill_(1.0), requires_grad=True) + x = torch.ones(2, 2) trace, _ = torch.jit.get_trace_graph(nn.Dropout(0.6), x) - self.assertExpectedTrace(trace) - - def test_batchnorm_run_twice(self): - @torch.jit.compile(nderivs=0) - class MyBatchNorm2d(nn.BatchNorm2d): - pass - - bn = MyBatchNorm2d(1) - x = Variable(torch.randn(5, 1, 2, 1)) - z = bn(x) - with self.assertCompiled(bn): - z2 = bn(x) - self.assertEqual(z, z2) - - def test_non_decorator_use_fails(self): - MyLSTM = torch.jit.compile(nn.LSTM) - self.assertRaisesRegex(TypeError, "class decorator", lambda: MyLSTM(2, 2)) + self.assertExpectedGraph(trace) def test_conv(self): - x = Variable(torch.randn(20, 16, 50, 40).fill_(1.0), requires_grad=True) + x = torch.ones(20, 16, 50, 40) trace, _ = torch.jit.get_trace_graph(nn.Conv2d(16, 13, 3, bias=False), x) - self.assertExpectedTrace(trace) - - def test_reuse_function(self): - @torch.jit.compile(nderivs=0) - def clinear(*args): - return F.linear(*args) - - def cast(x): - return x - - input = Variable(cast(torch.randn(1, 1))) - weights = Variable(cast(torch.randn(1, 1))) - bias = Variable(cast(torch.randn(1, 1))) - - # linear AKA addmm without bias is of particular interest - # because we allocate a zero-filled new variable when we execute, - # and then *fill* it with the result - - r1_ = clinear(input, weights) - with self.assertCompiled(clinear): - r1 = clinear(r1_, weights) - r2 = F.linear(F.linear(input, weights), weights) - - self.assertEqual(r1, r2) - - def test_unused_input(self): - @torch.jit.compile(nderivs=1) - def fn(a, b, c): - return a + b - - a, b, c = [Variable(torch.randn(2, 2), requires_grad=True) for _ in range(3)] - fn(a, b, c).sum().backward() - with self.assertCompiled(fn): - fn(a, b, c).sum().backward() + self.assertExpectedGraph(trace) def test_repeated_input(self): - @torch.jit.compile(nderivs=1) def fn(a, b): return a + b - a, b = [Variable(torch.randn(2, 2), requires_grad=True) for _ in range(2)] - fn(a, a).sum().backward() - with self.assertCompiled(fn): - fn(a, a).sum().backward() - with self.assertCompiled(fn): - fn(a, b).sum().backward() - self.assertExpected(str(fn.graph_for(a, a))) + ge = self.checkTrace(fn, [torch.randn(2, 2)] * 2) + self.assertExpectedGraph(ge.graph) def test_repeated_output(self): - @torch.jit.compile(nderivs=1) def fn(a, b): z = a + b return z, z - a, b = [Variable(torch.randn(2, 2), requires_grad=True) for _ in range(2)] - sum(fn(a, b)).sum().backward() - with self.assertCompiled(fn): - sum(fn(a, b)).sum().backward() - self.assertExpected(str(fn.graph_for(a, b))) - - def test_re_enter(self): - @torch.jit.compile(nderivs=1) - def fn(a, b): - return a + b - - @torch.jit.compile(nderivs=1) - def fn2(a, b, c): - return fn(a, b) + c - - a, b, c = [Variable(torch.randn(2, 2), requires_grad=True) for _ in range(3)] - - fn(a, b).sum().backward() - with self.assertCompiled(fn): - fn(a, b).sum().backward() - - fn2(a, b, c).sum().backward() - with self.assertCompiled(fn2): - fn2(a, b, c).sum().backward() - - def test_mini_wlm(self): - """Exercise null-edge pruning in the tracer.""" - - @torch.jit.compile - class MyModel(nn.Module): - def __init__(self): - super(MyModel, self).__init__() - self.encoder = nn.Embedding(2, 2) - - def forward(self, input, hidden): - emb = self.encoder(input) - hidden = hidden.clone() # simulate some RNN operation - return emb, hidden - - model = MyModel() - - x = Variable(torch.LongTensor([[0, 1], [1, 0]])) - y = Variable(torch.FloatTensor([0])) - - z, _ = model(x, y) - z.sum().backward() - self.assertTrue(model.has_trace_for(x, y)) - - with self.assertCompiled(model): - z, _ = model(x, y) - z.sum().backward() - - def test_module_cast(self): - """Compiled modules can be casted to other data types""" - @torch.jit.compile(nderivs=0) - class Adder(nn.Module): - def __init__(self): - super(Adder, self).__init__() - self.y = nn.Parameter(torch.randn(2, 2)) - - def forward(self, x): - return x + self.y - - x = Variable(torch.randn(2, 2).float()) - # Wrap it in a sequential to make sure it works for submodules - a = nn.Sequential(Adder()).float() - - def check_type(caster): - caster(a) - a(caster(x)) - with self.assertCompiled(a[0]): - a(caster(x)) - - check_type(lambda x: x) - check_type(lambda x: x.double()) - if torch.cuda.is_available(): - check_type(lambda x: x.float().cuda()) - check_type(lambda x: x.double().cuda()) - self.assertEqual(a[0].hits, 4 if torch.cuda.is_available() else 2) - - # Tracer fails when it receives the same grad variable as multiple input to - # traced region. The problem is that it's not immediately obvious how to - # assign multiple inputs to this Variable. It might be possible to solve - # this using the view mechanism, but this requires some thought. - # In general, it should be supported, because the user has no control - # over this (and it's quite common, e.g. the sum call below will pass the same - # grad variable as both inputs to grad of fn). - @unittest.skip("Broken - repeated grads trigger an assertion failure.") - def test_repeated_grad(self): - @torch.jit.compile - def fn(x): - return x * x, x + x - - x = Variable(torch.randn(5, 5), requires_grad=True) - # This shouldn't raise! - sum(fn(x)).sum().backward() - - def test_input_pruning(self): - """Check that stage 1 will return only one value""" - # One of the inputs doesn't require grad, so it should be pruned - @torch.jit.compile - def fn(x, y): - return x * y, x + y - - x = Variable(torch.randn(5, 5), requires_grad=True) - y = Variable(torch.randn(5, 5)) - - out = fn(x, y) - (out[0] * out[1]).sum().backward() - with self.assertCompiled(fn): - fn(x, y) - self.assertExpected(str(fn.graph_for(x, y))) - - def test_output_pruning(self): - """Check that stage 1 will take one value as an argument""" - # One of the outputs doesn't require grad, so it should be pruned - @torch.jit.compile - def fn(x, y): - return x * y, y + y - - x = Variable(torch.randn(5, 5), requires_grad=True) - y = Variable(torch.randn(5, 5)) - - out = fn(x, y) - (out[0] * out[1]).sum().backward() - with self.assertCompiled(fn): - fn(x, y) - self.assertExpected(str(fn.graph_for(x, y))) + ge = self.checkTrace(fn, [torch.randn(2, 2) for _ in range(2)]) + self.assertExpectedGraph(ge.graph) @skipIfNoTorchVision def test_alexnet(self): - return - x = Variable(torch.randn(10, 3, 224, 224).fill_(1.0), requires_grad=True) + x = torch.ones(1, 3, 224, 224) trace, _ = torch.jit.get_trace_graph(torchvision.models.AlexNet(), x) - self.assertExpectedTrace(trace) - self.assertExportImport(trace, (x, )) - # NB: Purposely NOT testing protobuf export here - - def test_debug_info(self): - """Check that debug info doesn't crash and has some reasonable info""" - - @torch.jit.compile(nderivs=1) - def fn(x, y): - return x * y + x + y - - x = Variable(torch.randn(5, 5), requires_grad=True) - y = Variable(torch.randn(5, 5), requires_grad=True) - - out = fn(x, y) - - out.sum().backward() - - for _ in range(0, 100): - out = fn(x, y) - info_str = fn.jit_debug_info() - self.assertTrue("hits: 100" in info_str) - self.assertTrue("stage 1" in info_str) + self.assertExpectedGraph(trace) # Inplace copies don't work with tracer yet. # This is actually somewhat important to support correctly @@ -1197,7 +642,7 @@ class TestJit(TestCase): # viewed portion. @unittest.expectedFailure def test_inplace_copy(self): - x = Variable(torch.randn(4, 4), requires_grad=True) + x = torch.randn(4, 4, requires_grad=True) def f(x): out = Variable(torch.zeros(x.size())) @@ -1205,28 +650,9 @@ class TestJit(TestCase): return out trace, z = torch.jit.get_trace_graph(f, (x, ), nderivs=0) - torch._C._jit_pass_lint(trace.graph()) - torch._C._jit_pass_dce(trace.graph()) - self.assertExpectedTrace(trace) - self.assertExportImport(trace, (x, )) - - def test_index_trace(self): - x = Variable(torch.randn(4, 4), requires_grad=True) - trace, z = torch.jit.get_trace_graph(lambda x: x[0], (x, ), nderivs=1) - z.sum().backward() - torch._C._jit_pass_lint(trace.graph()) - torch._C._jit_pass_dce(trace.graph()) - self.assertExpectedTrace(trace) - - def test_saved_output(self): - x = Variable(torch.randn(4, 4), requires_grad=True) - - @torch.jit.compile(nderivs=1) - def fn(x): - return x.sigmoid() - - fn(x).sum().backward() - self.assertExpected(str(fn.graph_for(x))) + self.run_pass('dce', trace) + self.assertExpectedGraph(trace) + self.assertExportImport(trace, (x,)) def test_shared_param(self): @@ -1239,18 +665,18 @@ class TestJit(TestCase): return x * self.a + self.b m = MyModule() - trace, _ = torch.jit.get_trace_graph(m, (Variable(torch.randn(2, 2)),), nderivs=0) + trace, _ = torch.jit.get_trace_graph(m, (torch.randn(2, 2),), nderivs=0) self.assertEqual(len(list(trace.graph().inputs())), 2) - self.assertExpected(str(trace)) + self.assertExpectedGraph(trace) def test_nested_inplace(self): - x = Variable(torch.randn(2, 2)) + x = torch.randn(2, 2) trace, _ = torch.jit.get_trace_graph(lambda x: F.threshold(x, 0, 0, inplace=True), (x,), nderivs=0) - self.assertExpectedTrace(trace) - self.assertExportImport(trace, (x, )) + self.assertExpectedGraph(trace) + self.assertExportImport(trace, (x,)) - def checkGraphExecutor(self, func, reference_tensors, input_tensors=None, - optimize=True, drop=None, allow_unused=False): + def checkTrace(self, func, reference_tensors, input_tensors=None, + optimize=True, drop=None, allow_unused=False): def allSum(vs): # drop allows us to remove some values from ever being used # to test unused outputs @@ -1262,11 +688,10 @@ class TestJit(TestCase): if input_tensors is None: input_tensors = reference_tensors - nograd_inputs = [Variable(t) for t in reference_tensors] - recording_inputs = [Variable(t, requires_grad=True) - for t in reference_tensors] + nograd_inputs = reference_tensors + recording_inputs = [t.clone().requires_grad_() for t in reference_tensors] - ge = torch._C.GraphExecutor(func, [Variable(t) for t in input_tensors], optimize) + ge = torch.jit.trace(*input_tensors, optimize=optimize)(func) # test no gradients case @@ -1307,7 +732,12 @@ class TestJit(TestCase): self.assertEqual(outputs, outputs_ge) self.assertEqual(grads, grads_ge) - self.assertEqual(grads2, grads2_ge) + for g2, g2_ge in zip(grads2, grads2_ge): + if g2 is None and g2_ge is None: + continue + self.assertTrue(torch.allclose(g2, g2_ge, atol=5e-4, rtol=1e-4)) + + return ge def run_ge_tests(self, optimize, use_cuda): def rand(*args): @@ -1315,27 +745,27 @@ class TestJit(TestCase): if use_cuda: t = t.cuda() return t - self.checkGraphExecutor(lambda a, b: a * b + b, - [rand(1), rand(1)], [rand(2, 3), rand(2, 3)], - optimize=optimize) + self.checkTrace(lambda a, b: a * b + b, + [rand(1), rand(1)], [rand(2, 3), rand(2, 3)], + optimize=optimize) # trivial identity - self.checkGraphExecutor(lambda a, b: ( + self.checkTrace(lambda a, b: ( b, a), [rand(1), rand(1)], optimize=optimize) def foo(a): t = a * a return t * t, 4 * t - self.checkGraphExecutor(foo, [rand(1)], optimize=optimize) + self.checkTrace(foo, [rand(1)], optimize=optimize) # unused input - self.checkGraphExecutor( + self.checkTrace( lambda a, b: a * a, [rand(1), rand(1)], optimize=optimize, allow_unused=True) # test outputs that do not get used in grad - self.checkGraphExecutor(foo, [rand(1)], drop=1, optimize=optimize) + self.checkTrace(foo, [rand(1)], drop=1, optimize=optimize) # test autograd fallback - self.checkGraphExecutor(lambda a, b: a * b / - (a - 2 * b) + b, [rand(1), rand(1)], - optimize=optimize) + self.checkTrace(lambda a, b: a * b / + (a - 2 * b) + b, [rand(1), rand(1)], + optimize=optimize) def test_ge_unoptimized(self): self.run_ge_tests(False, False) @@ -1373,10 +803,12 @@ class TestJit(TestCase): self.assertEqual(g2result, g2result2) def test_trace_annotation(self): - @torch.jit.trace(Variable(torch.rand(1))) + @torch.jit.trace(torch.rand(1)) def foo(a): return a + a + a - s = Variable(torch.rand(2)) + + x = torch.randn(5, 5) + self.assertEqual(foo(x), x + x + x) @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") @unittest.skipIf(not RUN_CUDA, "calls .cuda()") @@ -1439,86 +871,6 @@ class TestJit(TestCase): self.assertEqual(out, out_state) self.assertNotEqual(out, out_ones) - def test_shape_prop_mismatch_output(self): - with self.assertRaises(RuntimeError): - cu = torch.jit.CompilationUnit(''' - def test_shape_prop_mismatch_output(a): - b = slice(a, dim=0, end=-2, start=2, step=1) - b = topk(a, dim=0, k=2, largest=True, sorted=True) - return b - ''') - inputs = [torch.zeros(10)] - outputs = [torch.zeros(2), torch.from_numpy(np.array([1, 5])).long()] - - real_outs = cu.test_shape_prop_mismatch_output(*inputs) - self.assertEqual(real_outs, outputs) - - def test_view_shape_prop(self): - cu = torch.jit.CompilationUnit(''' - def test_view_shape_prop(a): - return view(a, size=[-1]) - ''') - inputs = [torch.zeros(10, 10)] - outputs = torch.zeros(100) - - real_outs = cu.test_view_shape_prop(*inputs) - self.assertEqual(real_outs, outputs) - - def test_integral_shape_inference(self): - cu = torch.jit.CompilationUnit(''' - def test_integral_shape_inference(a): - return a / a - ''') - inputs = [torch.ones(10, 10).type(torch.LongTensor)] - outputs = torch.ones(10, 10) - - self.assertEqual(cu.test_integral_shape_inference(*inputs), outputs) - - def test_shape_analysis_broadcast(self): - def broadcast(a, b): - return a + b - - x = torch.randn(3, 1, 5, requires_grad=True) - y = torch.randn(4, 1, 8, 5, requires_grad=True) - - graph = torch.jit._script_graph(broadcast) - torch._C._jit_pass_shape_analysis(graph, (x, y), False) - self.assertExpected(str(graph)) - - def test_fuser_multiple_blocks(self): - cu = torch.jit.CompilationUnit(''' - def test_fuser_multiple_blocks(this, that, theother, meme): - i = 0 - while i < 20: - this = cat([this, meme], dim=0) - that = cat([that, meme], dim=0) - theother = cat([theother, meme], dim=0) - i = i + 1 - return this, that, theother - ''') - - inputs = [torch.ones(0, 10, 10)] * 3 - inputs += [torch.ones(1, 10, 10)] - outputs = [torch.ones(20, 10, 10)] * 3 - - self.assertEqual(cu.test_fuser_multiple_blocks(*inputs), outputs) - - def test_dropout_script(self): - - eg = torch.zeros(1, 2, 3, requires_grad=True) - - @torch.jit.trace(eg) - def foo(x): - x = torch.neg(x) - return F.dropout(x) - - class MyDrop(nn.Module): - def forward(self, x): - return foo(x) - - f = io.BytesIO() - torch.onnx.export(MyDrop(), (eg,), f, verbose=False) - def test_python_function(self): class MyFn(Function): @staticmethod @@ -1752,6 +1104,61 @@ class TestScript(TestCase): # NOTE: cannot optimize yet because broadcasts are not inserted before the fuser runs self.checkScript(script, [alpha, beta, x, y], optimize=False, outputs=outputs) + def test_view_shape_prop(self): + cu = torch.jit.CompilationUnit(''' + def test_view_shape_prop(a): + return view(a, size=[-1]) + ''') + inputs = [torch.zeros(10, 10)] + outputs = torch.zeros(100) + + real_outs = cu.test_view_shape_prop(*inputs) + self.assertEqual(real_outs, outputs) + + def test_integral_shape_inference(self): + cu = torch.jit.CompilationUnit(''' + def test_integral_shape_inference(a): + return a / a + ''') + inputs = [torch.ones(10, 10).type(torch.LongTensor)] + outputs = torch.ones(10, 10) + + self.assertEqual(cu.test_integral_shape_inference(*inputs), outputs) + + def test_fuser_multiple_blocks(self): + cu = torch.jit.CompilationUnit(''' + def test_fuser_multiple_blocks(this, that, theother, meme): + i = 0 + while i < 20: + this = cat([this, meme], dim=0) + that = cat([that, meme], dim=0) + theother = cat([theother, meme], dim=0) + i = i + 1 + return this, that, theother + ''') + + inputs = [torch.ones(0, 10, 10)] * 3 + inputs += [torch.ones(1, 10, 10)] + outputs = [torch.ones(20, 10, 10)] * 3 + + self.assertEqual(cu.test_fuser_multiple_blocks(*inputs), outputs) + + def test_dropout_script(self): + + eg = torch.zeros(1, 2, 3, requires_grad=True) + + @torch.jit.trace(eg) + def foo(x): + x = torch.neg(x) + return F.dropout(x) + + class MyDrop(nn.Module): + def forward(self, x): + return foo(x) + + f = io.BytesIO() + torch.onnx.export(MyDrop(), (eg,), f, verbose=False) + @unittest.skip("RuntimeError: VariableType::ID() not implemented") def test_cast(self): script = ''' |