diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/expect/TestPytorchExportModes.test_aten_fallback.expect | 2 | ||||
-rw-r--r-- | test/expect/TestPytorchExportModes.test_onnx_aten.expect | 2 | ||||
-rw-r--r-- | test/expect/TestScript.test_listconstruct_erasure.expect | 2 | ||||
-rw-r--r-- | test/expect/TestScript.test_onnx_raw_export_script_truediv.expect | 2 | ||||
-rw-r--r-- | test/onnx/expect/TestOperators.test_c2_op.expect | 173 | ||||
-rw-r--r-- | test/onnx/test_operators.py | 26 | ||||
-rw-r--r-- | test/onnx/test_pytorch_onnx_caffe2.py | 185 | ||||
-rw-r--r-- | test/onnx/verify.py | 3 |
8 files changed, 390 insertions, 5 deletions
diff --git a/test/expect/TestPytorchExportModes.test_aten_fallback.expect b/test/expect/TestPytorchExportModes.test_aten_fallback.expect index fdb6194f2d..8e903310eb 100644 --- a/test/expect/TestPytorchExportModes.test_aten_fallback.expect +++ b/test/expect/TestPytorchExportModes.test_aten_fallback.expect @@ -13,5 +13,5 @@ ModelProto { Node {type: "ATen", inputs: [2], outputs: [3,4], attributes: [{ name: 'operator', type: string, value: 'qr'}]} ] } - opset_import: [OperatorSetIdProto { domain: }], + opset_import: [OperatorSetIdProto { domain: }OperatorSetIdProto { domain: org.pytorch.aten}], } diff --git a/test/expect/TestPytorchExportModes.test_onnx_aten.expect b/test/expect/TestPytorchExportModes.test_onnx_aten.expect index 222fa42704..22f1c57f95 100644 --- a/test/expect/TestPytorchExportModes.test_onnx_aten.expect +++ b/test/expect/TestPytorchExportModes.test_onnx_aten.expect @@ -12,5 +12,5 @@ ModelProto { Node {type: "ATen", inputs: [0,1], outputs: [2], attributes: [{ name: 'operator', type: string, value: 'fmod'}]} ] } - opset_import: [OperatorSetIdProto { domain: }], + opset_import: [OperatorSetIdProto { domain: }OperatorSetIdProto { domain: org.pytorch.aten}], } diff --git a/test/expect/TestScript.test_listconstruct_erasure.expect b/test/expect/TestScript.test_listconstruct_erasure.expect index 09442626e4..818a115e47 100644 --- a/test/expect/TestScript.test_listconstruct_erasure.expect +++ b/test/expect/TestScript.test_listconstruct_erasure.expect @@ -16,5 +16,5 @@ ModelProto { Node {type: "ATen", inputs: [0,4], outputs: [5], attributes: [{ name: 'operator', type: string, value: 'index'}]} ] } - opset_import: [OperatorSetIdProto { domain: }], + opset_import: [OperatorSetIdProto { domain: }OperatorSetIdProto { domain: org.pytorch.aten}], } diff --git a/test/expect/TestScript.test_onnx_raw_export_script_truediv.expect b/test/expect/TestScript.test_onnx_raw_export_script_truediv.expect index 4a0f296688..dbaff42ed7 100644 --- a/test/expect/TestScript.test_onnx_raw_export_script_truediv.expect +++ b/test/expect/TestScript.test_onnx_raw_export_script_truediv.expect @@ -21,5 +21,5 @@ ModelProto { Node {type: "add", inputs: [x,z,1], outputs: [10], attributes: []} ] } - opset_import: [OperatorSetIdProto { domain: }], + opset_import: [OperatorSetIdProto { domain: }OperatorSetIdProto { domain: org.pytorch.aten}OperatorSetIdProto { domain: org.pytorch.prim}], } diff --git a/test/onnx/expect/TestOperators.test_c2_op.expect b/test/onnx/expect/TestOperators.test_c2_op.expect new file mode 100644 index 0000000000..568df7594c --- /dev/null +++ b/test/onnx/expect/TestOperators.test_c2_op.expect @@ -0,0 +1,173 @@ +ir_version: 4 +producer_name: "pytorch" +producer_version: "1.1" +graph { + node { + input: "0" + input: "1" + input: "2" + input: "3" + output: "4" + output: "5" + op_type: "GenerateProposals" + attribute { + name: "spatial_scale" + f: 2 + type: FLOAT + } + attribute { + name: "pre_nms_topN" + i: 6000 + type: INT + } + attribute { + name: "post_nms_topN" + i: 300 + type: INT + } + attribute { + name: "nms_thresh" + f: 0.7 + type: FLOAT + } + attribute { + name: "min_size" + f: 16 + type: FLOAT + } + attribute { + name: "angle_bound_on" + i: 1 + type: INT + } + attribute { + name: "angle_bound_lo" + i: -90 + type: INT + } + attribute { + name: "angle_bound_hi" + i: 90 + type: INT + } + attribute { + name: "clip_angle_thresh" + f: 1 + type: FLOAT + } + domain: "org.pytorch._caffe2" + } + name: "torch-jit-export" + input { + name: "0" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + dim { + dim_value: 10 + } + dim { + dim_value: 8 + } + } + } + } + } + input { + name: "1" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 16 + } + dim { + dim_value: 10 + } + dim { + dim_value: 8 + } + } + } + } + } + input { + name: "2" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 3 + } + } + } + } + } + input { + name: "3" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 4 + } + dim { + dim_value: 4 + } + } + } + } + } + output { + name: "4" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 0 + } + dim { + dim_value: 5 + } + } + } + } + } + output { + name: "5" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 0 + } + } + } + } + } +} +opset_import { + version: 9 +} +opset_import { + domain: "org.pytorch._caffe2" + version: 0 +} diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py index 061c3b0ea1..77764f55db 100644 --- a/test/onnx/test_operators.py +++ b/test/onnx/test_operators.py @@ -569,6 +569,32 @@ class TestOperators(TestCase): x = torch.randn(3, 4).float() self.assertONNX(MyModule(), (x,), _retain_param_name=False) + def test_c2_op(self): + class MyModel(torch.nn.Module): + def __init__(self): + super(MyModel, self).__init__() + + def forward(self, scores, bbox_deltas, im_info, anchors): + a, b = torch.ops._caffe2.GenerateProposals( + (scores), (bbox_deltas), (im_info), (anchors), + 2.0, 6000, 300, 0.7, 16, True, -90, 90, 1.0, + ) + return a, b + + model = MyModel() + A = 4 + H = 10 + W = 8 + img_count = 3 + scores = torch.ones(img_count, A, H, W, dtype=torch.float32) + bbox_deltas = torch.linspace(0, 10, steps=img_count * 4 * A * H * W, + dtype=torch.float32) + bbox_deltas = bbox_deltas.view(img_count, 4 * A, H, W) + im_info = torch.ones(img_count, 3, dtype=torch.float32) + anchors = torch.ones(A, 4, dtype=torch.float32) + inputs = (scores, bbox_deltas, im_info, anchors) + self.assertONNX(model, inputs) + if __name__ == '__main__': no_onnx_dep_flag = '--no-onnx' diff --git a/test/onnx/test_pytorch_onnx_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py index daea9ef412..f61c6808c3 100644 --- a/test/onnx/test_pytorch_onnx_caffe2.py +++ b/test/onnx/test_pytorch_onnx_caffe2.py @@ -33,6 +33,8 @@ import model_defs.word_language_model as word_language_model from model_defs.mnist import MNIST from model_defs.lstm_flattening_result import LstmFlatteningResult from model_defs.rnn_model_with_packed_sequence import RnnModelWithPackedSequence +from caffe2.python.operator_test.torch_integration_test import (generate_rois_rotated, + create_bbox_transform_inputs) import onnx import caffe2.python.onnx.backend as c2 @@ -1230,6 +1232,189 @@ class TestCaffe2Backend(unittest.TestCase): x = torch.randn(3, 3, requires_grad=True) self.run_model_test(NarrowModel(), train=False, input=x, batch_size=BATCH_SIZE) + def test_c2_roi_align(self): + class MyModel(torch.nn.Module): + def __init__(self): + super(MyModel, self).__init__() + + def forward(self, feature, rois): + roi_feature = torch.ops._caffe2.RoIAlign( + feature, rois, order="NCHW", spatial_scale=1.0, + pooled_h=3, pooled_w=3, sampling_ratio=3, + ) + return roi_feature + + def rand_roi(N, C, H, W): + return [ + float(int(N * np.random.rand())), + 0.5 * np.random.rand() * W, + 0.5 * np.random.rand() * H, + (0.5 + 0.5 * np.random.rand()) * W, + (0.5 + 0.5 * np.random.rand()) * H, + ] + + N, C, H, W = 1, 4, 10, 8 + feature = torch.randn(N, C, H, W) + rois = torch.tensor([rand_roi(N, C, H, W) for _ in range(10)]) + inputs = (feature, rois) + self.run_model_test(MyModel(), train=False, input=inputs, batch_size=3) + + def test_c2_generate_proposals(self): + class MyModel(torch.nn.Module): + def __init__(self): + super(MyModel, self).__init__() + + def forward(self, scores, bbox_deltas, im_info, anchors): + a, b = torch.ops._caffe2.GenerateProposals( + scores, bbox_deltas, im_info, anchors, + 2.0, 6000, 300, 0.7, 16, True, -90, 90, 1.0, + ) + return a, b + + A = 4 + H = 10 + W = 8 + img_count = 3 + scores = torch.ones(img_count, A, H, W, dtype=torch.float32) + bbox_deltas = torch.linspace(0, 10, steps=img_count * 4 * A * H * W, + dtype=torch.float32) + bbox_deltas = bbox_deltas.view(img_count, 4 * A, H, W) + im_info = torch.ones(img_count, 3, dtype=torch.float32) + anchors = torch.ones(A, 4, dtype=torch.float32) + inputs = (scores, bbox_deltas, im_info, anchors) + self.run_model_test(MyModel(), train=False, input=inputs, batch_size=3) + + def test_c2_bbox_transform(self): + class MyModel(torch.nn.Module): + def __init__(self): + super(MyModel, self).__init__() + + def forward(self, rois, deltas, im_info): + a, b = torch.ops._caffe2.BBoxTransform( + rois, + deltas, + im_info, + weights=[1., 1., 1., 1.], + apply_scale=False, + rotated=True, + angle_bound_on=True, + angle_bound_lo=-90, + angle_bound_hi=90, + clip_angle_thresh=0.5, + ) + return a, b + + roi_counts = [0, 2, 3, 4, 5] + batch_size = len(roi_counts) + total_rois = sum(roi_counts) + im_dims = np.random.randint(100, 600, batch_size) + rois = generate_rois_rotated(roi_counts, im_dims) + box_dim = 5 + num_classes = 7 + deltas = np.random.randn(total_rois, box_dim * num_classes).astype(np.float32) + im_info = np.zeros((batch_size, 3)).astype(np.float32) + im_info[:, 0] = im_dims + im_info[:, 1] = im_dims + im_info[:, 2] = 1.0 + im_info = torch.zeros((batch_size, 3)) + inputs = (torch.tensor(rois), torch.tensor(deltas), torch.tensor(im_info)) + self.run_model_test(MyModel(), train=False, input=inputs, batch_size=3) + + # BoxWithNMSLimits has requirements for the inputs, so randomly generated inputs + # in Caffe2BackendTestEmbed doesn't work with this op. + @skipIfEmbed + def test_c2_box_with_nms_limits(self): + roi_counts = [0, 2, 3, 4, 5] + num_classes = 7 + rotated = False + angle_bound_on = True + clip_angle_thresh = 0.5 + rois, deltas, im_info = create_bbox_transform_inputs( + roi_counts, num_classes, rotated + ) + pred_bbox, batch_splits = [ + t.detach().numpy() + for t in torch.ops._caffe2.BBoxTransform( + torch.tensor(rois), + torch.tensor(deltas), + torch.tensor(im_info), + [1.0, 1.0, 1.0, 1.0], + False, + rotated, + angle_bound_on, + -90, + 90, + clip_angle_thresh, + ) + ] + class_prob = np.random.randn(sum(roi_counts), num_classes).astype(np.float32) + score_thresh = 0.5 + nms_thresh = 0.5 + topk_per_image = int(sum(roi_counts) / 2) + + class MyModel(torch.nn.Module): + def __init__(self): + super(MyModel, self).__init__() + + def forward(self, class_prob, pred_bbox, batch_splits): + a, b, c, d = torch.ops._caffe2.BoxWithNMSLimit( + class_prob, + pred_bbox, + batch_splits, + score_thresh=score_thresh, + nms=nms_thresh, + detections_per_im=topk_per_image, + soft_nms_enabled=False, + soft_nms_method="linear", + soft_nms_sigma=0.5, + soft_nms_min_score_thres=0.001, + rotated=rotated, + ) + return a, b, c, d + + inputs = (torch.tensor(class_prob), torch.tensor(pred_bbox), torch.tensor(batch_splits)) + self.run_model_test(MyModel(), train=False, input=inputs, batch_size=3) + + def test_c2_inference_lstm(self): + num_layers = 4 + seq_lens = 6 + emb_lens = 10 + has_bias = True + batch_first = True + is_bidirectional = True + + class MyModel(torch.nn.Module): + def __init__(self): + super(MyModel, self).__init__() + + def forward(self, lstm_in): + a, b, c = torch.ops._caffe2.InferenceLSTM( + lstm_in, num_layers, has_bias, batch_first, is_bidirectional + ) + return a, b, c + + num_directions = 2 + bsz = 5 + hidden_size = 7 + hx = np.zeros((num_layers * num_directions, bsz, hidden_size), dtype=np.float32) + inputs = np.random.randn(bsz, seq_lens, emb_lens).astype(np.float32) + torch_lstm = torch.nn.LSTM( + emb_lens, + hidden_size, + batch_first=batch_first, + bidirectional=is_bidirectional, + bias=has_bias, + num_layers=num_layers, + ) + lstm_in = [ + torch.from_numpy(inputs), + torch.from_numpy(hx), + torch.from_numpy(hx), + ] + [param.detach() for param in torch_lstm._flat_weights] + + self.run_model_test(MyModel(), train=False, input=lstm_in, batch_size=3) + + # a bit of metaprogramming to set up all the rnn tests diff --git a/test/onnx/verify.py b/test/onnx/verify.py index 61defae9ca..95f42a180c 100644 --- a/test/onnx/verify.py +++ b/test/onnx/verify.py @@ -68,7 +68,8 @@ class Errors(object): """ if isinstance(x, np.ndarray) and isinstance(y, np.ndarray): try: - np.testing.assert_allclose(x, y, rtol=self.rtol, atol=self.atol, equal_nan=False, verbose=True) + np.testing.assert_allclose(x, y, rtol=self.rtol, atol=self.atol, + equal_nan=True, verbose=True) except AssertionError as e: raise k("{}{}".format(colonize(msg), str(e).lstrip())) |