diff options
82 files changed, 1197 insertions, 1620 deletions
diff --git a/aten/src/ATen/core/dispatch/DeviceId.cpp b/aten/src/ATen/core/dispatch/DeviceId.cpp deleted file mode 100644 index 7c65fbe70c..0000000000 --- a/aten/src/ATen/core/dispatch/DeviceId.cpp +++ /dev/null @@ -1 +0,0 @@ -#include <ATen/core/dispatch/DeviceId.h> diff --git a/aten/src/ATen/core/dispatch/DeviceId.h b/aten/src/ATen/core/dispatch/DeviceId.h deleted file mode 100644 index cdcaf4b635..0000000000 --- a/aten/src/ATen/core/dispatch/DeviceId.h +++ /dev/null @@ -1,36 +0,0 @@ -#pragma once - -#include <c10/util/C++17.h> -#include <functional> -#include <iostream> - -namespace c10 { - -enum class DeviceTypeId : uint8_t { - // Don't use the int values here in the enum (i.e. don't do static_cast to or from int). - // Instead, if you want to serialize this, write a function with switch/case. - CPU = 0, - CUDA = 1, - UNDEFINED -}; - -inline std::ostream& operator<<(std::ostream& stream, DeviceTypeId device_type_id) { - switch(device_type_id) { - case c10::DeviceTypeId::CPU: return stream << "DeviceTypeId(CPU)"; - case c10::DeviceTypeId::CUDA: return stream << "DeviceTypeId(CUDA)"; - case c10::DeviceTypeId::UNDEFINED: return stream << "DeviceTypeId(UNDEFINED)"; - } - throw std::logic_error("Unknown DeviceTypeId: " + c10::guts::to_string(static_cast<int>(device_type_id))); -} - -} - -namespace std { - -template <> struct hash<c10::DeviceTypeId> { - size_t operator()(c10::DeviceTypeId v) const { - return std::hash<uint8_t>()(static_cast<uint8_t>(v)); - } -}; - -} diff --git a/aten/src/ATen/core/dispatch/DispatchKey.cpp b/aten/src/ATen/core/dispatch/DispatchKey.cpp deleted file mode 100644 index 33e8e29b28..0000000000 --- a/aten/src/ATen/core/dispatch/DispatchKey.cpp +++ /dev/null @@ -1 +0,0 @@ -#include <ATen/core/dispatch/DispatchKey.h> diff --git a/aten/src/ATen/core/dispatch/DispatchKey.h b/aten/src/ATen/core/dispatch/DispatchKey.h deleted file mode 100644 index 7c6cf15f61..0000000000 --- a/aten/src/ATen/core/dispatch/DispatchKey.h +++ /dev/null @@ -1,97 +0,0 @@ -#pragma once - -#include <ATen/core/dispatch/DeviceId.h> -#include <ATen/core/dispatch/LayoutId.h> -#include <c10/util/typeid.h> - -#include <vector> -#include <functional> -#include <sstream> -#include <c10/util/Array.h> - -namespace c10 { - -namespace details { - -struct TensorParameterDispatchKey final { - // note: This dispatch key structure is not final yet and will change. Don't rely on it. - DeviceTypeId deviceTypeId; - LayoutId layoutId; - // TODO Move this TypeIdentifier to c10 namespace - caffe2::TypeIdentifier dataType; -}; -inline constexpr bool operator==(const TensorParameterDispatchKey& lhs, const TensorParameterDispatchKey& rhs) { - return lhs.deviceTypeId == rhs.deviceTypeId && lhs.layoutId == rhs.layoutId && lhs.dataType == rhs.dataType; -} - -inline std::ostream& operator<<(std::ostream& stream, const TensorParameterDispatchKey& key) { - return stream << "TensorKey(" << key.deviceTypeId << ", " << key.layoutId.value() << ", " << key.dataType << ")"; -} - -} // namespace details -} // namespace c10 - -namespace std { - template<> - struct hash<c10::details::TensorParameterDispatchKey> { - // TODO constexpr hashing - size_t operator()(const c10::details::TensorParameterDispatchKey& obj) const { - return std::hash<c10::DeviceTypeId>()(obj.deviceTypeId) ^ std::hash<c10::LayoutId>()(obj.layoutId) ^ std::hash<caffe2::TypeIdentifier>()(obj.dataType); - } - }; -} // namespace std - -namespace c10 { -/** - * The dispatch key encodes the runtime type identity of a function call arguments, - * specifying what aspects of this identity can be dynamically dispatched on. - * - * Intuitively, given a function signature like f(Tensor, int), a valid dispatch - * key for the arguments might be [CPUFloatTensor] (notice that 'f' is NOT included - * in the dispatch key, and the runtime type of 'int' is NOT considered for dispatch - * (since it is trivial). - * - * Dispatch keys permit equality tests and are hashable. - * - * @tparam num_dispatch_args The number of dispatchable arguments - */ -template<size_t num_dispatch_args> -struct DispatchKey final { - guts::array<details::TensorParameterDispatchKey, num_dispatch_args> argTypes; -}; - -template<size_t num_dispatch_args> -inline constexpr bool operator==(const DispatchKey<num_dispatch_args> &lhs, const DispatchKey<num_dispatch_args>& rhs) { - // TODO: Use AVX instructions to perform this equality test more quickly - return lhs.argTypes == rhs.argTypes; -} - -template<size_t num_dispatch_args> -inline std::ostream& operator<<(std::ostream& stream, const DispatchKey<num_dispatch_args>& key) { - stream << "DispatchKey("; - if (num_dispatch_args > 0) { - stream << "DispatchKey(" << key.argTypes[0]; - for (size_t i = 1; i < num_dispatch_args; ++i) { - stream << ", " << key.argTypes[i]; - } - stream << ")"; - } - return stream << ")"; -} - -} // namespace c10 - -namespace std { - template<size_t num_dispatch_args> - struct hash<c10::DispatchKey<num_dispatch_args>> { - // TODO constexpr hashing - size_t operator()(const c10::DispatchKey<num_dispatch_args>& obj) const { - size_t hash_value = 0; - for (const auto& argType : obj.argTypes) { - hash_value *= 10883; // prime - hash_value += std::hash<c10::details::TensorParameterDispatchKey>()(argType); - } - return hash_value; - } - }; -} // namespace std diff --git a/aten/src/ATen/core/dispatch/DispatchTable.cpp b/aten/src/ATen/core/dispatch/DispatchTable.cpp deleted file mode 100644 index 0bf7c5903d..0000000000 --- a/aten/src/ATen/core/dispatch/DispatchTable.cpp +++ /dev/null @@ -1 +0,0 @@ -#include <ATen/core/dispatch/DispatchTable.h> diff --git a/aten/src/ATen/core/dispatch/DispatchTable.h b/aten/src/ATen/core/dispatch/DispatchTable.h index 22206658c5..c005314c27 100644 --- a/aten/src/ATen/core/dispatch/DispatchTable.h +++ b/aten/src/ATen/core/dispatch/DispatchTable.h @@ -1,28 +1,29 @@ #pragma once -#include <ATen/core/dispatch/OpSchema.h> +#include <ATen/core/function_schema.h> #include <c10/util/LeftRight.h> #include <c10/util/Metaprogramming.h> #include <c10/util/flat_hash_map.h> #include <ATen/core/ivalue.h> +#include <ATen/core/dispatch/KernelFunction.h> #include <array> #include <atomic> #include <iostream> #include <mutex> #include <type_traits> +#include <sstream> #include <unordered_map> namespace c10 { /** * The type of a user-supplied function to initialize the kernel cache. - * this is stored together with the kernel function in the dispatch table + * this is stored together with the KernelFunction in the DispatchTable * so we can create a new cache instance when a kernel is looked up * from the dispatch table. */ -using KernelStateCreatorFunction = std::unique_ptr<c10::KernelState> (); - +using KernelCacheCreatorFunction = std::unique_ptr<c10::KernelCache> (); /** * The dispatch table stores a pointer to a kernel function and a pointer * to a function initializing a cache for the kernel. If the kernel wants @@ -32,72 +33,100 @@ using KernelStateCreatorFunction = std::unique_ptr<c10::KernelState> (); * this same cache instance. */ struct DispatchTableEntry final { - KernelFunction* kernel_func; - KernelStateCreatorFunction* state_creator_func; + /*not-nullable*/ KernelFunction* kernel_func; + /*not-nullable*/ KernelCacheCreatorFunction* cache_creator_func; }; -namespace details { +namespace detail { /// Kernel implementations in a thread-safe hash table. -template <class Key> class ThreadsafeOperatorTable_ final { public: - template <class Key_> - void emplace(Key_&& key, const DispatchTableEntry& value) { - bool res = map_.write([&](ska::flat_hash_map<Key, DispatchTableEntry>& map) -> bool { - auto result = map.emplace(std::forward<Key>(key), value); + void emplace(TensorTypeId key, const DispatchTableEntry& value, const std::string& operator_name) { + bool res = map_.write([&](ska::flat_hash_map<TensorTypeId, DispatchTableEntry>& map) -> bool { + auto result = map.emplace(key, value); return result.second; }); if (!res) { - AT_ERROR("Tried to register conflicting kernels to the dispatcher: ", key); + AT_ERROR("Tried to register multiple kernels with same dispatch key '", + dispatch_key_to_string(key), "' for operator '", operator_name ,"'."); } } - void erase(const Key& key) { - auto num_removed = - map_.write([&](ska::flat_hash_map<Key, DispatchTableEntry>& map) -> size_t { - return map.erase(key); - }); - assert(num_removed <= 1); // This is not a multi-map - if (num_removed == 0) { - AT_ERROR("Tried to deregister a kernel that isn't registered."); - } + void erase(TensorTypeId key, const std::string& operator_name) { + map_.write([&](ska::flat_hash_map<TensorTypeId, DispatchTableEntry>& map) { + auto num_removed = map.erase(key); + + assert(num_removed <= 1); // This is not a multi-map + if (num_removed == 0) { + AT_ERROR("Tried to deregister a kernel with dispatch key '", + dispatch_key_to_string(key), "' for operator '", operator_name, + "' but that kernel isn't registered. Registered dispatch keys are: ", + list_all_dispatch_keys(map)); + } + }); } - const DispatchTableEntry* lookup(const Key& key) const { - return map_.read([&](const ska::flat_hash_map<Key, DispatchTableEntry>& map) -> const DispatchTableEntry* { + const DispatchTableEntry* lookup(TensorTypeId key, const string& operator_name) const { + return map_.read([&](const ska::flat_hash_map<TensorTypeId, DispatchTableEntry>& map) -> const DispatchTableEntry* { auto found = map.find(key); if (found != map.end()) { return &found->second; } else { - return nullptr; + AT_ERROR("Didn't find kernel to dispatch to for operator '", operator_name, + "'. Tried to look up kernel for dispatch key '", dispatch_key_to_string(key), + "'. Registered dispatch keys are: ", list_all_dispatch_keys(map)); } }); } + bool isEmpty() const { + return map_.read([&](const ska::flat_hash_map<TensorTypeId, DispatchTableEntry>& map) -> bool { + return map.size() == 0; + }); + } + private: - LeftRight<ska::flat_hash_map<Key, DispatchTableEntry>> map_; + static std::string list_all_dispatch_keys(const ska::flat_hash_map<TensorTypeId, DispatchTableEntry>& map) { + if (map.size() == 0) { + return ""; + } + std::ostringstream str; + str << dispatch_key_to_string(map.begin()->first); + for (auto iter = ++map.begin(); iter != map.end(); ++iter) { + str << ", " << dispatch_key_to_string(iter->first); + } + return str.str(); + } + + static std::string dispatch_key_to_string(TensorTypeId id) { + return std::string(toString(tensorTypeIdToBackend(id))) + "[" + guts::to_string(id) + "]"; + } + + LeftRight<ska::flat_hash_map<TensorTypeId, DispatchTableEntry>> map_; }; -} // namespace details +} // namespace detail /** * Per-operator dispatch table. * - * Given an operator specified by 'OpSchemaDef', this class records a dispatch + * Given an operator specified by a FunctionSchema, this class records a dispatch * table for various kernels provided for this operator. For example, if we * consider the operator add(Tensor, Tensor), the dispatch table for this * operator may contain implementations for various dynamic tensor types, such - * as (CPUFloatTensor, CPUFloatTensor), (CUDAFloatTensor, CUDAFloatTensor), etc. - * - * @tparam OpSchemaDef The operator signature this dispatch table encodes. + * as CPUTensorId, CUDATensorId, etc. */ -// TODO: Support dispatch for meta-operators (which apply to all dynamic types) -template <class OpSchemaDef> class DispatchTable final { - private: - using Schema = OpSchema<OpSchemaDef>; - public: - DispatchTable() : kernels_() {} + explicit DispatchTable(const FunctionSchema& schema) + : kernels_() + , reverse_index_of_first_tensor_arg_( + schema.arguments().size() - get_index_of_first_tensor_arg_(schema)) + , operator_name_(schema.name()) {} + + DispatchTable(DispatchTable&&) = default; + DispatchTable& operator=(DispatchTable&&) = default; + DispatchTable(const DispatchTable&) = delete; + DispatchTable& operator=(const DispatchTable&) = delete; /** * Register a kernel in the table at some dispatch key. @@ -105,9 +134,9 @@ class DispatchTable final { * @param dispatch_key Dispatch key to define when this kernel is selected */ void registerKernel( - typename Schema::dispatch::dispatch_key_type dispatch_key, + TensorTypeId dispatch_key, const DispatchTableEntry& kernel) { - kernels_.emplace(std::move(dispatch_key), kernel); + kernels_.emplace(dispatch_key, kernel, operator_name_); } /** @@ -118,9 +147,8 @@ class DispatchTable final { // TODO: This isn't going to work so well when we get more complicated // override patterns! In this case, an operator will show up in multiple // slots, and erasing them one-by-one is probably not such a good idea. - void deregisterKernel( - const typename Schema::dispatch::dispatch_key_type& dispatch_key) { - kernels_.erase(dispatch_key); + void deregisterKernel(TensorTypeId dispatch_key) { + kernels_.erase(dispatch_key, operator_name_); } /** @@ -131,30 +159,39 @@ class DispatchTable final { * @return Kernel function pointing to the right kernel for the given arguments */ const DispatchTableEntry& lookup(const Stack* stack) const { - auto dispatch_key = Schema::dispatch::dispatch_key(stack); - const DispatchTableEntry* found = kernels_.lookup(dispatch_key); - if (found == nullptr) { - // TODO Better error message - include op name and dispatch key (i.e. - // argument types) - AT_ERROR("Didn't find kernel to dispatch to for operator '", Schema::metadata::name(), "'"); - } - return *found; + TensorTypeId dispatch_key = torch::jit::peek( + *stack, + 0, + reverse_index_of_first_tensor_arg_ + ).toTensor().type_id(); + return *kernels_.lookup(dispatch_key, operator_name_); + } + + bool isEmpty() const { + return kernels_.isEmpty(); } private: + static size_t get_index_of_first_tensor_arg_(const FunctionSchema& schema) { + for (size_t i = 0; i < schema.arguments().size(); ++i) { + if (schema.arguments()[i].type()->isSubtypeOf(DynamicType::get())) { // DynamicType means it's a tensor + return i; + } + } + + AT_ERROR("Tried to create dispatch table for operator schema ", schema.name(), " that doesn't have tensor arguments."); + } + detail::ThreadsafeOperatorTable_ kernels_; - details::ThreadsafeOperatorTable_< - typename Schema::dispatch::dispatch_key_type> - kernels_; + // this is caching the index so we don't have to parse the schema inputs + // again and again for each dispatcher lookup. + // reverse_index means this is the distance from the first tensor argument + // to argument_list.end(), i.e. from the top of the stack. + // Since it is distance to end(), this means it's 1-indexed, + // i.e. '1' is the last argument. + size_t reverse_index_of_first_tensor_arg_; + std::string operator_name_; }; } // namespace c10 - -/* - * Use this to access the dispatch table singleton for a given op schema. - * It has an implementation for each op schema def in a cpp file, because - * we can't rely on the one-definition-rule. - */ -template <class OpSchemaDef> -C10_API c10::DispatchTable<OpSchemaDef>& c10_dispatch_table(); diff --git a/aten/src/ATen/core/dispatch/Dispatcher.cpp b/aten/src/ATen/core/dispatch/Dispatcher.cpp index 9f76d8d969..832b687316 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.cpp +++ b/aten/src/ATen/core/dispatch/Dispatcher.cpp @@ -1 +1,8 @@ #include <ATen/core/dispatch/Dispatcher.h> + +namespace c10 { +C10_EXPORT Dispatcher& Dispatcher::singleton() { + static Dispatcher _singleton; + return _singleton; +} +} diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h index 1d845750ef..bcca890c16 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.h +++ b/aten/src/ATen/core/dispatch/Dispatcher.h @@ -1,9 +1,14 @@ #pragma once #include <ATen/core/dispatch/DispatchTable.h> +#include <c10/util/Exception.h> +#include <mutex> +#include <list> namespace c10 { +class CAFFE2_API OperatorHandle; + /** * This class represents an operator kernel, i.e. an operator *after* it was * dispatched to a certain device. You can use it to call the kernel. @@ -14,12 +19,11 @@ namespace c10 { * Also, keeping around the OpKernel instance will keep around a local cache * that is used by some kernels to get better performance when they're called * multiple times (mostly Caffe2 kernels do that). + * + * OpKernel is not threadsafe. */ -class OpKernel final { +class CAFFE2_API OpKernel final { public: - explicit OpKernel(KernelFunction* kernel, KernelStateCreatorFunction* state_creator) - : kernel_(kernel), state_creator_(state_creator) {} - OpKernel(OpKernel&&) = default; OpKernel& operator=(OpKernel&&) = default; OpKernel(const OpKernel&) = delete; @@ -28,60 +32,137 @@ public: /** * Call the operator kernel with the given arguments. */ - void call(Stack* stack) { - if (state_.get() == nullptr) { - AT_ASSERT(state_creator_ != nullptr); - state_ = (*state_creator_)(); + void call(Stack* stack) const { + if (cache_.get() == nullptr) { + AT_ASSERT(cache_creator_ != nullptr); + cache_ = (*cache_creator_)(); } - return (*kernel_)(stack, state_.get()); + return (*kernel_)(stack, cache_.get()); } private: - // The kernel function is a global C function, not a std::function. - // That is, ownership is not an issue. + explicit OpKernel(KernelFunction* kernel, KernelCacheCreatorFunction* cache_creator) + : kernel_(kernel), cache_creator_(cache_creator) {} + friend class Dispatcher; + KernelFunction* kernel_; - KernelStateCreatorFunction* state_creator_; - std::unique_ptr<c10::KernelState> state_; + KernelCacheCreatorFunction* cache_creator_; + mutable std::unique_ptr<c10::KernelCache> cache_; }; /** * Top-level dispatch interface for dispatching via the dynamic dispatcher. */ -template<class OpSchemaDef> -class Dispatcher final { +class CAFFE2_API Dispatcher final { private: - using Schema = OpSchema<OpSchemaDef>; + struct OperatorDef final { + explicit OperatorDef(FunctionSchema schema_) + : dispatchTable(schema_) + , schema(std::move(schema_)) {} + + DispatchTable dispatchTable; + FunctionSchema schema; + }; + friend class OperatorHandle; + public: // Implementation note: this class abstracts over the fact that we have per-operator // dispatch tables. This could be easily adjusted to have a single global hash // table. + static Dispatcher& singleton(); + /** - * Register an operator to the dispatch table for some operator schema. + * Register a new operator schema. The handle returned can be used to register + * kernels to this operator or to call it. */ - static void registerKernel(typename Schema::dispatch::dispatch_key_type dispatch_key, KernelFunction* kernel_func, KernelStateCreatorFunction* state_creator_func) { - auto& dispatch_table_for_this_op = c10_dispatch_table<OpSchemaDef>(); - return dispatch_table_for_this_op.registerKernel(std::move(dispatch_key), DispatchTableEntry{kernel_func, state_creator_func}); - } + OperatorHandle registerSchema(FunctionSchema schema); /** - * Remove an operator from the dispatch table for some operator schema. + * Remove an operator from the dispatcher. Make sure you removed + * all kernels for this operatorbefore calling this. */ - static void deregisterKernel(const typename Schema::dispatch::dispatch_key_type& dispatch_key) { - auto& dispatch_table_for_this_op = c10_dispatch_table<OpSchemaDef>(); - return dispatch_table_for_this_op.deregisterKernel(dispatch_key); - } + void deregisterSchema(const OperatorHandle& op); /** - * Perform a dynamic dispatch and get the kernel for an operator + * Register an operator to the dispatch table for an operator. */ - static OpKernel lookup(const Stack* stack) { - auto& dispatch_table_for_this_op = c10_dispatch_table<OpSchemaDef>(); - const DispatchTableEntry& kernel = dispatch_table_for_this_op.lookup(stack); - return OpKernel(kernel.kernel_func, kernel.state_creator_func); + void registerKernel(const OperatorHandle& op, TensorTypeId dispatch_key, KernelFunction* kernel_func, KernelCacheCreatorFunction* cache_creator_func); + + /** + * Remove an operator from the dispatch table for an operator. + */ + void deregisterKernel(const OperatorHandle& op, TensorTypeId dispatch_key); + + /** + * Perform a dynamic dispatch and get the kernel for an operator. + */ + OpKernel lookup(const OperatorHandle& op, const Stack* stack) const; + +private: + std::list<OperatorDef> operators_; + std::mutex mutex_; +}; + +/** + * This is a handle to an operator schema registered with the dispatcher. + * This handle can be used to register kernels with the dispatcher or + * to lookup a kernel for a certain set of arguments. + */ +class CAFFE2_API OperatorHandle final { +public: + OperatorHandle(OperatorHandle&&) = default; + OperatorHandle& operator=(OperatorHandle&&) = default; + OperatorHandle(const OperatorHandle&) = default; + OperatorHandle& operator=(const OperatorHandle&) = default; + + const FunctionSchema& schema() const { + return operatorDefIterator_->schema; } +private: + explicit OperatorHandle(std::list<Dispatcher::OperatorDef>::iterator operatorDefIterator) + : operatorDefIterator_(std::move(operatorDefIterator)) {} + friend class Dispatcher; + + std::list<Dispatcher::OperatorDef>::iterator operatorDefIterator_; }; + + +inline OperatorHandle Dispatcher::registerSchema(FunctionSchema schema) { + // we need a lock to avoid concurrent writes + std::lock_guard<std::mutex> lock(mutex_); + + operators_.emplace_back(std::move(schema)); + return OperatorHandle(--operators_.end()); +} + +inline void Dispatcher::deregisterSchema(const OperatorHandle& op) { + // we need a lock to avoid concurrent writes + std::lock_guard<std::mutex> lock(mutex_); + + if (!op.operatorDefIterator_->dispatchTable.isEmpty()) { + AT_ERROR("Tried to deregister op schema that still has kernels registered"); + } + operators_.erase(op.operatorDefIterator_); +} + +inline void Dispatcher::registerKernel(const OperatorHandle& op, TensorTypeId dispatch_key, KernelFunction* kernel_func, KernelCacheCreatorFunction* cache_creator_func) { + // note: this doesn't need the mutex because write operations on the list keep iterators intact. + op.operatorDefIterator_->dispatchTable.registerKernel(std::move(dispatch_key), DispatchTableEntry{kernel_func, cache_creator_func}); +} + +inline void Dispatcher::deregisterKernel(const OperatorHandle& op, TensorTypeId dispatch_key) { + // note: this doesn't need the mutex because write operations on the list keep iterators intact. + op.operatorDefIterator_->dispatchTable.deregisterKernel(dispatch_key); +} + +inline OpKernel Dispatcher::lookup(const OperatorHandle& op, const Stack* stack) const { + // note: this doesn't need the mutex because write operations on the list keep iterators intact. + const DispatchTableEntry& kernel = op.operatorDefIterator_->dispatchTable.lookup(stack); + return OpKernel(kernel.kernel_func, kernel.cache_creator_func); +} + } // namespace c10 diff --git a/aten/src/ATen/core/dispatch/KernelCache.h b/aten/src/ATen/core/dispatch/KernelCache.h new file mode 100644 index 0000000000..e879b9d09f --- /dev/null +++ b/aten/src/ATen/core/dispatch/KernelCache.h @@ -0,0 +1,20 @@ +#pragma once + +namespace c10 { + +/** + * A kernel can keep around a cache to have better performance when it's + * called multiple times. This is used by a lot of caffe2 kernels, for example + * conv_op stores a set of tensors for intermediate values to avoid having + * to reallocate them on each call. + * This cache owned by the call site (i.e. stored inside the OpKernel object) + * kept at the call site to call into the kernel) and passed in to the kernel + * as a function argument. It must inherit from KernelCache so the call site + * knows how to store and destruct it. + */ +class CAFFE2_API KernelCache { +public: + virtual ~KernelCache() = default; +}; + +} diff --git a/aten/src/ATen/core/dispatch/KernelFunction.h b/aten/src/ATen/core/dispatch/KernelFunction.h new file mode 100644 index 0000000000..67a296fa4e --- /dev/null +++ b/aten/src/ATen/core/dispatch/KernelFunction.h @@ -0,0 +1,16 @@ +#pragma once + +#include <ATen/core/dispatch/KernelCache.h> +#include <ATen/core/stack.h> + +namespace c10 { + +using Stack = torch::jit::Stack; // TODO Instead of this, move torch::jit::Stack to the c10 namespace. + +/** + * This is the basic ABI for any kernel call. Each kernel is registered as a + * function pointer `KernelFunction*`, i.e. kernels are not allowed to be closures. + */ +using KernelFunction = void(Stack*, KernelCache* cache); + +} diff --git a/aten/src/ATen/core/dispatch/KernelRegistration.cpp b/aten/src/ATen/core/dispatch/KernelRegistration.cpp deleted file mode 100644 index 0848300f5e..0000000000 --- a/aten/src/ATen/core/dispatch/KernelRegistration.cpp +++ /dev/null @@ -1 +0,0 @@ -#include <ATen/core/dispatch/KernelRegistration.h> diff --git a/aten/src/ATen/core/dispatch/KernelRegistration.h b/aten/src/ATen/core/dispatch/KernelRegistration.h index df198c437d..eaab0c15e0 100644 --- a/aten/src/ATen/core/dispatch/KernelRegistration.h +++ b/aten/src/ATen/core/dispatch/KernelRegistration.h @@ -2,13 +2,20 @@ #include <c10/util/Optional.h> #include <ATen/core/dispatch/Dispatcher.h> -#include <ATen/core/dispatch/OpSchema.h> /** * To register your own kernel for an operator, do in one (!) cpp file: - * C10_REGISTER_KERNEL(OpSchemaDef) - * .kernel(&kernel_func) + * C10_REGISTER_KERNEL(OperatorHandle) + * .kernel<decltype(&kernel_func), &kernel_func>() * .dispatchKey(dispatch_key); + * + * Example: + * + * Tensor my_kernel_cpu(Tensor in) {...} + * + * C10_REGISTER_KERNEL(MyOpSchema) + * .kernel<decltype(my_kernel_cpu), &my_kernel_cpu>() + * .dispatchKey(CPUTensorId()); */ namespace c10 { @@ -17,30 +24,29 @@ namespace c10 { // TODO Test no dispatch key defined /** - * Class which, on construction, registers an operator in the dispatch table. The intent is that + * Class which, on construction, registers an operator in the dispatch table. The intent is that * this class is constructed at static initialization time so that operators automatically get * registered when a dlopen() occurs. * - * You shouldn't call this directly; instead, use the KernelRegistrationBuilder - * - * @tparam OpSchemaDef + * You shouldn't call this directly; instead, use the C10_REGISTER_KERNEL macros. */ -template<class OpSchemaDef> class KernelRegistrar final { -private: - using Schema = OpSchema<OpSchemaDef>; public: + using OpHandleGetter = const OperatorHandle& (); + /** - * @param kernel The concrete function implementation to register + * @param op The operator to register the kernel for * @param dispatch_key The dispatch key to register the function to + * @param kernel The concrete function implementation to register + * @param cache_creator A function initializing the cache for the kernel */ - KernelRegistrar(typename Schema::dispatch::dispatch_key_type dispatch_key, KernelFunction* kernel, KernelStateCreatorFunction* state_creator) - : dispatch_key_(std::move(dispatch_key)), owns_registration_(true) { - Dispatcher<OpSchemaDef>::registerKernel(dispatch_key_, kernel, state_creator); + explicit KernelRegistrar(OpHandleGetter *op, TensorTypeId dispatch_key, KernelFunction* kernel, KernelCacheCreatorFunction* cache_creator) + : op_(std::move(op)), dispatch_key_(std::move(dispatch_key)), owns_registration_(true) { + Dispatcher::singleton().registerKernel(op_(), dispatch_key_, kernel, cache_creator); } KernelRegistrar(KernelRegistrar&& rhs) - : dispatch_key_(std::move(rhs.dispatch_key_)), owns_registration_(true) { + : op_(std::move(rhs.op_)), dispatch_key_(std::move(rhs.dispatch_key_)), owns_registration_(true) { rhs.owns_registration_ = false; } @@ -49,17 +55,111 @@ public: ~KernelRegistrar() { if (owns_registration_) { - Dispatcher<OpSchemaDef>::deregisterKernel(dispatch_key_); + Dispatcher::singleton().deregisterKernel(op_(), dispatch_key_); } } private: - const typename Schema::dispatch::dispatch_key_type dispatch_key_; + OpHandleGetter *op_; + const TensorTypeId dispatch_key_; bool owns_registration_; C10_DISABLE_COPY_AND_ASSIGN(KernelRegistrar); }; +namespace detail { +// ivalue_to_arg_type<T>: Take an IValue that is an argument to a kernel and +// cast it to the type that should be passed to the kernel function. +// Examples: If the IValue contains a plain type like an int, return that. +// If the IValue contains an IntList, return it as ArrayRef<int>. +template<class T> +struct ivalue_to_arg_type { + static T call(const IValue& v) { + return std::move(v).to<T>(); + } +}; +template<class T> +struct ivalue_to_arg_type<ArrayRef<T>> { + static ArrayRef<T> call(const IValue& v) { + return v.to<intrusive_ptr<ivalue::List<T>>>()->elements(); + } +}; + +// call_with_ivalue_args: Take a function pointer and an ArrayRef<IValue> +// containing the arguments to call the function pointer with, and call it. +// The extra_args are appended as additional arguments at the end of the function call. +// Example: +// int myfunc(int a, ArrayRef<int> b, string c); +// int main() { +// std::vector<IValue> ivalue_args = {IValue(2), IntList::create(3, 4)}; +// call_with_ivalue_args<decltype(myfunc), &myfunc>(ivalue_args, "extra_arg"); +// } +template<class FuncType, FuncType* func, class... ExtraArgs, size_t... ivalue_arg_indices> +typename guts::function_traits<FuncType>::return_type call_with_ivalue_args_(ArrayRef<IValue> ivalue_args, guts::index_sequence<ivalue_arg_indices...>, ExtraArgs&&... extra_args) { + using IValueArgTypes = typename guts::function_traits<FuncType>::parameter_types; + return (*func)(ivalue_to_arg_type<guts::remove_cv_t<guts::remove_reference_t<guts::typelist::element_t<ivalue_arg_indices, IValueArgTypes>>>>::call(ivalue_args[ivalue_arg_indices])..., std::forward<ExtraArgs>(extra_args)...); +} + +template<class FuncType, FuncType* func, class... ExtraArgs> +typename guts::function_traits<FuncType>::return_type call_with_ivalue_args(ArrayRef<IValue> ivalue_args, ExtraArgs&&... extra_args) { + constexpr size_t num_ivalue_args = guts::function_traits<FuncType>::number_of_parameters - sizeof...(ExtraArgs); + AT_ASSERTM(num_ivalue_args == ivalue_args.size(), "Wrong number of ivalue arguments"); + return call_with_ivalue_args_<FuncType, func>(ivalue_args, guts::make_index_sequence<num_ivalue_args>(), std::forward<ExtraArgs>(extra_args)...); +} + +template<class OutputType> +struct push_outputs final { + static void call(OutputType&& output, Stack* stack) { + push_outputs<std::tuple<OutputType>>(std::tuple<OutputType>(std::move(output)), stack); + } +}; +template<class... OutputTypes> +struct push_outputs<std::tuple<OutputTypes...>> final { + static void call(std::tuple<OutputTypes...>&& output, Stack* stack) { + for (size_t i = 0; i < sizeof...(OutputTypes); ++i) { + torch::jit::push(return_type_to_ivalue(std::move(output))); + } + } +}; + +// SFINAE over (1) does the operator kernel have a cache and (2) does it return a value or void +template<class CacheTypeOrVoid, class FuncType, FuncType* kernel, class Enable = void> struct wrap_kernel {}; +// SFINAE version for kernels with output and with cache +template<class CacheTypeOrVoid, class FuncType, FuncType* kernel> +struct wrap_kernel<CacheTypeOrVoid, FuncType, kernel, guts::enable_if_t<!std::is_same<void, CacheTypeOrVoid>::value && !std::is_same<void, typename guts::function_traits<FuncType>::return_type>::value>> final { + static typename guts::function_traits<FuncType>::return_type call(Stack* stack, KernelCache* cache) { + constexpr size_t num_inputs = guts::function_traits<FuncType>::number_of_parameters - 1; // -1 because it takes the kernel cache as last argument + auto output = call_with_ivalue_args<FuncType, kernel>(torch::jit::last(*stack, num_inputs), static_cast<CacheTypeOrVoid*>(cache)); + push_outputs<typename guts::function_traits<FuncType>::return_type>(std::move(output), stack); + } +}; +// SFINAE version for kernels with output and without a cache +template<class CacheTypeOrVoid, class FuncType, FuncType* kernel> +struct wrap_kernel<CacheTypeOrVoid, FuncType, kernel, guts::enable_if_t<std::is_same<void, CacheTypeOrVoid>::value && !std::is_same<void, typename guts::function_traits<FuncType>::return_type>::value>> final { + static typename guts::function_traits<FuncType>::return_type call(Stack* stack, c10::KernelCache* /*cache*/) { + constexpr size_t num_inputs = guts::function_traits<FuncType>::number_of_parameters; + auto output = call_with_ivalue_args<FuncType, kernel>(torch::jit::last(*stack, num_inputs)); + push_outputs<typename guts::function_traits<FuncType>::return_type>(std::move(output), stack); + } +}; +// SFINAE version for kernels without output and with a cache +template<class CacheTypeOrVoid, class FuncType, FuncType* kernel> +struct wrap_kernel<CacheTypeOrVoid, FuncType, kernel, guts::enable_if_t<!std::is_same<void, CacheTypeOrVoid>::value && std::is_same<void, typename guts::function_traits<FuncType>::return_type>::value>> final { + static typename guts::function_traits<FuncType>::return_type call(Stack* stack, c10::KernelCache* cache) { + constexpr size_t num_inputs = guts::function_traits<FuncType>::number_of_parameters - 1; // -1 because it takes the kernel cache as last argument + call_with_ivalue_args<FuncType, kernel>(torch::jit::last(*stack, num_inputs), static_cast<CacheTypeOrVoid*>(cache)); + } +}; +// SFINAE version for kernels without output and without a cache +template<class CacheTypeOrVoid, class FuncType, FuncType* kernel> +struct wrap_kernel<CacheTypeOrVoid, FuncType, kernel, guts::enable_if_t<std::is_same<void, CacheTypeOrVoid>::value && std::is_same<void, typename guts::function_traits<FuncType>::return_type>::value>> final { + static typename guts::function_traits<FuncType>::return_type call(Stack* stack, c10::KernelCache* /*cache*/) { + constexpr size_t num_inputs = guts::function_traits<FuncType>::number_of_parameters; + call_with_ivalue_args<FuncType, kernel>(torch::jit::last(*stack, num_inputs)); + } +}; +} + /** * Helper class for building a KernelRegistrar. This permits "keyword-argument" like syntax * when performing operator registration, e.g., as in: @@ -76,52 +176,51 @@ private: * .dispatchKey("bla"); * * The resulting full expression is implicitly convertible to a KernelRegistrar. - * - * @tparam OpSchemaDef The operator schema this is building a KernelRegistration for - * @tparam FieldsPresentFlags Remembers which fields are already set in the builder */ -template<class OpSchemaDef, class StateTypeOrVoid, uint64_t FieldsPresentFlags> +template<class CacheTypeOrVoid, uint64_t FieldsPresentFlags> class KernelRegistrationBuilder final { private: - using Schema = OpSchema<OpSchemaDef>; - static constexpr uint64_t DISPATCH_KEY_PRESENT = 0x01 << 0; static constexpr uint64_t KERNEL_PRESENT = 0x01 << 1; - static constexpr uint64_t STATE_PRESENT = 0x01 << 2; + static constexpr uint64_t CACHE_PRESENT = 0x01 << 2; + + using OpHandleGetter = KernelRegistrar::OpHandleGetter; - static std::unique_ptr<c10::KernelState> defaultStateCreator() { + static std::unique_ptr<c10::KernelCache> defaultCacheCreator() { return nullptr; } - template<class State> - static std::unique_ptr<c10::KernelState> stateCreator() { - static_assert(std::is_default_constructible<State>::value, "State class must be default constructible"); - return guts::make_unique<State>(); + template<class Cache> + static std::unique_ptr<c10::KernelCache> cacheCreator() { + static_assert(std::is_default_constructible<Cache>::value, "Cache class must be default constructible"); + return guts::make_unique<Cache>(); } - c10::optional<typename Schema::dispatch::dispatch_key_type> dispatch_key_; + OpHandleGetter *op_; + c10::optional<TensorTypeId> dispatch_key_; KernelFunction* kernel_; - KernelStateCreatorFunction* state_creator_; + KernelCacheCreatorFunction* cache_creator_; public: - constexpr KernelRegistrationBuilder() - : KernelRegistrationBuilder(c10::nullopt, nullptr, &defaultStateCreator) {} + constexpr explicit KernelRegistrationBuilder(OpHandleGetter *op) + : KernelRegistrationBuilder(std::move(op), c10::nullopt, nullptr, &defaultCacheCreator) {} - constexpr KernelRegistrationBuilder( - c10::optional<typename Schema::dispatch::dispatch_key_type> dispatch_key, + constexpr explicit KernelRegistrationBuilder( + OpHandleGetter *op, + c10::optional<TensorTypeId> dispatch_key, KernelFunction* kernel, - KernelStateCreatorFunction* state_creator) - : dispatch_key_(std::move(dispatch_key)), kernel_(kernel), state_creator_(state_creator) {} + KernelCacheCreatorFunction* cache_creator) + : op_(std::move(op)), dispatch_key_(std::move(dispatch_key)), kernel_(kernel), cache_creator_(cache_creator) {} /** - * Implicit coercion to KernelRegistrar<OpSchemaDef> that finalizes the builder and + * Implicit coercion to KernelRegistrar that finalizes the builder and * creates the object. * @return Produced KernelRegistrar */ - operator KernelRegistrar<OpSchemaDef>() && { + operator KernelRegistrar() && { static_assert(FieldsPresentFlags & KERNEL_PRESENT, "Forgot to call .kernel() in kernel registration"); static_assert(FieldsPresentFlags & DISPATCH_KEY_PRESENT, "Forgot to call .dispatchKey() in kernel registration"); - return KernelRegistrar<OpSchemaDef>(std::move(*dispatch_key_), kernel_, state_creator_); + return KernelRegistrar(op_, std::move(*dispatch_key_), kernel_, cache_creator_); } /** @@ -129,9 +228,9 @@ private: * @param dispatch_key dispatch key to register the function to * @return "this" for method chaining */ - constexpr KernelRegistrationBuilder<OpSchemaDef, StateTypeOrVoid, FieldsPresentFlags | DISPATCH_KEY_PRESENT> dispatchKey(typename Schema::dispatch::dispatch_key_type dispatch_key) && { + constexpr KernelRegistrationBuilder<CacheTypeOrVoid, FieldsPresentFlags | DISPATCH_KEY_PRESENT> dispatchKey(TensorTypeId dispatch_key) && { static_assert(!(FieldsPresentFlags & DISPATCH_KEY_PRESENT), "Tried to define kernel twice in same op registration"); - return KernelRegistrationBuilder<OpSchemaDef, StateTypeOrVoid, FieldsPresentFlags | DISPATCH_KEY_PRESENT>(std::move(dispatch_key), kernel_, state_creator_); + return KernelRegistrationBuilder<CacheTypeOrVoid, FieldsPresentFlags | DISPATCH_KEY_PRESENT>(std::move(op_), std::move(dispatch_key), kernel_, cache_creator_); } /** @@ -140,10 +239,10 @@ private: * @return "this" for method chaining */ template<KernelFunction* kernel_func> - constexpr KernelRegistrationBuilder<OpSchemaDef, StateTypeOrVoid, FieldsPresentFlags | KERNEL_PRESENT> kernel() && { + constexpr KernelRegistrationBuilder<CacheTypeOrVoid, FieldsPresentFlags | KERNEL_PRESENT> kernel() && { static_assert(!(FieldsPresentFlags & KERNEL_PRESENT), "Tried to define kernel twice in same op registration"); - // TODO Better error message when kernel function mismatches, one common mismatch is missing state parameter or state parameter present while not expected. - return KernelRegistrationBuilder<OpSchemaDef, StateTypeOrVoid, FieldsPresentFlags | KERNEL_PRESENT>(std::move(dispatch_key_), kernel_func, state_creator_); + // TODO Better error message when kernel function mismatches, one common mismatch is missing cache parameter or cache parameter present while not expected. + return KernelRegistrationBuilder<CacheTypeOrVoid, FieldsPresentFlags | KERNEL_PRESENT>(std::move(op_), std::move(dispatch_key_), kernel_func, cache_creator_); } /** @@ -151,9 +250,10 @@ private: * @param kernel concrete function implementation to be registered * @return "this" for method chaining */ - template<typename Schema::signature::template func_type_with_state<StateTypeOrVoid>* kernel_func> - constexpr KernelRegistrationBuilder<OpSchemaDef, StateTypeOrVoid, FieldsPresentFlags | KERNEL_PRESENT> kernel() && { - return std::move(*this).template kernel<&Schema::signature::template wrap_kernel<StateTypeOrVoid, kernel_func>>(); + template<class FuncType, FuncType* kernel_func> + constexpr KernelRegistrationBuilder<CacheTypeOrVoid, FieldsPresentFlags | KERNEL_PRESENT> kernel() && { + // TODO Better error message if FuncType is not a func type + return std::move(*this).template kernel<&detail::wrap_kernel<CacheTypeOrVoid, FuncType, kernel_func>::call>(); } /** @@ -161,20 +261,19 @@ private: * @param dispatch_key dispatch key to register the function to * @return "this" for method chaining */ - template<class State> - constexpr KernelRegistrationBuilder<OpSchemaDef, State, FieldsPresentFlags | STATE_PRESENT> withState() && { - static_assert(!(FieldsPresentFlags & STATE_PRESENT), "Tried to define state twice in same op registration"); - static_assert(std::is_base_of<c10::KernelState, State>::value, "State must inherit from c10::KernelState"); + template<class Cache> + constexpr KernelRegistrationBuilder<Cache, FieldsPresentFlags | CACHE_PRESENT> withCache() && { + static_assert(!(FieldsPresentFlags & CACHE_PRESENT), "Tried to define cache twice in same op registration"); + static_assert(std::is_base_of<c10::KernelCache, Cache>::value, "Cache must inherit from c10::KernelCache"); - static_assert(!(FieldsPresentFlags & KERNEL_PRESENT), "Cannot set the state after the kernel function is already set. Please call .withState() first and .kernel() later in the chain."); + static_assert(!(FieldsPresentFlags & KERNEL_PRESENT), "Cannot set the cache after the kernel function is already set. Please call .withCache() first and .kernel() later in the chain."); - return KernelRegistrationBuilder<OpSchemaDef, State, FieldsPresentFlags | STATE_PRESENT>(std::move(dispatch_key_), kernel_, &stateCreator<State>); + return KernelRegistrationBuilder<Cache, FieldsPresentFlags | CACHE_PRESENT>(std::move(op_), std::move(dispatch_key_), kernel_, &cacheCreator<Cache>); } }; } // namespace c10 -// TODO Can the builder logic be moved to compile time? // NB: Semicolon after applying this macro is MANDATORY -#define C10_REGISTER_KERNEL(OpSchemaDef) \ - static KernelRegistrar<OpSchemaDef> MACRO_CONCAT(__kernelRegistrationBuilder_, __COUNTER__) = KernelRegistrationBuilder<OpSchemaDef, void, 0>() +#define C10_REGISTER_KERNEL(OperatorHandle) \ + static KernelRegistrar MACRO_CONCAT(__kernelRegistrationBuilder_, __COUNTER__) = KernelRegistrationBuilder<void, 0>(OperatorHandle) diff --git a/aten/src/ATen/core/dispatch/LayoutId.cpp b/aten/src/ATen/core/dispatch/LayoutId.cpp deleted file mode 100644 index bee71d8dcc..0000000000 --- a/aten/src/ATen/core/dispatch/LayoutId.cpp +++ /dev/null @@ -1 +0,0 @@ -#include <ATen/core/dispatch/LayoutId.h> diff --git a/aten/src/ATen/core/dispatch/LayoutId.h b/aten/src/ATen/core/dispatch/LayoutId.h deleted file mode 100644 index d0648392f4..0000000000 --- a/aten/src/ATen/core/dispatch/LayoutId.h +++ /dev/null @@ -1,22 +0,0 @@ -#pragma once - -#include <c10/util/IdWrapper.h> - -namespace c10 { - -class LayoutId final : public at::IdWrapper<LayoutId, uint8_t> { -public: - constexpr explicit LayoutId(underlying_type id): IdWrapper(id) {} - - constexpr uint8_t value() const { - return underlyingId(); - } - - // Don't use this default constructor! - // Unfortunately, a default constructor needs to be defined because of https://reviews.llvm.org/D41223 - constexpr LayoutId(): IdWrapper(0) {} -}; - -} - -C10_DEFINE_HASH_FOR_IDWRAPPER(c10::LayoutId) diff --git a/aten/src/ATen/core/dispatch/OpSchema.cpp b/aten/src/ATen/core/dispatch/OpSchema.cpp deleted file mode 100644 index 595989569c..0000000000 --- a/aten/src/ATen/core/dispatch/OpSchema.cpp +++ /dev/null @@ -1 +0,0 @@ -#include <ATen/core/dispatch/OpSchema.h> diff --git a/aten/src/ATen/core/dispatch/OpSchema.h b/aten/src/ATen/core/dispatch/OpSchema.h deleted file mode 100644 index d88bd7db87..0000000000 --- a/aten/src/ATen/core/dispatch/OpSchema.h +++ /dev/null @@ -1,406 +0,0 @@ -#pragma once - -#include <ATen/core/dispatch/DispatchKey.h> -#include <ATen/core/ivalue.h> -#include <c10/util/Array.h> -#include <c10/util/Metaprogramming.h> -#include <c10/util/TypeList.h> -#include <c10/core/DeviceType.h> -#include <ATen/core/Tensor.h> -#include <ATen/core/stack.h> - -namespace c10 { - -/** - * A kernel can keep around a cache to have better performance when it's - * called multiple times. This is used by a lot of caffe2 kernels. - * This cache owned by the call site and passed in to the kernel as a function - * argument. It must inherit from KernelState so the call site knows how to - * store and destruct it. - */ -class CAFFE2_API KernelState { -public: - virtual ~KernelState() = default; -}; - -using Stack = torch::jit::Stack; // TODO Instead of this, move torch::jit::Stack to the c10 namespace. - -/** - * This is the basic ABI for any kernel call. Each kernel is registered as a - * pointer to a global C function of this type. - */ -using KernelFunction = void(Stack*, KernelState* state); - -namespace details { - -/** - * If Arg is a Tensor or reference to a Tensor, provide the member constant value equal to true. Otherwise - * return false. - */ -template <class Arg> -using is_tensor_arg = std:: - is_same<at::Tensor, guts::remove_cv_t<guts::remove_reference_t<Arg>>>; - -inline DeviceTypeId to_device_type_id(DeviceType device_type) { - switch (device_type) { - case DeviceType::CPU: - return DeviceTypeId::CPU; - case DeviceType::CUDA: - return DeviceTypeId::CUDA; - default: - return DeviceTypeId::UNDEFINED; - } -} - -inline TensorParameterDispatchKey tensor_to_dispatch_key(const at::Tensor& tensor) { - return TensorParameterDispatchKey{ - to_device_type_id(tensor.device().type()), - LayoutId(0), - tensor.dtype().id()}; -} - -template<size_t index, size_t offset, class ParameterTypes, class Enable = void> struct get_ith_tensor_arg_ { - static_assert(!std::is_same<ParameterTypes, ParameterTypes>::value, "Index out of bounds"); -}; -template<size_t index, size_t offset, class Head, class... Tail> -struct get_ith_tensor_arg_<index, offset, guts::typelist::typelist<Head, Tail...>, guts::enable_if_t<index == 0 && is_tensor_arg<Head>::value>> { - static at::Tensor call(ArrayRef<IValue> args) { - if (!args[offset].isTensor()) { - throw std::runtime_error("Expected argument " + guts::to_string(offset) + " to be of type Tensor but found different type."); - } - return args[offset].toTensor(); - } -}; -template<size_t index, size_t offset, class Head, class... Tail> -struct get_ith_tensor_arg_<index, offset, guts::typelist::typelist<Head, Tail...>, guts::enable_if_t<index != 0 || !is_tensor_arg<Head>::value>> { - static at::Tensor call(ArrayRef<IValue> args) { - return get_ith_tensor_arg_<(is_tensor_arg<Head>::value ? (index-1) : index), offset + 1, guts::typelist::typelist<Tail...>>::call(args); - } -}; -template<class ParameterTypes, size_t index> at::Tensor get_ith_tensor_arg(ArrayRef<IValue> args) { - return get_ith_tensor_arg_<index, 0, ParameterTypes>::call(args); -} - -// Extract type ids for all tensors from an array of tensors -template<class OpSchemaDef, size_t... indices> -guts::array<TensorParameterDispatchKey, OpSchemaDef::num_dispatch_args()> getDispatchTypeIds__(ArrayRef<IValue> args, guts::index_sequence<indices...>) { - using ParameterTypes = typename guts::function_traits<typename OpSchemaDef::Signature>::parameter_types; - return {tensor_to_dispatch_key(get_ith_tensor_arg<ParameterTypes, indices>(args))...}; -} - -/** - * Extract the type ids of all tensors in a variadic list of arguments - * - * @tparam Args Inferred variadic list of argument types - * @param args List of arguments to get type ids from - * @return guts::array<TensorParameterDispatchKey, n>, where n is the number of tensor arguments (is_tensor_arg) in the class - */ -template<class OpSchemaDef> -guts::array<TensorParameterDispatchKey, OpSchemaDef::num_dispatch_args()> getDispatchTypeIds_(ArrayRef<IValue> args) { - return getDispatchTypeIds__<OpSchemaDef>(args, guts::make_index_sequence<OpSchemaDef::num_dispatch_args()>()); -} - -// TODO Test getDispatchTypeIds_ - -/** - * If T is a struct with a type field Signature, provides the member constant - * @tparam T - */ -template<class T, typename = void> -struct has_signature_defined : std::false_type {}; -template<class T> -struct has_signature_defined<T, guts::void_t< - typename T::Signature ->> : std::true_type {}; - -// TODO Test has_signature_defined - -template<class T, typename = void> -struct has_parameter_names_defined : std::false_type {}; -template<class T> -struct has_parameter_names_defined<T, guts::void_t< - decltype(T::parameter_names) ->> : std::true_type {}; - -// TODO Test has_parameter_names_defined - -template<class T, typename = void> -struct has_name_defined : std::false_type {}; -template<class T> -struct has_name_defined<T, guts::void_t< - decltype(T::name) ->> : std::true_type {}; - -// TODO Test has_name_defined - -template<class T> -struct ivalue_to_arg_type { - static T call(const IValue& v) { - return std::move(v).to<T>(); - } -}; -template<class T> -struct ivalue_to_arg_type<ArrayRef<T>> { - static ArrayRef<T> call(const IValue& v) { - return v.to<intrusive_ptr<ivalue::List<T>>>()->elements(); - } -}; - -template<class FuncType, class... ExtraArgs, size_t... ivalue_arg_indices> -typename guts::function_traits<FuncType>::return_type call_with_ivalue_args_(FuncType* func, ArrayRef<IValue> ivalue_args, guts::index_sequence<ivalue_arg_indices...>, ExtraArgs&&... extra_args) { - using IValueArgTypes = typename guts::function_traits<FuncType>::parameter_types; - return (*func)(ivalue_to_arg_type<guts::remove_cv_t<guts::remove_reference_t<guts::typelist::element_t<ivalue_arg_indices, IValueArgTypes>>>>::call(ivalue_args[ivalue_arg_indices])..., std::forward<ExtraArgs>(extra_args)...); -} - -template<class FuncType, class... ExtraArgs> -typename guts::function_traits<FuncType>::return_type call_with_ivalue_args(FuncType* func, ArrayRef<IValue> ivalue_args, ExtraArgs&&... extra_args) { - constexpr size_t num_ivalue_args = guts::function_traits<FuncType>::number_of_parameters - sizeof...(ExtraArgs); - return call_with_ivalue_args_<FuncType>(func, ivalue_args, guts::make_index_sequence<num_ivalue_args>(), std::forward<ExtraArgs>(extra_args)...); -} - -template<class OutputType> -struct write_outputs final { - static void call(OutputType&& output, ArrayRef<IValue> outputs) { - write_outputs<std::tuple<OutputType>>(std::tuple<OutputType>(std::move(output)), outputs); - } -}; -template<class... OutputTypes> -struct write_outputs<std::tuple<OutputTypes...>> final { - static void call(std::tuple<OutputTypes...>&& output, ArrayRef<IValue> outputs) { - AT_ASSERT(outputs.size() == sizeof...(OutputTypes)); // Mismatch in number of returns between kernel function and operator schema. - for (size_t i = 0; i < sizeof...(OutputTypes); ++i) { - outputs[i] = return_type_to_ivalue(std::move(output)); - } - } -}; - - -// SFINAE over (1) does the operator kernel have state and (2) does it return a value or void -template<class StateTypeOrVoid, class FuncType, class Enable = void> struct call_kernel_with_ivalue_args {}; -// SFINAE version for kernels with output and with state -template<class StateTypeOrVoid, class FuncType> -struct call_kernel_with_ivalue_args<StateTypeOrVoid, FuncType, guts::enable_if_t<!std::is_same<void, StateTypeOrVoid>::value && !std::is_same<void, typename guts::function_traits<FuncType>::return_type>::value>> final { - static typename guts::function_traits<FuncType>::return_type call(FuncType* func, ArrayRef<IValue> ivalue_args, ArrayRef<IValue> outputs, c10::KernelState* state) { - auto output = call_with_ivalue_args(func, ivalue_args, static_cast<StateTypeOrVoid*>(state)); - write_outputs<typename guts::function_traits<FuncType>::return_type>(std::move(output), outputs); - } -}; -// SFINAE version for kernels with output and without state -template<class StateTypeOrVoid, class FuncType> -struct call_kernel_with_ivalue_args<StateTypeOrVoid, FuncType, guts::enable_if_t<std::is_same<void, StateTypeOrVoid>::value && !std::is_same<void, typename guts::function_traits<FuncType>::return_type>::value>> final { - static typename guts::function_traits<FuncType>::return_type call(FuncType* func, ArrayRef<IValue> ivalue_args, ArrayRef<IValue> outputs, c10::KernelState* /*state*/) { - auto output = call_with_ivalue_args(func, ivalue_args); - write_outputs<typename guts::function_traits<FuncType>::return_type>(std::move(output), outputs); - } -}; -// SFINAE version for kernels without output and with state -template<class StateTypeOrVoid, class FuncType> -struct call_kernel_with_ivalue_args<StateTypeOrVoid, FuncType, guts::enable_if_t<!std::is_same<void, StateTypeOrVoid>::value && std::is_same<void, typename guts::function_traits<FuncType>::return_type>::value>> final { - static typename guts::function_traits<FuncType>::return_type call(FuncType* func, ArrayRef<IValue> ivalue_args, ArrayRef<IValue> outputs, c10::KernelState* state) { - call_with_ivalue_args(func, ivalue_args, static_cast<StateTypeOrVoid*>(state)); - } -}; -// SFINAE version for kernels without output and without state -template<class StateTypeOrVoid, class FuncType> -struct call_kernel_with_ivalue_args<StateTypeOrVoid, FuncType, guts::enable_if_t<std::is_same<void, StateTypeOrVoid>::value && std::is_same<void, typename guts::function_traits<FuncType>::return_type>::value>> final { - static typename guts::function_traits<FuncType>::return_type call(FuncType* func, ArrayRef<IValue> ivalue_args, ArrayRef<IValue> outputs, c10::KernelState* /*state*/) { - call_with_ivalue_args(func, ivalue_args); - } -}; - -template<class FuncType, class AddedParameter, class Enable = void> struct add_ptr_parameter_if_not_void final {}; -template<class Return, class... Parameters, class AddedParameter> -struct add_ptr_parameter_if_not_void<Return(Parameters...), AddedParameter, guts::enable_if_t<!std::is_same<void, AddedParameter>::value>> final { - using type = Return(Parameters..., AddedParameter*); -}; -template<class FuncType> struct add_ptr_parameter_if_not_void<FuncType, void, void> final { - using type = FuncType; -}; - -/** - * Wrapper class around a user-provided schema definition some useful information about the schema. - * - * @tparam OpSchemaDef Operator schema definition. See OpSchema for more details. - */ -template<class OpSchemaDef> class OpSignatureSchema final { - static_assert(details::has_signature_defined<OpSchemaDef>::value, "Operator schema doesn't define a valid Signature member type."); - static_assert(guts::is_function_type<typename OpSchemaDef::Signature>::value, "Signature member of operator schema must be a function type."); - - using signature_traits = guts::function_traits<typename OpSchemaDef::Signature>; -public: - /** - * The function type OpSchemaDef::Signature - */ - using func_type = typename signature_traits::func_type; - /** - * The return type of the function OpSchemaDef::Signature - */ - using return_type = typename signature_traits::return_type; - /** - * A type list of the parameter types of OpSchemaDef::Signature - */ - using parameter_types = typename signature_traits::parameter_types; - - /** - * The number of arguments of OpSchemaDef::Signature - */ - static constexpr size_t num_args = guts::typelist::size<parameter_types>::value; - /** - * The number of tensor arguments (as per is_tensor_arg) in OpSchemaDef::Signature - */ - static constexpr size_t num_tensor_args = guts::typelist::count_if<details::is_tensor_arg, parameter_types>::value; - - static constexpr size_t num_outputs = OpSchemaDef::num_outputs(); - - template<class StateTypeOrVoid> using func_type_with_state = typename add_ptr_parameter_if_not_void<func_type, StateTypeOrVoid>::type; - - template<class StateTypeOrVoid, func_type_with_state<StateTypeOrVoid>* kernel> - static void wrap_kernel(Stack* stack, KernelState* state) { - constexpr size_t num_inputs = guts::typelist::size<parameter_types>::value; - constexpr size_t num_outputs = 1; // TODO allow multiple outputs if it's a tuple - - ArrayRef<IValue> inputs = torch::jit::peekSlice(*stack, 0, num_inputs + num_outputs, num_inputs); - ArrayRef<IValue> outputs = torch::jit::peekSlice(*stack, 0, num_outputs, num_outputs); - - call_kernel_with_ivalue_args<StateTypeOrVoid, func_type_with_state<StateTypeOrVoid>>::call(kernel, inputs, outputs, state); - } - -private: - static_assert(details::has_parameter_names_defined<OpSchemaDef>::value, "Operator schema doesn't define parameter_names member."); - // TODO Allow simpler definition of parameter_names without having to spell out the guts::array type in the schema def. - static_assert(std::is_same<const guts::array<const char*, num_args>, decltype(OpSchemaDef::parameter_names)>::value, "Operator schema defines parameter_names member, but it isn't the correct type. Must be a static constexpr guts::array of const char* with one entry for each parameter."); - -public: - /** - * The names of the parameters (as per OpSchemaDef::parameter_names) - * @return Array - */ - static constexpr const guts::array<const char*, num_args>& parameter_names() { - return OpSchemaDef::parameter_names; - } -}; - -/** - * If T has a method dispatch_key, provide a member constant value equal to true. Otherwise return false. - * @tparam T - */ -template<class T, typename = void> -struct has_function_dispatch_key_defined : std::false_type {}; -template<class T> -struct has_function_dispatch_key_defined<T, guts::void_t< - decltype(&T::dispatch_key) ->> : std::true_type {}; - -/** - * Wrapper class around a user-defined schema definition providing a way of computing a dispatch key - * from arguments matching the signature of that schema. - * - * @tparam OpSchemaDef Operator schema definition. See OpSchema for more details. - * @tparam Enable Inferred, used to control specialization - */ -template<class OpSchemaDef, class Enable = void> class OpDispatchKeySchema final {}; - -// General case. Operator doesn't overwrite DispatchKey generation. Use default. -template<class OpSchemaDef> -class OpDispatchKeySchema<OpSchemaDef, guts::enable_if_t<!has_function_dispatch_key_defined<OpSchemaDef>::value>> final { - using signature = OpSignatureSchema<OpSchemaDef>; - - // TODO Static assert that dispatch_key_type has operator<<(ostream, _) defined for debug output. - // TODO Use an ADL-based debugString(DispatchKey) function instead of operator<< for debug printing. - -public: - using dispatch_key_type = DispatchKey<OpSchemaDef::num_dispatch_args()>; - - static inline dispatch_key_type dispatch_key(const Stack* stack) { - /* TODO Should we make this a runtime assert now? - using guts::typelist::map_t; - using guts::typelist::typelist; - static_assert(std::is_same< - map_t<guts::remove_cv_t, map_t<guts::remove_reference_t, typelist<Args...>>>, - map_t<guts::remove_cv_t, map_t<guts::remove_reference_t, typename signature::parameter_types>> - >::value, "Invalid argument types passed to OpSchema::dispatch_key()");*/ - return dispatch_key_type { - details::getDispatchTypeIds_<OpSchemaDef>(torch::jit::last(*stack, signature::num_args)) - }; - } -}; - -// Special case. Operator overwrites DispatchKey generation. Use that. -template<class OpSchemaDef> -class OpDispatchKeySchema<OpSchemaDef, guts::enable_if_t<has_function_dispatch_key_defined<OpSchemaDef>::value>> final { - using signature = OpSignatureSchema<OpSchemaDef>; - - static_assert(guts::is_function_type<decltype(OpSchemaDef::dispatch_key)>::value, "Operator schema defines dispatch_key member, but it isn't a function."); - - using dispatch_key_traits = guts::function_traits<decltype(OpSchemaDef::dispatch_key)>; - -public: - using dispatch_key_type = typename dispatch_key_traits::return_type; - -private: - - static_assert(guts::is_equality_comparable<dispatch_key_type>::value, "Operator schema specified custom dispatch_key() derivation function, but the returned dispatch key type doesn't have the equality operator defined. Please define it."); - static_assert(guts::is_hashable<dispatch_key_type>::value, "Operator schema specified custom dispatch_key() derivation function, but the returned dispatch key type doesn't have an overload for std::hash. Please define it."); - - static_assert(std::is_same< - guts::typelist::typelist<const Stack*>, - typename dispatch_key_traits::parameter_types - >::value, "Operator schema defines custom dispatch_key() derivation function, but it has the wrong signature. Expected to take one argument, which is of type const Stack*."); - -public: - - static inline dispatch_key_type dispatch_key(const Stack* stack) { - /* TODO Should we make this a runtime assert now? - using guts::typelist::map_t; - using guts::typelist::typelist; - static_assert(std::is_same< - map_t<guts::remove_cv_t, map_t<guts::remove_reference_t, typelist<Args...>>>, - map_t<guts::remove_cv_t, map_t<guts::remove_reference_t, typename signature::parameter_types>> - >::value, "Invalid argument types passed to OpSchema::dispatch_key()"); - */ - return OpSchemaDef::dispatch_key(stack); - } -}; - -template<class OpSchemaDef> -class OpMetadataSchema final { -private: - static_assert(has_name_defined<OpSchemaDef>::value, "The operator schema has to define a 'static constexpr const char* name = ...' member to specify the operator name."); - static_assert(std::is_same<const char* const, decltype(OpSchemaDef::name)>::value, "The 'name' member of the operator schema must have type 'static constexpr const char*'"); - -public: - static constexpr const char* name() { - return OpSchemaDef::name; - } -}; - -} // namespace details - -/** - * Wrapper class for user-defined OpSchemaDef, providing functionality for determining - * information about the signature and dispatching on that signature. This is the - * "public" facing class. - * - * @tparam OpSchemaDef User-defined OpSchemaDef. - * This struct is expected to define: - * - a function type Signature - * - a constexpr guts<const char*, n_args> parameter_names field (where n_args is - * the number of arguments in Signature) - */ -template <class OpSchemaDef> -class CAFFE2_API OpSchema final { - // TODO static_assert OpSchemaDef isn't an instanciation of OpSchema. If yes, the caller probably passed an OpSchema somewhere where an OpSchemaDef was expected and wants a good error message. -public: - using metadata = details::OpMetadataSchema<OpSchemaDef>; - /** - * Information about the signature - */ - using signature = details::OpSignatureSchema<OpSchemaDef>; - /** - * Functionality for dispatching on that signature - */ - using dispatch = details::OpDispatchKeySchema<OpSchemaDef>; -}; - -// TODO test OpSchema::dispatch stuff -} // namespace c10 diff --git a/aten/src/ATen/core/dispatch/OpSchemaRegistration.cpp b/aten/src/ATen/core/dispatch/OpSchemaRegistration.cpp deleted file mode 100644 index f153b5e198..0000000000 --- a/aten/src/ATen/core/dispatch/OpSchemaRegistration.cpp +++ /dev/null @@ -1 +0,0 @@ -#include <ATen/core/dispatch/OpSchemaRegistration.h> diff --git a/aten/src/ATen/core/dispatch/OpSchemaRegistration.h b/aten/src/ATen/core/dispatch/OpSchemaRegistration.h index 29ecde2f7b..07af15027d 100644 --- a/aten/src/ATen/core/dispatch/OpSchemaRegistration.h +++ b/aten/src/ATen/core/dispatch/OpSchemaRegistration.h @@ -2,17 +2,38 @@ #include <ATen/core/dispatch/Dispatcher.h> -// TODO Better error message when this definition is missing +namespace c10 { +namespace detail { +class OpSchemaRegistrar final { +public: + explicit OpSchemaRegistrar(FunctionSchema schema) + : opHandle_(c10::Dispatcher::singleton().registerSchema(std::move(schema))) {} + + ~OpSchemaRegistrar() { + c10::Dispatcher::singleton().deregisterSchema(opHandle_); + } + + const c10::OperatorHandle& opHandle() const { + return opHandle_; + } + +private: + c10::OperatorHandle opHandle_; +}; +} // namespace detail +} // namespace c10 /** - * Macro for defining an operator schema. Every user-defined OpSchemaDef struct must - * invoke this macro on it. Internally, this arranges for the dispatch table for + * Macro for defining an operator schema. Every operator schema must + * invoke C10_DECLARE_OP_SCHEMA in a header and C10_DEFINE_OP_SCHEMA in one (!) + * cpp file. Internally, this arranges for the dispatch table for * the operator to be created. */ -#define C10_DEFINE_OP_SCHEMA(OpSchemaDef) \ - template<> \ - C10_EXPORT c10::DispatchTable<OpSchemaDef>& c10_dispatch_table<OpSchemaDef>() { \ - static c10::DispatchTable<OpSchemaDef> singleton; \ - return singleton; \ +#define C10_DECLARE_OP_SCHEMA(Name) \ + CAFFE2_API const c10::OperatorHandle& Name(); \ + +#define C10_DEFINE_OP_SCHEMA(Name, Schema) \ + C10_EXPORT const c10::OperatorHandle& Name() { \ + static ::c10::detail::OpSchemaRegistrar registrar(Schema); \ + return registrar.opHandle(); \ } -// TODO Also register unboxed calling API here diff --git a/aten/src/ATen/core/dispatch/OpSchema_test.cpp b/aten/src/ATen/core/dispatch/OpSchema_test.cpp deleted file mode 100644 index 1f5f16a9f4..0000000000 --- a/aten/src/ATen/core/dispatch/OpSchema_test.cpp +++ /dev/null @@ -1,29 +0,0 @@ -#include <ATen/core/dispatch/OpSchema.h> -#include <c10/util/Array.h> -#include <ATen/core/Tensor.h> - -using namespace c10; -using at::Tensor; - -static_assert(details::is_tensor_arg<Tensor>::value, ""); -static_assert(details::is_tensor_arg<const Tensor&>::value, ""); -static_assert(details::is_tensor_arg<Tensor&&>::value, ""); -static_assert(!details::is_tensor_arg<int>::value, ""); - -struct SchemaDef final { - using Signature = bool(int, Tensor, float, Tensor, Tensor, unsigned int); - static constexpr guts::array<const char*, 6> parameter_names = {{ - "1", "2", "3", "4", "5", "6" - }}; - static constexpr size_t num_dispatch_args() {return 3;} - static constexpr size_t num_outputs() {return 0;} -}; -static_assert(6 == OpSchema<SchemaDef>::signature::num_args, ""); -static_assert(3 == OpSchema<SchemaDef>::signature::num_tensor_args, ""); -static_assert(std::is_same<bool, typename OpSchema<SchemaDef>::signature::return_type>::value, ""); -static_assert( - std::is_same< - guts::typelist:: - typelist<int, Tensor, float, Tensor, Tensor, unsigned int>, - typename OpSchema<SchemaDef>::signature::parameter_types>::value, - ""); diff --git a/aten/src/ATen/core/dispatch/README.md b/aten/src/ATen/core/dispatch/README.md new file mode 100644 index 0000000000..1cb2af9056 --- /dev/null +++ b/aten/src/ATen/core/dispatch/README.md @@ -0,0 +1,12 @@ +This folder contains the c10 dispatcher. This dispatcher is a single point +through which we are planning to route all kernel calls. +Existing dispatch mechanisms from legacy PyTorch or caffe2 are planned to +be replaced. + +This folder contains the following files: +- Dispatcher.h: Main facade interface. Code using the dispatcher should only use this. +- DispatchTable.h: Implementation of the actual dispatch mechanism. Hash table with kernels, lookup, ... +- KernelCache.h: An interface operator kernels can use to inherit from if they need to keep around a cache between invocations +- KernelFunction.h: The core interface (i.e. function pointer) for calling a kernel +- OpSchemaRegistration.h: The mechanisms to register new operators with the c10 dispatcher +- KernelRegistration.h: The mechanisms to register kernels with the c10 dispatcher diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index bca3f3b46e..df7cdcd3b0 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -5,7 +5,7 @@ #include <ATen/core/functional.h> #include <ATen/core/Type.h> #include <ATen/core/TensorMethods.h> - +#include <c10/util/TypeList.h> #include <caffe2/core/common.h> #include <memory> @@ -233,6 +233,7 @@ struct CAFFE2_API OptionalType: public SingleElementType<TypeKind::OptionalType, ss << "Optional[" << getElementType()->python_str() << "]"; return ss.str(); } + // common cast Optional[Tensor] for undefined tensor type static OptionalTypePtr ofTensor(); private: @@ -982,26 +983,51 @@ CAFFE2_API c10::optional<TypePtr> unifyTypes( const TypePtr& t1, const TypePtr& t2); -template <typename T> -TypePtr getTypePtr() { -#define TYPE_STR(Type) #Type, " ", - AT_ERROR( - "Type ", - c10::demangle_type<T>(), - " could not be converted to any of the known types { ", - C10_FORALL_TYPES(TYPE_STR) "}"); -#undef TYPE_STR -} +namespace detail { +template <typename T> struct getTypePtr_ final { + static_assert(guts::false_t<T>::value, "Type could not be converted to any of the known types."); +}; -template<> inline TypePtr getTypePtr<at::Tensor>() { return DynamicType::get(); } -template<> inline TypePtr getTypePtr<double>() { return FloatType::get(); } -template<> inline TypePtr getTypePtr<int64_t>() { return IntType::get(); } -template<> inline TypePtr getTypePtr<bool>() { return BoolType::get(); } -template<> inline TypePtr getTypePtr<at::Scalar>() { return NumberType::get(); } -template<> inline TypePtr getTypePtr<std::string>() { return StringType::get(); } -template<> inline TypePtr getTypePtr<std::vector<at::Tensor>>() { return ListType::ofTensors(); } -template<> inline TypePtr getTypePtr<std::vector<double>>() { return ListType::ofFloats(); } -template<> inline TypePtr getTypePtr<std::vector<int64_t>>() { return ListType::ofInts(); } +template<> struct getTypePtr_<at::Tensor> final { + static TypePtr call() { return DynamicType::get(); } +}; +template<> struct getTypePtr_<double> final { + static TypePtr call() { return FloatType::get(); } +}; +template<> struct getTypePtr_<int64_t> final { + static TypePtr call() { return IntType::get(); } +}; +template<> struct getTypePtr_<bool> final { + static TypePtr call() { return BoolType::get(); } +}; +template<> struct getTypePtr_<at::Scalar> final { + static TypePtr call() { return NumberType::get(); } +}; +template<> struct getTypePtr_<std::string> final { + static TypePtr call() { return StringType::get(); } +}; +template<class T> struct getTypePtr_<std::vector<T>> final { + static TypePtr call() { + static auto type = ListType::create(getTypePtr_<T>::call()); + return type; + } +}; +template<class T> struct getTypePtr_<ArrayRef<T>> final { + static TypePtr call() { + static auto type = ListType::create(getTypePtr_<T>::call()); + return type; + } +}; +template<class T> struct getTypePtr_<at::optional<T>> final { + static TypePtr call() { + static auto type = OptionalType::create(getTypePtr_<T>::call()); + return type; + } +}; +} +template<class T> inline TypePtr getTypePtr() { + return detail::getTypePtr_<T>::call(); +} CAFFE2_API TypePtr incompleteInferTypeFrom(const IValue& value); CAFFE2_API TypePtr attemptToRecoverType(const IValue& input_ivalue); diff --git a/aten/src/ATen/core/opschema/layer_norm.cpp b/aten/src/ATen/core/opschema/layer_norm.cpp index be908a58fc..d0749d0b14 100644 --- a/aten/src/ATen/core/opschema/layer_norm.cpp +++ b/aten/src/ATen/core/opschema/layer_norm.cpp @@ -1,4 +1,25 @@ #include <ATen/core/opschema/layer_norm.h> #include <ATen/core/dispatch/OpSchemaRegistration.h> -C10_DEFINE_OP_SCHEMA(c10::core::opschema::LayerNorm); +namespace c10 { +namespace core { +namespace opschema { + // TODO Parse schema string instead of creating FunctionSchema manually + C10_DEFINE_OP_SCHEMA(LayerNorm, FunctionSchema( + "caffe2::layer_norm_dont_use_this_op_yet", + (std::vector<c10::Argument>{ + c10::Argument("input"), + c10::Argument("axis", IntType::get()), + c10::Argument("epsilon", FloatType::get()), + c10::Argument("output", OptionalType::ofTensor(), c10::nullopt, IValue()), + c10::Argument("output_mean", OptionalType::ofTensor(), c10::nullopt, IValue()), + c10::Argument("output_stdev", OptionalType::ofTensor(), c10::nullopt, IValue()) + }), (std::vector<c10::Argument>{ + c10::Argument("output"), + c10::Argument("mean"), + c10::Argument("stdev") + }) + )); +} +} +} diff --git a/aten/src/ATen/core/opschema/layer_norm.h b/aten/src/ATen/core/opschema/layer_norm.h index 58255d0b96..0ea81ae60b 100644 --- a/aten/src/ATen/core/opschema/layer_norm.h +++ b/aten/src/ATen/core/opschema/layer_norm.h @@ -1,35 +1,12 @@ #pragma once -#include <ATen/core/Tensor.h> -#include <c10/util/Array.h> -#include <ATen/core/blob.h> -#include <ATen/core/dispatch/OpSchema.h> +#include <ATen/core/dispatch/OpSchemaRegistration.h> namespace c10 { namespace core { namespace opschema { -// TODO This op schema should probably not live in c10 since it's not a method -// on Tensor. It's only here as a proof-of-concept op and for LATTE team -// to be able to call caffe2 layer norm from PyTorch. -struct LayerNorm final { - static constexpr const char* name = "LayerNorm"; - - using Signature = std::tuple<at::Tensor, int, float, at::Tensor, at::Tensor, at::Tensor> ( - const at::Tensor& input, - int axis, - float epsilon, - const at::Tensor& output, - const at::Tensor& output_mean, - const at::Tensor& output_stdev); - - static constexpr size_t num_dispatch_args() {return 1;} - - static constexpr size_t num_outputs() {return 3;} - - static constexpr c10::guts::array<const char*, 6> parameter_names = { - {"input", "axis", "epsilon", "output", "output_mean", "output_stdev"}}; -}; +C10_DECLARE_OP_SCHEMA(LayerNorm); } // namespace opschema } // namespace core diff --git a/aten/src/ATen/core/stack.h b/aten/src/ATen/core/stack.h index 32c07a4d53..e51b6b5ef2 100644 --- a/aten/src/ATen/core/stack.h +++ b/aten/src/ATen/core/stack.h @@ -29,6 +29,9 @@ using Operation = std::function<int(Stack&)>; static inline IValue& peek(Stack& stack, size_t i, size_t N) { return *(stack.end() - N + i); } +static inline const IValue& peek(const Stack& stack, size_t i, size_t N) { + return *(stack.end() - N + i); +} // treat the last N elements of the stack as a list, looking up the // slice starting at index i and having length len static inline at::ArrayRef<IValue> peekSlice( diff --git a/c10/test/util/TypeList_test.cpp b/c10/test/util/TypeList_test.cpp index dba1476974..dd811b5771 100644 --- a/c10/test/util/TypeList_test.cpp +++ b/c10/test/util/TypeList_test.cpp @@ -137,5 +137,11 @@ namespace test_map_types_to_values { static_assert(std::is_same<decltype(expected), decltype(result)>::value, ""); EXPECT_EQ(expected, result); } +} +namespace test_find_if { + static_assert(0 == find_if<typelist<char&>, std::is_reference>::value, ""); + static_assert(0 == find_if<typelist<char&, int, char&, int&>, std::is_reference>::value, ""); + static_assert(2 == find_if<typelist<char, int, char&, int&>, std::is_reference>::value, ""); + static_assert(3 == find_if<typelist<char, int, char, int&>, std::is_reference>::value, ""); } diff --git a/c10/util/TypeList.h b/c10/util/TypeList.h index 4a55e9aaae..bb0d6bd7c6 100644 --- a/c10/util/TypeList.h +++ b/c10/util/TypeList.h @@ -3,11 +3,14 @@ #include <c10/util/C++17.h> #include <c10/util/TypeTraits.h> -namespace c10 { namespace guts { namespace typelist { +namespace c10 { namespace guts { -namespace detail { template<class... T> struct false_t : std::false_type {}; -} +template<template<class> class... T> struct false_higher_t : std::false_type {}; + +namespace typelist { + + /** * Type holding a list of types for compile time type computations @@ -25,7 +28,7 @@ private: * 3 == size<typelist<int, int, double>>::value */ template<class TypeList> struct size final { - static_assert(detail::false_t<TypeList>::value, "In typelist::size<T>, T must be typelist<...>."); + static_assert(false_t<TypeList>::value, "In typelist::size<T>, T must be typelist<...>."); }; template<class... Types> struct size<typelist<Types...>> final { static constexpr size_t value = sizeof...(Types); @@ -39,7 +42,7 @@ template<class... Types> struct size<typelist<Types...>> final { * std::tuple<int, string> == to_tuple_t<typelist<int, string>> */ template<class TypeList> struct to_tuple final { - static_assert(detail::false_t<TypeList>::value, "In typelist::to_tuple<T>, T must be typelist<...>."); + static_assert(false_t<TypeList>::value, "In typelist::to_tuple<T>, T must be typelist<...>."); }; template<class... Types> struct to_tuple<typelist<Types...>> final { using type = std::tuple<Types...>; @@ -55,7 +58,7 @@ template<class TypeList> using to_tuple_t = typename to_tuple<TypeList>::type; * typelist<int, string> == from_tuple_t<std::tuple<int, string>> */ template<class Tuple> struct from_tuple final { - static_assert(detail::false_t<Tuple>::value, "In typelist::from_tuple<T>, T must be std::tuple<...>."); + static_assert(false_t<Tuple>::value, "In typelist::from_tuple<T>, T must be std::tuple<...>."); }; template<class... Types> struct from_tuple<std::tuple<Types...>> final { using type = typelist<Types...>; @@ -70,7 +73,7 @@ template<class Tuple> using from_tuple_t = typename from_tuple<Tuple>::type; * typelist<int, string, int> == concat_t<typelist<int, string>, typelist<int>> */ template<class... TypeLists> struct concat final { - static_assert(detail::false_t<TypeLists...>::value, "In typelist::concat<T1, ...>, the T arguments each must be typelist<...>."); + static_assert(false_t<TypeLists...>::value, "In typelist::concat<T1, ...>, the T arguments each must be typelist<...>."); }; template<class... Head1Types, class... Head2Types, class... TailLists> struct concat<typelist<Head1Types...>, typelist<Head2Types...>, TailLists...> final { @@ -94,7 +97,7 @@ template<class... TypeLists> using concat_t = typename concat<TypeLists...>::typ * typelist<int&, const string&&> == filter_t<std::is_reference, typelist<void, string, int&, bool, const string&&, int>> */ template<template <class> class Condition, class TypeList> struct filter final { - static_assert(detail::false_t<TypeList>::value, "In typelist::filter<Condition, TypeList>, the TypeList argument must be typelist<...>."); + static_assert(false_t<TypeList>::value, "In typelist::filter<Condition, TypeList>, the TypeList argument must be typelist<...>."); }; template<template <class> class Condition, class Head, class... Tail> struct filter<Condition, typelist<Head, Tail...>> final { @@ -137,7 +140,7 @@ struct count_if final { * false == true_for_each_type<std::is_reference, typelist<int&, const float&&, MyClass>>::value */ template<template <class> class Condition, class TypeList> struct true_for_each_type final { - static_assert(detail::false_t<TypeList>::value, "In typelist::true_for_each_type<Condition, TypeList>, the TypeList argument must be typelist<...>."); + static_assert(false_t<TypeList>::value, "In typelist::true_for_each_type<Condition, TypeList>, the TypeList argument must be typelist<...>."); }; template<template <class> class Condition, class... Types> struct true_for_each_type<Condition, typelist<Types...>> final @@ -153,7 +156,7 @@ struct true_for_each_type<Condition, typelist<Types...>> final * typelist<int&, double&, string&> == map_t<std::add_lvalue_reference_t, typelist<int, double, string>> */ template<template <class> class Mapper, class TypeList> struct map final { - static_assert(detail::false_t<TypeList>::value, "In typelist::map<Mapper, TypeList>, the TypeList argument must be typelist<...>."); + static_assert(false_t<TypeList>::value, "In typelist::map<Mapper, TypeList>, the TypeList argument must be typelist<...>."); }; template<template <class> class Mapper, class... Types> struct map<Mapper, typelist<Types...>> final { @@ -170,7 +173,7 @@ using map_t = typename map<Mapper, TypeList>::type; * int == head_t<typelist<int, string>> */ template<class TypeList> struct head final { - static_assert(detail::false_t<TypeList>::value, "In typelist::head<T>, the T argument must be typelist<...>."); + static_assert(false_t<TypeList>::value, "In typelist::head<T>, the T argument must be typelist<...>."); }; template<class Head, class... Tail> struct head<typelist<Head, Tail...>> final { using type = Head; @@ -185,7 +188,7 @@ template<class TypeList> using head_t = typename head<TypeList>::type; /// Base template. template<size_t Index, class TypeList> struct element final { - static_assert(detail::false_t<TypeList>::value, "In typelist::element<T>, the T argument must be typelist<...>."); + static_assert(false_t<TypeList>::value, "In typelist::element<T>, the T argument must be typelist<...>."); }; /// Successful case, we have reached the zero index and can "return" the head type. @@ -213,7 +216,7 @@ using element_t = typename element<Index, TypeList>::type; template <class TypeList> struct last final { static_assert( - detail::false_t<TypeList>::value, + false_t<TypeList>::value, "In typelist::last<T>, the T argument must be typelist<...>."); }; template <class Head, class... Tail> @@ -236,7 +239,7 @@ static_assert( * typelist<int, string> == reverse_t<typelist<string, int>> */ template<class TypeList> struct reverse final { - static_assert(detail::false_t<TypeList>::value, "In typelist::reverse<T>, the T argument must be typelist<...>."); + static_assert(false_t<TypeList>::value, "In typelist::reverse<T>, the T argument must be typelist<...>."); }; template<class Head, class... Tail> struct reverse<typelist<Head, Tail...>> final { using type = concat_t<typename reverse<typelist<Tail...>>::type, typelist<Head>>; @@ -247,6 +250,28 @@ template<> struct reverse<typelist<>> final { template<class TypeList> using reverse_t = typename reverse<TypeList>::type; +/** + * Find the index of the first type in a typelist fulfilling a type trait condition. + * Example: + * + * 2 == find_if<typelist<char, int, char&, int&>, std::is_reference>::value + */ +template<class TypeList, template<class> class Condition, class Enable = void> struct find_if final { + static_assert(false_t<TypeList>::value, "In typelist::find_if<TypeList, Condition>, the TypeList argument must be typelist<...>."); +}; +template<template<class> class Condition> struct find_if<typelist<>, Condition, void> final { + static_assert(false_higher_t<Condition>::value, "In typelist::find_if<Type/List, Condition>, didn't find any type fulfilling the Condition."); +}; +template<class Head, class... Tail, template<class> class Condition> +struct find_if<typelist<Head, Tail...>, Condition, enable_if_t<Condition<Head>::value>> final { + static constexpr size_t value = 0; +}; +template<class Head, class... Tail, template<class> class Condition> +struct find_if<typelist<Head, Tail...>, Condition, enable_if_t<!Condition<Head>::value>> final { + static constexpr size_t value = 1 + find_if<typelist<Tail...>, Condition>::value; +}; + + /** * Maps a list of types into a list of values. @@ -282,7 +307,7 @@ template<class T> struct type_ final { using type = T; }; template<class TypeList> struct map_types_to_values final { - static_assert(detail::false_t<TypeList>::value, "In typelist::map_types_to_values<T>, the T argument must be typelist<...>."); + static_assert(false_t<TypeList>::value, "In typelist::map_types_to_values<T>, the T argument must be typelist<...>."); }; template<class... Types> struct map_types_to_values<typelist<Types...>> final { template<class Func> diff --git a/caffe2/core/operator.h b/caffe2/core/operator.h index 253428645d..a12767c456 100644 --- a/caffe2/core/operator.h +++ b/caffe2/core/operator.h @@ -203,6 +203,16 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> { return XBlobGetMutableTensor(outputs_.at(idx), dims, options); } + void SetOutputTensor(int idx, Tensor tensor) { + // also update the tensor in the hack + if (!isLegacyOperator()) { + output_tensors_[idx] = tensor.UnsafeSharedInstance(); + } + + // update the tensor in the workspace + BlobSetTensor(outputs_.at(idx), std::move(tensor)); + } + inline Tensor* OutputTensor(int idx, at::IntList dims, at::TensorOptions options) { if (isLegacyOperator()) { diff --git a/caffe2/core/operator_c10wrapper.h b/caffe2/core/operator_c10wrapper.h index fb0c458ce9..ab0230d2af 100644 --- a/caffe2/core/operator_c10wrapper.h +++ b/caffe2/core/operator_c10wrapper.h @@ -23,19 +23,15 @@ using extract_type_t = typename ParameterDef::type; * * REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH(C10Add, C2MyAddOpName) * - * Note: This wrapper currently only supports C10 ops that have exactly one - * output and take that in the last parameter as "Tensor* output". - * TODO: Figure out a better way to handle output parameters */ template < - class OpSchemaDef, + const c10::OperatorHandle& (*OperatorHandle)(), class Context, bool use_array_input, + size_t num_output_parameters, class ParameterDefTuple> class C10OperatorWrapper final : public Operator<Context> { - using Schema = c10::OpSchema<OpSchemaDef>; - public: static_assert( c10::guts::is_instantiation_of<std::tuple, ParameterDefTuple>::value, @@ -49,28 +45,39 @@ class C10OperatorWrapper final : public Operator<Context> { C10OperatorWrapper(const OperatorDef& operator_def, Workspace* ws) : Operator<Context>(operator_def, ws), + op_(OperatorHandle()), kernel_(at::nullopt), parameters_(parse_parameters_( operator_def, - c10::guts::make_index_sequence<num_parameters()>())) {} + c10::guts::make_index_sequence<num_parameters()>())) { - static constexpr size_t num_inputs() { - return Schema::signature::num_args - num_outputs() - num_parameters(); + AT_ASSERT(operator_def.output_size() == op_.schema().returns().size()); + AT_ASSERT(operator_def.input_size() == num_inputs()); } - static constexpr size_t num_parameters() { - return std::tuple_size<ParameterDefTuple>::value; + size_t num_inputs() { + return op_.schema().arguments().size() - num_output_parameters - num_parameters(); } - static constexpr size_t num_outputs() { - return Schema::signature::num_outputs; + static constexpr size_t num_parameters() { + return std::tuple_size<ParameterDefTuple>::value; } bool RunOnDevice() override { - RunOnDevice_( - c10::guts::make_index_sequence<num_inputs()>(), - c10::guts::make_index_sequence<num_outputs()>(), - c10::guts::make_index_sequence<num_parameters()>()); + // due to caching the stack_, concurrent calling is not allowed. + // TODO thread_local might fix this + std::lock_guard<std::mutex> lock(mutex_); + + AT_ASSERT(stack_.size() == 0); + + pushInputs_(); + pushParameters_(guts::make_index_sequence<num_parameters()>()); + pushOutputParameters_(); + + callKernel_(); + + popOutputs_(); + return true; } @@ -91,56 +98,44 @@ class C10OperatorWrapper final : public Operator<Context> { return Parameter::parse(ArgumentHelper(operator_def)); } - template < - size_t... InputIndex, - size_t... OutputIndex, - size_t... ParameterIndex> - c10::guts::enable_if_t< - details::true_t<InputIndex...>::value && - !use_array_input, - void> - RunOnDevice_( - c10::guts::index_sequence<InputIndex...>, - c10::guts::index_sequence<OutputIndex...>, - c10::guts::index_sequence<ParameterIndex...>) { - Stack stack; - torch::jit::push(stack, - IValue(at::Tensor(C10Tensor(Input(InputIndex))))..., - IValue(std::get<ParameterIndex>(parameters_))..., - IValue(at::Tensor(C10Tensor(*Output(OutputIndex))))... - ); - call_(&stack); - // TODO Do we have to Write outputs from stack back into the workspace? + void pushInputs_() { + if (use_array_input) { + stack_.emplace_back(ivalue::TensorList::create(array_inputs_())); + } else { + for (size_t i = 0; i < num_inputs(); ++i) { + stack_.emplace_back(at::Tensor(C10Tensor(Input(i)))); + } + } } - template < - size_t... InputIndex, - size_t... OutputIndex, - size_t... ParameterIndex> - c10::guts::enable_if_t< - details::true_t<InputIndex...>::value && - use_array_input, - void> - RunOnDevice_( - c10::guts::index_sequence<InputIndex...>, - c10::guts::index_sequence<OutputIndex...>, - c10::guts::index_sequence<ParameterIndex...>) { - Stack stack; - torch::jit::push(stack, - IValue(ivalue::TensorList::create(array_inputs_())), - IValue(std::get<ParameterIndex>(parameters_))..., - IValue(at::Tensor(C10Tensor(*Output(OutputIndex))))... - ); - call_(&stack); - // TODO Do we have to Write outputs from stack back into the workspace? + template<size_t... ParameterIndex> + void pushParameters_(guts::index_sequence<ParameterIndex...>) { + (void)std::initializer_list<int>{( + stack_.emplace_back(std::get<ParameterIndex>(parameters_)) + , 0)...}; + } + + void pushOutputParameters_() { + for (size_t i = 0; i < num_output_parameters; ++i) { + stack_.emplace_back(at::Tensor(C10Tensor(*Output(i)))); + } } - void call_(Stack* stack) { + void callKernel_() { + AT_ASSERT(stack_.size() == op_.schema().arguments().size()); if (!kernel_.has_value()) { // TODO if kernel is already set, try re-dispatch to assert it goes to the same kernel - kernel_ = c10::Dispatcher<OpSchemaDef>::lookup(stack); + kernel_ = c10::Dispatcher::singleton().lookup(op_, &stack_); } - kernel_->call(stack); + kernel_->call(&stack_); + } + + void popOutputs_() { + AT_ASSERT(stack_.size() == op_.schema().returns().size()); + for (size_t i = 0; i < op_.schema().returns().size(); ++i) { + OperatorBase::SetOutputTensor(i, Tensor(C10Tensor(std::move(stack_[i]).toTensor()))); + } + stack_.clear(); } std::vector<at::Tensor> array_inputs_() { @@ -152,8 +147,15 @@ class C10OperatorWrapper final : public Operator<Context> { return result; } + c10::OperatorHandle op_; c10::optional<OpKernel> kernel_; + // this is stored as a member here to avoid having to re-allocate a stack + // for each call. Between kernel calls, stack_.size() == 0, but capacity + // should not need to be grown anymore after the first call. + std::vector<IValue> stack_; + std::mutex mutex_; + ParameterTuple parameters_; }; @@ -174,39 +176,41 @@ C10_DECLARE_REGISTRY( // TODO Currently we only register the CPU variant. This is going to be fixed // once the tensor detemplatization lands. -#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH(OpSchemaDef, Name) \ - C10_REGISTER_CLASS( \ - C10OperatorRegistry, \ - Name, \ - C10OperatorWrapper<OpSchemaDef, CPUContext, false, std::tuple<>>) +#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH(OperatorHandle, Name, NumOutputParameters) \ + C10_REGISTER_CLASS( \ + C10OperatorRegistry, \ + Name, \ + C10OperatorWrapper<OperatorHandle, CPUContext, false, NumOutputParameters, std::tuple<>>) #define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_PARAMETERS( \ - OpSchemaDef, Name, ...) \ + OperatorHandle, Name, NumOutputParameters, ...) \ C10_REGISTER_CLASS( \ C10OperatorRegistry, \ Name, \ C10OperatorWrapper< \ - OpSchemaDef, \ + OperatorHandle, \ CPUContext, \ false, \ + NumOutputParameters, \ std::tuple<__VA_ARGS__>>) #define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_ARRAY_INPUT( \ - OpSchemaDef, Name) \ + OperatorHandle, Name, NumOutputParameters) \ C10_REGISTER_CLASS( \ C10OperatorRegistry, \ Name, \ - C10OperatorWrapper<OpSchemaDef, CPUContext, true, std::tuple<>>) + C10OperatorWrapper<OperatorHandle, CPUContext, true, NumOutputParameters, std::tuple<>>) #define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_ARRAY_INPUT_AND_PARAMETERS( \ - OpSchemaDef, Name, ...) \ + OperatorHandle, Name, NumOutputParameters, ...) \ C10_REGISTER_CLASS( \ C10OperatorRegistry, \ Name, \ C10OperatorWrapper< \ - OpSchemaDef, \ + OperatorHandle, \ CPUContext, \ true, \ + NumOutputParameters, \ std::tuple<__VA_ARGS__>>) } // namespace caffe2 diff --git a/caffe2/operators/experimental/c10/cpu/add_cpu.cc b/caffe2/operators/experimental/c10/cpu/add_cpu.cc index 9b81641731..9ec1aa1a8b 100644 --- a/caffe2/operators/experimental/c10/cpu/add_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/add_cpu.cc @@ -15,7 +15,7 @@ void add_op_cpu_impl( const at::Tensor& B_, const at::Tensor& C_, bool legacy_broadcast, - int axis) { + int64_t axis) { Tensor A{C10Tensor(A_)}; Tensor B{C10Tensor(B_)}; Tensor C{C10Tensor(C_)}; @@ -74,13 +74,6 @@ void add_op_cpu_impl( namespace c10 { C10_REGISTER_KERNEL(caffe2::ops::Add) - .kernel<&caffe2::add_op_cpu_impl<float>>() - .dispatchKey(c10::DispatchKey<2>{ - c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<float>()}, - c10::details::TensorParameterDispatchKey{ - DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<float>()}}); + .kernel<decltype(caffe2::add_op_cpu_impl<float>), &caffe2::add_op_cpu_impl<float>>() + .dispatchKey(CPUTensorId()); } // namespace c10 diff --git a/caffe2/operators/experimental/c10/cpu/averaged_loss_cpu.cc b/caffe2/operators/experimental/c10/cpu/averaged_loss_cpu.cc index 6cd8cbf94f..cc5823d9e3 100644 --- a/caffe2/operators/experimental/c10/cpu/averaged_loss_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/averaged_loss_cpu.cc @@ -10,7 +10,7 @@ using std::vector; namespace caffe2 { namespace { -struct State final : public c10::KernelState { +struct Cache final : public c10::KernelCache { at::Tensor scratch = at::Tensor(C10Tensor(empty({}, CPU))); }; @@ -18,7 +18,7 @@ template <class T, class Context> void averaged_loss_op_cpu_impl( const at::Tensor& X_, const at::Tensor& sum_, - State* state) { + Cache* state) { Tensor X{C10Tensor(X_)}; Tensor sum{C10Tensor(sum_)}; CPUContext context; @@ -48,10 +48,7 @@ void averaged_loss_op_cpu_impl( namespace c10 { C10_REGISTER_KERNEL(caffe2::ops::AveragedLoss) - .withState<caffe2::State>() - .kernel<&caffe2::averaged_loss_op_cpu_impl<float, caffe2::CPUContext>>() - .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{ - DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<float>()}}); + .withCache<caffe2::Cache>() + .kernel<decltype(caffe2::averaged_loss_op_cpu_impl<float, caffe2::CPUContext>), &caffe2::averaged_loss_op_cpu_impl<float, caffe2::CPUContext>>() + .dispatchKey(CPUTensorId()); } // namespace c10 diff --git a/caffe2/operators/experimental/c10/cpu/batch_gather_cpu.cc b/caffe2/operators/experimental/c10/cpu/batch_gather_cpu.cc index 3edb5ea841..a34c098d82 100644 --- a/caffe2/operators/experimental/c10/cpu/batch_gather_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/batch_gather_cpu.cc @@ -53,29 +53,21 @@ void batch_gather_op_cpu_impl( } } } + +void batch_gather_op_cpu(const at::Tensor& data, + const at::Tensor& indices, + const at::Tensor& output) { + switch (data.scalar_type()) { + case ScalarType::Int: return batch_gather_op_cpu_impl<int>(data, indices, output); + case ScalarType::Long: return batch_gather_op_cpu_impl<int64_t>(data, indices, output); + default: throw std::runtime_error(string() + "Unsupported dtype: " + toString(data.scalar_type())); + } +} } // namespace } // namespace caffe2 namespace c10 { C10_REGISTER_KERNEL(caffe2::ops::BatchGather) - .kernel<&caffe2::batch_gather_op_cpu_impl<int64_t>>() - .dispatchKey(c10::DispatchKey<2>{ - c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<float>()}, - c10::details::TensorParameterDispatchKey{ - DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<int64_t>()}}); - -C10_REGISTER_KERNEL(caffe2::ops::BatchGather) - .kernel<&caffe2::batch_gather_op_cpu_impl<int32_t>>() - .dispatchKey(c10::DispatchKey<2>{ - c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<float>()}, - c10::details::TensorParameterDispatchKey{ - DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<int32_t>()}}); + .kernel<decltype(caffe2::batch_gather_op_cpu), &caffe2::batch_gather_op_cpu>() + .dispatchKey(CPUTensorId()); } // namespace c10 diff --git a/caffe2/operators/experimental/c10/cpu/batch_matmul_cpu.cc b/caffe2/operators/experimental/c10/cpu/batch_matmul_cpu.cc index 8d977466eb..dfe85237a4 100644 --- a/caffe2/operators/experimental/c10/cpu/batch_matmul_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/batch_matmul_cpu.cc @@ -11,7 +11,7 @@ namespace math = caffe2::math; namespace caffe2 { namespace { -struct State final : public c10::KernelState { +struct Cache final : public c10::KernelCache { std::shared_ptr<at::Tensor> scratch; }; @@ -20,10 +20,10 @@ void batch_matmul_op_cpu_impl( const at::Tensor& A_, const at::Tensor& B_, const at::Tensor& Y_, - int trans_a, - int trans_b, - int broadcast, - State* state) { + int64_t trans_a, + int64_t trans_b, + int64_t broadcast, + Cache* cache) { Tensor A{C10Tensor(A_)}; Tensor B{C10Tensor(B_)}; Tensor Y{C10Tensor(Y_)}; @@ -273,14 +273,7 @@ void batch_matmul_op_cpu_impl( namespace c10 { C10_REGISTER_KERNEL(caffe2::ops::BatchMatmul) - .withState<caffe2::State>() - .kernel<&caffe2::batch_matmul_op_cpu_impl<float, caffe2::CPUContext>>() - .dispatchKey(c10::DispatchKey<2>{ - c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<float>()}, - c10::details::TensorParameterDispatchKey{ - DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<float>()}}); + .withCache<caffe2::Cache>() + .kernel<decltype(caffe2::batch_matmul_op_cpu_impl<float, caffe2::CPUContext>), &caffe2::batch_matmul_op_cpu_impl<float, caffe2::CPUContext>>() + .dispatchKey(CPUTensorId()); } // namespace c10 diff --git a/caffe2/operators/experimental/c10/cpu/cast_cpu.cc b/caffe2/operators/experimental/c10/cpu/cast_cpu.cc index 62c2a62fca..a2203c5927 100644 --- a/caffe2/operators/experimental/c10/cpu/cast_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/cast_cpu.cc @@ -74,62 +74,22 @@ void cast_op_cpu_impl( CAFFE_THROW("Unexpected 'to' argument value: ", to); } } +void cast_op_cpu( + const at::Tensor& input, + const at::Tensor& output, + int64_t to) { + switch (input.scalar_type()) { +#define CASE(ctype,name,_2) case ScalarType:: name : return cast_op_cpu_impl<ctype>(input, output, to); + AT_FORALL_SCALAR_TYPES(CASE) +#undef CASE + default: throw std::runtime_error(string() + "Unsupported scalar type " + toString(input.scalar_type())); + } +} } // namespace } // namespace caffe2 namespace c10 { C10_REGISTER_KERNEL(caffe2::ops::Cast) - .kernel<&caffe2::cast_op_cpu_impl<float>>() - .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{ - DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<float>()}}); -C10_REGISTER_KERNEL(caffe2::ops::Cast) - .kernel<&caffe2::cast_op_cpu_impl<int32_t>>() - .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{ - DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<int32_t>()}}); -C10_REGISTER_KERNEL(caffe2::ops::Cast) - .kernel<&caffe2::cast_op_cpu_impl<bool>>() - .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{ - DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<bool>()}}); -C10_REGISTER_KERNEL(caffe2::ops::Cast) - .kernel<&caffe2::cast_op_cpu_impl<uint8_t>>() - .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{ - DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<uint8_t>()}}); -C10_REGISTER_KERNEL(caffe2::ops::Cast) - .kernel<&caffe2::cast_op_cpu_impl<int8_t>>() - .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{ - DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<int8_t>()}}); -C10_REGISTER_KERNEL(caffe2::ops::Cast) - .kernel<&caffe2::cast_op_cpu_impl<uint16_t>>() - .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{ - DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<uint16_t>()}}); -C10_REGISTER_KERNEL(caffe2::ops::Cast) - .kernel<&caffe2::cast_op_cpu_impl<int16_t>>() - .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{ - DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<int16_t>()}}); -C10_REGISTER_KERNEL(caffe2::ops::Cast) - .kernel<&caffe2::cast_op_cpu_impl<int64_t>>() - .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{ - DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<int64_t>()}}); -C10_REGISTER_KERNEL(caffe2::ops::Cast) - .kernel<&caffe2::cast_op_cpu_impl<double>>() - .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{ - DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<double>()}}); + .kernel<decltype(caffe2::cast_op_cpu), &caffe2::cast_op_cpu>() + .dispatchKey(CPUTensorId()); } // namespace c10 diff --git a/caffe2/operators/experimental/c10/cpu/concat_cpu.cc b/caffe2/operators/experimental/c10/cpu/concat_cpu.cc index 858675cfce..931a7d221e 100644 --- a/caffe2/operators/experimental/c10/cpu/concat_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/concat_cpu.cc @@ -16,8 +16,8 @@ void concat_op_cpu_impl( ArrayRef<at::Tensor> inputs, const at::Tensor& output_, const at::Tensor& split_, - int axis, - int add_axis) { + int64_t axis, + int64_t add_axis) { Tensor output{C10Tensor(output_)}; Tensor split{C10Tensor(split_)}; CPUContext context; @@ -108,6 +108,6 @@ void concat_op_cpu_impl( namespace c10 { C10_REGISTER_KERNEL(caffe2::ops::Concat) - .kernel<&caffe2::concat_op_cpu_impl<float, CPUContext>>() - .dispatchKey(c10::DeviceTypeId::CPU); + .kernel<decltype(caffe2::concat_op_cpu_impl<float, CPUContext>), &caffe2::concat_op_cpu_impl<float, CPUContext>>() + .dispatchKey(CPUTensorId()); } // namespace c10 diff --git a/caffe2/operators/experimental/c10/cpu/enforce_finite_cpu.cc b/caffe2/operators/experimental/c10/cpu/enforce_finite_cpu.cc index 1db6ed2c22..2e9f608bd1 100644 --- a/caffe2/operators/experimental/c10/cpu/enforce_finite_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/enforce_finite_cpu.cc @@ -28,8 +28,6 @@ void enforce_finite_op_impl_cpu(const at::Tensor& input_) { namespace c10 { C10_REGISTER_KERNEL(caffe2::ops::EnforceFinite) - .kernel<&caffe2::enforce_finite_op_impl_cpu<float>>() - .dispatchKey({DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<float>()}); + .kernel<decltype(caffe2::enforce_finite_op_impl_cpu<float>), &caffe2::enforce_finite_op_impl_cpu<float>>() + .dispatchKey(CPUTensorId()); } // namespace c10 diff --git a/caffe2/operators/experimental/c10/cpu/expand_dims_cpu.cc b/caffe2/operators/experimental/c10/cpu/expand_dims_cpu.cc index 38c968163c..a7195009b8 100644 --- a/caffe2/operators/experimental/c10/cpu/expand_dims_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/expand_dims_cpu.cc @@ -8,7 +8,7 @@ using caffe2::Tensor; namespace caffe2 { namespace { -struct State final : public c10::KernelState { +struct Cache final : public c10::KernelCache { std::vector<int64_t> dims; bool initialized = false; }; @@ -18,38 +18,38 @@ void expand_dims_op_cpu_impl( const at::Tensor& input_, const at::Tensor& output_, ArrayRef<int64_t> dims, - State* state) { + Cache* cache) { Tensor input{C10Tensor(input_)}; Tensor output{C10Tensor(output_)}; - if (!state->initialized) { - state->dims = dims.vec(); - auto originalSize = state->dims.size(); + if (!cache->initialized) { + cache->dims = dims.vec(); + auto originalSize = cache->dims.size(); CAFFE_ENFORCE(originalSize > 0, "Parameter `dims` must be provided."); - std::sort(state->dims.begin(), state->dims.end()); - state->dims.erase( - std::unique(state->dims.begin(), state->dims.end()), state->dims.end()); - if (state->dims.size() < originalSize) { + std::sort(cache->dims.begin(), cache->dims.end()); + cache->dims.erase( + std::unique(cache->dims.begin(), cache->dims.end()), cache->dims.end()); + if (cache->dims.size() < originalSize) { LOG(WARNING) << "Parameter `dims` has repeated dimensions."; } CAFFE_ENFORCE( - state->dims.front() >= 0, "Dimension ids must be non-negative."); - state->initialized = true; + cache->dims.front() >= 0, "Dimension ids must be non-negative."); + cache->initialized = true; } output.CopyFrom(input); - if (state->dims.empty()) { + if (cache->dims.empty()) { return; } auto newDims = input.sizes().vec(); CAFFE_ENFORCE_GE( - input.sizes().size() + state->dims.size(), - state->dims.back() + 1, + input.sizes().size() + cache->dims.size(), + cache->dims.back() + 1, "Input needs at least ", - (1 + state->dims.back() - state->dims.size()), + (1 + cache->dims.back() - cache->dims.size()), " dimensions given `dims`."); - for (const auto dim : state->dims) { + for (const auto dim : cache->dims) { newDims.insert(newDims.begin() + dim, 1); } output.Reshape(newDims); @@ -59,9 +59,7 @@ void expand_dims_op_cpu_impl( namespace c10 { C10_REGISTER_KERNEL(caffe2::ops::ExpandDims) - .withState<caffe2::State>() - .kernel<&caffe2::expand_dims_op_cpu_impl<float>>() - .dispatchKey({DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<float>()}); + .withCache<caffe2::Cache>() + .kernel<decltype(caffe2::expand_dims_op_cpu_impl<float>), &caffe2::expand_dims_op_cpu_impl<float>>() + .dispatchKey(CPUTensorId()); } // namespace c10 diff --git a/caffe2/operators/experimental/c10/cpu/fc_cpu.cc b/caffe2/operators/experimental/c10/cpu/fc_cpu.cc index 18bed5d868..99e05458db 100644 --- a/caffe2/operators/experimental/c10/cpu/fc_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/fc_cpu.cc @@ -12,7 +12,7 @@ using caffe2::Tensor; namespace caffe2 { namespace { -struct State final : public c10::KernelState { +struct Cache final : public c10::KernelCache { vector<int64_t> Y_shape_cache_; at::Tensor bias_multiplier_ = at::Tensor(C10Tensor(Tensor())); }; @@ -23,9 +23,9 @@ void fc_op_cpu_impl( const at::Tensor& W_, const at::Tensor& b_, const at::Tensor& Y_, - int axis, - int axis_w, - State* state) { + int64_t axis, + int64_t axis_w, + Cache* cache) { Tensor X{C10Tensor(X_)}; Tensor W{C10Tensor(W_)}; Tensor b{C10Tensor(b_)}; @@ -68,12 +68,12 @@ void fc_op_cpu_impl( CAFFE_ENFORCE(N == b.dim32(0), dimErrorString()); CAFFE_ENFORCE(N == b.numel(), dimErrorString()); - state->Y_shape_cache_ = X.sizes().vec(); + cache->Y_shape_cache_ = X.sizes().vec(); // This is an invariant of canonical_axis, so we can DCHECK. - DCHECK_LE(canonical_axis + 1, state->Y_shape_cache_.size()); - state->Y_shape_cache_.resize(canonical_axis + 1); - state->Y_shape_cache_[canonical_axis] = N; - Y.Resize(state->Y_shape_cache_); + DCHECK_LE(canonical_axis + 1, cache->Y_shape_cache_.size()); + cache->Y_shape_cache_.resize(canonical_axis + 1); + cache->Y_shape_cache_[canonical_axis] = N; + Y.Resize(cache->Y_shape_cache_); CAFFE_ENFORCE(M * N == Y.numel(), dimErrorString()); if (X.numel() == 0) { @@ -103,7 +103,7 @@ void fc_op_cpu_impl( static_cast<Context*>(&context), math_type); // Add bias term - Tensor bias_multiplier(state->bias_multiplier_); + Tensor bias_multiplier(cache->bias_multiplier_); ReinitializeTensor(&bias_multiplier, {M}, at::dtype<DataType>().device(CPU)); caffe2::math::Set<DataType, Context>( M, @@ -129,17 +129,7 @@ void fc_op_cpu_impl( namespace c10 { C10_REGISTER_KERNEL(caffe2::ops::FullyConnected) - .withState<caffe2::State>() - .kernel<&caffe2::fc_op_cpu_impl<float, caffe2::CPUContext>>() - .dispatchKey(c10::DispatchKey<3>{ - c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<float>()}, - c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<float>()}, - c10::details::TensorParameterDispatchKey{ - DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<float>()}}); + .withCache<caffe2::Cache>() + .kernel<decltype(caffe2::fc_op_cpu_impl<float, caffe2::CPUContext>), &caffe2::fc_op_cpu_impl<float, caffe2::CPUContext>>() + .dispatchKey(CPUTensorId()); } // namespace c10 diff --git a/caffe2/operators/experimental/c10/cpu/filler_cpu.cc b/caffe2/operators/experimental/c10/cpu/filler_cpu.cc index 376bc924f7..6db668547f 100644 --- a/caffe2/operators/experimental/c10/cpu/filler_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/filler_cpu.cc @@ -76,8 +76,8 @@ void constant_fill_op_cpu_impl( ArrayRef<int64_t> shape, ArrayRef<int64_t> extra_shape, bool input_as_shape, - int dtype, - c10::IValue value) { + int64_t dtype, + c10::Scalar value) { Tensor output{C10Tensor(output_)}; CPUContext context; @@ -102,12 +102,6 @@ void constant_fill_op_cpu_impl( value.toInt(), output.template mutable_data<int64_t>(), static_cast<CPUContext*>(&context)); - } else if (dtype == caffe2::TensorProto_DataType_BOOL) { - caffe2::math::Set<bool, CPUContext>( - output.numel(), - value.toBool(), - output.template mutable_data<bool>(), - static_cast<CPUContext*>(&context)); } else { throw std::logic_error( "Unimplemented data type for ConstantFill: " + @@ -122,8 +116,8 @@ void uniform_fill_op_cpu_impl( ArrayRef<int64_t> shape, ArrayRef<int64_t> extra_shape, bool input_as_shape, - float min, - float max) { + double min, + double max) { Tensor output{C10Tensor(output_)}; CPUContext context; @@ -154,22 +148,22 @@ void uniform_fill_op_cpu_impl( namespace c10 { C10_REGISTER_KERNEL(caffe2::ops::ConstantFill) - .kernel<&caffe2::constant_fill_op_cpu_impl>() - .dispatchKey(c10::DeviceTypeId::CPU); + .kernel<decltype(caffe2::constant_fill_op_cpu_impl), &caffe2::constant_fill_op_cpu_impl>() + .dispatchKey(CPUTensorId()); C10_REGISTER_KERNEL(caffe2::ops::UniformFill) - .kernel<&caffe2::uniform_fill_op_cpu_impl>() - .dispatchKey(c10::DeviceTypeId::CPU); + .kernel<decltype(caffe2::uniform_fill_op_cpu_impl), &caffe2::uniform_fill_op_cpu_impl>() + .dispatchKey(CPUTensorId()); -C10_REGISTER_KERNEL(caffe2::ops::GivenTensorFill<float>) - .kernel<&caffe2::given_tensor_fill_op_cpu_impl<float, caffe2::CPUContext>>() - .dispatchKey(c10::DeviceTypeId::CPU); +C10_REGISTER_KERNEL(caffe2::ops::GivenTensorFill) + .kernel<decltype(caffe2::given_tensor_fill_op_cpu_impl<float, caffe2::CPUContext>), &caffe2::given_tensor_fill_op_cpu_impl<float, caffe2::CPUContext>>() + .dispatchKey(CPUTensorId()); -C10_REGISTER_KERNEL(caffe2::ops::GivenTensorFill<int>) - .kernel<&caffe2::given_tensor_fill_op_cpu_impl<int, caffe2::CPUContext>>() - .dispatchKey(c10::DeviceTypeId::CPU); +C10_REGISTER_KERNEL(caffe2::ops::GivenTensorIntFill) + .kernel<decltype(caffe2::given_tensor_fill_op_cpu_impl<int, caffe2::CPUContext>), &caffe2::given_tensor_fill_op_cpu_impl<int, caffe2::CPUContext>>() + .dispatchKey(CPUTensorId()); -C10_REGISTER_KERNEL(caffe2::ops::GivenTensorFill<int64_t>) - .kernel<&caffe2::given_tensor_fill_op_cpu_impl<int64_t, caffe2::CPUContext>>() - .dispatchKey(c10::DeviceTypeId::CPU); +C10_REGISTER_KERNEL(caffe2::ops::GivenTensorInt64Fill) + .kernel<decltype(caffe2::given_tensor_fill_op_cpu_impl<int64_t, caffe2::CPUContext>), &caffe2::given_tensor_fill_op_cpu_impl<int64_t, caffe2::CPUContext>>() + .dispatchKey(CPUTensorId()); } // namespace c10 diff --git a/caffe2/operators/experimental/c10/cpu/flatten_cpu.cc b/caffe2/operators/experimental/c10/cpu/flatten_cpu.cc index 26d16235db..23bbaf38dc 100644 --- a/caffe2/operators/experimental/c10/cpu/flatten_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/flatten_cpu.cc @@ -12,7 +12,7 @@ template <class DataType, class Context> void flatten_op_cpu_impl( const at::Tensor& input_, const at::Tensor& output_, - int axis) { + int64_t axis) { Tensor input{C10Tensor(input_)}; Tensor output{C10Tensor(output_)}; CPUContext context; @@ -30,8 +30,6 @@ void flatten_op_cpu_impl( namespace c10 { C10_REGISTER_KERNEL(caffe2::ops::Flatten) - .kernel<&caffe2::flatten_op_cpu_impl<float, caffe2::CPUContext>>() - .dispatchKey({DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<float>()}); + .kernel<decltype(caffe2::flatten_op_cpu_impl<float, caffe2::CPUContext>), &caffe2::flatten_op_cpu_impl<float, caffe2::CPUContext>>() + .dispatchKey(CPUTensorId()); } // namespace c10 diff --git a/caffe2/operators/experimental/c10/cpu/mul_cpu.cc b/caffe2/operators/experimental/c10/cpu/mul_cpu.cc index 38f7385ddf..247e1bb452 100644 --- a/caffe2/operators/experimental/c10/cpu/mul_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/mul_cpu.cc @@ -16,7 +16,7 @@ void mul_op_cpu_impl( const at::Tensor& B_, const at::Tensor& C_, bool legacy_broadcast, - int axis) { + int64_t axis) { Tensor A{C10Tensor(A_)}; Tensor B{C10Tensor(B_)}; Tensor C{C10Tensor(C_)}; @@ -75,13 +75,6 @@ void mul_op_cpu_impl( namespace c10 { C10_REGISTER_KERNEL(caffe2::ops::Mul) - .kernel<&caffe2::mul_op_cpu_impl<float>>() - .dispatchKey(c10::DispatchKey<2>{ - c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<float>()}, - c10::details::TensorParameterDispatchKey{ - DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<float>()}}); + .kernel<decltype(caffe2::mul_op_cpu_impl<float>), &caffe2::mul_op_cpu_impl<float>>() + .dispatchKey(CPUTensorId()); } // namespace c10 diff --git a/caffe2/operators/experimental/c10/cpu/relu_cpu.cc b/caffe2/operators/experimental/c10/cpu/relu_cpu.cc index 7c5b8620ef..67c7ee2431 100644 --- a/caffe2/operators/experimental/c10/cpu/relu_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/relu_cpu.cc @@ -44,8 +44,6 @@ void relu_op_cpu_impl( namespace c10 { C10_REGISTER_KERNEL(caffe2::ops::Relu) - .kernel<&caffe2::relu_op_cpu_impl<float>>() - .dispatchKey({DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<float>()}); + .kernel<decltype(caffe2::relu_op_cpu_impl<float>), &caffe2::relu_op_cpu_impl<float>>() + .dispatchKey(CPUTensorId()); } // namespace c10 diff --git a/caffe2/operators/experimental/c10/cpu/sigmoid_cpu.cc b/caffe2/operators/experimental/c10/cpu/sigmoid_cpu.cc index 78febf5de2..470d6332e9 100644 --- a/caffe2/operators/experimental/c10/cpu/sigmoid_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/sigmoid_cpu.cc @@ -27,8 +27,6 @@ void sigmoid_op_cpu_impl( namespace c10 { C10_REGISTER_KERNEL(caffe2::ops::Sigmoid) - .kernel<&caffe2::sigmoid_op_cpu_impl<float>>() - .dispatchKey({DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<float>()}); + .kernel<decltype(caffe2::sigmoid_op_cpu_impl<float>), &caffe2::sigmoid_op_cpu_impl<float>>() + .dispatchKey(CPUTensorId()); } // namespace c10 diff --git a/caffe2/operators/experimental/c10/cpu/sigmoid_cross_entropy_with_logits_cpu.cc b/caffe2/operators/experimental/c10/cpu/sigmoid_cross_entropy_with_logits_cpu.cc index 32853fd82b..4255e2ae31 100644 --- a/caffe2/operators/experimental/c10/cpu/sigmoid_cross_entropy_with_logits_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/sigmoid_cross_entropy_with_logits_cpu.cc @@ -74,13 +74,6 @@ void sigmoid_cross_entropy_with_logits_op_cpu_impl( namespace c10 { C10_REGISTER_KERNEL(caffe2::ops::SigmoidCrossEntropyWithLogits) - .kernel<&caffe2::sigmoid_cross_entropy_with_logits_op_cpu_impl>() - .dispatchKey(c10::DispatchKey<2>{ - c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<float>()}, - c10::details::TensorParameterDispatchKey{ - DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<float>()}}); + .kernel<decltype(caffe2::sigmoid_cross_entropy_with_logits_op_cpu_impl), &caffe2::sigmoid_cross_entropy_with_logits_op_cpu_impl>() + .dispatchKey(CPUTensorId()); } // namespace c10 diff --git a/caffe2/operators/experimental/c10/cpu/sparse_lengths_sum_cpu.cc b/caffe2/operators/experimental/c10/cpu/sparse_lengths_sum_cpu.cc index 94711168cf..1d8b1802e3 100644 --- a/caffe2/operators/experimental/c10/cpu/sparse_lengths_sum_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/sparse_lengths_sum_cpu.cc @@ -10,7 +10,7 @@ namespace caffe2 { namespace { template <typename InputType, typename IndexType> -void sparse_lengths_sum_op_cpu_impl( +void sparse_lengths_sum_op_cpu_impl_( const at::Tensor& dataInput_, const at::Tensor& indicesInput_, const at::Tensor& lengthsInput_, @@ -55,62 +55,37 @@ void sparse_lengths_sum_op_cpu_impl( USE_MEAN, out_data); } + +template<typename IndexType> +void sparse_lengths_sum_op_cpu_impl( + const at::Tensor& dataInput, + const at::Tensor& indicesInput, + const at::Tensor& lengthsInput, + const at::Tensor& output) { + switch (dataInput.scalar_type()) { + case ScalarType::Float: return sparse_lengths_sum_op_cpu_impl_<float, IndexType>(dataInput, indicesInput, lengthsInput, output); + case ScalarType::Half: return sparse_lengths_sum_op_cpu_impl_<at::Half, IndexType>(dataInput, indicesInput, lengthsInput, output); + default: throw std::runtime_error(string() + "Unsupported dtype for input data " + toString(dataInput.scalar_type())); + } +} + +void sparse_lengths_sum_op_cpu( + const at::Tensor& dataInput, + const at::Tensor& indicesInput, + const at::Tensor& lengthsInput, + const at::Tensor& output) { + switch (indicesInput.scalar_type()) { + case ScalarType::Int: return sparse_lengths_sum_op_cpu_impl<int>(dataInput, indicesInput, lengthsInput, output); + case ScalarType::Long: return sparse_lengths_sum_op_cpu_impl<int64_t>(dataInput, indicesInput, lengthsInput, output); + default: throw std::runtime_error(string() + "Unsupported dtype for input indices " + toString(dataInput.scalar_type())); + } +} + } // namespace } // namespace caffe2 namespace c10 { C10_REGISTER_KERNEL(caffe2::ops::SparseLengthsSum) - .kernel<&caffe2::sparse_lengths_sum_op_cpu_impl<float, int32_t>>() - .dispatchKey(c10::DispatchKey<3>{ - c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<float>()}, - c10::details::TensorParameterDispatchKey{ - DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<int32_t>()}, - c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<int>()}}); -C10_REGISTER_KERNEL(caffe2::ops::SparseLengthsSum) - .kernel<&caffe2::sparse_lengths_sum_op_cpu_impl<at::Half, int32_t>>() - .dispatchKey(c10::DispatchKey<3>{ - c10::details::TensorParameterDispatchKey{ - DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<at::Half>()}, - c10::details::TensorParameterDispatchKey{ - DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<int32_t>()}, - c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<int>()}}); -C10_REGISTER_KERNEL(caffe2::ops::SparseLengthsSum) - .kernel<&caffe2::sparse_lengths_sum_op_cpu_impl<float, int64_t>>() - .dispatchKey(c10::DispatchKey<3>{ - c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<float>()}, - c10::details::TensorParameterDispatchKey{ - DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<int64_t>()}, - c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<int>()}}); -C10_REGISTER_KERNEL(caffe2::ops::SparseLengthsSum) - .kernel<&caffe2::sparse_lengths_sum_op_cpu_impl<at::Half, int64_t>>() - .dispatchKey(c10::DispatchKey<3>{ - c10::details::TensorParameterDispatchKey{ - DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<at::Half>()}, - c10::details::TensorParameterDispatchKey{ - DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<int64_t>()}, - c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<int>()}}); + .kernel<decltype(caffe2::sparse_lengths_sum_op_cpu), &caffe2::sparse_lengths_sum_op_cpu>() + .dispatchKey(CPUTensorId()); } // namespace c10 diff --git a/caffe2/operators/experimental/c10/cpu/stop_gradient_cpu.cc b/caffe2/operators/experimental/c10/cpu/stop_gradient_cpu.cc index 1165978f9e..4d3a9812ce 100644 --- a/caffe2/operators/experimental/c10/cpu/stop_gradient_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/stop_gradient_cpu.cc @@ -23,8 +23,6 @@ void stop_gradient_op_cpu_impl( namespace c10 { C10_REGISTER_KERNEL(caffe2::ops::StopGradient) - .kernel<&caffe2::stop_gradient_op_cpu_impl<float>>() - .dispatchKey({DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<float>()}); + .kernel<decltype(caffe2::stop_gradient_op_cpu_impl<float>), &caffe2::stop_gradient_op_cpu_impl<float>>() + .dispatchKey(CPUTensorId()); } // namespace c10 diff --git a/caffe2/operators/experimental/c10/schemas/add.cc b/caffe2/operators/experimental/c10/schemas/add.cc index cfb778df12..88906b7a52 100644 --- a/caffe2/operators/experimental/c10/schemas/add.cc +++ b/caffe2/operators/experimental/c10/schemas/add.cc @@ -1,10 +1,24 @@ #include "caffe2/operators/experimental/c10/schemas/add.h" -#include <ATen/core/dispatch/OpSchemaRegistration.h> #include "caffe2/core/operator_c10wrapper.h" using caffe2::CPUContext; -C10_DEFINE_OP_SCHEMA(caffe2::ops::Add); +namespace caffe2 { +namespace ops { +// TODO Parse schema string instead of creating FunctionSchema manually +C10_DEFINE_OP_SCHEMA(Add, FunctionSchema( + "_c10_experimental::Add", + (std::vector<c10::Argument>{ + c10::Argument("input1"), + c10::Argument("input2"), + c10::Argument("output"), + c10::Argument("legacy_broadcast", BoolType::get()), + c10::Argument("axis", IntType::get()) + }), (std::vector<c10::Argument>{ + }) +)); +} +} namespace { @@ -32,6 +46,7 @@ namespace caffe2 { REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_PARAMETERS( ops::Add, C10Add_DontUseThisOpYet, + 1, ParameterHelper<LegacyBroadcastParameter>, ParameterHelper<AxisParameter>) } diff --git a/caffe2/operators/experimental/c10/schemas/add.h b/caffe2/operators/experimental/c10/schemas/add.h index fba907ab31..4dfa99ae77 100644 --- a/caffe2/operators/experimental/c10/schemas/add.h +++ b/caffe2/operators/experimental/c10/schemas/add.h @@ -1,29 +1,11 @@ #pragma once -#include <ATen/core/Tensor.h> -#include <c10/util/Array.h> -#include "caffe2/core/context_base.h" +#include <ATen/core/dispatch/OpSchemaRegistration.h> namespace caffe2 { namespace ops { -struct Add final { - static constexpr const char* name = "add"; - - using Signature = void( - const at::Tensor& input1, - const at::Tensor& input2, - const at::Tensor& output, - bool legacy_broadcast, - int axis); - - static constexpr size_t num_dispatch_args() {return 2;} - - static constexpr size_t num_outputs() {return 1;} - - static constexpr c10::guts::array<const char*, 5> parameter_names = { - {"input1", "input2", "output", "legacy_broadcast", "axis"}}; -}; +C10_DECLARE_OP_SCHEMA(Add); } // namespace ops } // namespace caffe2 diff --git a/caffe2/operators/experimental/c10/schemas/averaged_loss.cc b/caffe2/operators/experimental/c10/schemas/averaged_loss.cc index c276f84a65..c40752c9f2 100644 --- a/caffe2/operators/experimental/c10/schemas/averaged_loss.cc +++ b/caffe2/operators/experimental/c10/schemas/averaged_loss.cc @@ -4,11 +4,25 @@ using caffe2::CPUContext; -C10_DEFINE_OP_SCHEMA(caffe2::ops::AveragedLoss); +namespace caffe2 { +namespace ops { +// TODO Parse schema string instead of creating FunctionSchema manually +C10_DEFINE_OP_SCHEMA(AveragedLoss, FunctionSchema( + "_c10_experimental::AveragedLoss", + (std::vector<c10::Argument>{ + c10::Argument("input"), + c10::Argument("output") + }), (std::vector<c10::Argument>{ + }) +)); +} +} namespace caffe2 { REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_PARAMETERS( ops::AveragedLoss, - C10AveragedLoss_DontUseThisOpYet) + C10AveragedLoss_DontUseThisOpYet, + 1 + ) } diff --git a/caffe2/operators/experimental/c10/schemas/averaged_loss.h b/caffe2/operators/experimental/c10/schemas/averaged_loss.h index 37e6906dbe..548bd075f2 100644 --- a/caffe2/operators/experimental/c10/schemas/averaged_loss.h +++ b/caffe2/operators/experimental/c10/schemas/averaged_loss.h @@ -1,29 +1,11 @@ #pragma once -#include <ATen/core/Tensor.h> -#include <c10/util/Array.h> -#include "caffe2/core/context_base.h" -#include "caffe2/core/tensor.h" -#include <ATen/core/blob.h> -#include <ATen/core/dispatch/OpSchema.h> +#include <ATen/core/dispatch/OpSchemaRegistration.h> namespace caffe2 { namespace ops { -struct AveragedLoss final { - static constexpr const char* name = "averaged_loss"; - - using Signature = void( - const at::Tensor& input, - const at::Tensor& output); - - static constexpr size_t num_dispatch_args() {return 1;} - - static constexpr size_t num_outputs() {return 1;} - - static constexpr c10::guts::array<const char*, 2> parameter_names = { - {"input", "output"}}; -}; +C10_DECLARE_OP_SCHEMA(AveragedLoss); } // namespace ops } // namespace caffe2 diff --git a/caffe2/operators/experimental/c10/schemas/batch_gather.cc b/caffe2/operators/experimental/c10/schemas/batch_gather.cc index 5a45dd1bbc..a659ee90c3 100644 --- a/caffe2/operators/experimental/c10/schemas/batch_gather.cc +++ b/caffe2/operators/experimental/c10/schemas/batch_gather.cc @@ -4,10 +4,25 @@ using caffe2::CPUContext; -C10_DEFINE_OP_SCHEMA(caffe2::ops::BatchGather); +namespace caffe2 { +namespace ops { +// TODO Parse schema string instead of creating FunctionSchema manually +C10_DEFINE_OP_SCHEMA(BatchGather, FunctionSchema( + "_c10_experimental::BatchGather", + (std::vector<c10::Argument>{ + c10::Argument("data"), + c10::Argument("indices"), + c10::Argument("output") + }), (std::vector<c10::Argument>{ + }) +)); +} +} namespace caffe2 { REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH( ops::BatchGather, - C10BatchGather_DontUseThisOpYet) + C10BatchGather_DontUseThisOpYet, + 1 + ) } diff --git a/caffe2/operators/experimental/c10/schemas/batch_gather.h b/caffe2/operators/experimental/c10/schemas/batch_gather.h index d745efa35f..214c67ff99 100644 --- a/caffe2/operators/experimental/c10/schemas/batch_gather.h +++ b/caffe2/operators/experimental/c10/schemas/batch_gather.h @@ -1,27 +1,11 @@ #pragma once -#include <ATen/core/Tensor.h> -#include <c10/util/Array.h> -#include "caffe2/core/context_base.h" +#include <ATen/core/dispatch/OpSchemaRegistration.h> namespace caffe2 { namespace ops { -struct BatchGather final { - static constexpr const char* name = "batch_gather"; - - using Signature = void( - const at::Tensor& data, - const at::Tensor& indices, - const at::Tensor& output); - - static constexpr size_t num_dispatch_args() {return 2;} - - static constexpr size_t num_outputs() {return 1;} - - static constexpr c10::guts::array<const char*, 3> parameter_names = { - {"data", "indices", "output"}}; -}; +C10_DECLARE_OP_SCHEMA(BatchGather); } // namespace ops } // namespace caffe2 diff --git a/caffe2/operators/experimental/c10/schemas/batch_matmul.cc b/caffe2/operators/experimental/c10/schemas/batch_matmul.cc index 682609b7ea..70f8ba8194 100644 --- a/caffe2/operators/experimental/c10/schemas/batch_matmul.cc +++ b/caffe2/operators/experimental/c10/schemas/batch_matmul.cc @@ -4,7 +4,23 @@ using caffe2::CPUContext; -C10_DEFINE_OP_SCHEMA(caffe2::ops::BatchMatmul); +namespace caffe2 { +namespace ops { +// TODO Parse schema string instead of creating FunctionSchema manually +C10_DEFINE_OP_SCHEMA(BatchMatmul, FunctionSchema( + "_c10_experimental::BatchMatmul", + (std::vector<c10::Argument>{ + c10::Argument("A"), + c10::Argument("B"), + c10::Argument("output"), + c10::Argument("trans_a", IntType::get()), + c10::Argument("trans_b", IntType::get()), + c10::Argument("broadcast", IntType::get()) + }), (std::vector<c10::Argument>{ + }) +)); +} +} namespace { struct TransAParameter final { @@ -41,6 +57,7 @@ namespace caffe2 { REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_PARAMETERS( ops::BatchMatmul, C10BatchMatMul_DontUseThisOpYet, + 1, ParameterHelper<TransAParameter>, ParameterHelper<TransBParameter>, ParameterHelper<BroadcastParameter>) diff --git a/caffe2/operators/experimental/c10/schemas/batch_matmul.h b/caffe2/operators/experimental/c10/schemas/batch_matmul.h index 2d815869f8..191e0e6a57 100644 --- a/caffe2/operators/experimental/c10/schemas/batch_matmul.h +++ b/caffe2/operators/experimental/c10/schemas/batch_matmul.h @@ -1,37 +1,11 @@ #pragma once -#include <ATen/core/Tensor.h> -#include <c10/util/Array.h> -#include "caffe2/core/context_base.h" -#include <ATen/core/blob.h> -#include <ATen/core/dispatch/OpSchema.h> +#include <ATen/core/dispatch/OpSchemaRegistration.h> namespace caffe2 { namespace ops { -struct BatchMatmul final { - static constexpr const char* name = "batch_matmul"; - - using Signature = void( - const at::Tensor& A, - const at::Tensor& B, - const at::Tensor& output, - int trans_a, - int trans_b, - int broadcast); - - static constexpr size_t num_dispatch_args() {return 2;} - - static constexpr size_t num_outputs() {return 1;} - - static constexpr c10::guts::array<const char*, 6> parameter_names = { - {"A", - "B", - "output", - "trans_a", - "trans_b", - "broadcast"}}; -}; +C10_DECLARE_OP_SCHEMA(BatchMatmul); } // namespace ops } // namespace caffe2 diff --git a/caffe2/operators/experimental/c10/schemas/cast.cc b/caffe2/operators/experimental/c10/schemas/cast.cc index d5615a6b76..6fce08764b 100644 --- a/caffe2/operators/experimental/c10/schemas/cast.cc +++ b/caffe2/operators/experimental/c10/schemas/cast.cc @@ -5,7 +5,20 @@ using caffe2::CPUContext; -C10_DEFINE_OP_SCHEMA(caffe2::ops::Cast); +namespace caffe2 { +namespace ops { +// TODO Parse schema string instead of creating FunctionSchema manually +C10_DEFINE_OP_SCHEMA(Cast, FunctionSchema( + "_c10_experimental::Cast", + (std::vector<c10::Argument>{ + c10::Argument("input"), + c10::Argument("output"), + c10::Argument("to_dtype", IntType::get()), + }), (std::vector<c10::Argument>{ + }) +)); +} +} namespace { @@ -22,5 +35,6 @@ namespace caffe2 { REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_PARAMETERS( ops::Cast, C10Cast_DontUseThisOpYet, + 1, ToParameter) } diff --git a/caffe2/operators/experimental/c10/schemas/cast.h b/caffe2/operators/experimental/c10/schemas/cast.h index 095348b768..979637b251 100644 --- a/caffe2/operators/experimental/c10/schemas/cast.h +++ b/caffe2/operators/experimental/c10/schemas/cast.h @@ -1,27 +1,11 @@ #pragma once -#include <ATen/core/Tensor.h> -#include <c10/util/Array.h> -#include "caffe2/core/context_base.h" +#include <ATen/core/dispatch/OpSchemaRegistration.h> namespace caffe2 { namespace ops { -struct Cast final { - static constexpr const char* name = "cast"; - - using Signature = void( - const at::Tensor& input1, - const at::Tensor& output, - int64_t to_dtype); - - static constexpr size_t num_dispatch_args() {return 1;} - - static constexpr size_t num_outputs() {return 1;} - - static constexpr c10::guts::array<const char*, 3> parameter_names = { - {"input", "output", "to"}}; -}; +C10_DECLARE_OP_SCHEMA(Cast); } // namespace ops } // namespace caffe2 diff --git a/caffe2/operators/experimental/c10/schemas/concat.cc b/caffe2/operators/experimental/c10/schemas/concat.cc index 75c39b4ba7..dddd2a4e75 100644 --- a/caffe2/operators/experimental/c10/schemas/concat.cc +++ b/caffe2/operators/experimental/c10/schemas/concat.cc @@ -4,7 +4,22 @@ using caffe2::CPUContext; -C10_DEFINE_OP_SCHEMA(caffe2::ops::Concat); +namespace caffe2 { +namespace ops { +// TODO Parse schema string instead of creating FunctionSchema manually +C10_DEFINE_OP_SCHEMA(Concat, FunctionSchema( + "_c10_experimental::Concat", + (std::vector<c10::Argument>{ + c10::Argument("inputs", ListType::ofTensors()), + c10::Argument("output"), + c10::Argument("split_info", FloatType::get()), + c10::Argument("add", IntType::get()), + c10::Argument("add_axis", IntType::get()) + }), (std::vector<c10::Argument>{ + }) +)); +} +} namespace { struct AxisParameter final { @@ -31,6 +46,7 @@ namespace caffe2 { REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_ARRAY_INPUT_AND_PARAMETERS( ops::Concat, C10Concat_DontUseThisOpYet, + 2, ParameterHelper<AxisParameter>, ParameterHelper<AddAxisParameter>) } diff --git a/caffe2/operators/experimental/c10/schemas/concat.h b/caffe2/operators/experimental/c10/schemas/concat.h index 2309de2d82..aecaf4057a 100644 --- a/caffe2/operators/experimental/c10/schemas/concat.h +++ b/caffe2/operators/experimental/c10/schemas/concat.h @@ -1,36 +1,11 @@ #pragma once -#include <ATen/core/dispatch/OpSchema.h> -#include <ATen/core/dispatch/DeviceId.h> -#include <ATen/core/Tensor.h> -#include <c10/util/Array.h> -#include <c10/util/ArrayRef.h> -#include "caffe2/core/context_base.h" -#include <ATen/core/ivalue.h> +#include <ATen/core/dispatch/OpSchemaRegistration.h> namespace caffe2 { namespace ops { -struct Concat final { - static constexpr const char* name = "concat"; - - using Signature = void( - ArrayRef<at::Tensor> inputs, - const at::Tensor& output, - const at::Tensor& split_info, - int add, - int add_axis); - - static constexpr size_t num_outputs() {return 2;} - - static constexpr c10::guts::array<const char*, 5> parameter_names = { - {"inputs", "output", "split_info_output", "add", "add_axis"}}; - - static c10::DeviceTypeId dispatch_key( - const Stack* arguments) { - return c10::DeviceTypeId::CPU; - } -}; +C10_DECLARE_OP_SCHEMA(Concat); } // namespace ops } // namespace caffe2 diff --git a/caffe2/operators/experimental/c10/schemas/enforce_finite.cc b/caffe2/operators/experimental/c10/schemas/enforce_finite.cc index 82d5f18755..170bc2592a 100644 --- a/caffe2/operators/experimental/c10/schemas/enforce_finite.cc +++ b/caffe2/operators/experimental/c10/schemas/enforce_finite.cc @@ -4,10 +4,22 @@ using caffe2::CPUContext; -C10_DEFINE_OP_SCHEMA(caffe2::ops::EnforceFinite); +namespace caffe2 { +namespace ops { +// TODO Parse schema string instead of creating FunctionSchema manually +C10_DEFINE_OP_SCHEMA(EnforceFinite, FunctionSchema( + "_c10_experimental::EnforceFinite", + (std::vector<c10::Argument>{ + c10::Argument("input") + }), (std::vector<c10::Argument>{ + }) +)); +} +} namespace caffe2 { REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH( ops::EnforceFinite, - C10EnforceFinite_DontUseThisOpYet) + C10EnforceFinite_DontUseThisOpYet, + 0) } diff --git a/caffe2/operators/experimental/c10/schemas/enforce_finite.h b/caffe2/operators/experimental/c10/schemas/enforce_finite.h index f811e2b6d8..704136c1f6 100644 --- a/caffe2/operators/experimental/c10/schemas/enforce_finite.h +++ b/caffe2/operators/experimental/c10/schemas/enforce_finite.h @@ -1,23 +1,11 @@ #pragma once -#include <ATen/core/Tensor.h> -#include <c10/util/Array.h> +#include <ATen/core/dispatch/OpSchemaRegistration.h> namespace caffe2 { namespace ops { -struct EnforceFinite final { - static constexpr const char* name = "enforce_finite"; - - using Signature = void(const at::Tensor& input); - - static constexpr size_t num_dispatch_args() {return 1;} - - static constexpr size_t num_outputs() {return 0;} - - static constexpr c10::guts::array<const char*, 1> parameter_names = { - {"input"}}; -}; +C10_DECLARE_OP_SCHEMA(EnforceFinite); } // namespace ops } // namespace caffe2 diff --git a/caffe2/operators/experimental/c10/schemas/expand_dims.cc b/caffe2/operators/experimental/c10/schemas/expand_dims.cc index 2145a1c1f9..939edd7263 100644 --- a/caffe2/operators/experimental/c10/schemas/expand_dims.cc +++ b/caffe2/operators/experimental/c10/schemas/expand_dims.cc @@ -6,7 +6,20 @@ using caffe2::CPUContext; using c10::intrusive_ptr; using c10::ivalue::IntList; -C10_DEFINE_OP_SCHEMA(caffe2::ops::ExpandDims); +namespace caffe2 { +namespace ops { +// TODO Parse schema string instead of creating FunctionSchema manually +C10_DEFINE_OP_SCHEMA(ExpandDims, FunctionSchema( + "_c10_experimental::ExpandDims", + (std::vector<c10::Argument>{ + c10::Argument("input"), + c10::Argument("output"), + c10::Argument("dims", ListType::ofInts()) + }), (std::vector<c10::Argument>{ + }) +)); +} +} namespace { struct DimsParameter final { @@ -22,5 +35,6 @@ namespace caffe2 { REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_PARAMETERS( ops::ExpandDims, C10ExpandDims_DontUseThisOpYet, + 1, DimsParameter) } diff --git a/caffe2/operators/experimental/c10/schemas/expand_dims.h b/caffe2/operators/experimental/c10/schemas/expand_dims.h index f30c7dd2eb..fa3ab8f99f 100644 --- a/caffe2/operators/experimental/c10/schemas/expand_dims.h +++ b/caffe2/operators/experimental/c10/schemas/expand_dims.h @@ -1,30 +1,11 @@ #pragma once -#include <ATen/core/Tensor.h> -#include <c10/util/Array.h> -#include "caffe2/core/context_base.h" -#include <ATen/core/ivalue.h> -#include <ATen/core/blob.h> -#include <ATen/core/dispatch/OpSchema.h> +#include <ATen/core/dispatch/OpSchemaRegistration.h> namespace caffe2 { namespace ops { -struct ExpandDims final { - static constexpr const char* name = "expand_dims"; - - using Signature = void( - const at::Tensor& input, - const at::Tensor& output, - ArrayRef<int64_t> dims); - - static constexpr size_t num_dispatch_args() {return 1;} - - static constexpr size_t num_outputs() {return 1;} - - static constexpr c10::guts::array<const char*, 3> parameter_names = { - {"input", "output", "dims"}}; -}; +C10_DECLARE_OP_SCHEMA(ExpandDims); } // namespace ops } // namespace caffe2 diff --git a/caffe2/operators/experimental/c10/schemas/fc.cc b/caffe2/operators/experimental/c10/schemas/fc.cc index 4a1cdb41ff..29784a587c 100644 --- a/caffe2/operators/experimental/c10/schemas/fc.cc +++ b/caffe2/operators/experimental/c10/schemas/fc.cc @@ -4,7 +4,23 @@ using caffe2::CPUContext; -C10_DEFINE_OP_SCHEMA(caffe2::ops::FullyConnected); +namespace caffe2 { +namespace ops { +// TODO Parse schema string instead of creating FunctionSchema manually +C10_DEFINE_OP_SCHEMA(FullyConnected, FunctionSchema( + "_c10_experimental::FullyConnected", + (std::vector<c10::Argument>{ + c10::Argument("X"), + c10::Argument("W"), + c10::Argument("b"), + c10::Argument("output"), + c10::Argument("axis", IntType::get()), + c10::Argument("axis_w", IntType::get()) + }), (std::vector<c10::Argument>{ + }) +)); +} +} namespace { struct AxisParameter final { @@ -32,6 +48,7 @@ namespace caffe2 { REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_PARAMETERS( ops::FullyConnected, C10FC_DontUseThisOpYet, + 1, ParameterHelper<AxisParameter>, ParameterHelper<AxisWParameter>) } diff --git a/caffe2/operators/experimental/c10/schemas/fc.h b/caffe2/operators/experimental/c10/schemas/fc.h index 5b2a30fa2c..1aed0eb311 100644 --- a/caffe2/operators/experimental/c10/schemas/fc.h +++ b/caffe2/operators/experimental/c10/schemas/fc.h @@ -1,32 +1,11 @@ #pragma once -#include <ATen/core/Tensor.h> -#include <c10/util/Array.h> -#include "caffe2/core/tensor.h" -#include <ATen/core/blob.h> -#include <ATen/core/dispatch/OpSchema.h> +#include <ATen/core/dispatch/OpSchemaRegistration.h> namespace caffe2 { namespace ops { -struct FullyConnected final { - static constexpr const char* name = "FC"; - - using Signature = void( - const at::Tensor& X, - const at::Tensor& W, - const at::Tensor& b, - const at::Tensor& output, - int axis, - int axis_w); - - static constexpr size_t num_dispatch_args() {return 3;} - - static constexpr size_t num_outputs() {return 1;} - - static constexpr c10::guts::array<const char*, 6> parameter_names = { - {"X", "W", "b", "output", "axis", "axis_w"}}; -}; +C10_DECLARE_OP_SCHEMA(FullyConnected); } // namespace ops } // namespace caffe2 diff --git a/caffe2/operators/experimental/c10/schemas/filler.cc b/caffe2/operators/experimental/c10/schemas/filler.cc index 8dfd0655fd..3d7de21f6b 100644 --- a/caffe2/operators/experimental/c10/schemas/filler.cc +++ b/caffe2/operators/experimental/c10/schemas/filler.cc @@ -8,12 +8,73 @@ using c10::C10Tensor; using c10::ivalue::IntList; using c10::intrusive_ptr; -C10_DEFINE_OP_SCHEMA(caffe2::ops::ConstantFill); -C10_DEFINE_OP_SCHEMA(caffe2::ops::UniformFill); - -C10_DEFINE_OP_SCHEMA(caffe2::ops::GivenTensorFill<float>); -C10_DEFINE_OP_SCHEMA(caffe2::ops::GivenTensorFill<int>); -C10_DEFINE_OP_SCHEMA(caffe2::ops::GivenTensorFill<int64_t>); +namespace caffe2 { +namespace ops { +// TODO Parse schema strings instead of creating FunctionSchema manually +C10_DEFINE_OP_SCHEMA(ConstantFill, FunctionSchema( + "_c10_experimental::ConstantFill", + (std::vector<c10::Argument>{ + c10::Argument("inputs", ListType::ofTensors()), + c10::Argument("output"), + c10::Argument("shape", ListType::ofInts()), + c10::Argument("extra_shape", ListType::ofInts()), + c10::Argument("input_as_shape", BoolType::get()), + c10::Argument("dtype", IntType::get()), + c10::Argument("value", NumberType::get()) + }), (std::vector<c10::Argument>{ + }) +)); +C10_DEFINE_OP_SCHEMA(UniformFill, FunctionSchema( + "_c10_experimental::ConstantFill", + (std::vector<c10::Argument>{ + c10::Argument("inputs", ListType::ofTensors()), + c10::Argument("output"), + c10::Argument("shape", ListType::ofInts()), + c10::Argument("extra_shape", ListType::ofInts()), + c10::Argument("input_as_shape", BoolType::get()), + c10::Argument("min", FloatType::get()), + c10::Argument("max", FloatType::get()) + }), (std::vector<c10::Argument>{ + }) +)); +C10_DEFINE_OP_SCHEMA(GivenTensorFill, FunctionSchema( + "_c10_experimental::ConstantFill", + (std::vector<c10::Argument>{ + c10::Argument("inputs", ListType::ofTensors()), + c10::Argument("output"), + c10::Argument("shape", ListType::ofInts()), + c10::Argument("extra_shape", ListType::ofInts()), + c10::Argument("input_as_shape", BoolType::get()), + c10::Argument("values"), + }), (std::vector<c10::Argument>{ + }) +)); +C10_DEFINE_OP_SCHEMA(GivenTensorIntFill, FunctionSchema( + "_c10_experimental::ConstantFill", + (std::vector<c10::Argument>{ + c10::Argument("inputs", ListType::ofTensors()), + c10::Argument("output"), + c10::Argument("shape", ListType::ofInts()), + c10::Argument("extra_shape", ListType::ofInts()), + c10::Argument("input_as_shape", BoolType::get()), + c10::Argument("values"), + }), (std::vector<c10::Argument>{ + }) +)); +C10_DEFINE_OP_SCHEMA(GivenTensorInt64Fill, FunctionSchema( + "_c10_experimental::ConstantFill", + (std::vector<c10::Argument>{ + c10::Argument("inputs", ListType::ofTensors()), + c10::Argument("output"), + c10::Argument("shape", ListType::ofInts()), + c10::Argument("extra_shape", ListType::ofInts()), + c10::Argument("input_as_shape", BoolType::get()), + c10::Argument("values"), + }), (std::vector<c10::Argument>{ + }) +)); +} +} namespace { struct ShapeParameter final { @@ -136,6 +197,7 @@ namespace caffe2 { REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_ARRAY_INPUT_AND_PARAMETERS( ops::ConstantFill, C10ConstantFill_DontUseThisOpYet, + 1, ShapeParameter, ExtraShapeParameter, InputAsShapeParameter, @@ -144,6 +206,7 @@ REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_ARRAY_INPUT_AND_PARAMETERS( REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_ARRAY_INPUT_AND_PARAMETERS( ops::UniformFill, C10UniformFill_DontUseThisOpYet, + 1, ShapeParameter, ExtraShapeParameter, InputAsShapeParameter, @@ -151,22 +214,25 @@ REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_ARRAY_INPUT_AND_PARAMETERS( MaxParameter) REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_ARRAY_INPUT_AND_PARAMETERS( - ops::GivenTensorFill<float>, + ops::GivenTensorFill, C10GivenTensorFill_DontUseThisOpYet, + 1, ShapeParameter, ExtraShapeParameter, InputAsShapeParameter, ValuesParameter<float>) REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_ARRAY_INPUT_AND_PARAMETERS( - ops::GivenTensorFill<int>, + ops::GivenTensorIntFill, C10GivenTensorIntFill_DontUseThisOpYet, + 1, ShapeParameter, ExtraShapeParameter, InputAsShapeParameter, ValuesParameter<int>) REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_ARRAY_INPUT_AND_PARAMETERS( - ops::GivenTensorFill<int64_t>, + ops::GivenTensorInt64Fill, C10GivenTensorInt64Fill_DontUseThisOpYet, + 1, ShapeParameter, ExtraShapeParameter, InputAsShapeParameter, diff --git a/caffe2/operators/experimental/c10/schemas/filler.h b/caffe2/operators/experimental/c10/schemas/filler.h index ef0b04662e..616893d25a 100644 --- a/caffe2/operators/experimental/c10/schemas/filler.h +++ b/caffe2/operators/experimental/c10/schemas/filler.h @@ -1,105 +1,15 @@ #pragma once -#include <ATen/core/dispatch/OpSchema.h> -#include <ATen/core/dispatch/DeviceId.h> -#include <ATen/core/Tensor.h> -#include <ATen/core/ivalue.h> -#include <c10/util/Array.h> -#include <c10/util/ArrayRef.h> -#include "caffe2/core/context_base.h" +#include <ATen/core/dispatch/OpSchemaRegistration.h> namespace caffe2 { namespace ops { -// GivenTensorFill -// GivenTensorInt64Fill -// GivenTensorIntFill - -template <class T> -struct GivenTensorFill final { - static constexpr const char* name = "given_tensor_fill"; - - using Signature = void( - ArrayRef<at::Tensor> inputs, - const at::Tensor& output, - ArrayRef<int64_t> shape, - ArrayRef<int64_t> extra_shape, - bool input_as_shape, - const at::Tensor& values); - - static constexpr c10::guts::array<const char*, 6> parameter_names = { - {"inputs", - "output", - "shape", - "extra_shape", - "input_as_shape", - "values"}}; - - static constexpr size_t num_outputs() {return 1;} - - static c10::DeviceTypeId dispatch_key( - const Stack* stack) { - return c10::DeviceTypeId::CPU; - } -}; - -struct ConstantFill final { - static constexpr const char* name = "constant_fill"; - - using Signature = void( - ArrayRef<at::Tensor> inputs, - const at::Tensor& output, - ArrayRef<int64_t> shape, - ArrayRef<int64_t> extra_shape, - bool input_as_shape, - int dtype, - IValue value); - - static constexpr size_t num_outputs() {return 1;} - - static constexpr c10::guts::array<const char*, 7> parameter_names = { - {"inputs", - "output", - "shape", - "extra_shape", - "input_as_shape", - "dtype", - "value"}}; - - static c10::DeviceTypeId dispatch_key( - const Stack* stack) { - return c10::DeviceTypeId::CPU; - } -}; - -struct UniformFill final { - static constexpr const char* name = "uniform_fill"; - - using Signature = void( - ArrayRef<at::Tensor> inputs, - const at::Tensor& output, - ArrayRef<int64_t> shape, - ArrayRef<int64_t> extra_shape, - bool input_as_shape, - float min, - float max); - - static constexpr size_t num_outputs() {return 1;} - - static constexpr c10::guts::array<const char*, 7> parameter_names = { - {"inputs", - "output", - "shape", - "extra_shape", - "input_as_shape", - "min", - "max"}}; - - static c10::DeviceTypeId dispatch_key( - const Stack* stack) { - return c10::DeviceTypeId::CPU; - } -}; +C10_DECLARE_OP_SCHEMA(GivenTensorFill); +C10_DECLARE_OP_SCHEMA(GivenTensorIntFill); +C10_DECLARE_OP_SCHEMA(GivenTensorInt64Fill); +C10_DECLARE_OP_SCHEMA(ConstantFill); +C10_DECLARE_OP_SCHEMA(UniformFill); } // namespace ops } // namespace caffe2 diff --git a/caffe2/operators/experimental/c10/schemas/flatten.cc b/caffe2/operators/experimental/c10/schemas/flatten.cc index 8a14e58d07..c1918654c6 100644 --- a/caffe2/operators/experimental/c10/schemas/flatten.cc +++ b/caffe2/operators/experimental/c10/schemas/flatten.cc @@ -4,7 +4,20 @@ using caffe2::CPUContext; -C10_DEFINE_OP_SCHEMA(caffe2::ops::Flatten); +namespace caffe2 { +namespace ops { +// TODO Parse schema string instead of creating FunctionSchema manually +C10_DEFINE_OP_SCHEMA(Flatten, FunctionSchema( + "_c10_experimental::Flatten", + (std::vector<c10::Argument>{ + c10::Argument("input"), + c10::Argument("output"), + c10::Argument("axis", IntType::get()) + }), (std::vector<c10::Argument>{ + }) +)); +} +} namespace { struct AxisParameter final { @@ -22,5 +35,6 @@ namespace caffe2 { REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_PARAMETERS( ops::Flatten, C10Flatten_DontUseThisOpYet, + 1, ParameterHelper<AxisParameter>) } diff --git a/caffe2/operators/experimental/c10/schemas/flatten.h b/caffe2/operators/experimental/c10/schemas/flatten.h index 0ee1773cb6..9c53462528 100644 --- a/caffe2/operators/experimental/c10/schemas/flatten.h +++ b/caffe2/operators/experimental/c10/schemas/flatten.h @@ -1,27 +1,11 @@ #pragma once -#include <ATen/core/Tensor.h> -#include <c10/util/Array.h> -#include "caffe2/core/context_base.h" +#include <ATen/core/dispatch/OpSchemaRegistration.h> namespace caffe2 { namespace ops { -struct Flatten final { - static constexpr const char* name = "flatten"; - - using Signature = void( - const at::Tensor& input, - const at::Tensor& output, - int axis); - - static constexpr size_t num_dispatch_args() {return 1;} - - static constexpr size_t num_outputs() {return 1;} - - static constexpr c10::guts::array<const char*, 3> parameter_names = { - {"input", "output", "axis"}}; -}; +C10_DECLARE_OP_SCHEMA(Flatten); } // namespace ops } // namespace caffe2 diff --git a/caffe2/operators/experimental/c10/schemas/layer_norm.cc b/caffe2/operators/experimental/c10/schemas/layer_norm.cc index e9dddfb2db..4fc6b18999 100644 --- a/caffe2/operators/experimental/c10/schemas/layer_norm.cc +++ b/caffe2/operators/experimental/c10/schemas/layer_norm.cc @@ -28,6 +28,7 @@ namespace caffe2 { REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_PARAMETERS( c10::core::opschema::LayerNorm, C10LayerNorm_DontUseThisOpYet, + 3, ParameterHelper<AxisParameter>, ParameterHelper<EpsilonParameter>) } diff --git a/caffe2/operators/experimental/c10/schemas/mul.cc b/caffe2/operators/experimental/c10/schemas/mul.cc index 9111dfda29..723aeab92b 100644 --- a/caffe2/operators/experimental/c10/schemas/mul.cc +++ b/caffe2/operators/experimental/c10/schemas/mul.cc @@ -4,7 +4,22 @@ using caffe2::CPUContext; -C10_DEFINE_OP_SCHEMA(caffe2::ops::Mul); +namespace caffe2 { +namespace ops { +// TODO Parse schema string instead of creating FunctionSchema manually +C10_DEFINE_OP_SCHEMA(Mul, FunctionSchema( + "_c10_experimental::Mul", + (std::vector<c10::Argument>{ + c10::Argument("input1"), + c10::Argument("input2"), + c10::Argument("output"), + c10::Argument("legacy_broadcast", BoolType::get()), + c10::Argument("axis", IntType::get()) + }), (std::vector<c10::Argument>{ + }) +)); +} +} namespace { @@ -32,6 +47,7 @@ namespace caffe2 { REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_PARAMETERS( ops::Mul, C10Mul_DontUseThisOpYet, + 1, ParameterHelper<LegacyBroadcastParameter>, ParameterHelper<AxisParameter>) } diff --git a/caffe2/operators/experimental/c10/schemas/mul.h b/caffe2/operators/experimental/c10/schemas/mul.h index 12a1780330..54b64f4f74 100644 --- a/caffe2/operators/experimental/c10/schemas/mul.h +++ b/caffe2/operators/experimental/c10/schemas/mul.h @@ -1,29 +1,11 @@ #pragma once -#include <ATen/core/Tensor.h> -#include <c10/util/Array.h> -#include "caffe2/core/context_base.h" +#include <ATen/core/dispatch/OpSchemaRegistration.h> namespace caffe2 { namespace ops { -struct Mul final { - static constexpr const char* name = "mul"; - - using Signature = void( - const at::Tensor& input1, - const at::Tensor& input2, - const at::Tensor& output, - bool legacy_broadcast, - int axis); - - static constexpr size_t num_dispatch_args() {return 2;} - - static constexpr size_t num_outputs() {return 1;} - - static constexpr c10::guts::array<const char*, 5> parameter_names = { - {"input1", "input2", "output", "legacy_broadcast", "axis"}}; -}; +C10_DECLARE_OP_SCHEMA(Mul); } // namespace ops } // namespace caffe2 diff --git a/caffe2/operators/experimental/c10/schemas/relu.cc b/caffe2/operators/experimental/c10/schemas/relu.cc index 1d08b20be5..91b95f1820 100644 --- a/caffe2/operators/experimental/c10/schemas/relu.cc +++ b/caffe2/operators/experimental/c10/schemas/relu.cc @@ -4,10 +4,24 @@ using caffe2::CPUContext; -C10_DEFINE_OP_SCHEMA(caffe2::ops::Relu); +namespace caffe2 { +namespace ops { +// TODO Parse schema string instead of creating FunctionSchema manually +C10_DEFINE_OP_SCHEMA(Relu, FunctionSchema( + "_c10_experimental::Relu", + (std::vector<c10::Argument>{ + c10::Argument("input"), + c10::Argument("output") + }), (std::vector<c10::Argument>{ + }) +)); +} +} namespace caffe2 { REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH( ops::Relu, - C10Relu_DontUseThisOpYet) + C10Relu_DontUseThisOpYet, + 1 + ) } diff --git a/caffe2/operators/experimental/c10/schemas/relu.h b/caffe2/operators/experimental/c10/schemas/relu.h index bf1b8fd03c..ea0aa89367 100644 --- a/caffe2/operators/experimental/c10/schemas/relu.h +++ b/caffe2/operators/experimental/c10/schemas/relu.h @@ -1,24 +1,11 @@ #pragma once -#include <ATen/core/Tensor.h> -#include <c10/util/Array.h> +#include <ATen/core/dispatch/OpSchemaRegistration.h> namespace caffe2 { namespace ops { -struct Relu final { - static constexpr const char* name = "relu"; - - using Signature = - void(const at::Tensor& input, const at::Tensor& output); - - static constexpr size_t num_dispatch_args() {return 1;} - - static constexpr size_t num_outputs() {return 1;} - - static constexpr c10::guts::array<const char*, 2> parameter_names = { - {"input", "output"}}; -}; +C10_DECLARE_OP_SCHEMA(Relu); } // namespace ops } // namespace caffe2 diff --git a/caffe2/operators/experimental/c10/schemas/sigmoid.cc b/caffe2/operators/experimental/c10/schemas/sigmoid.cc index fbcfe1e449..ebc6fabfe5 100644 --- a/caffe2/operators/experimental/c10/schemas/sigmoid.cc +++ b/caffe2/operators/experimental/c10/schemas/sigmoid.cc @@ -4,10 +4,23 @@ using caffe2::CPUContext; -C10_DEFINE_OP_SCHEMA(caffe2::ops::Sigmoid); +namespace caffe2 { +namespace ops { +// TODO Parse schema string instead of creating FunctionSchema manually +C10_DEFINE_OP_SCHEMA(Sigmoid, FunctionSchema( + "_c10_experimental::Sigmoid", + (std::vector<c10::Argument>{ + c10::Argument("input"), + c10::Argument("output") + }), (std::vector<c10::Argument>{ + }) +)); +} +} namespace caffe2 { REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH( ops::Sigmoid, - C10Sigmoid_DontUseThisOpYet) + C10Sigmoid_DontUseThisOpYet, + 1) } diff --git a/caffe2/operators/experimental/c10/schemas/sigmoid.h b/caffe2/operators/experimental/c10/schemas/sigmoid.h index 326dc078f4..5d5ff41a59 100644 --- a/caffe2/operators/experimental/c10/schemas/sigmoid.h +++ b/caffe2/operators/experimental/c10/schemas/sigmoid.h @@ -1,24 +1,11 @@ #pragma once -#include <ATen/core/Tensor.h> -#include <c10/util/Array.h> +#include <ATen/core/dispatch/OpSchemaRegistration.h> namespace caffe2 { namespace ops { -struct Sigmoid final { - static constexpr const char* name = "sigmoid"; - - using Signature = - void(const at::Tensor& input, const at::Tensor& output); - - static constexpr size_t num_dispatch_args() {return 1;} - - static constexpr size_t num_outputs() {return 1;} - - static constexpr c10::guts::array<const char*, 2> parameter_names = { - {"input", "output"}}; -}; +C10_DECLARE_OP_SCHEMA(Sigmoid); } // namespace ops } // namespace caffe2 diff --git a/caffe2/operators/experimental/c10/schemas/sigmoid_cross_entropy_with_logits.cc b/caffe2/operators/experimental/c10/schemas/sigmoid_cross_entropy_with_logits.cc index 0e34417200..f7aa2acddb 100644 --- a/caffe2/operators/experimental/c10/schemas/sigmoid_cross_entropy_with_logits.cc +++ b/caffe2/operators/experimental/c10/schemas/sigmoid_cross_entropy_with_logits.cc @@ -4,7 +4,22 @@ using caffe2::CPUContext; -C10_DEFINE_OP_SCHEMA(caffe2::ops::SigmoidCrossEntropyWithLogits); +namespace caffe2 { +namespace ops { +// TODO Parse schema string instead of creating FunctionSchema manually +C10_DEFINE_OP_SCHEMA(SigmoidCrossEntropyWithLogits, FunctionSchema( + "_c10_experimental::SigmoidCrossEntropyWithLogits", + (std::vector<c10::Argument>{ + c10::Argument("input1"), + c10::Argument("input2"), + c10::Argument("output"), + c10::Argument("log_D_trick", BoolType::get()), + c10::Argument("unjoined_lr_loss", BoolType::get()) + }), (std::vector<c10::Argument>{ + }) +)); +} +} namespace { struct LogDTrickParameter final { @@ -31,6 +46,7 @@ namespace caffe2 { REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_PARAMETERS( ops::SigmoidCrossEntropyWithLogits, C10SigmoidCrossEntropyWithLogits_DontUseThisOpYet, + 1, ParameterHelper<LogDTrickParameter>, ParameterHelper<UnjoinedLRLossParameter>) } diff --git a/caffe2/operators/experimental/c10/schemas/sigmoid_cross_entropy_with_logits.h b/caffe2/operators/experimental/c10/schemas/sigmoid_cross_entropy_with_logits.h index 7fb7a88ece..671c2e2523 100644 --- a/caffe2/operators/experimental/c10/schemas/sigmoid_cross_entropy_with_logits.h +++ b/caffe2/operators/experimental/c10/schemas/sigmoid_cross_entropy_with_logits.h @@ -1,28 +1,11 @@ #pragma once -#include <ATen/core/Tensor.h> -#include <c10/util/Array.h> +#include <ATen/core/dispatch/OpSchemaRegistration.h> namespace caffe2 { namespace ops { -struct SigmoidCrossEntropyWithLogits final { - static constexpr const char* name = "sigmoid_cross_entropy_with_logits"; - - using Signature = void( - const at::Tensor& input1, - const at::Tensor& input2, - const at::Tensor& output, - bool log_D_trick, - bool unjoined_lr_loss); - - static constexpr size_t num_dispatch_args() {return 2;} - - static constexpr size_t num_outputs() {return 1;} - - static constexpr c10::guts::array<const char*, 5> parameter_names = { - {"input1", "input2", "output", "log_d_trick", "unjoined_lr_loss"}}; -}; +C10_DECLARE_OP_SCHEMA(SigmoidCrossEntropyWithLogits); } // namespace ops } // namespace caffe2 diff --git a/caffe2/operators/experimental/c10/schemas/sparse_lengths_sum.cc b/caffe2/operators/experimental/c10/schemas/sparse_lengths_sum.cc index a2e70180c5..1dd558b4ee 100644 --- a/caffe2/operators/experimental/c10/schemas/sparse_lengths_sum.cc +++ b/caffe2/operators/experimental/c10/schemas/sparse_lengths_sum.cc @@ -4,10 +4,25 @@ using caffe2::CPUContext; -C10_DEFINE_OP_SCHEMA(caffe2::ops::SparseLengthsSum); +namespace caffe2 { +namespace ops { +// TODO Parse schema string instead of creating FunctionSchema manually +C10_DEFINE_OP_SCHEMA(SparseLengthsSum, FunctionSchema( + "_c10_experimental::SparseLengthsSum", + (std::vector<c10::Argument>{ + c10::Argument("data"), + c10::Argument("indices"), + c10::Argument("lengths"), + c10::Argument("output") + }), (std::vector<c10::Argument>{ + }) +)); +} +} namespace caffe2 { REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH( ops::SparseLengthsSum, - C10SparseLengthsSum_DontUseThisOpYet) + C10SparseLengthsSum_DontUseThisOpYet, + 1) } diff --git a/caffe2/operators/experimental/c10/schemas/sparse_lengths_sum.h b/caffe2/operators/experimental/c10/schemas/sparse_lengths_sum.h index 16d23e733e..a4054e1aae 100644 --- a/caffe2/operators/experimental/c10/schemas/sparse_lengths_sum.h +++ b/caffe2/operators/experimental/c10/schemas/sparse_lengths_sum.h @@ -1,27 +1,11 @@ #pragma once -#include <ATen/core/Tensor.h> -#include <c10/util/Array.h> +#include <ATen/core/dispatch/OpSchemaRegistration.h> namespace caffe2 { namespace ops { -struct SparseLengthsSum final { - static constexpr const char* name = "sparse_lengths_sum"; - - using Signature = void( - const at::Tensor& data, - const at::Tensor& indices, - const at::Tensor& lengths, - const at::Tensor& output); - - static constexpr size_t num_dispatch_args() {return 3;} - - static constexpr size_t num_outputs() {return 1;} - - static constexpr c10::guts::array<const char*, 4> parameter_names = { - {"data", "indices", "lengths", "output"}}; -}; +C10_DECLARE_OP_SCHEMA(SparseLengthsSum); } // namespace ops } // namespace caffe2 diff --git a/caffe2/operators/experimental/c10/schemas/stop_gradient.cc b/caffe2/operators/experimental/c10/schemas/stop_gradient.cc index 630e74056b..4c26785e29 100644 --- a/caffe2/operators/experimental/c10/schemas/stop_gradient.cc +++ b/caffe2/operators/experimental/c10/schemas/stop_gradient.cc @@ -4,10 +4,24 @@ using caffe2::CPUContext; -C10_DEFINE_OP_SCHEMA(caffe2::ops::StopGradient); +namespace caffe2 { +namespace ops { +// TODO Parse schema string instead of creating FunctionSchema manually +C10_DEFINE_OP_SCHEMA(StopGradient, FunctionSchema( + "_c10_experimental::StopGradient", + (std::vector<c10::Argument>{ + c10::Argument("input"), + c10::Argument("output") + }), (std::vector<c10::Argument>{ + }) +)); +} +} namespace caffe2 { REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH( ops::StopGradient, - C10StopGradient_DontUseThisOpYet) + C10StopGradient_DontUseThisOpYet, + 1 + ) } diff --git a/caffe2/operators/experimental/c10/schemas/stop_gradient.h b/caffe2/operators/experimental/c10/schemas/stop_gradient.h index f89f942eb7..bb130e2c8c 100644 --- a/caffe2/operators/experimental/c10/schemas/stop_gradient.h +++ b/caffe2/operators/experimental/c10/schemas/stop_gradient.h @@ -1,26 +1,11 @@ #pragma once -#include <ATen/core/Tensor.h> -#include <c10/util/Array.h> -#include "caffe2/core/context_base.h" +#include <ATen/core/dispatch/OpSchemaRegistration.h> namespace caffe2 { namespace ops { -struct StopGradient final { - static constexpr const char* name = "stop_gradient"; - - using Signature = void( - const at::Tensor& input, - const at::Tensor& output); - - static constexpr size_t num_dispatch_args() {return 1;} - - static constexpr size_t num_outputs() {return 1;} - - static constexpr c10::guts::array<const char*, 2> parameter_names = { - {"input", "output"}}; -}; +C10_DECLARE_OP_SCHEMA(StopGradient); } // namespace ops } // namespace caffe2 diff --git a/caffe2/operators/layer_norm_op.cc b/caffe2/operators/layer_norm_op.cc index 1b096f4dc8..f452f4bf3c 100644 --- a/caffe2/operators/layer_norm_op.cc +++ b/caffe2/operators/layer_norm_op.cc @@ -186,15 +186,15 @@ to the end.) // Register layer norm with c10 namespace { -struct State final : public c10::KernelState { +struct Cache final : public c10::KernelCache { at::optional<at::Tensor> scale = at::nullopt; at::optional<at::Tensor> bias = at::nullopt; }; template <class DataType> -void layer_norm_c10(c10::Stack* stack, c10::KernelState* state) { // TODO Pass in correct state type +void layer_norm_c10(c10::Stack* stack, c10::KernelCache* cache_) { // TODO Pass in correct cache type c10::ArrayRef<c10::IValue> inputs = torch::jit::peekSlice(*stack, 0, 3, 6); - c10::ArrayRef<c10::IValue> outputs = torch::jit::peekSlice(*stack, 0, 3, 3); + c10::ArrayRef<c10::IValue> outputs = torch::jit::peekSlice(*stack, 3, 3, 6); caffe2::Tensor X{c10::C10Tensor(inputs[0].toTensor())}; int64_t axis = inputs[1].toInt(); @@ -204,7 +204,7 @@ void layer_norm_c10(c10::Stack* stack, c10::KernelState* state) { // TODO Pass i caffe2::Tensor sig{c10::C10Tensor(outputs[2].toTensor())}; caffe2::CPUContext context; - State* cache = static_cast<State*>(state); + Cache* cache = static_cast<Cache*>(cache_); if (!cache->scale.has_value()) { cache->scale = at::Tensor(caffe2::empty({0}, at::dtype<float>())); } @@ -224,19 +224,19 @@ void layer_norm_c10(c10::Stack* stack, c10::KernelState* state) { // TODO Pass i X, &Y, &mean, &sig, canonical_axis, epsilon, &scale, &bias, static_cast<caffe2::CPUContext*>(&context) ); - torch::jit::peek(*stack, 0, 3) = at::Tensor(c10::C10Tensor(std::move(Y))); - torch::jit::peek(*stack, 1, 3) = at::Tensor(c10::C10Tensor(std::move(mean))); - torch::jit::peek(*stack, 2, 3) = at::Tensor(c10::C10Tensor(std::move(sig))); + torch::jit::drop(*stack, 6); + torch::jit::push(*stack, + at::Tensor(c10::C10Tensor(std::move(Y))), + at::Tensor(c10::C10Tensor(std::move(mean))), + at::Tensor(c10::C10Tensor(std::move(sig))) + ); return; } } namespace c10 { C10_REGISTER_KERNEL(c10::core::opschema::LayerNorm) - .withState<State>() + .withCache<Cache>() .kernel<&layer_norm_c10<float>>() - .dispatchKey(c10::DispatchKey<1>{ - c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU, - LayoutId(0), - caffe2::TypeMeta::Id<float>()}}); + .dispatchKey(CPUTensorId()); } // namespace c10 diff --git a/torch/csrc/jit/c10_ops/layer_norm.cpp b/torch/csrc/jit/c10_ops/layer_norm.cpp index 7375645fac..02fdf89af4 100644 --- a/torch/csrc/jit/c10_ops/layer_norm.cpp +++ b/torch/csrc/jit/c10_ops/layer_norm.cpp @@ -27,27 +27,34 @@ namespace jit { namespace { RegisterOperators reg({ Operator( - "caffe2::layer_norm_dont_use_this_op_yet(Tensor input, int axis, float epsilon) -> (Tensor, Tensor, Tensor)", + //Note: This schema is: caffe2::layer_norm_dont_use_this_op_yet(Tensor input, int axis, float epsilon, Tensor? output = None, Tensor? output_mean = None, Tensor? output_stdev = None) -> (Tensor, Tensor, Tensor) + c10::core::opschema::LayerNorm().schema(), [](Stack& stack) { - Tensor tensor_input = std::move(stack[stack.size()-3]).toTensor(); + Tensor tensor_input = std::move(stack[stack.size()-6]).toTensor(); if (tensor_input.requires_grad()) { throw std::runtime_error("Autograd not yet supported for c10 ops."); } auto device = tensor_input.device(); - torch::jit::peek(stack, 0, 3) = torch::autograd::Variable(std::move(tensor_input)).data(); - // push output fields as outputs to stack - push(stack, at::empty({0}, device), at::empty({0}, device), at::empty({0}, device)); + // unwrap inputs from variable + torch::jit::peek(stack, 0, 6) = torch::autograd::Variable(std::move(tensor_input)).data(); - c10::Dispatcher<c10::core::opschema::LayerNorm>::lookup(&stack).call(&stack); + // allocate the output tensors that aren't set yet + for (int i = 3; i < 6; ++i) { + // TODO this should just check for isNone, not for undefined tensor. @wanchaol is working on this. + if (torch::jit::peek(stack, i, 6).isNone() || !torch::jit::peek(stack, i, 6).toTensor().defined()) { + torch::jit::peek(stack, i, 6) = at::empty({0}, device); + } + } + + // call caffe2 kernel + c10::Dispatcher::singleton().lookup(c10::core::opschema::LayerNorm(), &stack).call(&stack); - // move outputs down the stack to where the inputs were before + // wrap outputs into Variable for (int i = 0; i < 3; ++i) { - torch::jit::peek(stack, i, 6) = torch::autograd::make_variable(std::move(torch::jit::peek(stack, i, 3)).toTensor(), false); + torch::jit::peek(stack, i, 3) = torch::autograd::make_variable(std::move(torch::jit::peek(stack, i, 3)).toTensor(), false); } - drop(stack, 3); // drop inputs - return 0; }) }); |