summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--aten/src/ATen/core/dispatch/DeviceId.cpp1
-rw-r--r--aten/src/ATen/core/dispatch/DeviceId.h36
-rw-r--r--aten/src/ATen/core/dispatch/DispatchKey.cpp1
-rw-r--r--aten/src/ATen/core/dispatch/DispatchKey.h97
-rw-r--r--aten/src/ATen/core/dispatch/DispatchTable.cpp1
-rw-r--r--aten/src/ATen/core/dispatch/DispatchTable.h159
-rw-r--r--aten/src/ATen/core/dispatch/Dispatcher.cpp7
-rw-r--r--aten/src/ATen/core/dispatch/Dispatcher.h143
-rw-r--r--aten/src/ATen/core/dispatch/KernelCache.h20
-rw-r--r--aten/src/ATen/core/dispatch/KernelFunction.h16
-rw-r--r--aten/src/ATen/core/dispatch/KernelRegistration.cpp1
-rw-r--r--aten/src/ATen/core/dispatch/KernelRegistration.h213
-rw-r--r--aten/src/ATen/core/dispatch/LayoutId.cpp1
-rw-r--r--aten/src/ATen/core/dispatch/LayoutId.h22
-rw-r--r--aten/src/ATen/core/dispatch/OpSchema.cpp1
-rw-r--r--aten/src/ATen/core/dispatch/OpSchema.h406
-rw-r--r--aten/src/ATen/core/dispatch/OpSchemaRegistration.cpp1
-rw-r--r--aten/src/ATen/core/dispatch/OpSchemaRegistration.h39
-rw-r--r--aten/src/ATen/core/dispatch/OpSchema_test.cpp29
-rw-r--r--aten/src/ATen/core/dispatch/README.md12
-rw-r--r--aten/src/ATen/core/jit_type.h66
-rw-r--r--aten/src/ATen/core/opschema/layer_norm.cpp23
-rw-r--r--aten/src/ATen/core/opschema/layer_norm.h27
-rw-r--r--aten/src/ATen/core/stack.h3
-rw-r--r--c10/test/util/TypeList_test.cpp6
-rw-r--r--c10/util/TypeList.h55
-rw-r--r--caffe2/core/operator.h10
-rw-r--r--caffe2/core/operator_c10wrapper.h146
-rw-r--r--caffe2/operators/experimental/c10/cpu/add_cpu.cc13
-rw-r--r--caffe2/operators/experimental/c10/cpu/averaged_loss_cpu.cc13
-rw-r--r--caffe2/operators/experimental/c10/cpu/batch_gather_cpu.cc32
-rw-r--r--caffe2/operators/experimental/c10/cpu/batch_matmul_cpu.cc23
-rw-r--r--caffe2/operators/experimental/c10/cpu/cast_cpu.cc66
-rw-r--r--caffe2/operators/experimental/c10/cpu/concat_cpu.cc8
-rw-r--r--caffe2/operators/experimental/c10/cpu/enforce_finite_cpu.cc6
-rw-r--r--caffe2/operators/experimental/c10/cpu/expand_dims_cpu.cc40
-rw-r--r--caffe2/operators/experimental/c10/cpu/fc_cpu.cc36
-rw-r--r--caffe2/operators/experimental/c10/cpu/filler_cpu.cc40
-rw-r--r--caffe2/operators/experimental/c10/cpu/flatten_cpu.cc8
-rw-r--r--caffe2/operators/experimental/c10/cpu/mul_cpu.cc13
-rw-r--r--caffe2/operators/experimental/c10/cpu/relu_cpu.cc6
-rw-r--r--caffe2/operators/experimental/c10/cpu/sigmoid_cpu.cc6
-rw-r--r--caffe2/operators/experimental/c10/cpu/sigmoid_cross_entropy_with_logits_cpu.cc11
-rw-r--r--caffe2/operators/experimental/c10/cpu/sparse_lengths_sum_cpu.cc83
-rw-r--r--caffe2/operators/experimental/c10/cpu/stop_gradient_cpu.cc6
-rw-r--r--caffe2/operators/experimental/c10/schemas/add.cc19
-rw-r--r--caffe2/operators/experimental/c10/schemas/add.h22
-rw-r--r--caffe2/operators/experimental/c10/schemas/averaged_loss.cc18
-rw-r--r--caffe2/operators/experimental/c10/schemas/averaged_loss.h22
-rw-r--r--caffe2/operators/experimental/c10/schemas/batch_gather.cc19
-rw-r--r--caffe2/operators/experimental/c10/schemas/batch_gather.h20
-rw-r--r--caffe2/operators/experimental/c10/schemas/batch_matmul.cc19
-rw-r--r--caffe2/operators/experimental/c10/schemas/batch_matmul.h30
-rw-r--r--caffe2/operators/experimental/c10/schemas/cast.cc16
-rw-r--r--caffe2/operators/experimental/c10/schemas/cast.h20
-rw-r--r--caffe2/operators/experimental/c10/schemas/concat.cc18
-rw-r--r--caffe2/operators/experimental/c10/schemas/concat.h29
-rw-r--r--caffe2/operators/experimental/c10/schemas/enforce_finite.cc16
-rw-r--r--caffe2/operators/experimental/c10/schemas/enforce_finite.h16
-rw-r--r--caffe2/operators/experimental/c10/schemas/expand_dims.cc16
-rw-r--r--caffe2/operators/experimental/c10/schemas/expand_dims.h23
-rw-r--r--caffe2/operators/experimental/c10/schemas/fc.cc19
-rw-r--r--caffe2/operators/experimental/c10/schemas/fc.h25
-rw-r--r--caffe2/operators/experimental/c10/schemas/filler.cc84
-rw-r--r--caffe2/operators/experimental/c10/schemas/filler.h102
-rw-r--r--caffe2/operators/experimental/c10/schemas/flatten.cc16
-rw-r--r--caffe2/operators/experimental/c10/schemas/flatten.h20
-rw-r--r--caffe2/operators/experimental/c10/schemas/layer_norm.cc1
-rw-r--r--caffe2/operators/experimental/c10/schemas/mul.cc18
-rw-r--r--caffe2/operators/experimental/c10/schemas/mul.h22
-rw-r--r--caffe2/operators/experimental/c10/schemas/relu.cc18
-rw-r--r--caffe2/operators/experimental/c10/schemas/relu.h17
-rw-r--r--caffe2/operators/experimental/c10/schemas/sigmoid.cc17
-rw-r--r--caffe2/operators/experimental/c10/schemas/sigmoid.h17
-rw-r--r--caffe2/operators/experimental/c10/schemas/sigmoid_cross_entropy_with_logits.cc18
-rw-r--r--caffe2/operators/experimental/c10/schemas/sigmoid_cross_entropy_with_logits.h21
-rw-r--r--caffe2/operators/experimental/c10/schemas/sparse_lengths_sum.cc19
-rw-r--r--caffe2/operators/experimental/c10/schemas/sparse_lengths_sum.h20
-rw-r--r--caffe2/operators/experimental/c10/schemas/stop_gradient.cc18
-rw-r--r--caffe2/operators/experimental/c10/schemas/stop_gradient.h19
-rw-r--r--caffe2/operators/layer_norm_op.cc24
-rw-r--r--torch/csrc/jit/c10_ops/layer_norm.cpp27
82 files changed, 1197 insertions, 1620 deletions
diff --git a/aten/src/ATen/core/dispatch/DeviceId.cpp b/aten/src/ATen/core/dispatch/DeviceId.cpp
deleted file mode 100644
index 7c65fbe70c..0000000000
--- a/aten/src/ATen/core/dispatch/DeviceId.cpp
+++ /dev/null
@@ -1 +0,0 @@
-#include <ATen/core/dispatch/DeviceId.h>
diff --git a/aten/src/ATen/core/dispatch/DeviceId.h b/aten/src/ATen/core/dispatch/DeviceId.h
deleted file mode 100644
index cdcaf4b635..0000000000
--- a/aten/src/ATen/core/dispatch/DeviceId.h
+++ /dev/null
@@ -1,36 +0,0 @@
-#pragma once
-
-#include <c10/util/C++17.h>
-#include <functional>
-#include <iostream>
-
-namespace c10 {
-
-enum class DeviceTypeId : uint8_t {
- // Don't use the int values here in the enum (i.e. don't do static_cast to or from int).
- // Instead, if you want to serialize this, write a function with switch/case.
- CPU = 0,
- CUDA = 1,
- UNDEFINED
-};
-
-inline std::ostream& operator<<(std::ostream& stream, DeviceTypeId device_type_id) {
- switch(device_type_id) {
- case c10::DeviceTypeId::CPU: return stream << "DeviceTypeId(CPU)";
- case c10::DeviceTypeId::CUDA: return stream << "DeviceTypeId(CUDA)";
- case c10::DeviceTypeId::UNDEFINED: return stream << "DeviceTypeId(UNDEFINED)";
- }
- throw std::logic_error("Unknown DeviceTypeId: " + c10::guts::to_string(static_cast<int>(device_type_id)));
-}
-
-}
-
-namespace std {
-
-template <> struct hash<c10::DeviceTypeId> {
- size_t operator()(c10::DeviceTypeId v) const {
- return std::hash<uint8_t>()(static_cast<uint8_t>(v));
- }
-};
-
-}
diff --git a/aten/src/ATen/core/dispatch/DispatchKey.cpp b/aten/src/ATen/core/dispatch/DispatchKey.cpp
deleted file mode 100644
index 33e8e29b28..0000000000
--- a/aten/src/ATen/core/dispatch/DispatchKey.cpp
+++ /dev/null
@@ -1 +0,0 @@
-#include <ATen/core/dispatch/DispatchKey.h>
diff --git a/aten/src/ATen/core/dispatch/DispatchKey.h b/aten/src/ATen/core/dispatch/DispatchKey.h
deleted file mode 100644
index 7c6cf15f61..0000000000
--- a/aten/src/ATen/core/dispatch/DispatchKey.h
+++ /dev/null
@@ -1,97 +0,0 @@
-#pragma once
-
-#include <ATen/core/dispatch/DeviceId.h>
-#include <ATen/core/dispatch/LayoutId.h>
-#include <c10/util/typeid.h>
-
-#include <vector>
-#include <functional>
-#include <sstream>
-#include <c10/util/Array.h>
-
-namespace c10 {
-
-namespace details {
-
-struct TensorParameterDispatchKey final {
- // note: This dispatch key structure is not final yet and will change. Don't rely on it.
- DeviceTypeId deviceTypeId;
- LayoutId layoutId;
- // TODO Move this TypeIdentifier to c10 namespace
- caffe2::TypeIdentifier dataType;
-};
-inline constexpr bool operator==(const TensorParameterDispatchKey& lhs, const TensorParameterDispatchKey& rhs) {
- return lhs.deviceTypeId == rhs.deviceTypeId && lhs.layoutId == rhs.layoutId && lhs.dataType == rhs.dataType;
-}
-
-inline std::ostream& operator<<(std::ostream& stream, const TensorParameterDispatchKey& key) {
- return stream << "TensorKey(" << key.deviceTypeId << ", " << key.layoutId.value() << ", " << key.dataType << ")";
-}
-
-} // namespace details
-} // namespace c10
-
-namespace std {
- template<>
- struct hash<c10::details::TensorParameterDispatchKey> {
- // TODO constexpr hashing
- size_t operator()(const c10::details::TensorParameterDispatchKey& obj) const {
- return std::hash<c10::DeviceTypeId>()(obj.deviceTypeId) ^ std::hash<c10::LayoutId>()(obj.layoutId) ^ std::hash<caffe2::TypeIdentifier>()(obj.dataType);
- }
- };
-} // namespace std
-
-namespace c10 {
-/**
- * The dispatch key encodes the runtime type identity of a function call arguments,
- * specifying what aspects of this identity can be dynamically dispatched on.
- *
- * Intuitively, given a function signature like f(Tensor, int), a valid dispatch
- * key for the arguments might be [CPUFloatTensor] (notice that 'f' is NOT included
- * in the dispatch key, and the runtime type of 'int' is NOT considered for dispatch
- * (since it is trivial).
- *
- * Dispatch keys permit equality tests and are hashable.
- *
- * @tparam num_dispatch_args The number of dispatchable arguments
- */
-template<size_t num_dispatch_args>
-struct DispatchKey final {
- guts::array<details::TensorParameterDispatchKey, num_dispatch_args> argTypes;
-};
-
-template<size_t num_dispatch_args>
-inline constexpr bool operator==(const DispatchKey<num_dispatch_args> &lhs, const DispatchKey<num_dispatch_args>& rhs) {
- // TODO: Use AVX instructions to perform this equality test more quickly
- return lhs.argTypes == rhs.argTypes;
-}
-
-template<size_t num_dispatch_args>
-inline std::ostream& operator<<(std::ostream& stream, const DispatchKey<num_dispatch_args>& key) {
- stream << "DispatchKey(";
- if (num_dispatch_args > 0) {
- stream << "DispatchKey(" << key.argTypes[0];
- for (size_t i = 1; i < num_dispatch_args; ++i) {
- stream << ", " << key.argTypes[i];
- }
- stream << ")";
- }
- return stream << ")";
-}
-
-} // namespace c10
-
-namespace std {
- template<size_t num_dispatch_args>
- struct hash<c10::DispatchKey<num_dispatch_args>> {
- // TODO constexpr hashing
- size_t operator()(const c10::DispatchKey<num_dispatch_args>& obj) const {
- size_t hash_value = 0;
- for (const auto& argType : obj.argTypes) {
- hash_value *= 10883; // prime
- hash_value += std::hash<c10::details::TensorParameterDispatchKey>()(argType);
- }
- return hash_value;
- }
- };
-} // namespace std
diff --git a/aten/src/ATen/core/dispatch/DispatchTable.cpp b/aten/src/ATen/core/dispatch/DispatchTable.cpp
deleted file mode 100644
index 0bf7c5903d..0000000000
--- a/aten/src/ATen/core/dispatch/DispatchTable.cpp
+++ /dev/null
@@ -1 +0,0 @@
-#include <ATen/core/dispatch/DispatchTable.h>
diff --git a/aten/src/ATen/core/dispatch/DispatchTable.h b/aten/src/ATen/core/dispatch/DispatchTable.h
index 22206658c5..c005314c27 100644
--- a/aten/src/ATen/core/dispatch/DispatchTable.h
+++ b/aten/src/ATen/core/dispatch/DispatchTable.h
@@ -1,28 +1,29 @@
#pragma once
-#include <ATen/core/dispatch/OpSchema.h>
+#include <ATen/core/function_schema.h>
#include <c10/util/LeftRight.h>
#include <c10/util/Metaprogramming.h>
#include <c10/util/flat_hash_map.h>
#include <ATen/core/ivalue.h>
+#include <ATen/core/dispatch/KernelFunction.h>
#include <array>
#include <atomic>
#include <iostream>
#include <mutex>
#include <type_traits>
+#include <sstream>
#include <unordered_map>
namespace c10 {
/**
* The type of a user-supplied function to initialize the kernel cache.
- * this is stored together with the kernel function in the dispatch table
+ * this is stored together with the KernelFunction in the DispatchTable
* so we can create a new cache instance when a kernel is looked up
* from the dispatch table.
*/
-using KernelStateCreatorFunction = std::unique_ptr<c10::KernelState> ();
-
+using KernelCacheCreatorFunction = std::unique_ptr<c10::KernelCache> ();
/**
* The dispatch table stores a pointer to a kernel function and a pointer
* to a function initializing a cache for the kernel. If the kernel wants
@@ -32,72 +33,100 @@ using KernelStateCreatorFunction = std::unique_ptr<c10::KernelState> ();
* this same cache instance.
*/
struct DispatchTableEntry final {
- KernelFunction* kernel_func;
- KernelStateCreatorFunction* state_creator_func;
+ /*not-nullable*/ KernelFunction* kernel_func;
+ /*not-nullable*/ KernelCacheCreatorFunction* cache_creator_func;
};
-namespace details {
+namespace detail {
/// Kernel implementations in a thread-safe hash table.
-template <class Key>
class ThreadsafeOperatorTable_ final {
public:
- template <class Key_>
- void emplace(Key_&& key, const DispatchTableEntry& value) {
- bool res = map_.write([&](ska::flat_hash_map<Key, DispatchTableEntry>& map) -> bool {
- auto result = map.emplace(std::forward<Key>(key), value);
+ void emplace(TensorTypeId key, const DispatchTableEntry& value, const std::string& operator_name) {
+ bool res = map_.write([&](ska::flat_hash_map<TensorTypeId, DispatchTableEntry>& map) -> bool {
+ auto result = map.emplace(key, value);
return result.second;
});
if (!res) {
- AT_ERROR("Tried to register conflicting kernels to the dispatcher: ", key);
+ AT_ERROR("Tried to register multiple kernels with same dispatch key '",
+ dispatch_key_to_string(key), "' for operator '", operator_name ,"'.");
}
}
- void erase(const Key& key) {
- auto num_removed =
- map_.write([&](ska::flat_hash_map<Key, DispatchTableEntry>& map) -> size_t {
- return map.erase(key);
- });
- assert(num_removed <= 1); // This is not a multi-map
- if (num_removed == 0) {
- AT_ERROR("Tried to deregister a kernel that isn't registered.");
- }
+ void erase(TensorTypeId key, const std::string& operator_name) {
+ map_.write([&](ska::flat_hash_map<TensorTypeId, DispatchTableEntry>& map) {
+ auto num_removed = map.erase(key);
+
+ assert(num_removed <= 1); // This is not a multi-map
+ if (num_removed == 0) {
+ AT_ERROR("Tried to deregister a kernel with dispatch key '",
+ dispatch_key_to_string(key), "' for operator '", operator_name,
+ "' but that kernel isn't registered. Registered dispatch keys are: ",
+ list_all_dispatch_keys(map));
+ }
+ });
}
- const DispatchTableEntry* lookup(const Key& key) const {
- return map_.read([&](const ska::flat_hash_map<Key, DispatchTableEntry>& map) -> const DispatchTableEntry* {
+ const DispatchTableEntry* lookup(TensorTypeId key, const string& operator_name) const {
+ return map_.read([&](const ska::flat_hash_map<TensorTypeId, DispatchTableEntry>& map) -> const DispatchTableEntry* {
auto found = map.find(key);
if (found != map.end()) {
return &found->second;
} else {
- return nullptr;
+ AT_ERROR("Didn't find kernel to dispatch to for operator '", operator_name,
+ "'. Tried to look up kernel for dispatch key '", dispatch_key_to_string(key),
+ "'. Registered dispatch keys are: ", list_all_dispatch_keys(map));
}
});
}
+ bool isEmpty() const {
+ return map_.read([&](const ska::flat_hash_map<TensorTypeId, DispatchTableEntry>& map) -> bool {
+ return map.size() == 0;
+ });
+ }
+
private:
- LeftRight<ska::flat_hash_map<Key, DispatchTableEntry>> map_;
+ static std::string list_all_dispatch_keys(const ska::flat_hash_map<TensorTypeId, DispatchTableEntry>& map) {
+ if (map.size() == 0) {
+ return "";
+ }
+ std::ostringstream str;
+ str << dispatch_key_to_string(map.begin()->first);
+ for (auto iter = ++map.begin(); iter != map.end(); ++iter) {
+ str << ", " << dispatch_key_to_string(iter->first);
+ }
+ return str.str();
+ }
+
+ static std::string dispatch_key_to_string(TensorTypeId id) {
+ return std::string(toString(tensorTypeIdToBackend(id))) + "[" + guts::to_string(id) + "]";
+ }
+
+ LeftRight<ska::flat_hash_map<TensorTypeId, DispatchTableEntry>> map_;
};
-} // namespace details
+} // namespace detail
/**
* Per-operator dispatch table.
*
- * Given an operator specified by 'OpSchemaDef', this class records a dispatch
+ * Given an operator specified by a FunctionSchema, this class records a dispatch
* table for various kernels provided for this operator. For example, if we
* consider the operator add(Tensor, Tensor), the dispatch table for this
* operator may contain implementations for various dynamic tensor types, such
- * as (CPUFloatTensor, CPUFloatTensor), (CUDAFloatTensor, CUDAFloatTensor), etc.
- *
- * @tparam OpSchemaDef The operator signature this dispatch table encodes.
+ * as CPUTensorId, CUDATensorId, etc.
*/
-// TODO: Support dispatch for meta-operators (which apply to all dynamic types)
-template <class OpSchemaDef>
class DispatchTable final {
- private:
- using Schema = OpSchema<OpSchemaDef>;
-
public:
- DispatchTable() : kernels_() {}
+ explicit DispatchTable(const FunctionSchema& schema)
+ : kernels_()
+ , reverse_index_of_first_tensor_arg_(
+ schema.arguments().size() - get_index_of_first_tensor_arg_(schema))
+ , operator_name_(schema.name()) {}
+
+ DispatchTable(DispatchTable&&) = default;
+ DispatchTable& operator=(DispatchTable&&) = default;
+ DispatchTable(const DispatchTable&) = delete;
+ DispatchTable& operator=(const DispatchTable&) = delete;
/**
* Register a kernel in the table at some dispatch key.
@@ -105,9 +134,9 @@ class DispatchTable final {
* @param dispatch_key Dispatch key to define when this kernel is selected
*/
void registerKernel(
- typename Schema::dispatch::dispatch_key_type dispatch_key,
+ TensorTypeId dispatch_key,
const DispatchTableEntry& kernel) {
- kernels_.emplace(std::move(dispatch_key), kernel);
+ kernels_.emplace(dispatch_key, kernel, operator_name_);
}
/**
@@ -118,9 +147,8 @@ class DispatchTable final {
// TODO: This isn't going to work so well when we get more complicated
// override patterns! In this case, an operator will show up in multiple
// slots, and erasing them one-by-one is probably not such a good idea.
- void deregisterKernel(
- const typename Schema::dispatch::dispatch_key_type& dispatch_key) {
- kernels_.erase(dispatch_key);
+ void deregisterKernel(TensorTypeId dispatch_key) {
+ kernels_.erase(dispatch_key, operator_name_);
}
/**
@@ -131,30 +159,39 @@ class DispatchTable final {
* @return Kernel function pointing to the right kernel for the given arguments
*/
const DispatchTableEntry& lookup(const Stack* stack) const {
- auto dispatch_key = Schema::dispatch::dispatch_key(stack);
- const DispatchTableEntry* found = kernels_.lookup(dispatch_key);
- if (found == nullptr) {
- // TODO Better error message - include op name and dispatch key (i.e.
- // argument types)
- AT_ERROR("Didn't find kernel to dispatch to for operator '", Schema::metadata::name(), "'");
- }
- return *found;
+ TensorTypeId dispatch_key = torch::jit::peek(
+ *stack,
+ 0,
+ reverse_index_of_first_tensor_arg_
+ ).toTensor().type_id();
+ return *kernels_.lookup(dispatch_key, operator_name_);
+ }
+
+ bool isEmpty() const {
+ return kernels_.isEmpty();
}
private:
+ static size_t get_index_of_first_tensor_arg_(const FunctionSchema& schema) {
+ for (size_t i = 0; i < schema.arguments().size(); ++i) {
+ if (schema.arguments()[i].type()->isSubtypeOf(DynamicType::get())) { // DynamicType means it's a tensor
+ return i;
+ }
+ }
+
+ AT_ERROR("Tried to create dispatch table for operator schema ", schema.name(), " that doesn't have tensor arguments.");
+ }
+ detail::ThreadsafeOperatorTable_ kernels_;
- details::ThreadsafeOperatorTable_<
- typename Schema::dispatch::dispatch_key_type>
- kernels_;
+ // this is caching the index so we don't have to parse the schema inputs
+ // again and again for each dispatcher lookup.
+ // reverse_index means this is the distance from the first tensor argument
+ // to argument_list.end(), i.e. from the top of the stack.
+ // Since it is distance to end(), this means it's 1-indexed,
+ // i.e. '1' is the last argument.
+ size_t reverse_index_of_first_tensor_arg_;
+ std::string operator_name_;
};
} // namespace c10
-
-/*
- * Use this to access the dispatch table singleton for a given op schema.
- * It has an implementation for each op schema def in a cpp file, because
- * we can't rely on the one-definition-rule.
- */
-template <class OpSchemaDef>
-C10_API c10::DispatchTable<OpSchemaDef>& c10_dispatch_table();
diff --git a/aten/src/ATen/core/dispatch/Dispatcher.cpp b/aten/src/ATen/core/dispatch/Dispatcher.cpp
index 9f76d8d969..832b687316 100644
--- a/aten/src/ATen/core/dispatch/Dispatcher.cpp
+++ b/aten/src/ATen/core/dispatch/Dispatcher.cpp
@@ -1 +1,8 @@
#include <ATen/core/dispatch/Dispatcher.h>
+
+namespace c10 {
+C10_EXPORT Dispatcher& Dispatcher::singleton() {
+ static Dispatcher _singleton;
+ return _singleton;
+}
+}
diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h
index 1d845750ef..bcca890c16 100644
--- a/aten/src/ATen/core/dispatch/Dispatcher.h
+++ b/aten/src/ATen/core/dispatch/Dispatcher.h
@@ -1,9 +1,14 @@
#pragma once
#include <ATen/core/dispatch/DispatchTable.h>
+#include <c10/util/Exception.h>
+#include <mutex>
+#include <list>
namespace c10 {
+class CAFFE2_API OperatorHandle;
+
/**
* This class represents an operator kernel, i.e. an operator *after* it was
* dispatched to a certain device. You can use it to call the kernel.
@@ -14,12 +19,11 @@ namespace c10 {
* Also, keeping around the OpKernel instance will keep around a local cache
* that is used by some kernels to get better performance when they're called
* multiple times (mostly Caffe2 kernels do that).
+ *
+ * OpKernel is not threadsafe.
*/
-class OpKernel final {
+class CAFFE2_API OpKernel final {
public:
- explicit OpKernel(KernelFunction* kernel, KernelStateCreatorFunction* state_creator)
- : kernel_(kernel), state_creator_(state_creator) {}
-
OpKernel(OpKernel&&) = default;
OpKernel& operator=(OpKernel&&) = default;
OpKernel(const OpKernel&) = delete;
@@ -28,60 +32,137 @@ public:
/**
* Call the operator kernel with the given arguments.
*/
- void call(Stack* stack) {
- if (state_.get() == nullptr) {
- AT_ASSERT(state_creator_ != nullptr);
- state_ = (*state_creator_)();
+ void call(Stack* stack) const {
+ if (cache_.get() == nullptr) {
+ AT_ASSERT(cache_creator_ != nullptr);
+ cache_ = (*cache_creator_)();
}
- return (*kernel_)(stack, state_.get());
+ return (*kernel_)(stack, cache_.get());
}
private:
- // The kernel function is a global C function, not a std::function.
- // That is, ownership is not an issue.
+ explicit OpKernel(KernelFunction* kernel, KernelCacheCreatorFunction* cache_creator)
+ : kernel_(kernel), cache_creator_(cache_creator) {}
+ friend class Dispatcher;
+
KernelFunction* kernel_;
- KernelStateCreatorFunction* state_creator_;
- std::unique_ptr<c10::KernelState> state_;
+ KernelCacheCreatorFunction* cache_creator_;
+ mutable std::unique_ptr<c10::KernelCache> cache_;
};
/**
* Top-level dispatch interface for dispatching via the dynamic dispatcher.
*/
-template<class OpSchemaDef>
-class Dispatcher final {
+class CAFFE2_API Dispatcher final {
private:
- using Schema = OpSchema<OpSchemaDef>;
+ struct OperatorDef final {
+ explicit OperatorDef(FunctionSchema schema_)
+ : dispatchTable(schema_)
+ , schema(std::move(schema_)) {}
+
+ DispatchTable dispatchTable;
+ FunctionSchema schema;
+ };
+ friend class OperatorHandle;
+
public:
// Implementation note: this class abstracts over the fact that we have per-operator
// dispatch tables. This could be easily adjusted to have a single global hash
// table.
+ static Dispatcher& singleton();
+
/**
- * Register an operator to the dispatch table for some operator schema.
+ * Register a new operator schema. The handle returned can be used to register
+ * kernels to this operator or to call it.
*/
- static void registerKernel(typename Schema::dispatch::dispatch_key_type dispatch_key, KernelFunction* kernel_func, KernelStateCreatorFunction* state_creator_func) {
- auto& dispatch_table_for_this_op = c10_dispatch_table<OpSchemaDef>();
- return dispatch_table_for_this_op.registerKernel(std::move(dispatch_key), DispatchTableEntry{kernel_func, state_creator_func});
- }
+ OperatorHandle registerSchema(FunctionSchema schema);
/**
- * Remove an operator from the dispatch table for some operator schema.
+ * Remove an operator from the dispatcher. Make sure you removed
+ * all kernels for this operatorbefore calling this.
*/
- static void deregisterKernel(const typename Schema::dispatch::dispatch_key_type& dispatch_key) {
- auto& dispatch_table_for_this_op = c10_dispatch_table<OpSchemaDef>();
- return dispatch_table_for_this_op.deregisterKernel(dispatch_key);
- }
+ void deregisterSchema(const OperatorHandle& op);
/**
- * Perform a dynamic dispatch and get the kernel for an operator
+ * Register an operator to the dispatch table for an operator.
*/
- static OpKernel lookup(const Stack* stack) {
- auto& dispatch_table_for_this_op = c10_dispatch_table<OpSchemaDef>();
- const DispatchTableEntry& kernel = dispatch_table_for_this_op.lookup(stack);
- return OpKernel(kernel.kernel_func, kernel.state_creator_func);
+ void registerKernel(const OperatorHandle& op, TensorTypeId dispatch_key, KernelFunction* kernel_func, KernelCacheCreatorFunction* cache_creator_func);
+
+ /**
+ * Remove an operator from the dispatch table for an operator.
+ */
+ void deregisterKernel(const OperatorHandle& op, TensorTypeId dispatch_key);
+
+ /**
+ * Perform a dynamic dispatch and get the kernel for an operator.
+ */
+ OpKernel lookup(const OperatorHandle& op, const Stack* stack) const;
+
+private:
+ std::list<OperatorDef> operators_;
+ std::mutex mutex_;
+};
+
+/**
+ * This is a handle to an operator schema registered with the dispatcher.
+ * This handle can be used to register kernels with the dispatcher or
+ * to lookup a kernel for a certain set of arguments.
+ */
+class CAFFE2_API OperatorHandle final {
+public:
+ OperatorHandle(OperatorHandle&&) = default;
+ OperatorHandle& operator=(OperatorHandle&&) = default;
+ OperatorHandle(const OperatorHandle&) = default;
+ OperatorHandle& operator=(const OperatorHandle&) = default;
+
+ const FunctionSchema& schema() const {
+ return operatorDefIterator_->schema;
}
+private:
+ explicit OperatorHandle(std::list<Dispatcher::OperatorDef>::iterator operatorDefIterator)
+ : operatorDefIterator_(std::move(operatorDefIterator)) {}
+ friend class Dispatcher;
+
+ std::list<Dispatcher::OperatorDef>::iterator operatorDefIterator_;
};
+
+
+inline OperatorHandle Dispatcher::registerSchema(FunctionSchema schema) {
+ // we need a lock to avoid concurrent writes
+ std::lock_guard<std::mutex> lock(mutex_);
+
+ operators_.emplace_back(std::move(schema));
+ return OperatorHandle(--operators_.end());
+}
+
+inline void Dispatcher::deregisterSchema(const OperatorHandle& op) {
+ // we need a lock to avoid concurrent writes
+ std::lock_guard<std::mutex> lock(mutex_);
+
+ if (!op.operatorDefIterator_->dispatchTable.isEmpty()) {
+ AT_ERROR("Tried to deregister op schema that still has kernels registered");
+ }
+ operators_.erase(op.operatorDefIterator_);
+}
+
+inline void Dispatcher::registerKernel(const OperatorHandle& op, TensorTypeId dispatch_key, KernelFunction* kernel_func, KernelCacheCreatorFunction* cache_creator_func) {
+ // note: this doesn't need the mutex because write operations on the list keep iterators intact.
+ op.operatorDefIterator_->dispatchTable.registerKernel(std::move(dispatch_key), DispatchTableEntry{kernel_func, cache_creator_func});
+}
+
+inline void Dispatcher::deregisterKernel(const OperatorHandle& op, TensorTypeId dispatch_key) {
+ // note: this doesn't need the mutex because write operations on the list keep iterators intact.
+ op.operatorDefIterator_->dispatchTable.deregisterKernel(dispatch_key);
+}
+
+inline OpKernel Dispatcher::lookup(const OperatorHandle& op, const Stack* stack) const {
+ // note: this doesn't need the mutex because write operations on the list keep iterators intact.
+ const DispatchTableEntry& kernel = op.operatorDefIterator_->dispatchTable.lookup(stack);
+ return OpKernel(kernel.kernel_func, kernel.cache_creator_func);
+}
+
} // namespace c10
diff --git a/aten/src/ATen/core/dispatch/KernelCache.h b/aten/src/ATen/core/dispatch/KernelCache.h
new file mode 100644
index 0000000000..e879b9d09f
--- /dev/null
+++ b/aten/src/ATen/core/dispatch/KernelCache.h
@@ -0,0 +1,20 @@
+#pragma once
+
+namespace c10 {
+
+/**
+ * A kernel can keep around a cache to have better performance when it's
+ * called multiple times. This is used by a lot of caffe2 kernels, for example
+ * conv_op stores a set of tensors for intermediate values to avoid having
+ * to reallocate them on each call.
+ * This cache owned by the call site (i.e. stored inside the OpKernel object)
+ * kept at the call site to call into the kernel) and passed in to the kernel
+ * as a function argument. It must inherit from KernelCache so the call site
+ * knows how to store and destruct it.
+ */
+class CAFFE2_API KernelCache {
+public:
+ virtual ~KernelCache() = default;
+};
+
+}
diff --git a/aten/src/ATen/core/dispatch/KernelFunction.h b/aten/src/ATen/core/dispatch/KernelFunction.h
new file mode 100644
index 0000000000..67a296fa4e
--- /dev/null
+++ b/aten/src/ATen/core/dispatch/KernelFunction.h
@@ -0,0 +1,16 @@
+#pragma once
+
+#include <ATen/core/dispatch/KernelCache.h>
+#include <ATen/core/stack.h>
+
+namespace c10 {
+
+using Stack = torch::jit::Stack; // TODO Instead of this, move torch::jit::Stack to the c10 namespace.
+
+/**
+ * This is the basic ABI for any kernel call. Each kernel is registered as a
+ * function pointer `KernelFunction*`, i.e. kernels are not allowed to be closures.
+ */
+using KernelFunction = void(Stack*, KernelCache* cache);
+
+}
diff --git a/aten/src/ATen/core/dispatch/KernelRegistration.cpp b/aten/src/ATen/core/dispatch/KernelRegistration.cpp
deleted file mode 100644
index 0848300f5e..0000000000
--- a/aten/src/ATen/core/dispatch/KernelRegistration.cpp
+++ /dev/null
@@ -1 +0,0 @@
-#include <ATen/core/dispatch/KernelRegistration.h>
diff --git a/aten/src/ATen/core/dispatch/KernelRegistration.h b/aten/src/ATen/core/dispatch/KernelRegistration.h
index df198c437d..eaab0c15e0 100644
--- a/aten/src/ATen/core/dispatch/KernelRegistration.h
+++ b/aten/src/ATen/core/dispatch/KernelRegistration.h
@@ -2,13 +2,20 @@
#include <c10/util/Optional.h>
#include <ATen/core/dispatch/Dispatcher.h>
-#include <ATen/core/dispatch/OpSchema.h>
/**
* To register your own kernel for an operator, do in one (!) cpp file:
- * C10_REGISTER_KERNEL(OpSchemaDef)
- * .kernel(&kernel_func)
+ * C10_REGISTER_KERNEL(OperatorHandle)
+ * .kernel<decltype(&kernel_func), &kernel_func>()
* .dispatchKey(dispatch_key);
+ *
+ * Example:
+ *
+ * Tensor my_kernel_cpu(Tensor in) {...}
+ *
+ * C10_REGISTER_KERNEL(MyOpSchema)
+ * .kernel<decltype(my_kernel_cpu), &my_kernel_cpu>()
+ * .dispatchKey(CPUTensorId());
*/
namespace c10 {
@@ -17,30 +24,29 @@ namespace c10 {
// TODO Test no dispatch key defined
/**
- * Class which, on construction, registers an operator in the dispatch table. The intent is that
+ * Class which, on construction, registers an operator in the dispatch table. The intent is that
* this class is constructed at static initialization time so that operators automatically get
* registered when a dlopen() occurs.
*
- * You shouldn't call this directly; instead, use the KernelRegistrationBuilder
- *
- * @tparam OpSchemaDef
+ * You shouldn't call this directly; instead, use the C10_REGISTER_KERNEL macros.
*/
-template<class OpSchemaDef>
class KernelRegistrar final {
-private:
- using Schema = OpSchema<OpSchemaDef>;
public:
+ using OpHandleGetter = const OperatorHandle& ();
+
/**
- * @param kernel The concrete function implementation to register
+ * @param op The operator to register the kernel for
* @param dispatch_key The dispatch key to register the function to
+ * @param kernel The concrete function implementation to register
+ * @param cache_creator A function initializing the cache for the kernel
*/
- KernelRegistrar(typename Schema::dispatch::dispatch_key_type dispatch_key, KernelFunction* kernel, KernelStateCreatorFunction* state_creator)
- : dispatch_key_(std::move(dispatch_key)), owns_registration_(true) {
- Dispatcher<OpSchemaDef>::registerKernel(dispatch_key_, kernel, state_creator);
+ explicit KernelRegistrar(OpHandleGetter *op, TensorTypeId dispatch_key, KernelFunction* kernel, KernelCacheCreatorFunction* cache_creator)
+ : op_(std::move(op)), dispatch_key_(std::move(dispatch_key)), owns_registration_(true) {
+ Dispatcher::singleton().registerKernel(op_(), dispatch_key_, kernel, cache_creator);
}
KernelRegistrar(KernelRegistrar&& rhs)
- : dispatch_key_(std::move(rhs.dispatch_key_)), owns_registration_(true) {
+ : op_(std::move(rhs.op_)), dispatch_key_(std::move(rhs.dispatch_key_)), owns_registration_(true) {
rhs.owns_registration_ = false;
}
@@ -49,17 +55,111 @@ public:
~KernelRegistrar() {
if (owns_registration_) {
- Dispatcher<OpSchemaDef>::deregisterKernel(dispatch_key_);
+ Dispatcher::singleton().deregisterKernel(op_(), dispatch_key_);
}
}
private:
- const typename Schema::dispatch::dispatch_key_type dispatch_key_;
+ OpHandleGetter *op_;
+ const TensorTypeId dispatch_key_;
bool owns_registration_;
C10_DISABLE_COPY_AND_ASSIGN(KernelRegistrar);
};
+namespace detail {
+// ivalue_to_arg_type<T>: Take an IValue that is an argument to a kernel and
+// cast it to the type that should be passed to the kernel function.
+// Examples: If the IValue contains a plain type like an int, return that.
+// If the IValue contains an IntList, return it as ArrayRef<int>.
+template<class T>
+struct ivalue_to_arg_type {
+ static T call(const IValue& v) {
+ return std::move(v).to<T>();
+ }
+};
+template<class T>
+struct ivalue_to_arg_type<ArrayRef<T>> {
+ static ArrayRef<T> call(const IValue& v) {
+ return v.to<intrusive_ptr<ivalue::List<T>>>()->elements();
+ }
+};
+
+// call_with_ivalue_args: Take a function pointer and an ArrayRef<IValue>
+// containing the arguments to call the function pointer with, and call it.
+// The extra_args are appended as additional arguments at the end of the function call.
+// Example:
+// int myfunc(int a, ArrayRef<int> b, string c);
+// int main() {
+// std::vector<IValue> ivalue_args = {IValue(2), IntList::create(3, 4)};
+// call_with_ivalue_args<decltype(myfunc), &myfunc>(ivalue_args, "extra_arg");
+// }
+template<class FuncType, FuncType* func, class... ExtraArgs, size_t... ivalue_arg_indices>
+typename guts::function_traits<FuncType>::return_type call_with_ivalue_args_(ArrayRef<IValue> ivalue_args, guts::index_sequence<ivalue_arg_indices...>, ExtraArgs&&... extra_args) {
+ using IValueArgTypes = typename guts::function_traits<FuncType>::parameter_types;
+ return (*func)(ivalue_to_arg_type<guts::remove_cv_t<guts::remove_reference_t<guts::typelist::element_t<ivalue_arg_indices, IValueArgTypes>>>>::call(ivalue_args[ivalue_arg_indices])..., std::forward<ExtraArgs>(extra_args)...);
+}
+
+template<class FuncType, FuncType* func, class... ExtraArgs>
+typename guts::function_traits<FuncType>::return_type call_with_ivalue_args(ArrayRef<IValue> ivalue_args, ExtraArgs&&... extra_args) {
+ constexpr size_t num_ivalue_args = guts::function_traits<FuncType>::number_of_parameters - sizeof...(ExtraArgs);
+ AT_ASSERTM(num_ivalue_args == ivalue_args.size(), "Wrong number of ivalue arguments");
+ return call_with_ivalue_args_<FuncType, func>(ivalue_args, guts::make_index_sequence<num_ivalue_args>(), std::forward<ExtraArgs>(extra_args)...);
+}
+
+template<class OutputType>
+struct push_outputs final {
+ static void call(OutputType&& output, Stack* stack) {
+ push_outputs<std::tuple<OutputType>>(std::tuple<OutputType>(std::move(output)), stack);
+ }
+};
+template<class... OutputTypes>
+struct push_outputs<std::tuple<OutputTypes...>> final {
+ static void call(std::tuple<OutputTypes...>&& output, Stack* stack) {
+ for (size_t i = 0; i < sizeof...(OutputTypes); ++i) {
+ torch::jit::push(return_type_to_ivalue(std::move(output)));
+ }
+ }
+};
+
+// SFINAE over (1) does the operator kernel have a cache and (2) does it return a value or void
+template<class CacheTypeOrVoid, class FuncType, FuncType* kernel, class Enable = void> struct wrap_kernel {};
+// SFINAE version for kernels with output and with cache
+template<class CacheTypeOrVoid, class FuncType, FuncType* kernel>
+struct wrap_kernel<CacheTypeOrVoid, FuncType, kernel, guts::enable_if_t<!std::is_same<void, CacheTypeOrVoid>::value && !std::is_same<void, typename guts::function_traits<FuncType>::return_type>::value>> final {
+ static typename guts::function_traits<FuncType>::return_type call(Stack* stack, KernelCache* cache) {
+ constexpr size_t num_inputs = guts::function_traits<FuncType>::number_of_parameters - 1; // -1 because it takes the kernel cache as last argument
+ auto output = call_with_ivalue_args<FuncType, kernel>(torch::jit::last(*stack, num_inputs), static_cast<CacheTypeOrVoid*>(cache));
+ push_outputs<typename guts::function_traits<FuncType>::return_type>(std::move(output), stack);
+ }
+};
+// SFINAE version for kernels with output and without a cache
+template<class CacheTypeOrVoid, class FuncType, FuncType* kernel>
+struct wrap_kernel<CacheTypeOrVoid, FuncType, kernel, guts::enable_if_t<std::is_same<void, CacheTypeOrVoid>::value && !std::is_same<void, typename guts::function_traits<FuncType>::return_type>::value>> final {
+ static typename guts::function_traits<FuncType>::return_type call(Stack* stack, c10::KernelCache* /*cache*/) {
+ constexpr size_t num_inputs = guts::function_traits<FuncType>::number_of_parameters;
+ auto output = call_with_ivalue_args<FuncType, kernel>(torch::jit::last(*stack, num_inputs));
+ push_outputs<typename guts::function_traits<FuncType>::return_type>(std::move(output), stack);
+ }
+};
+// SFINAE version for kernels without output and with a cache
+template<class CacheTypeOrVoid, class FuncType, FuncType* kernel>
+struct wrap_kernel<CacheTypeOrVoid, FuncType, kernel, guts::enable_if_t<!std::is_same<void, CacheTypeOrVoid>::value && std::is_same<void, typename guts::function_traits<FuncType>::return_type>::value>> final {
+ static typename guts::function_traits<FuncType>::return_type call(Stack* stack, c10::KernelCache* cache) {
+ constexpr size_t num_inputs = guts::function_traits<FuncType>::number_of_parameters - 1; // -1 because it takes the kernel cache as last argument
+ call_with_ivalue_args<FuncType, kernel>(torch::jit::last(*stack, num_inputs), static_cast<CacheTypeOrVoid*>(cache));
+ }
+};
+// SFINAE version for kernels without output and without a cache
+template<class CacheTypeOrVoid, class FuncType, FuncType* kernel>
+struct wrap_kernel<CacheTypeOrVoid, FuncType, kernel, guts::enable_if_t<std::is_same<void, CacheTypeOrVoid>::value && std::is_same<void, typename guts::function_traits<FuncType>::return_type>::value>> final {
+ static typename guts::function_traits<FuncType>::return_type call(Stack* stack, c10::KernelCache* /*cache*/) {
+ constexpr size_t num_inputs = guts::function_traits<FuncType>::number_of_parameters;
+ call_with_ivalue_args<FuncType, kernel>(torch::jit::last(*stack, num_inputs));
+ }
+};
+}
+
/**
* Helper class for building a KernelRegistrar. This permits "keyword-argument" like syntax
* when performing operator registration, e.g., as in:
@@ -76,52 +176,51 @@ private:
* .dispatchKey("bla");
*
* The resulting full expression is implicitly convertible to a KernelRegistrar.
- *
- * @tparam OpSchemaDef The operator schema this is building a KernelRegistration for
- * @tparam FieldsPresentFlags Remembers which fields are already set in the builder
*/
-template<class OpSchemaDef, class StateTypeOrVoid, uint64_t FieldsPresentFlags>
+template<class CacheTypeOrVoid, uint64_t FieldsPresentFlags>
class KernelRegistrationBuilder final {
private:
- using Schema = OpSchema<OpSchemaDef>;
-
static constexpr uint64_t DISPATCH_KEY_PRESENT = 0x01 << 0;
static constexpr uint64_t KERNEL_PRESENT = 0x01 << 1;
- static constexpr uint64_t STATE_PRESENT = 0x01 << 2;
+ static constexpr uint64_t CACHE_PRESENT = 0x01 << 2;
+
+ using OpHandleGetter = KernelRegistrar::OpHandleGetter;
- static std::unique_ptr<c10::KernelState> defaultStateCreator() {
+ static std::unique_ptr<c10::KernelCache> defaultCacheCreator() {
return nullptr;
}
- template<class State>
- static std::unique_ptr<c10::KernelState> stateCreator() {
- static_assert(std::is_default_constructible<State>::value, "State class must be default constructible");
- return guts::make_unique<State>();
+ template<class Cache>
+ static std::unique_ptr<c10::KernelCache> cacheCreator() {
+ static_assert(std::is_default_constructible<Cache>::value, "Cache class must be default constructible");
+ return guts::make_unique<Cache>();
}
- c10::optional<typename Schema::dispatch::dispatch_key_type> dispatch_key_;
+ OpHandleGetter *op_;
+ c10::optional<TensorTypeId> dispatch_key_;
KernelFunction* kernel_;
- KernelStateCreatorFunction* state_creator_;
+ KernelCacheCreatorFunction* cache_creator_;
public:
- constexpr KernelRegistrationBuilder()
- : KernelRegistrationBuilder(c10::nullopt, nullptr, &defaultStateCreator) {}
+ constexpr explicit KernelRegistrationBuilder(OpHandleGetter *op)
+ : KernelRegistrationBuilder(std::move(op), c10::nullopt, nullptr, &defaultCacheCreator) {}
- constexpr KernelRegistrationBuilder(
- c10::optional<typename Schema::dispatch::dispatch_key_type> dispatch_key,
+ constexpr explicit KernelRegistrationBuilder(
+ OpHandleGetter *op,
+ c10::optional<TensorTypeId> dispatch_key,
KernelFunction* kernel,
- KernelStateCreatorFunction* state_creator)
- : dispatch_key_(std::move(dispatch_key)), kernel_(kernel), state_creator_(state_creator) {}
+ KernelCacheCreatorFunction* cache_creator)
+ : op_(std::move(op)), dispatch_key_(std::move(dispatch_key)), kernel_(kernel), cache_creator_(cache_creator) {}
/**
- * Implicit coercion to KernelRegistrar<OpSchemaDef> that finalizes the builder and
+ * Implicit coercion to KernelRegistrar that finalizes the builder and
* creates the object.
* @return Produced KernelRegistrar
*/
- operator KernelRegistrar<OpSchemaDef>() && {
+ operator KernelRegistrar() && {
static_assert(FieldsPresentFlags & KERNEL_PRESENT, "Forgot to call .kernel() in kernel registration");
static_assert(FieldsPresentFlags & DISPATCH_KEY_PRESENT, "Forgot to call .dispatchKey() in kernel registration");
- return KernelRegistrar<OpSchemaDef>(std::move(*dispatch_key_), kernel_, state_creator_);
+ return KernelRegistrar(op_, std::move(*dispatch_key_), kernel_, cache_creator_);
}
/**
@@ -129,9 +228,9 @@ private:
* @param dispatch_key dispatch key to register the function to
* @return "this" for method chaining
*/
- constexpr KernelRegistrationBuilder<OpSchemaDef, StateTypeOrVoid, FieldsPresentFlags | DISPATCH_KEY_PRESENT> dispatchKey(typename Schema::dispatch::dispatch_key_type dispatch_key) && {
+ constexpr KernelRegistrationBuilder<CacheTypeOrVoid, FieldsPresentFlags | DISPATCH_KEY_PRESENT> dispatchKey(TensorTypeId dispatch_key) && {
static_assert(!(FieldsPresentFlags & DISPATCH_KEY_PRESENT), "Tried to define kernel twice in same op registration");
- return KernelRegistrationBuilder<OpSchemaDef, StateTypeOrVoid, FieldsPresentFlags | DISPATCH_KEY_PRESENT>(std::move(dispatch_key), kernel_, state_creator_);
+ return KernelRegistrationBuilder<CacheTypeOrVoid, FieldsPresentFlags | DISPATCH_KEY_PRESENT>(std::move(op_), std::move(dispatch_key), kernel_, cache_creator_);
}
/**
@@ -140,10 +239,10 @@ private:
* @return "this" for method chaining
*/
template<KernelFunction* kernel_func>
- constexpr KernelRegistrationBuilder<OpSchemaDef, StateTypeOrVoid, FieldsPresentFlags | KERNEL_PRESENT> kernel() && {
+ constexpr KernelRegistrationBuilder<CacheTypeOrVoid, FieldsPresentFlags | KERNEL_PRESENT> kernel() && {
static_assert(!(FieldsPresentFlags & KERNEL_PRESENT), "Tried to define kernel twice in same op registration");
- // TODO Better error message when kernel function mismatches, one common mismatch is missing state parameter or state parameter present while not expected.
- return KernelRegistrationBuilder<OpSchemaDef, StateTypeOrVoid, FieldsPresentFlags | KERNEL_PRESENT>(std::move(dispatch_key_), kernel_func, state_creator_);
+ // TODO Better error message when kernel function mismatches, one common mismatch is missing cache parameter or cache parameter present while not expected.
+ return KernelRegistrationBuilder<CacheTypeOrVoid, FieldsPresentFlags | KERNEL_PRESENT>(std::move(op_), std::move(dispatch_key_), kernel_func, cache_creator_);
}
/**
@@ -151,9 +250,10 @@ private:
* @param kernel concrete function implementation to be registered
* @return "this" for method chaining
*/
- template<typename Schema::signature::template func_type_with_state<StateTypeOrVoid>* kernel_func>
- constexpr KernelRegistrationBuilder<OpSchemaDef, StateTypeOrVoid, FieldsPresentFlags | KERNEL_PRESENT> kernel() && {
- return std::move(*this).template kernel<&Schema::signature::template wrap_kernel<StateTypeOrVoid, kernel_func>>();
+ template<class FuncType, FuncType* kernel_func>
+ constexpr KernelRegistrationBuilder<CacheTypeOrVoid, FieldsPresentFlags | KERNEL_PRESENT> kernel() && {
+ // TODO Better error message if FuncType is not a func type
+ return std::move(*this).template kernel<&detail::wrap_kernel<CacheTypeOrVoid, FuncType, kernel_func>::call>();
}
/**
@@ -161,20 +261,19 @@ private:
* @param dispatch_key dispatch key to register the function to
* @return "this" for method chaining
*/
- template<class State>
- constexpr KernelRegistrationBuilder<OpSchemaDef, State, FieldsPresentFlags | STATE_PRESENT> withState() && {
- static_assert(!(FieldsPresentFlags & STATE_PRESENT), "Tried to define state twice in same op registration");
- static_assert(std::is_base_of<c10::KernelState, State>::value, "State must inherit from c10::KernelState");
+ template<class Cache>
+ constexpr KernelRegistrationBuilder<Cache, FieldsPresentFlags | CACHE_PRESENT> withCache() && {
+ static_assert(!(FieldsPresentFlags & CACHE_PRESENT), "Tried to define cache twice in same op registration");
+ static_assert(std::is_base_of<c10::KernelCache, Cache>::value, "Cache must inherit from c10::KernelCache");
- static_assert(!(FieldsPresentFlags & KERNEL_PRESENT), "Cannot set the state after the kernel function is already set. Please call .withState() first and .kernel() later in the chain.");
+ static_assert(!(FieldsPresentFlags & KERNEL_PRESENT), "Cannot set the cache after the kernel function is already set. Please call .withCache() first and .kernel() later in the chain.");
- return KernelRegistrationBuilder<OpSchemaDef, State, FieldsPresentFlags | STATE_PRESENT>(std::move(dispatch_key_), kernel_, &stateCreator<State>);
+ return KernelRegistrationBuilder<Cache, FieldsPresentFlags | CACHE_PRESENT>(std::move(op_), std::move(dispatch_key_), kernel_, &cacheCreator<Cache>);
}
};
} // namespace c10
-// TODO Can the builder logic be moved to compile time?
// NB: Semicolon after applying this macro is MANDATORY
-#define C10_REGISTER_KERNEL(OpSchemaDef) \
- static KernelRegistrar<OpSchemaDef> MACRO_CONCAT(__kernelRegistrationBuilder_, __COUNTER__) = KernelRegistrationBuilder<OpSchemaDef, void, 0>()
+#define C10_REGISTER_KERNEL(OperatorHandle) \
+ static KernelRegistrar MACRO_CONCAT(__kernelRegistrationBuilder_, __COUNTER__) = KernelRegistrationBuilder<void, 0>(OperatorHandle)
diff --git a/aten/src/ATen/core/dispatch/LayoutId.cpp b/aten/src/ATen/core/dispatch/LayoutId.cpp
deleted file mode 100644
index bee71d8dcc..0000000000
--- a/aten/src/ATen/core/dispatch/LayoutId.cpp
+++ /dev/null
@@ -1 +0,0 @@
-#include <ATen/core/dispatch/LayoutId.h>
diff --git a/aten/src/ATen/core/dispatch/LayoutId.h b/aten/src/ATen/core/dispatch/LayoutId.h
deleted file mode 100644
index d0648392f4..0000000000
--- a/aten/src/ATen/core/dispatch/LayoutId.h
+++ /dev/null
@@ -1,22 +0,0 @@
-#pragma once
-
-#include <c10/util/IdWrapper.h>
-
-namespace c10 {
-
-class LayoutId final : public at::IdWrapper<LayoutId, uint8_t> {
-public:
- constexpr explicit LayoutId(underlying_type id): IdWrapper(id) {}
-
- constexpr uint8_t value() const {
- return underlyingId();
- }
-
- // Don't use this default constructor!
- // Unfortunately, a default constructor needs to be defined because of https://reviews.llvm.org/D41223
- constexpr LayoutId(): IdWrapper(0) {}
-};
-
-}
-
-C10_DEFINE_HASH_FOR_IDWRAPPER(c10::LayoutId)
diff --git a/aten/src/ATen/core/dispatch/OpSchema.cpp b/aten/src/ATen/core/dispatch/OpSchema.cpp
deleted file mode 100644
index 595989569c..0000000000
--- a/aten/src/ATen/core/dispatch/OpSchema.cpp
+++ /dev/null
@@ -1 +0,0 @@
-#include <ATen/core/dispatch/OpSchema.h>
diff --git a/aten/src/ATen/core/dispatch/OpSchema.h b/aten/src/ATen/core/dispatch/OpSchema.h
deleted file mode 100644
index d88bd7db87..0000000000
--- a/aten/src/ATen/core/dispatch/OpSchema.h
+++ /dev/null
@@ -1,406 +0,0 @@
-#pragma once
-
-#include <ATen/core/dispatch/DispatchKey.h>
-#include <ATen/core/ivalue.h>
-#include <c10/util/Array.h>
-#include <c10/util/Metaprogramming.h>
-#include <c10/util/TypeList.h>
-#include <c10/core/DeviceType.h>
-#include <ATen/core/Tensor.h>
-#include <ATen/core/stack.h>
-
-namespace c10 {
-
-/**
- * A kernel can keep around a cache to have better performance when it's
- * called multiple times. This is used by a lot of caffe2 kernels.
- * This cache owned by the call site and passed in to the kernel as a function
- * argument. It must inherit from KernelState so the call site knows how to
- * store and destruct it.
- */
-class CAFFE2_API KernelState {
-public:
- virtual ~KernelState() = default;
-};
-
-using Stack = torch::jit::Stack; // TODO Instead of this, move torch::jit::Stack to the c10 namespace.
-
-/**
- * This is the basic ABI for any kernel call. Each kernel is registered as a
- * pointer to a global C function of this type.
- */
-using KernelFunction = void(Stack*, KernelState* state);
-
-namespace details {
-
-/**
- * If Arg is a Tensor or reference to a Tensor, provide the member constant value equal to true. Otherwise
- * return false.
- */
-template <class Arg>
-using is_tensor_arg = std::
- is_same<at::Tensor, guts::remove_cv_t<guts::remove_reference_t<Arg>>>;
-
-inline DeviceTypeId to_device_type_id(DeviceType device_type) {
- switch (device_type) {
- case DeviceType::CPU:
- return DeviceTypeId::CPU;
- case DeviceType::CUDA:
- return DeviceTypeId::CUDA;
- default:
- return DeviceTypeId::UNDEFINED;
- }
-}
-
-inline TensorParameterDispatchKey tensor_to_dispatch_key(const at::Tensor& tensor) {
- return TensorParameterDispatchKey{
- to_device_type_id(tensor.device().type()),
- LayoutId(0),
- tensor.dtype().id()};
-}
-
-template<size_t index, size_t offset, class ParameterTypes, class Enable = void> struct get_ith_tensor_arg_ {
- static_assert(!std::is_same<ParameterTypes, ParameterTypes>::value, "Index out of bounds");
-};
-template<size_t index, size_t offset, class Head, class... Tail>
-struct get_ith_tensor_arg_<index, offset, guts::typelist::typelist<Head, Tail...>, guts::enable_if_t<index == 0 && is_tensor_arg<Head>::value>> {
- static at::Tensor call(ArrayRef<IValue> args) {
- if (!args[offset].isTensor()) {
- throw std::runtime_error("Expected argument " + guts::to_string(offset) + " to be of type Tensor but found different type.");
- }
- return args[offset].toTensor();
- }
-};
-template<size_t index, size_t offset, class Head, class... Tail>
-struct get_ith_tensor_arg_<index, offset, guts::typelist::typelist<Head, Tail...>, guts::enable_if_t<index != 0 || !is_tensor_arg<Head>::value>> {
- static at::Tensor call(ArrayRef<IValue> args) {
- return get_ith_tensor_arg_<(is_tensor_arg<Head>::value ? (index-1) : index), offset + 1, guts::typelist::typelist<Tail...>>::call(args);
- }
-};
-template<class ParameterTypes, size_t index> at::Tensor get_ith_tensor_arg(ArrayRef<IValue> args) {
- return get_ith_tensor_arg_<index, 0, ParameterTypes>::call(args);
-}
-
-// Extract type ids for all tensors from an array of tensors
-template<class OpSchemaDef, size_t... indices>
-guts::array<TensorParameterDispatchKey, OpSchemaDef::num_dispatch_args()> getDispatchTypeIds__(ArrayRef<IValue> args, guts::index_sequence<indices...>) {
- using ParameterTypes = typename guts::function_traits<typename OpSchemaDef::Signature>::parameter_types;
- return {tensor_to_dispatch_key(get_ith_tensor_arg<ParameterTypes, indices>(args))...};
-}
-
-/**
- * Extract the type ids of all tensors in a variadic list of arguments
- *
- * @tparam Args Inferred variadic list of argument types
- * @param args List of arguments to get type ids from
- * @return guts::array<TensorParameterDispatchKey, n>, where n is the number of tensor arguments (is_tensor_arg) in the class
- */
-template<class OpSchemaDef>
-guts::array<TensorParameterDispatchKey, OpSchemaDef::num_dispatch_args()> getDispatchTypeIds_(ArrayRef<IValue> args) {
- return getDispatchTypeIds__<OpSchemaDef>(args, guts::make_index_sequence<OpSchemaDef::num_dispatch_args()>());
-}
-
-// TODO Test getDispatchTypeIds_
-
-/**
- * If T is a struct with a type field Signature, provides the member constant
- * @tparam T
- */
-template<class T, typename = void>
-struct has_signature_defined : std::false_type {};
-template<class T>
-struct has_signature_defined<T, guts::void_t<
- typename T::Signature
->> : std::true_type {};
-
-// TODO Test has_signature_defined
-
-template<class T, typename = void>
-struct has_parameter_names_defined : std::false_type {};
-template<class T>
-struct has_parameter_names_defined<T, guts::void_t<
- decltype(T::parameter_names)
->> : std::true_type {};
-
-// TODO Test has_parameter_names_defined
-
-template<class T, typename = void>
-struct has_name_defined : std::false_type {};
-template<class T>
-struct has_name_defined<T, guts::void_t<
- decltype(T::name)
->> : std::true_type {};
-
-// TODO Test has_name_defined
-
-template<class T>
-struct ivalue_to_arg_type {
- static T call(const IValue& v) {
- return std::move(v).to<T>();
- }
-};
-template<class T>
-struct ivalue_to_arg_type<ArrayRef<T>> {
- static ArrayRef<T> call(const IValue& v) {
- return v.to<intrusive_ptr<ivalue::List<T>>>()->elements();
- }
-};
-
-template<class FuncType, class... ExtraArgs, size_t... ivalue_arg_indices>
-typename guts::function_traits<FuncType>::return_type call_with_ivalue_args_(FuncType* func, ArrayRef<IValue> ivalue_args, guts::index_sequence<ivalue_arg_indices...>, ExtraArgs&&... extra_args) {
- using IValueArgTypes = typename guts::function_traits<FuncType>::parameter_types;
- return (*func)(ivalue_to_arg_type<guts::remove_cv_t<guts::remove_reference_t<guts::typelist::element_t<ivalue_arg_indices, IValueArgTypes>>>>::call(ivalue_args[ivalue_arg_indices])..., std::forward<ExtraArgs>(extra_args)...);
-}
-
-template<class FuncType, class... ExtraArgs>
-typename guts::function_traits<FuncType>::return_type call_with_ivalue_args(FuncType* func, ArrayRef<IValue> ivalue_args, ExtraArgs&&... extra_args) {
- constexpr size_t num_ivalue_args = guts::function_traits<FuncType>::number_of_parameters - sizeof...(ExtraArgs);
- return call_with_ivalue_args_<FuncType>(func, ivalue_args, guts::make_index_sequence<num_ivalue_args>(), std::forward<ExtraArgs>(extra_args)...);
-}
-
-template<class OutputType>
-struct write_outputs final {
- static void call(OutputType&& output, ArrayRef<IValue> outputs) {
- write_outputs<std::tuple<OutputType>>(std::tuple<OutputType>(std::move(output)), outputs);
- }
-};
-template<class... OutputTypes>
-struct write_outputs<std::tuple<OutputTypes...>> final {
- static void call(std::tuple<OutputTypes...>&& output, ArrayRef<IValue> outputs) {
- AT_ASSERT(outputs.size() == sizeof...(OutputTypes)); // Mismatch in number of returns between kernel function and operator schema.
- for (size_t i = 0; i < sizeof...(OutputTypes); ++i) {
- outputs[i] = return_type_to_ivalue(std::move(output));
- }
- }
-};
-
-
-// SFINAE over (1) does the operator kernel have state and (2) does it return a value or void
-template<class StateTypeOrVoid, class FuncType, class Enable = void> struct call_kernel_with_ivalue_args {};
-// SFINAE version for kernels with output and with state
-template<class StateTypeOrVoid, class FuncType>
-struct call_kernel_with_ivalue_args<StateTypeOrVoid, FuncType, guts::enable_if_t<!std::is_same<void, StateTypeOrVoid>::value && !std::is_same<void, typename guts::function_traits<FuncType>::return_type>::value>> final {
- static typename guts::function_traits<FuncType>::return_type call(FuncType* func, ArrayRef<IValue> ivalue_args, ArrayRef<IValue> outputs, c10::KernelState* state) {
- auto output = call_with_ivalue_args(func, ivalue_args, static_cast<StateTypeOrVoid*>(state));
- write_outputs<typename guts::function_traits<FuncType>::return_type>(std::move(output), outputs);
- }
-};
-// SFINAE version for kernels with output and without state
-template<class StateTypeOrVoid, class FuncType>
-struct call_kernel_with_ivalue_args<StateTypeOrVoid, FuncType, guts::enable_if_t<std::is_same<void, StateTypeOrVoid>::value && !std::is_same<void, typename guts::function_traits<FuncType>::return_type>::value>> final {
- static typename guts::function_traits<FuncType>::return_type call(FuncType* func, ArrayRef<IValue> ivalue_args, ArrayRef<IValue> outputs, c10::KernelState* /*state*/) {
- auto output = call_with_ivalue_args(func, ivalue_args);
- write_outputs<typename guts::function_traits<FuncType>::return_type>(std::move(output), outputs);
- }
-};
-// SFINAE version for kernels without output and with state
-template<class StateTypeOrVoid, class FuncType>
-struct call_kernel_with_ivalue_args<StateTypeOrVoid, FuncType, guts::enable_if_t<!std::is_same<void, StateTypeOrVoid>::value && std::is_same<void, typename guts::function_traits<FuncType>::return_type>::value>> final {
- static typename guts::function_traits<FuncType>::return_type call(FuncType* func, ArrayRef<IValue> ivalue_args, ArrayRef<IValue> outputs, c10::KernelState* state) {
- call_with_ivalue_args(func, ivalue_args, static_cast<StateTypeOrVoid*>(state));
- }
-};
-// SFINAE version for kernels without output and without state
-template<class StateTypeOrVoid, class FuncType>
-struct call_kernel_with_ivalue_args<StateTypeOrVoid, FuncType, guts::enable_if_t<std::is_same<void, StateTypeOrVoid>::value && std::is_same<void, typename guts::function_traits<FuncType>::return_type>::value>> final {
- static typename guts::function_traits<FuncType>::return_type call(FuncType* func, ArrayRef<IValue> ivalue_args, ArrayRef<IValue> outputs, c10::KernelState* /*state*/) {
- call_with_ivalue_args(func, ivalue_args);
- }
-};
-
-template<class FuncType, class AddedParameter, class Enable = void> struct add_ptr_parameter_if_not_void final {};
-template<class Return, class... Parameters, class AddedParameter>
-struct add_ptr_parameter_if_not_void<Return(Parameters...), AddedParameter, guts::enable_if_t<!std::is_same<void, AddedParameter>::value>> final {
- using type = Return(Parameters..., AddedParameter*);
-};
-template<class FuncType> struct add_ptr_parameter_if_not_void<FuncType, void, void> final {
- using type = FuncType;
-};
-
-/**
- * Wrapper class around a user-provided schema definition some useful information about the schema.
- *
- * @tparam OpSchemaDef Operator schema definition. See OpSchema for more details.
- */
-template<class OpSchemaDef> class OpSignatureSchema final {
- static_assert(details::has_signature_defined<OpSchemaDef>::value, "Operator schema doesn't define a valid Signature member type.");
- static_assert(guts::is_function_type<typename OpSchemaDef::Signature>::value, "Signature member of operator schema must be a function type.");
-
- using signature_traits = guts::function_traits<typename OpSchemaDef::Signature>;
-public:
- /**
- * The function type OpSchemaDef::Signature
- */
- using func_type = typename signature_traits::func_type;
- /**
- * The return type of the function OpSchemaDef::Signature
- */
- using return_type = typename signature_traits::return_type;
- /**
- * A type list of the parameter types of OpSchemaDef::Signature
- */
- using parameter_types = typename signature_traits::parameter_types;
-
- /**
- * The number of arguments of OpSchemaDef::Signature
- */
- static constexpr size_t num_args = guts::typelist::size<parameter_types>::value;
- /**
- * The number of tensor arguments (as per is_tensor_arg) in OpSchemaDef::Signature
- */
- static constexpr size_t num_tensor_args = guts::typelist::count_if<details::is_tensor_arg, parameter_types>::value;
-
- static constexpr size_t num_outputs = OpSchemaDef::num_outputs();
-
- template<class StateTypeOrVoid> using func_type_with_state = typename add_ptr_parameter_if_not_void<func_type, StateTypeOrVoid>::type;
-
- template<class StateTypeOrVoid, func_type_with_state<StateTypeOrVoid>* kernel>
- static void wrap_kernel(Stack* stack, KernelState* state) {
- constexpr size_t num_inputs = guts::typelist::size<parameter_types>::value;
- constexpr size_t num_outputs = 1; // TODO allow multiple outputs if it's a tuple
-
- ArrayRef<IValue> inputs = torch::jit::peekSlice(*stack, 0, num_inputs + num_outputs, num_inputs);
- ArrayRef<IValue> outputs = torch::jit::peekSlice(*stack, 0, num_outputs, num_outputs);
-
- call_kernel_with_ivalue_args<StateTypeOrVoid, func_type_with_state<StateTypeOrVoid>>::call(kernel, inputs, outputs, state);
- }
-
-private:
- static_assert(details::has_parameter_names_defined<OpSchemaDef>::value, "Operator schema doesn't define parameter_names member.");
- // TODO Allow simpler definition of parameter_names without having to spell out the guts::array type in the schema def.
- static_assert(std::is_same<const guts::array<const char*, num_args>, decltype(OpSchemaDef::parameter_names)>::value, "Operator schema defines parameter_names member, but it isn't the correct type. Must be a static constexpr guts::array of const char* with one entry for each parameter.");
-
-public:
- /**
- * The names of the parameters (as per OpSchemaDef::parameter_names)
- * @return Array
- */
- static constexpr const guts::array<const char*, num_args>& parameter_names() {
- return OpSchemaDef::parameter_names;
- }
-};
-
-/**
- * If T has a method dispatch_key, provide a member constant value equal to true. Otherwise return false.
- * @tparam T
- */
-template<class T, typename = void>
-struct has_function_dispatch_key_defined : std::false_type {};
-template<class T>
-struct has_function_dispatch_key_defined<T, guts::void_t<
- decltype(&T::dispatch_key)
->> : std::true_type {};
-
-/**
- * Wrapper class around a user-defined schema definition providing a way of computing a dispatch key
- * from arguments matching the signature of that schema.
- *
- * @tparam OpSchemaDef Operator schema definition. See OpSchema for more details.
- * @tparam Enable Inferred, used to control specialization
- */
-template<class OpSchemaDef, class Enable = void> class OpDispatchKeySchema final {};
-
-// General case. Operator doesn't overwrite DispatchKey generation. Use default.
-template<class OpSchemaDef>
-class OpDispatchKeySchema<OpSchemaDef, guts::enable_if_t<!has_function_dispatch_key_defined<OpSchemaDef>::value>> final {
- using signature = OpSignatureSchema<OpSchemaDef>;
-
- // TODO Static assert that dispatch_key_type has operator<<(ostream, _) defined for debug output.
- // TODO Use an ADL-based debugString(DispatchKey) function instead of operator<< for debug printing.
-
-public:
- using dispatch_key_type = DispatchKey<OpSchemaDef::num_dispatch_args()>;
-
- static inline dispatch_key_type dispatch_key(const Stack* stack) {
- /* TODO Should we make this a runtime assert now?
- using guts::typelist::map_t;
- using guts::typelist::typelist;
- static_assert(std::is_same<
- map_t<guts::remove_cv_t, map_t<guts::remove_reference_t, typelist<Args...>>>,
- map_t<guts::remove_cv_t, map_t<guts::remove_reference_t, typename signature::parameter_types>>
- >::value, "Invalid argument types passed to OpSchema::dispatch_key()");*/
- return dispatch_key_type {
- details::getDispatchTypeIds_<OpSchemaDef>(torch::jit::last(*stack, signature::num_args))
- };
- }
-};
-
-// Special case. Operator overwrites DispatchKey generation. Use that.
-template<class OpSchemaDef>
-class OpDispatchKeySchema<OpSchemaDef, guts::enable_if_t<has_function_dispatch_key_defined<OpSchemaDef>::value>> final {
- using signature = OpSignatureSchema<OpSchemaDef>;
-
- static_assert(guts::is_function_type<decltype(OpSchemaDef::dispatch_key)>::value, "Operator schema defines dispatch_key member, but it isn't a function.");
-
- using dispatch_key_traits = guts::function_traits<decltype(OpSchemaDef::dispatch_key)>;
-
-public:
- using dispatch_key_type = typename dispatch_key_traits::return_type;
-
-private:
-
- static_assert(guts::is_equality_comparable<dispatch_key_type>::value, "Operator schema specified custom dispatch_key() derivation function, but the returned dispatch key type doesn't have the equality operator defined. Please define it.");
- static_assert(guts::is_hashable<dispatch_key_type>::value, "Operator schema specified custom dispatch_key() derivation function, but the returned dispatch key type doesn't have an overload for std::hash. Please define it.");
-
- static_assert(std::is_same<
- guts::typelist::typelist<const Stack*>,
- typename dispatch_key_traits::parameter_types
- >::value, "Operator schema defines custom dispatch_key() derivation function, but it has the wrong signature. Expected to take one argument, which is of type const Stack*.");
-
-public:
-
- static inline dispatch_key_type dispatch_key(const Stack* stack) {
- /* TODO Should we make this a runtime assert now?
- using guts::typelist::map_t;
- using guts::typelist::typelist;
- static_assert(std::is_same<
- map_t<guts::remove_cv_t, map_t<guts::remove_reference_t, typelist<Args...>>>,
- map_t<guts::remove_cv_t, map_t<guts::remove_reference_t, typename signature::parameter_types>>
- >::value, "Invalid argument types passed to OpSchema::dispatch_key()");
- */
- return OpSchemaDef::dispatch_key(stack);
- }
-};
-
-template<class OpSchemaDef>
-class OpMetadataSchema final {
-private:
- static_assert(has_name_defined<OpSchemaDef>::value, "The operator schema has to define a 'static constexpr const char* name = ...' member to specify the operator name.");
- static_assert(std::is_same<const char* const, decltype(OpSchemaDef::name)>::value, "The 'name' member of the operator schema must have type 'static constexpr const char*'");
-
-public:
- static constexpr const char* name() {
- return OpSchemaDef::name;
- }
-};
-
-} // namespace details
-
-/**
- * Wrapper class for user-defined OpSchemaDef, providing functionality for determining
- * information about the signature and dispatching on that signature. This is the
- * "public" facing class.
- *
- * @tparam OpSchemaDef User-defined OpSchemaDef.
- * This struct is expected to define:
- * - a function type Signature
- * - a constexpr guts<const char*, n_args> parameter_names field (where n_args is
- * the number of arguments in Signature)
- */
-template <class OpSchemaDef>
-class CAFFE2_API OpSchema final {
- // TODO static_assert OpSchemaDef isn't an instanciation of OpSchema. If yes, the caller probably passed an OpSchema somewhere where an OpSchemaDef was expected and wants a good error message.
-public:
- using metadata = details::OpMetadataSchema<OpSchemaDef>;
- /**
- * Information about the signature
- */
- using signature = details::OpSignatureSchema<OpSchemaDef>;
- /**
- * Functionality for dispatching on that signature
- */
- using dispatch = details::OpDispatchKeySchema<OpSchemaDef>;
-};
-
-// TODO test OpSchema::dispatch stuff
-} // namespace c10
diff --git a/aten/src/ATen/core/dispatch/OpSchemaRegistration.cpp b/aten/src/ATen/core/dispatch/OpSchemaRegistration.cpp
deleted file mode 100644
index f153b5e198..0000000000
--- a/aten/src/ATen/core/dispatch/OpSchemaRegistration.cpp
+++ /dev/null
@@ -1 +0,0 @@
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
diff --git a/aten/src/ATen/core/dispatch/OpSchemaRegistration.h b/aten/src/ATen/core/dispatch/OpSchemaRegistration.h
index 29ecde2f7b..07af15027d 100644
--- a/aten/src/ATen/core/dispatch/OpSchemaRegistration.h
+++ b/aten/src/ATen/core/dispatch/OpSchemaRegistration.h
@@ -2,17 +2,38 @@
#include <ATen/core/dispatch/Dispatcher.h>
-// TODO Better error message when this definition is missing
+namespace c10 {
+namespace detail {
+class OpSchemaRegistrar final {
+public:
+ explicit OpSchemaRegistrar(FunctionSchema schema)
+ : opHandle_(c10::Dispatcher::singleton().registerSchema(std::move(schema))) {}
+
+ ~OpSchemaRegistrar() {
+ c10::Dispatcher::singleton().deregisterSchema(opHandle_);
+ }
+
+ const c10::OperatorHandle& opHandle() const {
+ return opHandle_;
+ }
+
+private:
+ c10::OperatorHandle opHandle_;
+};
+} // namespace detail
+} // namespace c10
/**
- * Macro for defining an operator schema. Every user-defined OpSchemaDef struct must
- * invoke this macro on it. Internally, this arranges for the dispatch table for
+ * Macro for defining an operator schema. Every operator schema must
+ * invoke C10_DECLARE_OP_SCHEMA in a header and C10_DEFINE_OP_SCHEMA in one (!)
+ * cpp file. Internally, this arranges for the dispatch table for
* the operator to be created.
*/
-#define C10_DEFINE_OP_SCHEMA(OpSchemaDef) \
- template<> \
- C10_EXPORT c10::DispatchTable<OpSchemaDef>& c10_dispatch_table<OpSchemaDef>() { \
- static c10::DispatchTable<OpSchemaDef> singleton; \
- return singleton; \
+#define C10_DECLARE_OP_SCHEMA(Name) \
+ CAFFE2_API const c10::OperatorHandle& Name(); \
+
+#define C10_DEFINE_OP_SCHEMA(Name, Schema) \
+ C10_EXPORT const c10::OperatorHandle& Name() { \
+ static ::c10::detail::OpSchemaRegistrar registrar(Schema); \
+ return registrar.opHandle(); \
}
-// TODO Also register unboxed calling API here
diff --git a/aten/src/ATen/core/dispatch/OpSchema_test.cpp b/aten/src/ATen/core/dispatch/OpSchema_test.cpp
deleted file mode 100644
index 1f5f16a9f4..0000000000
--- a/aten/src/ATen/core/dispatch/OpSchema_test.cpp
+++ /dev/null
@@ -1,29 +0,0 @@
-#include <ATen/core/dispatch/OpSchema.h>
-#include <c10/util/Array.h>
-#include <ATen/core/Tensor.h>
-
-using namespace c10;
-using at::Tensor;
-
-static_assert(details::is_tensor_arg<Tensor>::value, "");
-static_assert(details::is_tensor_arg<const Tensor&>::value, "");
-static_assert(details::is_tensor_arg<Tensor&&>::value, "");
-static_assert(!details::is_tensor_arg<int>::value, "");
-
-struct SchemaDef final {
- using Signature = bool(int, Tensor, float, Tensor, Tensor, unsigned int);
- static constexpr guts::array<const char*, 6> parameter_names = {{
- "1", "2", "3", "4", "5", "6"
- }};
- static constexpr size_t num_dispatch_args() {return 3;}
- static constexpr size_t num_outputs() {return 0;}
-};
-static_assert(6 == OpSchema<SchemaDef>::signature::num_args, "");
-static_assert(3 == OpSchema<SchemaDef>::signature::num_tensor_args, "");
-static_assert(std::is_same<bool, typename OpSchema<SchemaDef>::signature::return_type>::value, "");
-static_assert(
- std::is_same<
- guts::typelist::
- typelist<int, Tensor, float, Tensor, Tensor, unsigned int>,
- typename OpSchema<SchemaDef>::signature::parameter_types>::value,
- "");
diff --git a/aten/src/ATen/core/dispatch/README.md b/aten/src/ATen/core/dispatch/README.md
new file mode 100644
index 0000000000..1cb2af9056
--- /dev/null
+++ b/aten/src/ATen/core/dispatch/README.md
@@ -0,0 +1,12 @@
+This folder contains the c10 dispatcher. This dispatcher is a single point
+through which we are planning to route all kernel calls.
+Existing dispatch mechanisms from legacy PyTorch or caffe2 are planned to
+be replaced.
+
+This folder contains the following files:
+- Dispatcher.h: Main facade interface. Code using the dispatcher should only use this.
+- DispatchTable.h: Implementation of the actual dispatch mechanism. Hash table with kernels, lookup, ...
+- KernelCache.h: An interface operator kernels can use to inherit from if they need to keep around a cache between invocations
+- KernelFunction.h: The core interface (i.e. function pointer) for calling a kernel
+- OpSchemaRegistration.h: The mechanisms to register new operators with the c10 dispatcher
+- KernelRegistration.h: The mechanisms to register kernels with the c10 dispatcher
diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h
index bca3f3b46e..df7cdcd3b0 100644
--- a/aten/src/ATen/core/jit_type.h
+++ b/aten/src/ATen/core/jit_type.h
@@ -5,7 +5,7 @@
#include <ATen/core/functional.h>
#include <ATen/core/Type.h>
#include <ATen/core/TensorMethods.h>
-
+#include <c10/util/TypeList.h>
#include <caffe2/core/common.h>
#include <memory>
@@ -233,6 +233,7 @@ struct CAFFE2_API OptionalType: public SingleElementType<TypeKind::OptionalType,
ss << "Optional[" << getElementType()->python_str() << "]";
return ss.str();
}
+
// common cast Optional[Tensor] for undefined tensor type
static OptionalTypePtr ofTensor();
private:
@@ -982,26 +983,51 @@ CAFFE2_API c10::optional<TypePtr> unifyTypes(
const TypePtr& t1,
const TypePtr& t2);
-template <typename T>
-TypePtr getTypePtr() {
-#define TYPE_STR(Type) #Type, " ",
- AT_ERROR(
- "Type ",
- c10::demangle_type<T>(),
- " could not be converted to any of the known types { ",
- C10_FORALL_TYPES(TYPE_STR) "}");
-#undef TYPE_STR
-}
+namespace detail {
+template <typename T> struct getTypePtr_ final {
+ static_assert(guts::false_t<T>::value, "Type could not be converted to any of the known types.");
+};
-template<> inline TypePtr getTypePtr<at::Tensor>() { return DynamicType::get(); }
-template<> inline TypePtr getTypePtr<double>() { return FloatType::get(); }
-template<> inline TypePtr getTypePtr<int64_t>() { return IntType::get(); }
-template<> inline TypePtr getTypePtr<bool>() { return BoolType::get(); }
-template<> inline TypePtr getTypePtr<at::Scalar>() { return NumberType::get(); }
-template<> inline TypePtr getTypePtr<std::string>() { return StringType::get(); }
-template<> inline TypePtr getTypePtr<std::vector<at::Tensor>>() { return ListType::ofTensors(); }
-template<> inline TypePtr getTypePtr<std::vector<double>>() { return ListType::ofFloats(); }
-template<> inline TypePtr getTypePtr<std::vector<int64_t>>() { return ListType::ofInts(); }
+template<> struct getTypePtr_<at::Tensor> final {
+ static TypePtr call() { return DynamicType::get(); }
+};
+template<> struct getTypePtr_<double> final {
+ static TypePtr call() { return FloatType::get(); }
+};
+template<> struct getTypePtr_<int64_t> final {
+ static TypePtr call() { return IntType::get(); }
+};
+template<> struct getTypePtr_<bool> final {
+ static TypePtr call() { return BoolType::get(); }
+};
+template<> struct getTypePtr_<at::Scalar> final {
+ static TypePtr call() { return NumberType::get(); }
+};
+template<> struct getTypePtr_<std::string> final {
+ static TypePtr call() { return StringType::get(); }
+};
+template<class T> struct getTypePtr_<std::vector<T>> final {
+ static TypePtr call() {
+ static auto type = ListType::create(getTypePtr_<T>::call());
+ return type;
+ }
+};
+template<class T> struct getTypePtr_<ArrayRef<T>> final {
+ static TypePtr call() {
+ static auto type = ListType::create(getTypePtr_<T>::call());
+ return type;
+ }
+};
+template<class T> struct getTypePtr_<at::optional<T>> final {
+ static TypePtr call() {
+ static auto type = OptionalType::create(getTypePtr_<T>::call());
+ return type;
+ }
+};
+}
+template<class T> inline TypePtr getTypePtr() {
+ return detail::getTypePtr_<T>::call();
+}
CAFFE2_API TypePtr incompleteInferTypeFrom(const IValue& value);
CAFFE2_API TypePtr attemptToRecoverType(const IValue& input_ivalue);
diff --git a/aten/src/ATen/core/opschema/layer_norm.cpp b/aten/src/ATen/core/opschema/layer_norm.cpp
index be908a58fc..d0749d0b14 100644
--- a/aten/src/ATen/core/opschema/layer_norm.cpp
+++ b/aten/src/ATen/core/opschema/layer_norm.cpp
@@ -1,4 +1,25 @@
#include <ATen/core/opschema/layer_norm.h>
#include <ATen/core/dispatch/OpSchemaRegistration.h>
-C10_DEFINE_OP_SCHEMA(c10::core::opschema::LayerNorm);
+namespace c10 {
+namespace core {
+namespace opschema {
+ // TODO Parse schema string instead of creating FunctionSchema manually
+ C10_DEFINE_OP_SCHEMA(LayerNorm, FunctionSchema(
+ "caffe2::layer_norm_dont_use_this_op_yet",
+ (std::vector<c10::Argument>{
+ c10::Argument("input"),
+ c10::Argument("axis", IntType::get()),
+ c10::Argument("epsilon", FloatType::get()),
+ c10::Argument("output", OptionalType::ofTensor(), c10::nullopt, IValue()),
+ c10::Argument("output_mean", OptionalType::ofTensor(), c10::nullopt, IValue()),
+ c10::Argument("output_stdev", OptionalType::ofTensor(), c10::nullopt, IValue())
+ }), (std::vector<c10::Argument>{
+ c10::Argument("output"),
+ c10::Argument("mean"),
+ c10::Argument("stdev")
+ })
+ ));
+}
+}
+}
diff --git a/aten/src/ATen/core/opschema/layer_norm.h b/aten/src/ATen/core/opschema/layer_norm.h
index 58255d0b96..0ea81ae60b 100644
--- a/aten/src/ATen/core/opschema/layer_norm.h
+++ b/aten/src/ATen/core/opschema/layer_norm.h
@@ -1,35 +1,12 @@
#pragma once
-#include <ATen/core/Tensor.h>
-#include <c10/util/Array.h>
-#include <ATen/core/blob.h>
-#include <ATen/core/dispatch/OpSchema.h>
+#include <ATen/core/dispatch/OpSchemaRegistration.h>
namespace c10 {
namespace core {
namespace opschema {
-// TODO This op schema should probably not live in c10 since it's not a method
-// on Tensor. It's only here as a proof-of-concept op and for LATTE team
-// to be able to call caffe2 layer norm from PyTorch.
-struct LayerNorm final {
- static constexpr const char* name = "LayerNorm";
-
- using Signature = std::tuple<at::Tensor, int, float, at::Tensor, at::Tensor, at::Tensor> (
- const at::Tensor& input,
- int axis,
- float epsilon,
- const at::Tensor& output,
- const at::Tensor& output_mean,
- const at::Tensor& output_stdev);
-
- static constexpr size_t num_dispatch_args() {return 1;}
-
- static constexpr size_t num_outputs() {return 3;}
-
- static constexpr c10::guts::array<const char*, 6> parameter_names = {
- {"input", "axis", "epsilon", "output", "output_mean", "output_stdev"}};
-};
+C10_DECLARE_OP_SCHEMA(LayerNorm);
} // namespace opschema
} // namespace core
diff --git a/aten/src/ATen/core/stack.h b/aten/src/ATen/core/stack.h
index 32c07a4d53..e51b6b5ef2 100644
--- a/aten/src/ATen/core/stack.h
+++ b/aten/src/ATen/core/stack.h
@@ -29,6 +29,9 @@ using Operation = std::function<int(Stack&)>;
static inline IValue& peek(Stack& stack, size_t i, size_t N) {
return *(stack.end() - N + i);
}
+static inline const IValue& peek(const Stack& stack, size_t i, size_t N) {
+ return *(stack.end() - N + i);
+}
// treat the last N elements of the stack as a list, looking up the
// slice starting at index i and having length len
static inline at::ArrayRef<IValue> peekSlice(
diff --git a/c10/test/util/TypeList_test.cpp b/c10/test/util/TypeList_test.cpp
index dba1476974..dd811b5771 100644
--- a/c10/test/util/TypeList_test.cpp
+++ b/c10/test/util/TypeList_test.cpp
@@ -137,5 +137,11 @@ namespace test_map_types_to_values {
static_assert(std::is_same<decltype(expected), decltype(result)>::value, "");
EXPECT_EQ(expected, result);
}
+}
+namespace test_find_if {
+ static_assert(0 == find_if<typelist<char&>, std::is_reference>::value, "");
+ static_assert(0 == find_if<typelist<char&, int, char&, int&>, std::is_reference>::value, "");
+ static_assert(2 == find_if<typelist<char, int, char&, int&>, std::is_reference>::value, "");
+ static_assert(3 == find_if<typelist<char, int, char, int&>, std::is_reference>::value, "");
}
diff --git a/c10/util/TypeList.h b/c10/util/TypeList.h
index 4a55e9aaae..bb0d6bd7c6 100644
--- a/c10/util/TypeList.h
+++ b/c10/util/TypeList.h
@@ -3,11 +3,14 @@
#include <c10/util/C++17.h>
#include <c10/util/TypeTraits.h>
-namespace c10 { namespace guts { namespace typelist {
+namespace c10 { namespace guts {
-namespace detail {
template<class... T> struct false_t : std::false_type {};
-}
+template<template<class> class... T> struct false_higher_t : std::false_type {};
+
+namespace typelist {
+
+
/**
* Type holding a list of types for compile time type computations
@@ -25,7 +28,7 @@ private:
* 3 == size<typelist<int, int, double>>::value
*/
template<class TypeList> struct size final {
- static_assert(detail::false_t<TypeList>::value, "In typelist::size<T>, T must be typelist<...>.");
+ static_assert(false_t<TypeList>::value, "In typelist::size<T>, T must be typelist<...>.");
};
template<class... Types> struct size<typelist<Types...>> final {
static constexpr size_t value = sizeof...(Types);
@@ -39,7 +42,7 @@ template<class... Types> struct size<typelist<Types...>> final {
* std::tuple<int, string> == to_tuple_t<typelist<int, string>>
*/
template<class TypeList> struct to_tuple final {
- static_assert(detail::false_t<TypeList>::value, "In typelist::to_tuple<T>, T must be typelist<...>.");
+ static_assert(false_t<TypeList>::value, "In typelist::to_tuple<T>, T must be typelist<...>.");
};
template<class... Types> struct to_tuple<typelist<Types...>> final {
using type = std::tuple<Types...>;
@@ -55,7 +58,7 @@ template<class TypeList> using to_tuple_t = typename to_tuple<TypeList>::type;
* typelist<int, string> == from_tuple_t<std::tuple<int, string>>
*/
template<class Tuple> struct from_tuple final {
- static_assert(detail::false_t<Tuple>::value, "In typelist::from_tuple<T>, T must be std::tuple<...>.");
+ static_assert(false_t<Tuple>::value, "In typelist::from_tuple<T>, T must be std::tuple<...>.");
};
template<class... Types> struct from_tuple<std::tuple<Types...>> final {
using type = typelist<Types...>;
@@ -70,7 +73,7 @@ template<class Tuple> using from_tuple_t = typename from_tuple<Tuple>::type;
* typelist<int, string, int> == concat_t<typelist<int, string>, typelist<int>>
*/
template<class... TypeLists> struct concat final {
- static_assert(detail::false_t<TypeLists...>::value, "In typelist::concat<T1, ...>, the T arguments each must be typelist<...>.");
+ static_assert(false_t<TypeLists...>::value, "In typelist::concat<T1, ...>, the T arguments each must be typelist<...>.");
};
template<class... Head1Types, class... Head2Types, class... TailLists>
struct concat<typelist<Head1Types...>, typelist<Head2Types...>, TailLists...> final {
@@ -94,7 +97,7 @@ template<class... TypeLists> using concat_t = typename concat<TypeLists...>::typ
* typelist<int&, const string&&> == filter_t<std::is_reference, typelist<void, string, int&, bool, const string&&, int>>
*/
template<template <class> class Condition, class TypeList> struct filter final {
- static_assert(detail::false_t<TypeList>::value, "In typelist::filter<Condition, TypeList>, the TypeList argument must be typelist<...>.");
+ static_assert(false_t<TypeList>::value, "In typelist::filter<Condition, TypeList>, the TypeList argument must be typelist<...>.");
};
template<template <class> class Condition, class Head, class... Tail>
struct filter<Condition, typelist<Head, Tail...>> final {
@@ -137,7 +140,7 @@ struct count_if final {
* false == true_for_each_type<std::is_reference, typelist<int&, const float&&, MyClass>>::value
*/
template<template <class> class Condition, class TypeList> struct true_for_each_type final {
- static_assert(detail::false_t<TypeList>::value, "In typelist::true_for_each_type<Condition, TypeList>, the TypeList argument must be typelist<...>.");
+ static_assert(false_t<TypeList>::value, "In typelist::true_for_each_type<Condition, TypeList>, the TypeList argument must be typelist<...>.");
};
template<template <class> class Condition, class... Types>
struct true_for_each_type<Condition, typelist<Types...>> final
@@ -153,7 +156,7 @@ struct true_for_each_type<Condition, typelist<Types...>> final
* typelist<int&, double&, string&> == map_t<std::add_lvalue_reference_t, typelist<int, double, string>>
*/
template<template <class> class Mapper, class TypeList> struct map final {
- static_assert(detail::false_t<TypeList>::value, "In typelist::map<Mapper, TypeList>, the TypeList argument must be typelist<...>.");
+ static_assert(false_t<TypeList>::value, "In typelist::map<Mapper, TypeList>, the TypeList argument must be typelist<...>.");
};
template<template <class> class Mapper, class... Types>
struct map<Mapper, typelist<Types...>> final {
@@ -170,7 +173,7 @@ using map_t = typename map<Mapper, TypeList>::type;
* int == head_t<typelist<int, string>>
*/
template<class TypeList> struct head final {
- static_assert(detail::false_t<TypeList>::value, "In typelist::head<T>, the T argument must be typelist<...>.");
+ static_assert(false_t<TypeList>::value, "In typelist::head<T>, the T argument must be typelist<...>.");
};
template<class Head, class... Tail> struct head<typelist<Head, Tail...>> final {
using type = Head;
@@ -185,7 +188,7 @@ template<class TypeList> using head_t = typename head<TypeList>::type;
/// Base template.
template<size_t Index, class TypeList> struct element final {
- static_assert(detail::false_t<TypeList>::value, "In typelist::element<T>, the T argument must be typelist<...>.");
+ static_assert(false_t<TypeList>::value, "In typelist::element<T>, the T argument must be typelist<...>.");
};
/// Successful case, we have reached the zero index and can "return" the head type.
@@ -213,7 +216,7 @@ using element_t = typename element<Index, TypeList>::type;
template <class TypeList>
struct last final {
static_assert(
- detail::false_t<TypeList>::value,
+ false_t<TypeList>::value,
"In typelist::last<T>, the T argument must be typelist<...>.");
};
template <class Head, class... Tail>
@@ -236,7 +239,7 @@ static_assert(
* typelist<int, string> == reverse_t<typelist<string, int>>
*/
template<class TypeList> struct reverse final {
- static_assert(detail::false_t<TypeList>::value, "In typelist::reverse<T>, the T argument must be typelist<...>.");
+ static_assert(false_t<TypeList>::value, "In typelist::reverse<T>, the T argument must be typelist<...>.");
};
template<class Head, class... Tail> struct reverse<typelist<Head, Tail...>> final {
using type = concat_t<typename reverse<typelist<Tail...>>::type, typelist<Head>>;
@@ -247,6 +250,28 @@ template<> struct reverse<typelist<>> final {
template<class TypeList> using reverse_t = typename reverse<TypeList>::type;
+/**
+ * Find the index of the first type in a typelist fulfilling a type trait condition.
+ * Example:
+ *
+ * 2 == find_if<typelist<char, int, char&, int&>, std::is_reference>::value
+ */
+template<class TypeList, template<class> class Condition, class Enable = void> struct find_if final {
+ static_assert(false_t<TypeList>::value, "In typelist::find_if<TypeList, Condition>, the TypeList argument must be typelist<...>.");
+};
+template<template<class> class Condition> struct find_if<typelist<>, Condition, void> final {
+ static_assert(false_higher_t<Condition>::value, "In typelist::find_if<Type/List, Condition>, didn't find any type fulfilling the Condition.");
+};
+template<class Head, class... Tail, template<class> class Condition>
+struct find_if<typelist<Head, Tail...>, Condition, enable_if_t<Condition<Head>::value>> final {
+ static constexpr size_t value = 0;
+};
+template<class Head, class... Tail, template<class> class Condition>
+struct find_if<typelist<Head, Tail...>, Condition, enable_if_t<!Condition<Head>::value>> final {
+ static constexpr size_t value = 1 + find_if<typelist<Tail...>, Condition>::value;
+};
+
+
/**
* Maps a list of types into a list of values.
@@ -282,7 +307,7 @@ template<class T> struct type_ final {
using type = T;
};
template<class TypeList> struct map_types_to_values final {
- static_assert(detail::false_t<TypeList>::value, "In typelist::map_types_to_values<T>, the T argument must be typelist<...>.");
+ static_assert(false_t<TypeList>::value, "In typelist::map_types_to_values<T>, the T argument must be typelist<...>.");
};
template<class... Types> struct map_types_to_values<typelist<Types...>> final {
template<class Func>
diff --git a/caffe2/core/operator.h b/caffe2/core/operator.h
index 253428645d..a12767c456 100644
--- a/caffe2/core/operator.h
+++ b/caffe2/core/operator.h
@@ -203,6 +203,16 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
return XBlobGetMutableTensor(outputs_.at(idx), dims, options);
}
+ void SetOutputTensor(int idx, Tensor tensor) {
+ // also update the tensor in the hack
+ if (!isLegacyOperator()) {
+ output_tensors_[idx] = tensor.UnsafeSharedInstance();
+ }
+
+ // update the tensor in the workspace
+ BlobSetTensor(outputs_.at(idx), std::move(tensor));
+ }
+
inline Tensor*
OutputTensor(int idx, at::IntList dims, at::TensorOptions options) {
if (isLegacyOperator()) {
diff --git a/caffe2/core/operator_c10wrapper.h b/caffe2/core/operator_c10wrapper.h
index fb0c458ce9..ab0230d2af 100644
--- a/caffe2/core/operator_c10wrapper.h
+++ b/caffe2/core/operator_c10wrapper.h
@@ -23,19 +23,15 @@ using extract_type_t = typename ParameterDef::type;
*
* REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH(C10Add, C2MyAddOpName)
*
- * Note: This wrapper currently only supports C10 ops that have exactly one
- * output and take that in the last parameter as "Tensor* output".
- * TODO: Figure out a better way to handle output parameters
*/
template <
- class OpSchemaDef,
+ const c10::OperatorHandle& (*OperatorHandle)(),
class Context,
bool use_array_input,
+ size_t num_output_parameters,
class ParameterDefTuple>
class C10OperatorWrapper final : public Operator<Context> {
- using Schema = c10::OpSchema<OpSchemaDef>;
-
public:
static_assert(
c10::guts::is_instantiation_of<std::tuple, ParameterDefTuple>::value,
@@ -49,28 +45,39 @@ class C10OperatorWrapper final : public Operator<Context> {
C10OperatorWrapper(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
+ op_(OperatorHandle()),
kernel_(at::nullopt),
parameters_(parse_parameters_(
operator_def,
- c10::guts::make_index_sequence<num_parameters()>())) {}
+ c10::guts::make_index_sequence<num_parameters()>())) {
- static constexpr size_t num_inputs() {
- return Schema::signature::num_args - num_outputs() - num_parameters();
+ AT_ASSERT(operator_def.output_size() == op_.schema().returns().size());
+ AT_ASSERT(operator_def.input_size() == num_inputs());
}
- static constexpr size_t num_parameters() {
- return std::tuple_size<ParameterDefTuple>::value;
+ size_t num_inputs() {
+ return op_.schema().arguments().size() - num_output_parameters - num_parameters();
}
- static constexpr size_t num_outputs() {
- return Schema::signature::num_outputs;
+ static constexpr size_t num_parameters() {
+ return std::tuple_size<ParameterDefTuple>::value;
}
bool RunOnDevice() override {
- RunOnDevice_(
- c10::guts::make_index_sequence<num_inputs()>(),
- c10::guts::make_index_sequence<num_outputs()>(),
- c10::guts::make_index_sequence<num_parameters()>());
+ // due to caching the stack_, concurrent calling is not allowed.
+ // TODO thread_local might fix this
+ std::lock_guard<std::mutex> lock(mutex_);
+
+ AT_ASSERT(stack_.size() == 0);
+
+ pushInputs_();
+ pushParameters_(guts::make_index_sequence<num_parameters()>());
+ pushOutputParameters_();
+
+ callKernel_();
+
+ popOutputs_();
+
return true;
}
@@ -91,56 +98,44 @@ class C10OperatorWrapper final : public Operator<Context> {
return Parameter::parse(ArgumentHelper(operator_def));
}
- template <
- size_t... InputIndex,
- size_t... OutputIndex,
- size_t... ParameterIndex>
- c10::guts::enable_if_t<
- details::true_t<InputIndex...>::value &&
- !use_array_input,
- void>
- RunOnDevice_(
- c10::guts::index_sequence<InputIndex...>,
- c10::guts::index_sequence<OutputIndex...>,
- c10::guts::index_sequence<ParameterIndex...>) {
- Stack stack;
- torch::jit::push(stack,
- IValue(at::Tensor(C10Tensor(Input(InputIndex))))...,
- IValue(std::get<ParameterIndex>(parameters_))...,
- IValue(at::Tensor(C10Tensor(*Output(OutputIndex))))...
- );
- call_(&stack);
- // TODO Do we have to Write outputs from stack back into the workspace?
+ void pushInputs_() {
+ if (use_array_input) {
+ stack_.emplace_back(ivalue::TensorList::create(array_inputs_()));
+ } else {
+ for (size_t i = 0; i < num_inputs(); ++i) {
+ stack_.emplace_back(at::Tensor(C10Tensor(Input(i))));
+ }
+ }
}
- template <
- size_t... InputIndex,
- size_t... OutputIndex,
- size_t... ParameterIndex>
- c10::guts::enable_if_t<
- details::true_t<InputIndex...>::value &&
- use_array_input,
- void>
- RunOnDevice_(
- c10::guts::index_sequence<InputIndex...>,
- c10::guts::index_sequence<OutputIndex...>,
- c10::guts::index_sequence<ParameterIndex...>) {
- Stack stack;
- torch::jit::push(stack,
- IValue(ivalue::TensorList::create(array_inputs_())),
- IValue(std::get<ParameterIndex>(parameters_))...,
- IValue(at::Tensor(C10Tensor(*Output(OutputIndex))))...
- );
- call_(&stack);
- // TODO Do we have to Write outputs from stack back into the workspace?
+ template<size_t... ParameterIndex>
+ void pushParameters_(guts::index_sequence<ParameterIndex...>) {
+ (void)std::initializer_list<int>{(
+ stack_.emplace_back(std::get<ParameterIndex>(parameters_))
+ , 0)...};
+ }
+
+ void pushOutputParameters_() {
+ for (size_t i = 0; i < num_output_parameters; ++i) {
+ stack_.emplace_back(at::Tensor(C10Tensor(*Output(i))));
+ }
}
- void call_(Stack* stack) {
+ void callKernel_() {
+ AT_ASSERT(stack_.size() == op_.schema().arguments().size());
if (!kernel_.has_value()) {
// TODO if kernel is already set, try re-dispatch to assert it goes to the same kernel
- kernel_ = c10::Dispatcher<OpSchemaDef>::lookup(stack);
+ kernel_ = c10::Dispatcher::singleton().lookup(op_, &stack_);
}
- kernel_->call(stack);
+ kernel_->call(&stack_);
+ }
+
+ void popOutputs_() {
+ AT_ASSERT(stack_.size() == op_.schema().returns().size());
+ for (size_t i = 0; i < op_.schema().returns().size(); ++i) {
+ OperatorBase::SetOutputTensor(i, Tensor(C10Tensor(std::move(stack_[i]).toTensor())));
+ }
+ stack_.clear();
}
std::vector<at::Tensor> array_inputs_() {
@@ -152,8 +147,15 @@ class C10OperatorWrapper final : public Operator<Context> {
return result;
}
+ c10::OperatorHandle op_;
c10::optional<OpKernel> kernel_;
+ // this is stored as a member here to avoid having to re-allocate a stack
+ // for each call. Between kernel calls, stack_.size() == 0, but capacity
+ // should not need to be grown anymore after the first call.
+ std::vector<IValue> stack_;
+ std::mutex mutex_;
+
ParameterTuple parameters_;
};
@@ -174,39 +176,41 @@ C10_DECLARE_REGISTRY(
// TODO Currently we only register the CPU variant. This is going to be fixed
// once the tensor detemplatization lands.
-#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH(OpSchemaDef, Name) \
- C10_REGISTER_CLASS( \
- C10OperatorRegistry, \
- Name, \
- C10OperatorWrapper<OpSchemaDef, CPUContext, false, std::tuple<>>)
+#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH(OperatorHandle, Name, NumOutputParameters) \
+ C10_REGISTER_CLASS( \
+ C10OperatorRegistry, \
+ Name, \
+ C10OperatorWrapper<OperatorHandle, CPUContext, false, NumOutputParameters, std::tuple<>>)
#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_PARAMETERS( \
- OpSchemaDef, Name, ...) \
+ OperatorHandle, Name, NumOutputParameters, ...) \
C10_REGISTER_CLASS( \
C10OperatorRegistry, \
Name, \
C10OperatorWrapper< \
- OpSchemaDef, \
+ OperatorHandle, \
CPUContext, \
false, \
+ NumOutputParameters, \
std::tuple<__VA_ARGS__>>)
#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_ARRAY_INPUT( \
- OpSchemaDef, Name) \
+ OperatorHandle, Name, NumOutputParameters) \
C10_REGISTER_CLASS( \
C10OperatorRegistry, \
Name, \
- C10OperatorWrapper<OpSchemaDef, CPUContext, true, std::tuple<>>)
+ C10OperatorWrapper<OperatorHandle, CPUContext, true, NumOutputParameters, std::tuple<>>)
#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_ARRAY_INPUT_AND_PARAMETERS( \
- OpSchemaDef, Name, ...) \
+ OperatorHandle, Name, NumOutputParameters, ...) \
C10_REGISTER_CLASS( \
C10OperatorRegistry, \
Name, \
C10OperatorWrapper< \
- OpSchemaDef, \
+ OperatorHandle, \
CPUContext, \
true, \
+ NumOutputParameters, \
std::tuple<__VA_ARGS__>>)
} // namespace caffe2
diff --git a/caffe2/operators/experimental/c10/cpu/add_cpu.cc b/caffe2/operators/experimental/c10/cpu/add_cpu.cc
index 9b81641731..9ec1aa1a8b 100644
--- a/caffe2/operators/experimental/c10/cpu/add_cpu.cc
+++ b/caffe2/operators/experimental/c10/cpu/add_cpu.cc
@@ -15,7 +15,7 @@ void add_op_cpu_impl(
const at::Tensor& B_,
const at::Tensor& C_,
bool legacy_broadcast,
- int axis) {
+ int64_t axis) {
Tensor A{C10Tensor(A_)};
Tensor B{C10Tensor(B_)};
Tensor C{C10Tensor(C_)};
@@ -74,13 +74,6 @@ void add_op_cpu_impl(
namespace c10 {
C10_REGISTER_KERNEL(caffe2::ops::Add)
- .kernel<&caffe2::add_op_cpu_impl<float>>()
- .dispatchKey(c10::DispatchKey<2>{
- c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<float>()},
- c10::details::TensorParameterDispatchKey{
- DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<float>()}});
+ .kernel<decltype(caffe2::add_op_cpu_impl<float>), &caffe2::add_op_cpu_impl<float>>()
+ .dispatchKey(CPUTensorId());
} // namespace c10
diff --git a/caffe2/operators/experimental/c10/cpu/averaged_loss_cpu.cc b/caffe2/operators/experimental/c10/cpu/averaged_loss_cpu.cc
index 6cd8cbf94f..cc5823d9e3 100644
--- a/caffe2/operators/experimental/c10/cpu/averaged_loss_cpu.cc
+++ b/caffe2/operators/experimental/c10/cpu/averaged_loss_cpu.cc
@@ -10,7 +10,7 @@ using std::vector;
namespace caffe2 {
namespace {
-struct State final : public c10::KernelState {
+struct Cache final : public c10::KernelCache {
at::Tensor scratch = at::Tensor(C10Tensor(empty({}, CPU)));
};
@@ -18,7 +18,7 @@ template <class T, class Context>
void averaged_loss_op_cpu_impl(
const at::Tensor& X_,
const at::Tensor& sum_,
- State* state) {
+ Cache* state) {
Tensor X{C10Tensor(X_)};
Tensor sum{C10Tensor(sum_)};
CPUContext context;
@@ -48,10 +48,7 @@ void averaged_loss_op_cpu_impl(
namespace c10 {
C10_REGISTER_KERNEL(caffe2::ops::AveragedLoss)
- .withState<caffe2::State>()
- .kernel<&caffe2::averaged_loss_op_cpu_impl<float, caffe2::CPUContext>>()
- .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{
- DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<float>()}});
+ .withCache<caffe2::Cache>()
+ .kernel<decltype(caffe2::averaged_loss_op_cpu_impl<float, caffe2::CPUContext>), &caffe2::averaged_loss_op_cpu_impl<float, caffe2::CPUContext>>()
+ .dispatchKey(CPUTensorId());
} // namespace c10
diff --git a/caffe2/operators/experimental/c10/cpu/batch_gather_cpu.cc b/caffe2/operators/experimental/c10/cpu/batch_gather_cpu.cc
index 3edb5ea841..a34c098d82 100644
--- a/caffe2/operators/experimental/c10/cpu/batch_gather_cpu.cc
+++ b/caffe2/operators/experimental/c10/cpu/batch_gather_cpu.cc
@@ -53,29 +53,21 @@ void batch_gather_op_cpu_impl(
}
}
}
+
+void batch_gather_op_cpu(const at::Tensor& data,
+ const at::Tensor& indices,
+ const at::Tensor& output) {
+ switch (data.scalar_type()) {
+ case ScalarType::Int: return batch_gather_op_cpu_impl<int>(data, indices, output);
+ case ScalarType::Long: return batch_gather_op_cpu_impl<int64_t>(data, indices, output);
+ default: throw std::runtime_error(string() + "Unsupported dtype: " + toString(data.scalar_type()));
+ }
+}
} // namespace
} // namespace caffe2
namespace c10 {
C10_REGISTER_KERNEL(caffe2::ops::BatchGather)
- .kernel<&caffe2::batch_gather_op_cpu_impl<int64_t>>()
- .dispatchKey(c10::DispatchKey<2>{
- c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<float>()},
- c10::details::TensorParameterDispatchKey{
- DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<int64_t>()}});
-
-C10_REGISTER_KERNEL(caffe2::ops::BatchGather)
- .kernel<&caffe2::batch_gather_op_cpu_impl<int32_t>>()
- .dispatchKey(c10::DispatchKey<2>{
- c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<float>()},
- c10::details::TensorParameterDispatchKey{
- DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<int32_t>()}});
+ .kernel<decltype(caffe2::batch_gather_op_cpu), &caffe2::batch_gather_op_cpu>()
+ .dispatchKey(CPUTensorId());
} // namespace c10
diff --git a/caffe2/operators/experimental/c10/cpu/batch_matmul_cpu.cc b/caffe2/operators/experimental/c10/cpu/batch_matmul_cpu.cc
index 8d977466eb..dfe85237a4 100644
--- a/caffe2/operators/experimental/c10/cpu/batch_matmul_cpu.cc
+++ b/caffe2/operators/experimental/c10/cpu/batch_matmul_cpu.cc
@@ -11,7 +11,7 @@ namespace math = caffe2::math;
namespace caffe2 {
namespace {
-struct State final : public c10::KernelState {
+struct Cache final : public c10::KernelCache {
std::shared_ptr<at::Tensor> scratch;
};
@@ -20,10 +20,10 @@ void batch_matmul_op_cpu_impl(
const at::Tensor& A_,
const at::Tensor& B_,
const at::Tensor& Y_,
- int trans_a,
- int trans_b,
- int broadcast,
- State* state) {
+ int64_t trans_a,
+ int64_t trans_b,
+ int64_t broadcast,
+ Cache* cache) {
Tensor A{C10Tensor(A_)};
Tensor B{C10Tensor(B_)};
Tensor Y{C10Tensor(Y_)};
@@ -273,14 +273,7 @@ void batch_matmul_op_cpu_impl(
namespace c10 {
C10_REGISTER_KERNEL(caffe2::ops::BatchMatmul)
- .withState<caffe2::State>()
- .kernel<&caffe2::batch_matmul_op_cpu_impl<float, caffe2::CPUContext>>()
- .dispatchKey(c10::DispatchKey<2>{
- c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<float>()},
- c10::details::TensorParameterDispatchKey{
- DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<float>()}});
+ .withCache<caffe2::Cache>()
+ .kernel<decltype(caffe2::batch_matmul_op_cpu_impl<float, caffe2::CPUContext>), &caffe2::batch_matmul_op_cpu_impl<float, caffe2::CPUContext>>()
+ .dispatchKey(CPUTensorId());
} // namespace c10
diff --git a/caffe2/operators/experimental/c10/cpu/cast_cpu.cc b/caffe2/operators/experimental/c10/cpu/cast_cpu.cc
index 62c2a62fca..a2203c5927 100644
--- a/caffe2/operators/experimental/c10/cpu/cast_cpu.cc
+++ b/caffe2/operators/experimental/c10/cpu/cast_cpu.cc
@@ -74,62 +74,22 @@ void cast_op_cpu_impl(
CAFFE_THROW("Unexpected 'to' argument value: ", to);
}
}
+void cast_op_cpu(
+ const at::Tensor& input,
+ const at::Tensor& output,
+ int64_t to) {
+ switch (input.scalar_type()) {
+#define CASE(ctype,name,_2) case ScalarType:: name : return cast_op_cpu_impl<ctype>(input, output, to);
+ AT_FORALL_SCALAR_TYPES(CASE)
+#undef CASE
+ default: throw std::runtime_error(string() + "Unsupported scalar type " + toString(input.scalar_type()));
+ }
+}
} // namespace
} // namespace caffe2
namespace c10 {
C10_REGISTER_KERNEL(caffe2::ops::Cast)
- .kernel<&caffe2::cast_op_cpu_impl<float>>()
- .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{
- DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<float>()}});
-C10_REGISTER_KERNEL(caffe2::ops::Cast)
- .kernel<&caffe2::cast_op_cpu_impl<int32_t>>()
- .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{
- DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<int32_t>()}});
-C10_REGISTER_KERNEL(caffe2::ops::Cast)
- .kernel<&caffe2::cast_op_cpu_impl<bool>>()
- .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{
- DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<bool>()}});
-C10_REGISTER_KERNEL(caffe2::ops::Cast)
- .kernel<&caffe2::cast_op_cpu_impl<uint8_t>>()
- .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{
- DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<uint8_t>()}});
-C10_REGISTER_KERNEL(caffe2::ops::Cast)
- .kernel<&caffe2::cast_op_cpu_impl<int8_t>>()
- .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{
- DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<int8_t>()}});
-C10_REGISTER_KERNEL(caffe2::ops::Cast)
- .kernel<&caffe2::cast_op_cpu_impl<uint16_t>>()
- .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{
- DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<uint16_t>()}});
-C10_REGISTER_KERNEL(caffe2::ops::Cast)
- .kernel<&caffe2::cast_op_cpu_impl<int16_t>>()
- .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{
- DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<int16_t>()}});
-C10_REGISTER_KERNEL(caffe2::ops::Cast)
- .kernel<&caffe2::cast_op_cpu_impl<int64_t>>()
- .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{
- DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<int64_t>()}});
-C10_REGISTER_KERNEL(caffe2::ops::Cast)
- .kernel<&caffe2::cast_op_cpu_impl<double>>()
- .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{
- DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<double>()}});
+ .kernel<decltype(caffe2::cast_op_cpu), &caffe2::cast_op_cpu>()
+ .dispatchKey(CPUTensorId());
} // namespace c10
diff --git a/caffe2/operators/experimental/c10/cpu/concat_cpu.cc b/caffe2/operators/experimental/c10/cpu/concat_cpu.cc
index 858675cfce..931a7d221e 100644
--- a/caffe2/operators/experimental/c10/cpu/concat_cpu.cc
+++ b/caffe2/operators/experimental/c10/cpu/concat_cpu.cc
@@ -16,8 +16,8 @@ void concat_op_cpu_impl(
ArrayRef<at::Tensor> inputs,
const at::Tensor& output_,
const at::Tensor& split_,
- int axis,
- int add_axis) {
+ int64_t axis,
+ int64_t add_axis) {
Tensor output{C10Tensor(output_)};
Tensor split{C10Tensor(split_)};
CPUContext context;
@@ -108,6 +108,6 @@ void concat_op_cpu_impl(
namespace c10 {
C10_REGISTER_KERNEL(caffe2::ops::Concat)
- .kernel<&caffe2::concat_op_cpu_impl<float, CPUContext>>()
- .dispatchKey(c10::DeviceTypeId::CPU);
+ .kernel<decltype(caffe2::concat_op_cpu_impl<float, CPUContext>), &caffe2::concat_op_cpu_impl<float, CPUContext>>()
+ .dispatchKey(CPUTensorId());
} // namespace c10
diff --git a/caffe2/operators/experimental/c10/cpu/enforce_finite_cpu.cc b/caffe2/operators/experimental/c10/cpu/enforce_finite_cpu.cc
index 1db6ed2c22..2e9f608bd1 100644
--- a/caffe2/operators/experimental/c10/cpu/enforce_finite_cpu.cc
+++ b/caffe2/operators/experimental/c10/cpu/enforce_finite_cpu.cc
@@ -28,8 +28,6 @@ void enforce_finite_op_impl_cpu(const at::Tensor& input_) {
namespace c10 {
C10_REGISTER_KERNEL(caffe2::ops::EnforceFinite)
- .kernel<&caffe2::enforce_finite_op_impl_cpu<float>>()
- .dispatchKey({DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<float>()});
+ .kernel<decltype(caffe2::enforce_finite_op_impl_cpu<float>), &caffe2::enforce_finite_op_impl_cpu<float>>()
+ .dispatchKey(CPUTensorId());
} // namespace c10
diff --git a/caffe2/operators/experimental/c10/cpu/expand_dims_cpu.cc b/caffe2/operators/experimental/c10/cpu/expand_dims_cpu.cc
index 38c968163c..a7195009b8 100644
--- a/caffe2/operators/experimental/c10/cpu/expand_dims_cpu.cc
+++ b/caffe2/operators/experimental/c10/cpu/expand_dims_cpu.cc
@@ -8,7 +8,7 @@ using caffe2::Tensor;
namespace caffe2 {
namespace {
-struct State final : public c10::KernelState {
+struct Cache final : public c10::KernelCache {
std::vector<int64_t> dims;
bool initialized = false;
};
@@ -18,38 +18,38 @@ void expand_dims_op_cpu_impl(
const at::Tensor& input_,
const at::Tensor& output_,
ArrayRef<int64_t> dims,
- State* state) {
+ Cache* cache) {
Tensor input{C10Tensor(input_)};
Tensor output{C10Tensor(output_)};
- if (!state->initialized) {
- state->dims = dims.vec();
- auto originalSize = state->dims.size();
+ if (!cache->initialized) {
+ cache->dims = dims.vec();
+ auto originalSize = cache->dims.size();
CAFFE_ENFORCE(originalSize > 0, "Parameter `dims` must be provided.");
- std::sort(state->dims.begin(), state->dims.end());
- state->dims.erase(
- std::unique(state->dims.begin(), state->dims.end()), state->dims.end());
- if (state->dims.size() < originalSize) {
+ std::sort(cache->dims.begin(), cache->dims.end());
+ cache->dims.erase(
+ std::unique(cache->dims.begin(), cache->dims.end()), cache->dims.end());
+ if (cache->dims.size() < originalSize) {
LOG(WARNING) << "Parameter `dims` has repeated dimensions.";
}
CAFFE_ENFORCE(
- state->dims.front() >= 0, "Dimension ids must be non-negative.");
- state->initialized = true;
+ cache->dims.front() >= 0, "Dimension ids must be non-negative.");
+ cache->initialized = true;
}
output.CopyFrom(input);
- if (state->dims.empty()) {
+ if (cache->dims.empty()) {
return;
}
auto newDims = input.sizes().vec();
CAFFE_ENFORCE_GE(
- input.sizes().size() + state->dims.size(),
- state->dims.back() + 1,
+ input.sizes().size() + cache->dims.size(),
+ cache->dims.back() + 1,
"Input needs at least ",
- (1 + state->dims.back() - state->dims.size()),
+ (1 + cache->dims.back() - cache->dims.size()),
" dimensions given `dims`.");
- for (const auto dim : state->dims) {
+ for (const auto dim : cache->dims) {
newDims.insert(newDims.begin() + dim, 1);
}
output.Reshape(newDims);
@@ -59,9 +59,7 @@ void expand_dims_op_cpu_impl(
namespace c10 {
C10_REGISTER_KERNEL(caffe2::ops::ExpandDims)
- .withState<caffe2::State>()
- .kernel<&caffe2::expand_dims_op_cpu_impl<float>>()
- .dispatchKey({DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<float>()});
+ .withCache<caffe2::Cache>()
+ .kernel<decltype(caffe2::expand_dims_op_cpu_impl<float>), &caffe2::expand_dims_op_cpu_impl<float>>()
+ .dispatchKey(CPUTensorId());
} // namespace c10
diff --git a/caffe2/operators/experimental/c10/cpu/fc_cpu.cc b/caffe2/operators/experimental/c10/cpu/fc_cpu.cc
index 18bed5d868..99e05458db 100644
--- a/caffe2/operators/experimental/c10/cpu/fc_cpu.cc
+++ b/caffe2/operators/experimental/c10/cpu/fc_cpu.cc
@@ -12,7 +12,7 @@ using caffe2::Tensor;
namespace caffe2 {
namespace {
-struct State final : public c10::KernelState {
+struct Cache final : public c10::KernelCache {
vector<int64_t> Y_shape_cache_;
at::Tensor bias_multiplier_ = at::Tensor(C10Tensor(Tensor()));
};
@@ -23,9 +23,9 @@ void fc_op_cpu_impl(
const at::Tensor& W_,
const at::Tensor& b_,
const at::Tensor& Y_,
- int axis,
- int axis_w,
- State* state) {
+ int64_t axis,
+ int64_t axis_w,
+ Cache* cache) {
Tensor X{C10Tensor(X_)};
Tensor W{C10Tensor(W_)};
Tensor b{C10Tensor(b_)};
@@ -68,12 +68,12 @@ void fc_op_cpu_impl(
CAFFE_ENFORCE(N == b.dim32(0), dimErrorString());
CAFFE_ENFORCE(N == b.numel(), dimErrorString());
- state->Y_shape_cache_ = X.sizes().vec();
+ cache->Y_shape_cache_ = X.sizes().vec();
// This is an invariant of canonical_axis, so we can DCHECK.
- DCHECK_LE(canonical_axis + 1, state->Y_shape_cache_.size());
- state->Y_shape_cache_.resize(canonical_axis + 1);
- state->Y_shape_cache_[canonical_axis] = N;
- Y.Resize(state->Y_shape_cache_);
+ DCHECK_LE(canonical_axis + 1, cache->Y_shape_cache_.size());
+ cache->Y_shape_cache_.resize(canonical_axis + 1);
+ cache->Y_shape_cache_[canonical_axis] = N;
+ Y.Resize(cache->Y_shape_cache_);
CAFFE_ENFORCE(M * N == Y.numel(), dimErrorString());
if (X.numel() == 0) {
@@ -103,7 +103,7 @@ void fc_op_cpu_impl(
static_cast<Context*>(&context),
math_type);
// Add bias term
- Tensor bias_multiplier(state->bias_multiplier_);
+ Tensor bias_multiplier(cache->bias_multiplier_);
ReinitializeTensor(&bias_multiplier, {M}, at::dtype<DataType>().device(CPU));
caffe2::math::Set<DataType, Context>(
M,
@@ -129,17 +129,7 @@ void fc_op_cpu_impl(
namespace c10 {
C10_REGISTER_KERNEL(caffe2::ops::FullyConnected)
- .withState<caffe2::State>()
- .kernel<&caffe2::fc_op_cpu_impl<float, caffe2::CPUContext>>()
- .dispatchKey(c10::DispatchKey<3>{
- c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<float>()},
- c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<float>()},
- c10::details::TensorParameterDispatchKey{
- DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<float>()}});
+ .withCache<caffe2::Cache>()
+ .kernel<decltype(caffe2::fc_op_cpu_impl<float, caffe2::CPUContext>), &caffe2::fc_op_cpu_impl<float, caffe2::CPUContext>>()
+ .dispatchKey(CPUTensorId());
} // namespace c10
diff --git a/caffe2/operators/experimental/c10/cpu/filler_cpu.cc b/caffe2/operators/experimental/c10/cpu/filler_cpu.cc
index 376bc924f7..6db668547f 100644
--- a/caffe2/operators/experimental/c10/cpu/filler_cpu.cc
+++ b/caffe2/operators/experimental/c10/cpu/filler_cpu.cc
@@ -76,8 +76,8 @@ void constant_fill_op_cpu_impl(
ArrayRef<int64_t> shape,
ArrayRef<int64_t> extra_shape,
bool input_as_shape,
- int dtype,
- c10::IValue value) {
+ int64_t dtype,
+ c10::Scalar value) {
Tensor output{C10Tensor(output_)};
CPUContext context;
@@ -102,12 +102,6 @@ void constant_fill_op_cpu_impl(
value.toInt(),
output.template mutable_data<int64_t>(),
static_cast<CPUContext*>(&context));
- } else if (dtype == caffe2::TensorProto_DataType_BOOL) {
- caffe2::math::Set<bool, CPUContext>(
- output.numel(),
- value.toBool(),
- output.template mutable_data<bool>(),
- static_cast<CPUContext*>(&context));
} else {
throw std::logic_error(
"Unimplemented data type for ConstantFill: " +
@@ -122,8 +116,8 @@ void uniform_fill_op_cpu_impl(
ArrayRef<int64_t> shape,
ArrayRef<int64_t> extra_shape,
bool input_as_shape,
- float min,
- float max) {
+ double min,
+ double max) {
Tensor output{C10Tensor(output_)};
CPUContext context;
@@ -154,22 +148,22 @@ void uniform_fill_op_cpu_impl(
namespace c10 {
C10_REGISTER_KERNEL(caffe2::ops::ConstantFill)
- .kernel<&caffe2::constant_fill_op_cpu_impl>()
- .dispatchKey(c10::DeviceTypeId::CPU);
+ .kernel<decltype(caffe2::constant_fill_op_cpu_impl), &caffe2::constant_fill_op_cpu_impl>()
+ .dispatchKey(CPUTensorId());
C10_REGISTER_KERNEL(caffe2::ops::UniformFill)
- .kernel<&caffe2::uniform_fill_op_cpu_impl>()
- .dispatchKey(c10::DeviceTypeId::CPU);
+ .kernel<decltype(caffe2::uniform_fill_op_cpu_impl), &caffe2::uniform_fill_op_cpu_impl>()
+ .dispatchKey(CPUTensorId());
-C10_REGISTER_KERNEL(caffe2::ops::GivenTensorFill<float>)
- .kernel<&caffe2::given_tensor_fill_op_cpu_impl<float, caffe2::CPUContext>>()
- .dispatchKey(c10::DeviceTypeId::CPU);
+C10_REGISTER_KERNEL(caffe2::ops::GivenTensorFill)
+ .kernel<decltype(caffe2::given_tensor_fill_op_cpu_impl<float, caffe2::CPUContext>), &caffe2::given_tensor_fill_op_cpu_impl<float, caffe2::CPUContext>>()
+ .dispatchKey(CPUTensorId());
-C10_REGISTER_KERNEL(caffe2::ops::GivenTensorFill<int>)
- .kernel<&caffe2::given_tensor_fill_op_cpu_impl<int, caffe2::CPUContext>>()
- .dispatchKey(c10::DeviceTypeId::CPU);
+C10_REGISTER_KERNEL(caffe2::ops::GivenTensorIntFill)
+ .kernel<decltype(caffe2::given_tensor_fill_op_cpu_impl<int, caffe2::CPUContext>), &caffe2::given_tensor_fill_op_cpu_impl<int, caffe2::CPUContext>>()
+ .dispatchKey(CPUTensorId());
-C10_REGISTER_KERNEL(caffe2::ops::GivenTensorFill<int64_t>)
- .kernel<&caffe2::given_tensor_fill_op_cpu_impl<int64_t, caffe2::CPUContext>>()
- .dispatchKey(c10::DeviceTypeId::CPU);
+C10_REGISTER_KERNEL(caffe2::ops::GivenTensorInt64Fill)
+ .kernel<decltype(caffe2::given_tensor_fill_op_cpu_impl<int64_t, caffe2::CPUContext>), &caffe2::given_tensor_fill_op_cpu_impl<int64_t, caffe2::CPUContext>>()
+ .dispatchKey(CPUTensorId());
} // namespace c10
diff --git a/caffe2/operators/experimental/c10/cpu/flatten_cpu.cc b/caffe2/operators/experimental/c10/cpu/flatten_cpu.cc
index 26d16235db..23bbaf38dc 100644
--- a/caffe2/operators/experimental/c10/cpu/flatten_cpu.cc
+++ b/caffe2/operators/experimental/c10/cpu/flatten_cpu.cc
@@ -12,7 +12,7 @@ template <class DataType, class Context>
void flatten_op_cpu_impl(
const at::Tensor& input_,
const at::Tensor& output_,
- int axis) {
+ int64_t axis) {
Tensor input{C10Tensor(input_)};
Tensor output{C10Tensor(output_)};
CPUContext context;
@@ -30,8 +30,6 @@ void flatten_op_cpu_impl(
namespace c10 {
C10_REGISTER_KERNEL(caffe2::ops::Flatten)
- .kernel<&caffe2::flatten_op_cpu_impl<float, caffe2::CPUContext>>()
- .dispatchKey({DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<float>()});
+ .kernel<decltype(caffe2::flatten_op_cpu_impl<float, caffe2::CPUContext>), &caffe2::flatten_op_cpu_impl<float, caffe2::CPUContext>>()
+ .dispatchKey(CPUTensorId());
} // namespace c10
diff --git a/caffe2/operators/experimental/c10/cpu/mul_cpu.cc b/caffe2/operators/experimental/c10/cpu/mul_cpu.cc
index 38f7385ddf..247e1bb452 100644
--- a/caffe2/operators/experimental/c10/cpu/mul_cpu.cc
+++ b/caffe2/operators/experimental/c10/cpu/mul_cpu.cc
@@ -16,7 +16,7 @@ void mul_op_cpu_impl(
const at::Tensor& B_,
const at::Tensor& C_,
bool legacy_broadcast,
- int axis) {
+ int64_t axis) {
Tensor A{C10Tensor(A_)};
Tensor B{C10Tensor(B_)};
Tensor C{C10Tensor(C_)};
@@ -75,13 +75,6 @@ void mul_op_cpu_impl(
namespace c10 {
C10_REGISTER_KERNEL(caffe2::ops::Mul)
- .kernel<&caffe2::mul_op_cpu_impl<float>>()
- .dispatchKey(c10::DispatchKey<2>{
- c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<float>()},
- c10::details::TensorParameterDispatchKey{
- DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<float>()}});
+ .kernel<decltype(caffe2::mul_op_cpu_impl<float>), &caffe2::mul_op_cpu_impl<float>>()
+ .dispatchKey(CPUTensorId());
} // namespace c10
diff --git a/caffe2/operators/experimental/c10/cpu/relu_cpu.cc b/caffe2/operators/experimental/c10/cpu/relu_cpu.cc
index 7c5b8620ef..67c7ee2431 100644
--- a/caffe2/operators/experimental/c10/cpu/relu_cpu.cc
+++ b/caffe2/operators/experimental/c10/cpu/relu_cpu.cc
@@ -44,8 +44,6 @@ void relu_op_cpu_impl(
namespace c10 {
C10_REGISTER_KERNEL(caffe2::ops::Relu)
- .kernel<&caffe2::relu_op_cpu_impl<float>>()
- .dispatchKey({DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<float>()});
+ .kernel<decltype(caffe2::relu_op_cpu_impl<float>), &caffe2::relu_op_cpu_impl<float>>()
+ .dispatchKey(CPUTensorId());
} // namespace c10
diff --git a/caffe2/operators/experimental/c10/cpu/sigmoid_cpu.cc b/caffe2/operators/experimental/c10/cpu/sigmoid_cpu.cc
index 78febf5de2..470d6332e9 100644
--- a/caffe2/operators/experimental/c10/cpu/sigmoid_cpu.cc
+++ b/caffe2/operators/experimental/c10/cpu/sigmoid_cpu.cc
@@ -27,8 +27,6 @@ void sigmoid_op_cpu_impl(
namespace c10 {
C10_REGISTER_KERNEL(caffe2::ops::Sigmoid)
- .kernel<&caffe2::sigmoid_op_cpu_impl<float>>()
- .dispatchKey({DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<float>()});
+ .kernel<decltype(caffe2::sigmoid_op_cpu_impl<float>), &caffe2::sigmoid_op_cpu_impl<float>>()
+ .dispatchKey(CPUTensorId());
} // namespace c10
diff --git a/caffe2/operators/experimental/c10/cpu/sigmoid_cross_entropy_with_logits_cpu.cc b/caffe2/operators/experimental/c10/cpu/sigmoid_cross_entropy_with_logits_cpu.cc
index 32853fd82b..4255e2ae31 100644
--- a/caffe2/operators/experimental/c10/cpu/sigmoid_cross_entropy_with_logits_cpu.cc
+++ b/caffe2/operators/experimental/c10/cpu/sigmoid_cross_entropy_with_logits_cpu.cc
@@ -74,13 +74,6 @@ void sigmoid_cross_entropy_with_logits_op_cpu_impl(
namespace c10 {
C10_REGISTER_KERNEL(caffe2::ops::SigmoidCrossEntropyWithLogits)
- .kernel<&caffe2::sigmoid_cross_entropy_with_logits_op_cpu_impl>()
- .dispatchKey(c10::DispatchKey<2>{
- c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<float>()},
- c10::details::TensorParameterDispatchKey{
- DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<float>()}});
+ .kernel<decltype(caffe2::sigmoid_cross_entropy_with_logits_op_cpu_impl), &caffe2::sigmoid_cross_entropy_with_logits_op_cpu_impl>()
+ .dispatchKey(CPUTensorId());
} // namespace c10
diff --git a/caffe2/operators/experimental/c10/cpu/sparse_lengths_sum_cpu.cc b/caffe2/operators/experimental/c10/cpu/sparse_lengths_sum_cpu.cc
index 94711168cf..1d8b1802e3 100644
--- a/caffe2/operators/experimental/c10/cpu/sparse_lengths_sum_cpu.cc
+++ b/caffe2/operators/experimental/c10/cpu/sparse_lengths_sum_cpu.cc
@@ -10,7 +10,7 @@ namespace caffe2 {
namespace {
template <typename InputType, typename IndexType>
-void sparse_lengths_sum_op_cpu_impl(
+void sparse_lengths_sum_op_cpu_impl_(
const at::Tensor& dataInput_,
const at::Tensor& indicesInput_,
const at::Tensor& lengthsInput_,
@@ -55,62 +55,37 @@ void sparse_lengths_sum_op_cpu_impl(
USE_MEAN,
out_data);
}
+
+template<typename IndexType>
+void sparse_lengths_sum_op_cpu_impl(
+ const at::Tensor& dataInput,
+ const at::Tensor& indicesInput,
+ const at::Tensor& lengthsInput,
+ const at::Tensor& output) {
+ switch (dataInput.scalar_type()) {
+ case ScalarType::Float: return sparse_lengths_sum_op_cpu_impl_<float, IndexType>(dataInput, indicesInput, lengthsInput, output);
+ case ScalarType::Half: return sparse_lengths_sum_op_cpu_impl_<at::Half, IndexType>(dataInput, indicesInput, lengthsInput, output);
+ default: throw std::runtime_error(string() + "Unsupported dtype for input data " + toString(dataInput.scalar_type()));
+ }
+}
+
+void sparse_lengths_sum_op_cpu(
+ const at::Tensor& dataInput,
+ const at::Tensor& indicesInput,
+ const at::Tensor& lengthsInput,
+ const at::Tensor& output) {
+ switch (indicesInput.scalar_type()) {
+ case ScalarType::Int: return sparse_lengths_sum_op_cpu_impl<int>(dataInput, indicesInput, lengthsInput, output);
+ case ScalarType::Long: return sparse_lengths_sum_op_cpu_impl<int64_t>(dataInput, indicesInput, lengthsInput, output);
+ default: throw std::runtime_error(string() + "Unsupported dtype for input indices " + toString(dataInput.scalar_type()));
+ }
+}
+
} // namespace
} // namespace caffe2
namespace c10 {
C10_REGISTER_KERNEL(caffe2::ops::SparseLengthsSum)
- .kernel<&caffe2::sparse_lengths_sum_op_cpu_impl<float, int32_t>>()
- .dispatchKey(c10::DispatchKey<3>{
- c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<float>()},
- c10::details::TensorParameterDispatchKey{
- DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<int32_t>()},
- c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<int>()}});
-C10_REGISTER_KERNEL(caffe2::ops::SparseLengthsSum)
- .kernel<&caffe2::sparse_lengths_sum_op_cpu_impl<at::Half, int32_t>>()
- .dispatchKey(c10::DispatchKey<3>{
- c10::details::TensorParameterDispatchKey{
- DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<at::Half>()},
- c10::details::TensorParameterDispatchKey{
- DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<int32_t>()},
- c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<int>()}});
-C10_REGISTER_KERNEL(caffe2::ops::SparseLengthsSum)
- .kernel<&caffe2::sparse_lengths_sum_op_cpu_impl<float, int64_t>>()
- .dispatchKey(c10::DispatchKey<3>{
- c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<float>()},
- c10::details::TensorParameterDispatchKey{
- DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<int64_t>()},
- c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<int>()}});
-C10_REGISTER_KERNEL(caffe2::ops::SparseLengthsSum)
- .kernel<&caffe2::sparse_lengths_sum_op_cpu_impl<at::Half, int64_t>>()
- .dispatchKey(c10::DispatchKey<3>{
- c10::details::TensorParameterDispatchKey{
- DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<at::Half>()},
- c10::details::TensorParameterDispatchKey{
- DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<int64_t>()},
- c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<int>()}});
+ .kernel<decltype(caffe2::sparse_lengths_sum_op_cpu), &caffe2::sparse_lengths_sum_op_cpu>()
+ .dispatchKey(CPUTensorId());
} // namespace c10
diff --git a/caffe2/operators/experimental/c10/cpu/stop_gradient_cpu.cc b/caffe2/operators/experimental/c10/cpu/stop_gradient_cpu.cc
index 1165978f9e..4d3a9812ce 100644
--- a/caffe2/operators/experimental/c10/cpu/stop_gradient_cpu.cc
+++ b/caffe2/operators/experimental/c10/cpu/stop_gradient_cpu.cc
@@ -23,8 +23,6 @@ void stop_gradient_op_cpu_impl(
namespace c10 {
C10_REGISTER_KERNEL(caffe2::ops::StopGradient)
- .kernel<&caffe2::stop_gradient_op_cpu_impl<float>>()
- .dispatchKey({DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<float>()});
+ .kernel<decltype(caffe2::stop_gradient_op_cpu_impl<float>), &caffe2::stop_gradient_op_cpu_impl<float>>()
+ .dispatchKey(CPUTensorId());
} // namespace c10
diff --git a/caffe2/operators/experimental/c10/schemas/add.cc b/caffe2/operators/experimental/c10/schemas/add.cc
index cfb778df12..88906b7a52 100644
--- a/caffe2/operators/experimental/c10/schemas/add.cc
+++ b/caffe2/operators/experimental/c10/schemas/add.cc
@@ -1,10 +1,24 @@
#include "caffe2/operators/experimental/c10/schemas/add.h"
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
#include "caffe2/core/operator_c10wrapper.h"
using caffe2::CPUContext;
-C10_DEFINE_OP_SCHEMA(caffe2::ops::Add);
+namespace caffe2 {
+namespace ops {
+// TODO Parse schema string instead of creating FunctionSchema manually
+C10_DEFINE_OP_SCHEMA(Add, FunctionSchema(
+ "_c10_experimental::Add",
+ (std::vector<c10::Argument>{
+ c10::Argument("input1"),
+ c10::Argument("input2"),
+ c10::Argument("output"),
+ c10::Argument("legacy_broadcast", BoolType::get()),
+ c10::Argument("axis", IntType::get())
+ }), (std::vector<c10::Argument>{
+ })
+));
+}
+}
namespace {
@@ -32,6 +46,7 @@ namespace caffe2 {
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_PARAMETERS(
ops::Add,
C10Add_DontUseThisOpYet,
+ 1,
ParameterHelper<LegacyBroadcastParameter>,
ParameterHelper<AxisParameter>)
}
diff --git a/caffe2/operators/experimental/c10/schemas/add.h b/caffe2/operators/experimental/c10/schemas/add.h
index fba907ab31..4dfa99ae77 100644
--- a/caffe2/operators/experimental/c10/schemas/add.h
+++ b/caffe2/operators/experimental/c10/schemas/add.h
@@ -1,29 +1,11 @@
#pragma once
-#include <ATen/core/Tensor.h>
-#include <c10/util/Array.h>
-#include "caffe2/core/context_base.h"
+#include <ATen/core/dispatch/OpSchemaRegistration.h>
namespace caffe2 {
namespace ops {
-struct Add final {
- static constexpr const char* name = "add";
-
- using Signature = void(
- const at::Tensor& input1,
- const at::Tensor& input2,
- const at::Tensor& output,
- bool legacy_broadcast,
- int axis);
-
- static constexpr size_t num_dispatch_args() {return 2;}
-
- static constexpr size_t num_outputs() {return 1;}
-
- static constexpr c10::guts::array<const char*, 5> parameter_names = {
- {"input1", "input2", "output", "legacy_broadcast", "axis"}};
-};
+C10_DECLARE_OP_SCHEMA(Add);
} // namespace ops
} // namespace caffe2
diff --git a/caffe2/operators/experimental/c10/schemas/averaged_loss.cc b/caffe2/operators/experimental/c10/schemas/averaged_loss.cc
index c276f84a65..c40752c9f2 100644
--- a/caffe2/operators/experimental/c10/schemas/averaged_loss.cc
+++ b/caffe2/operators/experimental/c10/schemas/averaged_loss.cc
@@ -4,11 +4,25 @@
using caffe2::CPUContext;
-C10_DEFINE_OP_SCHEMA(caffe2::ops::AveragedLoss);
+namespace caffe2 {
+namespace ops {
+// TODO Parse schema string instead of creating FunctionSchema manually
+C10_DEFINE_OP_SCHEMA(AveragedLoss, FunctionSchema(
+ "_c10_experimental::AveragedLoss",
+ (std::vector<c10::Argument>{
+ c10::Argument("input"),
+ c10::Argument("output")
+ }), (std::vector<c10::Argument>{
+ })
+));
+}
+}
namespace caffe2 {
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_PARAMETERS(
ops::AveragedLoss,
- C10AveragedLoss_DontUseThisOpYet)
+ C10AveragedLoss_DontUseThisOpYet,
+ 1
+ )
}
diff --git a/caffe2/operators/experimental/c10/schemas/averaged_loss.h b/caffe2/operators/experimental/c10/schemas/averaged_loss.h
index 37e6906dbe..548bd075f2 100644
--- a/caffe2/operators/experimental/c10/schemas/averaged_loss.h
+++ b/caffe2/operators/experimental/c10/schemas/averaged_loss.h
@@ -1,29 +1,11 @@
#pragma once
-#include <ATen/core/Tensor.h>
-#include <c10/util/Array.h>
-#include "caffe2/core/context_base.h"
-#include "caffe2/core/tensor.h"
-#include <ATen/core/blob.h>
-#include <ATen/core/dispatch/OpSchema.h>
+#include <ATen/core/dispatch/OpSchemaRegistration.h>
namespace caffe2 {
namespace ops {
-struct AveragedLoss final {
- static constexpr const char* name = "averaged_loss";
-
- using Signature = void(
- const at::Tensor& input,
- const at::Tensor& output);
-
- static constexpr size_t num_dispatch_args() {return 1;}
-
- static constexpr size_t num_outputs() {return 1;}
-
- static constexpr c10::guts::array<const char*, 2> parameter_names = {
- {"input", "output"}};
-};
+C10_DECLARE_OP_SCHEMA(AveragedLoss);
} // namespace ops
} // namespace caffe2
diff --git a/caffe2/operators/experimental/c10/schemas/batch_gather.cc b/caffe2/operators/experimental/c10/schemas/batch_gather.cc
index 5a45dd1bbc..a659ee90c3 100644
--- a/caffe2/operators/experimental/c10/schemas/batch_gather.cc
+++ b/caffe2/operators/experimental/c10/schemas/batch_gather.cc
@@ -4,10 +4,25 @@
using caffe2::CPUContext;
-C10_DEFINE_OP_SCHEMA(caffe2::ops::BatchGather);
+namespace caffe2 {
+namespace ops {
+// TODO Parse schema string instead of creating FunctionSchema manually
+C10_DEFINE_OP_SCHEMA(BatchGather, FunctionSchema(
+ "_c10_experimental::BatchGather",
+ (std::vector<c10::Argument>{
+ c10::Argument("data"),
+ c10::Argument("indices"),
+ c10::Argument("output")
+ }), (std::vector<c10::Argument>{
+ })
+));
+}
+}
namespace caffe2 {
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH(
ops::BatchGather,
- C10BatchGather_DontUseThisOpYet)
+ C10BatchGather_DontUseThisOpYet,
+ 1
+ )
}
diff --git a/caffe2/operators/experimental/c10/schemas/batch_gather.h b/caffe2/operators/experimental/c10/schemas/batch_gather.h
index d745efa35f..214c67ff99 100644
--- a/caffe2/operators/experimental/c10/schemas/batch_gather.h
+++ b/caffe2/operators/experimental/c10/schemas/batch_gather.h
@@ -1,27 +1,11 @@
#pragma once
-#include <ATen/core/Tensor.h>
-#include <c10/util/Array.h>
-#include "caffe2/core/context_base.h"
+#include <ATen/core/dispatch/OpSchemaRegistration.h>
namespace caffe2 {
namespace ops {
-struct BatchGather final {
- static constexpr const char* name = "batch_gather";
-
- using Signature = void(
- const at::Tensor& data,
- const at::Tensor& indices,
- const at::Tensor& output);
-
- static constexpr size_t num_dispatch_args() {return 2;}
-
- static constexpr size_t num_outputs() {return 1;}
-
- static constexpr c10::guts::array<const char*, 3> parameter_names = {
- {"data", "indices", "output"}};
-};
+C10_DECLARE_OP_SCHEMA(BatchGather);
} // namespace ops
} // namespace caffe2
diff --git a/caffe2/operators/experimental/c10/schemas/batch_matmul.cc b/caffe2/operators/experimental/c10/schemas/batch_matmul.cc
index 682609b7ea..70f8ba8194 100644
--- a/caffe2/operators/experimental/c10/schemas/batch_matmul.cc
+++ b/caffe2/operators/experimental/c10/schemas/batch_matmul.cc
@@ -4,7 +4,23 @@
using caffe2::CPUContext;
-C10_DEFINE_OP_SCHEMA(caffe2::ops::BatchMatmul);
+namespace caffe2 {
+namespace ops {
+// TODO Parse schema string instead of creating FunctionSchema manually
+C10_DEFINE_OP_SCHEMA(BatchMatmul, FunctionSchema(
+ "_c10_experimental::BatchMatmul",
+ (std::vector<c10::Argument>{
+ c10::Argument("A"),
+ c10::Argument("B"),
+ c10::Argument("output"),
+ c10::Argument("trans_a", IntType::get()),
+ c10::Argument("trans_b", IntType::get()),
+ c10::Argument("broadcast", IntType::get())
+ }), (std::vector<c10::Argument>{
+ })
+));
+}
+}
namespace {
struct TransAParameter final {
@@ -41,6 +57,7 @@ namespace caffe2 {
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_PARAMETERS(
ops::BatchMatmul,
C10BatchMatMul_DontUseThisOpYet,
+ 1,
ParameterHelper<TransAParameter>,
ParameterHelper<TransBParameter>,
ParameterHelper<BroadcastParameter>)
diff --git a/caffe2/operators/experimental/c10/schemas/batch_matmul.h b/caffe2/operators/experimental/c10/schemas/batch_matmul.h
index 2d815869f8..191e0e6a57 100644
--- a/caffe2/operators/experimental/c10/schemas/batch_matmul.h
+++ b/caffe2/operators/experimental/c10/schemas/batch_matmul.h
@@ -1,37 +1,11 @@
#pragma once
-#include <ATen/core/Tensor.h>
-#include <c10/util/Array.h>
-#include "caffe2/core/context_base.h"
-#include <ATen/core/blob.h>
-#include <ATen/core/dispatch/OpSchema.h>
+#include <ATen/core/dispatch/OpSchemaRegistration.h>
namespace caffe2 {
namespace ops {
-struct BatchMatmul final {
- static constexpr const char* name = "batch_matmul";
-
- using Signature = void(
- const at::Tensor& A,
- const at::Tensor& B,
- const at::Tensor& output,
- int trans_a,
- int trans_b,
- int broadcast);
-
- static constexpr size_t num_dispatch_args() {return 2;}
-
- static constexpr size_t num_outputs() {return 1;}
-
- static constexpr c10::guts::array<const char*, 6> parameter_names = {
- {"A",
- "B",
- "output",
- "trans_a",
- "trans_b",
- "broadcast"}};
-};
+C10_DECLARE_OP_SCHEMA(BatchMatmul);
} // namespace ops
} // namespace caffe2
diff --git a/caffe2/operators/experimental/c10/schemas/cast.cc b/caffe2/operators/experimental/c10/schemas/cast.cc
index d5615a6b76..6fce08764b 100644
--- a/caffe2/operators/experimental/c10/schemas/cast.cc
+++ b/caffe2/operators/experimental/c10/schemas/cast.cc
@@ -5,7 +5,20 @@
using caffe2::CPUContext;
-C10_DEFINE_OP_SCHEMA(caffe2::ops::Cast);
+namespace caffe2 {
+namespace ops {
+// TODO Parse schema string instead of creating FunctionSchema manually
+C10_DEFINE_OP_SCHEMA(Cast, FunctionSchema(
+ "_c10_experimental::Cast",
+ (std::vector<c10::Argument>{
+ c10::Argument("input"),
+ c10::Argument("output"),
+ c10::Argument("to_dtype", IntType::get()),
+ }), (std::vector<c10::Argument>{
+ })
+));
+}
+}
namespace {
@@ -22,5 +35,6 @@ namespace caffe2 {
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_PARAMETERS(
ops::Cast,
C10Cast_DontUseThisOpYet,
+ 1,
ToParameter)
}
diff --git a/caffe2/operators/experimental/c10/schemas/cast.h b/caffe2/operators/experimental/c10/schemas/cast.h
index 095348b768..979637b251 100644
--- a/caffe2/operators/experimental/c10/schemas/cast.h
+++ b/caffe2/operators/experimental/c10/schemas/cast.h
@@ -1,27 +1,11 @@
#pragma once
-#include <ATen/core/Tensor.h>
-#include <c10/util/Array.h>
-#include "caffe2/core/context_base.h"
+#include <ATen/core/dispatch/OpSchemaRegistration.h>
namespace caffe2 {
namespace ops {
-struct Cast final {
- static constexpr const char* name = "cast";
-
- using Signature = void(
- const at::Tensor& input1,
- const at::Tensor& output,
- int64_t to_dtype);
-
- static constexpr size_t num_dispatch_args() {return 1;}
-
- static constexpr size_t num_outputs() {return 1;}
-
- static constexpr c10::guts::array<const char*, 3> parameter_names = {
- {"input", "output", "to"}};
-};
+C10_DECLARE_OP_SCHEMA(Cast);
} // namespace ops
} // namespace caffe2
diff --git a/caffe2/operators/experimental/c10/schemas/concat.cc b/caffe2/operators/experimental/c10/schemas/concat.cc
index 75c39b4ba7..dddd2a4e75 100644
--- a/caffe2/operators/experimental/c10/schemas/concat.cc
+++ b/caffe2/operators/experimental/c10/schemas/concat.cc
@@ -4,7 +4,22 @@
using caffe2::CPUContext;
-C10_DEFINE_OP_SCHEMA(caffe2::ops::Concat);
+namespace caffe2 {
+namespace ops {
+// TODO Parse schema string instead of creating FunctionSchema manually
+C10_DEFINE_OP_SCHEMA(Concat, FunctionSchema(
+ "_c10_experimental::Concat",
+ (std::vector<c10::Argument>{
+ c10::Argument("inputs", ListType::ofTensors()),
+ c10::Argument("output"),
+ c10::Argument("split_info", FloatType::get()),
+ c10::Argument("add", IntType::get()),
+ c10::Argument("add_axis", IntType::get())
+ }), (std::vector<c10::Argument>{
+ })
+));
+}
+}
namespace {
struct AxisParameter final {
@@ -31,6 +46,7 @@ namespace caffe2 {
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_ARRAY_INPUT_AND_PARAMETERS(
ops::Concat,
C10Concat_DontUseThisOpYet,
+ 2,
ParameterHelper<AxisParameter>,
ParameterHelper<AddAxisParameter>)
}
diff --git a/caffe2/operators/experimental/c10/schemas/concat.h b/caffe2/operators/experimental/c10/schemas/concat.h
index 2309de2d82..aecaf4057a 100644
--- a/caffe2/operators/experimental/c10/schemas/concat.h
+++ b/caffe2/operators/experimental/c10/schemas/concat.h
@@ -1,36 +1,11 @@
#pragma once
-#include <ATen/core/dispatch/OpSchema.h>
-#include <ATen/core/dispatch/DeviceId.h>
-#include <ATen/core/Tensor.h>
-#include <c10/util/Array.h>
-#include <c10/util/ArrayRef.h>
-#include "caffe2/core/context_base.h"
-#include <ATen/core/ivalue.h>
+#include <ATen/core/dispatch/OpSchemaRegistration.h>
namespace caffe2 {
namespace ops {
-struct Concat final {
- static constexpr const char* name = "concat";
-
- using Signature = void(
- ArrayRef<at::Tensor> inputs,
- const at::Tensor& output,
- const at::Tensor& split_info,
- int add,
- int add_axis);
-
- static constexpr size_t num_outputs() {return 2;}
-
- static constexpr c10::guts::array<const char*, 5> parameter_names = {
- {"inputs", "output", "split_info_output", "add", "add_axis"}};
-
- static c10::DeviceTypeId dispatch_key(
- const Stack* arguments) {
- return c10::DeviceTypeId::CPU;
- }
-};
+C10_DECLARE_OP_SCHEMA(Concat);
} // namespace ops
} // namespace caffe2
diff --git a/caffe2/operators/experimental/c10/schemas/enforce_finite.cc b/caffe2/operators/experimental/c10/schemas/enforce_finite.cc
index 82d5f18755..170bc2592a 100644
--- a/caffe2/operators/experimental/c10/schemas/enforce_finite.cc
+++ b/caffe2/operators/experimental/c10/schemas/enforce_finite.cc
@@ -4,10 +4,22 @@
using caffe2::CPUContext;
-C10_DEFINE_OP_SCHEMA(caffe2::ops::EnforceFinite);
+namespace caffe2 {
+namespace ops {
+// TODO Parse schema string instead of creating FunctionSchema manually
+C10_DEFINE_OP_SCHEMA(EnforceFinite, FunctionSchema(
+ "_c10_experimental::EnforceFinite",
+ (std::vector<c10::Argument>{
+ c10::Argument("input")
+ }), (std::vector<c10::Argument>{
+ })
+));
+}
+}
namespace caffe2 {
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH(
ops::EnforceFinite,
- C10EnforceFinite_DontUseThisOpYet)
+ C10EnforceFinite_DontUseThisOpYet,
+ 0)
}
diff --git a/caffe2/operators/experimental/c10/schemas/enforce_finite.h b/caffe2/operators/experimental/c10/schemas/enforce_finite.h
index f811e2b6d8..704136c1f6 100644
--- a/caffe2/operators/experimental/c10/schemas/enforce_finite.h
+++ b/caffe2/operators/experimental/c10/schemas/enforce_finite.h
@@ -1,23 +1,11 @@
#pragma once
-#include <ATen/core/Tensor.h>
-#include <c10/util/Array.h>
+#include <ATen/core/dispatch/OpSchemaRegistration.h>
namespace caffe2 {
namespace ops {
-struct EnforceFinite final {
- static constexpr const char* name = "enforce_finite";
-
- using Signature = void(const at::Tensor& input);
-
- static constexpr size_t num_dispatch_args() {return 1;}
-
- static constexpr size_t num_outputs() {return 0;}
-
- static constexpr c10::guts::array<const char*, 1> parameter_names = {
- {"input"}};
-};
+C10_DECLARE_OP_SCHEMA(EnforceFinite);
} // namespace ops
} // namespace caffe2
diff --git a/caffe2/operators/experimental/c10/schemas/expand_dims.cc b/caffe2/operators/experimental/c10/schemas/expand_dims.cc
index 2145a1c1f9..939edd7263 100644
--- a/caffe2/operators/experimental/c10/schemas/expand_dims.cc
+++ b/caffe2/operators/experimental/c10/schemas/expand_dims.cc
@@ -6,7 +6,20 @@ using caffe2::CPUContext;
using c10::intrusive_ptr;
using c10::ivalue::IntList;
-C10_DEFINE_OP_SCHEMA(caffe2::ops::ExpandDims);
+namespace caffe2 {
+namespace ops {
+// TODO Parse schema string instead of creating FunctionSchema manually
+C10_DEFINE_OP_SCHEMA(ExpandDims, FunctionSchema(
+ "_c10_experimental::ExpandDims",
+ (std::vector<c10::Argument>{
+ c10::Argument("input"),
+ c10::Argument("output"),
+ c10::Argument("dims", ListType::ofInts())
+ }), (std::vector<c10::Argument>{
+ })
+));
+}
+}
namespace {
struct DimsParameter final {
@@ -22,5 +35,6 @@ namespace caffe2 {
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_PARAMETERS(
ops::ExpandDims,
C10ExpandDims_DontUseThisOpYet,
+ 1,
DimsParameter)
}
diff --git a/caffe2/operators/experimental/c10/schemas/expand_dims.h b/caffe2/operators/experimental/c10/schemas/expand_dims.h
index f30c7dd2eb..fa3ab8f99f 100644
--- a/caffe2/operators/experimental/c10/schemas/expand_dims.h
+++ b/caffe2/operators/experimental/c10/schemas/expand_dims.h
@@ -1,30 +1,11 @@
#pragma once
-#include <ATen/core/Tensor.h>
-#include <c10/util/Array.h>
-#include "caffe2/core/context_base.h"
-#include <ATen/core/ivalue.h>
-#include <ATen/core/blob.h>
-#include <ATen/core/dispatch/OpSchema.h>
+#include <ATen/core/dispatch/OpSchemaRegistration.h>
namespace caffe2 {
namespace ops {
-struct ExpandDims final {
- static constexpr const char* name = "expand_dims";
-
- using Signature = void(
- const at::Tensor& input,
- const at::Tensor& output,
- ArrayRef<int64_t> dims);
-
- static constexpr size_t num_dispatch_args() {return 1;}
-
- static constexpr size_t num_outputs() {return 1;}
-
- static constexpr c10::guts::array<const char*, 3> parameter_names = {
- {"input", "output", "dims"}};
-};
+C10_DECLARE_OP_SCHEMA(ExpandDims);
} // namespace ops
} // namespace caffe2
diff --git a/caffe2/operators/experimental/c10/schemas/fc.cc b/caffe2/operators/experimental/c10/schemas/fc.cc
index 4a1cdb41ff..29784a587c 100644
--- a/caffe2/operators/experimental/c10/schemas/fc.cc
+++ b/caffe2/operators/experimental/c10/schemas/fc.cc
@@ -4,7 +4,23 @@
using caffe2::CPUContext;
-C10_DEFINE_OP_SCHEMA(caffe2::ops::FullyConnected);
+namespace caffe2 {
+namespace ops {
+// TODO Parse schema string instead of creating FunctionSchema manually
+C10_DEFINE_OP_SCHEMA(FullyConnected, FunctionSchema(
+ "_c10_experimental::FullyConnected",
+ (std::vector<c10::Argument>{
+ c10::Argument("X"),
+ c10::Argument("W"),
+ c10::Argument("b"),
+ c10::Argument("output"),
+ c10::Argument("axis", IntType::get()),
+ c10::Argument("axis_w", IntType::get())
+ }), (std::vector<c10::Argument>{
+ })
+));
+}
+}
namespace {
struct AxisParameter final {
@@ -32,6 +48,7 @@ namespace caffe2 {
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_PARAMETERS(
ops::FullyConnected,
C10FC_DontUseThisOpYet,
+ 1,
ParameterHelper<AxisParameter>,
ParameterHelper<AxisWParameter>)
}
diff --git a/caffe2/operators/experimental/c10/schemas/fc.h b/caffe2/operators/experimental/c10/schemas/fc.h
index 5b2a30fa2c..1aed0eb311 100644
--- a/caffe2/operators/experimental/c10/schemas/fc.h
+++ b/caffe2/operators/experimental/c10/schemas/fc.h
@@ -1,32 +1,11 @@
#pragma once
-#include <ATen/core/Tensor.h>
-#include <c10/util/Array.h>
-#include "caffe2/core/tensor.h"
-#include <ATen/core/blob.h>
-#include <ATen/core/dispatch/OpSchema.h>
+#include <ATen/core/dispatch/OpSchemaRegistration.h>
namespace caffe2 {
namespace ops {
-struct FullyConnected final {
- static constexpr const char* name = "FC";
-
- using Signature = void(
- const at::Tensor& X,
- const at::Tensor& W,
- const at::Tensor& b,
- const at::Tensor& output,
- int axis,
- int axis_w);
-
- static constexpr size_t num_dispatch_args() {return 3;}
-
- static constexpr size_t num_outputs() {return 1;}
-
- static constexpr c10::guts::array<const char*, 6> parameter_names = {
- {"X", "W", "b", "output", "axis", "axis_w"}};
-};
+C10_DECLARE_OP_SCHEMA(FullyConnected);
} // namespace ops
} // namespace caffe2
diff --git a/caffe2/operators/experimental/c10/schemas/filler.cc b/caffe2/operators/experimental/c10/schemas/filler.cc
index 8dfd0655fd..3d7de21f6b 100644
--- a/caffe2/operators/experimental/c10/schemas/filler.cc
+++ b/caffe2/operators/experimental/c10/schemas/filler.cc
@@ -8,12 +8,73 @@ using c10::C10Tensor;
using c10::ivalue::IntList;
using c10::intrusive_ptr;
-C10_DEFINE_OP_SCHEMA(caffe2::ops::ConstantFill);
-C10_DEFINE_OP_SCHEMA(caffe2::ops::UniformFill);
-
-C10_DEFINE_OP_SCHEMA(caffe2::ops::GivenTensorFill<float>);
-C10_DEFINE_OP_SCHEMA(caffe2::ops::GivenTensorFill<int>);
-C10_DEFINE_OP_SCHEMA(caffe2::ops::GivenTensorFill<int64_t>);
+namespace caffe2 {
+namespace ops {
+// TODO Parse schema strings instead of creating FunctionSchema manually
+C10_DEFINE_OP_SCHEMA(ConstantFill, FunctionSchema(
+ "_c10_experimental::ConstantFill",
+ (std::vector<c10::Argument>{
+ c10::Argument("inputs", ListType::ofTensors()),
+ c10::Argument("output"),
+ c10::Argument("shape", ListType::ofInts()),
+ c10::Argument("extra_shape", ListType::ofInts()),
+ c10::Argument("input_as_shape", BoolType::get()),
+ c10::Argument("dtype", IntType::get()),
+ c10::Argument("value", NumberType::get())
+ }), (std::vector<c10::Argument>{
+ })
+));
+C10_DEFINE_OP_SCHEMA(UniformFill, FunctionSchema(
+ "_c10_experimental::ConstantFill",
+ (std::vector<c10::Argument>{
+ c10::Argument("inputs", ListType::ofTensors()),
+ c10::Argument("output"),
+ c10::Argument("shape", ListType::ofInts()),
+ c10::Argument("extra_shape", ListType::ofInts()),
+ c10::Argument("input_as_shape", BoolType::get()),
+ c10::Argument("min", FloatType::get()),
+ c10::Argument("max", FloatType::get())
+ }), (std::vector<c10::Argument>{
+ })
+));
+C10_DEFINE_OP_SCHEMA(GivenTensorFill, FunctionSchema(
+ "_c10_experimental::ConstantFill",
+ (std::vector<c10::Argument>{
+ c10::Argument("inputs", ListType::ofTensors()),
+ c10::Argument("output"),
+ c10::Argument("shape", ListType::ofInts()),
+ c10::Argument("extra_shape", ListType::ofInts()),
+ c10::Argument("input_as_shape", BoolType::get()),
+ c10::Argument("values"),
+ }), (std::vector<c10::Argument>{
+ })
+));
+C10_DEFINE_OP_SCHEMA(GivenTensorIntFill, FunctionSchema(
+ "_c10_experimental::ConstantFill",
+ (std::vector<c10::Argument>{
+ c10::Argument("inputs", ListType::ofTensors()),
+ c10::Argument("output"),
+ c10::Argument("shape", ListType::ofInts()),
+ c10::Argument("extra_shape", ListType::ofInts()),
+ c10::Argument("input_as_shape", BoolType::get()),
+ c10::Argument("values"),
+ }), (std::vector<c10::Argument>{
+ })
+));
+C10_DEFINE_OP_SCHEMA(GivenTensorInt64Fill, FunctionSchema(
+ "_c10_experimental::ConstantFill",
+ (std::vector<c10::Argument>{
+ c10::Argument("inputs", ListType::ofTensors()),
+ c10::Argument("output"),
+ c10::Argument("shape", ListType::ofInts()),
+ c10::Argument("extra_shape", ListType::ofInts()),
+ c10::Argument("input_as_shape", BoolType::get()),
+ c10::Argument("values"),
+ }), (std::vector<c10::Argument>{
+ })
+));
+}
+}
namespace {
struct ShapeParameter final {
@@ -136,6 +197,7 @@ namespace caffe2 {
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_ARRAY_INPUT_AND_PARAMETERS(
ops::ConstantFill,
C10ConstantFill_DontUseThisOpYet,
+ 1,
ShapeParameter,
ExtraShapeParameter,
InputAsShapeParameter,
@@ -144,6 +206,7 @@ REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_ARRAY_INPUT_AND_PARAMETERS(
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_ARRAY_INPUT_AND_PARAMETERS(
ops::UniformFill,
C10UniformFill_DontUseThisOpYet,
+ 1,
ShapeParameter,
ExtraShapeParameter,
InputAsShapeParameter,
@@ -151,22 +214,25 @@ REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_ARRAY_INPUT_AND_PARAMETERS(
MaxParameter)
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_ARRAY_INPUT_AND_PARAMETERS(
- ops::GivenTensorFill<float>,
+ ops::GivenTensorFill,
C10GivenTensorFill_DontUseThisOpYet,
+ 1,
ShapeParameter,
ExtraShapeParameter,
InputAsShapeParameter,
ValuesParameter<float>)
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_ARRAY_INPUT_AND_PARAMETERS(
- ops::GivenTensorFill<int>,
+ ops::GivenTensorIntFill,
C10GivenTensorIntFill_DontUseThisOpYet,
+ 1,
ShapeParameter,
ExtraShapeParameter,
InputAsShapeParameter,
ValuesParameter<int>)
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_ARRAY_INPUT_AND_PARAMETERS(
- ops::GivenTensorFill<int64_t>,
+ ops::GivenTensorInt64Fill,
C10GivenTensorInt64Fill_DontUseThisOpYet,
+ 1,
ShapeParameter,
ExtraShapeParameter,
InputAsShapeParameter,
diff --git a/caffe2/operators/experimental/c10/schemas/filler.h b/caffe2/operators/experimental/c10/schemas/filler.h
index ef0b04662e..616893d25a 100644
--- a/caffe2/operators/experimental/c10/schemas/filler.h
+++ b/caffe2/operators/experimental/c10/schemas/filler.h
@@ -1,105 +1,15 @@
#pragma once
-#include <ATen/core/dispatch/OpSchema.h>
-#include <ATen/core/dispatch/DeviceId.h>
-#include <ATen/core/Tensor.h>
-#include <ATen/core/ivalue.h>
-#include <c10/util/Array.h>
-#include <c10/util/ArrayRef.h>
-#include "caffe2/core/context_base.h"
+#include <ATen/core/dispatch/OpSchemaRegistration.h>
namespace caffe2 {
namespace ops {
-// GivenTensorFill
-// GivenTensorInt64Fill
-// GivenTensorIntFill
-
-template <class T>
-struct GivenTensorFill final {
- static constexpr const char* name = "given_tensor_fill";
-
- using Signature = void(
- ArrayRef<at::Tensor> inputs,
- const at::Tensor& output,
- ArrayRef<int64_t> shape,
- ArrayRef<int64_t> extra_shape,
- bool input_as_shape,
- const at::Tensor& values);
-
- static constexpr c10::guts::array<const char*, 6> parameter_names = {
- {"inputs",
- "output",
- "shape",
- "extra_shape",
- "input_as_shape",
- "values"}};
-
- static constexpr size_t num_outputs() {return 1;}
-
- static c10::DeviceTypeId dispatch_key(
- const Stack* stack) {
- return c10::DeviceTypeId::CPU;
- }
-};
-
-struct ConstantFill final {
- static constexpr const char* name = "constant_fill";
-
- using Signature = void(
- ArrayRef<at::Tensor> inputs,
- const at::Tensor& output,
- ArrayRef<int64_t> shape,
- ArrayRef<int64_t> extra_shape,
- bool input_as_shape,
- int dtype,
- IValue value);
-
- static constexpr size_t num_outputs() {return 1;}
-
- static constexpr c10::guts::array<const char*, 7> parameter_names = {
- {"inputs",
- "output",
- "shape",
- "extra_shape",
- "input_as_shape",
- "dtype",
- "value"}};
-
- static c10::DeviceTypeId dispatch_key(
- const Stack* stack) {
- return c10::DeviceTypeId::CPU;
- }
-};
-
-struct UniformFill final {
- static constexpr const char* name = "uniform_fill";
-
- using Signature = void(
- ArrayRef<at::Tensor> inputs,
- const at::Tensor& output,
- ArrayRef<int64_t> shape,
- ArrayRef<int64_t> extra_shape,
- bool input_as_shape,
- float min,
- float max);
-
- static constexpr size_t num_outputs() {return 1;}
-
- static constexpr c10::guts::array<const char*, 7> parameter_names = {
- {"inputs",
- "output",
- "shape",
- "extra_shape",
- "input_as_shape",
- "min",
- "max"}};
-
- static c10::DeviceTypeId dispatch_key(
- const Stack* stack) {
- return c10::DeviceTypeId::CPU;
- }
-};
+C10_DECLARE_OP_SCHEMA(GivenTensorFill);
+C10_DECLARE_OP_SCHEMA(GivenTensorIntFill);
+C10_DECLARE_OP_SCHEMA(GivenTensorInt64Fill);
+C10_DECLARE_OP_SCHEMA(ConstantFill);
+C10_DECLARE_OP_SCHEMA(UniformFill);
} // namespace ops
} // namespace caffe2
diff --git a/caffe2/operators/experimental/c10/schemas/flatten.cc b/caffe2/operators/experimental/c10/schemas/flatten.cc
index 8a14e58d07..c1918654c6 100644
--- a/caffe2/operators/experimental/c10/schemas/flatten.cc
+++ b/caffe2/operators/experimental/c10/schemas/flatten.cc
@@ -4,7 +4,20 @@
using caffe2::CPUContext;
-C10_DEFINE_OP_SCHEMA(caffe2::ops::Flatten);
+namespace caffe2 {
+namespace ops {
+// TODO Parse schema string instead of creating FunctionSchema manually
+C10_DEFINE_OP_SCHEMA(Flatten, FunctionSchema(
+ "_c10_experimental::Flatten",
+ (std::vector<c10::Argument>{
+ c10::Argument("input"),
+ c10::Argument("output"),
+ c10::Argument("axis", IntType::get())
+ }), (std::vector<c10::Argument>{
+ })
+));
+}
+}
namespace {
struct AxisParameter final {
@@ -22,5 +35,6 @@ namespace caffe2 {
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_PARAMETERS(
ops::Flatten,
C10Flatten_DontUseThisOpYet,
+ 1,
ParameterHelper<AxisParameter>)
}
diff --git a/caffe2/operators/experimental/c10/schemas/flatten.h b/caffe2/operators/experimental/c10/schemas/flatten.h
index 0ee1773cb6..9c53462528 100644
--- a/caffe2/operators/experimental/c10/schemas/flatten.h
+++ b/caffe2/operators/experimental/c10/schemas/flatten.h
@@ -1,27 +1,11 @@
#pragma once
-#include <ATen/core/Tensor.h>
-#include <c10/util/Array.h>
-#include "caffe2/core/context_base.h"
+#include <ATen/core/dispatch/OpSchemaRegistration.h>
namespace caffe2 {
namespace ops {
-struct Flatten final {
- static constexpr const char* name = "flatten";
-
- using Signature = void(
- const at::Tensor& input,
- const at::Tensor& output,
- int axis);
-
- static constexpr size_t num_dispatch_args() {return 1;}
-
- static constexpr size_t num_outputs() {return 1;}
-
- static constexpr c10::guts::array<const char*, 3> parameter_names = {
- {"input", "output", "axis"}};
-};
+C10_DECLARE_OP_SCHEMA(Flatten);
} // namespace ops
} // namespace caffe2
diff --git a/caffe2/operators/experimental/c10/schemas/layer_norm.cc b/caffe2/operators/experimental/c10/schemas/layer_norm.cc
index e9dddfb2db..4fc6b18999 100644
--- a/caffe2/operators/experimental/c10/schemas/layer_norm.cc
+++ b/caffe2/operators/experimental/c10/schemas/layer_norm.cc
@@ -28,6 +28,7 @@ namespace caffe2 {
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_PARAMETERS(
c10::core::opschema::LayerNorm,
C10LayerNorm_DontUseThisOpYet,
+ 3,
ParameterHelper<AxisParameter>,
ParameterHelper<EpsilonParameter>)
}
diff --git a/caffe2/operators/experimental/c10/schemas/mul.cc b/caffe2/operators/experimental/c10/schemas/mul.cc
index 9111dfda29..723aeab92b 100644
--- a/caffe2/operators/experimental/c10/schemas/mul.cc
+++ b/caffe2/operators/experimental/c10/schemas/mul.cc
@@ -4,7 +4,22 @@
using caffe2::CPUContext;
-C10_DEFINE_OP_SCHEMA(caffe2::ops::Mul);
+namespace caffe2 {
+namespace ops {
+// TODO Parse schema string instead of creating FunctionSchema manually
+C10_DEFINE_OP_SCHEMA(Mul, FunctionSchema(
+ "_c10_experimental::Mul",
+ (std::vector<c10::Argument>{
+ c10::Argument("input1"),
+ c10::Argument("input2"),
+ c10::Argument("output"),
+ c10::Argument("legacy_broadcast", BoolType::get()),
+ c10::Argument("axis", IntType::get())
+ }), (std::vector<c10::Argument>{
+ })
+));
+}
+}
namespace {
@@ -32,6 +47,7 @@ namespace caffe2 {
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_PARAMETERS(
ops::Mul,
C10Mul_DontUseThisOpYet,
+ 1,
ParameterHelper<LegacyBroadcastParameter>,
ParameterHelper<AxisParameter>)
}
diff --git a/caffe2/operators/experimental/c10/schemas/mul.h b/caffe2/operators/experimental/c10/schemas/mul.h
index 12a1780330..54b64f4f74 100644
--- a/caffe2/operators/experimental/c10/schemas/mul.h
+++ b/caffe2/operators/experimental/c10/schemas/mul.h
@@ -1,29 +1,11 @@
#pragma once
-#include <ATen/core/Tensor.h>
-#include <c10/util/Array.h>
-#include "caffe2/core/context_base.h"
+#include <ATen/core/dispatch/OpSchemaRegistration.h>
namespace caffe2 {
namespace ops {
-struct Mul final {
- static constexpr const char* name = "mul";
-
- using Signature = void(
- const at::Tensor& input1,
- const at::Tensor& input2,
- const at::Tensor& output,
- bool legacy_broadcast,
- int axis);
-
- static constexpr size_t num_dispatch_args() {return 2;}
-
- static constexpr size_t num_outputs() {return 1;}
-
- static constexpr c10::guts::array<const char*, 5> parameter_names = {
- {"input1", "input2", "output", "legacy_broadcast", "axis"}};
-};
+C10_DECLARE_OP_SCHEMA(Mul);
} // namespace ops
} // namespace caffe2
diff --git a/caffe2/operators/experimental/c10/schemas/relu.cc b/caffe2/operators/experimental/c10/schemas/relu.cc
index 1d08b20be5..91b95f1820 100644
--- a/caffe2/operators/experimental/c10/schemas/relu.cc
+++ b/caffe2/operators/experimental/c10/schemas/relu.cc
@@ -4,10 +4,24 @@
using caffe2::CPUContext;
-C10_DEFINE_OP_SCHEMA(caffe2::ops::Relu);
+namespace caffe2 {
+namespace ops {
+// TODO Parse schema string instead of creating FunctionSchema manually
+C10_DEFINE_OP_SCHEMA(Relu, FunctionSchema(
+ "_c10_experimental::Relu",
+ (std::vector<c10::Argument>{
+ c10::Argument("input"),
+ c10::Argument("output")
+ }), (std::vector<c10::Argument>{
+ })
+));
+}
+}
namespace caffe2 {
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH(
ops::Relu,
- C10Relu_DontUseThisOpYet)
+ C10Relu_DontUseThisOpYet,
+ 1
+ )
}
diff --git a/caffe2/operators/experimental/c10/schemas/relu.h b/caffe2/operators/experimental/c10/schemas/relu.h
index bf1b8fd03c..ea0aa89367 100644
--- a/caffe2/operators/experimental/c10/schemas/relu.h
+++ b/caffe2/operators/experimental/c10/schemas/relu.h
@@ -1,24 +1,11 @@
#pragma once
-#include <ATen/core/Tensor.h>
-#include <c10/util/Array.h>
+#include <ATen/core/dispatch/OpSchemaRegistration.h>
namespace caffe2 {
namespace ops {
-struct Relu final {
- static constexpr const char* name = "relu";
-
- using Signature =
- void(const at::Tensor& input, const at::Tensor& output);
-
- static constexpr size_t num_dispatch_args() {return 1;}
-
- static constexpr size_t num_outputs() {return 1;}
-
- static constexpr c10::guts::array<const char*, 2> parameter_names = {
- {"input", "output"}};
-};
+C10_DECLARE_OP_SCHEMA(Relu);
} // namespace ops
} // namespace caffe2
diff --git a/caffe2/operators/experimental/c10/schemas/sigmoid.cc b/caffe2/operators/experimental/c10/schemas/sigmoid.cc
index fbcfe1e449..ebc6fabfe5 100644
--- a/caffe2/operators/experimental/c10/schemas/sigmoid.cc
+++ b/caffe2/operators/experimental/c10/schemas/sigmoid.cc
@@ -4,10 +4,23 @@
using caffe2::CPUContext;
-C10_DEFINE_OP_SCHEMA(caffe2::ops::Sigmoid);
+namespace caffe2 {
+namespace ops {
+// TODO Parse schema string instead of creating FunctionSchema manually
+C10_DEFINE_OP_SCHEMA(Sigmoid, FunctionSchema(
+ "_c10_experimental::Sigmoid",
+ (std::vector<c10::Argument>{
+ c10::Argument("input"),
+ c10::Argument("output")
+ }), (std::vector<c10::Argument>{
+ })
+));
+}
+}
namespace caffe2 {
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH(
ops::Sigmoid,
- C10Sigmoid_DontUseThisOpYet)
+ C10Sigmoid_DontUseThisOpYet,
+ 1)
}
diff --git a/caffe2/operators/experimental/c10/schemas/sigmoid.h b/caffe2/operators/experimental/c10/schemas/sigmoid.h
index 326dc078f4..5d5ff41a59 100644
--- a/caffe2/operators/experimental/c10/schemas/sigmoid.h
+++ b/caffe2/operators/experimental/c10/schemas/sigmoid.h
@@ -1,24 +1,11 @@
#pragma once
-#include <ATen/core/Tensor.h>
-#include <c10/util/Array.h>
+#include <ATen/core/dispatch/OpSchemaRegistration.h>
namespace caffe2 {
namespace ops {
-struct Sigmoid final {
- static constexpr const char* name = "sigmoid";
-
- using Signature =
- void(const at::Tensor& input, const at::Tensor& output);
-
- static constexpr size_t num_dispatch_args() {return 1;}
-
- static constexpr size_t num_outputs() {return 1;}
-
- static constexpr c10::guts::array<const char*, 2> parameter_names = {
- {"input", "output"}};
-};
+C10_DECLARE_OP_SCHEMA(Sigmoid);
} // namespace ops
} // namespace caffe2
diff --git a/caffe2/operators/experimental/c10/schemas/sigmoid_cross_entropy_with_logits.cc b/caffe2/operators/experimental/c10/schemas/sigmoid_cross_entropy_with_logits.cc
index 0e34417200..f7aa2acddb 100644
--- a/caffe2/operators/experimental/c10/schemas/sigmoid_cross_entropy_with_logits.cc
+++ b/caffe2/operators/experimental/c10/schemas/sigmoid_cross_entropy_with_logits.cc
@@ -4,7 +4,22 @@
using caffe2::CPUContext;
-C10_DEFINE_OP_SCHEMA(caffe2::ops::SigmoidCrossEntropyWithLogits);
+namespace caffe2 {
+namespace ops {
+// TODO Parse schema string instead of creating FunctionSchema manually
+C10_DEFINE_OP_SCHEMA(SigmoidCrossEntropyWithLogits, FunctionSchema(
+ "_c10_experimental::SigmoidCrossEntropyWithLogits",
+ (std::vector<c10::Argument>{
+ c10::Argument("input1"),
+ c10::Argument("input2"),
+ c10::Argument("output"),
+ c10::Argument("log_D_trick", BoolType::get()),
+ c10::Argument("unjoined_lr_loss", BoolType::get())
+ }), (std::vector<c10::Argument>{
+ })
+));
+}
+}
namespace {
struct LogDTrickParameter final {
@@ -31,6 +46,7 @@ namespace caffe2 {
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_PARAMETERS(
ops::SigmoidCrossEntropyWithLogits,
C10SigmoidCrossEntropyWithLogits_DontUseThisOpYet,
+ 1,
ParameterHelper<LogDTrickParameter>,
ParameterHelper<UnjoinedLRLossParameter>)
}
diff --git a/caffe2/operators/experimental/c10/schemas/sigmoid_cross_entropy_with_logits.h b/caffe2/operators/experimental/c10/schemas/sigmoid_cross_entropy_with_logits.h
index 7fb7a88ece..671c2e2523 100644
--- a/caffe2/operators/experimental/c10/schemas/sigmoid_cross_entropy_with_logits.h
+++ b/caffe2/operators/experimental/c10/schemas/sigmoid_cross_entropy_with_logits.h
@@ -1,28 +1,11 @@
#pragma once
-#include <ATen/core/Tensor.h>
-#include <c10/util/Array.h>
+#include <ATen/core/dispatch/OpSchemaRegistration.h>
namespace caffe2 {
namespace ops {
-struct SigmoidCrossEntropyWithLogits final {
- static constexpr const char* name = "sigmoid_cross_entropy_with_logits";
-
- using Signature = void(
- const at::Tensor& input1,
- const at::Tensor& input2,
- const at::Tensor& output,
- bool log_D_trick,
- bool unjoined_lr_loss);
-
- static constexpr size_t num_dispatch_args() {return 2;}
-
- static constexpr size_t num_outputs() {return 1;}
-
- static constexpr c10::guts::array<const char*, 5> parameter_names = {
- {"input1", "input2", "output", "log_d_trick", "unjoined_lr_loss"}};
-};
+C10_DECLARE_OP_SCHEMA(SigmoidCrossEntropyWithLogits);
} // namespace ops
} // namespace caffe2
diff --git a/caffe2/operators/experimental/c10/schemas/sparse_lengths_sum.cc b/caffe2/operators/experimental/c10/schemas/sparse_lengths_sum.cc
index a2e70180c5..1dd558b4ee 100644
--- a/caffe2/operators/experimental/c10/schemas/sparse_lengths_sum.cc
+++ b/caffe2/operators/experimental/c10/schemas/sparse_lengths_sum.cc
@@ -4,10 +4,25 @@
using caffe2::CPUContext;
-C10_DEFINE_OP_SCHEMA(caffe2::ops::SparseLengthsSum);
+namespace caffe2 {
+namespace ops {
+// TODO Parse schema string instead of creating FunctionSchema manually
+C10_DEFINE_OP_SCHEMA(SparseLengthsSum, FunctionSchema(
+ "_c10_experimental::SparseLengthsSum",
+ (std::vector<c10::Argument>{
+ c10::Argument("data"),
+ c10::Argument("indices"),
+ c10::Argument("lengths"),
+ c10::Argument("output")
+ }), (std::vector<c10::Argument>{
+ })
+));
+}
+}
namespace caffe2 {
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH(
ops::SparseLengthsSum,
- C10SparseLengthsSum_DontUseThisOpYet)
+ C10SparseLengthsSum_DontUseThisOpYet,
+ 1)
}
diff --git a/caffe2/operators/experimental/c10/schemas/sparse_lengths_sum.h b/caffe2/operators/experimental/c10/schemas/sparse_lengths_sum.h
index 16d23e733e..a4054e1aae 100644
--- a/caffe2/operators/experimental/c10/schemas/sparse_lengths_sum.h
+++ b/caffe2/operators/experimental/c10/schemas/sparse_lengths_sum.h
@@ -1,27 +1,11 @@
#pragma once
-#include <ATen/core/Tensor.h>
-#include <c10/util/Array.h>
+#include <ATen/core/dispatch/OpSchemaRegistration.h>
namespace caffe2 {
namespace ops {
-struct SparseLengthsSum final {
- static constexpr const char* name = "sparse_lengths_sum";
-
- using Signature = void(
- const at::Tensor& data,
- const at::Tensor& indices,
- const at::Tensor& lengths,
- const at::Tensor& output);
-
- static constexpr size_t num_dispatch_args() {return 3;}
-
- static constexpr size_t num_outputs() {return 1;}
-
- static constexpr c10::guts::array<const char*, 4> parameter_names = {
- {"data", "indices", "lengths", "output"}};
-};
+C10_DECLARE_OP_SCHEMA(SparseLengthsSum);
} // namespace ops
} // namespace caffe2
diff --git a/caffe2/operators/experimental/c10/schemas/stop_gradient.cc b/caffe2/operators/experimental/c10/schemas/stop_gradient.cc
index 630e74056b..4c26785e29 100644
--- a/caffe2/operators/experimental/c10/schemas/stop_gradient.cc
+++ b/caffe2/operators/experimental/c10/schemas/stop_gradient.cc
@@ -4,10 +4,24 @@
using caffe2::CPUContext;
-C10_DEFINE_OP_SCHEMA(caffe2::ops::StopGradient);
+namespace caffe2 {
+namespace ops {
+// TODO Parse schema string instead of creating FunctionSchema manually
+C10_DEFINE_OP_SCHEMA(StopGradient, FunctionSchema(
+ "_c10_experimental::StopGradient",
+ (std::vector<c10::Argument>{
+ c10::Argument("input"),
+ c10::Argument("output")
+ }), (std::vector<c10::Argument>{
+ })
+));
+}
+}
namespace caffe2 {
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH(
ops::StopGradient,
- C10StopGradient_DontUseThisOpYet)
+ C10StopGradient_DontUseThisOpYet,
+ 1
+ )
}
diff --git a/caffe2/operators/experimental/c10/schemas/stop_gradient.h b/caffe2/operators/experimental/c10/schemas/stop_gradient.h
index f89f942eb7..bb130e2c8c 100644
--- a/caffe2/operators/experimental/c10/schemas/stop_gradient.h
+++ b/caffe2/operators/experimental/c10/schemas/stop_gradient.h
@@ -1,26 +1,11 @@
#pragma once
-#include <ATen/core/Tensor.h>
-#include <c10/util/Array.h>
-#include "caffe2/core/context_base.h"
+#include <ATen/core/dispatch/OpSchemaRegistration.h>
namespace caffe2 {
namespace ops {
-struct StopGradient final {
- static constexpr const char* name = "stop_gradient";
-
- using Signature = void(
- const at::Tensor& input,
- const at::Tensor& output);
-
- static constexpr size_t num_dispatch_args() {return 1;}
-
- static constexpr size_t num_outputs() {return 1;}
-
- static constexpr c10::guts::array<const char*, 2> parameter_names = {
- {"input", "output"}};
-};
+C10_DECLARE_OP_SCHEMA(StopGradient);
} // namespace ops
} // namespace caffe2
diff --git a/caffe2/operators/layer_norm_op.cc b/caffe2/operators/layer_norm_op.cc
index 1b096f4dc8..f452f4bf3c 100644
--- a/caffe2/operators/layer_norm_op.cc
+++ b/caffe2/operators/layer_norm_op.cc
@@ -186,15 +186,15 @@ to the end.)
// Register layer norm with c10
namespace {
-struct State final : public c10::KernelState {
+struct Cache final : public c10::KernelCache {
at::optional<at::Tensor> scale = at::nullopt;
at::optional<at::Tensor> bias = at::nullopt;
};
template <class DataType>
-void layer_norm_c10(c10::Stack* stack, c10::KernelState* state) { // TODO Pass in correct state type
+void layer_norm_c10(c10::Stack* stack, c10::KernelCache* cache_) { // TODO Pass in correct cache type
c10::ArrayRef<c10::IValue> inputs = torch::jit::peekSlice(*stack, 0, 3, 6);
- c10::ArrayRef<c10::IValue> outputs = torch::jit::peekSlice(*stack, 0, 3, 3);
+ c10::ArrayRef<c10::IValue> outputs = torch::jit::peekSlice(*stack, 3, 3, 6);
caffe2::Tensor X{c10::C10Tensor(inputs[0].toTensor())};
int64_t axis = inputs[1].toInt();
@@ -204,7 +204,7 @@ void layer_norm_c10(c10::Stack* stack, c10::KernelState* state) { // TODO Pass i
caffe2::Tensor sig{c10::C10Tensor(outputs[2].toTensor())};
caffe2::CPUContext context;
- State* cache = static_cast<State*>(state);
+ Cache* cache = static_cast<Cache*>(cache_);
if (!cache->scale.has_value()) {
cache->scale = at::Tensor(caffe2::empty({0}, at::dtype<float>()));
}
@@ -224,19 +224,19 @@ void layer_norm_c10(c10::Stack* stack, c10::KernelState* state) { // TODO Pass i
X, &Y, &mean, &sig, canonical_axis, epsilon, &scale, &bias, static_cast<caffe2::CPUContext*>(&context)
);
- torch::jit::peek(*stack, 0, 3) = at::Tensor(c10::C10Tensor(std::move(Y)));
- torch::jit::peek(*stack, 1, 3) = at::Tensor(c10::C10Tensor(std::move(mean)));
- torch::jit::peek(*stack, 2, 3) = at::Tensor(c10::C10Tensor(std::move(sig)));
+ torch::jit::drop(*stack, 6);
+ torch::jit::push(*stack,
+ at::Tensor(c10::C10Tensor(std::move(Y))),
+ at::Tensor(c10::C10Tensor(std::move(mean))),
+ at::Tensor(c10::C10Tensor(std::move(sig)))
+ );
return;
}
}
namespace c10 {
C10_REGISTER_KERNEL(c10::core::opschema::LayerNorm)
- .withState<State>()
+ .withCache<Cache>()
.kernel<&layer_norm_c10<float>>()
- .dispatchKey(c10::DispatchKey<1>{
- c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU,
- LayoutId(0),
- caffe2::TypeMeta::Id<float>()}});
+ .dispatchKey(CPUTensorId());
} // namespace c10
diff --git a/torch/csrc/jit/c10_ops/layer_norm.cpp b/torch/csrc/jit/c10_ops/layer_norm.cpp
index 7375645fac..02fdf89af4 100644
--- a/torch/csrc/jit/c10_ops/layer_norm.cpp
+++ b/torch/csrc/jit/c10_ops/layer_norm.cpp
@@ -27,27 +27,34 @@ namespace jit {
namespace {
RegisterOperators reg({
Operator(
- "caffe2::layer_norm_dont_use_this_op_yet(Tensor input, int axis, float epsilon) -> (Tensor, Tensor, Tensor)",
+ //Note: This schema is: caffe2::layer_norm_dont_use_this_op_yet(Tensor input, int axis, float epsilon, Tensor? output = None, Tensor? output_mean = None, Tensor? output_stdev = None) -> (Tensor, Tensor, Tensor)
+ c10::core::opschema::LayerNorm().schema(),
[](Stack& stack) {
- Tensor tensor_input = std::move(stack[stack.size()-3]).toTensor();
+ Tensor tensor_input = std::move(stack[stack.size()-6]).toTensor();
if (tensor_input.requires_grad()) {
throw std::runtime_error("Autograd not yet supported for c10 ops.");
}
auto device = tensor_input.device();
- torch::jit::peek(stack, 0, 3) = torch::autograd::Variable(std::move(tensor_input)).data();
- // push output fields as outputs to stack
- push(stack, at::empty({0}, device), at::empty({0}, device), at::empty({0}, device));
+ // unwrap inputs from variable
+ torch::jit::peek(stack, 0, 6) = torch::autograd::Variable(std::move(tensor_input)).data();
- c10::Dispatcher<c10::core::opschema::LayerNorm>::lookup(&stack).call(&stack);
+ // allocate the output tensors that aren't set yet
+ for (int i = 3; i < 6; ++i) {
+ // TODO this should just check for isNone, not for undefined tensor. @wanchaol is working on this.
+ if (torch::jit::peek(stack, i, 6).isNone() || !torch::jit::peek(stack, i, 6).toTensor().defined()) {
+ torch::jit::peek(stack, i, 6) = at::empty({0}, device);
+ }
+ }
+
+ // call caffe2 kernel
+ c10::Dispatcher::singleton().lookup(c10::core::opschema::LayerNorm(), &stack).call(&stack);
- // move outputs down the stack to where the inputs were before
+ // wrap outputs into Variable
for (int i = 0; i < 3; ++i) {
- torch::jit::peek(stack, i, 6) = torch::autograd::make_variable(std::move(torch::jit::peek(stack, i, 3)).toTensor(), false);
+ torch::jit::peek(stack, i, 3) = torch::autograd::make_variable(std::move(torch::jit::peek(stack, i, 3)).toTensor(), false);
}
- drop(stack, 3); // drop inputs
-
return 0;
})
});