summaryrefslogtreecommitdiff
path: root/c10
diff options
context:
space:
mode:
authorSebastian Messmer <messmer@fb.com>2019-01-08 20:22:41 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-01-08 20:31:43 -0800
commitd562840910b8743b6ea476a47c4df53a8531fd14 (patch)
tree71b9ddf1d85c8a697987697f325f807623be5acc /c10
parent8ac55a6812884d76d6116aa72aa7beb4a6bda832 (diff)
downloadpytorch-d562840910b8743b6ea476a47c4df53a8531fd14.tar.gz
pytorch-d562840910b8743b6ea476a47c4df53a8531fd14.tar.bz2
pytorch-d562840910b8743b6ea476a47c4df53a8531fd14.zip
Use C10Tensor in the dispatcher (#15195)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15195 This removes the use of caffe2::Tensor or at::Tensor in the c10 dispatcher and only uses C10::Tensor. It also changes output tensors to be passed as `const Tensor&` instead of `Tensor*` because we otherwise can't forward them in operator_c10wrapper.h. Reviewed By: ezyang Differential Revision: D13461640 fbshipit-source-id: 7f79925a7d60f01660a24bbfda47391af0c70ed3
Diffstat (limited to 'c10')
-rw-r--r--c10/CMakeLists.txt1
-rw-r--r--c10/core/Tensor.h4
-rw-r--r--c10/core/dispatch/OpSchema.h44
-rw-r--r--c10/test/dispatch/OpSchema_test.cpp23
4 files changed, 38 insertions, 34 deletions
diff --git a/c10/CMakeLists.txt b/c10/CMakeLists.txt
index e80dd795c2..514aeee079 100644
--- a/c10/CMakeLists.txt
+++ b/c10/CMakeLists.txt
@@ -29,6 +29,7 @@ file(GLOB C10_SRCS
*.cpp
core/*.cpp
core/dispatch/*.cpp
+ core/opschema/*.cpp
impl/*.cpp
macros/*.cpp
util/*.cpp
diff --git a/c10/core/Tensor.h b/c10/core/Tensor.h
index c8f9c11258..461c1a8c6e 100644
--- a/c10/core/Tensor.h
+++ b/c10/core/Tensor.h
@@ -27,7 +27,7 @@ public:
C10Tensor& operator=(C10Tensor&&) noexcept = default;
const TensorImplPtr &impl() const & noexcept;
- TensorImplPtr impl() && noexcept;
+ TensorImplPtr&& impl() && noexcept;
TensorTypeId type_id() const;
@@ -42,7 +42,7 @@ inline const C10Tensor::TensorImplPtr &C10Tensor::impl() const & noexcept {
return impl_;
}
-inline C10Tensor::TensorImplPtr C10Tensor::impl() && noexcept {
+inline C10Tensor::TensorImplPtr&& C10Tensor::impl() && noexcept {
return std::move(impl_);
}
diff --git a/c10/core/dispatch/OpSchema.h b/c10/core/dispatch/OpSchema.h
index b9658d6fb9..0825e1b654 100644
--- a/c10/core/dispatch/OpSchema.h
+++ b/c10/core/dispatch/OpSchema.h
@@ -4,10 +4,7 @@
#include <c10/util/Array.h>
#include <c10/util/Metaprogramming.h>
#include <c10/DeviceType.h>
-
-namespace caffe2 {
-class Tensor;
-} // namespace caffe2
+#include <c10/core/Tensor.h>
namespace c10 {
@@ -19,7 +16,7 @@ namespace details {
*/
template <class Arg>
using is_tensor_arg = std::
- is_same<caffe2::Tensor, guts::remove_cv_t<guts::remove_reference_t<Arg>>>;
+ is_same<C10Tensor, guts::remove_cv_t<guts::remove_reference_t<Arg>>>;
inline DeviceTypeId to_device_type_id(DeviceType device_type) {
switch (device_type) {
@@ -32,16 +29,18 @@ inline DeviceTypeId to_device_type_id(DeviceType device_type) {
}
}
-// TODO get rid of tensor_to_dispatch_key once c2::Tensor is de-templatized. This then fits into a template lambda instead of a functor.
-struct tensor_to_dispatch_key final {
- template<class TensorType>
- TensorParameterDispatchKey operator()(const TensorType& tensor) const {
- return TensorParameterDispatchKey{
- to_device_type_id(tensor.GetDeviceType()),
- LayoutId(0),
- tensor.dtype().id()};
- }
-};
+inline TensorParameterDispatchKey tensor_to_dispatch_key(const C10Tensor& tensor) {
+ return TensorParameterDispatchKey{
+ to_device_type_id(tensor.impl()->device_type()),
+ LayoutId(0),
+ tensor.impl()->dtype().id()};
+}
+
+// Extract type ids for all tensors from an array of tensors
+template<size_t num_dispatch_args, size_t num_tensor_args, size_t... indices>
+guts::array<TensorParameterDispatchKey, num_dispatch_args> getDispatchTypeIds__(const guts::array<const C10Tensor*, num_tensor_args>& tensor_args, guts::index_sequence<indices...>) {
+ return {tensor_to_dispatch_key(*tensor_args[indices])...};
+}
/**
* Extract the type ids of all tensors in a variadic list of arguments
@@ -50,12 +49,13 @@ struct tensor_to_dispatch_key final {
* @param args List of arguments to get type ids from
* @return guts::array<TensorParameterDispatchKey, n>, where n is the number of tensor arguments (is_tensor_arg) in the class
*/
-template<class... Args> auto getTensorTypeIds_(const Args&... args)
--> guts::array<TensorParameterDispatchKey, guts::typelist::count_if<is_tensor_arg, guts::typelist::typelist<Args...>>::value> {
- return guts::filter_map<TensorParameterDispatchKey, is_tensor_arg>(tensor_to_dispatch_key(), args...);
+template<size_t num_dispatch_args, class... Args>
+guts::array<TensorParameterDispatchKey, num_dispatch_args> getDispatchTypeIds_(const Args&... args) {
+ auto tensor_args = guts::filter_map<const C10Tensor*, is_tensor_arg>([] (const C10Tensor& v){return &v;}, args...);
+ return getDispatchTypeIds__<num_dispatch_args>(tensor_args, guts::make_index_sequence<num_dispatch_args>());
}
-// TODO Test getTensorTypeIds_
+// TODO Test getDispatchTypeIds_
/**
* If T is a struct with a type field Signature, provides the member constant
@@ -121,6 +121,8 @@ public:
*/
static constexpr size_t num_tensor_args = guts::typelist::count_if<details::is_tensor_arg, parameter_types>::value;
+ static constexpr size_t num_outputs = OpSchemaDef::num_outputs();
+
private:
static_assert(details::has_parameter_names_defined<OpSchemaDef>::value, "Operator schema doesn't define parameter_names member.");
// TODO Allow simpler definition of parameter_names without having to spell out the guts::array type in the schema def.
@@ -165,7 +167,7 @@ class OpDispatchKeySchema<OpSchemaDef, guts::enable_if_t<!has_function_dispatch_
// TODO Use an ADL-based debugString(DispatchKey) function instead of operator<< for debug printing.
public:
- using dispatch_key_type = DispatchKey<signature::num_tensor_args>;
+ using dispatch_key_type = DispatchKey<OpSchemaDef::num_dispatch_args()>;
template<class... Args>
static inline dispatch_key_type dispatch_key(const Args&... args) {
@@ -176,7 +178,7 @@ public:
map_t<guts::remove_cv_t, map_t<guts::remove_reference_t, typename signature::parameter_types>>
>::value, "Invalid argument types passed to OpSchema::dispatch_key()");
return dispatch_key_type {
- details::getTensorTypeIds_(args...)
+ details::getDispatchTypeIds_<OpSchemaDef::num_dispatch_args()>(args...)
};
}
};
diff --git a/c10/test/dispatch/OpSchema_test.cpp b/c10/test/dispatch/OpSchema_test.cpp
index 3a56ff9fc5..d10f7fc223 100644
--- a/c10/test/dispatch/OpSchema_test.cpp
+++ b/c10/test/dispatch/OpSchema_test.cpp
@@ -1,26 +1,27 @@
-#include "c10/core/dispatch/OpSchema.h"
+#include <c10/core/dispatch/OpSchema.h>
#include <c10/util/Array.h>
using namespace c10;
-using namespace caffe2;
-static_assert(details::is_tensor_arg<Tensor>::value, "");
-static_assert(details::is_tensor_arg<const Tensor&>::value, "");
-static_assert(details::is_tensor_arg<Tensor&&>::value, "");
+static_assert(details::is_tensor_arg<C10Tensor>::value, "");
+static_assert(details::is_tensor_arg<const C10Tensor&>::value, "");
+static_assert(details::is_tensor_arg<C10Tensor&&>::value, "");
static_assert(!details::is_tensor_arg<int>::value, "");
struct SchemaDef final {
- using Signature = bool(int, Tensor, float, Tensor, Tensor, unsigned int);
+ using Signature = bool(int, C10Tensor, float, C10Tensor, C10Tensor, unsigned int);
static constexpr guts::array<const char*, 6> parameter_names = {{
"1", "2", "3", "4", "5", "6"
}};
+ static constexpr size_t num_dispatch_args() {return 3;}
+ static constexpr size_t num_outputs() {return 0;}
};
-static_assert(6 == OpSchema<SchemaDef>::signature::num_args, "test num_dispatch_args");
-static_assert(3 == OpSchema<SchemaDef>::signature::num_tensor_args, "test num_dispatch_args");
-static_assert(std::is_same<bool, typename OpSchema<SchemaDef>::signature::return_type>::value, "test num_dispatch_args");
+static_assert(6 == OpSchema<SchemaDef>::signature::num_args, "");
+static_assert(3 == OpSchema<SchemaDef>::signature::num_tensor_args, "");
+static_assert(std::is_same<bool, typename OpSchema<SchemaDef>::signature::return_type>::value, "");
static_assert(
std::is_same<
guts::typelist::
- typelist<int, Tensor, float, Tensor, Tensor, unsigned int>,
+ typelist<int, C10Tensor, float, C10Tensor, C10Tensor, unsigned int>,
typename OpSchema<SchemaDef>::signature::parameter_types>::value,
- "test num_dispatch_args");
+ "");