diff options
author | Lu Fang <lufang@fb.com> | 2019-04-08 16:01:30 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-04-08 16:06:00 -0700 |
commit | 443a58e03d00fa04a429ab625c1fc7e3a7e4d529 (patch) | |
tree | fa8e7bbe1a66fba3d5e086bd9439cada437b7b15 /caffe2 | |
parent | 09c19e10682884efe8433ea06009de589a5b4183 (diff) | |
download | pytorch-443a58e03d00fa04a429ab625c1fc7e3a7e4d529.tar.gz pytorch-443a58e03d00fa04a429ab625c1fc7e3a7e4d529.tar.bz2 pytorch-443a58e03d00fa04a429ab625c1fc7e3a7e4d529.zip |
Export C10 operator in PyTorch Model (#18210)
Summary:
Almost there, feel free to review.
these c10 operators are exported to _caffe2 domain.
TODO:
- [x] let the onnx checker pass
- [x] test tensor list as argument
- [x] test caffe2 backend and converter
- [x] check the c10 schema can be exported to onnx
- [x] refactor the test case to share some code
- [x] fix the problem in ONNX_ATEN_FALLBACK
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18210
Reviewed By: zrphercule
Differential Revision: D14600916
Pulled By: houseroad
fbshipit-source-id: 2592a75f21098fb6ceb38c5d00ee40e9e01cd144
Diffstat (limited to 'caffe2')
-rw-r--r-- | caffe2/core/c10_operator.h | 3 | ||||
-rw-r--r-- | caffe2/python/operator_test/torch_integration_test.py | 4 |
2 files changed, 6 insertions, 1 deletions
diff --git a/caffe2/core/c10_operator.h b/caffe2/core/c10_operator.h index 240a16be25..86d8911a1c 100644 --- a/caffe2/core/c10_operator.h +++ b/caffe2/core/c10_operator.h @@ -1,6 +1,7 @@ #pragma once #include <ATen/core/function_schema.h> +#include <ATen/core/interned_strings.h> #include <ATen/core/op_registration/op_registration.h> #include <vector> @@ -97,7 +98,7 @@ inline c10::FunctionSchema make_function_schema_for_c10(const char* OperatorName IValue()); return c10::FunctionSchema( - std::string("_caffe2::") + OperatorName, + Symbol::caffe2(OperatorName).toQualString(), "", std::move(actual_inputs), std::move(outputs)); diff --git a/caffe2/python/operator_test/torch_integration_test.py b/caffe2/python/operator_test/torch_integration_test.py index d8ce5b69a3..07343d986e 100644 --- a/caffe2/python/operator_test/torch_integration_test.py +++ b/caffe2/python/operator_test/torch_integration_test.py @@ -446,3 +446,7 @@ class TorchIntegration(hu.HypothesisTestCase): @unittest.skipIf(not workspace.has_cuda_support, "No cuda support") def test_roi_align_cuda(self): self._test_roi_align(device="cuda") + + +if __name__ == '__main__': + unittest.main() |