summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorLu Fang <lufang@fb.com>2019-04-08 16:01:30 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-04-08 16:06:00 -0700
commit443a58e03d00fa04a429ab625c1fc7e3a7e4d529 (patch)
treefa8e7bbe1a66fba3d5e086bd9439cada437b7b15 /test
parent09c19e10682884efe8433ea06009de589a5b4183 (diff)
downloadpytorch-443a58e03d00fa04a429ab625c1fc7e3a7e4d529.tar.gz
pytorch-443a58e03d00fa04a429ab625c1fc7e3a7e4d529.tar.bz2
pytorch-443a58e03d00fa04a429ab625c1fc7e3a7e4d529.zip
Export C10 operator in PyTorch Model (#18210)
Summary: Almost there, feel free to review. these c10 operators are exported to _caffe2 domain. TODO: - [x] let the onnx checker pass - [x] test tensor list as argument - [x] test caffe2 backend and converter - [x] check the c10 schema can be exported to onnx - [x] refactor the test case to share some code - [x] fix the problem in ONNX_ATEN_FALLBACK Pull Request resolved: https://github.com/pytorch/pytorch/pull/18210 Reviewed By: zrphercule Differential Revision: D14600916 Pulled By: houseroad fbshipit-source-id: 2592a75f21098fb6ceb38c5d00ee40e9e01cd144
Diffstat (limited to 'test')
-rw-r--r--test/expect/TestPytorchExportModes.test_aten_fallback.expect2
-rw-r--r--test/expect/TestPytorchExportModes.test_onnx_aten.expect2
-rw-r--r--test/expect/TestScript.test_listconstruct_erasure.expect2
-rw-r--r--test/expect/TestScript.test_onnx_raw_export_script_truediv.expect2
-rw-r--r--test/onnx/expect/TestOperators.test_c2_op.expect173
-rw-r--r--test/onnx/test_operators.py26
-rw-r--r--test/onnx/test_pytorch_onnx_caffe2.py185
-rw-r--r--test/onnx/verify.py3
8 files changed, 390 insertions, 5 deletions
diff --git a/test/expect/TestPytorchExportModes.test_aten_fallback.expect b/test/expect/TestPytorchExportModes.test_aten_fallback.expect
index fdb6194f2d..8e903310eb 100644
--- a/test/expect/TestPytorchExportModes.test_aten_fallback.expect
+++ b/test/expect/TestPytorchExportModes.test_aten_fallback.expect
@@ -13,5 +13,5 @@ ModelProto {
Node {type: "ATen", inputs: [2], outputs: [3,4], attributes: [{ name: 'operator', type: string, value: 'qr'}]}
]
}
- opset_import: [OperatorSetIdProto { domain: }],
+ opset_import: [OperatorSetIdProto { domain: }OperatorSetIdProto { domain: org.pytorch.aten}],
}
diff --git a/test/expect/TestPytorchExportModes.test_onnx_aten.expect b/test/expect/TestPytorchExportModes.test_onnx_aten.expect
index 222fa42704..22f1c57f95 100644
--- a/test/expect/TestPytorchExportModes.test_onnx_aten.expect
+++ b/test/expect/TestPytorchExportModes.test_onnx_aten.expect
@@ -12,5 +12,5 @@ ModelProto {
Node {type: "ATen", inputs: [0,1], outputs: [2], attributes: [{ name: 'operator', type: string, value: 'fmod'}]}
]
}
- opset_import: [OperatorSetIdProto { domain: }],
+ opset_import: [OperatorSetIdProto { domain: }OperatorSetIdProto { domain: org.pytorch.aten}],
}
diff --git a/test/expect/TestScript.test_listconstruct_erasure.expect b/test/expect/TestScript.test_listconstruct_erasure.expect
index 09442626e4..818a115e47 100644
--- a/test/expect/TestScript.test_listconstruct_erasure.expect
+++ b/test/expect/TestScript.test_listconstruct_erasure.expect
@@ -16,5 +16,5 @@ ModelProto {
Node {type: "ATen", inputs: [0,4], outputs: [5], attributes: [{ name: 'operator', type: string, value: 'index'}]}
]
}
- opset_import: [OperatorSetIdProto { domain: }],
+ opset_import: [OperatorSetIdProto { domain: }OperatorSetIdProto { domain: org.pytorch.aten}],
}
diff --git a/test/expect/TestScript.test_onnx_raw_export_script_truediv.expect b/test/expect/TestScript.test_onnx_raw_export_script_truediv.expect
index 4a0f296688..dbaff42ed7 100644
--- a/test/expect/TestScript.test_onnx_raw_export_script_truediv.expect
+++ b/test/expect/TestScript.test_onnx_raw_export_script_truediv.expect
@@ -21,5 +21,5 @@ ModelProto {
Node {type: "add", inputs: [x,z,1], outputs: [10], attributes: []}
]
}
- opset_import: [OperatorSetIdProto { domain: }],
+ opset_import: [OperatorSetIdProto { domain: }OperatorSetIdProto { domain: org.pytorch.aten}OperatorSetIdProto { domain: org.pytorch.prim}],
}
diff --git a/test/onnx/expect/TestOperators.test_c2_op.expect b/test/onnx/expect/TestOperators.test_c2_op.expect
new file mode 100644
index 0000000000..568df7594c
--- /dev/null
+++ b/test/onnx/expect/TestOperators.test_c2_op.expect
@@ -0,0 +1,173 @@
+ir_version: 4
+producer_name: "pytorch"
+producer_version: "1.1"
+graph {
+ node {
+ input: "0"
+ input: "1"
+ input: "2"
+ input: "3"
+ output: "4"
+ output: "5"
+ op_type: "GenerateProposals"
+ attribute {
+ name: "spatial_scale"
+ f: 2
+ type: FLOAT
+ }
+ attribute {
+ name: "pre_nms_topN"
+ i: 6000
+ type: INT
+ }
+ attribute {
+ name: "post_nms_topN"
+ i: 300
+ type: INT
+ }
+ attribute {
+ name: "nms_thresh"
+ f: 0.7
+ type: FLOAT
+ }
+ attribute {
+ name: "min_size"
+ f: 16
+ type: FLOAT
+ }
+ attribute {
+ name: "angle_bound_on"
+ i: 1
+ type: INT
+ }
+ attribute {
+ name: "angle_bound_lo"
+ i: -90
+ type: INT
+ }
+ attribute {
+ name: "angle_bound_hi"
+ i: 90
+ type: INT
+ }
+ attribute {
+ name: "clip_angle_thresh"
+ f: 1
+ type: FLOAT
+ }
+ domain: "org.pytorch._caffe2"
+ }
+ name: "torch-jit-export"
+ input {
+ name: "0"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 3
+ }
+ dim {
+ dim_value: 4
+ }
+ dim {
+ dim_value: 10
+ }
+ dim {
+ dim_value: 8
+ }
+ }
+ }
+ }
+ }
+ input {
+ name: "1"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 3
+ }
+ dim {
+ dim_value: 16
+ }
+ dim {
+ dim_value: 10
+ }
+ dim {
+ dim_value: 8
+ }
+ }
+ }
+ }
+ }
+ input {
+ name: "2"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 3
+ }
+ dim {
+ dim_value: 3
+ }
+ }
+ }
+ }
+ }
+ input {
+ name: "3"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 4
+ }
+ dim {
+ dim_value: 4
+ }
+ }
+ }
+ }
+ }
+ output {
+ name: "4"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 0
+ }
+ dim {
+ dim_value: 5
+ }
+ }
+ }
+ }
+ }
+ output {
+ name: "5"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 0
+ }
+ }
+ }
+ }
+ }
+}
+opset_import {
+ version: 9
+}
+opset_import {
+ domain: "org.pytorch._caffe2"
+ version: 0
+}
diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py
index 061c3b0ea1..77764f55db 100644
--- a/test/onnx/test_operators.py
+++ b/test/onnx/test_operators.py
@@ -569,6 +569,32 @@ class TestOperators(TestCase):
x = torch.randn(3, 4).float()
self.assertONNX(MyModule(), (x,), _retain_param_name=False)
+ def test_c2_op(self):
+ class MyModel(torch.nn.Module):
+ def __init__(self):
+ super(MyModel, self).__init__()
+
+ def forward(self, scores, bbox_deltas, im_info, anchors):
+ a, b = torch.ops._caffe2.GenerateProposals(
+ (scores), (bbox_deltas), (im_info), (anchors),
+ 2.0, 6000, 300, 0.7, 16, True, -90, 90, 1.0,
+ )
+ return a, b
+
+ model = MyModel()
+ A = 4
+ H = 10
+ W = 8
+ img_count = 3
+ scores = torch.ones(img_count, A, H, W, dtype=torch.float32)
+ bbox_deltas = torch.linspace(0, 10, steps=img_count * 4 * A * H * W,
+ dtype=torch.float32)
+ bbox_deltas = bbox_deltas.view(img_count, 4 * A, H, W)
+ im_info = torch.ones(img_count, 3, dtype=torch.float32)
+ anchors = torch.ones(A, 4, dtype=torch.float32)
+ inputs = (scores, bbox_deltas, im_info, anchors)
+ self.assertONNX(model, inputs)
+
if __name__ == '__main__':
no_onnx_dep_flag = '--no-onnx'
diff --git a/test/onnx/test_pytorch_onnx_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py
index daea9ef412..f61c6808c3 100644
--- a/test/onnx/test_pytorch_onnx_caffe2.py
+++ b/test/onnx/test_pytorch_onnx_caffe2.py
@@ -33,6 +33,8 @@ import model_defs.word_language_model as word_language_model
from model_defs.mnist import MNIST
from model_defs.lstm_flattening_result import LstmFlatteningResult
from model_defs.rnn_model_with_packed_sequence import RnnModelWithPackedSequence
+from caffe2.python.operator_test.torch_integration_test import (generate_rois_rotated,
+ create_bbox_transform_inputs)
import onnx
import caffe2.python.onnx.backend as c2
@@ -1230,6 +1232,189 @@ class TestCaffe2Backend(unittest.TestCase):
x = torch.randn(3, 3, requires_grad=True)
self.run_model_test(NarrowModel(), train=False, input=x, batch_size=BATCH_SIZE)
+ def test_c2_roi_align(self):
+ class MyModel(torch.nn.Module):
+ def __init__(self):
+ super(MyModel, self).__init__()
+
+ def forward(self, feature, rois):
+ roi_feature = torch.ops._caffe2.RoIAlign(
+ feature, rois, order="NCHW", spatial_scale=1.0,
+ pooled_h=3, pooled_w=3, sampling_ratio=3,
+ )
+ return roi_feature
+
+ def rand_roi(N, C, H, W):
+ return [
+ float(int(N * np.random.rand())),
+ 0.5 * np.random.rand() * W,
+ 0.5 * np.random.rand() * H,
+ (0.5 + 0.5 * np.random.rand()) * W,
+ (0.5 + 0.5 * np.random.rand()) * H,
+ ]
+
+ N, C, H, W = 1, 4, 10, 8
+ feature = torch.randn(N, C, H, W)
+ rois = torch.tensor([rand_roi(N, C, H, W) for _ in range(10)])
+ inputs = (feature, rois)
+ self.run_model_test(MyModel(), train=False, input=inputs, batch_size=3)
+
+ def test_c2_generate_proposals(self):
+ class MyModel(torch.nn.Module):
+ def __init__(self):
+ super(MyModel, self).__init__()
+
+ def forward(self, scores, bbox_deltas, im_info, anchors):
+ a, b = torch.ops._caffe2.GenerateProposals(
+ scores, bbox_deltas, im_info, anchors,
+ 2.0, 6000, 300, 0.7, 16, True, -90, 90, 1.0,
+ )
+ return a, b
+
+ A = 4
+ H = 10
+ W = 8
+ img_count = 3
+ scores = torch.ones(img_count, A, H, W, dtype=torch.float32)
+ bbox_deltas = torch.linspace(0, 10, steps=img_count * 4 * A * H * W,
+ dtype=torch.float32)
+ bbox_deltas = bbox_deltas.view(img_count, 4 * A, H, W)
+ im_info = torch.ones(img_count, 3, dtype=torch.float32)
+ anchors = torch.ones(A, 4, dtype=torch.float32)
+ inputs = (scores, bbox_deltas, im_info, anchors)
+ self.run_model_test(MyModel(), train=False, input=inputs, batch_size=3)
+
+ def test_c2_bbox_transform(self):
+ class MyModel(torch.nn.Module):
+ def __init__(self):
+ super(MyModel, self).__init__()
+
+ def forward(self, rois, deltas, im_info):
+ a, b = torch.ops._caffe2.BBoxTransform(
+ rois,
+ deltas,
+ im_info,
+ weights=[1., 1., 1., 1.],
+ apply_scale=False,
+ rotated=True,
+ angle_bound_on=True,
+ angle_bound_lo=-90,
+ angle_bound_hi=90,
+ clip_angle_thresh=0.5,
+ )
+ return a, b
+
+ roi_counts = [0, 2, 3, 4, 5]
+ batch_size = len(roi_counts)
+ total_rois = sum(roi_counts)
+ im_dims = np.random.randint(100, 600, batch_size)
+ rois = generate_rois_rotated(roi_counts, im_dims)
+ box_dim = 5
+ num_classes = 7
+ deltas = np.random.randn(total_rois, box_dim * num_classes).astype(np.float32)
+ im_info = np.zeros((batch_size, 3)).astype(np.float32)
+ im_info[:, 0] = im_dims
+ im_info[:, 1] = im_dims
+ im_info[:, 2] = 1.0
+ im_info = torch.zeros((batch_size, 3))
+ inputs = (torch.tensor(rois), torch.tensor(deltas), torch.tensor(im_info))
+ self.run_model_test(MyModel(), train=False, input=inputs, batch_size=3)
+
+ # BoxWithNMSLimits has requirements for the inputs, so randomly generated inputs
+ # in Caffe2BackendTestEmbed doesn't work with this op.
+ @skipIfEmbed
+ def test_c2_box_with_nms_limits(self):
+ roi_counts = [0, 2, 3, 4, 5]
+ num_classes = 7
+ rotated = False
+ angle_bound_on = True
+ clip_angle_thresh = 0.5
+ rois, deltas, im_info = create_bbox_transform_inputs(
+ roi_counts, num_classes, rotated
+ )
+ pred_bbox, batch_splits = [
+ t.detach().numpy()
+ for t in torch.ops._caffe2.BBoxTransform(
+ torch.tensor(rois),
+ torch.tensor(deltas),
+ torch.tensor(im_info),
+ [1.0, 1.0, 1.0, 1.0],
+ False,
+ rotated,
+ angle_bound_on,
+ -90,
+ 90,
+ clip_angle_thresh,
+ )
+ ]
+ class_prob = np.random.randn(sum(roi_counts), num_classes).astype(np.float32)
+ score_thresh = 0.5
+ nms_thresh = 0.5
+ topk_per_image = int(sum(roi_counts) / 2)
+
+ class MyModel(torch.nn.Module):
+ def __init__(self):
+ super(MyModel, self).__init__()
+
+ def forward(self, class_prob, pred_bbox, batch_splits):
+ a, b, c, d = torch.ops._caffe2.BoxWithNMSLimit(
+ class_prob,
+ pred_bbox,
+ batch_splits,
+ score_thresh=score_thresh,
+ nms=nms_thresh,
+ detections_per_im=topk_per_image,
+ soft_nms_enabled=False,
+ soft_nms_method="linear",
+ soft_nms_sigma=0.5,
+ soft_nms_min_score_thres=0.001,
+ rotated=rotated,
+ )
+ return a, b, c, d
+
+ inputs = (torch.tensor(class_prob), torch.tensor(pred_bbox), torch.tensor(batch_splits))
+ self.run_model_test(MyModel(), train=False, input=inputs, batch_size=3)
+
+ def test_c2_inference_lstm(self):
+ num_layers = 4
+ seq_lens = 6
+ emb_lens = 10
+ has_bias = True
+ batch_first = True
+ is_bidirectional = True
+
+ class MyModel(torch.nn.Module):
+ def __init__(self):
+ super(MyModel, self).__init__()
+
+ def forward(self, lstm_in):
+ a, b, c = torch.ops._caffe2.InferenceLSTM(
+ lstm_in, num_layers, has_bias, batch_first, is_bidirectional
+ )
+ return a, b, c
+
+ num_directions = 2
+ bsz = 5
+ hidden_size = 7
+ hx = np.zeros((num_layers * num_directions, bsz, hidden_size), dtype=np.float32)
+ inputs = np.random.randn(bsz, seq_lens, emb_lens).astype(np.float32)
+ torch_lstm = torch.nn.LSTM(
+ emb_lens,
+ hidden_size,
+ batch_first=batch_first,
+ bidirectional=is_bidirectional,
+ bias=has_bias,
+ num_layers=num_layers,
+ )
+ lstm_in = [
+ torch.from_numpy(inputs),
+ torch.from_numpy(hx),
+ torch.from_numpy(hx),
+ ] + [param.detach() for param in torch_lstm._flat_weights]
+
+ self.run_model_test(MyModel(), train=False, input=lstm_in, batch_size=3)
+
+
# a bit of metaprogramming to set up all the rnn tests
diff --git a/test/onnx/verify.py b/test/onnx/verify.py
index 61defae9ca..95f42a180c 100644
--- a/test/onnx/verify.py
+++ b/test/onnx/verify.py
@@ -68,7 +68,8 @@ class Errors(object):
"""
if isinstance(x, np.ndarray) and isinstance(y, np.ndarray):
try:
- np.testing.assert_allclose(x, y, rtol=self.rtol, atol=self.atol, equal_nan=False, verbose=True)
+ np.testing.assert_allclose(x, y, rtol=self.rtol, atol=self.atol,
+ equal_nan=True, verbose=True)
except AssertionError as e:
raise
k("{}{}".format(colonize(msg), str(e).lstrip()))