summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBram Wasti <bwasti@fb.com>2018-11-07 18:09:33 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-11-07 18:11:29 -0800
commit16165875406fc5224592fb778c96613161b78dca (patch)
tree6f2f69f5d04adbe0c15ef21aca1b24607495ab78
parent87b47ff850428a546bbcb0b5909a24f4445b5a1b (diff)
downloadpytorch-16165875406fc5224592fb778c96613161b78dca.tar.gz
pytorch-16165875406fc5224592fb778c96613161b78dca.tar.bz2
pytorch-16165875406fc5224592fb778c96613161b78dca.zip
Redo jit/type and utils/functional to ATen/core (#13455)
Summary: This is a redo of the previous move which broke OS X and Windows tests -- RTTI seemed to be broken Pull Request resolved: https://github.com/pytorch/pytorch/pull/13455 Differential Revision: D12883775 Pulled By: bwasti fbshipit-source-id: 2b6c65e8150e6f89624c6ee99c389335c6fb4bb8
-rw-r--r--aten/src/ATen/core/functional.h63
-rw-r--r--aten/src/ATen/core/jit_type.h933
-rw-r--r--aten/src/ATen/core/type.cpp (renamed from torch/csrc/jit/type.cpp)12
-rw-r--r--torch/CMakeLists.txt1
-rw-r--r--torch/csrc/autograd/python_variable_indexing.cpp4
-rw-r--r--torch/csrc/jit/function_schema.h2
-rw-r--r--torch/csrc/jit/python_ir.cpp5
-rw-r--r--torch/csrc/jit/type.h938
-rw-r--r--torch/csrc/utils/functional.h63
9 files changed, 1026 insertions, 995 deletions
diff --git a/aten/src/ATen/core/functional.h b/aten/src/ATen/core/functional.h
new file mode 100644
index 0000000000..e0f8c84253
--- /dev/null
+++ b/aten/src/ATen/core/functional.h
@@ -0,0 +1,63 @@
+#pragma once
+
+#include <vector>
+#include <ATen/core/ArrayRef.h>
+
+namespace c10 {
+
+// The passed in function must take T by value (T), or by
+// const reference (const T&); taking T by non-const reference
+// will result in an error like:
+//
+// error: no type named 'type' in 'class std::result_of<foobar::__lambda(T)>'
+//
+// No explicit template parameters are required.
+
+// Overload for explicit function and ArrayRef
+template<typename F, typename T>
+inline auto fmap(const T& inputs, const F& fn) -> std::vector<decltype(fn(*inputs.begin()))> {
+ std::vector<decltype(fn(*inputs.begin()))> r;
+ r.reserve(inputs.size());
+ for(const auto & input : inputs)
+ r.push_back(fn(input));
+ return r;
+}
+
+template<typename F, typename T>
+inline auto fmap(T& inputs, const F& fn) -> std::vector<decltype(fn(*inputs.begin()))> {
+ std::vector<decltype(fn(*inputs.begin()))> r;
+ r.reserve(inputs.size());
+ for(auto & input : inputs)
+ r.push_back(fn(input));
+ return r;
+}
+
+// C++ forbids taking an address of a constructor, so here's a workaround...
+// Overload for constructor (R) application
+template<typename R, typename T>
+inline std::vector<R> fmap(const T& inputs) {
+ std::vector<R> r;
+ r.reserve(inputs.size());
+ for(auto & input : inputs)
+ r.push_back(R(input));
+ return r;
+}
+
+template<typename F, typename T>
+inline std::vector<T> filter(at::ArrayRef<T> inputs, const F& fn) {
+ std::vector<T> r;
+ r.reserve(inputs.size());
+ for(auto & input : inputs) {
+ if (fn(input)) {
+ r.push_back(input);
+ }
+ }
+ return r;
+}
+
+template<typename F, typename T>
+inline std::vector<T> filter(const std::vector<T>& inputs, const F& fn) {
+ return filter<F, T>(static_cast<at::ArrayRef<T>>(inputs), fn);
+}
+
+} // namespace c10
diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h
new file mode 100644
index 0000000000..9307b4fee4
--- /dev/null
+++ b/aten/src/ATen/core/jit_type.h
@@ -0,0 +1,933 @@
+#pragma once
+
+#include <ATen/core/ivalue.h>
+#include <ATen/core/interned_strings.h>
+#include <ATen/core/functional.h>
+#include <ATen/core/Type.h>
+#include <ATen/core/TensorMethods.h>
+
+#include <caffe2/core/common.h>
+
+#include <memory>
+#include <iostream>
+#include <type_traits>
+
+namespace c10 {
+
+#define C10_FORALL_TYPES(_) \
+_(DynamicType) \
+_(TensorType) \
+_(CompleteTensorType) \
+_(UndefinedTensorType) \
+_(TupleType) \
+_(ListType) \
+_(NumberType) \
+_(FloatType) \
+_(FutureType) \
+_(IntType) \
+_(NoneType) \
+_(StringType) \
+_(GeneratorType) \
+_(BoolType) \
+_(OptionalType) \
+_(VarType) \
+
+enum class TypeKind {
+#define DEFINE_TYPE(T) T,
+ C10_FORALL_TYPES(DEFINE_TYPE)
+#undef DEFINE_TYPE
+};
+
+#define DEFINE_IS_SUBCLASS(_kind) \
+ bool isSubclass(const TypeKind kind) const override { \
+ return kind == TypeKind::_kind; \
+ }
+
+struct Type;
+using TypePtr = std::shared_ptr<Type>;
+
+struct CAFFE2_API Type : std::enable_shared_from_this<Type> {
+private:
+ TypeKind kind_;
+ template<typename T>
+ static std::shared_ptr<T> sliceType(std::shared_ptr<const T> ptr) {
+ auto result = std::make_shared<typename std::remove_const<T>::type>(*ptr);
+ // XXX: the line above will correctly slice the struct, and make its runtype
+ // type exactly equal to T. However, kind_ is a field of Type, so it will simply
+ // be copied, and we need to fix it in here to match the dynamic type.
+ result->kind_ = T::Kind;
+ return result;
+ }
+
+protected:
+ Type(TypeKind kind)
+ : kind_(kind) {}
+
+public:
+ virtual bool operator==(const Type& rhs) const = 0;
+
+ // subtyping relation. By default, we return true for the case
+ // when the type is exactly equal
+ virtual bool isSubtypeOf(const TypePtr rhs) const {
+ return *this == *rhs;
+ }
+
+ // If this class can be cast to the kind passed in
+ // This removes the need for RTTI
+ virtual bool isSubclass(const TypeKind kind) const = 0;
+
+ // How this type will appear in FunctionSchema declarations
+ virtual std::string str() const = 0;
+
+ // How this type will appear as if it were a type annotation in Python
+ // which is sometimes different than how it appears in declarations (e.g. int[] vs List[int])
+ virtual std::string python_str() const {
+ return str();
+ }
+
+ TypeKind kind() const {
+ return kind_;
+ }
+
+ virtual bool requires_grad() const { return false; }
+
+ // Dynamically cast this object to the subclass indicated by the
+ // template variable, returning nullptr if the cast is invalid.
+ // NOTE: if the cast succeeds, but the casted kind is not the
+ // run-time kind of the type, we also slice the structure, so
+ // that assignments of those types to values don't accidentally
+ // inherit more detailed information from subclasses.
+ template<typename T>
+ std::shared_ptr<T> cast() {
+ std::shared_ptr<T> r = nullptr;
+ if (isSubclass(T::Kind)) {
+ r = std::static_pointer_cast<T>(shared_from_this());
+ }
+ if (!r || T::Kind == kind()) {
+ return r;
+ } else {
+ return sliceType<T>(r);
+ }
+ }
+ template<typename T>
+ std::shared_ptr<const T> cast() const {
+ std::shared_ptr<const T> r = nullptr;
+ if (isSubclass(T::Kind)) {
+ r = std::static_pointer_cast<const T>(shared_from_this());
+ }
+ if (!r || T::Kind == kind()) {
+ return r;
+ } else {
+ return sliceType<T>(r);
+ }
+ }
+ template<typename T>
+ std::shared_ptr<T> expect() {
+ auto r = cast<T>();
+ AT_ASSERT(r);
+ return r;
+ }
+ template<typename T>
+ std::shared_ptr<const T> expect() const {
+ auto r = cast<const T>();
+ AT_ASSERT(r);
+ return r;
+ }
+ virtual ~Type() = default;
+ virtual bool hasFreeVariables() const {
+ return false;
+ }
+ // list of types this type contains, e.g. for a List then element type of a list
+ // for a tuple, the types of the tuple elements
+ virtual at::ArrayRef<TypePtr> containedTypes() const {
+ return {};
+ }
+ // create a new version of this type, replacing its contained types with
+ // contained_types
+ TypePtr withContained(std::vector<TypePtr> contained_types) {
+ auto current_contained = containedTypes();
+ AT_ASSERT(current_contained.size() == contained_types.size());
+ if(current_contained.equals(contained_types)) {
+ return shared_from_this();
+ }
+ return createWithContained(std::move(contained_types));
+ }
+ // per-type constructor, you only need to override this if the containedTypes()
+ // is not empty
+ virtual TypePtr createWithContained(std::vector<TypePtr> contained_types) const {
+ AT_ERROR("type with contained types did not overload createWithContained: ", str());
+ }
+};
+
+inline bool operator!=(const Type & lhs, const Type & rhs) {
+ return !(lhs == rhs);
+}
+
+struct OptionalType;
+using OptionalTypePtr = std::shared_ptr<OptionalType>;
+// This type represents an optional type, for each element type.
+struct OptionalType: public Type {
+ static OptionalTypePtr create(TypePtr element) {
+ return OptionalTypePtr(new OptionalType(std::move(element))); // NOLINT(modernize-make-shared)
+ }
+ DEFINE_IS_SUBCLASS(OptionalType);
+ bool operator==(const Type& rhs) const override {
+ if(auto rhs_ = rhs.cast<OptionalType>()) {
+ return *getElementType() == *rhs_->getElementType();
+ }
+ return false;
+ }
+ bool requires_grad() const override {
+ return elem->requires_grad();
+ }
+
+ bool isSubtypeOf(const TypePtr rhs) const override {
+ if(auto rhs_ = rhs->cast<OptionalType>()) {
+ return getElementType()->isSubtypeOf(rhs_->getElementType());
+ }
+ return false;
+ }
+
+ std::string str() const override {
+ std::stringstream ss;
+ ss << getElementType()->str() << "?";
+ return ss.str();
+ }
+ std::string python_str() const override {
+ std::stringstream ss;
+ ss << "Optional[" << getElementType()->python_str() << "]";
+ return ss.str();
+ }
+ TypePtr getElementType() const {
+ return elem;
+ }
+ bool hasFreeVariables() const override {
+ return has_free_variables_;
+ }
+
+ static const TypeKind Kind = TypeKind::OptionalType;
+private:
+ OptionalType(TypePtr elem)
+ : Type(TypeKind::OptionalType)
+ , elem(std::move(elem))
+ , has_free_variables_(getElementType()->hasFreeVariables()) {}
+ TypePtr elem;
+ bool has_free_variables_;
+
+};
+
+struct DynamicType;
+using DynamicTypePtr = std::shared_ptr<DynamicType>;
+// This type represents a single Tensor, with an unknown shape.
+struct CAFFE2_API DynamicType : public Type {
+ static DynamicTypePtr create() {
+ return DynamicTypePtr(new DynamicType()); // NOLINT(modernize-make-shared)
+ }
+ DEFINE_IS_SUBCLASS(DynamicType);
+
+ bool requires_grad() const override { return true; }
+
+ bool operator==(const Type& rhs) const override {
+ return rhs.kind() == kind();
+ }
+ std::string str() const override {
+ return "Tensor";
+ }
+ static const TypeKind Kind = TypeKind::DynamicType;
+ // global singleton
+ static DynamicTypePtr get();
+private:
+ DynamicType()
+ : Type(TypeKind::DynamicType) {}
+};
+
+struct UndefinedTensorType;
+using UndefinedTensorTypePtr = std::shared_ptr<UndefinedTensorType>;
+// This type represents an undefined tensor.
+struct CAFFE2_API UndefinedTensorType : public Type {
+ static const TypeKind Kind = TypeKind::UndefinedTensorType;
+ static UndefinedTensorTypePtr create() {
+ return UndefinedTensorTypePtr(new UndefinedTensorType()); // NOLINT(modernize-make-shared)
+ }
+
+ DEFINE_IS_SUBCLASS(UndefinedTensorType);
+
+ bool operator==(const Type& rhs) const override {
+ return rhs.kind() == kind();
+ }
+ bool isSubtypeOf(const TypePtr rhs) const override {
+ return rhs->kind() == TypeKind::DynamicType ||
+ rhs->kind() == TypeKind::UndefinedTensorType;
+ }
+ std::string str() const override {
+ return "UndefinedTensor";
+ }
+ static UndefinedTensorTypePtr get();
+protected:
+ UndefinedTensorType(): Type(TypeKind::UndefinedTensorType) {}
+};
+
+struct TensorType;
+using TensorTypePtr = std::shared_ptr<TensorType>;
+// This type represents a single Tensor with a specific size
+struct CAFFE2_API TensorType : public Type {
+ static const TypeKind Kind = TypeKind::TensorType;
+ template<typename ... T>
+ static TensorTypePtr create( T&& ... all ) {
+ return TensorTypePtr(new TensorType( std::forward<T>(all)... )); // NOLINT(modernize-make-shared)
+ }
+
+ at::ScalarType scalarType() const { return scalar_type_; }
+ int device() const { return device_; }
+ int dim() const { return dim_; }
+ bool requires_grad() const override { return requires_grad_; }
+
+ TensorTypePtr toScalarType(at::ScalarType type){
+ auto t = TensorType::create(*this);
+ t->scalar_type_ = type;
+ return t;
+ }
+ TensorTypePtr withDim(int new_dim) {
+ auto t = TensorType::create(*this);
+ t->dim_ = new_dim;
+ return t;
+ }
+ TensorTypePtr withRequiresGrad(bool req) {
+ auto t = TensorType::create(*this);
+ t->requires_grad_ = req;
+ return t;
+ }
+
+ bool operator==(const Type& rhs) const override {
+ if (rhs.kind() != TypeKind::TensorType)
+ return false;
+ auto rt = rhs.expect<TensorType>();
+ return scalarType() == rt->scalarType() &&
+ device() == rt->device() &&
+ dim() == rt->dim();
+ }
+ bool isSubtypeOf(const TypePtr rhs) const override {
+ if (rhs->kind() == TypeKind::DynamicType)
+ return true;
+ return rhs->kind() == TypeKind::TensorType && *this == *rhs;
+ }
+ bool isSubclass(const TypeKind kind) const override {
+ return kind == TypeKind::DynamicType ||
+ kind == TypeKind::TensorType;
+ }
+ std::string str() const override {
+ // str is used for user-facing error messages, where we
+ // don't want to reveal underlying size information.
+ return "Tensor";
+ }
+
+protected:
+ TensorType(const at::Tensor& tensor, TypeKind kind=TypeKind::TensorType)
+ : TensorType(tensor.type().scalarType(),
+ tensor.is_cuda() ? tensor.get_device() : -1,
+ tensor.dim(),
+ tensor.is_variable() && tensor.requires_grad(),
+ kind) {}
+ TensorType(at::ScalarType scalar_type, int device, int dim, bool requires_grad=true, TypeKind kind=TypeKind::TensorType)
+ : Type(kind)
+ , scalar_type_(scalar_type)
+ , requires_grad_(at::isFloatingType(scalar_type) && requires_grad)
+ , device_(device)
+ , dim_(dim) {}
+
+ at::ScalarType scalar_type_;
+ bool requires_grad_;
+ int device_;
+ int dim_;
+};
+
+struct CompleteTensorType;
+using CompleteTensorTypePtr = std::shared_ptr<CompleteTensorType>;
+// This type represents a single Tensor with a specific size
+struct CAFFE2_API CompleteTensorType : public TensorType {
+ template<typename ... T>
+ static CompleteTensorTypePtr create( T&& ... all ) {
+ return CompleteTensorTypePtr(new CompleteTensorType( std::forward<T>(all)... )); // NOLINT(modernize-make-shared)
+ }
+
+ // overloaded create variadic template argument as it could not distinguish initializer list
+ static CompleteTensorTypePtr create(at::ScalarType scalar_type, int device, at::IntList sizes) {
+ return CompleteTensorTypePtr(new CompleteTensorType(scalar_type, device, sizes)); // NOLINT(modernize-make-shared)
+ }
+ static CompleteTensorTypePtr create(at::ScalarType scalar_type, int device, at::IntList sizes, at::IntList strides) {
+ return CompleteTensorTypePtr(new CompleteTensorType(scalar_type, device, sizes, strides)); // NOLINT(modernize-make-shared)
+ }
+
+ static const TypeKind Kind = TypeKind::CompleteTensorType;
+
+ const std::vector<int64_t>& sizes() const { return sizes_; }
+ const std::vector<int64_t>& strides() const { return strides_; }
+
+ TypePtr withSizesStrides(at::IntList sizes, at::IntList strides) const {
+ return CompleteTensorType::create(scalar_type_, device_, sizes, strides);
+ }
+
+ TypePtr withSizes(at::IntList sizes) const {
+ return withSizesStrides(sizes, CompleteTensorType::contiguousStridesOf(sizes));
+ }
+
+ CompleteTensorTypePtr contiguous() const {
+ auto t = CompleteTensorType::create(*this);
+ t->strides_ = CompleteTensorType::contiguousStridesOf(sizes_);
+ return t;
+ }
+
+ CompleteTensorTypePtr toScalarType(at::ScalarType type){
+ auto t = CompleteTensorType::create(*this);
+ t->scalar_type_ = type;
+ return t;
+ }
+
+ bool operator==(const Type& rhs) const override {
+ if(rhs.kind() != kind())
+ return false;
+ auto rt = rhs.expect<CompleteTensorType>();
+ return scalarType() == rt->scalarType() &&
+ sizes() == rt->sizes() &&
+ strides() == rt->strides() &&
+ device() == rt->device();
+ }
+ bool isSubtypeOf(const TypePtr rhs) const override {
+ if (rhs->kind() == TypeKind::DynamicType)
+ return true;
+ if (rhs->kind() == TypeKind::TensorType)
+ return *expect<TensorType>() == *rhs;
+ return *this == *rhs;
+ }
+ bool isSubclass(const TypeKind kind) const override {
+ return kind == TypeKind::DynamicType ||
+ kind == TypeKind::TensorType ||
+ kind == TypeKind::CompleteTensorType;
+ }
+ std::string str() const override {
+ // str is used for user-facing error messages, where we
+ // don't want to reveal underlying size information.
+ return "Tensor";
+ }
+ bool numel() const {
+ size_t prod = 1;
+ for(auto s : sizes()) {
+ prod *= s;
+ }
+ return prod;
+ }
+ static TypePtr fromNumberType(TypePtr typ);
+ static TypePtr fromBoolType();
+
+private:
+ CompleteTensorType(const at::Tensor& tensor)
+ : TensorType(tensor, TypeKind::CompleteTensorType)
+ , sizes_(tensor.sizes().vec())
+ , strides_(tensor.strides().vec()) {}
+ CompleteTensorType(at::ScalarType scalar_type, int device, at::IntList sizes, bool requires_grad=true)
+ : CompleteTensorType(scalar_type, device, sizes, CompleteTensorType::contiguousStridesOf(sizes), requires_grad) {}
+ CompleteTensorType(at::ScalarType scalar_type, int device, at::IntList sizes, at::IntList strides, bool requires_grad=true)
+ : TensorType(scalar_type, device, sizes.size(), requires_grad, TypeKind::CompleteTensorType)
+ , sizes_(sizes.vec())
+ , strides_(strides.vec()) {}
+
+ static std::vector<int64_t> contiguousStridesOf(at::IntList sizes) {
+ std::vector<int64_t> strides(sizes.size());
+ if(sizes.empty()) // zero-dim case
+ return strides;
+ strides.back() = 1;
+ for(size_t i = strides.size() - 1; i > 0; i--) {
+ strides[i-1] = strides[i] * sizes[i];
+ }
+ return strides;
+ }
+
+ std::vector<int64_t> sizes_;
+ std::vector<int64_t> strides_;
+};
+
+// common base for all types that have a single sub element
+// e.g. Future[T], Option[T], List[T]
+template<TypeKind K, typename T>
+struct SingleElementType : public Type {
+ static const TypeKind Kind = K;
+ TypePtr getElementType() const {
+ return elem;
+ }
+ bool hasFreeVariables() const override {
+ return has_free_variables_;
+ }
+ at::ArrayRef<TypePtr> containedTypes() const override {
+ return elem;
+ }
+ bool requires_grad() const override {
+ return elem->requires_grad();
+ }
+ bool operator==(const Type& rhs) const override {
+ if(auto rhs_ = rhs.cast<T>()) {
+ return *getElementType() == *rhs_->getElementType();
+ }
+ return false;
+ }
+protected:
+ SingleElementType(TypePtr elem)
+ : Type(Kind)
+ , elem(std::move(elem))
+ , has_free_variables_(getElementType()->hasFreeVariables()) {}
+private:
+ TypePtr elem;
+ bool has_free_variables_;
+};
+
+struct ListType;
+using ListTypePtr = std::shared_ptr<ListType>;
+struct CAFFE2_API ListType : public SingleElementType<TypeKind::ListType, ListType> {
+ // It's not exactly a singleton, but there should be exactly once instance of
+ // List[T] for every T
+ friend struct Type;
+ template<typename ... T>
+ static ListTypePtr create( T&& ... all ) {
+ return ListTypePtr(new ListType( std::forward<T>(all)... )); // NOLINT(modernize-make-shared)
+ }
+ DEFINE_IS_SUBCLASS(ListType);
+ std::string str() const override {
+ std::stringstream ss;
+ ss << getElementType()->str() << "[]";
+ return ss.str();
+ }
+ std::string python_str() const override {
+ std::stringstream ss;
+ ss << "List[" << getElementType()->python_str() << "]";
+ return ss.str();
+ }
+ TypePtr createWithContained(std::vector<TypePtr> contained_types) const override {
+ return create(contained_types.at(0));
+ }
+ // common cast List[Tensor]
+ static ListTypePtr ofTensors();
+ static ListTypePtr ofInts();
+ static ListTypePtr ofFloats();
+ static ListTypePtr ofBools();
+private:
+ using SingleElementType::SingleElementType;
+};
+
+struct FutureType;
+using FutureTypePtr = std::shared_ptr<FutureType>;
+
+struct CAFFE2_API FutureType : public Type {
+ friend struct Type;
+ template<typename ... T>
+ static FutureTypePtr create(TypePtr elem) {
+ return FutureTypePtr(new FutureType(std::move(elem))); // NOLINT(modernize-make-shared)
+ }
+
+ DEFINE_IS_SUBCLASS(FutureType);
+
+ bool operator==(const Type& rhs) const override {
+ if (auto rhs_ = rhs.cast<FutureType>()) {
+ return *getElementType() == *rhs_->getElementType();
+ }
+ return false;
+ }
+ bool requires_grad() const override {
+ return elem->requires_grad();
+ }
+ std::string str() const override {
+ std::stringstream ss;
+ ss << "Future(" << getElementType()->str() << ")";
+ return ss.str();
+ }
+ std::string python_str() const override {
+ std::stringstream ss;
+ ss << "Future[" << getElementType()->python_str() << "]";
+ return ss.str();
+ }
+ TypePtr getElementType() const {
+ return elem;
+ }
+ bool hasFreeVariables() const override {
+ return has_free_variables_;
+ }
+
+ static const TypeKind Kind = TypeKind::FutureType;
+private:
+ FutureType(TypePtr elem)
+ : Type(TypeKind::FutureType)
+ , elem(std::move(elem))
+ , has_free_variables_(getElementType()->hasFreeVariables()) {}
+ TypePtr elem;
+ bool has_free_variables_;
+};
+
+struct TupleType;
+using TupleTypePtr = std::shared_ptr<TupleType>;
+// This type represents a Tuple
+struct CAFFE2_API TupleType : public Type {
+ static TupleTypePtr create(std::vector<TypePtr> types) {
+ return TupleTypePtr(new TupleType( std::move(types) )); // NOLINT(modernize-make-shared)
+ }
+ DEFINE_IS_SUBCLASS(TupleType);
+ at::ArrayRef<TypePtr> elements() const {
+ return elements_;
+ }
+ bool operator==(const Type& rhs) const override {
+ return compare(rhs, [](const TypePtr a, const TypePtr b) {
+ return *a == *b;
+ });
+ }
+ bool isSubtypeOf(const TypePtr rhs) const override {
+ // co-variant rules for tuples
+ return compare(*rhs, [](const TypePtr a, const TypePtr b) {
+ return a->isSubtypeOf(b);
+ });
+ }
+ bool requires_grad() const override {
+ return std::any_of(elements_.begin(), elements_.end(),
+ [](const TypePtr& ptr) { return ptr->requires_grad(); });
+ }
+ std::string str() const override {
+ std::stringstream ss;
+ ss << "(";
+ for(size_t i = 0; i < elements().size(); ++i) {
+ if(i > 0)
+ ss << ", ";
+ ss << elements()[i]->str();
+ }
+ ss << ")";
+ return ss.str();
+ }
+ std::string python_str() const override {
+ std::stringstream ss;
+ ss << "Tuple[";
+ for(size_t i = 0; i < elements().size(); ++i) {
+ if(i > 0)
+ ss << ", ";
+ ss << elements()[i]->python_str();
+ }
+ ss << "]";
+ return ss.str();
+ }
+ bool hasFreeVariables() const override {
+ return has_free_variables_;
+ }
+
+ at::ArrayRef<TypePtr> containedTypes() const override {
+ return elements_;
+ }
+ TypePtr createWithContained(std::vector<TypePtr> contained_types) const override {
+ return create(std::move(contained_types));
+ }
+
+ static const TypeKind Kind = TypeKind::TupleType;
+private:
+ TupleType(std::vector<TypePtr> elements_)
+ : Type(TypeKind::TupleType)
+ , elements_(std::move(elements_)) {
+ has_free_variables_ =
+ std::any_of(elements_.begin(), elements_.end(), [](TypePtr v) {
+ return v->hasFreeVariables();
+ });
+ }
+
+ bool compare(const Type& rhs, std::function<bool(const TypePtr, const TypePtr)> fn) const {
+ if(rhs.kind() != kind())
+ return false;
+ const auto & l_elements = elements();
+ const auto & r_elements = rhs.cast<TupleType>()->elements();
+ if(l_elements.size() != r_elements.size())
+ return false;
+ for(size_t i = 0; i < l_elements.size(); ++i) {
+ if(!fn(l_elements[i], r_elements[i]))
+ return false;
+ }
+ return true;
+ }
+ std::vector<TypePtr> elements_;
+ bool has_free_variables_;
+};
+
+struct NumberType;
+using NumberTypePtr = std::shared_ptr<NumberType>;
+// This type represents a Python number
+struct CAFFE2_API NumberType : public Type {
+ static NumberTypePtr create() {
+ return NumberTypePtr(new NumberType()); // NOLINT(modernize-make-shared)
+ }
+ DEFINE_IS_SUBCLASS(NumberType);
+ bool operator==(const Type& rhs) const override {
+ return rhs.kind() == kind();
+ }
+ std::string str() const override {
+ return "Scalar"; // match what PythonArgParser says for clarity
+ }
+ static const TypeKind Kind = TypeKind::NumberType;
+ // global singleton
+ static NumberTypePtr get();
+private:
+ NumberType()
+ : Type(TypeKind::NumberType) {}
+};
+
+struct FloatType;
+using FloatTypePtr = std::shared_ptr<FloatType>;
+// This type represents a Python float number
+struct CAFFE2_API FloatType : public Type {
+ static FloatTypePtr create() {
+ return FloatTypePtr(new FloatType()); // NOLINT(modernize-make-shared)
+ }
+ DEFINE_IS_SUBCLASS(FloatType);
+ bool operator==(const Type& rhs) const override {
+ return rhs.kind() == kind();
+ }
+ std::string str() const override {
+ return "float";
+ }
+ bool isSubtypeOf(const TypePtr rhs) const override {
+ if(auto rhs_ = rhs->cast<OptionalType>()) {
+ return this->isSubtypeOf(rhs_->getElementType());
+ }
+ return *this == *rhs || rhs->kind() == TypeKind::NumberType;
+ }
+ static const TypeKind Kind = TypeKind::FloatType;
+ // global singleton
+ static FloatTypePtr get();
+private:
+ FloatType()
+ : Type(TypeKind::FloatType) {}
+};
+
+struct IntType;
+using IntTypePtr = std::shared_ptr<IntType>;
+// This type represents a Python int number
+struct CAFFE2_API IntType : public Type {
+ static IntTypePtr create() {
+ return IntTypePtr(new IntType()); // NOLINT(modernize-make-shared)
+ }
+ DEFINE_IS_SUBCLASS(IntType);
+ bool operator==(const Type& rhs) const override {
+ return rhs.kind() == kind();
+ }
+ std::string str() const override {
+ return "int";
+ }
+ bool isSubtypeOf(const TypePtr rhs) const override {
+ if(auto rhs_ = rhs->cast<OptionalType>()) {
+ return this->isSubtypeOf(rhs_->getElementType());
+ }
+ return *this == *rhs || rhs->kind() == TypeKind::NumberType;
+ }
+ static const TypeKind Kind = TypeKind::IntType;
+ // global singleton
+ static IntTypePtr get();
+private:
+ IntType()
+ : Type(TypeKind::IntType) {}
+};
+
+struct BoolType;
+using BoolTypePtr = std::shared_ptr<BoolType>;
+// This node represents a Python bool value
+struct CAFFE2_API BoolType : public Type {
+ static BoolTypePtr create( ) {
+ return BoolTypePtr(new BoolType());
+ }
+ DEFINE_IS_SUBCLASS(BoolType);
+ bool operator==(const Type& rhs) const override {
+ return rhs.kind() == kind();
+ }
+ std::string str() const override {
+ return "bool";
+ }
+ bool isSubtypeOf(const TypePtr rhs) const override {
+ return *this == *rhs || rhs->kind() == TypeKind::BoolType;
+ }
+ static const TypeKind Kind = TypeKind::BoolType;
+ // global singleton
+ static BoolTypePtr get();
+private:
+ BoolType()
+ : Type(TypeKind::BoolType) {}
+};
+
+struct StringType;
+using StringTypePtr = std::shared_ptr<StringType>;
+// This type represents a Python string
+struct CAFFE2_API StringType : public Type {
+ static StringTypePtr create() {
+ return StringTypePtr(new StringType()); // NOLINT(modernize-make-shared)
+ }
+ DEFINE_IS_SUBCLASS(StringType);
+ bool operator==(const Type& rhs) const override {
+ return rhs.kind() == kind();
+ }
+ std::string str() const override {
+ return "string";
+ }
+ bool isSubtypeOf(const TypePtr rhs) const override {
+ if(auto rhs_ = rhs->cast<OptionalType>()) {
+ return this->isSubtypeOf(rhs_->getElementType());
+ }
+ return *this == *rhs;
+ }
+ static const TypeKind Kind = TypeKind::StringType;
+ // global singleton
+ static StringTypePtr get();
+private:
+ StringType()
+ : Type(TypeKind::StringType) {}
+};
+
+struct NoneType;
+using NoneTypePtr = std::shared_ptr<NoneType>;
+// This type represents a Python None
+struct CAFFE2_API NoneType : public Type {
+ static NoneTypePtr create() {
+ return NoneTypePtr(new NoneType()); // NOLINT(modernize-make-shared)
+ }
+ DEFINE_IS_SUBCLASS(NoneType);
+ bool operator==(const Type& rhs) const override {
+ return rhs.kind() == kind();
+ }
+
+ bool isSubtypeOf(const TypePtr rhs) const override {
+ return rhs->kind() == TypeKind::NoneType ||
+ rhs->kind() == TypeKind::OptionalType;
+ }
+
+ std::string str() const override {
+ return "None";
+ }
+ static const TypeKind Kind = TypeKind::NoneType;
+ // global singleton
+ static NoneTypePtr get();
+private:
+ NoneType()
+ : Type(TypeKind::NoneType) {}
+};
+
+struct GeneratorType;
+using GeneratorTypePtr = std::shared_ptr<GeneratorType>;
+// This type represents a Generator
+struct CAFFE2_API GeneratorType : public Type {
+ static GeneratorTypePtr create() {
+ return GeneratorTypePtr(new GeneratorType()); // NOLINT(modernize-make-shared)
+ }
+ DEFINE_IS_SUBCLASS(GeneratorType);
+ bool operator==(const Type& rhs) const override {
+ return rhs.kind() == kind();
+ }
+ std::string str() const override {
+ return "Generator";
+ }
+ static const TypeKind Kind = TypeKind::GeneratorType;
+ // global singleton
+ static GeneratorTypePtr get();
+private:
+ GeneratorType()
+ : Type(TypeKind::GeneratorType) {}
+};
+
+
+struct VarType;
+using VarTypePtr = std::shared_ptr<VarType>;
+// This type represents a type variable, used in FunctionSchema
+struct VarType : public Type {
+ static VarTypePtr create(std::string name_) {
+ return VarTypePtr(new VarType(std::move(name_)));
+ }
+ DEFINE_IS_SUBCLASS(VarType);
+ bool operator==(const Type& rhs) const override {
+ return rhs.kind() == kind();
+ }
+ std::string str() const override {
+ return name();
+ }
+ static const TypeKind Kind = TypeKind::VarType;
+ const std::string& name() const {
+ return name_;
+ }
+ bool hasFreeVariables() const override {
+ return true;
+ }
+private:
+ VarType(std::string name_)
+ : Type(TypeKind::VarType), name_(std::move(name_)) {}
+ std::string name_;
+};
+
+CAFFE2_API std::ostream& operator<<(std::ostream & out, const Type & t);
+// what is the type, ignoring extra size/shape information?
+// e.g. Tensor(2x3) -> Dynamic, and Tuple(Tensor(2x3),...) -> Tuple(Dynamic,...)
+
+inline TypePtr unshapedType(const TypePtr& type) {
+ if (type->kind() == TypeKind::TensorType ||
+ type->kind() == TypeKind::CompleteTensorType) {
+ return DynamicType::get();
+ }
+ return type->withContained(fmap(type->containedTypes(), unshapedType));
+}
+
+inline TypePtr CompleteTensorType::fromNumberType(TypePtr typ) {
+ AT_ASSERT(typ->isSubtypeOf(NumberType::get()));
+ if (typ->isSubtypeOf(IntType::get())) {
+ return CompleteTensorType::create(at::kLong, -1, {});
+ } else if (typ->isSubtypeOf(FloatType::get())) {
+ return CompleteTensorType::create(at::kFloat, -1, {});
+ } else if (typ->isSubtypeOf(BoolType::get())) {
+ return CompleteTensorType::create(at::kLong, -1, {});
+ }
+ AT_ERROR("unknown number type", typ->str());
+}
+
+inline TypePtr CompleteTensorType::fromBoolType() {
+ return CompleteTensorType::create(at::kLong, -1, {});
+}
+
+// Attempt to find the correct supertype of t1 and t2. If none is found then
+// nullopt will be returned. If t1 == t2, or t1 is a type refinement of t2,
+// then t2 will be returned (and vice versa).
+// Two different tensortypes will return dynamic.
+// Currently we chose not to support returning a NumberType for a float & int
+// input because of a lack of operator support for NumberType
+CAFFE2_API c10::optional<TypePtr> unifyTypes(
+ const TypePtr& t1,
+ const TypePtr& t2);
+
+template <typename T>
+TypePtr getTypePtr() {
+#define TYPE_STR(Type) #Type, " ",
+ AT_ERROR(
+ "Type ",
+ c10::demangle_type<T>(),
+ " could not be converted to any of the known types { ",
+ C10_FORALL_TYPES(TYPE_STR) "}");
+#undef TYPE_STR
+ return nullptr;
+}
+
+template<> inline TypePtr getTypePtr<at::Tensor>() { return DynamicType::get(); }
+template<> inline TypePtr getTypePtr<double>() { return FloatType::get(); }
+template<> inline TypePtr getTypePtr<int64_t>() { return IntType::get(); }
+template<> inline TypePtr getTypePtr<bool>() { return BoolType::get(); }
+template<> inline TypePtr getTypePtr<at::Scalar>() { return NumberType::get(); }
+template<> inline TypePtr getTypePtr<std::vector<at::Tensor>>() { return ListType::ofTensors(); }
+template<> inline TypePtr getTypePtr<std::vector<double>>() { return ListType::ofFloats(); }
+template<> inline TypePtr getTypePtr<std::vector<int64_t>>() { return ListType::ofInts(); }
+
+CAFFE2_API TypePtr inferTypeFrom(const IValue& value);
+
+struct CAFFE2_API TypeMatchError : public std::exception {
+ TypeMatchError(std::string msg_)
+ : msg_(std::move(msg_)) {}
+ const char * what() const noexcept override {
+ return msg_.c_str();
+ }
+private:
+ std::string msg_;
+};
+using TypeEnv = std::unordered_map<std::string, TypePtr>;
+CAFFE2_API TypePtr matchTypeVariables(TypePtr formal, TypePtr actual, TypeEnv & type_env);
+CAFFE2_API TypePtr evalTypeVariables(TypePtr type, TypeEnv & type_env);
+
+} // namespace c10
diff --git a/torch/csrc/jit/type.cpp b/aten/src/ATen/core/type.cpp
index 3ae2879b31..77191fdd52 100644
--- a/torch/csrc/jit/type.cpp
+++ b/aten/src/ATen/core/type.cpp
@@ -1,17 +1,15 @@
-#include "torch/csrc/jit/type.h"
-
-#include "torch/csrc/jit/assertions.h"
+#include <ATen/core/jit_type.h>
#include <iostream>
-namespace torch { namespace jit {
+namespace c10 {
std::ostream& operator<<(std::ostream & out, const Type & t) {
if(auto value = t.cast<CompleteTensorType>()) {
out << at::toString(value->scalarType()) << "(";
auto& sizes = value->sizes();
auto& strides = value->strides();
- JIT_ASSERT(sizes.size() == strides.size());
+ AT_ASSERT(sizes.size() == strides.size());
for (size_t i = 0; i < sizes.size(); i++) {
if (i > 0) {
out << ", ";
@@ -260,7 +258,7 @@ TypePtr matchTypeVariables(TypePtr formal, TypePtr actual, TypeEnv& type_env) {
}
// change return types like List[List[t]] into List[List[int]]
-TORCH_API TypePtr evalTypeVariables(TypePtr type, std::unordered_map<std::string, TypePtr>& type_env) {
+CAFFE2_API TypePtr evalTypeVariables(TypePtr type, std::unordered_map<std::string, TypePtr>& type_env) {
if(!type->hasFreeVariables())
return type;
@@ -276,4 +274,4 @@ TORCH_API TypePtr evalTypeVariables(TypePtr type, std::unordered_map<std::string
}
}
-}} // namespace torch::jit
+} // namespace c10
diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt
index 3b879b37c9..ebebaa3388 100644
--- a/torch/CMakeLists.txt
+++ b/torch/CMakeLists.txt
@@ -193,7 +193,6 @@ set(TORCH_SRCS
${TORCH_SRC_DIR}/csrc/jit/script/lexer.cpp
${TORCH_SRC_DIR}/csrc/jit/script/module.cpp
${TORCH_SRC_DIR}/csrc/jit/tracer.cpp
- ${TORCH_SRC_DIR}/csrc/jit/type.cpp
${TORCH_SRC_DIR}/csrc/torch.cpp
${TORCH_SRC_DIR}/csrc/utils/tensor_flatten.cpp
${TORCH_SRC_DIR}/csrc/utils/variadic.cpp
diff --git a/torch/csrc/autograd/python_variable_indexing.cpp b/torch/csrc/autograd/python_variable_indexing.cpp
index 40c2dc9d9d..0ea8ce0bc8 100644
--- a/torch/csrc/autograd/python_variable_indexing.cpp
+++ b/torch/csrc/autograd/python_variable_indexing.cpp
@@ -106,12 +106,12 @@ static Variable applySelect(const Variable& self, int64_t dim, int64_t index) {
return self.select(dim, index);
}
-static Variable sequenceToVariable(const Type& type, PyObject* seq) {
+static Variable sequenceToVariable(const at::Type& type, PyObject* seq) {
auto& idx_type = type.toScalarType(kLong);
return torch::utils::legacy_new_from_data(idx_type, c10::nullopt, seq);
}
-static Variable valueToTensor(const Type & type, PyObject* value) {
+static Variable valueToTensor(const at::Type & type, PyObject* value) {
if (THPVariable_Check(value)) {
return reinterpret_cast<THPVariable*>(value)->cdata;
}
diff --git a/torch/csrc/jit/function_schema.h b/torch/csrc/jit/function_schema.h
index 9351f272be..a02c1da32e 100644
--- a/torch/csrc/jit/function_schema.h
+++ b/torch/csrc/jit/function_schema.h
@@ -2,8 +2,10 @@
#include "ATen/ATen.h"
#include "torch/csrc/jit/type.h"
+#include "torch/csrc/jit/interned_strings.h"
#include "torch/csrc/jit/ivalue.h"
#include "torch/csrc/jit/alias_info.h"
+#include "torch/csrc/jit/assertions.h"
namespace torch { namespace jit {
diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp
index fee47eb9d5..32ba16b35f 100644
--- a/torch/csrc/jit/python_ir.cpp
+++ b/torch/csrc/jit/python_ir.cpp
@@ -16,6 +16,8 @@
namespace torch { namespace jit {
+using c10::Type;
+
std::string getPythonName(const PyObject* obj_) {
AutoGIL gil;
PyObject* obj = const_cast<PyObject*>(obj_);
@@ -413,6 +415,7 @@ void initPythonIRBindings(PyObject * module_) {
})
;
+ using ::c10::Type;
py::class_<Type,std::shared_ptr<Type>>(m,"Type")
.def("__repr__",[](Type & t) {
return t.python_str();
@@ -479,7 +482,7 @@ void initPythonIRBindings(PyObject * module_) {
.def("isSubtypeOf", [](std::shared_ptr<Type>& self, std::shared_ptr<Type> other) {
return self->isSubtypeOf(other);
})
- .def_static("inferFrom", inferTypeFrom);
+ .def_static("inferFrom", c10::inferTypeFrom);
py::class_<NumberType, Type, std::shared_ptr<NumberType>>(m, "NumberType")
.def_static("get", &NumberType::get);
diff --git a/torch/csrc/jit/type.h b/torch/csrc/jit/type.h
index 8a21ad11d5..384957da61 100644
--- a/torch/csrc/jit/type.h
+++ b/torch/csrc/jit/type.h
@@ -1,933 +1,21 @@
-#pragma once
-
-#include "torch/csrc/jit/ivalue.h"
-#include "torch/csrc/jit/assertions.h"
-#include "torch/csrc/jit/interned_strings.h"
-#include "torch/csrc/WindowsTorchApiMacro.h"
-#include "torch/csrc/utils/functional.h"
-
-#include <ATen/ATen.h>
-
-#include <memory>
-#include <iostream>
-#include <type_traits>
+#include <ATen/core/jit_type.h>
namespace torch { namespace jit {
-#define TH_FORALL_TYPES(_) \
-_(DynamicType) \
-_(TensorType) \
-_(CompleteTensorType) \
-_(UndefinedTensorType) \
-_(TupleType) \
-_(ListType) \
-_(NumberType) \
-_(FloatType) \
-_(FutureType) \
-_(IntType) \
-_(NoneType) \
-_(StringType) \
-_(GeneratorType) \
-_(BoolType) \
-_(OptionalType) \
-_(VarType) \
-
-enum class TypeKind {
-#define DEFINE_TYPE(T) T,
- TH_FORALL_TYPES(DEFINE_TYPE)
-#undef DEFINE_TYPE
-};
-
-#define DEFINE_IS_SUBCLASS(_kind) \
- bool isSubclass(const TypeKind kind) const override { \
- return kind == TypeKind::_kind; \
- }
-
-struct Type;
-using TypePtr = std::shared_ptr<Type>;
-
-struct TORCH_API Type : std::enable_shared_from_this<Type> {
-private:
- TypeKind kind_;
- template<typename T>
- static std::shared_ptr<T> sliceType(std::shared_ptr<const T> ptr) {
- auto result = std::make_shared<typename std::remove_const<T>::type>(*ptr);
- // XXX: the line above will correctly slice the struct, and make its runtype
- // type exactly equal to T. However, kind_ is a field of Type, so it will simply
- // be copied, and we need to fix it in here to match the dynamic type.
- result->kind_ = T::Kind;
- return result;
- }
-
-protected:
- Type(TypeKind kind)
- : kind_(kind) {}
-
-public:
- virtual bool operator==(const Type& rhs) const = 0;
-
- // subtyping relation. By default, we return true for the case
- // when the type is exactly equal
- virtual bool isSubtypeOf(const TypePtr rhs) const {
- return *this == *rhs;
- }
-
- // If this class can be cast to the kind passed in
- // This removes the need for RTTI
- virtual bool isSubclass(const TypeKind kind) const = 0;
-
- // How this type will appear in FunctionSchema declarations
- virtual std::string str() const = 0;
-
- // How this type will appear as if it were a type annotation in Python
- // which is sometimes different than how it appears in declarations (e.g. int[] vs List[int])
- virtual std::string python_str() const {
- return str();
- }
-
- TypeKind kind() const {
- return kind_;
- }
-
- virtual bool requires_grad() const { return false; }
-
- // Dynamically cast this object to the subclass indicated by the
- // template variable, returning nullptr if the cast is invalid.
- // NOTE: if the cast succeeds, but the casted kind is not the
- // run-time kind of the type, we also slice the structure, so
- // that assignments of those types to values don't accidentally
- // inherit more detailed information from subclasses.
- template<typename T>
- std::shared_ptr<T> cast() {
- std::shared_ptr<T> r = nullptr;
- if (isSubclass(T::Kind)) {
- r = std::static_pointer_cast<T>(shared_from_this());
- }
- if (!r || T::Kind == kind()) {
- return r;
- } else {
- return sliceType<T>(r);
- }
- }
- template<typename T>
- std::shared_ptr<const T> cast() const {
- std::shared_ptr<const T> r = nullptr;
- if (isSubclass(T::Kind)) {
- r = std::static_pointer_cast<const T>(shared_from_this());
- }
- if (!r || T::Kind == kind()) {
- return r;
- } else {
- return sliceType<T>(r);
- }
- }
- template<typename T>
- std::shared_ptr<T> expect() {
- auto r = cast<T>();
- JIT_ASSERT(r);
- return r;
- }
- template<typename T>
- std::shared_ptr<const T> expect() const {
- auto r = cast<const T>();
- JIT_ASSERT(r);
- return r;
- }
- virtual ~Type() = default;
- virtual bool hasFreeVariables() const {
- return false;
- }
- // list of types this type contains, e.g. for a List then element type of a list
- // for a tuple, the types of the tuple elements
- virtual at::ArrayRef<TypePtr> containedTypes() const {
- return {};
- }
- // create a new version of this type, replacing its contained types with
- // contained_types
- TypePtr withContained(std::vector<TypePtr> contained_types) {
- auto current_contained = containedTypes();
- JIT_ASSERT(current_contained.size() == contained_types.size());
- if(current_contained.equals(contained_types)) {
- return shared_from_this();
- }
- return createWithContained(std::move(contained_types));
- }
- // per-type constructor, you only need to override this if the containedTypes()
- // is not empty
- virtual TypePtr createWithContained(std::vector<TypePtr> contained_types) const {
- AT_ERROR("type with contained types did not overload createWithContained: ", str());
- }
-};
-
-inline bool operator!=(const Type & lhs, const Type & rhs) {
- return !(lhs == rhs);
-}
-
-struct OptionalType;
-using OptionalTypePtr = std::shared_ptr<OptionalType>;
-// This type represents an optional type, for each element type.
-struct OptionalType: public Type {
- static OptionalTypePtr create(TypePtr element) {
- return OptionalTypePtr(new OptionalType(std::move(element))); // NOLINT(modernize-make-shared)
- }
- DEFINE_IS_SUBCLASS(OptionalType);
- bool operator==(const Type& rhs) const override {
- if(auto rhs_ = rhs.cast<OptionalType>()) {
- return *getElementType() == *rhs_->getElementType();
- }
- return false;
- }
- bool requires_grad() const override {
- return elem->requires_grad();
- }
-
- bool isSubtypeOf(const TypePtr rhs) const override {
- if(auto rhs_ = rhs->cast<OptionalType>()) {
- return getElementType()->isSubtypeOf(rhs_->getElementType());
- }
- return false;
- }
-
- std::string str() const override {
- std::stringstream ss;
- ss << getElementType()->str() << "?";
- return ss.str();
- }
- std::string python_str() const override {
- std::stringstream ss;
- ss << "Optional[" << getElementType()->python_str() << "]";
- return ss.str();
- }
- TypePtr getElementType() const {
- return elem;
- }
- bool hasFreeVariables() const override {
- return has_free_variables_;
- }
-
- static const TypeKind Kind = TypeKind::OptionalType;
-private:
- OptionalType(TypePtr elem)
- : Type(TypeKind::OptionalType)
- , elem(std::move(elem))
- , has_free_variables_(getElementType()->hasFreeVariables()) {}
- TypePtr elem;
- bool has_free_variables_;
-
-};
-
-struct DynamicType;
-using DynamicTypePtr = std::shared_ptr<DynamicType>;
-// This type represents a single Tensor, with an unknown shape.
-struct TORCH_API DynamicType : public Type {
- static DynamicTypePtr create() {
- return DynamicTypePtr(new DynamicType()); // NOLINT(modernize-make-shared)
- }
- DEFINE_IS_SUBCLASS(DynamicType);
-
- bool requires_grad() const override { return true; }
-
- bool operator==(const Type& rhs) const override {
- return rhs.kind() == kind();
- }
- std::string str() const override {
- return "Tensor";
- }
- static const TypeKind Kind = TypeKind::DynamicType;
- // global singleton
- static DynamicTypePtr get();
-private:
- DynamicType()
- : Type(TypeKind::DynamicType) {}
-};
-
-struct UndefinedTensorType;
-using UndefinedTensorTypePtr = std::shared_ptr<UndefinedTensorType>;
-// This type represents an undefined tensor.
-struct TORCH_API UndefinedTensorType : public Type {
- static const TypeKind Kind = TypeKind::UndefinedTensorType;
- static UndefinedTensorTypePtr create() {
- return UndefinedTensorTypePtr(new UndefinedTensorType()); // NOLINT(modernize-make-shared)
- }
-
- DEFINE_IS_SUBCLASS(UndefinedTensorType);
-
- bool operator==(const Type& rhs) const override {
- return rhs.kind() == kind();
- }
- bool isSubtypeOf(const TypePtr rhs) const override {
- return rhs->kind() == TypeKind::DynamicType ||
- rhs->kind() == TypeKind::UndefinedTensorType;
- }
- std::string str() const override {
- return "UndefinedTensor";
- }
- static UndefinedTensorTypePtr get();
-protected:
- UndefinedTensorType(): Type(TypeKind::UndefinedTensorType) {}
-};
-
-struct TensorType;
-using TensorTypePtr = std::shared_ptr<TensorType>;
-// This type represents a single Tensor with a specific size
-struct TORCH_API TensorType : public Type {
- static const TypeKind Kind = TypeKind::TensorType;
- template<typename ... T>
- static TensorTypePtr create( T&& ... all ) {
- return TensorTypePtr(new TensorType( std::forward<T>(all)... )); // NOLINT(modernize-make-shared)
- }
-
- at::ScalarType scalarType() const { return scalar_type_; }
- int device() const { return device_; }
- int dim() const { return dim_; }
- bool requires_grad() const override { return requires_grad_; }
-
- TensorTypePtr toScalarType(at::ScalarType type){
- auto t = TensorType::create(*this);
- t->scalar_type_ = type;
- return t;
- }
- TensorTypePtr withDim(int new_dim) {
- auto t = TensorType::create(*this);
- t->dim_ = new_dim;
- return t;
- }
- TensorTypePtr withRequiresGrad(bool req) {
- auto t = TensorType::create(*this);
- t->requires_grad_ = req;
- return t;
- }
-
- bool operator==(const Type& rhs) const override {
- if (rhs.kind() != TypeKind::TensorType)
- return false;
- auto rt = rhs.expect<TensorType>();
- return scalarType() == rt->scalarType() &&
- device() == rt->device() &&
- dim() == rt->dim();
- }
- bool isSubtypeOf(const TypePtr rhs) const override {
- if (rhs->kind() == TypeKind::DynamicType)
- return true;
- return rhs->kind() == TypeKind::TensorType && *this == *rhs;
- }
- bool isSubclass(const TypeKind kind) const override {
- return kind == TypeKind::DynamicType ||
- kind == TypeKind::TensorType;
- }
- std::string str() const override {
- // str is used for user-facing error messages, where we
- // don't want to reveal underlying size information.
- return "Tensor";
- }
-
-protected:
- TensorType(const at::Tensor& tensor, TypeKind kind=TypeKind::TensorType)
- : TensorType(tensor.type().scalarType(),
- tensor.is_cuda() ? tensor.get_device() : -1,
- tensor.dim(),
- tensor.is_variable() && tensor.requires_grad(),
- kind) {}
- TensorType(at::ScalarType scalar_type, int device, int dim, bool requires_grad=true, TypeKind kind=TypeKind::TensorType)
- : Type(kind)
- , scalar_type_(scalar_type)
- , requires_grad_(at::isFloatingType(scalar_type) && requires_grad)
- , device_(device)
- , dim_(dim) {}
-
- at::ScalarType scalar_type_;
- bool requires_grad_;
- int device_;
- int dim_;
-};
-
-struct CompleteTensorType;
-using CompleteTensorTypePtr = std::shared_ptr<CompleteTensorType>;
-// This type represents a single Tensor with a specific size
-struct TORCH_API CompleteTensorType : public TensorType {
- template<typename ... T>
- static CompleteTensorTypePtr create( T&& ... all ) {
- return CompleteTensorTypePtr(new CompleteTensorType( std::forward<T>(all)... )); // NOLINT(modernize-make-shared)
- }
-
- // overloaded create variadic template argument as it could not distinguish initializer list
- static CompleteTensorTypePtr create(at::ScalarType scalar_type, int device, at::IntList sizes) {
- return CompleteTensorTypePtr(new CompleteTensorType(scalar_type, device, sizes)); // NOLINT(modernize-make-shared)
- }
- static CompleteTensorTypePtr create(at::ScalarType scalar_type, int device, at::IntList sizes, at::IntList strides) {
- return CompleteTensorTypePtr(new CompleteTensorType(scalar_type, device, sizes, strides)); // NOLINT(modernize-make-shared)
- }
-
- static const TypeKind Kind = TypeKind::CompleteTensorType;
-
- const std::vector<int64_t>& sizes() const { return sizes_; }
- const std::vector<int64_t>& strides() const { return strides_; }
-
- TypePtr withSizesStrides(at::IntList sizes, at::IntList strides) const {
- return CompleteTensorType::create(scalar_type_, device_, sizes, strides);
- }
-
- TypePtr withSizes(at::IntList sizes) const {
- return withSizesStrides(sizes, CompleteTensorType::contiguousStridesOf(sizes));
- }
-
- CompleteTensorTypePtr contiguous() const {
- auto t = CompleteTensorType::create(*this);
- t->strides_ = CompleteTensorType::contiguousStridesOf(sizes_);
- return t;
- }
-
- CompleteTensorTypePtr toScalarType(at::ScalarType type){
- auto t = CompleteTensorType::create(*this);
- t->scalar_type_ = type;
- return t;
- }
-
- bool operator==(const Type& rhs) const override {
- if(rhs.kind() != kind())
- return false;
- auto rt = rhs.expect<CompleteTensorType>();
- return scalarType() == rt->scalarType() &&
- sizes() == rt->sizes() &&
- strides() == rt->strides() &&
- device() == rt->device();
- }
- bool isSubtypeOf(const TypePtr rhs) const override {
- if (rhs->kind() == TypeKind::DynamicType)
- return true;
- if (rhs->kind() == TypeKind::TensorType)
- return *expect<TensorType>() == *rhs;
- return *this == *rhs;
- }
- bool isSubclass(const TypeKind kind) const override {
- return kind == TypeKind::DynamicType ||
- kind == TypeKind::TensorType ||
- kind == TypeKind::CompleteTensorType;
- }
- std::string str() const override {
- // str is used for user-facing error messages, where we
- // don't want to reveal underlying size information.
- return "Tensor";
- }
- bool numel() const {
- size_t prod = 1;
- for(auto s : sizes()) {
- prod *= s;
- }
- return prod;
- }
- static TypePtr fromNumberType(TypePtr typ);
- static TypePtr fromBoolType();
-
-private:
- CompleteTensorType(const at::Tensor& tensor)
- : TensorType(tensor, TypeKind::CompleteTensorType)
- , sizes_(tensor.sizes().vec())
- , strides_(tensor.strides().vec()) {}
- CompleteTensorType(at::ScalarType scalar_type, int device, at::IntList sizes, bool requires_grad=true)
- : CompleteTensorType(scalar_type, device, sizes, CompleteTensorType::contiguousStridesOf(sizes), requires_grad) {}
- CompleteTensorType(at::ScalarType scalar_type, int device, at::IntList sizes, at::IntList strides, bool requires_grad=true)
- : TensorType(scalar_type, device, sizes.size(), requires_grad, TypeKind::CompleteTensorType)
- , sizes_(sizes.vec())
- , strides_(strides.vec()) {}
-
- static std::vector<int64_t> contiguousStridesOf(at::IntList sizes) {
- std::vector<int64_t> strides(sizes.size());
- if(sizes.empty()) // zero-dim case
- return strides;
- strides.back() = 1;
- for(size_t i = strides.size() - 1; i > 0; i--) {
- strides[i-1] = strides[i] * sizes[i];
- }
- return strides;
- }
-
- std::vector<int64_t> sizes_;
- std::vector<int64_t> strides_;
-};
-
-// common base for all types that have a single sub element
-// e.g. Future[T], Option[T], List[T]
-template<TypeKind K, typename T>
-struct SingleElementType : public Type {
- static const TypeKind Kind = K;
- TypePtr getElementType() const {
- return elem;
- }
- bool hasFreeVariables() const override {
- return has_free_variables_;
- }
- at::ArrayRef<TypePtr> containedTypes() const override {
- return elem;
- }
- bool requires_grad() const override {
- return elem->requires_grad();
- }
- bool operator==(const Type& rhs) const override {
- if(auto rhs_ = rhs.cast<T>()) {
- return *getElementType() == *rhs_->getElementType();
- }
- return false;
- }
-protected:
- SingleElementType(TypePtr elem)
- : Type(Kind)
- , elem(std::move(elem))
- , has_free_variables_(getElementType()->hasFreeVariables()) {}
-private:
- TypePtr elem;
- bool has_free_variables_;
-};
-
-struct ListType;
-using ListTypePtr = std::shared_ptr<ListType>;
-struct TORCH_API ListType : public SingleElementType<TypeKind::ListType, ListType> {
- // It's not exactly a singleton, but there should be exactly once instance of
- // List[T] for every T
- friend struct Type;
- template<typename ... T>
- static ListTypePtr create( T&& ... all ) {
- return ListTypePtr(new ListType( std::forward<T>(all)... )); // NOLINT(modernize-make-shared)
- }
- DEFINE_IS_SUBCLASS(ListType);
- std::string str() const override {
- std::stringstream ss;
- ss << getElementType()->str() << "[]";
- return ss.str();
- }
- std::string python_str() const override {
- std::stringstream ss;
- ss << "List[" << getElementType()->python_str() << "]";
- return ss.str();
- }
- TypePtr createWithContained(std::vector<TypePtr> contained_types) const override {
- return create(contained_types.at(0));
- }
- // common cast List[Tensor]
- static ListTypePtr ofTensors();
- static ListTypePtr ofInts();
- static ListTypePtr ofFloats();
- static ListTypePtr ofBools();
-private:
- using SingleElementType::SingleElementType;
-};
-
-struct FutureType;
-using FutureTypePtr = std::shared_ptr<FutureType>;
-
-struct TORCH_API FutureType : public Type {
- friend struct Type;
- template<typename ... T>
- static FutureTypePtr create(TypePtr elem) {
- return FutureTypePtr(new FutureType(std::move(elem))); // NOLINT(modernize-make-shared)
- }
-
- DEFINE_IS_SUBCLASS(FutureType);
-
- bool operator==(const Type& rhs) const override {
- if (auto rhs_ = rhs.cast<FutureType>()) {
- return *getElementType() == *rhs_->getElementType();
- }
- return false;
- }
- bool requires_grad() const override {
- return elem->requires_grad();
- }
- std::string str() const override {
- std::stringstream ss;
- ss << "Future(" << getElementType()->str() << ")";
- return ss.str();
- }
- std::string python_str() const override {
- std::stringstream ss;
- ss << "Future[" << getElementType()->python_str() << "]";
- return ss.str();
- }
- TypePtr getElementType() const {
- return elem;
- }
- bool hasFreeVariables() const override {
- return has_free_variables_;
- }
-
- static const TypeKind Kind = TypeKind::FutureType;
-private:
- FutureType(TypePtr elem)
- : Type(TypeKind::FutureType)
- , elem(std::move(elem))
- , has_free_variables_(getElementType()->hasFreeVariables()) {}
- TypePtr elem;
- bool has_free_variables_;
-};
-
-struct TupleType;
-using TupleTypePtr = std::shared_ptr<TupleType>;
-// This type represents a Tuple
-struct TORCH_API TupleType : public Type {
- static TupleTypePtr create(std::vector<TypePtr> types) {
- return TupleTypePtr(new TupleType( std::move(types) )); // NOLINT(modernize-make-shared)
- }
- DEFINE_IS_SUBCLASS(TupleType);
- at::ArrayRef<TypePtr> elements() const {
- return elements_;
- }
- bool operator==(const Type& rhs) const override {
- return compare(rhs, [](const TypePtr a, const TypePtr b) {
- return *a == *b;
- });
- }
- bool isSubtypeOf(const TypePtr rhs) const override {
- // co-variant rules for tuples
- return compare(*rhs, [](const TypePtr a, const TypePtr b) {
- return a->isSubtypeOf(b);
- });
- }
- bool requires_grad() const override {
- return std::any_of(elements_.begin(), elements_.end(),
- [](const TypePtr& ptr) { return ptr->requires_grad(); });
- }
- std::string str() const override {
- std::stringstream ss;
- ss << "(";
- for(size_t i = 0; i < elements().size(); ++i) {
- if(i > 0)
- ss << ", ";
- ss << elements()[i]->str();
- }
- ss << ")";
- return ss.str();
- }
- std::string python_str() const override {
- std::stringstream ss;
- ss << "Tuple[";
- for(size_t i = 0; i < elements().size(); ++i) {
- if(i > 0)
- ss << ", ";
- ss << elements()[i]->python_str();
- }
- ss << "]";
- return ss.str();
- }
- bool hasFreeVariables() const override {
- return has_free_variables_;
- }
-
- at::ArrayRef<TypePtr> containedTypes() const override {
- return elements_;
- }
- TypePtr createWithContained(std::vector<TypePtr> contained_types) const override {
- return create(std::move(contained_types));
- }
-
- static const TypeKind Kind = TypeKind::TupleType;
-private:
- TupleType(std::vector<TypePtr> elements_)
- : Type(TypeKind::TupleType)
- , elements_(std::move(elements_)) {
- has_free_variables_ =
- std::any_of(elements_.begin(), elements_.end(), [](TypePtr v) {
- return v->hasFreeVariables();
- });
- }
-
- bool compare(const Type& rhs, std::function<bool(const TypePtr, const TypePtr)> fn) const {
- if(rhs.kind() != kind())
- return false;
- const auto & l_elements = elements();
- const auto & r_elements = rhs.cast<TupleType>()->elements();
- if(l_elements.size() != r_elements.size())
- return false;
- for(size_t i = 0; i < l_elements.size(); ++i) {
- if(!fn(l_elements[i], r_elements[i]))
- return false;
- }
- return true;
- }
- std::vector<TypePtr> elements_;
- bool has_free_variables_;
-};
-
-struct NumberType;
-using NumberTypePtr = std::shared_ptr<NumberType>;
-// This type represents a Python number
-struct TORCH_API NumberType : public Type {
- static NumberTypePtr create() {
- return NumberTypePtr(new NumberType()); // NOLINT(modernize-make-shared)
- }
- DEFINE_IS_SUBCLASS(NumberType);
- bool operator==(const Type& rhs) const override {
- return rhs.kind() == kind();
- }
- std::string str() const override {
- return "Scalar"; // match what PythonArgParser says for clarity
- }
- static const TypeKind Kind = TypeKind::NumberType;
- // global singleton
- static NumberTypePtr get();
-private:
- NumberType()
- : Type(TypeKind::NumberType) {}
-};
-
-struct FloatType;
-using FloatTypePtr = std::shared_ptr<FloatType>;
-// This type represents a Python float number
-struct TORCH_API FloatType : public Type {
- static FloatTypePtr create() {
- return FloatTypePtr(new FloatType()); // NOLINT(modernize-make-shared)
- }
- DEFINE_IS_SUBCLASS(FloatType);
- bool operator==(const Type& rhs) const override {
- return rhs.kind() == kind();
- }
- std::string str() const override {
- return "float";
- }
- bool isSubtypeOf(const TypePtr rhs) const override {
- if(auto rhs_ = rhs->cast<OptionalType>()) {
- return this->isSubtypeOf(rhs_->getElementType());
- }
- return *this == *rhs || rhs->kind() == TypeKind::NumberType;
- }
- static const TypeKind Kind = TypeKind::FloatType;
- // global singleton
- static FloatTypePtr get();
-private:
- FloatType()
- : Type(TypeKind::FloatType) {}
-};
-
-struct IntType;
-using IntTypePtr = std::shared_ptr<IntType>;
-// This type represents a Python int number
-struct TORCH_API IntType : public Type {
- static IntTypePtr create() {
- return IntTypePtr(new IntType()); // NOLINT(modernize-make-shared)
- }
- DEFINE_IS_SUBCLASS(IntType);
- bool operator==(const Type& rhs) const override {
- return rhs.kind() == kind();
- }
- std::string str() const override {
- return "int";
- }
- bool isSubtypeOf(const TypePtr rhs) const override {
- if(auto rhs_ = rhs->cast<OptionalType>()) {
- return this->isSubtypeOf(rhs_->getElementType());
- }
- return *this == *rhs || rhs->kind() == TypeKind::NumberType;
- }
- static const TypeKind Kind = TypeKind::IntType;
- // global singleton
- static IntTypePtr get();
-private:
- IntType()
- : Type(TypeKind::IntType) {}
-};
-
-struct BoolType;
-using BoolTypePtr = std::shared_ptr<BoolType>;
-// This node represents a Python bool value
-struct TORCH_API BoolType : public Type {
- static BoolTypePtr create( ) {
- return BoolTypePtr(new BoolType());
- }
- DEFINE_IS_SUBCLASS(BoolType);
- bool operator==(const Type& rhs) const override {
- return rhs.kind() == kind();
- }
- std::string str() const override {
- return "bool";
- }
- bool isSubtypeOf(const TypePtr rhs) const override {
- return *this == *rhs || rhs->kind() == TypeKind::BoolType;
- }
- static const TypeKind Kind = TypeKind::BoolType;
- // global singleton
- static BoolTypePtr get();
-private:
- BoolType()
- : Type(TypeKind::BoolType) {}
-};
-
-struct StringType;
-using StringTypePtr = std::shared_ptr<StringType>;
-// This type represents a Python string
-struct TORCH_API StringType : public Type {
- static StringTypePtr create() {
- return StringTypePtr(new StringType()); // NOLINT(modernize-make-shared)
- }
- DEFINE_IS_SUBCLASS(StringType);
- bool operator==(const Type& rhs) const override {
- return rhs.kind() == kind();
- }
- std::string str() const override {
- return "string";
- }
- bool isSubtypeOf(const TypePtr rhs) const override {
- if(auto rhs_ = rhs->cast<OptionalType>()) {
- return this->isSubtypeOf(rhs_->getElementType());
- }
- return *this == *rhs;
- }
- static const TypeKind Kind = TypeKind::StringType;
- // global singleton
- static StringTypePtr get();
-private:
- StringType()
- : Type(TypeKind::StringType) {}
-};
-
-struct NoneType;
-using NoneTypePtr = std::shared_ptr<NoneType>;
-// This type represents a Python None
-struct NoneType : public Type {
- static NoneTypePtr create() {
- return NoneTypePtr(new NoneType()); // NOLINT(modernize-make-shared)
- }
- DEFINE_IS_SUBCLASS(NoneType);
- bool operator==(const Type& rhs) const override {
- return rhs.kind() == kind();
- }
-
- bool isSubtypeOf(const TypePtr rhs) const override {
- return rhs->kind() == TypeKind::NoneType ||
- rhs->kind() == TypeKind::OptionalType;
- }
-
- std::string str() const override {
- return "None";
- }
- static const TypeKind Kind = TypeKind::NoneType;
- // global singleton
- static NoneTypePtr get();
-private:
- NoneType()
- : Type(TypeKind::NoneType) {}
-};
-
-struct GeneratorType;
-using GeneratorTypePtr = std::shared_ptr<GeneratorType>;
-// This type represents a Generator
-struct GeneratorType : public Type {
- static GeneratorTypePtr create() {
- return GeneratorTypePtr(new GeneratorType()); // NOLINT(modernize-make-shared)
- }
- DEFINE_IS_SUBCLASS(GeneratorType);
- bool operator==(const Type& rhs) const override {
- return rhs.kind() == kind();
- }
- std::string str() const override {
- return "Generator";
- }
- static const TypeKind Kind = TypeKind::GeneratorType;
- // global singleton
- static GeneratorTypePtr get();
-private:
- GeneratorType()
- : Type(TypeKind::GeneratorType) {}
-};
-
-
-struct VarType;
-using VarTypePtr = std::shared_ptr<VarType>;
-// This type represents a type variable, used in FunctionSchema
-struct VarType : public Type {
- static VarTypePtr create(std::string name_) {
- return VarTypePtr(new VarType(std::move(name_)));
- }
- DEFINE_IS_SUBCLASS(VarType);
- bool operator==(const Type& rhs) const override {
- return rhs.kind() == kind();
- }
- std::string str() const override {
- return name();
- }
- static const TypeKind Kind = TypeKind::VarType;
- const std::string& name() const {
- return name_;
- }
- bool hasFreeVariables() const override {
- return true;
- }
-private:
- VarType(std::string name_)
- : Type(TypeKind::VarType), name_(std::move(name_)) {}
- std::string name_;
-};
-
-TORCH_API std::ostream& operator<<(std::ostream & out, const Type & t);
-// what is the type, ignoring extra size/shape information?
-// e.g. Tensor(2x3) -> Dynamic, and Tuple(Tensor(2x3),...) -> Tuple(Dynamic,...)
-
-inline TypePtr unshapedType(const TypePtr& type) {
- if (type->kind() == TypeKind::TensorType ||
- type->kind() == TypeKind::CompleteTensorType) {
- return DynamicType::get();
- }
- return type->withContained(fmap(type->containedTypes(), unshapedType));
-}
-
-inline TypePtr CompleteTensorType::fromNumberType(TypePtr typ) {
- JIT_ASSERT(typ->isSubtypeOf(NumberType::get()));
- if (typ->isSubtypeOf(IntType::get())) {
- return CompleteTensorType::create(at::kLong, -1, {});
- } else if (typ->isSubtypeOf(FloatType::get())) {
- return CompleteTensorType::create(at::kFloat, -1, {});
- } else if (typ->isSubtypeOf(BoolType::get())) {
- return CompleteTensorType::create(at::kLong, -1, {});
- }
- AT_ERROR("unknown number type", typ->str());
-}
-
-inline TypePtr CompleteTensorType::fromBoolType() {
- return CompleteTensorType::create(at::kLong, -1, {});
-}
-
-// Attempt to find the correct supertype of t1 and t2. If none is found then
-// nullopt will be returned. If t1 == t2, or t1 is a type refinement of t2,
-// then t2 will be returned (and vice versa).
-// Two different tensortypes will return dynamic.
-// Currently we chose not to support returning a NumberType for a float & int
-// input because of a lack of operator support for NumberType
-TORCH_API c10::optional<TypePtr> unifyTypes(
- const TypePtr& t1,
- const TypePtr& t2);
-
-template <typename T>
-TypePtr getTypePtr() {
-#define TYPE_STR(Type) #Type, " ",
- AT_ERROR(
- "Type ",
- c10::demangle_type<T>(),
- " could not be converted to any of the known types { ",
- TH_FORALL_TYPES(TYPE_STR) "}");
-#undef TYPE_STR
- return nullptr;
-}
+#define C10_USING(T) using ::c10::T;
+ C10_FORALL_TYPES(C10_USING)
+#undef C10_USING
-template<> inline TypePtr getTypePtr<at::Tensor>() { return DynamicType::get(); }
-template<> inline TypePtr getTypePtr<double>() { return FloatType::get(); }
-template<> inline TypePtr getTypePtr<int64_t>() { return IntType::get(); }
-template<> inline TypePtr getTypePtr<bool>() { return BoolType::get(); }
-template<> inline TypePtr getTypePtr<at::Scalar>() { return NumberType::get(); }
-template<> inline TypePtr getTypePtr<std::vector<at::Tensor>>() { return ListType::ofTensors(); }
-template<> inline TypePtr getTypePtr<std::vector<double>>() { return ListType::ofFloats(); }
-template<> inline TypePtr getTypePtr<std::vector<int64_t>>() { return ListType::ofInts(); }
+#define C10_USING(T) using ::c10::T##Ptr;
+ C10_FORALL_TYPES(C10_USING)
+#undef C10_USING
-TORCH_API TypePtr inferTypeFrom(const IValue& value);
+using ::c10::Type;
+using ::c10::TypePtr;
+using ::c10::TypeEnv;
+using ::c10::TypeMatchError;
-struct TORCH_API TypeMatchError : public std::exception {
- TypeMatchError(std::string msg_)
- : msg_(std::move(msg_)) {}
- const char * what() const noexcept override {
- return msg_.c_str();
- }
-private:
- std::string msg_;
-};
-using TypeEnv = std::unordered_map<std::string, TypePtr>;
-TORCH_API TypePtr matchTypeVariables(TypePtr formal, TypePtr actual, TypeEnv & type_env);
-TORCH_API TypePtr evalTypeVariables(TypePtr type, TypeEnv & type_env);
+using ::c10::getTypePtr;
+using ::c10::TypeKind;
}} // namespace torch::jit
diff --git a/torch/csrc/utils/functional.h b/torch/csrc/utils/functional.h
index af5099e7ce..7864d2a0ab 100644
--- a/torch/csrc/utils/functional.h
+++ b/torch/csrc/utils/functional.h
@@ -1,63 +1,8 @@
-#pragma once
-
-#include <vector>
-#include <ATen/ATen.h>
+#include <ATen/core/functional.h>
namespace torch {
-// The passed in function must take T by value (T), or by
-// const reference (const T&); taking T by non-const reference
-// will result in an error like:
-//
-// error: no type named 'type' in 'class std::result_of<foobar::__lambda(T)>'
-//
-// No explicit template parameters are required.
-
-// Overload for explicit function and ArrayRef
-template<typename F, typename T>
-inline auto fmap(const T& inputs, const F& fn) -> std::vector<decltype(fn(*inputs.begin()))> {
- std::vector<decltype(fn(*inputs.begin()))> r;
- r.reserve(inputs.size());
- for(const auto & input : inputs)
- r.push_back(fn(input));
- return r;
-}
-
-template<typename F, typename T>
-inline auto fmap(T& inputs, const F& fn) -> std::vector<decltype(fn(*inputs.begin()))> {
- std::vector<decltype(fn(*inputs.begin()))> r;
- r.reserve(inputs.size());
- for(auto & input : inputs)
- r.push_back(fn(input));
- return r;
-}
-
-// C++ forbids taking an address of a constructor, so here's a workaround...
-// Overload for constructor (R) application
-template<typename R, typename T>
-inline std::vector<R> fmap(const T& inputs) {
- std::vector<R> r;
- r.reserve(inputs.size());
- for(auto & input : inputs)
- r.push_back(R(input));
- return r;
-}
-
-template<typename F, typename T>
-inline std::vector<T> filter(at::ArrayRef<T> inputs, const F& fn) {
- std::vector<T> r;
- r.reserve(inputs.size());
- for(auto & input : inputs) {
- if (fn(input)) {
- r.push_back(input);
- }
- }
- return r;
-}
-
-template<typename F, typename T>
-inline std::vector<T> filter(const std::vector<T>& inputs, const F& fn) {
- return filter<F, T>(static_cast<at::ArrayRef<T>>(inputs), fn);
-}
+using ::c10::fmap;
+using ::c10::filter;
-}
+} // namespace torch