diff options
author | Sebastian Messmer <messmer@fb.com> | 2019-03-30 00:03:43 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-03-30 00:07:16 -0700 |
commit | 9abc8a5b47d116342a5c277c62cfce81fd9dd331 (patch) | |
tree | 896ecd6bbe4c8ee828e42a92183832e2064c20fe /caffe2 | |
parent | 6095814229b2354d4445bd7083e7d13d37f772aa (diff) | |
download | pytorch-9abc8a5b47d116342a5c277c62cfce81fd9dd331.tar.gz pytorch-9abc8a5b47d116342a5c277c62cfce81fd9dd331.tar.bz2 pytorch-9abc8a5b47d116342a5c277c62cfce81fd9dd331.zip |
New operator registration MVP (#18161)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18161
This introduces version 0 for the new operator registration.
For now, it only works with kernels that are defined as stack-based functions.
This is actually not the intended public API for defining kernels, but it's the basis which is going to be used to define the public APIs (see diffs on top for them),
and it's also the API used for exposing caffe2 operators.
This diff also switches the mechanism for exposing caffe2 operators to the new mechanism.
Reviewed By: dzhulgakov
Differential Revision: D14514231
fbshipit-source-id: 454ab7b5b46a10203aa27b175400d23f818dd1df
Diffstat (limited to 'caffe2')
-rw-r--r-- | caffe2/core/c10_operator.h | 107 |
1 files changed, 58 insertions, 49 deletions
diff --git a/caffe2/core/c10_operator.h b/caffe2/core/c10_operator.h index e23259b291..240a16be25 100644 --- a/caffe2/core/c10_operator.h +++ b/caffe2/core/c10_operator.h @@ -1,9 +1,8 @@ #pragma once -#include <vector> -#include <ATen/core/dispatch/OpSchemaRegistration.h> -#include <ATen/core/dispatch/KernelRegistration.h> #include <ATen/core/function_schema.h> +#include <ATen/core/op_registration/op_registration.h> +#include <vector> namespace caffe2 { namespace detail { @@ -80,12 +79,11 @@ inline void _call_caffe2_op_from_c10( // might reuse one of the preallocated tensors but doesn't have to. } -template <const c10::OperatorHandle& (*OpHandle)(), class Caffe2Operator> +template <const c10::FunctionSchema& (*Schema)(), class Caffe2Operator> void call_caffe2_op_from_c10( c10::Stack* stack, c10::KernelCache* cache) { // TODO Pass in correct cache type - _call_caffe2_op_from_c10( - stack, OpHandle().schema(), &_call_caffe2_op<Caffe2Operator>); + _call_caffe2_op_from_c10(stack, Schema(), &_call_caffe2_op<Caffe2Operator>); } inline c10::FunctionSchema make_function_schema_for_c10(const char* OperatorName, std::vector<c10::Argument> inputs, std::vector<c10::Argument> outputs) { @@ -105,6 +103,9 @@ inline c10::FunctionSchema make_function_schema_for_c10(const char* OperatorName std::move(outputs)); } +inline std::unique_ptr<c10::KernelCache> noCache() { + return nullptr; +} } } @@ -154,56 +155,64 @@ inline c10::FunctionSchema make_function_schema_for_c10(const char* OperatorName * input an input of type TensorList. There must be no other tensor inputs. */ #ifndef C10_MOBILE -#define C10_DECLARE_CAFFE2_OPERATOR(OperatorName) \ - namespace caffe2 { \ - namespace _c10_ops { \ - C10_DECLARE_OP_SCHEMA(OperatorName); \ - } \ +#define C10_DECLARE_CAFFE2_OPERATOR(OperatorName) \ + namespace caffe2 { \ + namespace _c10_ops { \ + CAFFE2_API const ::c10::FunctionSchema& schema_##OperatorName(); \ + } \ } // TODO This macro should take a JIT schema string instead of a vector of inputs and outputs. -#define C10_REGISTER_CAFFE2_OPERATOR_CPU( \ - OperatorName, Inputs, Outputs, OperatorClass) \ - /* Register the op schema with the c10 dispatcher */ \ - namespace caffe2 { \ - namespace _c10_ops { \ - C10_DEFINE_OP_SCHEMA( \ - OperatorName, \ - caffe2::detail::make_function_schema_for_c10( \ - #OperatorName, \ - Inputs, \ - Outputs)); \ - } \ - } \ - /* Register call_caffe2_op_from_c10 as a kernel with the c10 dispatcher */ \ - namespace c10 { \ - C10_REGISTER_KERNEL(caffe2::_c10_ops::OperatorName) /*.withCache<Cache>()*/ \ - .kernel<&caffe2::detail::call_caffe2_op_from_c10< \ - ::caffe2::_c10_ops::OperatorName, \ - OperatorClass>>() \ - .dispatchKey(CPUTensorId()); \ - } - -#define C10_REGISTER_CAFFE2_OPERATOR_CUDA(OperatorName, OperatorClass) \ - namespace c10 { \ - C10_REGISTER_KERNEL(caffe2::_c10_ops::OperatorName) /*.withCache<Cache>()*/ \ - .kernel<&caffe2::detail::call_caffe2_op_from_c10< \ - ::caffe2::_c10_ops::OperatorName, \ - OperatorClass>>() \ - .dispatchKey(CUDATensorId()); \ - } +#define C10_REGISTER_CAFFE2_OPERATOR_CPU( \ + OperatorName, Inputs, Outputs, OperatorClass) \ + /* Register the op schema with the c10 dispatcher */ \ + namespace caffe2 { \ + namespace _c10_ops { \ + C10_EXPORT const ::c10::FunctionSchema& schema_##OperatorName() { \ + static ::c10::FunctionSchema schema = \ + ::caffe2::detail::make_function_schema_for_c10( \ + #OperatorName, Inputs, Outputs); \ + return schema; \ + } \ + } \ + } \ + /* Register call_caffe2_op_from_c10 as a kernel with the c10 dispatcher */ \ + static auto registry_##OperatorName##_##__COUNTER__ = \ + ::c10::RegisterOperators().op( \ + ::caffe2::_c10_ops::schema_##OperatorName(), \ + ::c10::kernel( \ + &::caffe2::detail::call_caffe2_op_from_c10< \ + ::caffe2::_c10_ops::schema_##OperatorName, \ + OperatorClass>, \ + &::caffe2::detail::noCache), \ + ::c10::dispatchKey(::c10::CPUTensorId())); + +#define C10_REGISTER_CAFFE2_OPERATOR_CUDA(OperatorName, OperatorClass) \ + /* Register call_caffe2_op_from_c10 as a kernel with the c10 dispatcher */ \ + static auto registry_##OperatorName##_##__COUNTER__ = \ + ::c10::RegisterOperators().op( \ + ::caffe2::_c10_ops::schema_##OperatorName(), \ + ::c10::kernel( \ + &::caffe2::detail::call_caffe2_op_from_c10< \ + ::caffe2::_c10_ops::schema_##OperatorName, \ + OperatorClass>, \ + &::caffe2::detail::noCache), \ + ::c10::dispatchKey(::c10::CUDATensorId())); // You should never manually call the C10_REGISTER_CAFFE2_OPERATOR_HIP macro. // The C10_REGISTER_CAFFE2_OPERATOR_CUDA macro from above will be automatically // rewritten to C10_REGISTER_CAFFE2_OPERATOR_HIP by hipify. -#define C10_REGISTER_CAFFE2_OPERATOR_HIP(OperatorName, OperatorClass) \ - namespace c10 { \ - C10_REGISTER_KERNEL(caffe2::_c10_ops::OperatorName) /*.withCache<Cache>()*/ \ - .kernel<&caffe2::detail::call_caffe2_op_from_c10< \ - ::caffe2::_c10_ops::OperatorName, \ - OperatorClass>>() \ - .dispatchKey(HIPTensorId()); \ - } +#define C10_REGISTER_CAFFE2_OPERATOR_HIP(OperatorName, OperatorClass) \ + /* Register call_caffe2_op_from_c10 as a kernel with the c10 dispatcher */ \ + static auto registry_##OperatorName##_##__COUNTER__ = \ + ::c10::RegisterOperators().op( \ + ::caffe2::_c10_ops::schema_##OperatorName(), \ + ::c10::kernel( \ + &::caffe2::detail::call_caffe2_op_from_c10< \ + ::caffe2::_c10_ops::schema_##OperatorName, \ + OperatorClass>, \ + &::caffe2::detail::noCache), \ + ::c10::dispatchKey(::c10::HIPTensorId())); #else // Don't use c10 dispatcher on mobile because of binary size |