summaryrefslogtreecommitdiff
path: root/aten
diff options
context:
space:
mode:
authorSebastian Messmer <messmer@fb.com>2019-01-22 13:21:38 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-01-22 13:29:11 -0800
commitcd8f4154f41a88ccd4f55ce72af2036afbf826a7 (patch)
treed5b1330f11a0fe85fd4918a5b15f267431e031fd /aten
parent6192831b7655bce09b70e31c1e34ddaf6d9f146c (diff)
downloadpytorch-cd8f4154f41a88ccd4f55ce72af2036afbf826a7.tar.gz
pytorch-cd8f4154f41a88ccd4f55ce72af2036afbf826a7.tar.bz2
pytorch-cd8f4154f41a88ccd4f55ce72af2036afbf826a7.zip
Avoid closure around kernel (#16165)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16165 Store kernels as direct function pointers instead of std::function. Using direct function pointers avoids a performance risk std::function would introduce. Reviewed By: ezyang Differential Revision: D13738627 fbshipit-source-id: a348906c8a201436699681980a82ca95065a06a0
Diffstat (limited to 'aten')
-rw-r--r--aten/src/ATen/core/dispatch/DispatchTable.h28
-rw-r--r--aten/src/ATen/core/dispatch/KernelRegistration.h16
-rw-r--r--aten/src/ATen/core/dispatch/OpSchema.h58
3 files changed, 44 insertions, 58 deletions
diff --git a/aten/src/ATen/core/dispatch/DispatchTable.h b/aten/src/ATen/core/dispatch/DispatchTable.h
index 39c850a14f..cc7ef8a8d2 100644
--- a/aten/src/ATen/core/dispatch/DispatchTable.h
+++ b/aten/src/ATen/core/dispatch/DispatchTable.h
@@ -21,8 +21,8 @@ template <class Key>
class ThreadsafeOperatorTable_ final {
public:
template <class Key_>
- void emplace(Key_&& key, KernelFunction value) {
- bool res = map_.write([&](ska::flat_hash_map<Key, KernelFunction>& map) -> bool {
+ void emplace(Key_&& key, KernelFunction* value) {
+ bool res = map_.write([&](ska::flat_hash_map<Key, KernelFunction*>& map) -> bool {
auto result = map.emplace(std::forward<Key>(key), std::move(value));
return result.second;
});
@@ -35,7 +35,7 @@ class ThreadsafeOperatorTable_ final {
void erase(const Key& key) {
auto num_removed =
- map_.write([&](ska::flat_hash_map<Key, KernelFunction>& map) -> size_t {
+ map_.write([&](ska::flat_hash_map<Key, KernelFunction*>& map) -> size_t {
return map.erase(key);
});
assert(num_removed <= 1); // This is not a multi-map
@@ -45,11 +45,11 @@ class ThreadsafeOperatorTable_ final {
}
}
- const KernelFunction* lookup(const Key& key) const {
- return map_.read([&](const ska::flat_hash_map<Key, KernelFunction>& map) -> const KernelFunction* {
+ KernelFunction* lookup(const Key& key) const {
+ return map_.read([&](const ska::flat_hash_map<Key, KernelFunction*>& map) -> KernelFunction* {
auto found = map.find(key);
if (found != map.end()) {
- return &found->second;
+ return found->second;
} else {
return nullptr;
}
@@ -57,7 +57,7 @@ class ThreadsafeOperatorTable_ final {
}
private:
- LeftRight<ska::flat_hash_map<Key, KernelFunction>> map_;
+ LeftRight<ska::flat_hash_map<Key, KernelFunction*>> map_;
};
} // namespace details
@@ -87,9 +87,9 @@ class DispatchTable final {
* @param dispatch_key Dispatch key to define when this kernel is selected
*/
void registerKernel(
- KernelFunction func,
+ KernelFunction* func,
typename Schema::dispatch::dispatch_key_type dispatch_key) {
- kernels_.emplace(std::move(dispatch_key), std::move(func));
+ kernels_.emplace(std::move(dispatch_key), func);
}
/**
@@ -118,14 +118,14 @@ class DispatchTable final {
// static_assert(std::is_same<typename Schema::return_type (Args...),
// typename Schema::func_type>::value, "Argument types don't match
// operator signature");
- const auto& kernel_func = lookupKernelFunc_(args);
- return kernel_func(args);
+ KernelFunction* kernel_func = lookupKernelFunc_(args);
+ return (*kernel_func)(args);
}
private:
- const KernelFunction& lookupKernelFunc_(ArrayRef<IValue> args) const {
+ KernelFunction* lookupKernelFunc_(ArrayRef<IValue> args) const {
auto dispatch_key = Schema::dispatch::dispatch_key(args);
- const KernelFunction* found = kernels_.lookup(dispatch_key);
+ KernelFunction* found = kernels_.lookup(dispatch_key);
if (found == nullptr) {
// TODO Better error message - include op name and dispatch key (i.e.
// argument types)
@@ -133,7 +133,7 @@ class DispatchTable final {
std::string() + "Didn't find kernel to dispatch to for operator '" +
Schema::metadata::name() + "'");
}
- return *found;
+ return found;
}
details::ThreadsafeOperatorTable_<
diff --git a/aten/src/ATen/core/dispatch/KernelRegistration.h b/aten/src/ATen/core/dispatch/KernelRegistration.h
index 13aa17ac52..935b6c00c7 100644
--- a/aten/src/ATen/core/dispatch/KernelRegistration.h
+++ b/aten/src/ATen/core/dispatch/KernelRegistration.h
@@ -34,7 +34,7 @@ public:
* @param kernel The concrete function implementation to register
* @param dispatch_key The dispatch key to register the function to
*/
- KernelRegistrar(KernelFunction kernel, typename Schema::dispatch::dispatch_key_type dispatch_key)
+ KernelRegistrar(KernelFunction* kernel, typename Schema::dispatch::dispatch_key_type dispatch_key)
: dispatch_key_(std::move(dispatch_key)), owns_registration_(true) {
Dispatcher<OpSchemaDef>::registerKernel(kernel, dispatch_key_);
}
@@ -88,7 +88,7 @@ private:
static constexpr uint64_t KERNEL_PRESENT = 0x01 << 0;
static constexpr uint64_t DISPATCH_KEY_PRESENT = 0x01 << 1;
- c10::optional<KernelFunction> kernel_;
+ c10::optional<KernelFunction*> kernel_;
c10::optional<typename Schema::dispatch::dispatch_key_type> dispatch_key_;
public:
@@ -96,7 +96,7 @@ private:
: KernelRegistrationBuilder(c10::nullopt, c10::nullopt) {}
KernelRegistrationBuilder(
- c10::optional<KernelFunction> kernel,
+ c10::optional<KernelFunction*> kernel,
c10::optional<typename Schema::dispatch::dispatch_key_type> dispatch_key)
: kernel_(std::move(kernel)), dispatch_key_(std::move(dispatch_key)) {}
@@ -116,9 +116,10 @@ private:
* @param kernel concrete function implementation to be registered
* @return "this" for method chaining
*/
- KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | KERNEL_PRESENT> kernel(KernelFunction kernel_func) && {
+ template<KernelFunction* kernel_func>
+ KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | KERNEL_PRESENT> kernel() && {
static_assert(!(FieldsPresentFlags & KERNEL_PRESENT), "Tried to define kernel twice in same op registration");
- return KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | KERNEL_PRESENT>(std::move(kernel_func), std::move(dispatch_key_));
+ return KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | KERNEL_PRESENT>(kernel_func, std::move(dispatch_key_));
}
/**
@@ -126,8 +127,9 @@ private:
* @param kernel concrete function implementation to be registered
* @return "this" for method chaining
*/
- KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | KERNEL_PRESENT> kernel(typename Schema::signature::func_type* kernel_func) && {
- return std::move(*this).kernel(Schema::signature::wrap_kernel(kernel_func));
+ template<typename Schema::signature::func_type* kernel_func>
+ KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | KERNEL_PRESENT> kernel() && {
+ return std::move(*this).template kernel<&Schema::signature::template wrap_kernel<kernel_func>>();
}
/**
diff --git a/aten/src/ATen/core/dispatch/OpSchema.h b/aten/src/ATen/core/dispatch/OpSchema.h
index db6f3e722d..a386b8d5e8 100644
--- a/aten/src/ATen/core/dispatch/OpSchema.h
+++ b/aten/src/ATen/core/dispatch/OpSchema.h
@@ -10,8 +10,7 @@
namespace c10 {
-// TODO Use folly::Function for perf
-using KernelFunction = std::function<IValue(ArrayRef<IValue>)>;
+using KernelFunction = IValue(ArrayRef<IValue>);
namespace details {
@@ -128,50 +127,33 @@ struct ivalue_to_arg_type<ArrayRef<T>> {
}
};
-template<class ReturnType, class ParamTypes, class FuncType> struct _wrapKernel {};
-template<class ReturnType, class... ParamTypes, class FuncType> struct _wrapKernel<ReturnType, guts::typelist::typelist<ParamTypes...>, FuncType> {
+template<class ReturnType, class ParamTypes, class FuncType, FuncType* kernel> struct _wrapKernel {};
+template<class ReturnType, class... ParamTypes, class FuncType, FuncType* kernel> struct _wrapKernel<ReturnType, guts::typelist::typelist<ParamTypes...>, FuncType, kernel> {
using parameter_types = guts::typelist::typelist<ParamTypes...>;
template<size_t... indices>
- static KernelFunction call(FuncType* kernel, guts::index_sequence<indices...>) {
- return [kernel] (ArrayRef<IValue> args) -> IValue {
- if (args.size() != sizeof...(ParamTypes)) {
- throw std::runtime_error("Wrong number of arguments for operator call");
- }
- return return_type_to_ivalue(
- (*kernel)(ivalue_to_arg_type<guts::remove_cv_t<guts::remove_reference_t<guts::typelist::element_t<indices, parameter_types>>>>::call(args[indices])...)
- );
- };
+ static IValue call(ArrayRef<IValue> args, guts::index_sequence<indices...>) {
+ if (args.size() != sizeof...(ParamTypes)) {
+ throw std::runtime_error("Wrong number of arguments for operator call");
+ }
+ return return_type_to_ivalue(
+ (*kernel)(ivalue_to_arg_type<guts::remove_cv_t<guts::remove_reference_t<guts::typelist::element_t<indices, parameter_types>>>>::call(args[indices])...)
+ );
}
};
-template<class... ParamTypes, class FuncType> struct _wrapKernel<void, guts::typelist::typelist<ParamTypes...>, FuncType> {
+template<class... ParamTypes, class FuncType, FuncType* kernel> struct _wrapKernel<void, guts::typelist::typelist<ParamTypes...>, FuncType, kernel> {
using parameter_types = guts::typelist::typelist<ParamTypes...>;
template<size_t... indices>
- static KernelFunction call(FuncType* kernel, guts::index_sequence<indices...>) {
- return [kernel] (ArrayRef<IValue> args) -> IValue {
- if (args.size() != sizeof...(ParamTypes)) {
- throw std::runtime_error("Wrong number of arguments for operator call");
- }
- (*kernel)(ivalue_to_arg_type<guts::remove_cv_t<guts::remove_reference_t<guts::typelist::element_t<indices, parameter_types>>>>::call(args[indices])...);
- return IValue();
- };
+ static IValue call(ArrayRef<IValue> args, guts::index_sequence<indices...>) {
+ if (args.size() != sizeof...(ParamTypes)) {
+ throw std::runtime_error("Wrong number of arguments for operator call");
+ }
+ (*kernel)(ivalue_to_arg_type<guts::remove_cv_t<guts::remove_reference_t<guts::typelist::element_t<indices, parameter_types>>>>::call(args[indices])...);
+ return IValue();
}
};
-template<class SignatureTraits>
-KernelFunction wrapKernel(typename SignatureTraits::func_type* kernel) {
- using return_type = typename SignatureTraits::return_type;
- using parameter_types = typename SignatureTraits::parameter_types;
- using func_type = typename SignatureTraits::func_type;
- constexpr size_t num_parameters = guts::typelist::size<parameter_types>::value;
-
- return _wrapKernel<return_type, parameter_types, func_type>::call(
- kernel,
- guts::make_index_sequence<num_parameters>()
- );
-}
-
/**
* Wrapper class around a user-provided schema definition some useful information about the schema.
*
@@ -207,8 +189,10 @@ public:
static constexpr size_t num_outputs = OpSchemaDef::num_outputs();
- static KernelFunction wrap_kernel(func_type* kernel) {
- return details::wrapKernel<signature_traits>(kernel);
+ template<func_type* kernel>
+ static IValue wrap_kernel(ArrayRef<IValue> args) {
+ constexpr size_t num_parameters = guts::typelist::size<parameter_types>::value;
+ return details::_wrapKernel<return_type, parameter_types, func_type, kernel>::call(args, guts::make_index_sequence<num_parameters>());
}
private: