diff options
author | Sebastian Messmer <messmer@fb.com> | 2019-01-22 13:21:38 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-01-22 13:29:11 -0800 |
commit | cd8f4154f41a88ccd4f55ce72af2036afbf826a7 (patch) | |
tree | d5b1330f11a0fe85fd4918a5b15f267431e031fd /aten | |
parent | 6192831b7655bce09b70e31c1e34ddaf6d9f146c (diff) | |
download | pytorch-cd8f4154f41a88ccd4f55ce72af2036afbf826a7.tar.gz pytorch-cd8f4154f41a88ccd4f55ce72af2036afbf826a7.tar.bz2 pytorch-cd8f4154f41a88ccd4f55ce72af2036afbf826a7.zip |
Avoid closure around kernel (#16165)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16165
Store kernels as direct function pointers instead of std::function.
Using direct function pointers avoids a performance risk std::function would introduce.
Reviewed By: ezyang
Differential Revision: D13738627
fbshipit-source-id: a348906c8a201436699681980a82ca95065a06a0
Diffstat (limited to 'aten')
-rw-r--r-- | aten/src/ATen/core/dispatch/DispatchTable.h | 28 | ||||
-rw-r--r-- | aten/src/ATen/core/dispatch/KernelRegistration.h | 16 | ||||
-rw-r--r-- | aten/src/ATen/core/dispatch/OpSchema.h | 58 |
3 files changed, 44 insertions, 58 deletions
diff --git a/aten/src/ATen/core/dispatch/DispatchTable.h b/aten/src/ATen/core/dispatch/DispatchTable.h index 39c850a14f..cc7ef8a8d2 100644 --- a/aten/src/ATen/core/dispatch/DispatchTable.h +++ b/aten/src/ATen/core/dispatch/DispatchTable.h @@ -21,8 +21,8 @@ template <class Key> class ThreadsafeOperatorTable_ final { public: template <class Key_> - void emplace(Key_&& key, KernelFunction value) { - bool res = map_.write([&](ska::flat_hash_map<Key, KernelFunction>& map) -> bool { + void emplace(Key_&& key, KernelFunction* value) { + bool res = map_.write([&](ska::flat_hash_map<Key, KernelFunction*>& map) -> bool { auto result = map.emplace(std::forward<Key>(key), std::move(value)); return result.second; }); @@ -35,7 +35,7 @@ class ThreadsafeOperatorTable_ final { void erase(const Key& key) { auto num_removed = - map_.write([&](ska::flat_hash_map<Key, KernelFunction>& map) -> size_t { + map_.write([&](ska::flat_hash_map<Key, KernelFunction*>& map) -> size_t { return map.erase(key); }); assert(num_removed <= 1); // This is not a multi-map @@ -45,11 +45,11 @@ class ThreadsafeOperatorTable_ final { } } - const KernelFunction* lookup(const Key& key) const { - return map_.read([&](const ska::flat_hash_map<Key, KernelFunction>& map) -> const KernelFunction* { + KernelFunction* lookup(const Key& key) const { + return map_.read([&](const ska::flat_hash_map<Key, KernelFunction*>& map) -> KernelFunction* { auto found = map.find(key); if (found != map.end()) { - return &found->second; + return found->second; } else { return nullptr; } @@ -57,7 +57,7 @@ class ThreadsafeOperatorTable_ final { } private: - LeftRight<ska::flat_hash_map<Key, KernelFunction>> map_; + LeftRight<ska::flat_hash_map<Key, KernelFunction*>> map_; }; } // namespace details @@ -87,9 +87,9 @@ class DispatchTable final { * @param dispatch_key Dispatch key to define when this kernel is selected */ void registerKernel( - KernelFunction func, + KernelFunction* func, typename Schema::dispatch::dispatch_key_type dispatch_key) { - kernels_.emplace(std::move(dispatch_key), std::move(func)); + kernels_.emplace(std::move(dispatch_key), func); } /** @@ -118,14 +118,14 @@ class DispatchTable final { // static_assert(std::is_same<typename Schema::return_type (Args...), // typename Schema::func_type>::value, "Argument types don't match // operator signature"); - const auto& kernel_func = lookupKernelFunc_(args); - return kernel_func(args); + KernelFunction* kernel_func = lookupKernelFunc_(args); + return (*kernel_func)(args); } private: - const KernelFunction& lookupKernelFunc_(ArrayRef<IValue> args) const { + KernelFunction* lookupKernelFunc_(ArrayRef<IValue> args) const { auto dispatch_key = Schema::dispatch::dispatch_key(args); - const KernelFunction* found = kernels_.lookup(dispatch_key); + KernelFunction* found = kernels_.lookup(dispatch_key); if (found == nullptr) { // TODO Better error message - include op name and dispatch key (i.e. // argument types) @@ -133,7 +133,7 @@ class DispatchTable final { std::string() + "Didn't find kernel to dispatch to for operator '" + Schema::metadata::name() + "'"); } - return *found; + return found; } details::ThreadsafeOperatorTable_< diff --git a/aten/src/ATen/core/dispatch/KernelRegistration.h b/aten/src/ATen/core/dispatch/KernelRegistration.h index 13aa17ac52..935b6c00c7 100644 --- a/aten/src/ATen/core/dispatch/KernelRegistration.h +++ b/aten/src/ATen/core/dispatch/KernelRegistration.h @@ -34,7 +34,7 @@ public: * @param kernel The concrete function implementation to register * @param dispatch_key The dispatch key to register the function to */ - KernelRegistrar(KernelFunction kernel, typename Schema::dispatch::dispatch_key_type dispatch_key) + KernelRegistrar(KernelFunction* kernel, typename Schema::dispatch::dispatch_key_type dispatch_key) : dispatch_key_(std::move(dispatch_key)), owns_registration_(true) { Dispatcher<OpSchemaDef>::registerKernel(kernel, dispatch_key_); } @@ -88,7 +88,7 @@ private: static constexpr uint64_t KERNEL_PRESENT = 0x01 << 0; static constexpr uint64_t DISPATCH_KEY_PRESENT = 0x01 << 1; - c10::optional<KernelFunction> kernel_; + c10::optional<KernelFunction*> kernel_; c10::optional<typename Schema::dispatch::dispatch_key_type> dispatch_key_; public: @@ -96,7 +96,7 @@ private: : KernelRegistrationBuilder(c10::nullopt, c10::nullopt) {} KernelRegistrationBuilder( - c10::optional<KernelFunction> kernel, + c10::optional<KernelFunction*> kernel, c10::optional<typename Schema::dispatch::dispatch_key_type> dispatch_key) : kernel_(std::move(kernel)), dispatch_key_(std::move(dispatch_key)) {} @@ -116,9 +116,10 @@ private: * @param kernel concrete function implementation to be registered * @return "this" for method chaining */ - KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | KERNEL_PRESENT> kernel(KernelFunction kernel_func) && { + template<KernelFunction* kernel_func> + KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | KERNEL_PRESENT> kernel() && { static_assert(!(FieldsPresentFlags & KERNEL_PRESENT), "Tried to define kernel twice in same op registration"); - return KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | KERNEL_PRESENT>(std::move(kernel_func), std::move(dispatch_key_)); + return KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | KERNEL_PRESENT>(kernel_func, std::move(dispatch_key_)); } /** @@ -126,8 +127,9 @@ private: * @param kernel concrete function implementation to be registered * @return "this" for method chaining */ - KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | KERNEL_PRESENT> kernel(typename Schema::signature::func_type* kernel_func) && { - return std::move(*this).kernel(Schema::signature::wrap_kernel(kernel_func)); + template<typename Schema::signature::func_type* kernel_func> + KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | KERNEL_PRESENT> kernel() && { + return std::move(*this).template kernel<&Schema::signature::template wrap_kernel<kernel_func>>(); } /** diff --git a/aten/src/ATen/core/dispatch/OpSchema.h b/aten/src/ATen/core/dispatch/OpSchema.h index db6f3e722d..a386b8d5e8 100644 --- a/aten/src/ATen/core/dispatch/OpSchema.h +++ b/aten/src/ATen/core/dispatch/OpSchema.h @@ -10,8 +10,7 @@ namespace c10 { -// TODO Use folly::Function for perf -using KernelFunction = std::function<IValue(ArrayRef<IValue>)>; +using KernelFunction = IValue(ArrayRef<IValue>); namespace details { @@ -128,50 +127,33 @@ struct ivalue_to_arg_type<ArrayRef<T>> { } }; -template<class ReturnType, class ParamTypes, class FuncType> struct _wrapKernel {}; -template<class ReturnType, class... ParamTypes, class FuncType> struct _wrapKernel<ReturnType, guts::typelist::typelist<ParamTypes...>, FuncType> { +template<class ReturnType, class ParamTypes, class FuncType, FuncType* kernel> struct _wrapKernel {}; +template<class ReturnType, class... ParamTypes, class FuncType, FuncType* kernel> struct _wrapKernel<ReturnType, guts::typelist::typelist<ParamTypes...>, FuncType, kernel> { using parameter_types = guts::typelist::typelist<ParamTypes...>; template<size_t... indices> - static KernelFunction call(FuncType* kernel, guts::index_sequence<indices...>) { - return [kernel] (ArrayRef<IValue> args) -> IValue { - if (args.size() != sizeof...(ParamTypes)) { - throw std::runtime_error("Wrong number of arguments for operator call"); - } - return return_type_to_ivalue( - (*kernel)(ivalue_to_arg_type<guts::remove_cv_t<guts::remove_reference_t<guts::typelist::element_t<indices, parameter_types>>>>::call(args[indices])...) - ); - }; + static IValue call(ArrayRef<IValue> args, guts::index_sequence<indices...>) { + if (args.size() != sizeof...(ParamTypes)) { + throw std::runtime_error("Wrong number of arguments for operator call"); + } + return return_type_to_ivalue( + (*kernel)(ivalue_to_arg_type<guts::remove_cv_t<guts::remove_reference_t<guts::typelist::element_t<indices, parameter_types>>>>::call(args[indices])...) + ); } }; -template<class... ParamTypes, class FuncType> struct _wrapKernel<void, guts::typelist::typelist<ParamTypes...>, FuncType> { +template<class... ParamTypes, class FuncType, FuncType* kernel> struct _wrapKernel<void, guts::typelist::typelist<ParamTypes...>, FuncType, kernel> { using parameter_types = guts::typelist::typelist<ParamTypes...>; template<size_t... indices> - static KernelFunction call(FuncType* kernel, guts::index_sequence<indices...>) { - return [kernel] (ArrayRef<IValue> args) -> IValue { - if (args.size() != sizeof...(ParamTypes)) { - throw std::runtime_error("Wrong number of arguments for operator call"); - } - (*kernel)(ivalue_to_arg_type<guts::remove_cv_t<guts::remove_reference_t<guts::typelist::element_t<indices, parameter_types>>>>::call(args[indices])...); - return IValue(); - }; + static IValue call(ArrayRef<IValue> args, guts::index_sequence<indices...>) { + if (args.size() != sizeof...(ParamTypes)) { + throw std::runtime_error("Wrong number of arguments for operator call"); + } + (*kernel)(ivalue_to_arg_type<guts::remove_cv_t<guts::remove_reference_t<guts::typelist::element_t<indices, parameter_types>>>>::call(args[indices])...); + return IValue(); } }; -template<class SignatureTraits> -KernelFunction wrapKernel(typename SignatureTraits::func_type* kernel) { - using return_type = typename SignatureTraits::return_type; - using parameter_types = typename SignatureTraits::parameter_types; - using func_type = typename SignatureTraits::func_type; - constexpr size_t num_parameters = guts::typelist::size<parameter_types>::value; - - return _wrapKernel<return_type, parameter_types, func_type>::call( - kernel, - guts::make_index_sequence<num_parameters>() - ); -} - /** * Wrapper class around a user-provided schema definition some useful information about the schema. * @@ -207,8 +189,10 @@ public: static constexpr size_t num_outputs = OpSchemaDef::num_outputs(); - static KernelFunction wrap_kernel(func_type* kernel) { - return details::wrapKernel<signature_traits>(kernel); + template<func_type* kernel> + static IValue wrap_kernel(ArrayRef<IValue> args) { + constexpr size_t num_parameters = guts::typelist::size<parameter_types>::value; + return details::_wrapKernel<return_type, parameter_types, func_type, kernel>::call(args, guts::make_index_sequence<num_parameters>()); } private: |