diff options
author | Sebastian Messmer <messmer@fb.com> | 2019-01-08 20:22:41 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-01-08 20:31:43 -0800 |
commit | d562840910b8743b6ea476a47c4df53a8531fd14 (patch) | |
tree | 71b9ddf1d85c8a697987697f325f807623be5acc /c10 | |
parent | 8ac55a6812884d76d6116aa72aa7beb4a6bda832 (diff) | |
download | pytorch-d562840910b8743b6ea476a47c4df53a8531fd14.tar.gz pytorch-d562840910b8743b6ea476a47c4df53a8531fd14.tar.bz2 pytorch-d562840910b8743b6ea476a47c4df53a8531fd14.zip |
Use C10Tensor in the dispatcher (#15195)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15195
This removes the use of caffe2::Tensor or at::Tensor in the c10 dispatcher and only uses C10::Tensor.
It also changes output tensors to be passed as `const Tensor&` instead of `Tensor*` because we otherwise can't forward them in operator_c10wrapper.h.
Reviewed By: ezyang
Differential Revision: D13461640
fbshipit-source-id: 7f79925a7d60f01660a24bbfda47391af0c70ed3
Diffstat (limited to 'c10')
-rw-r--r-- | c10/CMakeLists.txt | 1 | ||||
-rw-r--r-- | c10/core/Tensor.h | 4 | ||||
-rw-r--r-- | c10/core/dispatch/OpSchema.h | 44 | ||||
-rw-r--r-- | c10/test/dispatch/OpSchema_test.cpp | 23 |
4 files changed, 38 insertions, 34 deletions
diff --git a/c10/CMakeLists.txt b/c10/CMakeLists.txt index e80dd795c2..514aeee079 100644 --- a/c10/CMakeLists.txt +++ b/c10/CMakeLists.txt @@ -29,6 +29,7 @@ file(GLOB C10_SRCS *.cpp core/*.cpp core/dispatch/*.cpp + core/opschema/*.cpp impl/*.cpp macros/*.cpp util/*.cpp diff --git a/c10/core/Tensor.h b/c10/core/Tensor.h index c8f9c11258..461c1a8c6e 100644 --- a/c10/core/Tensor.h +++ b/c10/core/Tensor.h @@ -27,7 +27,7 @@ public: C10Tensor& operator=(C10Tensor&&) noexcept = default; const TensorImplPtr &impl() const & noexcept; - TensorImplPtr impl() && noexcept; + TensorImplPtr&& impl() && noexcept; TensorTypeId type_id() const; @@ -42,7 +42,7 @@ inline const C10Tensor::TensorImplPtr &C10Tensor::impl() const & noexcept { return impl_; } -inline C10Tensor::TensorImplPtr C10Tensor::impl() && noexcept { +inline C10Tensor::TensorImplPtr&& C10Tensor::impl() && noexcept { return std::move(impl_); } diff --git a/c10/core/dispatch/OpSchema.h b/c10/core/dispatch/OpSchema.h index b9658d6fb9..0825e1b654 100644 --- a/c10/core/dispatch/OpSchema.h +++ b/c10/core/dispatch/OpSchema.h @@ -4,10 +4,7 @@ #include <c10/util/Array.h> #include <c10/util/Metaprogramming.h> #include <c10/DeviceType.h> - -namespace caffe2 { -class Tensor; -} // namespace caffe2 +#include <c10/core/Tensor.h> namespace c10 { @@ -19,7 +16,7 @@ namespace details { */ template <class Arg> using is_tensor_arg = std:: - is_same<caffe2::Tensor, guts::remove_cv_t<guts::remove_reference_t<Arg>>>; + is_same<C10Tensor, guts::remove_cv_t<guts::remove_reference_t<Arg>>>; inline DeviceTypeId to_device_type_id(DeviceType device_type) { switch (device_type) { @@ -32,16 +29,18 @@ inline DeviceTypeId to_device_type_id(DeviceType device_type) { } } -// TODO get rid of tensor_to_dispatch_key once c2::Tensor is de-templatized. This then fits into a template lambda instead of a functor. -struct tensor_to_dispatch_key final { - template<class TensorType> - TensorParameterDispatchKey operator()(const TensorType& tensor) const { - return TensorParameterDispatchKey{ - to_device_type_id(tensor.GetDeviceType()), - LayoutId(0), - tensor.dtype().id()}; - } -}; +inline TensorParameterDispatchKey tensor_to_dispatch_key(const C10Tensor& tensor) { + return TensorParameterDispatchKey{ + to_device_type_id(tensor.impl()->device_type()), + LayoutId(0), + tensor.impl()->dtype().id()}; +} + +// Extract type ids for all tensors from an array of tensors +template<size_t num_dispatch_args, size_t num_tensor_args, size_t... indices> +guts::array<TensorParameterDispatchKey, num_dispatch_args> getDispatchTypeIds__(const guts::array<const C10Tensor*, num_tensor_args>& tensor_args, guts::index_sequence<indices...>) { + return {tensor_to_dispatch_key(*tensor_args[indices])...}; +} /** * Extract the type ids of all tensors in a variadic list of arguments @@ -50,12 +49,13 @@ struct tensor_to_dispatch_key final { * @param args List of arguments to get type ids from * @return guts::array<TensorParameterDispatchKey, n>, where n is the number of tensor arguments (is_tensor_arg) in the class */ -template<class... Args> auto getTensorTypeIds_(const Args&... args) --> guts::array<TensorParameterDispatchKey, guts::typelist::count_if<is_tensor_arg, guts::typelist::typelist<Args...>>::value> { - return guts::filter_map<TensorParameterDispatchKey, is_tensor_arg>(tensor_to_dispatch_key(), args...); +template<size_t num_dispatch_args, class... Args> +guts::array<TensorParameterDispatchKey, num_dispatch_args> getDispatchTypeIds_(const Args&... args) { + auto tensor_args = guts::filter_map<const C10Tensor*, is_tensor_arg>([] (const C10Tensor& v){return &v;}, args...); + return getDispatchTypeIds__<num_dispatch_args>(tensor_args, guts::make_index_sequence<num_dispatch_args>()); } -// TODO Test getTensorTypeIds_ +// TODO Test getDispatchTypeIds_ /** * If T is a struct with a type field Signature, provides the member constant @@ -121,6 +121,8 @@ public: */ static constexpr size_t num_tensor_args = guts::typelist::count_if<details::is_tensor_arg, parameter_types>::value; + static constexpr size_t num_outputs = OpSchemaDef::num_outputs(); + private: static_assert(details::has_parameter_names_defined<OpSchemaDef>::value, "Operator schema doesn't define parameter_names member."); // TODO Allow simpler definition of parameter_names without having to spell out the guts::array type in the schema def. @@ -165,7 +167,7 @@ class OpDispatchKeySchema<OpSchemaDef, guts::enable_if_t<!has_function_dispatch_ // TODO Use an ADL-based debugString(DispatchKey) function instead of operator<< for debug printing. public: - using dispatch_key_type = DispatchKey<signature::num_tensor_args>; + using dispatch_key_type = DispatchKey<OpSchemaDef::num_dispatch_args()>; template<class... Args> static inline dispatch_key_type dispatch_key(const Args&... args) { @@ -176,7 +178,7 @@ public: map_t<guts::remove_cv_t, map_t<guts::remove_reference_t, typename signature::parameter_types>> >::value, "Invalid argument types passed to OpSchema::dispatch_key()"); return dispatch_key_type { - details::getTensorTypeIds_(args...) + details::getDispatchTypeIds_<OpSchemaDef::num_dispatch_args()>(args...) }; } }; diff --git a/c10/test/dispatch/OpSchema_test.cpp b/c10/test/dispatch/OpSchema_test.cpp index 3a56ff9fc5..d10f7fc223 100644 --- a/c10/test/dispatch/OpSchema_test.cpp +++ b/c10/test/dispatch/OpSchema_test.cpp @@ -1,26 +1,27 @@ -#include "c10/core/dispatch/OpSchema.h" +#include <c10/core/dispatch/OpSchema.h> #include <c10/util/Array.h> using namespace c10; -using namespace caffe2; -static_assert(details::is_tensor_arg<Tensor>::value, ""); -static_assert(details::is_tensor_arg<const Tensor&>::value, ""); -static_assert(details::is_tensor_arg<Tensor&&>::value, ""); +static_assert(details::is_tensor_arg<C10Tensor>::value, ""); +static_assert(details::is_tensor_arg<const C10Tensor&>::value, ""); +static_assert(details::is_tensor_arg<C10Tensor&&>::value, ""); static_assert(!details::is_tensor_arg<int>::value, ""); struct SchemaDef final { - using Signature = bool(int, Tensor, float, Tensor, Tensor, unsigned int); + using Signature = bool(int, C10Tensor, float, C10Tensor, C10Tensor, unsigned int); static constexpr guts::array<const char*, 6> parameter_names = {{ "1", "2", "3", "4", "5", "6" }}; + static constexpr size_t num_dispatch_args() {return 3;} + static constexpr size_t num_outputs() {return 0;} }; -static_assert(6 == OpSchema<SchemaDef>::signature::num_args, "test num_dispatch_args"); -static_assert(3 == OpSchema<SchemaDef>::signature::num_tensor_args, "test num_dispatch_args"); -static_assert(std::is_same<bool, typename OpSchema<SchemaDef>::signature::return_type>::value, "test num_dispatch_args"); +static_assert(6 == OpSchema<SchemaDef>::signature::num_args, ""); +static_assert(3 == OpSchema<SchemaDef>::signature::num_tensor_args, ""); +static_assert(std::is_same<bool, typename OpSchema<SchemaDef>::signature::return_type>::value, ""); static_assert( std::is_same< guts::typelist:: - typelist<int, Tensor, float, Tensor, Tensor, unsigned int>, + typelist<int, C10Tensor, float, C10Tensor, C10Tensor, unsigned int>, typename OpSchema<SchemaDef>::signature::parameter_types>::value, - "test num_dispatch_args"); + ""); |