summaryrefslogtreecommitdiff
path: root/aten
diff options
context:
space:
mode:
authorSebastian Messmer <messmer@fb.com>2019-02-01 12:44:55 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-02-01 13:52:01 -0800
commitaaa8ace48642a1a5774332084161e0f93f171e1f (patch)
tree4f468a87d962d1c83710d618e563fc9515286175 /aten
parenta40e8ce7c553d61024fdf4f8f2b7b13ff606e77b (diff)
downloadpytorch-aaa8ace48642a1a5774332084161e0f93f171e1f.tar.gz
pytorch-aaa8ace48642a1a5774332084161e0f93f171e1f.tar.bz2
pytorch-aaa8ace48642a1a5774332084161e0f93f171e1f.zip
Implement new c10 dispatcher (#16625)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16625 This is a squash of multiple PRs that refactored the old c10 dispatcher into a new one that follows the c10 dispatcher design doc. It is now unboxed and follows the Stack semantics from JIT. It also uses the runtime JIT schema instead of its own compile time schema definitions. Reviewed By: ezyang Differential Revision: D13907069 fbshipit-source-id: edcc4806ccd21474fdfb5a98516219b1956db13d
Diffstat (limited to 'aten')
-rw-r--r--aten/src/ATen/core/dispatch/DeviceId.cpp1
-rw-r--r--aten/src/ATen/core/dispatch/DeviceId.h36
-rw-r--r--aten/src/ATen/core/dispatch/DispatchKey.cpp1
-rw-r--r--aten/src/ATen/core/dispatch/DispatchKey.h97
-rw-r--r--aten/src/ATen/core/dispatch/DispatchTable.cpp1
-rw-r--r--aten/src/ATen/core/dispatch/DispatchTable.h159
-rw-r--r--aten/src/ATen/core/dispatch/Dispatcher.cpp7
-rw-r--r--aten/src/ATen/core/dispatch/Dispatcher.h143
-rw-r--r--aten/src/ATen/core/dispatch/KernelCache.h20
-rw-r--r--aten/src/ATen/core/dispatch/KernelFunction.h16
-rw-r--r--aten/src/ATen/core/dispatch/KernelRegistration.cpp1
-rw-r--r--aten/src/ATen/core/dispatch/KernelRegistration.h213
-rw-r--r--aten/src/ATen/core/dispatch/LayoutId.cpp1
-rw-r--r--aten/src/ATen/core/dispatch/LayoutId.h22
-rw-r--r--aten/src/ATen/core/dispatch/OpSchema.cpp1
-rw-r--r--aten/src/ATen/core/dispatch/OpSchema.h406
-rw-r--r--aten/src/ATen/core/dispatch/OpSchemaRegistration.cpp1
-rw-r--r--aten/src/ATen/core/dispatch/OpSchemaRegistration.h39
-rw-r--r--aten/src/ATen/core/dispatch/OpSchema_test.cpp29
-rw-r--r--aten/src/ATen/core/dispatch/README.md12
-rw-r--r--aten/src/ATen/core/jit_type.h66
-rw-r--r--aten/src/ATen/core/opschema/layer_norm.cpp23
-rw-r--r--aten/src/ATen/core/opschema/layer_norm.h27
-rw-r--r--aten/src/ATen/core/stack.h3
24 files changed, 524 insertions, 801 deletions
diff --git a/aten/src/ATen/core/dispatch/DeviceId.cpp b/aten/src/ATen/core/dispatch/DeviceId.cpp
deleted file mode 100644
index 7c65fbe70c..0000000000
--- a/aten/src/ATen/core/dispatch/DeviceId.cpp
+++ /dev/null
@@ -1 +0,0 @@
-#include <ATen/core/dispatch/DeviceId.h>
diff --git a/aten/src/ATen/core/dispatch/DeviceId.h b/aten/src/ATen/core/dispatch/DeviceId.h
deleted file mode 100644
index cdcaf4b635..0000000000
--- a/aten/src/ATen/core/dispatch/DeviceId.h
+++ /dev/null
@@ -1,36 +0,0 @@
-#pragma once
-
-#include <c10/util/C++17.h>
-#include <functional>
-#include <iostream>
-
-namespace c10 {
-
-enum class DeviceTypeId : uint8_t {
- // Don't use the int values here in the enum (i.e. don't do static_cast to or from int).
- // Instead, if you want to serialize this, write a function with switch/case.
- CPU = 0,
- CUDA = 1,
- UNDEFINED
-};
-
-inline std::ostream& operator<<(std::ostream& stream, DeviceTypeId device_type_id) {
- switch(device_type_id) {
- case c10::DeviceTypeId::CPU: return stream << "DeviceTypeId(CPU)";
- case c10::DeviceTypeId::CUDA: return stream << "DeviceTypeId(CUDA)";
- case c10::DeviceTypeId::UNDEFINED: return stream << "DeviceTypeId(UNDEFINED)";
- }
- throw std::logic_error("Unknown DeviceTypeId: " + c10::guts::to_string(static_cast<int>(device_type_id)));
-}
-
-}
-
-namespace std {
-
-template <> struct hash<c10::DeviceTypeId> {
- size_t operator()(c10::DeviceTypeId v) const {
- return std::hash<uint8_t>()(static_cast<uint8_t>(v));
- }
-};
-
-}
diff --git a/aten/src/ATen/core/dispatch/DispatchKey.cpp b/aten/src/ATen/core/dispatch/DispatchKey.cpp
deleted file mode 100644
index 33e8e29b28..0000000000
--- a/aten/src/ATen/core/dispatch/DispatchKey.cpp
+++ /dev/null
@@ -1 +0,0 @@
-#include <ATen/core/dispatch/DispatchKey.h>
diff --git a/aten/src/ATen/core/dispatch/DispatchKey.h b/aten/src/ATen/core/dispatch/DispatchKey.h
deleted file mode 100644
index 7c6cf15f61..0000000000
--- a/aten/src/ATen/core/dispatch/DispatchKey.h
+++ /dev/null
@@ -1,97 +0,0 @@
-#pragma once
-
-#include <ATen/core/dispatch/DeviceId.h>
-#include <ATen/core/dispatch/LayoutId.h>
-#include <c10/util/typeid.h>
-
-#include <vector>
-#include <functional>
-#include <sstream>
-#include <c10/util/Array.h>
-
-namespace c10 {
-
-namespace details {
-
-struct TensorParameterDispatchKey final {
- // note: This dispatch key structure is not final yet and will change. Don't rely on it.
- DeviceTypeId deviceTypeId;
- LayoutId layoutId;
- // TODO Move this TypeIdentifier to c10 namespace
- caffe2::TypeIdentifier dataType;
-};
-inline constexpr bool operator==(const TensorParameterDispatchKey& lhs, const TensorParameterDispatchKey& rhs) {
- return lhs.deviceTypeId == rhs.deviceTypeId && lhs.layoutId == rhs.layoutId && lhs.dataType == rhs.dataType;
-}
-
-inline std::ostream& operator<<(std::ostream& stream, const TensorParameterDispatchKey& key) {
- return stream << "TensorKey(" << key.deviceTypeId << ", " << key.layoutId.value() << ", " << key.dataType << ")";
-}
-
-} // namespace details
-} // namespace c10
-
-namespace std {
- template<>
- struct hash<c10::details::TensorParameterDispatchKey> {
- // TODO constexpr hashing
- size_t operator()(const c10::details::TensorParameterDispatchKey& obj) const {
- return std::hash<c10::DeviceTypeId>()(obj.deviceTypeId) ^ std::hash<c10::LayoutId>()(obj.layoutId) ^ std::hash<caffe2::TypeIdentifier>()(obj.dataType);
- }
- };
-} // namespace std
-
-namespace c10 {
-/**
- * The dispatch key encodes the runtime type identity of a function call arguments,
- * specifying what aspects of this identity can be dynamically dispatched on.
- *
- * Intuitively, given a function signature like f(Tensor, int), a valid dispatch
- * key for the arguments might be [CPUFloatTensor] (notice that 'f' is NOT included
- * in the dispatch key, and the runtime type of 'int' is NOT considered for dispatch
- * (since it is trivial).
- *
- * Dispatch keys permit equality tests and are hashable.
- *
- * @tparam num_dispatch_args The number of dispatchable arguments
- */
-template<size_t num_dispatch_args>
-struct DispatchKey final {
- guts::array<details::TensorParameterDispatchKey, num_dispatch_args> argTypes;
-};
-
-template<size_t num_dispatch_args>
-inline constexpr bool operator==(const DispatchKey<num_dispatch_args> &lhs, const DispatchKey<num_dispatch_args>& rhs) {
- // TODO: Use AVX instructions to perform this equality test more quickly
- return lhs.argTypes == rhs.argTypes;
-}
-
-template<size_t num_dispatch_args>
-inline std::ostream& operator<<(std::ostream& stream, const DispatchKey<num_dispatch_args>& key) {
- stream << "DispatchKey(";
- if (num_dispatch_args > 0) {
- stream << "DispatchKey(" << key.argTypes[0];
- for (size_t i = 1; i < num_dispatch_args; ++i) {
- stream << ", " << key.argTypes[i];
- }
- stream << ")";
- }
- return stream << ")";
-}
-
-} // namespace c10
-
-namespace std {
- template<size_t num_dispatch_args>
- struct hash<c10::DispatchKey<num_dispatch_args>> {
- // TODO constexpr hashing
- size_t operator()(const c10::DispatchKey<num_dispatch_args>& obj) const {
- size_t hash_value = 0;
- for (const auto& argType : obj.argTypes) {
- hash_value *= 10883; // prime
- hash_value += std::hash<c10::details::TensorParameterDispatchKey>()(argType);
- }
- return hash_value;
- }
- };
-} // namespace std
diff --git a/aten/src/ATen/core/dispatch/DispatchTable.cpp b/aten/src/ATen/core/dispatch/DispatchTable.cpp
deleted file mode 100644
index 0bf7c5903d..0000000000
--- a/aten/src/ATen/core/dispatch/DispatchTable.cpp
+++ /dev/null
@@ -1 +0,0 @@
-#include <ATen/core/dispatch/DispatchTable.h>
diff --git a/aten/src/ATen/core/dispatch/DispatchTable.h b/aten/src/ATen/core/dispatch/DispatchTable.h
index 22206658c5..c005314c27 100644
--- a/aten/src/ATen/core/dispatch/DispatchTable.h
+++ b/aten/src/ATen/core/dispatch/DispatchTable.h
@@ -1,28 +1,29 @@
#pragma once
-#include <ATen/core/dispatch/OpSchema.h>
+#include <ATen/core/function_schema.h>
#include <c10/util/LeftRight.h>
#include <c10/util/Metaprogramming.h>
#include <c10/util/flat_hash_map.h>
#include <ATen/core/ivalue.h>
+#include <ATen/core/dispatch/KernelFunction.h>
#include <array>
#include <atomic>
#include <iostream>
#include <mutex>
#include <type_traits>
+#include <sstream>
#include <unordered_map>
namespace c10 {
/**
* The type of a user-supplied function to initialize the kernel cache.
- * this is stored together with the kernel function in the dispatch table
+ * this is stored together with the KernelFunction in the DispatchTable
* so we can create a new cache instance when a kernel is looked up
* from the dispatch table.
*/
-using KernelStateCreatorFunction = std::unique_ptr<c10::KernelState> ();
-
+using KernelCacheCreatorFunction = std::unique_ptr<c10::KernelCache> ();
/**
* The dispatch table stores a pointer to a kernel function and a pointer
* to a function initializing a cache for the kernel. If the kernel wants
@@ -32,72 +33,100 @@ using KernelStateCreatorFunction = std::unique_ptr<c10::KernelState> ();
* this same cache instance.
*/
struct DispatchTableEntry final {
- KernelFunction* kernel_func;
- KernelStateCreatorFunction* state_creator_func;
+ /*not-nullable*/ KernelFunction* kernel_func;
+ /*not-nullable*/ KernelCacheCreatorFunction* cache_creator_func;
};
-namespace details {
+namespace detail {
/// Kernel implementations in a thread-safe hash table.
-template <class Key>
class ThreadsafeOperatorTable_ final {
public:
- template <class Key_>
- void emplace(Key_&& key, const DispatchTableEntry& value) {
- bool res = map_.write([&](ska::flat_hash_map<Key, DispatchTableEntry>& map) -> bool {
- auto result = map.emplace(std::forward<Key>(key), value);
+ void emplace(TensorTypeId key, const DispatchTableEntry& value, const std::string& operator_name) {
+ bool res = map_.write([&](ska::flat_hash_map<TensorTypeId, DispatchTableEntry>& map) -> bool {
+ auto result = map.emplace(key, value);
return result.second;
});
if (!res) {
- AT_ERROR("Tried to register conflicting kernels to the dispatcher: ", key);
+ AT_ERROR("Tried to register multiple kernels with same dispatch key '",
+ dispatch_key_to_string(key), "' for operator '", operator_name ,"'.");
}
}
- void erase(const Key& key) {
- auto num_removed =
- map_.write([&](ska::flat_hash_map<Key, DispatchTableEntry>& map) -> size_t {
- return map.erase(key);
- });
- assert(num_removed <= 1); // This is not a multi-map
- if (num_removed == 0) {
- AT_ERROR("Tried to deregister a kernel that isn't registered.");
- }
+ void erase(TensorTypeId key, const std::string& operator_name) {
+ map_.write([&](ska::flat_hash_map<TensorTypeId, DispatchTableEntry>& map) {
+ auto num_removed = map.erase(key);
+
+ assert(num_removed <= 1); // This is not a multi-map
+ if (num_removed == 0) {
+ AT_ERROR("Tried to deregister a kernel with dispatch key '",
+ dispatch_key_to_string(key), "' for operator '", operator_name,
+ "' but that kernel isn't registered. Registered dispatch keys are: ",
+ list_all_dispatch_keys(map));
+ }
+ });
}
- const DispatchTableEntry* lookup(const Key& key) const {
- return map_.read([&](const ska::flat_hash_map<Key, DispatchTableEntry>& map) -> const DispatchTableEntry* {
+ const DispatchTableEntry* lookup(TensorTypeId key, const string& operator_name) const {
+ return map_.read([&](const ska::flat_hash_map<TensorTypeId, DispatchTableEntry>& map) -> const DispatchTableEntry* {
auto found = map.find(key);
if (found != map.end()) {
return &found->second;
} else {
- return nullptr;
+ AT_ERROR("Didn't find kernel to dispatch to for operator '", operator_name,
+ "'. Tried to look up kernel for dispatch key '", dispatch_key_to_string(key),
+ "'. Registered dispatch keys are: ", list_all_dispatch_keys(map));
}
});
}
+ bool isEmpty() const {
+ return map_.read([&](const ska::flat_hash_map<TensorTypeId, DispatchTableEntry>& map) -> bool {
+ return map.size() == 0;
+ });
+ }
+
private:
- LeftRight<ska::flat_hash_map<Key, DispatchTableEntry>> map_;
+ static std::string list_all_dispatch_keys(const ska::flat_hash_map<TensorTypeId, DispatchTableEntry>& map) {
+ if (map.size() == 0) {
+ return "";
+ }
+ std::ostringstream str;
+ str << dispatch_key_to_string(map.begin()->first);
+ for (auto iter = ++map.begin(); iter != map.end(); ++iter) {
+ str << ", " << dispatch_key_to_string(iter->first);
+ }
+ return str.str();
+ }
+
+ static std::string dispatch_key_to_string(TensorTypeId id) {
+ return std::string(toString(tensorTypeIdToBackend(id))) + "[" + guts::to_string(id) + "]";
+ }
+
+ LeftRight<ska::flat_hash_map<TensorTypeId, DispatchTableEntry>> map_;
};
-} // namespace details
+} // namespace detail
/**
* Per-operator dispatch table.
*
- * Given an operator specified by 'OpSchemaDef', this class records a dispatch
+ * Given an operator specified by a FunctionSchema, this class records a dispatch
* table for various kernels provided for this operator. For example, if we
* consider the operator add(Tensor, Tensor), the dispatch table for this
* operator may contain implementations for various dynamic tensor types, such
- * as (CPUFloatTensor, CPUFloatTensor), (CUDAFloatTensor, CUDAFloatTensor), etc.
- *
- * @tparam OpSchemaDef The operator signature this dispatch table encodes.
+ * as CPUTensorId, CUDATensorId, etc.
*/
-// TODO: Support dispatch for meta-operators (which apply to all dynamic types)
-template <class OpSchemaDef>
class DispatchTable final {
- private:
- using Schema = OpSchema<OpSchemaDef>;
-
public:
- DispatchTable() : kernels_() {}
+ explicit DispatchTable(const FunctionSchema& schema)
+ : kernels_()
+ , reverse_index_of_first_tensor_arg_(
+ schema.arguments().size() - get_index_of_first_tensor_arg_(schema))
+ , operator_name_(schema.name()) {}
+
+ DispatchTable(DispatchTable&&) = default;
+ DispatchTable& operator=(DispatchTable&&) = default;
+ DispatchTable(const DispatchTable&) = delete;
+ DispatchTable& operator=(const DispatchTable&) = delete;
/**
* Register a kernel in the table at some dispatch key.
@@ -105,9 +134,9 @@ class DispatchTable final {
* @param dispatch_key Dispatch key to define when this kernel is selected
*/
void registerKernel(
- typename Schema::dispatch::dispatch_key_type dispatch_key,
+ TensorTypeId dispatch_key,
const DispatchTableEntry& kernel) {
- kernels_.emplace(std::move(dispatch_key), kernel);
+ kernels_.emplace(dispatch_key, kernel, operator_name_);
}
/**
@@ -118,9 +147,8 @@ class DispatchTable final {
// TODO: This isn't going to work so well when we get more complicated
// override patterns! In this case, an operator will show up in multiple
// slots, and erasing them one-by-one is probably not such a good idea.
- void deregisterKernel(
- const typename Schema::dispatch::dispatch_key_type& dispatch_key) {
- kernels_.erase(dispatch_key);
+ void deregisterKernel(TensorTypeId dispatch_key) {
+ kernels_.erase(dispatch_key, operator_name_);
}
/**
@@ -131,30 +159,39 @@ class DispatchTable final {
* @return Kernel function pointing to the right kernel for the given arguments
*/
const DispatchTableEntry& lookup(const Stack* stack) const {
- auto dispatch_key = Schema::dispatch::dispatch_key(stack);
- const DispatchTableEntry* found = kernels_.lookup(dispatch_key);
- if (found == nullptr) {
- // TODO Better error message - include op name and dispatch key (i.e.
- // argument types)
- AT_ERROR("Didn't find kernel to dispatch to for operator '", Schema::metadata::name(), "'");
- }
- return *found;
+ TensorTypeId dispatch_key = torch::jit::peek(
+ *stack,
+ 0,
+ reverse_index_of_first_tensor_arg_
+ ).toTensor().type_id();
+ return *kernels_.lookup(dispatch_key, operator_name_);
+ }
+
+ bool isEmpty() const {
+ return kernels_.isEmpty();
}
private:
+ static size_t get_index_of_first_tensor_arg_(const FunctionSchema& schema) {
+ for (size_t i = 0; i < schema.arguments().size(); ++i) {
+ if (schema.arguments()[i].type()->isSubtypeOf(DynamicType::get())) { // DynamicType means it's a tensor
+ return i;
+ }
+ }
+
+ AT_ERROR("Tried to create dispatch table for operator schema ", schema.name(), " that doesn't have tensor arguments.");
+ }
+ detail::ThreadsafeOperatorTable_ kernels_;
- details::ThreadsafeOperatorTable_<
- typename Schema::dispatch::dispatch_key_type>
- kernels_;
+ // this is caching the index so we don't have to parse the schema inputs
+ // again and again for each dispatcher lookup.
+ // reverse_index means this is the distance from the first tensor argument
+ // to argument_list.end(), i.e. from the top of the stack.
+ // Since it is distance to end(), this means it's 1-indexed,
+ // i.e. '1' is the last argument.
+ size_t reverse_index_of_first_tensor_arg_;
+ std::string operator_name_;
};
} // namespace c10
-
-/*
- * Use this to access the dispatch table singleton for a given op schema.
- * It has an implementation for each op schema def in a cpp file, because
- * we can't rely on the one-definition-rule.
- */
-template <class OpSchemaDef>
-C10_API c10::DispatchTable<OpSchemaDef>& c10_dispatch_table();
diff --git a/aten/src/ATen/core/dispatch/Dispatcher.cpp b/aten/src/ATen/core/dispatch/Dispatcher.cpp
index 9f76d8d969..832b687316 100644
--- a/aten/src/ATen/core/dispatch/Dispatcher.cpp
+++ b/aten/src/ATen/core/dispatch/Dispatcher.cpp
@@ -1 +1,8 @@
#include <ATen/core/dispatch/Dispatcher.h>
+
+namespace c10 {
+C10_EXPORT Dispatcher& Dispatcher::singleton() {
+ static Dispatcher _singleton;
+ return _singleton;
+}
+}
diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h
index 1d845750ef..bcca890c16 100644
--- a/aten/src/ATen/core/dispatch/Dispatcher.h
+++ b/aten/src/ATen/core/dispatch/Dispatcher.h
@@ -1,9 +1,14 @@
#pragma once
#include <ATen/core/dispatch/DispatchTable.h>
+#include <c10/util/Exception.h>
+#include <mutex>
+#include <list>
namespace c10 {
+class CAFFE2_API OperatorHandle;
+
/**
* This class represents an operator kernel, i.e. an operator *after* it was
* dispatched to a certain device. You can use it to call the kernel.
@@ -14,12 +19,11 @@ namespace c10 {
* Also, keeping around the OpKernel instance will keep around a local cache
* that is used by some kernels to get better performance when they're called
* multiple times (mostly Caffe2 kernels do that).
+ *
+ * OpKernel is not threadsafe.
*/
-class OpKernel final {
+class CAFFE2_API OpKernel final {
public:
- explicit OpKernel(KernelFunction* kernel, KernelStateCreatorFunction* state_creator)
- : kernel_(kernel), state_creator_(state_creator) {}
-
OpKernel(OpKernel&&) = default;
OpKernel& operator=(OpKernel&&) = default;
OpKernel(const OpKernel&) = delete;
@@ -28,60 +32,137 @@ public:
/**
* Call the operator kernel with the given arguments.
*/
- void call(Stack* stack) {
- if (state_.get() == nullptr) {
- AT_ASSERT(state_creator_ != nullptr);
- state_ = (*state_creator_)();
+ void call(Stack* stack) const {
+ if (cache_.get() == nullptr) {
+ AT_ASSERT(cache_creator_ != nullptr);
+ cache_ = (*cache_creator_)();
}
- return (*kernel_)(stack, state_.get());
+ return (*kernel_)(stack, cache_.get());
}
private:
- // The kernel function is a global C function, not a std::function.
- // That is, ownership is not an issue.
+ explicit OpKernel(KernelFunction* kernel, KernelCacheCreatorFunction* cache_creator)
+ : kernel_(kernel), cache_creator_(cache_creator) {}
+ friend class Dispatcher;
+
KernelFunction* kernel_;
- KernelStateCreatorFunction* state_creator_;
- std::unique_ptr<c10::KernelState> state_;
+ KernelCacheCreatorFunction* cache_creator_;
+ mutable std::unique_ptr<c10::KernelCache> cache_;
};
/**
* Top-level dispatch interface for dispatching via the dynamic dispatcher.
*/
-template<class OpSchemaDef>
-class Dispatcher final {
+class CAFFE2_API Dispatcher final {
private:
- using Schema = OpSchema<OpSchemaDef>;
+ struct OperatorDef final {
+ explicit OperatorDef(FunctionSchema schema_)
+ : dispatchTable(schema_)
+ , schema(std::move(schema_)) {}
+
+ DispatchTable dispatchTable;
+ FunctionSchema schema;
+ };
+ friend class OperatorHandle;
+
public:
// Implementation note: this class abstracts over the fact that we have per-operator
// dispatch tables. This could be easily adjusted to have a single global hash
// table.
+ static Dispatcher& singleton();
+
/**
- * Register an operator to the dispatch table for some operator schema.
+ * Register a new operator schema. The handle returned can be used to register
+ * kernels to this operator or to call it.
*/
- static void registerKernel(typename Schema::dispatch::dispatch_key_type dispatch_key, KernelFunction* kernel_func, KernelStateCreatorFunction* state_creator_func) {
- auto& dispatch_table_for_this_op = c10_dispatch_table<OpSchemaDef>();
- return dispatch_table_for_this_op.registerKernel(std::move(dispatch_key), DispatchTableEntry{kernel_func, state_creator_func});
- }
+ OperatorHandle registerSchema(FunctionSchema schema);
/**
- * Remove an operator from the dispatch table for some operator schema.
+ * Remove an operator from the dispatcher. Make sure you removed
+ * all kernels for this operatorbefore calling this.
*/
- static void deregisterKernel(const typename Schema::dispatch::dispatch_key_type& dispatch_key) {
- auto& dispatch_table_for_this_op = c10_dispatch_table<OpSchemaDef>();
- return dispatch_table_for_this_op.deregisterKernel(dispatch_key);
- }
+ void deregisterSchema(const OperatorHandle& op);
/**
- * Perform a dynamic dispatch and get the kernel for an operator
+ * Register an operator to the dispatch table for an operator.
*/
- static OpKernel lookup(const Stack* stack) {
- auto& dispatch_table_for_this_op = c10_dispatch_table<OpSchemaDef>();
- const DispatchTableEntry& kernel = dispatch_table_for_this_op.lookup(stack);
- return OpKernel(kernel.kernel_func, kernel.state_creator_func);
+ void registerKernel(const OperatorHandle& op, TensorTypeId dispatch_key, KernelFunction* kernel_func, KernelCacheCreatorFunction* cache_creator_func);
+
+ /**
+ * Remove an operator from the dispatch table for an operator.
+ */
+ void deregisterKernel(const OperatorHandle& op, TensorTypeId dispatch_key);
+
+ /**
+ * Perform a dynamic dispatch and get the kernel for an operator.
+ */
+ OpKernel lookup(const OperatorHandle& op, const Stack* stack) const;
+
+private:
+ std::list<OperatorDef> operators_;
+ std::mutex mutex_;
+};
+
+/**
+ * This is a handle to an operator schema registered with the dispatcher.
+ * This handle can be used to register kernels with the dispatcher or
+ * to lookup a kernel for a certain set of arguments.
+ */
+class CAFFE2_API OperatorHandle final {
+public:
+ OperatorHandle(OperatorHandle&&) = default;
+ OperatorHandle& operator=(OperatorHandle&&) = default;
+ OperatorHandle(const OperatorHandle&) = default;
+ OperatorHandle& operator=(const OperatorHandle&) = default;
+
+ const FunctionSchema& schema() const {
+ return operatorDefIterator_->schema;
}
+private:
+ explicit OperatorHandle(std::list<Dispatcher::OperatorDef>::iterator operatorDefIterator)
+ : operatorDefIterator_(std::move(operatorDefIterator)) {}
+ friend class Dispatcher;
+
+ std::list<Dispatcher::OperatorDef>::iterator operatorDefIterator_;
};
+
+
+inline OperatorHandle Dispatcher::registerSchema(FunctionSchema schema) {
+ // we need a lock to avoid concurrent writes
+ std::lock_guard<std::mutex> lock(mutex_);
+
+ operators_.emplace_back(std::move(schema));
+ return OperatorHandle(--operators_.end());
+}
+
+inline void Dispatcher::deregisterSchema(const OperatorHandle& op) {
+ // we need a lock to avoid concurrent writes
+ std::lock_guard<std::mutex> lock(mutex_);
+
+ if (!op.operatorDefIterator_->dispatchTable.isEmpty()) {
+ AT_ERROR("Tried to deregister op schema that still has kernels registered");
+ }
+ operators_.erase(op.operatorDefIterator_);
+}
+
+inline void Dispatcher::registerKernel(const OperatorHandle& op, TensorTypeId dispatch_key, KernelFunction* kernel_func, KernelCacheCreatorFunction* cache_creator_func) {
+ // note: this doesn't need the mutex because write operations on the list keep iterators intact.
+ op.operatorDefIterator_->dispatchTable.registerKernel(std::move(dispatch_key), DispatchTableEntry{kernel_func, cache_creator_func});
+}
+
+inline void Dispatcher::deregisterKernel(const OperatorHandle& op, TensorTypeId dispatch_key) {
+ // note: this doesn't need the mutex because write operations on the list keep iterators intact.
+ op.operatorDefIterator_->dispatchTable.deregisterKernel(dispatch_key);
+}
+
+inline OpKernel Dispatcher::lookup(const OperatorHandle& op, const Stack* stack) const {
+ // note: this doesn't need the mutex because write operations on the list keep iterators intact.
+ const DispatchTableEntry& kernel = op.operatorDefIterator_->dispatchTable.lookup(stack);
+ return OpKernel(kernel.kernel_func, kernel.cache_creator_func);
+}
+
} // namespace c10
diff --git a/aten/src/ATen/core/dispatch/KernelCache.h b/aten/src/ATen/core/dispatch/KernelCache.h
new file mode 100644
index 0000000000..e879b9d09f
--- /dev/null
+++ b/aten/src/ATen/core/dispatch/KernelCache.h
@@ -0,0 +1,20 @@
+#pragma once
+
+namespace c10 {
+
+/**
+ * A kernel can keep around a cache to have better performance when it's
+ * called multiple times. This is used by a lot of caffe2 kernels, for example
+ * conv_op stores a set of tensors for intermediate values to avoid having
+ * to reallocate them on each call.
+ * This cache owned by the call site (i.e. stored inside the OpKernel object)
+ * kept at the call site to call into the kernel) and passed in to the kernel
+ * as a function argument. It must inherit from KernelCache so the call site
+ * knows how to store and destruct it.
+ */
+class CAFFE2_API KernelCache {
+public:
+ virtual ~KernelCache() = default;
+};
+
+}
diff --git a/aten/src/ATen/core/dispatch/KernelFunction.h b/aten/src/ATen/core/dispatch/KernelFunction.h
new file mode 100644
index 0000000000..67a296fa4e
--- /dev/null
+++ b/aten/src/ATen/core/dispatch/KernelFunction.h
@@ -0,0 +1,16 @@
+#pragma once
+
+#include <ATen/core/dispatch/KernelCache.h>
+#include <ATen/core/stack.h>
+
+namespace c10 {
+
+using Stack = torch::jit::Stack; // TODO Instead of this, move torch::jit::Stack to the c10 namespace.
+
+/**
+ * This is the basic ABI for any kernel call. Each kernel is registered as a
+ * function pointer `KernelFunction*`, i.e. kernels are not allowed to be closures.
+ */
+using KernelFunction = void(Stack*, KernelCache* cache);
+
+}
diff --git a/aten/src/ATen/core/dispatch/KernelRegistration.cpp b/aten/src/ATen/core/dispatch/KernelRegistration.cpp
deleted file mode 100644
index 0848300f5e..0000000000
--- a/aten/src/ATen/core/dispatch/KernelRegistration.cpp
+++ /dev/null
@@ -1 +0,0 @@
-#include <ATen/core/dispatch/KernelRegistration.h>
diff --git a/aten/src/ATen/core/dispatch/KernelRegistration.h b/aten/src/ATen/core/dispatch/KernelRegistration.h
index df198c437d..eaab0c15e0 100644
--- a/aten/src/ATen/core/dispatch/KernelRegistration.h
+++ b/aten/src/ATen/core/dispatch/KernelRegistration.h
@@ -2,13 +2,20 @@
#include <c10/util/Optional.h>
#include <ATen/core/dispatch/Dispatcher.h>
-#include <ATen/core/dispatch/OpSchema.h>
/**
* To register your own kernel for an operator, do in one (!) cpp file:
- * C10_REGISTER_KERNEL(OpSchemaDef)
- * .kernel(&kernel_func)
+ * C10_REGISTER_KERNEL(OperatorHandle)
+ * .kernel<decltype(&kernel_func), &kernel_func>()
* .dispatchKey(dispatch_key);
+ *
+ * Example:
+ *
+ * Tensor my_kernel_cpu(Tensor in) {...}
+ *
+ * C10_REGISTER_KERNEL(MyOpSchema)
+ * .kernel<decltype(my_kernel_cpu), &my_kernel_cpu>()
+ * .dispatchKey(CPUTensorId());
*/
namespace c10 {
@@ -17,30 +24,29 @@ namespace c10 {
// TODO Test no dispatch key defined
/**
- * Class which, on construction, registers an operator in the dispatch table. The intent is that
+ * Class which, on construction, registers an operator in the dispatch table. The intent is that
* this class is constructed at static initialization time so that operators automatically get
* registered when a dlopen() occurs.
*
- * You shouldn't call this directly; instead, use the KernelRegistrationBuilder
- *
- * @tparam OpSchemaDef
+ * You shouldn't call this directly; instead, use the C10_REGISTER_KERNEL macros.
*/
-template<class OpSchemaDef>
class KernelRegistrar final {
-private:
- using Schema = OpSchema<OpSchemaDef>;
public:
+ using OpHandleGetter = const OperatorHandle& ();
+
/**
- * @param kernel The concrete function implementation to register
+ * @param op The operator to register the kernel for
* @param dispatch_key The dispatch key to register the function to
+ * @param kernel The concrete function implementation to register
+ * @param cache_creator A function initializing the cache for the kernel
*/
- KernelRegistrar(typename Schema::dispatch::dispatch_key_type dispatch_key, KernelFunction* kernel, KernelStateCreatorFunction* state_creator)
- : dispatch_key_(std::move(dispatch_key)), owns_registration_(true) {
- Dispatcher<OpSchemaDef>::registerKernel(dispatch_key_, kernel, state_creator);
+ explicit KernelRegistrar(OpHandleGetter *op, TensorTypeId dispatch_key, KernelFunction* kernel, KernelCacheCreatorFunction* cache_creator)
+ : op_(std::move(op)), dispatch_key_(std::move(dispatch_key)), owns_registration_(true) {
+ Dispatcher::singleton().registerKernel(op_(), dispatch_key_, kernel, cache_creator);
}
KernelRegistrar(KernelRegistrar&& rhs)
- : dispatch_key_(std::move(rhs.dispatch_key_)), owns_registration_(true) {
+ : op_(std::move(rhs.op_)), dispatch_key_(std::move(rhs.dispatch_key_)), owns_registration_(true) {
rhs.owns_registration_ = false;
}
@@ -49,17 +55,111 @@ public:
~KernelRegistrar() {
if (owns_registration_) {
- Dispatcher<OpSchemaDef>::deregisterKernel(dispatch_key_);
+ Dispatcher::singleton().deregisterKernel(op_(), dispatch_key_);
}
}
private:
- const typename Schema::dispatch::dispatch_key_type dispatch_key_;
+ OpHandleGetter *op_;
+ const TensorTypeId dispatch_key_;
bool owns_registration_;
C10_DISABLE_COPY_AND_ASSIGN(KernelRegistrar);
};
+namespace detail {
+// ivalue_to_arg_type<T>: Take an IValue that is an argument to a kernel and
+// cast it to the type that should be passed to the kernel function.
+// Examples: If the IValue contains a plain type like an int, return that.
+// If the IValue contains an IntList, return it as ArrayRef<int>.
+template<class T>
+struct ivalue_to_arg_type {
+ static T call(const IValue& v) {
+ return std::move(v).to<T>();
+ }
+};
+template<class T>
+struct ivalue_to_arg_type<ArrayRef<T>> {
+ static ArrayRef<T> call(const IValue& v) {
+ return v.to<intrusive_ptr<ivalue::List<T>>>()->elements();
+ }
+};
+
+// call_with_ivalue_args: Take a function pointer and an ArrayRef<IValue>
+// containing the arguments to call the function pointer with, and call it.
+// The extra_args are appended as additional arguments at the end of the function call.
+// Example:
+// int myfunc(int a, ArrayRef<int> b, string c);
+// int main() {
+// std::vector<IValue> ivalue_args = {IValue(2), IntList::create(3, 4)};
+// call_with_ivalue_args<decltype(myfunc), &myfunc>(ivalue_args, "extra_arg");
+// }
+template<class FuncType, FuncType* func, class... ExtraArgs, size_t... ivalue_arg_indices>
+typename guts::function_traits<FuncType>::return_type call_with_ivalue_args_(ArrayRef<IValue> ivalue_args, guts::index_sequence<ivalue_arg_indices...>, ExtraArgs&&... extra_args) {
+ using IValueArgTypes = typename guts::function_traits<FuncType>::parameter_types;
+ return (*func)(ivalue_to_arg_type<guts::remove_cv_t<guts::remove_reference_t<guts::typelist::element_t<ivalue_arg_indices, IValueArgTypes>>>>::call(ivalue_args[ivalue_arg_indices])..., std::forward<ExtraArgs>(extra_args)...);
+}
+
+template<class FuncType, FuncType* func, class... ExtraArgs>
+typename guts::function_traits<FuncType>::return_type call_with_ivalue_args(ArrayRef<IValue> ivalue_args, ExtraArgs&&... extra_args) {
+ constexpr size_t num_ivalue_args = guts::function_traits<FuncType>::number_of_parameters - sizeof...(ExtraArgs);
+ AT_ASSERTM(num_ivalue_args == ivalue_args.size(), "Wrong number of ivalue arguments");
+ return call_with_ivalue_args_<FuncType, func>(ivalue_args, guts::make_index_sequence<num_ivalue_args>(), std::forward<ExtraArgs>(extra_args)...);
+}
+
+template<class OutputType>
+struct push_outputs final {
+ static void call(OutputType&& output, Stack* stack) {
+ push_outputs<std::tuple<OutputType>>(std::tuple<OutputType>(std::move(output)), stack);
+ }
+};
+template<class... OutputTypes>
+struct push_outputs<std::tuple<OutputTypes...>> final {
+ static void call(std::tuple<OutputTypes...>&& output, Stack* stack) {
+ for (size_t i = 0; i < sizeof...(OutputTypes); ++i) {
+ torch::jit::push(return_type_to_ivalue(std::move(output)));
+ }
+ }
+};
+
+// SFINAE over (1) does the operator kernel have a cache and (2) does it return a value or void
+template<class CacheTypeOrVoid, class FuncType, FuncType* kernel, class Enable = void> struct wrap_kernel {};
+// SFINAE version for kernels with output and with cache
+template<class CacheTypeOrVoid, class FuncType, FuncType* kernel>
+struct wrap_kernel<CacheTypeOrVoid, FuncType, kernel, guts::enable_if_t<!std::is_same<void, CacheTypeOrVoid>::value && !std::is_same<void, typename guts::function_traits<FuncType>::return_type>::value>> final {
+ static typename guts::function_traits<FuncType>::return_type call(Stack* stack, KernelCache* cache) {
+ constexpr size_t num_inputs = guts::function_traits<FuncType>::number_of_parameters - 1; // -1 because it takes the kernel cache as last argument
+ auto output = call_with_ivalue_args<FuncType, kernel>(torch::jit::last(*stack, num_inputs), static_cast<CacheTypeOrVoid*>(cache));
+ push_outputs<typename guts::function_traits<FuncType>::return_type>(std::move(output), stack);
+ }
+};
+// SFINAE version for kernels with output and without a cache
+template<class CacheTypeOrVoid, class FuncType, FuncType* kernel>
+struct wrap_kernel<CacheTypeOrVoid, FuncType, kernel, guts::enable_if_t<std::is_same<void, CacheTypeOrVoid>::value && !std::is_same<void, typename guts::function_traits<FuncType>::return_type>::value>> final {
+ static typename guts::function_traits<FuncType>::return_type call(Stack* stack, c10::KernelCache* /*cache*/) {
+ constexpr size_t num_inputs = guts::function_traits<FuncType>::number_of_parameters;
+ auto output = call_with_ivalue_args<FuncType, kernel>(torch::jit::last(*stack, num_inputs));
+ push_outputs<typename guts::function_traits<FuncType>::return_type>(std::move(output), stack);
+ }
+};
+// SFINAE version for kernels without output and with a cache
+template<class CacheTypeOrVoid, class FuncType, FuncType* kernel>
+struct wrap_kernel<CacheTypeOrVoid, FuncType, kernel, guts::enable_if_t<!std::is_same<void, CacheTypeOrVoid>::value && std::is_same<void, typename guts::function_traits<FuncType>::return_type>::value>> final {
+ static typename guts::function_traits<FuncType>::return_type call(Stack* stack, c10::KernelCache* cache) {
+ constexpr size_t num_inputs = guts::function_traits<FuncType>::number_of_parameters - 1; // -1 because it takes the kernel cache as last argument
+ call_with_ivalue_args<FuncType, kernel>(torch::jit::last(*stack, num_inputs), static_cast<CacheTypeOrVoid*>(cache));
+ }
+};
+// SFINAE version for kernels without output and without a cache
+template<class CacheTypeOrVoid, class FuncType, FuncType* kernel>
+struct wrap_kernel<CacheTypeOrVoid, FuncType, kernel, guts::enable_if_t<std::is_same<void, CacheTypeOrVoid>::value && std::is_same<void, typename guts::function_traits<FuncType>::return_type>::value>> final {
+ static typename guts::function_traits<FuncType>::return_type call(Stack* stack, c10::KernelCache* /*cache*/) {
+ constexpr size_t num_inputs = guts::function_traits<FuncType>::number_of_parameters;
+ call_with_ivalue_args<FuncType, kernel>(torch::jit::last(*stack, num_inputs));
+ }
+};
+}
+
/**
* Helper class for building a KernelRegistrar. This permits "keyword-argument" like syntax
* when performing operator registration, e.g., as in:
@@ -76,52 +176,51 @@ private:
* .dispatchKey("bla");
*
* The resulting full expression is implicitly convertible to a KernelRegistrar.
- *
- * @tparam OpSchemaDef The operator schema this is building a KernelRegistration for
- * @tparam FieldsPresentFlags Remembers which fields are already set in the builder
*/
-template<class OpSchemaDef, class StateTypeOrVoid, uint64_t FieldsPresentFlags>
+template<class CacheTypeOrVoid, uint64_t FieldsPresentFlags>
class KernelRegistrationBuilder final {
private:
- using Schema = OpSchema<OpSchemaDef>;
-
static constexpr uint64_t DISPATCH_KEY_PRESENT = 0x01 << 0;
static constexpr uint64_t KERNEL_PRESENT = 0x01 << 1;
- static constexpr uint64_t STATE_PRESENT = 0x01 << 2;
+ static constexpr uint64_t CACHE_PRESENT = 0x01 << 2;
+
+ using OpHandleGetter = KernelRegistrar::OpHandleGetter;
- static std::unique_ptr<c10::KernelState> defaultStateCreator() {
+ static std::unique_ptr<c10::KernelCache> defaultCacheCreator() {
return nullptr;
}
- template<class State>
- static std::unique_ptr<c10::KernelState> stateCreator() {
- static_assert(std::is_default_constructible<State>::value, "State class must be default constructible");
- return guts::make_unique<State>();
+ template<class Cache>
+ static std::unique_ptr<c10::KernelCache> cacheCreator() {
+ static_assert(std::is_default_constructible<Cache>::value, "Cache class must be default constructible");
+ return guts::make_unique<Cache>();
}
- c10::optional<typename Schema::dispatch::dispatch_key_type> dispatch_key_;
+ OpHandleGetter *op_;
+ c10::optional<TensorTypeId> dispatch_key_;
KernelFunction* kernel_;
- KernelStateCreatorFunction* state_creator_;
+ KernelCacheCreatorFunction* cache_creator_;
public:
- constexpr KernelRegistrationBuilder()
- : KernelRegistrationBuilder(c10::nullopt, nullptr, &defaultStateCreator) {}
+ constexpr explicit KernelRegistrationBuilder(OpHandleGetter *op)
+ : KernelRegistrationBuilder(std::move(op), c10::nullopt, nullptr, &defaultCacheCreator) {}
- constexpr KernelRegistrationBuilder(
- c10::optional<typename Schema::dispatch::dispatch_key_type> dispatch_key,
+ constexpr explicit KernelRegistrationBuilder(
+ OpHandleGetter *op,
+ c10::optional<TensorTypeId> dispatch_key,
KernelFunction* kernel,
- KernelStateCreatorFunction* state_creator)
- : dispatch_key_(std::move(dispatch_key)), kernel_(kernel), state_creator_(state_creator) {}
+ KernelCacheCreatorFunction* cache_creator)
+ : op_(std::move(op)), dispatch_key_(std::move(dispatch_key)), kernel_(kernel), cache_creator_(cache_creator) {}
/**
- * Implicit coercion to KernelRegistrar<OpSchemaDef> that finalizes the builder and
+ * Implicit coercion to KernelRegistrar that finalizes the builder and
* creates the object.
* @return Produced KernelRegistrar
*/
- operator KernelRegistrar<OpSchemaDef>() && {
+ operator KernelRegistrar() && {
static_assert(FieldsPresentFlags & KERNEL_PRESENT, "Forgot to call .kernel() in kernel registration");
static_assert(FieldsPresentFlags & DISPATCH_KEY_PRESENT, "Forgot to call .dispatchKey() in kernel registration");
- return KernelRegistrar<OpSchemaDef>(std::move(*dispatch_key_), kernel_, state_creator_);
+ return KernelRegistrar(op_, std::move(*dispatch_key_), kernel_, cache_creator_);
}
/**
@@ -129,9 +228,9 @@ private:
* @param dispatch_key dispatch key to register the function to
* @return "this" for method chaining
*/
- constexpr KernelRegistrationBuilder<OpSchemaDef, StateTypeOrVoid, FieldsPresentFlags | DISPATCH_KEY_PRESENT> dispatchKey(typename Schema::dispatch::dispatch_key_type dispatch_key) && {
+ constexpr KernelRegistrationBuilder<CacheTypeOrVoid, FieldsPresentFlags | DISPATCH_KEY_PRESENT> dispatchKey(TensorTypeId dispatch_key) && {
static_assert(!(FieldsPresentFlags & DISPATCH_KEY_PRESENT), "Tried to define kernel twice in same op registration");
- return KernelRegistrationBuilder<OpSchemaDef, StateTypeOrVoid, FieldsPresentFlags | DISPATCH_KEY_PRESENT>(std::move(dispatch_key), kernel_, state_creator_);
+ return KernelRegistrationBuilder<CacheTypeOrVoid, FieldsPresentFlags | DISPATCH_KEY_PRESENT>(std::move(op_), std::move(dispatch_key), kernel_, cache_creator_);
}
/**
@@ -140,10 +239,10 @@ private:
* @return "this" for method chaining
*/
template<KernelFunction* kernel_func>
- constexpr KernelRegistrationBuilder<OpSchemaDef, StateTypeOrVoid, FieldsPresentFlags | KERNEL_PRESENT> kernel() && {
+ constexpr KernelRegistrationBuilder<CacheTypeOrVoid, FieldsPresentFlags | KERNEL_PRESENT> kernel() && {
static_assert(!(FieldsPresentFlags & KERNEL_PRESENT), "Tried to define kernel twice in same op registration");
- // TODO Better error message when kernel function mismatches, one common mismatch is missing state parameter or state parameter present while not expected.
- return KernelRegistrationBuilder<OpSchemaDef, StateTypeOrVoid, FieldsPresentFlags | KERNEL_PRESENT>(std::move(dispatch_key_), kernel_func, state_creator_);
+ // TODO Better error message when kernel function mismatches, one common mismatch is missing cache parameter or cache parameter present while not expected.
+ return KernelRegistrationBuilder<CacheTypeOrVoid, FieldsPresentFlags | KERNEL_PRESENT>(std::move(op_), std::move(dispatch_key_), kernel_func, cache_creator_);
}
/**
@@ -151,9 +250,10 @@ private:
* @param kernel concrete function implementation to be registered
* @return "this" for method chaining
*/
- template<typename Schema::signature::template func_type_with_state<StateTypeOrVoid>* kernel_func>
- constexpr KernelRegistrationBuilder<OpSchemaDef, StateTypeOrVoid, FieldsPresentFlags | KERNEL_PRESENT> kernel() && {
- return std::move(*this).template kernel<&Schema::signature::template wrap_kernel<StateTypeOrVoid, kernel_func>>();
+ template<class FuncType, FuncType* kernel_func>
+ constexpr KernelRegistrationBuilder<CacheTypeOrVoid, FieldsPresentFlags | KERNEL_PRESENT> kernel() && {
+ // TODO Better error message if FuncType is not a func type
+ return std::move(*this).template kernel<&detail::wrap_kernel<CacheTypeOrVoid, FuncType, kernel_func>::call>();
}
/**
@@ -161,20 +261,19 @@ private:
* @param dispatch_key dispatch key to register the function to
* @return "this" for method chaining
*/
- template<class State>
- constexpr KernelRegistrationBuilder<OpSchemaDef, State, FieldsPresentFlags | STATE_PRESENT> withState() && {
- static_assert(!(FieldsPresentFlags & STATE_PRESENT), "Tried to define state twice in same op registration");
- static_assert(std::is_base_of<c10::KernelState, State>::value, "State must inherit from c10::KernelState");
+ template<class Cache>
+ constexpr KernelRegistrationBuilder<Cache, FieldsPresentFlags | CACHE_PRESENT> withCache() && {
+ static_assert(!(FieldsPresentFlags & CACHE_PRESENT), "Tried to define cache twice in same op registration");
+ static_assert(std::is_base_of<c10::KernelCache, Cache>::value, "Cache must inherit from c10::KernelCache");
- static_assert(!(FieldsPresentFlags & KERNEL_PRESENT), "Cannot set the state after the kernel function is already set. Please call .withState() first and .kernel() later in the chain.");
+ static_assert(!(FieldsPresentFlags & KERNEL_PRESENT), "Cannot set the cache after the kernel function is already set. Please call .withCache() first and .kernel() later in the chain.");
- return KernelRegistrationBuilder<OpSchemaDef, State, FieldsPresentFlags | STATE_PRESENT>(std::move(dispatch_key_), kernel_, &stateCreator<State>);
+ return KernelRegistrationBuilder<Cache, FieldsPresentFlags | CACHE_PRESENT>(std::move(op_), std::move(dispatch_key_), kernel_, &cacheCreator<Cache>);
}
};
} // namespace c10
-// TODO Can the builder logic be moved to compile time?
// NB: Semicolon after applying this macro is MANDATORY
-#define C10_REGISTER_KERNEL(OpSchemaDef) \
- static KernelRegistrar<OpSchemaDef> MACRO_CONCAT(__kernelRegistrationBuilder_, __COUNTER__) = KernelRegistrationBuilder<OpSchemaDef, void, 0>()
+#define C10_REGISTER_KERNEL(OperatorHandle) \
+ static KernelRegistrar MACRO_CONCAT(__kernelRegistrationBuilder_, __COUNTER__) = KernelRegistrationBuilder<void, 0>(OperatorHandle)
diff --git a/aten/src/ATen/core/dispatch/LayoutId.cpp b/aten/src/ATen/core/dispatch/LayoutId.cpp
deleted file mode 100644
index bee71d8dcc..0000000000
--- a/aten/src/ATen/core/dispatch/LayoutId.cpp
+++ /dev/null
@@ -1 +0,0 @@
-#include <ATen/core/dispatch/LayoutId.h>
diff --git a/aten/src/ATen/core/dispatch/LayoutId.h b/aten/src/ATen/core/dispatch/LayoutId.h
deleted file mode 100644
index d0648392f4..0000000000
--- a/aten/src/ATen/core/dispatch/LayoutId.h
+++ /dev/null
@@ -1,22 +0,0 @@
-#pragma once
-
-#include <c10/util/IdWrapper.h>
-
-namespace c10 {
-
-class LayoutId final : public at::IdWrapper<LayoutId, uint8_t> {
-public:
- constexpr explicit LayoutId(underlying_type id): IdWrapper(id) {}
-
- constexpr uint8_t value() const {
- return underlyingId();
- }
-
- // Don't use this default constructor!
- // Unfortunately, a default constructor needs to be defined because of https://reviews.llvm.org/D41223
- constexpr LayoutId(): IdWrapper(0) {}
-};
-
-}
-
-C10_DEFINE_HASH_FOR_IDWRAPPER(c10::LayoutId)
diff --git a/aten/src/ATen/core/dispatch/OpSchema.cpp b/aten/src/ATen/core/dispatch/OpSchema.cpp
deleted file mode 100644
index 595989569c..0000000000
--- a/aten/src/ATen/core/dispatch/OpSchema.cpp
+++ /dev/null
@@ -1 +0,0 @@
-#include <ATen/core/dispatch/OpSchema.h>
diff --git a/aten/src/ATen/core/dispatch/OpSchema.h b/aten/src/ATen/core/dispatch/OpSchema.h
deleted file mode 100644
index d88bd7db87..0000000000
--- a/aten/src/ATen/core/dispatch/OpSchema.h
+++ /dev/null
@@ -1,406 +0,0 @@
-#pragma once
-
-#include <ATen/core/dispatch/DispatchKey.h>
-#include <ATen/core/ivalue.h>
-#include <c10/util/Array.h>
-#include <c10/util/Metaprogramming.h>
-#include <c10/util/TypeList.h>
-#include <c10/core/DeviceType.h>
-#include <ATen/core/Tensor.h>
-#include <ATen/core/stack.h>
-
-namespace c10 {
-
-/**
- * A kernel can keep around a cache to have better performance when it's
- * called multiple times. This is used by a lot of caffe2 kernels.
- * This cache owned by the call site and passed in to the kernel as a function
- * argument. It must inherit from KernelState so the call site knows how to
- * store and destruct it.
- */
-class CAFFE2_API KernelState {
-public:
- virtual ~KernelState() = default;
-};
-
-using Stack = torch::jit::Stack; // TODO Instead of this, move torch::jit::Stack to the c10 namespace.
-
-/**
- * This is the basic ABI for any kernel call. Each kernel is registered as a
- * pointer to a global C function of this type.
- */
-using KernelFunction = void(Stack*, KernelState* state);
-
-namespace details {
-
-/**
- * If Arg is a Tensor or reference to a Tensor, provide the member constant value equal to true. Otherwise
- * return false.
- */
-template <class Arg>
-using is_tensor_arg = std::
- is_same<at::Tensor, guts::remove_cv_t<guts::remove_reference_t<Arg>>>;
-
-inline DeviceTypeId to_device_type_id(DeviceType device_type) {
- switch (device_type) {
- case DeviceType::CPU:
- return DeviceTypeId::CPU;
- case DeviceType::CUDA:
- return DeviceTypeId::CUDA;
- default:
- return DeviceTypeId::UNDEFINED;
- }
-}
-
-inline TensorParameterDispatchKey tensor_to_dispatch_key(const at::Tensor& tensor) {
- return TensorParameterDispatchKey{
- to_device_type_id(tensor.device().type()),
- LayoutId(0),
- tensor.dtype().id()};
-}
-
-template<size_t index, size_t offset, class ParameterTypes, class Enable = void> struct get_ith_tensor_arg_ {
- static_assert(!std::is_same<ParameterTypes, ParameterTypes>::value, "Index out of bounds");
-};
-template<size_t index, size_t offset, class Head, class... Tail>
-struct get_ith_tensor_arg_<index, offset, guts::typelist::typelist<Head, Tail...>, guts::enable_if_t<index == 0 && is_tensor_arg<Head>::value>> {
- static at::Tensor call(ArrayRef<IValue> args) {
- if (!args[offset].isTensor()) {
- throw std::runtime_error("Expected argument " + guts::to_string(offset) + " to be of type Tensor but found different type.");
- }
- return args[offset].toTensor();
- }
-};
-template<size_t index, size_t offset, class Head, class... Tail>
-struct get_ith_tensor_arg_<index, offset, guts::typelist::typelist<Head, Tail...>, guts::enable_if_t<index != 0 || !is_tensor_arg<Head>::value>> {
- static at::Tensor call(ArrayRef<IValue> args) {
- return get_ith_tensor_arg_<(is_tensor_arg<Head>::value ? (index-1) : index), offset + 1, guts::typelist::typelist<Tail...>>::call(args);
- }
-};
-template<class ParameterTypes, size_t index> at::Tensor get_ith_tensor_arg(ArrayRef<IValue> args) {
- return get_ith_tensor_arg_<index, 0, ParameterTypes>::call(args);
-}
-
-// Extract type ids for all tensors from an array of tensors
-template<class OpSchemaDef, size_t... indices>
-guts::array<TensorParameterDispatchKey, OpSchemaDef::num_dispatch_args()> getDispatchTypeIds__(ArrayRef<IValue> args, guts::index_sequence<indices...>) {
- using ParameterTypes = typename guts::function_traits<typename OpSchemaDef::Signature>::parameter_types;
- return {tensor_to_dispatch_key(get_ith_tensor_arg<ParameterTypes, indices>(args))...};
-}
-
-/**
- * Extract the type ids of all tensors in a variadic list of arguments
- *
- * @tparam Args Inferred variadic list of argument types
- * @param args List of arguments to get type ids from
- * @return guts::array<TensorParameterDispatchKey, n>, where n is the number of tensor arguments (is_tensor_arg) in the class
- */
-template<class OpSchemaDef>
-guts::array<TensorParameterDispatchKey, OpSchemaDef::num_dispatch_args()> getDispatchTypeIds_(ArrayRef<IValue> args) {
- return getDispatchTypeIds__<OpSchemaDef>(args, guts::make_index_sequence<OpSchemaDef::num_dispatch_args()>());
-}
-
-// TODO Test getDispatchTypeIds_
-
-/**
- * If T is a struct with a type field Signature, provides the member constant
- * @tparam T
- */
-template<class T, typename = void>
-struct has_signature_defined : std::false_type {};
-template<class T>
-struct has_signature_defined<T, guts::void_t<
- typename T::Signature
->> : std::true_type {};
-
-// TODO Test has_signature_defined
-
-template<class T, typename = void>
-struct has_parameter_names_defined : std::false_type {};
-template<class T>
-struct has_parameter_names_defined<T, guts::void_t<
- decltype(T::parameter_names)
->> : std::true_type {};
-
-// TODO Test has_parameter_names_defined
-
-template<class T, typename = void>
-struct has_name_defined : std::false_type {};
-template<class T>
-struct has_name_defined<T, guts::void_t<
- decltype(T::name)
->> : std::true_type {};
-
-// TODO Test has_name_defined
-
-template<class T>
-struct ivalue_to_arg_type {
- static T call(const IValue& v) {
- return std::move(v).to<T>();
- }
-};
-template<class T>
-struct ivalue_to_arg_type<ArrayRef<T>> {
- static ArrayRef<T> call(const IValue& v) {
- return v.to<intrusive_ptr<ivalue::List<T>>>()->elements();
- }
-};
-
-template<class FuncType, class... ExtraArgs, size_t... ivalue_arg_indices>
-typename guts::function_traits<FuncType>::return_type call_with_ivalue_args_(FuncType* func, ArrayRef<IValue> ivalue_args, guts::index_sequence<ivalue_arg_indices...>, ExtraArgs&&... extra_args) {
- using IValueArgTypes = typename guts::function_traits<FuncType>::parameter_types;
- return (*func)(ivalue_to_arg_type<guts::remove_cv_t<guts::remove_reference_t<guts::typelist::element_t<ivalue_arg_indices, IValueArgTypes>>>>::call(ivalue_args[ivalue_arg_indices])..., std::forward<ExtraArgs>(extra_args)...);
-}
-
-template<class FuncType, class... ExtraArgs>
-typename guts::function_traits<FuncType>::return_type call_with_ivalue_args(FuncType* func, ArrayRef<IValue> ivalue_args, ExtraArgs&&... extra_args) {
- constexpr size_t num_ivalue_args = guts::function_traits<FuncType>::number_of_parameters - sizeof...(ExtraArgs);
- return call_with_ivalue_args_<FuncType>(func, ivalue_args, guts::make_index_sequence<num_ivalue_args>(), std::forward<ExtraArgs>(extra_args)...);
-}
-
-template<class OutputType>
-struct write_outputs final {
- static void call(OutputType&& output, ArrayRef<IValue> outputs) {
- write_outputs<std::tuple<OutputType>>(std::tuple<OutputType>(std::move(output)), outputs);
- }
-};
-template<class... OutputTypes>
-struct write_outputs<std::tuple<OutputTypes...>> final {
- static void call(std::tuple<OutputTypes...>&& output, ArrayRef<IValue> outputs) {
- AT_ASSERT(outputs.size() == sizeof...(OutputTypes)); // Mismatch in number of returns between kernel function and operator schema.
- for (size_t i = 0; i < sizeof...(OutputTypes); ++i) {
- outputs[i] = return_type_to_ivalue(std::move(output));
- }
- }
-};
-
-
-// SFINAE over (1) does the operator kernel have state and (2) does it return a value or void
-template<class StateTypeOrVoid, class FuncType, class Enable = void> struct call_kernel_with_ivalue_args {};
-// SFINAE version for kernels with output and with state
-template<class StateTypeOrVoid, class FuncType>
-struct call_kernel_with_ivalue_args<StateTypeOrVoid, FuncType, guts::enable_if_t<!std::is_same<void, StateTypeOrVoid>::value && !std::is_same<void, typename guts::function_traits<FuncType>::return_type>::value>> final {
- static typename guts::function_traits<FuncType>::return_type call(FuncType* func, ArrayRef<IValue> ivalue_args, ArrayRef<IValue> outputs, c10::KernelState* state) {
- auto output = call_with_ivalue_args(func, ivalue_args, static_cast<StateTypeOrVoid*>(state));
- write_outputs<typename guts::function_traits<FuncType>::return_type>(std::move(output), outputs);
- }
-};
-// SFINAE version for kernels with output and without state
-template<class StateTypeOrVoid, class FuncType>
-struct call_kernel_with_ivalue_args<StateTypeOrVoid, FuncType, guts::enable_if_t<std::is_same<void, StateTypeOrVoid>::value && !std::is_same<void, typename guts::function_traits<FuncType>::return_type>::value>> final {
- static typename guts::function_traits<FuncType>::return_type call(FuncType* func, ArrayRef<IValue> ivalue_args, ArrayRef<IValue> outputs, c10::KernelState* /*state*/) {
- auto output = call_with_ivalue_args(func, ivalue_args);
- write_outputs<typename guts::function_traits<FuncType>::return_type>(std::move(output), outputs);
- }
-};
-// SFINAE version for kernels without output and with state
-template<class StateTypeOrVoid, class FuncType>
-struct call_kernel_with_ivalue_args<StateTypeOrVoid, FuncType, guts::enable_if_t<!std::is_same<void, StateTypeOrVoid>::value && std::is_same<void, typename guts::function_traits<FuncType>::return_type>::value>> final {
- static typename guts::function_traits<FuncType>::return_type call(FuncType* func, ArrayRef<IValue> ivalue_args, ArrayRef<IValue> outputs, c10::KernelState* state) {
- call_with_ivalue_args(func, ivalue_args, static_cast<StateTypeOrVoid*>(state));
- }
-};
-// SFINAE version for kernels without output and without state
-template<class StateTypeOrVoid, class FuncType>
-struct call_kernel_with_ivalue_args<StateTypeOrVoid, FuncType, guts::enable_if_t<std::is_same<void, StateTypeOrVoid>::value && std::is_same<void, typename guts::function_traits<FuncType>::return_type>::value>> final {
- static typename guts::function_traits<FuncType>::return_type call(FuncType* func, ArrayRef<IValue> ivalue_args, ArrayRef<IValue> outputs, c10::KernelState* /*state*/) {
- call_with_ivalue_args(func, ivalue_args);
- }
-};
-
-template<class FuncType, class AddedParameter, class Enable = void> struct add_ptr_parameter_if_not_void final {};
-template<class Return, class... Parameters, class AddedParameter>
-struct add_ptr_parameter_if_not_void<Return(Parameters...), AddedParameter, guts::enable_if_t<!std::is_same<void, AddedParameter>::value>> final {
- using type = Return(Parameters..., AddedParameter*);
-};
-template<class FuncType> struct add_ptr_parameter_if_not_void<FuncType, void, void> final {
- using type = FuncType;
-};
-
-/**
- * Wrapper class around a user-provided schema definition some useful information about the schema.
- *
- * @tparam OpSchemaDef Operator schema definition. See OpSchema for more details.
- */
-template<class OpSchemaDef> class OpSignatureSchema final {
- static_assert(details::has_signature_defined<OpSchemaDef>::value, "Operator schema doesn't define a valid Signature member type.");
- static_assert(guts::is_function_type<typename OpSchemaDef::Signature>::value, "Signature member of operator schema must be a function type.");
-
- using signature_traits = guts::function_traits<typename OpSchemaDef::Signature>;
-public:
- /**
- * The function type OpSchemaDef::Signature
- */
- using func_type = typename signature_traits::func_type;
- /**
- * The return type of the function OpSchemaDef::Signature
- */
- using return_type = typename signature_traits::return_type;
- /**
- * A type list of the parameter types of OpSchemaDef::Signature
- */
- using parameter_types = typename signature_traits::parameter_types;
-
- /**
- * The number of arguments of OpSchemaDef::Signature
- */
- static constexpr size_t num_args = guts::typelist::size<parameter_types>::value;
- /**
- * The number of tensor arguments (as per is_tensor_arg) in OpSchemaDef::Signature
- */
- static constexpr size_t num_tensor_args = guts::typelist::count_if<details::is_tensor_arg, parameter_types>::value;
-
- static constexpr size_t num_outputs = OpSchemaDef::num_outputs();
-
- template<class StateTypeOrVoid> using func_type_with_state = typename add_ptr_parameter_if_not_void<func_type, StateTypeOrVoid>::type;
-
- template<class StateTypeOrVoid, func_type_with_state<StateTypeOrVoid>* kernel>
- static void wrap_kernel(Stack* stack, KernelState* state) {
- constexpr size_t num_inputs = guts::typelist::size<parameter_types>::value;
- constexpr size_t num_outputs = 1; // TODO allow multiple outputs if it's a tuple
-
- ArrayRef<IValue> inputs = torch::jit::peekSlice(*stack, 0, num_inputs + num_outputs, num_inputs);
- ArrayRef<IValue> outputs = torch::jit::peekSlice(*stack, 0, num_outputs, num_outputs);
-
- call_kernel_with_ivalue_args<StateTypeOrVoid, func_type_with_state<StateTypeOrVoid>>::call(kernel, inputs, outputs, state);
- }
-
-private:
- static_assert(details::has_parameter_names_defined<OpSchemaDef>::value, "Operator schema doesn't define parameter_names member.");
- // TODO Allow simpler definition of parameter_names without having to spell out the guts::array type in the schema def.
- static_assert(std::is_same<const guts::array<const char*, num_args>, decltype(OpSchemaDef::parameter_names)>::value, "Operator schema defines parameter_names member, but it isn't the correct type. Must be a static constexpr guts::array of const char* with one entry for each parameter.");
-
-public:
- /**
- * The names of the parameters (as per OpSchemaDef::parameter_names)
- * @return Array
- */
- static constexpr const guts::array<const char*, num_args>& parameter_names() {
- return OpSchemaDef::parameter_names;
- }
-};
-
-/**
- * If T has a method dispatch_key, provide a member constant value equal to true. Otherwise return false.
- * @tparam T
- */
-template<class T, typename = void>
-struct has_function_dispatch_key_defined : std::false_type {};
-template<class T>
-struct has_function_dispatch_key_defined<T, guts::void_t<
- decltype(&T::dispatch_key)
->> : std::true_type {};
-
-/**
- * Wrapper class around a user-defined schema definition providing a way of computing a dispatch key
- * from arguments matching the signature of that schema.
- *
- * @tparam OpSchemaDef Operator schema definition. See OpSchema for more details.
- * @tparam Enable Inferred, used to control specialization
- */
-template<class OpSchemaDef, class Enable = void> class OpDispatchKeySchema final {};
-
-// General case. Operator doesn't overwrite DispatchKey generation. Use default.
-template<class OpSchemaDef>
-class OpDispatchKeySchema<OpSchemaDef, guts::enable_if_t<!has_function_dispatch_key_defined<OpSchemaDef>::value>> final {
- using signature = OpSignatureSchema<OpSchemaDef>;
-
- // TODO Static assert that dispatch_key_type has operator<<(ostream, _) defined for debug output.
- // TODO Use an ADL-based debugString(DispatchKey) function instead of operator<< for debug printing.
-
-public:
- using dispatch_key_type = DispatchKey<OpSchemaDef::num_dispatch_args()>;
-
- static inline dispatch_key_type dispatch_key(const Stack* stack) {
- /* TODO Should we make this a runtime assert now?
- using guts::typelist::map_t;
- using guts::typelist::typelist;
- static_assert(std::is_same<
- map_t<guts::remove_cv_t, map_t<guts::remove_reference_t, typelist<Args...>>>,
- map_t<guts::remove_cv_t, map_t<guts::remove_reference_t, typename signature::parameter_types>>
- >::value, "Invalid argument types passed to OpSchema::dispatch_key()");*/
- return dispatch_key_type {
- details::getDispatchTypeIds_<OpSchemaDef>(torch::jit::last(*stack, signature::num_args))
- };
- }
-};
-
-// Special case. Operator overwrites DispatchKey generation. Use that.
-template<class OpSchemaDef>
-class OpDispatchKeySchema<OpSchemaDef, guts::enable_if_t<has_function_dispatch_key_defined<OpSchemaDef>::value>> final {
- using signature = OpSignatureSchema<OpSchemaDef>;
-
- static_assert(guts::is_function_type<decltype(OpSchemaDef::dispatch_key)>::value, "Operator schema defines dispatch_key member, but it isn't a function.");
-
- using dispatch_key_traits = guts::function_traits<decltype(OpSchemaDef::dispatch_key)>;
-
-public:
- using dispatch_key_type = typename dispatch_key_traits::return_type;
-
-private:
-
- static_assert(guts::is_equality_comparable<dispatch_key_type>::value, "Operator schema specified custom dispatch_key() derivation function, but the returned dispatch key type doesn't have the equality operator defined. Please define it.");
- static_assert(guts::is_hashable<dispatch_key_type>::value, "Operator schema specified custom dispatch_key() derivation function, but the returned dispatch key type doesn't have an overload for std::hash. Please define it.");
-
- static_assert(std::is_same<
- guts::typelist::typelist<const Stack*>,
- typename dispatch_key_traits::parameter_types
- >::value, "Operator schema defines custom dispatch_key() derivation function, but it has the wrong signature. Expected to take one argument, which is of type const Stack*.");
-
-public:
-
- static inline dispatch_key_type dispatch_key(const Stack* stack) {
- /* TODO Should we make this a runtime assert now?
- using guts::typelist::map_t;
- using guts::typelist::typelist;
- static_assert(std::is_same<
- map_t<guts::remove_cv_t, map_t<guts::remove_reference_t, typelist<Args...>>>,
- map_t<guts::remove_cv_t, map_t<guts::remove_reference_t, typename signature::parameter_types>>
- >::value, "Invalid argument types passed to OpSchema::dispatch_key()");
- */
- return OpSchemaDef::dispatch_key(stack);
- }
-};
-
-template<class OpSchemaDef>
-class OpMetadataSchema final {
-private:
- static_assert(has_name_defined<OpSchemaDef>::value, "The operator schema has to define a 'static constexpr const char* name = ...' member to specify the operator name.");
- static_assert(std::is_same<const char* const, decltype(OpSchemaDef::name)>::value, "The 'name' member of the operator schema must have type 'static constexpr const char*'");
-
-public:
- static constexpr const char* name() {
- return OpSchemaDef::name;
- }
-};
-
-} // namespace details
-
-/**
- * Wrapper class for user-defined OpSchemaDef, providing functionality for determining
- * information about the signature and dispatching on that signature. This is the
- * "public" facing class.
- *
- * @tparam OpSchemaDef User-defined OpSchemaDef.
- * This struct is expected to define:
- * - a function type Signature
- * - a constexpr guts<const char*, n_args> parameter_names field (where n_args is
- * the number of arguments in Signature)
- */
-template <class OpSchemaDef>
-class CAFFE2_API OpSchema final {
- // TODO static_assert OpSchemaDef isn't an instanciation of OpSchema. If yes, the caller probably passed an OpSchema somewhere where an OpSchemaDef was expected and wants a good error message.
-public:
- using metadata = details::OpMetadataSchema<OpSchemaDef>;
- /**
- * Information about the signature
- */
- using signature = details::OpSignatureSchema<OpSchemaDef>;
- /**
- * Functionality for dispatching on that signature
- */
- using dispatch = details::OpDispatchKeySchema<OpSchemaDef>;
-};
-
-// TODO test OpSchema::dispatch stuff
-} // namespace c10
diff --git a/aten/src/ATen/core/dispatch/OpSchemaRegistration.cpp b/aten/src/ATen/core/dispatch/OpSchemaRegistration.cpp
deleted file mode 100644
index f153b5e198..0000000000
--- a/aten/src/ATen/core/dispatch/OpSchemaRegistration.cpp
+++ /dev/null
@@ -1 +0,0 @@
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
diff --git a/aten/src/ATen/core/dispatch/OpSchemaRegistration.h b/aten/src/ATen/core/dispatch/OpSchemaRegistration.h
index 29ecde2f7b..07af15027d 100644
--- a/aten/src/ATen/core/dispatch/OpSchemaRegistration.h
+++ b/aten/src/ATen/core/dispatch/OpSchemaRegistration.h
@@ -2,17 +2,38 @@
#include <ATen/core/dispatch/Dispatcher.h>
-// TODO Better error message when this definition is missing
+namespace c10 {
+namespace detail {
+class OpSchemaRegistrar final {
+public:
+ explicit OpSchemaRegistrar(FunctionSchema schema)
+ : opHandle_(c10::Dispatcher::singleton().registerSchema(std::move(schema))) {}
+
+ ~OpSchemaRegistrar() {
+ c10::Dispatcher::singleton().deregisterSchema(opHandle_);
+ }
+
+ const c10::OperatorHandle& opHandle() const {
+ return opHandle_;
+ }
+
+private:
+ c10::OperatorHandle opHandle_;
+};
+} // namespace detail
+} // namespace c10
/**
- * Macro for defining an operator schema. Every user-defined OpSchemaDef struct must
- * invoke this macro on it. Internally, this arranges for the dispatch table for
+ * Macro for defining an operator schema. Every operator schema must
+ * invoke C10_DECLARE_OP_SCHEMA in a header and C10_DEFINE_OP_SCHEMA in one (!)
+ * cpp file. Internally, this arranges for the dispatch table for
* the operator to be created.
*/
-#define C10_DEFINE_OP_SCHEMA(OpSchemaDef) \
- template<> \
- C10_EXPORT c10::DispatchTable<OpSchemaDef>& c10_dispatch_table<OpSchemaDef>() { \
- static c10::DispatchTable<OpSchemaDef> singleton; \
- return singleton; \
+#define C10_DECLARE_OP_SCHEMA(Name) \
+ CAFFE2_API const c10::OperatorHandle& Name(); \
+
+#define C10_DEFINE_OP_SCHEMA(Name, Schema) \
+ C10_EXPORT const c10::OperatorHandle& Name() { \
+ static ::c10::detail::OpSchemaRegistrar registrar(Schema); \
+ return registrar.opHandle(); \
}
-// TODO Also register unboxed calling API here
diff --git a/aten/src/ATen/core/dispatch/OpSchema_test.cpp b/aten/src/ATen/core/dispatch/OpSchema_test.cpp
deleted file mode 100644
index 1f5f16a9f4..0000000000
--- a/aten/src/ATen/core/dispatch/OpSchema_test.cpp
+++ /dev/null
@@ -1,29 +0,0 @@
-#include <ATen/core/dispatch/OpSchema.h>
-#include <c10/util/Array.h>
-#include <ATen/core/Tensor.h>
-
-using namespace c10;
-using at::Tensor;
-
-static_assert(details::is_tensor_arg<Tensor>::value, "");
-static_assert(details::is_tensor_arg<const Tensor&>::value, "");
-static_assert(details::is_tensor_arg<Tensor&&>::value, "");
-static_assert(!details::is_tensor_arg<int>::value, "");
-
-struct SchemaDef final {
- using Signature = bool(int, Tensor, float, Tensor, Tensor, unsigned int);
- static constexpr guts::array<const char*, 6> parameter_names = {{
- "1", "2", "3", "4", "5", "6"
- }};
- static constexpr size_t num_dispatch_args() {return 3;}
- static constexpr size_t num_outputs() {return 0;}
-};
-static_assert(6 == OpSchema<SchemaDef>::signature::num_args, "");
-static_assert(3 == OpSchema<SchemaDef>::signature::num_tensor_args, "");
-static_assert(std::is_same<bool, typename OpSchema<SchemaDef>::signature::return_type>::value, "");
-static_assert(
- std::is_same<
- guts::typelist::
- typelist<int, Tensor, float, Tensor, Tensor, unsigned int>,
- typename OpSchema<SchemaDef>::signature::parameter_types>::value,
- "");
diff --git a/aten/src/ATen/core/dispatch/README.md b/aten/src/ATen/core/dispatch/README.md
new file mode 100644
index 0000000000..1cb2af9056
--- /dev/null
+++ b/aten/src/ATen/core/dispatch/README.md
@@ -0,0 +1,12 @@
+This folder contains the c10 dispatcher. This dispatcher is a single point
+through which we are planning to route all kernel calls.
+Existing dispatch mechanisms from legacy PyTorch or caffe2 are planned to
+be replaced.
+
+This folder contains the following files:
+- Dispatcher.h: Main facade interface. Code using the dispatcher should only use this.
+- DispatchTable.h: Implementation of the actual dispatch mechanism. Hash table with kernels, lookup, ...
+- KernelCache.h: An interface operator kernels can use to inherit from if they need to keep around a cache between invocations
+- KernelFunction.h: The core interface (i.e. function pointer) for calling a kernel
+- OpSchemaRegistration.h: The mechanisms to register new operators with the c10 dispatcher
+- KernelRegistration.h: The mechanisms to register kernels with the c10 dispatcher
diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h
index bca3f3b46e..df7cdcd3b0 100644
--- a/aten/src/ATen/core/jit_type.h
+++ b/aten/src/ATen/core/jit_type.h
@@ -5,7 +5,7 @@
#include <ATen/core/functional.h>
#include <ATen/core/Type.h>
#include <ATen/core/TensorMethods.h>
-
+#include <c10/util/TypeList.h>
#include <caffe2/core/common.h>
#include <memory>
@@ -233,6 +233,7 @@ struct CAFFE2_API OptionalType: public SingleElementType<TypeKind::OptionalType,
ss << "Optional[" << getElementType()->python_str() << "]";
return ss.str();
}
+
// common cast Optional[Tensor] for undefined tensor type
static OptionalTypePtr ofTensor();
private:
@@ -982,26 +983,51 @@ CAFFE2_API c10::optional<TypePtr> unifyTypes(
const TypePtr& t1,
const TypePtr& t2);
-template <typename T>
-TypePtr getTypePtr() {
-#define TYPE_STR(Type) #Type, " ",
- AT_ERROR(
- "Type ",
- c10::demangle_type<T>(),
- " could not be converted to any of the known types { ",
- C10_FORALL_TYPES(TYPE_STR) "}");
-#undef TYPE_STR
-}
+namespace detail {
+template <typename T> struct getTypePtr_ final {
+ static_assert(guts::false_t<T>::value, "Type could not be converted to any of the known types.");
+};
-template<> inline TypePtr getTypePtr<at::Tensor>() { return DynamicType::get(); }
-template<> inline TypePtr getTypePtr<double>() { return FloatType::get(); }
-template<> inline TypePtr getTypePtr<int64_t>() { return IntType::get(); }
-template<> inline TypePtr getTypePtr<bool>() { return BoolType::get(); }
-template<> inline TypePtr getTypePtr<at::Scalar>() { return NumberType::get(); }
-template<> inline TypePtr getTypePtr<std::string>() { return StringType::get(); }
-template<> inline TypePtr getTypePtr<std::vector<at::Tensor>>() { return ListType::ofTensors(); }
-template<> inline TypePtr getTypePtr<std::vector<double>>() { return ListType::ofFloats(); }
-template<> inline TypePtr getTypePtr<std::vector<int64_t>>() { return ListType::ofInts(); }
+template<> struct getTypePtr_<at::Tensor> final {
+ static TypePtr call() { return DynamicType::get(); }
+};
+template<> struct getTypePtr_<double> final {
+ static TypePtr call() { return FloatType::get(); }
+};
+template<> struct getTypePtr_<int64_t> final {
+ static TypePtr call() { return IntType::get(); }
+};
+template<> struct getTypePtr_<bool> final {
+ static TypePtr call() { return BoolType::get(); }
+};
+template<> struct getTypePtr_<at::Scalar> final {
+ static TypePtr call() { return NumberType::get(); }
+};
+template<> struct getTypePtr_<std::string> final {
+ static TypePtr call() { return StringType::get(); }
+};
+template<class T> struct getTypePtr_<std::vector<T>> final {
+ static TypePtr call() {
+ static auto type = ListType::create(getTypePtr_<T>::call());
+ return type;
+ }
+};
+template<class T> struct getTypePtr_<ArrayRef<T>> final {
+ static TypePtr call() {
+ static auto type = ListType::create(getTypePtr_<T>::call());
+ return type;
+ }
+};
+template<class T> struct getTypePtr_<at::optional<T>> final {
+ static TypePtr call() {
+ static auto type = OptionalType::create(getTypePtr_<T>::call());
+ return type;
+ }
+};
+}
+template<class T> inline TypePtr getTypePtr() {
+ return detail::getTypePtr_<T>::call();
+}
CAFFE2_API TypePtr incompleteInferTypeFrom(const IValue& value);
CAFFE2_API TypePtr attemptToRecoverType(const IValue& input_ivalue);
diff --git a/aten/src/ATen/core/opschema/layer_norm.cpp b/aten/src/ATen/core/opschema/layer_norm.cpp
index be908a58fc..d0749d0b14 100644
--- a/aten/src/ATen/core/opschema/layer_norm.cpp
+++ b/aten/src/ATen/core/opschema/layer_norm.cpp
@@ -1,4 +1,25 @@
#include <ATen/core/opschema/layer_norm.h>
#include <ATen/core/dispatch/OpSchemaRegistration.h>
-C10_DEFINE_OP_SCHEMA(c10::core::opschema::LayerNorm);
+namespace c10 {
+namespace core {
+namespace opschema {
+ // TODO Parse schema string instead of creating FunctionSchema manually
+ C10_DEFINE_OP_SCHEMA(LayerNorm, FunctionSchema(
+ "caffe2::layer_norm_dont_use_this_op_yet",
+ (std::vector<c10::Argument>{
+ c10::Argument("input"),
+ c10::Argument("axis", IntType::get()),
+ c10::Argument("epsilon", FloatType::get()),
+ c10::Argument("output", OptionalType::ofTensor(), c10::nullopt, IValue()),
+ c10::Argument("output_mean", OptionalType::ofTensor(), c10::nullopt, IValue()),
+ c10::Argument("output_stdev", OptionalType::ofTensor(), c10::nullopt, IValue())
+ }), (std::vector<c10::Argument>{
+ c10::Argument("output"),
+ c10::Argument("mean"),
+ c10::Argument("stdev")
+ })
+ ));
+}
+}
+}
diff --git a/aten/src/ATen/core/opschema/layer_norm.h b/aten/src/ATen/core/opschema/layer_norm.h
index 58255d0b96..0ea81ae60b 100644
--- a/aten/src/ATen/core/opschema/layer_norm.h
+++ b/aten/src/ATen/core/opschema/layer_norm.h
@@ -1,35 +1,12 @@
#pragma once
-#include <ATen/core/Tensor.h>
-#include <c10/util/Array.h>
-#include <ATen/core/blob.h>
-#include <ATen/core/dispatch/OpSchema.h>
+#include <ATen/core/dispatch/OpSchemaRegistration.h>
namespace c10 {
namespace core {
namespace opschema {
-// TODO This op schema should probably not live in c10 since it's not a method
-// on Tensor. It's only here as a proof-of-concept op and for LATTE team
-// to be able to call caffe2 layer norm from PyTorch.
-struct LayerNorm final {
- static constexpr const char* name = "LayerNorm";
-
- using Signature = std::tuple<at::Tensor, int, float, at::Tensor, at::Tensor, at::Tensor> (
- const at::Tensor& input,
- int axis,
- float epsilon,
- const at::Tensor& output,
- const at::Tensor& output_mean,
- const at::Tensor& output_stdev);
-
- static constexpr size_t num_dispatch_args() {return 1;}
-
- static constexpr size_t num_outputs() {return 3;}
-
- static constexpr c10::guts::array<const char*, 6> parameter_names = {
- {"input", "axis", "epsilon", "output", "output_mean", "output_stdev"}};
-};
+C10_DECLARE_OP_SCHEMA(LayerNorm);
} // namespace opschema
} // namespace core
diff --git a/aten/src/ATen/core/stack.h b/aten/src/ATen/core/stack.h
index 32c07a4d53..e51b6b5ef2 100644
--- a/aten/src/ATen/core/stack.h
+++ b/aten/src/ATen/core/stack.h
@@ -29,6 +29,9 @@ using Operation = std::function<int(Stack&)>;
static inline IValue& peek(Stack& stack, size_t i, size_t N) {
return *(stack.end() - N + i);
}
+static inline const IValue& peek(const Stack& stack, size_t i, size_t N) {
+ return *(stack.end() - N + i);
+}
// treat the last N elements of the stack as a list, looking up the
// slice starting at index i and having length len
static inline at::ArrayRef<IValue> peekSlice(