diff options
author | Bram Wasti <bwasti@fb.com> | 2018-11-07 18:09:33 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-11-07 18:11:29 -0800 |
commit | 16165875406fc5224592fb778c96613161b78dca (patch) | |
tree | 6f2f69f5d04adbe0c15ef21aca1b24607495ab78 | |
parent | 87b47ff850428a546bbcb0b5909a24f4445b5a1b (diff) | |
download | pytorch-16165875406fc5224592fb778c96613161b78dca.tar.gz pytorch-16165875406fc5224592fb778c96613161b78dca.tar.bz2 pytorch-16165875406fc5224592fb778c96613161b78dca.zip |
Redo jit/type and utils/functional to ATen/core (#13455)
Summary:
This is a redo of the previous move which broke OS X and Windows tests -- RTTI seemed to be broken
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13455
Differential Revision: D12883775
Pulled By: bwasti
fbshipit-source-id: 2b6c65e8150e6f89624c6ee99c389335c6fb4bb8
-rw-r--r-- | aten/src/ATen/core/functional.h | 63 | ||||
-rw-r--r-- | aten/src/ATen/core/jit_type.h | 933 | ||||
-rw-r--r-- | aten/src/ATen/core/type.cpp (renamed from torch/csrc/jit/type.cpp) | 12 | ||||
-rw-r--r-- | torch/CMakeLists.txt | 1 | ||||
-rw-r--r-- | torch/csrc/autograd/python_variable_indexing.cpp | 4 | ||||
-rw-r--r-- | torch/csrc/jit/function_schema.h | 2 | ||||
-rw-r--r-- | torch/csrc/jit/python_ir.cpp | 5 | ||||
-rw-r--r-- | torch/csrc/jit/type.h | 938 | ||||
-rw-r--r-- | torch/csrc/utils/functional.h | 63 |
9 files changed, 1026 insertions, 995 deletions
diff --git a/aten/src/ATen/core/functional.h b/aten/src/ATen/core/functional.h new file mode 100644 index 0000000000..e0f8c84253 --- /dev/null +++ b/aten/src/ATen/core/functional.h @@ -0,0 +1,63 @@ +#pragma once + +#include <vector> +#include <ATen/core/ArrayRef.h> + +namespace c10 { + +// The passed in function must take T by value (T), or by +// const reference (const T&); taking T by non-const reference +// will result in an error like: +// +// error: no type named 'type' in 'class std::result_of<foobar::__lambda(T)>' +// +// No explicit template parameters are required. + +// Overload for explicit function and ArrayRef +template<typename F, typename T> +inline auto fmap(const T& inputs, const F& fn) -> std::vector<decltype(fn(*inputs.begin()))> { + std::vector<decltype(fn(*inputs.begin()))> r; + r.reserve(inputs.size()); + for(const auto & input : inputs) + r.push_back(fn(input)); + return r; +} + +template<typename F, typename T> +inline auto fmap(T& inputs, const F& fn) -> std::vector<decltype(fn(*inputs.begin()))> { + std::vector<decltype(fn(*inputs.begin()))> r; + r.reserve(inputs.size()); + for(auto & input : inputs) + r.push_back(fn(input)); + return r; +} + +// C++ forbids taking an address of a constructor, so here's a workaround... +// Overload for constructor (R) application +template<typename R, typename T> +inline std::vector<R> fmap(const T& inputs) { + std::vector<R> r; + r.reserve(inputs.size()); + for(auto & input : inputs) + r.push_back(R(input)); + return r; +} + +template<typename F, typename T> +inline std::vector<T> filter(at::ArrayRef<T> inputs, const F& fn) { + std::vector<T> r; + r.reserve(inputs.size()); + for(auto & input : inputs) { + if (fn(input)) { + r.push_back(input); + } + } + return r; +} + +template<typename F, typename T> +inline std::vector<T> filter(const std::vector<T>& inputs, const F& fn) { + return filter<F, T>(static_cast<at::ArrayRef<T>>(inputs), fn); +} + +} // namespace c10 diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h new file mode 100644 index 0000000000..9307b4fee4 --- /dev/null +++ b/aten/src/ATen/core/jit_type.h @@ -0,0 +1,933 @@ +#pragma once + +#include <ATen/core/ivalue.h> +#include <ATen/core/interned_strings.h> +#include <ATen/core/functional.h> +#include <ATen/core/Type.h> +#include <ATen/core/TensorMethods.h> + +#include <caffe2/core/common.h> + +#include <memory> +#include <iostream> +#include <type_traits> + +namespace c10 { + +#define C10_FORALL_TYPES(_) \ +_(DynamicType) \ +_(TensorType) \ +_(CompleteTensorType) \ +_(UndefinedTensorType) \ +_(TupleType) \ +_(ListType) \ +_(NumberType) \ +_(FloatType) \ +_(FutureType) \ +_(IntType) \ +_(NoneType) \ +_(StringType) \ +_(GeneratorType) \ +_(BoolType) \ +_(OptionalType) \ +_(VarType) \ + +enum class TypeKind { +#define DEFINE_TYPE(T) T, + C10_FORALL_TYPES(DEFINE_TYPE) +#undef DEFINE_TYPE +}; + +#define DEFINE_IS_SUBCLASS(_kind) \ + bool isSubclass(const TypeKind kind) const override { \ + return kind == TypeKind::_kind; \ + } + +struct Type; +using TypePtr = std::shared_ptr<Type>; + +struct CAFFE2_API Type : std::enable_shared_from_this<Type> { +private: + TypeKind kind_; + template<typename T> + static std::shared_ptr<T> sliceType(std::shared_ptr<const T> ptr) { + auto result = std::make_shared<typename std::remove_const<T>::type>(*ptr); + // XXX: the line above will correctly slice the struct, and make its runtype + // type exactly equal to T. However, kind_ is a field of Type, so it will simply + // be copied, and we need to fix it in here to match the dynamic type. + result->kind_ = T::Kind; + return result; + } + +protected: + Type(TypeKind kind) + : kind_(kind) {} + +public: + virtual bool operator==(const Type& rhs) const = 0; + + // subtyping relation. By default, we return true for the case + // when the type is exactly equal + virtual bool isSubtypeOf(const TypePtr rhs) const { + return *this == *rhs; + } + + // If this class can be cast to the kind passed in + // This removes the need for RTTI + virtual bool isSubclass(const TypeKind kind) const = 0; + + // How this type will appear in FunctionSchema declarations + virtual std::string str() const = 0; + + // How this type will appear as if it were a type annotation in Python + // which is sometimes different than how it appears in declarations (e.g. int[] vs List[int]) + virtual std::string python_str() const { + return str(); + } + + TypeKind kind() const { + return kind_; + } + + virtual bool requires_grad() const { return false; } + + // Dynamically cast this object to the subclass indicated by the + // template variable, returning nullptr if the cast is invalid. + // NOTE: if the cast succeeds, but the casted kind is not the + // run-time kind of the type, we also slice the structure, so + // that assignments of those types to values don't accidentally + // inherit more detailed information from subclasses. + template<typename T> + std::shared_ptr<T> cast() { + std::shared_ptr<T> r = nullptr; + if (isSubclass(T::Kind)) { + r = std::static_pointer_cast<T>(shared_from_this()); + } + if (!r || T::Kind == kind()) { + return r; + } else { + return sliceType<T>(r); + } + } + template<typename T> + std::shared_ptr<const T> cast() const { + std::shared_ptr<const T> r = nullptr; + if (isSubclass(T::Kind)) { + r = std::static_pointer_cast<const T>(shared_from_this()); + } + if (!r || T::Kind == kind()) { + return r; + } else { + return sliceType<T>(r); + } + } + template<typename T> + std::shared_ptr<T> expect() { + auto r = cast<T>(); + AT_ASSERT(r); + return r; + } + template<typename T> + std::shared_ptr<const T> expect() const { + auto r = cast<const T>(); + AT_ASSERT(r); + return r; + } + virtual ~Type() = default; + virtual bool hasFreeVariables() const { + return false; + } + // list of types this type contains, e.g. for a List then element type of a list + // for a tuple, the types of the tuple elements + virtual at::ArrayRef<TypePtr> containedTypes() const { + return {}; + } + // create a new version of this type, replacing its contained types with + // contained_types + TypePtr withContained(std::vector<TypePtr> contained_types) { + auto current_contained = containedTypes(); + AT_ASSERT(current_contained.size() == contained_types.size()); + if(current_contained.equals(contained_types)) { + return shared_from_this(); + } + return createWithContained(std::move(contained_types)); + } + // per-type constructor, you only need to override this if the containedTypes() + // is not empty + virtual TypePtr createWithContained(std::vector<TypePtr> contained_types) const { + AT_ERROR("type with contained types did not overload createWithContained: ", str()); + } +}; + +inline bool operator!=(const Type & lhs, const Type & rhs) { + return !(lhs == rhs); +} + +struct OptionalType; +using OptionalTypePtr = std::shared_ptr<OptionalType>; +// This type represents an optional type, for each element type. +struct OptionalType: public Type { + static OptionalTypePtr create(TypePtr element) { + return OptionalTypePtr(new OptionalType(std::move(element))); // NOLINT(modernize-make-shared) + } + DEFINE_IS_SUBCLASS(OptionalType); + bool operator==(const Type& rhs) const override { + if(auto rhs_ = rhs.cast<OptionalType>()) { + return *getElementType() == *rhs_->getElementType(); + } + return false; + } + bool requires_grad() const override { + return elem->requires_grad(); + } + + bool isSubtypeOf(const TypePtr rhs) const override { + if(auto rhs_ = rhs->cast<OptionalType>()) { + return getElementType()->isSubtypeOf(rhs_->getElementType()); + } + return false; + } + + std::string str() const override { + std::stringstream ss; + ss << getElementType()->str() << "?"; + return ss.str(); + } + std::string python_str() const override { + std::stringstream ss; + ss << "Optional[" << getElementType()->python_str() << "]"; + return ss.str(); + } + TypePtr getElementType() const { + return elem; + } + bool hasFreeVariables() const override { + return has_free_variables_; + } + + static const TypeKind Kind = TypeKind::OptionalType; +private: + OptionalType(TypePtr elem) + : Type(TypeKind::OptionalType) + , elem(std::move(elem)) + , has_free_variables_(getElementType()->hasFreeVariables()) {} + TypePtr elem; + bool has_free_variables_; + +}; + +struct DynamicType; +using DynamicTypePtr = std::shared_ptr<DynamicType>; +// This type represents a single Tensor, with an unknown shape. +struct CAFFE2_API DynamicType : public Type { + static DynamicTypePtr create() { + return DynamicTypePtr(new DynamicType()); // NOLINT(modernize-make-shared) + } + DEFINE_IS_SUBCLASS(DynamicType); + + bool requires_grad() const override { return true; } + + bool operator==(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "Tensor"; + } + static const TypeKind Kind = TypeKind::DynamicType; + // global singleton + static DynamicTypePtr get(); +private: + DynamicType() + : Type(TypeKind::DynamicType) {} +}; + +struct UndefinedTensorType; +using UndefinedTensorTypePtr = std::shared_ptr<UndefinedTensorType>; +// This type represents an undefined tensor. +struct CAFFE2_API UndefinedTensorType : public Type { + static const TypeKind Kind = TypeKind::UndefinedTensorType; + static UndefinedTensorTypePtr create() { + return UndefinedTensorTypePtr(new UndefinedTensorType()); // NOLINT(modernize-make-shared) + } + + DEFINE_IS_SUBCLASS(UndefinedTensorType); + + bool operator==(const Type& rhs) const override { + return rhs.kind() == kind(); + } + bool isSubtypeOf(const TypePtr rhs) const override { + return rhs->kind() == TypeKind::DynamicType || + rhs->kind() == TypeKind::UndefinedTensorType; + } + std::string str() const override { + return "UndefinedTensor"; + } + static UndefinedTensorTypePtr get(); +protected: + UndefinedTensorType(): Type(TypeKind::UndefinedTensorType) {} +}; + +struct TensorType; +using TensorTypePtr = std::shared_ptr<TensorType>; +// This type represents a single Tensor with a specific size +struct CAFFE2_API TensorType : public Type { + static const TypeKind Kind = TypeKind::TensorType; + template<typename ... T> + static TensorTypePtr create( T&& ... all ) { + return TensorTypePtr(new TensorType( std::forward<T>(all)... )); // NOLINT(modernize-make-shared) + } + + at::ScalarType scalarType() const { return scalar_type_; } + int device() const { return device_; } + int dim() const { return dim_; } + bool requires_grad() const override { return requires_grad_; } + + TensorTypePtr toScalarType(at::ScalarType type){ + auto t = TensorType::create(*this); + t->scalar_type_ = type; + return t; + } + TensorTypePtr withDim(int new_dim) { + auto t = TensorType::create(*this); + t->dim_ = new_dim; + return t; + } + TensorTypePtr withRequiresGrad(bool req) { + auto t = TensorType::create(*this); + t->requires_grad_ = req; + return t; + } + + bool operator==(const Type& rhs) const override { + if (rhs.kind() != TypeKind::TensorType) + return false; + auto rt = rhs.expect<TensorType>(); + return scalarType() == rt->scalarType() && + device() == rt->device() && + dim() == rt->dim(); + } + bool isSubtypeOf(const TypePtr rhs) const override { + if (rhs->kind() == TypeKind::DynamicType) + return true; + return rhs->kind() == TypeKind::TensorType && *this == *rhs; + } + bool isSubclass(const TypeKind kind) const override { + return kind == TypeKind::DynamicType || + kind == TypeKind::TensorType; + } + std::string str() const override { + // str is used for user-facing error messages, where we + // don't want to reveal underlying size information. + return "Tensor"; + } + +protected: + TensorType(const at::Tensor& tensor, TypeKind kind=TypeKind::TensorType) + : TensorType(tensor.type().scalarType(), + tensor.is_cuda() ? tensor.get_device() : -1, + tensor.dim(), + tensor.is_variable() && tensor.requires_grad(), + kind) {} + TensorType(at::ScalarType scalar_type, int device, int dim, bool requires_grad=true, TypeKind kind=TypeKind::TensorType) + : Type(kind) + , scalar_type_(scalar_type) + , requires_grad_(at::isFloatingType(scalar_type) && requires_grad) + , device_(device) + , dim_(dim) {} + + at::ScalarType scalar_type_; + bool requires_grad_; + int device_; + int dim_; +}; + +struct CompleteTensorType; +using CompleteTensorTypePtr = std::shared_ptr<CompleteTensorType>; +// This type represents a single Tensor with a specific size +struct CAFFE2_API CompleteTensorType : public TensorType { + template<typename ... T> + static CompleteTensorTypePtr create( T&& ... all ) { + return CompleteTensorTypePtr(new CompleteTensorType( std::forward<T>(all)... )); // NOLINT(modernize-make-shared) + } + + // overloaded create variadic template argument as it could not distinguish initializer list + static CompleteTensorTypePtr create(at::ScalarType scalar_type, int device, at::IntList sizes) { + return CompleteTensorTypePtr(new CompleteTensorType(scalar_type, device, sizes)); // NOLINT(modernize-make-shared) + } + static CompleteTensorTypePtr create(at::ScalarType scalar_type, int device, at::IntList sizes, at::IntList strides) { + return CompleteTensorTypePtr(new CompleteTensorType(scalar_type, device, sizes, strides)); // NOLINT(modernize-make-shared) + } + + static const TypeKind Kind = TypeKind::CompleteTensorType; + + const std::vector<int64_t>& sizes() const { return sizes_; } + const std::vector<int64_t>& strides() const { return strides_; } + + TypePtr withSizesStrides(at::IntList sizes, at::IntList strides) const { + return CompleteTensorType::create(scalar_type_, device_, sizes, strides); + } + + TypePtr withSizes(at::IntList sizes) const { + return withSizesStrides(sizes, CompleteTensorType::contiguousStridesOf(sizes)); + } + + CompleteTensorTypePtr contiguous() const { + auto t = CompleteTensorType::create(*this); + t->strides_ = CompleteTensorType::contiguousStridesOf(sizes_); + return t; + } + + CompleteTensorTypePtr toScalarType(at::ScalarType type){ + auto t = CompleteTensorType::create(*this); + t->scalar_type_ = type; + return t; + } + + bool operator==(const Type& rhs) const override { + if(rhs.kind() != kind()) + return false; + auto rt = rhs.expect<CompleteTensorType>(); + return scalarType() == rt->scalarType() && + sizes() == rt->sizes() && + strides() == rt->strides() && + device() == rt->device(); + } + bool isSubtypeOf(const TypePtr rhs) const override { + if (rhs->kind() == TypeKind::DynamicType) + return true; + if (rhs->kind() == TypeKind::TensorType) + return *expect<TensorType>() == *rhs; + return *this == *rhs; + } + bool isSubclass(const TypeKind kind) const override { + return kind == TypeKind::DynamicType || + kind == TypeKind::TensorType || + kind == TypeKind::CompleteTensorType; + } + std::string str() const override { + // str is used for user-facing error messages, where we + // don't want to reveal underlying size information. + return "Tensor"; + } + bool numel() const { + size_t prod = 1; + for(auto s : sizes()) { + prod *= s; + } + return prod; + } + static TypePtr fromNumberType(TypePtr typ); + static TypePtr fromBoolType(); + +private: + CompleteTensorType(const at::Tensor& tensor) + : TensorType(tensor, TypeKind::CompleteTensorType) + , sizes_(tensor.sizes().vec()) + , strides_(tensor.strides().vec()) {} + CompleteTensorType(at::ScalarType scalar_type, int device, at::IntList sizes, bool requires_grad=true) + : CompleteTensorType(scalar_type, device, sizes, CompleteTensorType::contiguousStridesOf(sizes), requires_grad) {} + CompleteTensorType(at::ScalarType scalar_type, int device, at::IntList sizes, at::IntList strides, bool requires_grad=true) + : TensorType(scalar_type, device, sizes.size(), requires_grad, TypeKind::CompleteTensorType) + , sizes_(sizes.vec()) + , strides_(strides.vec()) {} + + static std::vector<int64_t> contiguousStridesOf(at::IntList sizes) { + std::vector<int64_t> strides(sizes.size()); + if(sizes.empty()) // zero-dim case + return strides; + strides.back() = 1; + for(size_t i = strides.size() - 1; i > 0; i--) { + strides[i-1] = strides[i] * sizes[i]; + } + return strides; + } + + std::vector<int64_t> sizes_; + std::vector<int64_t> strides_; +}; + +// common base for all types that have a single sub element +// e.g. Future[T], Option[T], List[T] +template<TypeKind K, typename T> +struct SingleElementType : public Type { + static const TypeKind Kind = K; + TypePtr getElementType() const { + return elem; + } + bool hasFreeVariables() const override { + return has_free_variables_; + } + at::ArrayRef<TypePtr> containedTypes() const override { + return elem; + } + bool requires_grad() const override { + return elem->requires_grad(); + } + bool operator==(const Type& rhs) const override { + if(auto rhs_ = rhs.cast<T>()) { + return *getElementType() == *rhs_->getElementType(); + } + return false; + } +protected: + SingleElementType(TypePtr elem) + : Type(Kind) + , elem(std::move(elem)) + , has_free_variables_(getElementType()->hasFreeVariables()) {} +private: + TypePtr elem; + bool has_free_variables_; +}; + +struct ListType; +using ListTypePtr = std::shared_ptr<ListType>; +struct CAFFE2_API ListType : public SingleElementType<TypeKind::ListType, ListType> { + // It's not exactly a singleton, but there should be exactly once instance of + // List[T] for every T + friend struct Type; + template<typename ... T> + static ListTypePtr create( T&& ... all ) { + return ListTypePtr(new ListType( std::forward<T>(all)... )); // NOLINT(modernize-make-shared) + } + DEFINE_IS_SUBCLASS(ListType); + std::string str() const override { + std::stringstream ss; + ss << getElementType()->str() << "[]"; + return ss.str(); + } + std::string python_str() const override { + std::stringstream ss; + ss << "List[" << getElementType()->python_str() << "]"; + return ss.str(); + } + TypePtr createWithContained(std::vector<TypePtr> contained_types) const override { + return create(contained_types.at(0)); + } + // common cast List[Tensor] + static ListTypePtr ofTensors(); + static ListTypePtr ofInts(); + static ListTypePtr ofFloats(); + static ListTypePtr ofBools(); +private: + using SingleElementType::SingleElementType; +}; + +struct FutureType; +using FutureTypePtr = std::shared_ptr<FutureType>; + +struct CAFFE2_API FutureType : public Type { + friend struct Type; + template<typename ... T> + static FutureTypePtr create(TypePtr elem) { + return FutureTypePtr(new FutureType(std::move(elem))); // NOLINT(modernize-make-shared) + } + + DEFINE_IS_SUBCLASS(FutureType); + + bool operator==(const Type& rhs) const override { + if (auto rhs_ = rhs.cast<FutureType>()) { + return *getElementType() == *rhs_->getElementType(); + } + return false; + } + bool requires_grad() const override { + return elem->requires_grad(); + } + std::string str() const override { + std::stringstream ss; + ss << "Future(" << getElementType()->str() << ")"; + return ss.str(); + } + std::string python_str() const override { + std::stringstream ss; + ss << "Future[" << getElementType()->python_str() << "]"; + return ss.str(); + } + TypePtr getElementType() const { + return elem; + } + bool hasFreeVariables() const override { + return has_free_variables_; + } + + static const TypeKind Kind = TypeKind::FutureType; +private: + FutureType(TypePtr elem) + : Type(TypeKind::FutureType) + , elem(std::move(elem)) + , has_free_variables_(getElementType()->hasFreeVariables()) {} + TypePtr elem; + bool has_free_variables_; +}; + +struct TupleType; +using TupleTypePtr = std::shared_ptr<TupleType>; +// This type represents a Tuple +struct CAFFE2_API TupleType : public Type { + static TupleTypePtr create(std::vector<TypePtr> types) { + return TupleTypePtr(new TupleType( std::move(types) )); // NOLINT(modernize-make-shared) + } + DEFINE_IS_SUBCLASS(TupleType); + at::ArrayRef<TypePtr> elements() const { + return elements_; + } + bool operator==(const Type& rhs) const override { + return compare(rhs, [](const TypePtr a, const TypePtr b) { + return *a == *b; + }); + } + bool isSubtypeOf(const TypePtr rhs) const override { + // co-variant rules for tuples + return compare(*rhs, [](const TypePtr a, const TypePtr b) { + return a->isSubtypeOf(b); + }); + } + bool requires_grad() const override { + return std::any_of(elements_.begin(), elements_.end(), + [](const TypePtr& ptr) { return ptr->requires_grad(); }); + } + std::string str() const override { + std::stringstream ss; + ss << "("; + for(size_t i = 0; i < elements().size(); ++i) { + if(i > 0) + ss << ", "; + ss << elements()[i]->str(); + } + ss << ")"; + return ss.str(); + } + std::string python_str() const override { + std::stringstream ss; + ss << "Tuple["; + for(size_t i = 0; i < elements().size(); ++i) { + if(i > 0) + ss << ", "; + ss << elements()[i]->python_str(); + } + ss << "]"; + return ss.str(); + } + bool hasFreeVariables() const override { + return has_free_variables_; + } + + at::ArrayRef<TypePtr> containedTypes() const override { + return elements_; + } + TypePtr createWithContained(std::vector<TypePtr> contained_types) const override { + return create(std::move(contained_types)); + } + + static const TypeKind Kind = TypeKind::TupleType; +private: + TupleType(std::vector<TypePtr> elements_) + : Type(TypeKind::TupleType) + , elements_(std::move(elements_)) { + has_free_variables_ = + std::any_of(elements_.begin(), elements_.end(), [](TypePtr v) { + return v->hasFreeVariables(); + }); + } + + bool compare(const Type& rhs, std::function<bool(const TypePtr, const TypePtr)> fn) const { + if(rhs.kind() != kind()) + return false; + const auto & l_elements = elements(); + const auto & r_elements = rhs.cast<TupleType>()->elements(); + if(l_elements.size() != r_elements.size()) + return false; + for(size_t i = 0; i < l_elements.size(); ++i) { + if(!fn(l_elements[i], r_elements[i])) + return false; + } + return true; + } + std::vector<TypePtr> elements_; + bool has_free_variables_; +}; + +struct NumberType; +using NumberTypePtr = std::shared_ptr<NumberType>; +// This type represents a Python number +struct CAFFE2_API NumberType : public Type { + static NumberTypePtr create() { + return NumberTypePtr(new NumberType()); // NOLINT(modernize-make-shared) + } + DEFINE_IS_SUBCLASS(NumberType); + bool operator==(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "Scalar"; // match what PythonArgParser says for clarity + } + static const TypeKind Kind = TypeKind::NumberType; + // global singleton + static NumberTypePtr get(); +private: + NumberType() + : Type(TypeKind::NumberType) {} +}; + +struct FloatType; +using FloatTypePtr = std::shared_ptr<FloatType>; +// This type represents a Python float number +struct CAFFE2_API FloatType : public Type { + static FloatTypePtr create() { + return FloatTypePtr(new FloatType()); // NOLINT(modernize-make-shared) + } + DEFINE_IS_SUBCLASS(FloatType); + bool operator==(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "float"; + } + bool isSubtypeOf(const TypePtr rhs) const override { + if(auto rhs_ = rhs->cast<OptionalType>()) { + return this->isSubtypeOf(rhs_->getElementType()); + } + return *this == *rhs || rhs->kind() == TypeKind::NumberType; + } + static const TypeKind Kind = TypeKind::FloatType; + // global singleton + static FloatTypePtr get(); +private: + FloatType() + : Type(TypeKind::FloatType) {} +}; + +struct IntType; +using IntTypePtr = std::shared_ptr<IntType>; +// This type represents a Python int number +struct CAFFE2_API IntType : public Type { + static IntTypePtr create() { + return IntTypePtr(new IntType()); // NOLINT(modernize-make-shared) + } + DEFINE_IS_SUBCLASS(IntType); + bool operator==(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "int"; + } + bool isSubtypeOf(const TypePtr rhs) const override { + if(auto rhs_ = rhs->cast<OptionalType>()) { + return this->isSubtypeOf(rhs_->getElementType()); + } + return *this == *rhs || rhs->kind() == TypeKind::NumberType; + } + static const TypeKind Kind = TypeKind::IntType; + // global singleton + static IntTypePtr get(); +private: + IntType() + : Type(TypeKind::IntType) {} +}; + +struct BoolType; +using BoolTypePtr = std::shared_ptr<BoolType>; +// This node represents a Python bool value +struct CAFFE2_API BoolType : public Type { + static BoolTypePtr create( ) { + return BoolTypePtr(new BoolType()); + } + DEFINE_IS_SUBCLASS(BoolType); + bool operator==(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "bool"; + } + bool isSubtypeOf(const TypePtr rhs) const override { + return *this == *rhs || rhs->kind() == TypeKind::BoolType; + } + static const TypeKind Kind = TypeKind::BoolType; + // global singleton + static BoolTypePtr get(); +private: + BoolType() + : Type(TypeKind::BoolType) {} +}; + +struct StringType; +using StringTypePtr = std::shared_ptr<StringType>; +// This type represents a Python string +struct CAFFE2_API StringType : public Type { + static StringTypePtr create() { + return StringTypePtr(new StringType()); // NOLINT(modernize-make-shared) + } + DEFINE_IS_SUBCLASS(StringType); + bool operator==(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "string"; + } + bool isSubtypeOf(const TypePtr rhs) const override { + if(auto rhs_ = rhs->cast<OptionalType>()) { + return this->isSubtypeOf(rhs_->getElementType()); + } + return *this == *rhs; + } + static const TypeKind Kind = TypeKind::StringType; + // global singleton + static StringTypePtr get(); +private: + StringType() + : Type(TypeKind::StringType) {} +}; + +struct NoneType; +using NoneTypePtr = std::shared_ptr<NoneType>; +// This type represents a Python None +struct CAFFE2_API NoneType : public Type { + static NoneTypePtr create() { + return NoneTypePtr(new NoneType()); // NOLINT(modernize-make-shared) + } + DEFINE_IS_SUBCLASS(NoneType); + bool operator==(const Type& rhs) const override { + return rhs.kind() == kind(); + } + + bool isSubtypeOf(const TypePtr rhs) const override { + return rhs->kind() == TypeKind::NoneType || + rhs->kind() == TypeKind::OptionalType; + } + + std::string str() const override { + return "None"; + } + static const TypeKind Kind = TypeKind::NoneType; + // global singleton + static NoneTypePtr get(); +private: + NoneType() + : Type(TypeKind::NoneType) {} +}; + +struct GeneratorType; +using GeneratorTypePtr = std::shared_ptr<GeneratorType>; +// This type represents a Generator +struct CAFFE2_API GeneratorType : public Type { + static GeneratorTypePtr create() { + return GeneratorTypePtr(new GeneratorType()); // NOLINT(modernize-make-shared) + } + DEFINE_IS_SUBCLASS(GeneratorType); + bool operator==(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "Generator"; + } + static const TypeKind Kind = TypeKind::GeneratorType; + // global singleton + static GeneratorTypePtr get(); +private: + GeneratorType() + : Type(TypeKind::GeneratorType) {} +}; + + +struct VarType; +using VarTypePtr = std::shared_ptr<VarType>; +// This type represents a type variable, used in FunctionSchema +struct VarType : public Type { + static VarTypePtr create(std::string name_) { + return VarTypePtr(new VarType(std::move(name_))); + } + DEFINE_IS_SUBCLASS(VarType); + bool operator==(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return name(); + } + static const TypeKind Kind = TypeKind::VarType; + const std::string& name() const { + return name_; + } + bool hasFreeVariables() const override { + return true; + } +private: + VarType(std::string name_) + : Type(TypeKind::VarType), name_(std::move(name_)) {} + std::string name_; +}; + +CAFFE2_API std::ostream& operator<<(std::ostream & out, const Type & t); +// what is the type, ignoring extra size/shape information? +// e.g. Tensor(2x3) -> Dynamic, and Tuple(Tensor(2x3),...) -> Tuple(Dynamic,...) + +inline TypePtr unshapedType(const TypePtr& type) { + if (type->kind() == TypeKind::TensorType || + type->kind() == TypeKind::CompleteTensorType) { + return DynamicType::get(); + } + return type->withContained(fmap(type->containedTypes(), unshapedType)); +} + +inline TypePtr CompleteTensorType::fromNumberType(TypePtr typ) { + AT_ASSERT(typ->isSubtypeOf(NumberType::get())); + if (typ->isSubtypeOf(IntType::get())) { + return CompleteTensorType::create(at::kLong, -1, {}); + } else if (typ->isSubtypeOf(FloatType::get())) { + return CompleteTensorType::create(at::kFloat, -1, {}); + } else if (typ->isSubtypeOf(BoolType::get())) { + return CompleteTensorType::create(at::kLong, -1, {}); + } + AT_ERROR("unknown number type", typ->str()); +} + +inline TypePtr CompleteTensorType::fromBoolType() { + return CompleteTensorType::create(at::kLong, -1, {}); +} + +// Attempt to find the correct supertype of t1 and t2. If none is found then +// nullopt will be returned. If t1 == t2, or t1 is a type refinement of t2, +// then t2 will be returned (and vice versa). +// Two different tensortypes will return dynamic. +// Currently we chose not to support returning a NumberType for a float & int +// input because of a lack of operator support for NumberType +CAFFE2_API c10::optional<TypePtr> unifyTypes( + const TypePtr& t1, + const TypePtr& t2); + +template <typename T> +TypePtr getTypePtr() { +#define TYPE_STR(Type) #Type, " ", + AT_ERROR( + "Type ", + c10::demangle_type<T>(), + " could not be converted to any of the known types { ", + C10_FORALL_TYPES(TYPE_STR) "}"); +#undef TYPE_STR + return nullptr; +} + +template<> inline TypePtr getTypePtr<at::Tensor>() { return DynamicType::get(); } +template<> inline TypePtr getTypePtr<double>() { return FloatType::get(); } +template<> inline TypePtr getTypePtr<int64_t>() { return IntType::get(); } +template<> inline TypePtr getTypePtr<bool>() { return BoolType::get(); } +template<> inline TypePtr getTypePtr<at::Scalar>() { return NumberType::get(); } +template<> inline TypePtr getTypePtr<std::vector<at::Tensor>>() { return ListType::ofTensors(); } +template<> inline TypePtr getTypePtr<std::vector<double>>() { return ListType::ofFloats(); } +template<> inline TypePtr getTypePtr<std::vector<int64_t>>() { return ListType::ofInts(); } + +CAFFE2_API TypePtr inferTypeFrom(const IValue& value); + +struct CAFFE2_API TypeMatchError : public std::exception { + TypeMatchError(std::string msg_) + : msg_(std::move(msg_)) {} + const char * what() const noexcept override { + return msg_.c_str(); + } +private: + std::string msg_; +}; +using TypeEnv = std::unordered_map<std::string, TypePtr>; +CAFFE2_API TypePtr matchTypeVariables(TypePtr formal, TypePtr actual, TypeEnv & type_env); +CAFFE2_API TypePtr evalTypeVariables(TypePtr type, TypeEnv & type_env); + +} // namespace c10 diff --git a/torch/csrc/jit/type.cpp b/aten/src/ATen/core/type.cpp index 3ae2879b31..77191fdd52 100644 --- a/torch/csrc/jit/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -1,17 +1,15 @@ -#include "torch/csrc/jit/type.h" - -#include "torch/csrc/jit/assertions.h" +#include <ATen/core/jit_type.h> #include <iostream> -namespace torch { namespace jit { +namespace c10 { std::ostream& operator<<(std::ostream & out, const Type & t) { if(auto value = t.cast<CompleteTensorType>()) { out << at::toString(value->scalarType()) << "("; auto& sizes = value->sizes(); auto& strides = value->strides(); - JIT_ASSERT(sizes.size() == strides.size()); + AT_ASSERT(sizes.size() == strides.size()); for (size_t i = 0; i < sizes.size(); i++) { if (i > 0) { out << ", "; @@ -260,7 +258,7 @@ TypePtr matchTypeVariables(TypePtr formal, TypePtr actual, TypeEnv& type_env) { } // change return types like List[List[t]] into List[List[int]] -TORCH_API TypePtr evalTypeVariables(TypePtr type, std::unordered_map<std::string, TypePtr>& type_env) { +CAFFE2_API TypePtr evalTypeVariables(TypePtr type, std::unordered_map<std::string, TypePtr>& type_env) { if(!type->hasFreeVariables()) return type; @@ -276,4 +274,4 @@ TORCH_API TypePtr evalTypeVariables(TypePtr type, std::unordered_map<std::string } } -}} // namespace torch::jit +} // namespace c10 diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 3b879b37c9..ebebaa3388 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -193,7 +193,6 @@ set(TORCH_SRCS ${TORCH_SRC_DIR}/csrc/jit/script/lexer.cpp ${TORCH_SRC_DIR}/csrc/jit/script/module.cpp ${TORCH_SRC_DIR}/csrc/jit/tracer.cpp - ${TORCH_SRC_DIR}/csrc/jit/type.cpp ${TORCH_SRC_DIR}/csrc/torch.cpp ${TORCH_SRC_DIR}/csrc/utils/tensor_flatten.cpp ${TORCH_SRC_DIR}/csrc/utils/variadic.cpp diff --git a/torch/csrc/autograd/python_variable_indexing.cpp b/torch/csrc/autograd/python_variable_indexing.cpp index 40c2dc9d9d..0ea8ce0bc8 100644 --- a/torch/csrc/autograd/python_variable_indexing.cpp +++ b/torch/csrc/autograd/python_variable_indexing.cpp @@ -106,12 +106,12 @@ static Variable applySelect(const Variable& self, int64_t dim, int64_t index) { return self.select(dim, index); } -static Variable sequenceToVariable(const Type& type, PyObject* seq) { +static Variable sequenceToVariable(const at::Type& type, PyObject* seq) { auto& idx_type = type.toScalarType(kLong); return torch::utils::legacy_new_from_data(idx_type, c10::nullopt, seq); } -static Variable valueToTensor(const Type & type, PyObject* value) { +static Variable valueToTensor(const at::Type & type, PyObject* value) { if (THPVariable_Check(value)) { return reinterpret_cast<THPVariable*>(value)->cdata; } diff --git a/torch/csrc/jit/function_schema.h b/torch/csrc/jit/function_schema.h index 9351f272be..a02c1da32e 100644 --- a/torch/csrc/jit/function_schema.h +++ b/torch/csrc/jit/function_schema.h @@ -2,8 +2,10 @@ #include "ATen/ATen.h" #include "torch/csrc/jit/type.h" +#include "torch/csrc/jit/interned_strings.h" #include "torch/csrc/jit/ivalue.h" #include "torch/csrc/jit/alias_info.h" +#include "torch/csrc/jit/assertions.h" namespace torch { namespace jit { diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp index fee47eb9d5..32ba16b35f 100644 --- a/torch/csrc/jit/python_ir.cpp +++ b/torch/csrc/jit/python_ir.cpp @@ -16,6 +16,8 @@ namespace torch { namespace jit { +using c10::Type; + std::string getPythonName(const PyObject* obj_) { AutoGIL gil; PyObject* obj = const_cast<PyObject*>(obj_); @@ -413,6 +415,7 @@ void initPythonIRBindings(PyObject * module_) { }) ; + using ::c10::Type; py::class_<Type,std::shared_ptr<Type>>(m,"Type") .def("__repr__",[](Type & t) { return t.python_str(); @@ -479,7 +482,7 @@ void initPythonIRBindings(PyObject * module_) { .def("isSubtypeOf", [](std::shared_ptr<Type>& self, std::shared_ptr<Type> other) { return self->isSubtypeOf(other); }) - .def_static("inferFrom", inferTypeFrom); + .def_static("inferFrom", c10::inferTypeFrom); py::class_<NumberType, Type, std::shared_ptr<NumberType>>(m, "NumberType") .def_static("get", &NumberType::get); diff --git a/torch/csrc/jit/type.h b/torch/csrc/jit/type.h index 8a21ad11d5..384957da61 100644 --- a/torch/csrc/jit/type.h +++ b/torch/csrc/jit/type.h @@ -1,933 +1,21 @@ -#pragma once - -#include "torch/csrc/jit/ivalue.h" -#include "torch/csrc/jit/assertions.h" -#include "torch/csrc/jit/interned_strings.h" -#include "torch/csrc/WindowsTorchApiMacro.h" -#include "torch/csrc/utils/functional.h" - -#include <ATen/ATen.h> - -#include <memory> -#include <iostream> -#include <type_traits> +#include <ATen/core/jit_type.h> namespace torch { namespace jit { -#define TH_FORALL_TYPES(_) \ -_(DynamicType) \ -_(TensorType) \ -_(CompleteTensorType) \ -_(UndefinedTensorType) \ -_(TupleType) \ -_(ListType) \ -_(NumberType) \ -_(FloatType) \ -_(FutureType) \ -_(IntType) \ -_(NoneType) \ -_(StringType) \ -_(GeneratorType) \ -_(BoolType) \ -_(OptionalType) \ -_(VarType) \ - -enum class TypeKind { -#define DEFINE_TYPE(T) T, - TH_FORALL_TYPES(DEFINE_TYPE) -#undef DEFINE_TYPE -}; - -#define DEFINE_IS_SUBCLASS(_kind) \ - bool isSubclass(const TypeKind kind) const override { \ - return kind == TypeKind::_kind; \ - } - -struct Type; -using TypePtr = std::shared_ptr<Type>; - -struct TORCH_API Type : std::enable_shared_from_this<Type> { -private: - TypeKind kind_; - template<typename T> - static std::shared_ptr<T> sliceType(std::shared_ptr<const T> ptr) { - auto result = std::make_shared<typename std::remove_const<T>::type>(*ptr); - // XXX: the line above will correctly slice the struct, and make its runtype - // type exactly equal to T. However, kind_ is a field of Type, so it will simply - // be copied, and we need to fix it in here to match the dynamic type. - result->kind_ = T::Kind; - return result; - } - -protected: - Type(TypeKind kind) - : kind_(kind) {} - -public: - virtual bool operator==(const Type& rhs) const = 0; - - // subtyping relation. By default, we return true for the case - // when the type is exactly equal - virtual bool isSubtypeOf(const TypePtr rhs) const { - return *this == *rhs; - } - - // If this class can be cast to the kind passed in - // This removes the need for RTTI - virtual bool isSubclass(const TypeKind kind) const = 0; - - // How this type will appear in FunctionSchema declarations - virtual std::string str() const = 0; - - // How this type will appear as if it were a type annotation in Python - // which is sometimes different than how it appears in declarations (e.g. int[] vs List[int]) - virtual std::string python_str() const { - return str(); - } - - TypeKind kind() const { - return kind_; - } - - virtual bool requires_grad() const { return false; } - - // Dynamically cast this object to the subclass indicated by the - // template variable, returning nullptr if the cast is invalid. - // NOTE: if the cast succeeds, but the casted kind is not the - // run-time kind of the type, we also slice the structure, so - // that assignments of those types to values don't accidentally - // inherit more detailed information from subclasses. - template<typename T> - std::shared_ptr<T> cast() { - std::shared_ptr<T> r = nullptr; - if (isSubclass(T::Kind)) { - r = std::static_pointer_cast<T>(shared_from_this()); - } - if (!r || T::Kind == kind()) { - return r; - } else { - return sliceType<T>(r); - } - } - template<typename T> - std::shared_ptr<const T> cast() const { - std::shared_ptr<const T> r = nullptr; - if (isSubclass(T::Kind)) { - r = std::static_pointer_cast<const T>(shared_from_this()); - } - if (!r || T::Kind == kind()) { - return r; - } else { - return sliceType<T>(r); - } - } - template<typename T> - std::shared_ptr<T> expect() { - auto r = cast<T>(); - JIT_ASSERT(r); - return r; - } - template<typename T> - std::shared_ptr<const T> expect() const { - auto r = cast<const T>(); - JIT_ASSERT(r); - return r; - } - virtual ~Type() = default; - virtual bool hasFreeVariables() const { - return false; - } - // list of types this type contains, e.g. for a List then element type of a list - // for a tuple, the types of the tuple elements - virtual at::ArrayRef<TypePtr> containedTypes() const { - return {}; - } - // create a new version of this type, replacing its contained types with - // contained_types - TypePtr withContained(std::vector<TypePtr> contained_types) { - auto current_contained = containedTypes(); - JIT_ASSERT(current_contained.size() == contained_types.size()); - if(current_contained.equals(contained_types)) { - return shared_from_this(); - } - return createWithContained(std::move(contained_types)); - } - // per-type constructor, you only need to override this if the containedTypes() - // is not empty - virtual TypePtr createWithContained(std::vector<TypePtr> contained_types) const { - AT_ERROR("type with contained types did not overload createWithContained: ", str()); - } -}; - -inline bool operator!=(const Type & lhs, const Type & rhs) { - return !(lhs == rhs); -} - -struct OptionalType; -using OptionalTypePtr = std::shared_ptr<OptionalType>; -// This type represents an optional type, for each element type. -struct OptionalType: public Type { - static OptionalTypePtr create(TypePtr element) { - return OptionalTypePtr(new OptionalType(std::move(element))); // NOLINT(modernize-make-shared) - } - DEFINE_IS_SUBCLASS(OptionalType); - bool operator==(const Type& rhs) const override { - if(auto rhs_ = rhs.cast<OptionalType>()) { - return *getElementType() == *rhs_->getElementType(); - } - return false; - } - bool requires_grad() const override { - return elem->requires_grad(); - } - - bool isSubtypeOf(const TypePtr rhs) const override { - if(auto rhs_ = rhs->cast<OptionalType>()) { - return getElementType()->isSubtypeOf(rhs_->getElementType()); - } - return false; - } - - std::string str() const override { - std::stringstream ss; - ss << getElementType()->str() << "?"; - return ss.str(); - } - std::string python_str() const override { - std::stringstream ss; - ss << "Optional[" << getElementType()->python_str() << "]"; - return ss.str(); - } - TypePtr getElementType() const { - return elem; - } - bool hasFreeVariables() const override { - return has_free_variables_; - } - - static const TypeKind Kind = TypeKind::OptionalType; -private: - OptionalType(TypePtr elem) - : Type(TypeKind::OptionalType) - , elem(std::move(elem)) - , has_free_variables_(getElementType()->hasFreeVariables()) {} - TypePtr elem; - bool has_free_variables_; - -}; - -struct DynamicType; -using DynamicTypePtr = std::shared_ptr<DynamicType>; -// This type represents a single Tensor, with an unknown shape. -struct TORCH_API DynamicType : public Type { - static DynamicTypePtr create() { - return DynamicTypePtr(new DynamicType()); // NOLINT(modernize-make-shared) - } - DEFINE_IS_SUBCLASS(DynamicType); - - bool requires_grad() const override { return true; } - - bool operator==(const Type& rhs) const override { - return rhs.kind() == kind(); - } - std::string str() const override { - return "Tensor"; - } - static const TypeKind Kind = TypeKind::DynamicType; - // global singleton - static DynamicTypePtr get(); -private: - DynamicType() - : Type(TypeKind::DynamicType) {} -}; - -struct UndefinedTensorType; -using UndefinedTensorTypePtr = std::shared_ptr<UndefinedTensorType>; -// This type represents an undefined tensor. -struct TORCH_API UndefinedTensorType : public Type { - static const TypeKind Kind = TypeKind::UndefinedTensorType; - static UndefinedTensorTypePtr create() { - return UndefinedTensorTypePtr(new UndefinedTensorType()); // NOLINT(modernize-make-shared) - } - - DEFINE_IS_SUBCLASS(UndefinedTensorType); - - bool operator==(const Type& rhs) const override { - return rhs.kind() == kind(); - } - bool isSubtypeOf(const TypePtr rhs) const override { - return rhs->kind() == TypeKind::DynamicType || - rhs->kind() == TypeKind::UndefinedTensorType; - } - std::string str() const override { - return "UndefinedTensor"; - } - static UndefinedTensorTypePtr get(); -protected: - UndefinedTensorType(): Type(TypeKind::UndefinedTensorType) {} -}; - -struct TensorType; -using TensorTypePtr = std::shared_ptr<TensorType>; -// This type represents a single Tensor with a specific size -struct TORCH_API TensorType : public Type { - static const TypeKind Kind = TypeKind::TensorType; - template<typename ... T> - static TensorTypePtr create( T&& ... all ) { - return TensorTypePtr(new TensorType( std::forward<T>(all)... )); // NOLINT(modernize-make-shared) - } - - at::ScalarType scalarType() const { return scalar_type_; } - int device() const { return device_; } - int dim() const { return dim_; } - bool requires_grad() const override { return requires_grad_; } - - TensorTypePtr toScalarType(at::ScalarType type){ - auto t = TensorType::create(*this); - t->scalar_type_ = type; - return t; - } - TensorTypePtr withDim(int new_dim) { - auto t = TensorType::create(*this); - t->dim_ = new_dim; - return t; - } - TensorTypePtr withRequiresGrad(bool req) { - auto t = TensorType::create(*this); - t->requires_grad_ = req; - return t; - } - - bool operator==(const Type& rhs) const override { - if (rhs.kind() != TypeKind::TensorType) - return false; - auto rt = rhs.expect<TensorType>(); - return scalarType() == rt->scalarType() && - device() == rt->device() && - dim() == rt->dim(); - } - bool isSubtypeOf(const TypePtr rhs) const override { - if (rhs->kind() == TypeKind::DynamicType) - return true; - return rhs->kind() == TypeKind::TensorType && *this == *rhs; - } - bool isSubclass(const TypeKind kind) const override { - return kind == TypeKind::DynamicType || - kind == TypeKind::TensorType; - } - std::string str() const override { - // str is used for user-facing error messages, where we - // don't want to reveal underlying size information. - return "Tensor"; - } - -protected: - TensorType(const at::Tensor& tensor, TypeKind kind=TypeKind::TensorType) - : TensorType(tensor.type().scalarType(), - tensor.is_cuda() ? tensor.get_device() : -1, - tensor.dim(), - tensor.is_variable() && tensor.requires_grad(), - kind) {} - TensorType(at::ScalarType scalar_type, int device, int dim, bool requires_grad=true, TypeKind kind=TypeKind::TensorType) - : Type(kind) - , scalar_type_(scalar_type) - , requires_grad_(at::isFloatingType(scalar_type) && requires_grad) - , device_(device) - , dim_(dim) {} - - at::ScalarType scalar_type_; - bool requires_grad_; - int device_; - int dim_; -}; - -struct CompleteTensorType; -using CompleteTensorTypePtr = std::shared_ptr<CompleteTensorType>; -// This type represents a single Tensor with a specific size -struct TORCH_API CompleteTensorType : public TensorType { - template<typename ... T> - static CompleteTensorTypePtr create( T&& ... all ) { - return CompleteTensorTypePtr(new CompleteTensorType( std::forward<T>(all)... )); // NOLINT(modernize-make-shared) - } - - // overloaded create variadic template argument as it could not distinguish initializer list - static CompleteTensorTypePtr create(at::ScalarType scalar_type, int device, at::IntList sizes) { - return CompleteTensorTypePtr(new CompleteTensorType(scalar_type, device, sizes)); // NOLINT(modernize-make-shared) - } - static CompleteTensorTypePtr create(at::ScalarType scalar_type, int device, at::IntList sizes, at::IntList strides) { - return CompleteTensorTypePtr(new CompleteTensorType(scalar_type, device, sizes, strides)); // NOLINT(modernize-make-shared) - } - - static const TypeKind Kind = TypeKind::CompleteTensorType; - - const std::vector<int64_t>& sizes() const { return sizes_; } - const std::vector<int64_t>& strides() const { return strides_; } - - TypePtr withSizesStrides(at::IntList sizes, at::IntList strides) const { - return CompleteTensorType::create(scalar_type_, device_, sizes, strides); - } - - TypePtr withSizes(at::IntList sizes) const { - return withSizesStrides(sizes, CompleteTensorType::contiguousStridesOf(sizes)); - } - - CompleteTensorTypePtr contiguous() const { - auto t = CompleteTensorType::create(*this); - t->strides_ = CompleteTensorType::contiguousStridesOf(sizes_); - return t; - } - - CompleteTensorTypePtr toScalarType(at::ScalarType type){ - auto t = CompleteTensorType::create(*this); - t->scalar_type_ = type; - return t; - } - - bool operator==(const Type& rhs) const override { - if(rhs.kind() != kind()) - return false; - auto rt = rhs.expect<CompleteTensorType>(); - return scalarType() == rt->scalarType() && - sizes() == rt->sizes() && - strides() == rt->strides() && - device() == rt->device(); - } - bool isSubtypeOf(const TypePtr rhs) const override { - if (rhs->kind() == TypeKind::DynamicType) - return true; - if (rhs->kind() == TypeKind::TensorType) - return *expect<TensorType>() == *rhs; - return *this == *rhs; - } - bool isSubclass(const TypeKind kind) const override { - return kind == TypeKind::DynamicType || - kind == TypeKind::TensorType || - kind == TypeKind::CompleteTensorType; - } - std::string str() const override { - // str is used for user-facing error messages, where we - // don't want to reveal underlying size information. - return "Tensor"; - } - bool numel() const { - size_t prod = 1; - for(auto s : sizes()) { - prod *= s; - } - return prod; - } - static TypePtr fromNumberType(TypePtr typ); - static TypePtr fromBoolType(); - -private: - CompleteTensorType(const at::Tensor& tensor) - : TensorType(tensor, TypeKind::CompleteTensorType) - , sizes_(tensor.sizes().vec()) - , strides_(tensor.strides().vec()) {} - CompleteTensorType(at::ScalarType scalar_type, int device, at::IntList sizes, bool requires_grad=true) - : CompleteTensorType(scalar_type, device, sizes, CompleteTensorType::contiguousStridesOf(sizes), requires_grad) {} - CompleteTensorType(at::ScalarType scalar_type, int device, at::IntList sizes, at::IntList strides, bool requires_grad=true) - : TensorType(scalar_type, device, sizes.size(), requires_grad, TypeKind::CompleteTensorType) - , sizes_(sizes.vec()) - , strides_(strides.vec()) {} - - static std::vector<int64_t> contiguousStridesOf(at::IntList sizes) { - std::vector<int64_t> strides(sizes.size()); - if(sizes.empty()) // zero-dim case - return strides; - strides.back() = 1; - for(size_t i = strides.size() - 1; i > 0; i--) { - strides[i-1] = strides[i] * sizes[i]; - } - return strides; - } - - std::vector<int64_t> sizes_; - std::vector<int64_t> strides_; -}; - -// common base for all types that have a single sub element -// e.g. Future[T], Option[T], List[T] -template<TypeKind K, typename T> -struct SingleElementType : public Type { - static const TypeKind Kind = K; - TypePtr getElementType() const { - return elem; - } - bool hasFreeVariables() const override { - return has_free_variables_; - } - at::ArrayRef<TypePtr> containedTypes() const override { - return elem; - } - bool requires_grad() const override { - return elem->requires_grad(); - } - bool operator==(const Type& rhs) const override { - if(auto rhs_ = rhs.cast<T>()) { - return *getElementType() == *rhs_->getElementType(); - } - return false; - } -protected: - SingleElementType(TypePtr elem) - : Type(Kind) - , elem(std::move(elem)) - , has_free_variables_(getElementType()->hasFreeVariables()) {} -private: - TypePtr elem; - bool has_free_variables_; -}; - -struct ListType; -using ListTypePtr = std::shared_ptr<ListType>; -struct TORCH_API ListType : public SingleElementType<TypeKind::ListType, ListType> { - // It's not exactly a singleton, but there should be exactly once instance of - // List[T] for every T - friend struct Type; - template<typename ... T> - static ListTypePtr create( T&& ... all ) { - return ListTypePtr(new ListType( std::forward<T>(all)... )); // NOLINT(modernize-make-shared) - } - DEFINE_IS_SUBCLASS(ListType); - std::string str() const override { - std::stringstream ss; - ss << getElementType()->str() << "[]"; - return ss.str(); - } - std::string python_str() const override { - std::stringstream ss; - ss << "List[" << getElementType()->python_str() << "]"; - return ss.str(); - } - TypePtr createWithContained(std::vector<TypePtr> contained_types) const override { - return create(contained_types.at(0)); - } - // common cast List[Tensor] - static ListTypePtr ofTensors(); - static ListTypePtr ofInts(); - static ListTypePtr ofFloats(); - static ListTypePtr ofBools(); -private: - using SingleElementType::SingleElementType; -}; - -struct FutureType; -using FutureTypePtr = std::shared_ptr<FutureType>; - -struct TORCH_API FutureType : public Type { - friend struct Type; - template<typename ... T> - static FutureTypePtr create(TypePtr elem) { - return FutureTypePtr(new FutureType(std::move(elem))); // NOLINT(modernize-make-shared) - } - - DEFINE_IS_SUBCLASS(FutureType); - - bool operator==(const Type& rhs) const override { - if (auto rhs_ = rhs.cast<FutureType>()) { - return *getElementType() == *rhs_->getElementType(); - } - return false; - } - bool requires_grad() const override { - return elem->requires_grad(); - } - std::string str() const override { - std::stringstream ss; - ss << "Future(" << getElementType()->str() << ")"; - return ss.str(); - } - std::string python_str() const override { - std::stringstream ss; - ss << "Future[" << getElementType()->python_str() << "]"; - return ss.str(); - } - TypePtr getElementType() const { - return elem; - } - bool hasFreeVariables() const override { - return has_free_variables_; - } - - static const TypeKind Kind = TypeKind::FutureType; -private: - FutureType(TypePtr elem) - : Type(TypeKind::FutureType) - , elem(std::move(elem)) - , has_free_variables_(getElementType()->hasFreeVariables()) {} - TypePtr elem; - bool has_free_variables_; -}; - -struct TupleType; -using TupleTypePtr = std::shared_ptr<TupleType>; -// This type represents a Tuple -struct TORCH_API TupleType : public Type { - static TupleTypePtr create(std::vector<TypePtr> types) { - return TupleTypePtr(new TupleType( std::move(types) )); // NOLINT(modernize-make-shared) - } - DEFINE_IS_SUBCLASS(TupleType); - at::ArrayRef<TypePtr> elements() const { - return elements_; - } - bool operator==(const Type& rhs) const override { - return compare(rhs, [](const TypePtr a, const TypePtr b) { - return *a == *b; - }); - } - bool isSubtypeOf(const TypePtr rhs) const override { - // co-variant rules for tuples - return compare(*rhs, [](const TypePtr a, const TypePtr b) { - return a->isSubtypeOf(b); - }); - } - bool requires_grad() const override { - return std::any_of(elements_.begin(), elements_.end(), - [](const TypePtr& ptr) { return ptr->requires_grad(); }); - } - std::string str() const override { - std::stringstream ss; - ss << "("; - for(size_t i = 0; i < elements().size(); ++i) { - if(i > 0) - ss << ", "; - ss << elements()[i]->str(); - } - ss << ")"; - return ss.str(); - } - std::string python_str() const override { - std::stringstream ss; - ss << "Tuple["; - for(size_t i = 0; i < elements().size(); ++i) { - if(i > 0) - ss << ", "; - ss << elements()[i]->python_str(); - } - ss << "]"; - return ss.str(); - } - bool hasFreeVariables() const override { - return has_free_variables_; - } - - at::ArrayRef<TypePtr> containedTypes() const override { - return elements_; - } - TypePtr createWithContained(std::vector<TypePtr> contained_types) const override { - return create(std::move(contained_types)); - } - - static const TypeKind Kind = TypeKind::TupleType; -private: - TupleType(std::vector<TypePtr> elements_) - : Type(TypeKind::TupleType) - , elements_(std::move(elements_)) { - has_free_variables_ = - std::any_of(elements_.begin(), elements_.end(), [](TypePtr v) { - return v->hasFreeVariables(); - }); - } - - bool compare(const Type& rhs, std::function<bool(const TypePtr, const TypePtr)> fn) const { - if(rhs.kind() != kind()) - return false; - const auto & l_elements = elements(); - const auto & r_elements = rhs.cast<TupleType>()->elements(); - if(l_elements.size() != r_elements.size()) - return false; - for(size_t i = 0; i < l_elements.size(); ++i) { - if(!fn(l_elements[i], r_elements[i])) - return false; - } - return true; - } - std::vector<TypePtr> elements_; - bool has_free_variables_; -}; - -struct NumberType; -using NumberTypePtr = std::shared_ptr<NumberType>; -// This type represents a Python number -struct TORCH_API NumberType : public Type { - static NumberTypePtr create() { - return NumberTypePtr(new NumberType()); // NOLINT(modernize-make-shared) - } - DEFINE_IS_SUBCLASS(NumberType); - bool operator==(const Type& rhs) const override { - return rhs.kind() == kind(); - } - std::string str() const override { - return "Scalar"; // match what PythonArgParser says for clarity - } - static const TypeKind Kind = TypeKind::NumberType; - // global singleton - static NumberTypePtr get(); -private: - NumberType() - : Type(TypeKind::NumberType) {} -}; - -struct FloatType; -using FloatTypePtr = std::shared_ptr<FloatType>; -// This type represents a Python float number -struct TORCH_API FloatType : public Type { - static FloatTypePtr create() { - return FloatTypePtr(new FloatType()); // NOLINT(modernize-make-shared) - } - DEFINE_IS_SUBCLASS(FloatType); - bool operator==(const Type& rhs) const override { - return rhs.kind() == kind(); - } - std::string str() const override { - return "float"; - } - bool isSubtypeOf(const TypePtr rhs) const override { - if(auto rhs_ = rhs->cast<OptionalType>()) { - return this->isSubtypeOf(rhs_->getElementType()); - } - return *this == *rhs || rhs->kind() == TypeKind::NumberType; - } - static const TypeKind Kind = TypeKind::FloatType; - // global singleton - static FloatTypePtr get(); -private: - FloatType() - : Type(TypeKind::FloatType) {} -}; - -struct IntType; -using IntTypePtr = std::shared_ptr<IntType>; -// This type represents a Python int number -struct TORCH_API IntType : public Type { - static IntTypePtr create() { - return IntTypePtr(new IntType()); // NOLINT(modernize-make-shared) - } - DEFINE_IS_SUBCLASS(IntType); - bool operator==(const Type& rhs) const override { - return rhs.kind() == kind(); - } - std::string str() const override { - return "int"; - } - bool isSubtypeOf(const TypePtr rhs) const override { - if(auto rhs_ = rhs->cast<OptionalType>()) { - return this->isSubtypeOf(rhs_->getElementType()); - } - return *this == *rhs || rhs->kind() == TypeKind::NumberType; - } - static const TypeKind Kind = TypeKind::IntType; - // global singleton - static IntTypePtr get(); -private: - IntType() - : Type(TypeKind::IntType) {} -}; - -struct BoolType; -using BoolTypePtr = std::shared_ptr<BoolType>; -// This node represents a Python bool value -struct TORCH_API BoolType : public Type { - static BoolTypePtr create( ) { - return BoolTypePtr(new BoolType()); - } - DEFINE_IS_SUBCLASS(BoolType); - bool operator==(const Type& rhs) const override { - return rhs.kind() == kind(); - } - std::string str() const override { - return "bool"; - } - bool isSubtypeOf(const TypePtr rhs) const override { - return *this == *rhs || rhs->kind() == TypeKind::BoolType; - } - static const TypeKind Kind = TypeKind::BoolType; - // global singleton - static BoolTypePtr get(); -private: - BoolType() - : Type(TypeKind::BoolType) {} -}; - -struct StringType; -using StringTypePtr = std::shared_ptr<StringType>; -// This type represents a Python string -struct TORCH_API StringType : public Type { - static StringTypePtr create() { - return StringTypePtr(new StringType()); // NOLINT(modernize-make-shared) - } - DEFINE_IS_SUBCLASS(StringType); - bool operator==(const Type& rhs) const override { - return rhs.kind() == kind(); - } - std::string str() const override { - return "string"; - } - bool isSubtypeOf(const TypePtr rhs) const override { - if(auto rhs_ = rhs->cast<OptionalType>()) { - return this->isSubtypeOf(rhs_->getElementType()); - } - return *this == *rhs; - } - static const TypeKind Kind = TypeKind::StringType; - // global singleton - static StringTypePtr get(); -private: - StringType() - : Type(TypeKind::StringType) {} -}; - -struct NoneType; -using NoneTypePtr = std::shared_ptr<NoneType>; -// This type represents a Python None -struct NoneType : public Type { - static NoneTypePtr create() { - return NoneTypePtr(new NoneType()); // NOLINT(modernize-make-shared) - } - DEFINE_IS_SUBCLASS(NoneType); - bool operator==(const Type& rhs) const override { - return rhs.kind() == kind(); - } - - bool isSubtypeOf(const TypePtr rhs) const override { - return rhs->kind() == TypeKind::NoneType || - rhs->kind() == TypeKind::OptionalType; - } - - std::string str() const override { - return "None"; - } - static const TypeKind Kind = TypeKind::NoneType; - // global singleton - static NoneTypePtr get(); -private: - NoneType() - : Type(TypeKind::NoneType) {} -}; - -struct GeneratorType; -using GeneratorTypePtr = std::shared_ptr<GeneratorType>; -// This type represents a Generator -struct GeneratorType : public Type { - static GeneratorTypePtr create() { - return GeneratorTypePtr(new GeneratorType()); // NOLINT(modernize-make-shared) - } - DEFINE_IS_SUBCLASS(GeneratorType); - bool operator==(const Type& rhs) const override { - return rhs.kind() == kind(); - } - std::string str() const override { - return "Generator"; - } - static const TypeKind Kind = TypeKind::GeneratorType; - // global singleton - static GeneratorTypePtr get(); -private: - GeneratorType() - : Type(TypeKind::GeneratorType) {} -}; - - -struct VarType; -using VarTypePtr = std::shared_ptr<VarType>; -// This type represents a type variable, used in FunctionSchema -struct VarType : public Type { - static VarTypePtr create(std::string name_) { - return VarTypePtr(new VarType(std::move(name_))); - } - DEFINE_IS_SUBCLASS(VarType); - bool operator==(const Type& rhs) const override { - return rhs.kind() == kind(); - } - std::string str() const override { - return name(); - } - static const TypeKind Kind = TypeKind::VarType; - const std::string& name() const { - return name_; - } - bool hasFreeVariables() const override { - return true; - } -private: - VarType(std::string name_) - : Type(TypeKind::VarType), name_(std::move(name_)) {} - std::string name_; -}; - -TORCH_API std::ostream& operator<<(std::ostream & out, const Type & t); -// what is the type, ignoring extra size/shape information? -// e.g. Tensor(2x3) -> Dynamic, and Tuple(Tensor(2x3),...) -> Tuple(Dynamic,...) - -inline TypePtr unshapedType(const TypePtr& type) { - if (type->kind() == TypeKind::TensorType || - type->kind() == TypeKind::CompleteTensorType) { - return DynamicType::get(); - } - return type->withContained(fmap(type->containedTypes(), unshapedType)); -} - -inline TypePtr CompleteTensorType::fromNumberType(TypePtr typ) { - JIT_ASSERT(typ->isSubtypeOf(NumberType::get())); - if (typ->isSubtypeOf(IntType::get())) { - return CompleteTensorType::create(at::kLong, -1, {}); - } else if (typ->isSubtypeOf(FloatType::get())) { - return CompleteTensorType::create(at::kFloat, -1, {}); - } else if (typ->isSubtypeOf(BoolType::get())) { - return CompleteTensorType::create(at::kLong, -1, {}); - } - AT_ERROR("unknown number type", typ->str()); -} - -inline TypePtr CompleteTensorType::fromBoolType() { - return CompleteTensorType::create(at::kLong, -1, {}); -} - -// Attempt to find the correct supertype of t1 and t2. If none is found then -// nullopt will be returned. If t1 == t2, or t1 is a type refinement of t2, -// then t2 will be returned (and vice versa). -// Two different tensortypes will return dynamic. -// Currently we chose not to support returning a NumberType for a float & int -// input because of a lack of operator support for NumberType -TORCH_API c10::optional<TypePtr> unifyTypes( - const TypePtr& t1, - const TypePtr& t2); - -template <typename T> -TypePtr getTypePtr() { -#define TYPE_STR(Type) #Type, " ", - AT_ERROR( - "Type ", - c10::demangle_type<T>(), - " could not be converted to any of the known types { ", - TH_FORALL_TYPES(TYPE_STR) "}"); -#undef TYPE_STR - return nullptr; -} +#define C10_USING(T) using ::c10::T; + C10_FORALL_TYPES(C10_USING) +#undef C10_USING -template<> inline TypePtr getTypePtr<at::Tensor>() { return DynamicType::get(); } -template<> inline TypePtr getTypePtr<double>() { return FloatType::get(); } -template<> inline TypePtr getTypePtr<int64_t>() { return IntType::get(); } -template<> inline TypePtr getTypePtr<bool>() { return BoolType::get(); } -template<> inline TypePtr getTypePtr<at::Scalar>() { return NumberType::get(); } -template<> inline TypePtr getTypePtr<std::vector<at::Tensor>>() { return ListType::ofTensors(); } -template<> inline TypePtr getTypePtr<std::vector<double>>() { return ListType::ofFloats(); } -template<> inline TypePtr getTypePtr<std::vector<int64_t>>() { return ListType::ofInts(); } +#define C10_USING(T) using ::c10::T##Ptr; + C10_FORALL_TYPES(C10_USING) +#undef C10_USING -TORCH_API TypePtr inferTypeFrom(const IValue& value); +using ::c10::Type; +using ::c10::TypePtr; +using ::c10::TypeEnv; +using ::c10::TypeMatchError; -struct TORCH_API TypeMatchError : public std::exception { - TypeMatchError(std::string msg_) - : msg_(std::move(msg_)) {} - const char * what() const noexcept override { - return msg_.c_str(); - } -private: - std::string msg_; -}; -using TypeEnv = std::unordered_map<std::string, TypePtr>; -TORCH_API TypePtr matchTypeVariables(TypePtr formal, TypePtr actual, TypeEnv & type_env); -TORCH_API TypePtr evalTypeVariables(TypePtr type, TypeEnv & type_env); +using ::c10::getTypePtr; +using ::c10::TypeKind; }} // namespace torch::jit diff --git a/torch/csrc/utils/functional.h b/torch/csrc/utils/functional.h index af5099e7ce..7864d2a0ab 100644 --- a/torch/csrc/utils/functional.h +++ b/torch/csrc/utils/functional.h @@ -1,63 +1,8 @@ -#pragma once - -#include <vector> -#include <ATen/ATen.h> +#include <ATen/core/functional.h> namespace torch { -// The passed in function must take T by value (T), or by -// const reference (const T&); taking T by non-const reference -// will result in an error like: -// -// error: no type named 'type' in 'class std::result_of<foobar::__lambda(T)>' -// -// No explicit template parameters are required. - -// Overload for explicit function and ArrayRef -template<typename F, typename T> -inline auto fmap(const T& inputs, const F& fn) -> std::vector<decltype(fn(*inputs.begin()))> { - std::vector<decltype(fn(*inputs.begin()))> r; - r.reserve(inputs.size()); - for(const auto & input : inputs) - r.push_back(fn(input)); - return r; -} - -template<typename F, typename T> -inline auto fmap(T& inputs, const F& fn) -> std::vector<decltype(fn(*inputs.begin()))> { - std::vector<decltype(fn(*inputs.begin()))> r; - r.reserve(inputs.size()); - for(auto & input : inputs) - r.push_back(fn(input)); - return r; -} - -// C++ forbids taking an address of a constructor, so here's a workaround... -// Overload for constructor (R) application -template<typename R, typename T> -inline std::vector<R> fmap(const T& inputs) { - std::vector<R> r; - r.reserve(inputs.size()); - for(auto & input : inputs) - r.push_back(R(input)); - return r; -} - -template<typename F, typename T> -inline std::vector<T> filter(at::ArrayRef<T> inputs, const F& fn) { - std::vector<T> r; - r.reserve(inputs.size()); - for(auto & input : inputs) { - if (fn(input)) { - r.push_back(input); - } - } - return r; -} - -template<typename F, typename T> -inline std::vector<T> filter(const std::vector<T>& inputs, const F& fn) { - return filter<F, T>(static_cast<at::ArrayRef<T>>(inputs), fn); -} +using ::c10::fmap; +using ::c10::filter; -} +} // namespace torch |