diff options
author | Sebastian Messmer <messmer@fb.com> | 2019-04-22 16:16:30 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-04-22 16:31:28 -0700 |
commit | 969af4315a30e96205e125a16e67bd6e3c03e218 (patch) | |
tree | 20e9236a36fc2d484913107ebbe69accf9c1a542 | |
parent | 8abab61d396c8ced5fefbd1c8c804edb596662ec (diff) | |
download | pytorch-969af4315a30e96205e125a16e67bd6e3c03e218.tar.gz pytorch-969af4315a30e96205e125a16e67bd6e3c03e218.tar.bz2 pytorch-969af4315a30e96205e125a16e67bd6e3c03e218.zip |
Explicitly define supported types (#19516)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19516
Explicitly define types that are supported in kernel inputs and outputs.
Also, this allows us to show much nicer error messages if a user writes kernels with wrong argument types.
Reviewed By: ezyang
Differential Revision: D15020306
fbshipit-source-id: 55ebec81e075e874777acd59aa29a5578fc19ef7
-rw-r--r-- | aten/src/ATen/core/op_registration/kernel_functor.h | 92 | ||||
-rw-r--r-- | aten/src/ATen/core/op_registration/op_registration_test.cpp | 5 | ||||
-rw-r--r-- | c10/test/util/TypeList_test.cpp | 7 | ||||
-rw-r--r-- | c10/util/TypeList.h | 17 |
4 files changed, 107 insertions, 14 deletions
diff --git a/aten/src/ATen/core/op_registration/kernel_functor.h b/aten/src/ATen/core/op_registration/kernel_functor.h index d589362817..52cf47b1e2 100644 --- a/aten/src/ATen/core/op_registration/kernel_functor.h +++ b/aten/src/ATen/core/op_registration/kernel_functor.h @@ -25,13 +25,29 @@ namespace c10 { class OperatorKernel : public KernelCache {}; namespace detail { + // supported_primitive_arg_types defines which primitive types we allow in + // kernel functions as arguments or returns. + // Additionally, we support lists, dicts and optionals containing these types. + using supported_primitive_arg_types = guts::typelist::typelist< + int64_t, + double, + bool, + std::string, + at::Tensor, + at::Scalar + >; + // ivalue_to_arg_type<T>: Take an IValue that is an argument to a kernel and // cast it to the type that should be passed to the kernel function. // Examples: If the IValue contains a plain type like an int, return that. // If the IValue contains an IntList, return it as ArrayRef<int>. // TODO Should we move the IValue so we can avoid bumping the Tensor refcount? + template<class T, class Enable = void> struct ivalue_to_arg_type { + // This base case is hit whenever a type does not have a specialisation below. + static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported argument type."); + }; template<class T> - struct ivalue_to_arg_type { + struct ivalue_to_arg_type<T, guts::enable_if_t<guts::typelist::contains<supported_primitive_arg_types, T>::value>> { static T call(const IValue& v) { return std::move(v).to<T>(); } @@ -39,30 +55,51 @@ namespace detail { template<class T> struct ivalue_to_arg_type<ArrayRef<T>> { static ArrayRef<T> call(const IValue& v) { + // TODO Do we want to support ArrayRef<optional<T>> ? + static_assert(guts::typelist::contains<supported_primitive_arg_types, T>::value, "You tried to register a kernel with an unsupported argument type: c10::ArrayRef<T> and T is not one of the supported primitive types."); return v.to<intrusive_ptr<ivalue::List<T>>>()->elements(); } }; template<class T> - struct ivalue_to_arg_type<std::vector<T>> { - static ArrayRef<T> call(const IValue& v) { - // We don't support std::vector because that would prevent us from doing - // internal optimization to how we represent lists (e.g. SmallVector). - // Users should use ArrayRef instead. - static_assert(guts::false_t<std::vector<T>>::value, "You tried to register a kernel with an unsupported argument type: std::vector<T>. Please use c10::ArrayRef<T> instead."); - } - }; - template<class T> struct ivalue_to_arg_type<optional<T>> { static optional<T> call(const IValue& v) { if (v.isNone()) { return nullopt; } - return v.to<T>(); + return ivalue_to_arg_type<T>::call(v); } }; - + // The following specialisations of ivalue_to_arg_type are technically not + // necessary since we would hit the base case and show an error message + // there if they didn't exist, but we can show a better error message + // in some common error scenarios. template<class T> + struct ivalue_to_arg_type<std::vector<T>> { + // We don't support std::vector because that would prevent us from doing + // internal optimization to how we represent lists (e.g. SmallVector). + // Users should use ArrayRef instead. + static_assert(guts::false_t<std::vector<T>>::value, "You tried to register a kernel with an unsupported argument type: std::vector<T>. Please use c10::ArrayRef<T> instead."); + }; + template<class T> + struct ivalue_to_arg_type<T, guts::enable_if_t<std::is_same<float, T>::value>> { + // There is no reason to support float when we have double. Keep the API lean. + static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported argument type: float. Please use double instead."); + }; + template<class T> + struct ivalue_to_arg_type<T, guts::enable_if_t<std::is_same<const char*, T>::value>> { + static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported argument type: const char*. Please use std::string instead."); + }; + template<class T> + struct ivalue_to_arg_type<T, guts::enable_if_t<std::is_integral<T>::value && !guts::typelist::contains<supported_primitive_arg_types, T>::value>> { + static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported integral argument type. Please use int64_t instead."); + }; + + template<class T, class Enable = void> struct return_type_to_ivalue_ { + static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported return type."); + }; + template<class T> + struct return_type_to_ivalue_<T, guts::enable_if_t<guts::typelist::contains<supported_primitive_arg_types, T>::value>> { static IValue call(T&& v) { return IValue(std::move(v)); } @@ -73,9 +110,38 @@ namespace detail { if (!v.has_value()) { return IValue(); } - return IValue(std::move(*v)); + return return_type_to_ivalue_<T>::call(std::move(*v)); + } + }; + template<class T> + struct return_type_to_ivalue_<std::vector<T>> { + static IValue call(std::vector<T>&& v) { + // TODO Do we want to support vector<optional<T>> ? + static_assert(guts::typelist::contains<supported_primitive_arg_types, T>::value, "You tried to register a kernel with an unsupported return type: vector<T> and T is not one of the supported primitive types."); + return IValue(std::move(v)); } }; + // The following specialisations of return_type_to_ivalue_ are technically not + // necessary since we would hit the base case and show an error message + // there if they didn't exist, but we can show a better error message + // in some common error scenarios. + template<class T> + struct return_type_to_ivalue_<c10::ArrayRef<T>> { + static_assert(guts::false_t<c10::ArrayRef<T>>::value, "You tried to register a kernel with an unsupported return type: c10::ArrayRef<T>. Please use std::vector<T> instead."); + }; + template<class T> + struct return_type_to_ivalue_<T, guts::enable_if_t<std::is_same<float, T>::value>> { + static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported return type: float. Please use double instead."); + }; + template<class T> + struct return_type_to_ivalue_<T, guts::enable_if_t<std::is_same<const char*, T>::value>> { + static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported return type: const char*. Please use std::string instead."); + }; + template<class T> + struct return_type_to_ivalue_<T, guts::enable_if_t<std::is_integral<T>::value && !guts::typelist::contains<supported_primitive_arg_types, T>::value>> { + static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported integral return argument type. Please use int64_t instead."); + }; + template<class T> IValue return_type_to_ivalue(T&& v) { return return_type_to_ivalue_<guts::decay_t<T>>::call(std::move(v)); diff --git a/aten/src/ATen/core/op_registration/op_registration_test.cpp b/aten/src/ATen/core/op_registration/op_registration_test.cpp index 615555bf6c..cef5c2ab79 100644 --- a/aten/src/ATen/core/op_registration/op_registration_test.cpp +++ b/aten/src/ATen/core/op_registration/op_registration_test.cpp @@ -265,6 +265,8 @@ private: }; TEST(OperatorRegistrationTest, testAvailableArgTypes) { + // TODO Test Scalar + // primitive types ArgTypeTestKernel<double>::test( 1.5, [] (const double& v) {EXPECT_EQ(1.5, v);}, @@ -481,8 +483,9 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) { }, "(Tensor[] a) -> Tensor[]"); + // TODO We support optional of list. Add test cases for it. - // TODO Do we want to support list of optional / optional of list ? + // TODO Do we want to support list of optional ? // TODO Add tests for dict types } diff --git a/c10/test/util/TypeList_test.cpp b/c10/test/util/TypeList_test.cpp index dd811b5771..9350c7a364 100644 --- a/c10/test/util/TypeList_test.cpp +++ b/c10/test/util/TypeList_test.cpp @@ -145,3 +145,10 @@ namespace test_find_if { static_assert(2 == find_if<typelist<char, int, char&, int&>, std::is_reference>::value, ""); static_assert(3 == find_if<typelist<char, int, char, int&>, std::is_reference>::value, ""); } + +namespace test_contains { + static_assert(contains<typelist<double>, double>::value, ""); + static_assert(contains<typelist<int, double>, double>::value, ""); + static_assert(!contains<typelist<int, double>, float>::value, ""); + static_assert(!contains<typelist<>, double>::value, ""); +} diff --git a/c10/util/TypeList.h b/c10/util/TypeList.h index bb0d6bd7c6..373e669cc1 100644 --- a/c10/util/TypeList.h +++ b/c10/util/TypeList.h @@ -132,6 +132,23 @@ struct count_if final { }; +/** + * Checks if a typelist contains a certain type. + * Examples: + * contains<typelist<int, string>, string> == true_type + * contains<typelist<int, string>, double> == false_type + */ +namespace detail { +template<class TypeList, class Type, class Enable = void> struct contains {}; +template<class Type> struct contains<typelist<>, Type, void> : std::false_type {}; +template<class Type, class Head, class... Tail> +struct contains<typelist<Head, Tail...>, Type, guts::enable_if_t<std::is_same<Head, Type>::value>> : std::true_type {}; +template<class Type, class Head, class... Tail> +struct contains<typelist<Head, Tail...>, Type, guts::enable_if_t<!std::is_same<Head, Type>::value>> : contains<typelist<Tail...>, Type> {}; +} +template<class TypeList, class Type> +using contains = typename detail::contains<TypeList, Type>::type; + /** * Returns true iff the type trait is true for all types in the type list |