diff options
author | Sebastian Messmer <messmer@fb.com> | 2019-01-17 15:47:16 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-01-17 15:56:52 -0800 |
commit | 3e85a2bcbfb1c0275ed61122e6a70a50ba2deece (patch) | |
tree | 5412d1025171a3c496505a619e7708c0c78fa577 /c10 | |
parent | a9438ba62f02af5032171c3773bfc54c348de298 (diff) | |
download | pytorch-3e85a2bcbfb1c0275ed61122e6a70a50ba2deece.tar.gz pytorch-3e85a2bcbfb1c0275ed61122e6a70a50ba2deece.tar.bz2 pytorch-3e85a2bcbfb1c0275ed61122e6a70a50ba2deece.zip |
Move c10 dispatcher back to ATen/core (#16050)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16050
The c10 dispatcher will (soon) depend on IValue and IValue can't be moved to c10 yet because it depends on at::Tensor, which depends on legacy Type dispatch and we don't want the legacy dispatch in c10.
So instead, we move the c10 dispatcher back to ATen/core until we can actually move at::Tensor to c10.
Reviewed By: ezyang
Differential Revision: D13684517
fbshipit-source-id: 1125f4254223907c52f96ff73034f6d4ae9fd0a7
Diffstat (limited to 'c10')
-rw-r--r-- | c10/core/dispatch/DeviceId.cpp | 1 | ||||
-rw-r--r-- | c10/core/dispatch/DeviceId.h | 36 | ||||
-rw-r--r-- | c10/core/dispatch/DispatchKey.cpp | 1 | ||||
-rw-r--r-- | c10/core/dispatch/DispatchKey.h | 97 | ||||
-rw-r--r-- | c10/core/dispatch/DispatchTable.cpp | 1 | ||||
-rw-r--r-- | c10/core/dispatch/DispatchTable.h | 154 | ||||
-rw-r--r-- | c10/core/dispatch/Dispatcher.cpp | 1 | ||||
-rw-r--r-- | c10/core/dispatch/Dispatcher.h | 60 | ||||
-rw-r--r-- | c10/core/dispatch/KernelRegistration.cpp | 1 | ||||
-rw-r--r-- | c10/core/dispatch/KernelRegistration.h | 141 | ||||
-rw-r--r-- | c10/core/dispatch/LayoutId.cpp | 1 | ||||
-rw-r--r-- | c10/core/dispatch/LayoutId.h | 22 | ||||
-rw-r--r-- | c10/core/dispatch/OpSchema.cpp | 1 | ||||
-rw-r--r-- | c10/core/dispatch/OpSchema.h | 263 | ||||
-rw-r--r-- | c10/core/dispatch/OpSchemaRegistration.cpp | 1 | ||||
-rw-r--r-- | c10/core/dispatch/OpSchemaRegistration.h | 18 | ||||
-rw-r--r-- | c10/core/opschema/layer_norm.cpp | 4 | ||||
-rw-r--r-- | c10/core/opschema/layer_norm.h | 40 | ||||
-rw-r--r-- | c10/test/core/dispatch/OpSchema_test.cpp | 27 |
19 files changed, 0 insertions, 870 deletions
diff --git a/c10/core/dispatch/DeviceId.cpp b/c10/core/dispatch/DeviceId.cpp deleted file mode 100644 index 60feda2fef..0000000000 --- a/c10/core/dispatch/DeviceId.cpp +++ /dev/null @@ -1 +0,0 @@ -#include <c10/core/dispatch/DeviceId.h> diff --git a/c10/core/dispatch/DeviceId.h b/c10/core/dispatch/DeviceId.h deleted file mode 100644 index cdcaf4b635..0000000000 --- a/c10/core/dispatch/DeviceId.h +++ /dev/null @@ -1,36 +0,0 @@ -#pragma once - -#include <c10/util/C++17.h> -#include <functional> -#include <iostream> - -namespace c10 { - -enum class DeviceTypeId : uint8_t { - // Don't use the int values here in the enum (i.e. don't do static_cast to or from int). - // Instead, if you want to serialize this, write a function with switch/case. - CPU = 0, - CUDA = 1, - UNDEFINED -}; - -inline std::ostream& operator<<(std::ostream& stream, DeviceTypeId device_type_id) { - switch(device_type_id) { - case c10::DeviceTypeId::CPU: return stream << "DeviceTypeId(CPU)"; - case c10::DeviceTypeId::CUDA: return stream << "DeviceTypeId(CUDA)"; - case c10::DeviceTypeId::UNDEFINED: return stream << "DeviceTypeId(UNDEFINED)"; - } - throw std::logic_error("Unknown DeviceTypeId: " + c10::guts::to_string(static_cast<int>(device_type_id))); -} - -} - -namespace std { - -template <> struct hash<c10::DeviceTypeId> { - size_t operator()(c10::DeviceTypeId v) const { - return std::hash<uint8_t>()(static_cast<uint8_t>(v)); - } -}; - -} diff --git a/c10/core/dispatch/DispatchKey.cpp b/c10/core/dispatch/DispatchKey.cpp deleted file mode 100644 index 1d736d865f..0000000000 --- a/c10/core/dispatch/DispatchKey.cpp +++ /dev/null @@ -1 +0,0 @@ -#include <c10/core/dispatch/DispatchKey.h> diff --git a/c10/core/dispatch/DispatchKey.h b/c10/core/dispatch/DispatchKey.h deleted file mode 100644 index bb691bfd01..0000000000 --- a/c10/core/dispatch/DispatchKey.h +++ /dev/null @@ -1,97 +0,0 @@ -#pragma once - -#include <c10/core/dispatch/DeviceId.h> -#include <c10/core/dispatch/LayoutId.h> -#include <c10/util/typeid.h> - -#include <vector> -#include <functional> -#include <sstream> -#include <c10/util/Array.h> - -namespace c10 { - -namespace details { - -struct TensorParameterDispatchKey final { - // note: This dispatch key structure is not final yet and will change. Don't rely on it. - DeviceTypeId deviceTypeId; - LayoutId layoutId; - // TODO Move this TypeIdentifier to c10 namespace - caffe2::TypeIdentifier dataType; -}; -inline constexpr bool operator==(const TensorParameterDispatchKey& lhs, const TensorParameterDispatchKey& rhs) { - return lhs.deviceTypeId == rhs.deviceTypeId && lhs.layoutId == rhs.layoutId && lhs.dataType == rhs.dataType; -} - -inline std::ostream& operator<<(std::ostream& stream, const TensorParameterDispatchKey& key) { - return stream << "TensorKey(" << key.deviceTypeId << ", " << key.layoutId.value() << ", " << key.dataType << ")"; -} - -} // namespace details -} // namespace c10 - -namespace std { - template<> - struct hash<c10::details::TensorParameterDispatchKey> { - // TODO constexpr hashing - size_t operator()(const c10::details::TensorParameterDispatchKey& obj) const { - return std::hash<c10::DeviceTypeId>()(obj.deviceTypeId) ^ std::hash<c10::LayoutId>()(obj.layoutId) ^ std::hash<caffe2::TypeIdentifier>()(obj.dataType); - } - }; -} // namespace std - -namespace c10 { -/** - * The dispatch key encodes the runtime type identity of a function call arguments, - * specifying what aspects of this identity can be dynamically dispatched on. - * - * Intuitively, given a function signature like f(Tensor, int), a valid dispatch - * key for the arguments might be [CPUFloatTensor] (notice that 'f' is NOT included - * in the dispatch key, and the runtime type of 'int' is NOT considered for dispatch - * (since it is trivial). - * - * Dispatch keys permit equality tests and are hashable. - * - * @tparam num_dispatch_args The number of dispatchable arguments - */ -template<size_t num_dispatch_args> -struct DispatchKey final { - guts::array<details::TensorParameterDispatchKey, num_dispatch_args> argTypes; -}; - -template<size_t num_dispatch_args> -inline constexpr bool operator==(const DispatchKey<num_dispatch_args> &lhs, const DispatchKey<num_dispatch_args>& rhs) { - // TODO: Use AVX instructions to perform this equality test more quickly - return lhs.argTypes == rhs.argTypes; -} - -template<size_t num_dispatch_args> -inline std::ostream& operator<<(std::ostream& stream, const DispatchKey<num_dispatch_args>& key) { - stream << "DispatchKey("; - if (num_dispatch_args > 0) { - stream << "DispatchKey(" << key.argTypes[0]; - for (size_t i = 1; i < num_dispatch_args; ++i) { - stream << ", " << key.argTypes[i]; - } - stream << ")"; - } - return stream << ")"; -} - -} // namespace c10 - -namespace std { - template<size_t num_dispatch_args> - struct hash<c10::DispatchKey<num_dispatch_args>> { - // TODO constexpr hashing - size_t operator()(const c10::DispatchKey<num_dispatch_args>& obj) const { - size_t hash_value = 0; - for (const auto& argType : obj.argTypes) { - hash_value *= 10883; // prime - hash_value += std::hash<c10::details::TensorParameterDispatchKey>()(argType); - } - return hash_value; - } - }; -} // namespace std diff --git a/c10/core/dispatch/DispatchTable.cpp b/c10/core/dispatch/DispatchTable.cpp deleted file mode 100644 index fc3a86ed1b..0000000000 --- a/c10/core/dispatch/DispatchTable.cpp +++ /dev/null @@ -1 +0,0 @@ -#include <c10/core/dispatch/DispatchTable.h> diff --git a/c10/core/dispatch/DispatchTable.h b/c10/core/dispatch/DispatchTable.h deleted file mode 100644 index 936d780b25..0000000000 --- a/c10/core/dispatch/DispatchTable.h +++ /dev/null @@ -1,154 +0,0 @@ -#pragma once - -#include <c10/core/dispatch/OpSchema.h> -#include <c10/util/LeftRight.h> -#include <c10/util/Metaprogramming.h> -#include <c10/util/flat_hash_map.h> - -#include <array> -#include <atomic> -#include <iostream> -#include <mutex> -#include <type_traits> -#include <unordered_map> - -namespace c10 { - -namespace details { -/// Kernel implementations in a thread-safe hash table. -template <class Key> -class ThreadsafeOperatorTable_ final { - public: - template <class Key_> - void emplace(Key_&& key, void* value) { - bool res = map_.write([&](ska::flat_hash_map<Key, void*>& map) -> bool { - auto result = map.emplace(std::forward<Key>(key), value); - return result.second; - }); - if (!res) { - std::ostringstream msg; - msg << "Tried to register conflicting kernels to the dispatcher: " << key; - throw std::logic_error(msg.str()); - } - } - - void erase(const Key& key) { - auto num_removed = - map_.write([&](ska::flat_hash_map<Key, void*>& map) -> size_t { - return map.erase(key); - }); - assert(num_removed <= 1); // This is not a multi-map - if (num_removed == 0) { - throw std::logic_error( - "Tried to deregister a kernel that isn't registered."); - } - } - - void* lookup(const Key& key) const { - return map_.read([&](const ska::flat_hash_map<Key, void*>& map) -> void* { - auto found = map.find(key); - if (found != map.end()) { - return found->second; - } else { - return nullptr; - } - }); - } - - private: - LeftRight<ska::flat_hash_map<Key, void*>> map_; -}; -} // namespace details - -/** - * Per-operator dispatch table. - * - * Given an operator specified by 'OpSchemaDef', this class records a dispatch - * table for various kernels provided for this operator. For example, if we - * consider the operator add(Tensor, Tensor), the dispatch table for this - * operator may contain implementations for various dynamic tensor types, such - * as (CPUFloatTensor, CPUFloatTensor), (CUDAFloatTensor, CUDAFloatTensor), etc. - * - * @tparam OpSchemaDef The operator signature this dispatch table encodes. - */ -// TODO: Support dispatch for meta-operators (which apply to all dynamic types) -template <class OpSchemaDef> -class DispatchTable final { - private: - using Schema = OpSchema<OpSchemaDef>; - - public: - DispatchTable() : kernels_() {} - - /** - * Register a kernel in the table at some dispatch key. - * @param func Concrete kernel function implementation to register - * @param dispatch_key Dispatch key to define when this kernel is selected - */ - void registerKernel( - typename Schema::signature::func_type* func, - typename Schema::dispatch::dispatch_key_type dispatch_key) { - kernels_.emplace(std::move(dispatch_key), reinterpret_cast<void*>(func)); - } - - /** - * Deregister the kernel for some dispatch key. - * - * @param dispatch_key Dispatch key to unregister. - */ - // TODO: This isn't going to work so well when we get more complicated - // override patterns! In this case, an operator will show up in multiple - // slots, and erasing them one-by-one is probably not such a good idea. - void deregisterKernel( - const typename Schema::dispatch::dispatch_key_type& dispatch_key) { - kernels_.erase(dispatch_key); - } - - /** - * Perform a dynamic dispatch on this table. - * - * @tparam Args Perfect forwarding template arguments to the dispatch - * @param args Arguments to invoke the function with - * @return Returned value of the operator - */ - template <class... Args> - typename Schema::signature::return_type call(Args&&... args) const { - // TODO Better error message, but need to take care that reference arguments - // match non-reference arguments and so on. - // static_assert(std::is_same<typename Schema::return_type (Args...), - // typename Schema::func_type>::value, "Argument types don't match - // operator signature"); - auto kernel_func = lookupKernelFunc_(args...); - return kernel_func(std::forward<Args>(args)...); - } - - private: - template <class... Args> - typename Schema::signature::func_type* lookupKernelFunc_( - const Args&... args) const { - auto dispatch_key = Schema::dispatch::dispatch_key(args...); - void* found = kernels_.lookup(dispatch_key); - if (found == nullptr) { - // TODO Better error message - include op name and dispatch key (i.e. - // argument types) - throw std::logic_error( - std::string() + "Didn't find kernel to dispatch to for operator '" + - Schema::metadata::name() + "'"); - } - return reinterpret_cast<typename Schema::signature::func_type*>(found); - } - - details::ThreadsafeOperatorTable_< - typename Schema::dispatch::dispatch_key_type> - kernels_; -}; - -} // namespace c10 - -/* - * Use this to access the dispatch table singleton for a given op schema. - * It has an implementation for each op schema def in a cpp file, because - * we can't rely on the one-definition-rule. - */ -template <class OpSchemaDef> -C10_API c10::DispatchTable<OpSchemaDef>& c10_dispatch_table(); diff --git a/c10/core/dispatch/Dispatcher.cpp b/c10/core/dispatch/Dispatcher.cpp deleted file mode 100644 index 81fabbce4a..0000000000 --- a/c10/core/dispatch/Dispatcher.cpp +++ /dev/null @@ -1 +0,0 @@ -#include <c10/core/dispatch/Dispatcher.h> diff --git a/c10/core/dispatch/Dispatcher.h b/c10/core/dispatch/Dispatcher.h deleted file mode 100644 index c57d94340b..0000000000 --- a/c10/core/dispatch/Dispatcher.h +++ /dev/null @@ -1,60 +0,0 @@ -#pragma once - -#include <c10/core/dispatch/DispatchTable.h> - -namespace c10 { - -/** - * Top-level dispatch interface for dispatching via the dynamic dispatcher. - */ -template<class OpSchemaDef> -class Dispatcher final { -public: - // Implementation note: this class abstracts over the fact that we have per-operator - // dispatch tables. This could be easily adjusted to have a single global hash - // table. - - /** - * Register an operator to the dispatch table for some operator schema. - * - * @tparam OpSchemaDef Operator schema to register this operator to (mandatory) - * @tparam Args Perfect-forwarding args to c10::dispatch::impl::DispatchTable::registerOp (inferred) - * @param args Perfect-forwarding args to c10::dispatch::impl::DispatchTable::registerOp - * @return void - */ - template<class... Args> - static void registerKernel(Args&&... args) { - auto& dispatch_table_for_this_op = c10_dispatch_table<OpSchemaDef>(); - return dispatch_table_for_this_op.registerKernel(std::forward<Args>(args)...); - } - - /** - * Remove an operator from the dispatch table for some operator schema. - * - * @tparam OpSchemaDef Operator schema to deregister from (mandatory) - * @tparam Args Perfect-forwarding args to c10::dispatch::impl::DispatchTable::deregisterOp (inferred) - * @param args Perfect-forwarding args to c10::dispatch::impl::DispatchTable::deregisterOp - * @return void - */ - template<class... Args> - static void deregisterKernel(Args&&... args) { - auto& dispatch_table_for_this_op = c10_dispatch_table<OpSchemaDef>(); - return dispatch_table_for_this_op.deregisterKernel(std::forward<Args>(args)...); - } - - /** - * Perform a dynamic dispatch to some operator - * - * @tparam OpSchemaDef Operator schema to dispatch with (mandatory) - * @tparam Args Perfect-forwarding args to c10::dispatch::impl::DispatchTable::call (inferred) - * @param args Perfect-forwarding args to c10::dispatch::impl::DispatchTable::call - * @return Return type of this operator - */ - template<class... Args> - static typename OpSchema<OpSchemaDef>::signature::return_type call(Args&&... args) { - auto& dispatch_table_for_this_op = c10_dispatch_table<OpSchemaDef>(); - return dispatch_table_for_this_op.call(std::forward<Args>(args)...); - } -}; - -} // namespace c10 diff --git a/c10/core/dispatch/KernelRegistration.cpp b/c10/core/dispatch/KernelRegistration.cpp deleted file mode 100644 index a5a8a30b5e..0000000000 --- a/c10/core/dispatch/KernelRegistration.cpp +++ /dev/null @@ -1 +0,0 @@ -#include <c10/core/dispatch/KernelRegistration.h> diff --git a/c10/core/dispatch/KernelRegistration.h b/c10/core/dispatch/KernelRegistration.h deleted file mode 100644 index 5b3c9e76b0..0000000000 --- a/c10/core/dispatch/KernelRegistration.h +++ /dev/null @@ -1,141 +0,0 @@ -#pragma once - -#include <c10/util/Optional.h> -#include <c10/core/dispatch/Dispatcher.h> -#include <c10/core/dispatch/OpSchema.h> - -/** - * To register your own kernel for an operator, do in one (!) cpp file: - * C10_REGISTER_KERNEL(OpSchemaDef) - * .kernel(&kernel_func) - * .dispatchKey(dispatch_key); - */ - -namespace c10 { - -// TODO Test different order for builder -// TODO Test no dispatch key defined - -/** - * Class which, on construction, registers an operator in the dispatch table. The intent is that - * this class is constructed at static initialization time so that operators automatically get - * registered when a dlopen() occurs. - * - * You shouldn't call this directly; instead, use the KernelRegistrationBuilder - * - * @tparam OpSchemaDef - */ -template<class OpSchemaDef> -class KernelRegistrar final { -private: - using Schema = OpSchema<OpSchemaDef>; -public: - /** - * @param kernel The concrete function implementation to register - * @param dispatch_key The dispatch key to register the function to - */ - KernelRegistrar(typename Schema::signature::func_type* kernel, typename Schema::dispatch::dispatch_key_type dispatch_key) - : dispatch_key_(std::move(dispatch_key)), owns_registration_(true) { - Dispatcher<OpSchemaDef>::registerKernel(kernel, dispatch_key_); - } - - KernelRegistrar(KernelRegistrar&& rhs) - : dispatch_key_(std::move(rhs.dispatch_key_)), owns_registration_(true) { - rhs.owns_registration_ = false; - } - - // not needed for now - KernelRegistrar& operator=(KernelRegistrar&& rhs) = delete; - - ~KernelRegistrar() { - if (owns_registration_) { - Dispatcher<OpSchemaDef>::deregisterKernel(dispatch_key_); - } - } - -private: - const typename Schema::dispatch::dispatch_key_type dispatch_key_; - bool owns_registration_; - - C10_DISABLE_COPY_AND_ASSIGN(KernelRegistrar); -}; - -/** - * Helper class for building a KernelRegistrar. This permits "keyword-argument" like syntax - * when performing operator registration, e.g., as in: - * - * C10_REGISTER_KERNEL(::ops::add_notensor) - * .kernel(&add_notensor_op) - * .dispatchKey("bla"); - * - * Expanded, this macro invocation looks like: - * - * static KernelRegistrar<::ops::add_notensor> _anon0 = - * KernelRegistrationBuilder<::ops::add_notensor, false, false>() - * .kernel(&add_notensor_op) - * .dispatchKey("bla"); - * - * The resulting full expression is implicitly convertible to a KernelRegistrar. - * - * @tparam OpSchemaDef The operator schema this is building a KernelRegistration for - * @tparam hasKernel Boolean for compile-time checking that a kernel is specified before finalizing the builder - * @tparam hasDispatchKey Boolean for compile-time checking thhat a dispatch key is specified before finalizing the builder - */ -template<class OpSchemaDef, uint64_t FieldsPresentFlags> -class KernelRegistrationBuilder final { -private: - using Schema = OpSchema<OpSchemaDef>; - - static constexpr uint64_t KERNEL_PRESENT = 0x01 << 0; - static constexpr uint64_t DISPATCH_KEY_PRESENT = 0x01 << 1; - - c10::optional<typename Schema::signature::func_type*> kernel_; - c10::optional<typename Schema::dispatch::dispatch_key_type> dispatch_key_; - - public: - constexpr KernelRegistrationBuilder() - : KernelRegistrationBuilder(c10::nullopt, c10::nullopt) {} - - constexpr KernelRegistrationBuilder( - c10::optional<typename Schema::signature::func_type*> kernel, - c10::optional<typename Schema::dispatch::dispatch_key_type> dispatch_key) - : kernel_(std::move(kernel)), dispatch_key_(std::move(dispatch_key)) {} - - /** - * Implicit coercion to KernelRegistrar<OpSchemaDef> that finalizes the builder and - * creates the object. - * @return Produced KernelRegistrar - */ - constexpr operator KernelRegistrar<OpSchemaDef>() && { - static_assert(FieldsPresentFlags & KERNEL_PRESENT, "Forgot to call .kernel() in kernel registration"); - static_assert(FieldsPresentFlags & DISPATCH_KEY_PRESENT, "Forgot to call .dispatchKey() in kernel registration"); - return KernelRegistrar<OpSchemaDef>(std::move(*kernel_), std::move(*dispatch_key_)); - } - - /** - * Specify the concrete function implementation for this dispatch registration - * @param kernel concrete function implementation to be registered - * @return "this" for method chaining - */ - constexpr KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | KERNEL_PRESENT> kernel(typename Schema::signature::func_type* kernel_func) && { - static_assert(!(FieldsPresentFlags & KERNEL_PRESENT), "Tried to define kernel twice in same op registration"); - return KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | KERNEL_PRESENT>(*kernel_func, std::move(dispatch_key_)); - } - - /** - * Specify the dispatch key for this dispatch registration - * @param dispatch_key dispatch key to register the function to - * @return "this" for method chaining - */ - constexpr KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | DISPATCH_KEY_PRESENT> dispatchKey(typename Schema::dispatch::dispatch_key_type dispatch_key) && { - static_assert(!(FieldsPresentFlags & DISPATCH_KEY_PRESENT), "Tried to define kernel twice in same op registration"); - return KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | DISPATCH_KEY_PRESENT>(std::move(kernel_), std::move(dispatch_key)); - } -}; - -} // namespace c10 - -// TODO Can the builder logic be moved to compile time? -// NB: Semicolon after applying this macro is MANDATORY -#define C10_REGISTER_KERNEL(OpSchemaDef) \ - static KernelRegistrar<OpSchemaDef> MACRO_CONCAT(__kernelRegistrationBuilder_, __COUNTER__) = KernelRegistrationBuilder<OpSchemaDef, 0>() diff --git a/c10/core/dispatch/LayoutId.cpp b/c10/core/dispatch/LayoutId.cpp deleted file mode 100644 index 15396ee955..0000000000 --- a/c10/core/dispatch/LayoutId.cpp +++ /dev/null @@ -1 +0,0 @@ -#include <c10/core/dispatch/LayoutId.h> diff --git a/c10/core/dispatch/LayoutId.h b/c10/core/dispatch/LayoutId.h deleted file mode 100644 index d0648392f4..0000000000 --- a/c10/core/dispatch/LayoutId.h +++ /dev/null @@ -1,22 +0,0 @@ -#pragma once - -#include <c10/util/IdWrapper.h> - -namespace c10 { - -class LayoutId final : public at::IdWrapper<LayoutId, uint8_t> { -public: - constexpr explicit LayoutId(underlying_type id): IdWrapper(id) {} - - constexpr uint8_t value() const { - return underlyingId(); - } - - // Don't use this default constructor! - // Unfortunately, a default constructor needs to be defined because of https://reviews.llvm.org/D41223 - constexpr LayoutId(): IdWrapper(0) {} -}; - -} - -C10_DEFINE_HASH_FOR_IDWRAPPER(c10::LayoutId) diff --git a/c10/core/dispatch/OpSchema.cpp b/c10/core/dispatch/OpSchema.cpp deleted file mode 100644 index c5feae4914..0000000000 --- a/c10/core/dispatch/OpSchema.cpp +++ /dev/null @@ -1 +0,0 @@ -#include <c10/core/dispatch/OpSchema.h> diff --git a/c10/core/dispatch/OpSchema.h b/c10/core/dispatch/OpSchema.h deleted file mode 100644 index 73ec4fe537..0000000000 --- a/c10/core/dispatch/OpSchema.h +++ /dev/null @@ -1,263 +0,0 @@ -#pragma once - -#include <c10/core/dispatch/DispatchKey.h> -#include <c10/util/Array.h> -#include <c10/util/Metaprogramming.h> -#include <c10/core/DeviceType.h> -#include <c10/core/Tensor.h> - -namespace c10 { - -namespace details { - -/** - * If Arg is a Tensor or reference to a Tensor, provide the member constant value equal to true. Otherwise - * return false. - */ -template <class Arg> -using is_tensor_arg = std:: - is_same<C10Tensor, guts::remove_cv_t<guts::remove_reference_t<Arg>>>; - -inline DeviceTypeId to_device_type_id(DeviceType device_type) { - switch (device_type) { - case DeviceType::CPU: - return DeviceTypeId::CPU; - case DeviceType::CUDA: - return DeviceTypeId::CUDA; - default: - return DeviceTypeId::UNDEFINED; - } -} - -inline TensorParameterDispatchKey tensor_to_dispatch_key(const C10Tensor& tensor) { - return TensorParameterDispatchKey{ - to_device_type_id(tensor.impl()->device_type()), - LayoutId(0), - tensor.impl()->dtype().id()}; -} - -// Extract type ids for all tensors from an array of tensors -template<size_t num_dispatch_args, size_t num_tensor_args, size_t... indices> -guts::array<TensorParameterDispatchKey, num_dispatch_args> getDispatchTypeIds__(const guts::array<const C10Tensor*, num_tensor_args>& tensor_args, guts::index_sequence<indices...>) { - return {tensor_to_dispatch_key(*tensor_args[indices])...}; -} - -/** - * Extract the type ids of all tensors in a variadic list of arguments - * - * @tparam Args Inferred variadic list of argument types - * @param args List of arguments to get type ids from - * @return guts::array<TensorParameterDispatchKey, n>, where n is the number of tensor arguments (is_tensor_arg) in the class - */ -template<size_t num_dispatch_args, class... Args> -guts::array<TensorParameterDispatchKey, num_dispatch_args> getDispatchTypeIds_(const Args&... args) { - auto tensor_args = guts::filter_map<const C10Tensor*, is_tensor_arg>([] (const C10Tensor& v){return &v;}, args...); - return getDispatchTypeIds__<num_dispatch_args>(tensor_args, guts::make_index_sequence<num_dispatch_args>()); -} - -// TODO Test getDispatchTypeIds_ - -/** - * If T is a struct with a type field Signature, provides the member constant - * @tparam T - */ -template<class T, typename = void> -struct has_signature_defined : std::false_type {}; -template<class T> -struct has_signature_defined<T, guts::void_t< - typename T::Signature ->> : std::true_type {}; - -// TODO Test has_signature_defined - -template<class T, typename = void> -struct has_parameter_names_defined : std::false_type {}; -template<class T> -struct has_parameter_names_defined<T, guts::void_t< - decltype(T::parameter_names) ->> : std::true_type {}; - -// TODO Test has_parameter_names_defined - -template<class T, typename = void> -struct has_name_defined : std::false_type {}; -template<class T> -struct has_name_defined<T, guts::void_t< - decltype(T::name) ->> : std::true_type {}; - -// TODO Test has_name_defined - -/** - * Wrapper class around a user-provided schema definition some useful information about the schema. - * - * @tparam OpSchemaDef Operator schema definition. See OpSchema for more details. - */ -template<class OpSchemaDef> class OpSignatureSchema final { - static_assert(details::has_signature_defined<OpSchemaDef>::value, "Operator schema doesn't define a valid Signature member type."); - static_assert(guts::is_function_type<typename OpSchemaDef::Signature>::value, "Signature member of operator schema must be a function type."); - - using signature_traits = guts::function_traits<typename OpSchemaDef::Signature>; -public: - /** - * The function type OpSchemaDef::Signature - */ - using func_type = typename signature_traits::func_type; - /** - * The return type of the function OpSchemaDef::Signature - */ - using return_type = typename signature_traits::return_type; - /** - * A type list of the parameter types of OpSchemaDef::Signature - */ - using parameter_types = typename signature_traits::parameter_types; - - /** - * The number of arguments of OpSchemaDef::Signature - */ - static constexpr size_t num_args = guts::typelist::size<parameter_types>::value; - /** - * The number of tensor arguments (as per is_tensor_arg) in OpSchemaDef::Signature - */ - static constexpr size_t num_tensor_args = guts::typelist::count_if<details::is_tensor_arg, parameter_types>::value; - - static constexpr size_t num_outputs = OpSchemaDef::num_outputs(); - -private: - static_assert(details::has_parameter_names_defined<OpSchemaDef>::value, "Operator schema doesn't define parameter_names member."); - // TODO Allow simpler definition of parameter_names without having to spell out the guts::array type in the schema def. - static_assert(std::is_same<const guts::array<const char*, num_args>, decltype(OpSchemaDef::parameter_names)>::value, "Operator schema defines parameter_names member, but it isn't the correct type. Must be a static constexpr guts::array of const char* with one entry for each parameter."); - -public: - /** - * The names of the parameters (as per OpSchemaDef::parameter_names) - * @return Array - */ - static constexpr const guts::array<const char*, num_args>& parameter_names() { - return OpSchemaDef::parameter_names; - } -}; - -/** - * If T has a method dispatch_key, provide a member constant value equal to true. Otherwise return false. - * @tparam T - */ -template<class T, typename = void> -struct has_function_dispatch_key_defined : std::false_type {}; -template<class T> -struct has_function_dispatch_key_defined<T, guts::void_t< - decltype(&T::dispatch_key) ->> : std::true_type {}; - -/** - * Wrapper class around a user-defined schema definition providing a way of computing a dispatch key - * from arguments matching the signature of that schema. - * - * @tparam OpSchemaDef Operator schema definition. See OpSchema for more details. - * @tparam Enable Inferred, used to control specialization - */ -template<class OpSchemaDef, class Enable = void> class OpDispatchKeySchema final {}; - -// General case. Operator doesn't overwrite DispatchKey generation. Use default. -template<class OpSchemaDef> -class OpDispatchKeySchema<OpSchemaDef, guts::enable_if_t<!has_function_dispatch_key_defined<OpSchemaDef>::value>> final { - using signature = OpSignatureSchema<OpSchemaDef>; - - // TODO Static assert that dispatch_key_type has operator<<(ostream, _) defined for debug output. - // TODO Use an ADL-based debugString(DispatchKey) function instead of operator<< for debug printing. - -public: - using dispatch_key_type = DispatchKey<OpSchemaDef::num_dispatch_args()>; - - template<class... Args> - static inline dispatch_key_type dispatch_key(const Args&... args) { - using guts::typelist::map_t; - using guts::typelist::typelist; - static_assert(std::is_same< - map_t<guts::remove_cv_t, map_t<guts::remove_reference_t, typelist<Args...>>>, - map_t<guts::remove_cv_t, map_t<guts::remove_reference_t, typename signature::parameter_types>> - >::value, "Invalid argument types passed to OpSchema::dispatch_key()"); - return dispatch_key_type { - details::getDispatchTypeIds_<OpSchemaDef::num_dispatch_args()>(args...) - }; - } -}; - -// Special case. Operator overwrites DispatchKey generation. Use that. -template<class OpSchemaDef> -class OpDispatchKeySchema<OpSchemaDef, guts::enable_if_t<has_function_dispatch_key_defined<OpSchemaDef>::value>> final { - using signature = OpSignatureSchema<OpSchemaDef>; - - static_assert(guts::is_function_type<decltype(OpSchemaDef::dispatch_key)>::value, "Operator schema defines dispatch_key member, but it isn't a function."); - - using dispatch_key_traits = guts::function_traits<decltype(OpSchemaDef::dispatch_key)>; - -public: - using dispatch_key_type = typename dispatch_key_traits::return_type; - -private: - - static_assert(guts::is_equality_comparable<dispatch_key_type>::value, "Operator schema specified custom dispatch_key() derivation function, but the returned dispatch key type doesn't have the equality operator defined. Please define it."); - static_assert(guts::is_hashable<dispatch_key_type>::value, "Operator schema specified custom dispatch_key() derivation function, but the returned dispatch key type doesn't have an overload for std::hash. Please define it."); - - static_assert(std::is_same< - guts::typelist::map_t<guts::remove_cv_t, guts::typelist::map_t<guts::remove_reference_t, typename dispatch_key_traits::parameter_types>>, - guts::typelist::map_t<guts::remove_cv_t, guts::typelist::map_t<guts::remove_reference_t, typename signature::parameter_types>> - >::value, "Operator schema defines custom dispatch_key() derivation function, but the arguments don't match the operator signature."); - -public: - - template<class... Args> - static inline dispatch_key_type dispatch_key(const Args&... args) { - using guts::typelist::map_t; - using guts::typelist::typelist; - static_assert(std::is_same< - map_t<guts::remove_cv_t, map_t<guts::remove_reference_t, typelist<Args...>>>, - map_t<guts::remove_cv_t, map_t<guts::remove_reference_t, typename signature::parameter_types>> - >::value, "Invalid argument types passed to OpSchema::dispatch_key()"); - return OpSchemaDef::dispatch_key(args...); - } -}; - -template<class OpSchemaDef> -class OpMetadataSchema final { -private: - static_assert(has_name_defined<OpSchemaDef>::value, "The operator schema has to define a 'static constexpr const char* name = ...' member to specify the operator name."); - static_assert(std::is_same<const char* const, decltype(OpSchemaDef::name)>::value, "The 'name' member of the operator schema must have type 'static constexpr const char*'"); - -public: - static constexpr const char* name() { - return OpSchemaDef::name; - } -}; - -} // namespace details - -/** - * Wrapper class for user-defined OpSchemaDef, providing functionality for determining - * information about the signature and dispatching on that signature. This is the - * "public" facing class. - * - * @tparam OpSchemaDef User-defined OpSchemaDef. - * This struct is expected to define: - * - a function type Signature - * - a constexpr guts<const char*, n_args> parameter_names field (where n_args is - * the number of arguments in Signature) - */ -template <class OpSchemaDef> -class CAFFE2_API OpSchema final { - // TODO static_assert OpSchemaDef isn't an instanciation of OpSchema. If yes, the caller probably passed an OpSchema somewhere where an OpSchemaDef was expected and wants a good error message. -public: - using metadata = details::OpMetadataSchema<OpSchemaDef>; - /** - * Information about the signature - */ - using signature = details::OpSignatureSchema<OpSchemaDef>; - /** - * Functionality for dispatching on that signature - */ - using dispatch = details::OpDispatchKeySchema<OpSchemaDef>; -}; - -// TODO test OpSchema::dispatch stuff -} // namespace c10 diff --git a/c10/core/dispatch/OpSchemaRegistration.cpp b/c10/core/dispatch/OpSchemaRegistration.cpp deleted file mode 100644 index 468fa89c53..0000000000 --- a/c10/core/dispatch/OpSchemaRegistration.cpp +++ /dev/null @@ -1 +0,0 @@ -#include <c10/core/dispatch/OpSchemaRegistration.h> diff --git a/c10/core/dispatch/OpSchemaRegistration.h b/c10/core/dispatch/OpSchemaRegistration.h deleted file mode 100644 index dc81c05d31..0000000000 --- a/c10/core/dispatch/OpSchemaRegistration.h +++ /dev/null @@ -1,18 +0,0 @@ -#pragma once - -#include <c10/core/dispatch/Dispatcher.h> - -// TODO Better error message when this definition is missing - -/** - * Macro for defining an operator schema. Every user-defined OpSchemaDef struct must - * invoke this macro on it. Internally, this arranges for the dispatch table for - * the operator to be created. - */ -#define C10_DEFINE_OP_SCHEMA(OpSchemaDef) \ - template<> \ - C10_EXPORT c10::DispatchTable<OpSchemaDef>& c10_dispatch_table<OpSchemaDef>() { \ - static c10::DispatchTable<OpSchemaDef> singleton; \ - return singleton; \ - } -// TODO Also register unboxed calling API here diff --git a/c10/core/opschema/layer_norm.cpp b/c10/core/opschema/layer_norm.cpp deleted file mode 100644 index c8f71ba509..0000000000 --- a/c10/core/opschema/layer_norm.cpp +++ /dev/null @@ -1,4 +0,0 @@ -#include <c10/core/opschema/layer_norm.h> -#include <c10/core/dispatch/OpSchemaRegistration.h> - -C10_DEFINE_OP_SCHEMA(c10::core::opschema::LayerNorm); diff --git a/c10/core/opschema/layer_norm.h b/c10/core/opschema/layer_norm.h deleted file mode 100644 index d80c9650b3..0000000000 --- a/c10/core/opschema/layer_norm.h +++ /dev/null @@ -1,40 +0,0 @@ -#pragma once - -#include <c10/core/Tensor.h> -#include <c10/util/Array.h> - -namespace c10 { -namespace core { -namespace opschema { - -// TODO This op schema should probably not live in c10 since it's not a method -// on Tensor. It's only here as a proof-of-concept op and for LATTE team -// to be able to call caffe2 layer norm from PyTorch. -struct LayerNorm final { - static constexpr const char* name = "LayerNorm"; - - struct Cache final { - at::optional<C10Tensor> scale = at::nullopt; - at::optional<C10Tensor> bias = at::nullopt; - }; - - using Signature = void( - const C10Tensor& input, - const C10Tensor& output, - const C10Tensor& output_mean, - const C10Tensor& output_stddev, - int axis, - float epsilon, - Cache* cache); - - static constexpr size_t num_dispatch_args() {return 1;} - - static constexpr size_t num_outputs() {return 3;} - - static constexpr c10::guts::array<const char*, 7> parameter_names = { - {"input", "output", "output_mean", "output_stddev", "axis", "epsilon", "cache"}}; -}; - -} // namespace opschema -} // namespace core -} // namespace c10 diff --git a/c10/test/core/dispatch/OpSchema_test.cpp b/c10/test/core/dispatch/OpSchema_test.cpp deleted file mode 100644 index d10f7fc223..0000000000 --- a/c10/test/core/dispatch/OpSchema_test.cpp +++ /dev/null @@ -1,27 +0,0 @@ -#include <c10/core/dispatch/OpSchema.h> -#include <c10/util/Array.h> - -using namespace c10; - -static_assert(details::is_tensor_arg<C10Tensor>::value, ""); -static_assert(details::is_tensor_arg<const C10Tensor&>::value, ""); -static_assert(details::is_tensor_arg<C10Tensor&&>::value, ""); -static_assert(!details::is_tensor_arg<int>::value, ""); - -struct SchemaDef final { - using Signature = bool(int, C10Tensor, float, C10Tensor, C10Tensor, unsigned int); - static constexpr guts::array<const char*, 6> parameter_names = {{ - "1", "2", "3", "4", "5", "6" - }}; - static constexpr size_t num_dispatch_args() {return 3;} - static constexpr size_t num_outputs() {return 0;} -}; -static_assert(6 == OpSchema<SchemaDef>::signature::num_args, ""); -static_assert(3 == OpSchema<SchemaDef>::signature::num_tensor_args, ""); -static_assert(std::is_same<bool, typename OpSchema<SchemaDef>::signature::return_type>::value, ""); -static_assert( - std::is_same< - guts::typelist:: - typelist<int, C10Tensor, float, C10Tensor, C10Tensor, unsigned int>, - typename OpSchema<SchemaDef>::signature::parameter_types>::value, - ""); |