summaryrefslogtreecommitdiff
path: root/caffe2
diff options
context:
space:
mode:
authorLu Fang <lufang@fb.com>2019-04-08 16:01:30 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-04-08 16:06:00 -0700
commit443a58e03d00fa04a429ab625c1fc7e3a7e4d529 (patch)
treefa8e7bbe1a66fba3d5e086bd9439cada437b7b15 /caffe2
parent09c19e10682884efe8433ea06009de589a5b4183 (diff)
downloadpytorch-443a58e03d00fa04a429ab625c1fc7e3a7e4d529.tar.gz
pytorch-443a58e03d00fa04a429ab625c1fc7e3a7e4d529.tar.bz2
pytorch-443a58e03d00fa04a429ab625c1fc7e3a7e4d529.zip
Export C10 operator in PyTorch Model (#18210)
Summary: Almost there, feel free to review. these c10 operators are exported to _caffe2 domain. TODO: - [x] let the onnx checker pass - [x] test tensor list as argument - [x] test caffe2 backend and converter - [x] check the c10 schema can be exported to onnx - [x] refactor the test case to share some code - [x] fix the problem in ONNX_ATEN_FALLBACK Pull Request resolved: https://github.com/pytorch/pytorch/pull/18210 Reviewed By: zrphercule Differential Revision: D14600916 Pulled By: houseroad fbshipit-source-id: 2592a75f21098fb6ceb38c5d00ee40e9e01cd144
Diffstat (limited to 'caffe2')
-rw-r--r--caffe2/core/c10_operator.h3
-rw-r--r--caffe2/python/operator_test/torch_integration_test.py4
2 files changed, 6 insertions, 1 deletions
diff --git a/caffe2/core/c10_operator.h b/caffe2/core/c10_operator.h
index 240a16be25..86d8911a1c 100644
--- a/caffe2/core/c10_operator.h
+++ b/caffe2/core/c10_operator.h
@@ -1,6 +1,7 @@
#pragma once
#include <ATen/core/function_schema.h>
+#include <ATen/core/interned_strings.h>
#include <ATen/core/op_registration/op_registration.h>
#include <vector>
@@ -97,7 +98,7 @@ inline c10::FunctionSchema make_function_schema_for_c10(const char* OperatorName
IValue());
return c10::FunctionSchema(
- std::string("_caffe2::") + OperatorName,
+ Symbol::caffe2(OperatorName).toQualString(),
"",
std::move(actual_inputs),
std::move(outputs));
diff --git a/caffe2/python/operator_test/torch_integration_test.py b/caffe2/python/operator_test/torch_integration_test.py
index d8ce5b69a3..07343d986e 100644
--- a/caffe2/python/operator_test/torch_integration_test.py
+++ b/caffe2/python/operator_test/torch_integration_test.py
@@ -446,3 +446,7 @@ class TorchIntegration(hu.HypothesisTestCase):
@unittest.skipIf(not workspace.has_cuda_support, "No cuda support")
def test_roi_align_cuda(self):
self._test_roi_align(device="cuda")
+
+
+if __name__ == '__main__':
+ unittest.main()