summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSebastian Messmer <messmer@fb.com>2019-04-22 16:16:30 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-04-22 16:31:28 -0700
commit969af4315a30e96205e125a16e67bd6e3c03e218 (patch)
tree20e9236a36fc2d484913107ebbe69accf9c1a542
parent8abab61d396c8ced5fefbd1c8c804edb596662ec (diff)
downloadpytorch-969af4315a30e96205e125a16e67bd6e3c03e218.tar.gz
pytorch-969af4315a30e96205e125a16e67bd6e3c03e218.tar.bz2
pytorch-969af4315a30e96205e125a16e67bd6e3c03e218.zip
Explicitly define supported types (#19516)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19516 Explicitly define types that are supported in kernel inputs and outputs. Also, this allows us to show much nicer error messages if a user writes kernels with wrong argument types. Reviewed By: ezyang Differential Revision: D15020306 fbshipit-source-id: 55ebec81e075e874777acd59aa29a5578fc19ef7
-rw-r--r--aten/src/ATen/core/op_registration/kernel_functor.h92
-rw-r--r--aten/src/ATen/core/op_registration/op_registration_test.cpp5
-rw-r--r--c10/test/util/TypeList_test.cpp7
-rw-r--r--c10/util/TypeList.h17
4 files changed, 107 insertions, 14 deletions
diff --git a/aten/src/ATen/core/op_registration/kernel_functor.h b/aten/src/ATen/core/op_registration/kernel_functor.h
index d589362817..52cf47b1e2 100644
--- a/aten/src/ATen/core/op_registration/kernel_functor.h
+++ b/aten/src/ATen/core/op_registration/kernel_functor.h
@@ -25,13 +25,29 @@ namespace c10 {
class OperatorKernel : public KernelCache {};
namespace detail {
+ // supported_primitive_arg_types defines which primitive types we allow in
+ // kernel functions as arguments or returns.
+ // Additionally, we support lists, dicts and optionals containing these types.
+ using supported_primitive_arg_types = guts::typelist::typelist<
+ int64_t,
+ double,
+ bool,
+ std::string,
+ at::Tensor,
+ at::Scalar
+ >;
+
// ivalue_to_arg_type<T>: Take an IValue that is an argument to a kernel and
// cast it to the type that should be passed to the kernel function.
// Examples: If the IValue contains a plain type like an int, return that.
// If the IValue contains an IntList, return it as ArrayRef<int>.
// TODO Should we move the IValue so we can avoid bumping the Tensor refcount?
+ template<class T, class Enable = void> struct ivalue_to_arg_type {
+ // This base case is hit whenever a type does not have a specialisation below.
+ static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported argument type.");
+ };
template<class T>
- struct ivalue_to_arg_type {
+ struct ivalue_to_arg_type<T, guts::enable_if_t<guts::typelist::contains<supported_primitive_arg_types, T>::value>> {
static T call(const IValue& v) {
return std::move(v).to<T>();
}
@@ -39,30 +55,51 @@ namespace detail {
template<class T>
struct ivalue_to_arg_type<ArrayRef<T>> {
static ArrayRef<T> call(const IValue& v) {
+ // TODO Do we want to support ArrayRef<optional<T>> ?
+ static_assert(guts::typelist::contains<supported_primitive_arg_types, T>::value, "You tried to register a kernel with an unsupported argument type: c10::ArrayRef<T> and T is not one of the supported primitive types.");
return v.to<intrusive_ptr<ivalue::List<T>>>()->elements();
}
};
template<class T>
- struct ivalue_to_arg_type<std::vector<T>> {
- static ArrayRef<T> call(const IValue& v) {
- // We don't support std::vector because that would prevent us from doing
- // internal optimization to how we represent lists (e.g. SmallVector).
- // Users should use ArrayRef instead.
- static_assert(guts::false_t<std::vector<T>>::value, "You tried to register a kernel with an unsupported argument type: std::vector<T>. Please use c10::ArrayRef<T> instead.");
- }
- };
- template<class T>
struct ivalue_to_arg_type<optional<T>> {
static optional<T> call(const IValue& v) {
if (v.isNone()) {
return nullopt;
}
- return v.to<T>();
+ return ivalue_to_arg_type<T>::call(v);
}
};
-
+ // The following specialisations of ivalue_to_arg_type are technically not
+ // necessary since we would hit the base case and show an error message
+ // there if they didn't exist, but we can show a better error message
+ // in some common error scenarios.
template<class T>
+ struct ivalue_to_arg_type<std::vector<T>> {
+ // We don't support std::vector because that would prevent us from doing
+ // internal optimization to how we represent lists (e.g. SmallVector).
+ // Users should use ArrayRef instead.
+ static_assert(guts::false_t<std::vector<T>>::value, "You tried to register a kernel with an unsupported argument type: std::vector<T>. Please use c10::ArrayRef<T> instead.");
+ };
+ template<class T>
+ struct ivalue_to_arg_type<T, guts::enable_if_t<std::is_same<float, T>::value>> {
+ // There is no reason to support float when we have double. Keep the API lean.
+ static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported argument type: float. Please use double instead.");
+ };
+ template<class T>
+ struct ivalue_to_arg_type<T, guts::enable_if_t<std::is_same<const char*, T>::value>> {
+ static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported argument type: const char*. Please use std::string instead.");
+ };
+ template<class T>
+ struct ivalue_to_arg_type<T, guts::enable_if_t<std::is_integral<T>::value && !guts::typelist::contains<supported_primitive_arg_types, T>::value>> {
+ static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported integral argument type. Please use int64_t instead.");
+ };
+
+ template<class T, class Enable = void>
struct return_type_to_ivalue_ {
+ static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported return type.");
+ };
+ template<class T>
+ struct return_type_to_ivalue_<T, guts::enable_if_t<guts::typelist::contains<supported_primitive_arg_types, T>::value>> {
static IValue call(T&& v) {
return IValue(std::move(v));
}
@@ -73,9 +110,38 @@ namespace detail {
if (!v.has_value()) {
return IValue();
}
- return IValue(std::move(*v));
+ return return_type_to_ivalue_<T>::call(std::move(*v));
+ }
+ };
+ template<class T>
+ struct return_type_to_ivalue_<std::vector<T>> {
+ static IValue call(std::vector<T>&& v) {
+ // TODO Do we want to support vector<optional<T>> ?
+ static_assert(guts::typelist::contains<supported_primitive_arg_types, T>::value, "You tried to register a kernel with an unsupported return type: vector<T> and T is not one of the supported primitive types.");
+ return IValue(std::move(v));
}
};
+ // The following specialisations of return_type_to_ivalue_ are technically not
+ // necessary since we would hit the base case and show an error message
+ // there if they didn't exist, but we can show a better error message
+ // in some common error scenarios.
+ template<class T>
+ struct return_type_to_ivalue_<c10::ArrayRef<T>> {
+ static_assert(guts::false_t<c10::ArrayRef<T>>::value, "You tried to register a kernel with an unsupported return type: c10::ArrayRef<T>. Please use std::vector<T> instead.");
+ };
+ template<class T>
+ struct return_type_to_ivalue_<T, guts::enable_if_t<std::is_same<float, T>::value>> {
+ static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported return type: float. Please use double instead.");
+ };
+ template<class T>
+ struct return_type_to_ivalue_<T, guts::enable_if_t<std::is_same<const char*, T>::value>> {
+ static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported return type: const char*. Please use std::string instead.");
+ };
+ template<class T>
+ struct return_type_to_ivalue_<T, guts::enable_if_t<std::is_integral<T>::value && !guts::typelist::contains<supported_primitive_arg_types, T>::value>> {
+ static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported integral return argument type. Please use int64_t instead.");
+ };
+
template<class T>
IValue return_type_to_ivalue(T&& v) {
return return_type_to_ivalue_<guts::decay_t<T>>::call(std::move(v));
diff --git a/aten/src/ATen/core/op_registration/op_registration_test.cpp b/aten/src/ATen/core/op_registration/op_registration_test.cpp
index 615555bf6c..cef5c2ab79 100644
--- a/aten/src/ATen/core/op_registration/op_registration_test.cpp
+++ b/aten/src/ATen/core/op_registration/op_registration_test.cpp
@@ -265,6 +265,8 @@ private:
};
TEST(OperatorRegistrationTest, testAvailableArgTypes) {
+ // TODO Test Scalar
+
// primitive types
ArgTypeTestKernel<double>::test(
1.5, [] (const double& v) {EXPECT_EQ(1.5, v);},
@@ -481,8 +483,9 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) {
},
"(Tensor[] a) -> Tensor[]");
+ // TODO We support optional of list. Add test cases for it.
- // TODO Do we want to support list of optional / optional of list ?
+ // TODO Do we want to support list of optional ?
// TODO Add tests for dict types
}
diff --git a/c10/test/util/TypeList_test.cpp b/c10/test/util/TypeList_test.cpp
index dd811b5771..9350c7a364 100644
--- a/c10/test/util/TypeList_test.cpp
+++ b/c10/test/util/TypeList_test.cpp
@@ -145,3 +145,10 @@ namespace test_find_if {
static_assert(2 == find_if<typelist<char, int, char&, int&>, std::is_reference>::value, "");
static_assert(3 == find_if<typelist<char, int, char, int&>, std::is_reference>::value, "");
}
+
+namespace test_contains {
+ static_assert(contains<typelist<double>, double>::value, "");
+ static_assert(contains<typelist<int, double>, double>::value, "");
+ static_assert(!contains<typelist<int, double>, float>::value, "");
+ static_assert(!contains<typelist<>, double>::value, "");
+}
diff --git a/c10/util/TypeList.h b/c10/util/TypeList.h
index bb0d6bd7c6..373e669cc1 100644
--- a/c10/util/TypeList.h
+++ b/c10/util/TypeList.h
@@ -132,6 +132,23 @@ struct count_if final {
};
+/**
+ * Checks if a typelist contains a certain type.
+ * Examples:
+ * contains<typelist<int, string>, string> == true_type
+ * contains<typelist<int, string>, double> == false_type
+ */
+namespace detail {
+template<class TypeList, class Type, class Enable = void> struct contains {};
+template<class Type> struct contains<typelist<>, Type, void> : std::false_type {};
+template<class Type, class Head, class... Tail>
+struct contains<typelist<Head, Tail...>, Type, guts::enable_if_t<std::is_same<Head, Type>::value>> : std::true_type {};
+template<class Type, class Head, class... Tail>
+struct contains<typelist<Head, Tail...>, Type, guts::enable_if_t<!std::is_same<Head, Type>::value>> : contains<typelist<Tail...>, Type> {};
+}
+template<class TypeList, class Type>
+using contains = typename detail::contains<TypeList, Type>::type;
+
/**
* Returns true iff the type trait is true for all types in the type list