#pragma once #include "torch/csrc/jit/assertions.h" #include "torch/csrc/WindowsTorchApiMacro.h" #include #include namespace torch { namespace jit { template using Shared = c10::intrusive_ptr; // string struct TORCH_API ConstantString : c10::intrusive_ptr_target { private: const std::string str_; public: ConstantString(std::string str) : str_(std::move(str)) {} static c10::intrusive_ptr create(const std::string str_) { return c10::make_intrusive(str_); } const std::string & string() const { return str_; } operator const std::string & () const { return string(); } TORCH_API friend std::ostream& operator<<(std::ostream& out, const ConstantString & v); }; // non-mutable list template struct TORCH_API ConstantList : c10::intrusive_ptr_target { private: std::vector elements_; public: ConstantList(std::vector elements_) : elements_(std::move(elements_)) {} static c10::intrusive_ptr> create(std::vector elements_) { return c10::make_intrusive>(std::move(elements_)); } const std::vector& elements() const { return elements_; } operator const std::vector&() const { return elements(); } }; struct IValue; using Tuple = ConstantList; using IntList = ConstantList; using TensorList = ConstantList; using DoubleList = ConstantList; // IValue is the generic tagged union used by the interpreter to hold // all value types. // It is a 16-byte object with an 8-byte payload and an 8-byte tag. // The tag is currently 4 bytes to determine the type, and 1 byte // to mark whether that type is a subtype of c10::intrusive_ptr_target and needs // retain/release calls. #define TORCH_FORALL_TAGS(_) \ _(None) _(Tensor) _(Double) _(Int) _(Tuple) _(IntList) _(DoubleList) _(String) _(TensorList) struct TORCH_API IValue { IValue() : payload(0) , tag(Tag::None) , is_intrusive_ptr(false) {} IValue(const IValue& rhs) : payload(rhs.payload), tag(rhs.tag), is_intrusive_ptr(rhs.is_intrusive_ptr) { if (is_intrusive_ptr) { c10::raw::intrusive_ptr::incref(as_intrusive_ptr); } } IValue(IValue&& rhs) noexcept : IValue() { swap(rhs); } ~IValue() { if (is_intrusive_ptr) { c10::raw::intrusive_ptr::decref(as_intrusive_ptr); } } IValue & operator=(IValue && rhs) & noexcept { rhs.swap(*this); return *this; } IValue & operator=(IValue const & rhs) & { IValue(rhs).swap(*this); return *this; } void swap(IValue & rhs) noexcept { std::swap(payload, rhs.payload); std::swap(is_intrusive_ptr, rhs.is_intrusive_ptr); std::swap(tag, rhs.tag); } // Accessors for subtypes are arranged together below // While some of these accessors could be generated through templates, // we prefer to write them manually for clarity // Tensor IValue(at::Tensor t) : tag(Tag::Tensor), is_intrusive_ptr(t.defined()) { // Note: the undefined tensor is not refcounted, so while it // is tagged as a tensor, is_intrusive_ptr is set to false. // This is not an optional optimization: our incref call // *will not* do the right thing when called on an // undefined tensor. as_tensor_impl = t.unsafeReleaseTensorImpl(); } bool isTensor() const { return Tag::Tensor == tag; } at::Tensor toTensor() && { JIT_ASSERT(isTensor()); at::Tensor t(as_tensor_impl, /*retain=*/false); clearToNone(); return t; } at::Tensor toTensor() const & { JIT_ASSERT(isTensor()); return at::Tensor(as_tensor_impl, /*retain=*/true); } // Tuple IValue(c10::intrusive_ptr v); bool isTuple() const { return Tag::Tuple == tag; } c10::intrusive_ptr toTuple() && { JIT_ASSERT(isTuple()); return moveToIntrusivePtr(); } c10::intrusive_ptr toTuple() const & { JIT_ASSERT(isTuple()); return toIntrusivePtr(); } // Double IValue(double d) : tag(Tag::Double), is_intrusive_ptr(false) { as_double = d; } bool isDouble() const { return Tag::Double == tag; } double toDouble() const { JIT_ASSERT(isDouble()); return as_double; } // Int IValue(int64_t i) : tag(Tag::Int), is_intrusive_ptr(false) { as_int = i; } // allow you to pass literals (3, 4) without ambiguity IValue(int32_t i) : IValue(static_cast(i)) {} IValue(bool b) : IValue(static_cast(b)) {} bool isInt() const { return Tag::Int == tag; } int64_t toInt() const { JIT_ASSERT(isInt()); return as_int; } // IntList IValue(c10::intrusive_ptr v); IValue(std::vector v); IValue(at::ArrayRef v) : IValue(std::vector(v.begin(), v.end())) {} bool isIntList() const { return Tag::IntList == tag; } c10::intrusive_ptr toIntList() && { JIT_ASSERT(isIntList()); return moveToIntrusivePtr(); } c10::intrusive_ptr toIntList() const & { JIT_ASSERT(isIntList()); return toIntrusivePtr(); } const std::vector& toIntListRef() const; const std::vector& toDoubleListRef() const; const std::vector& toTensorListRef() const; // ConstantString IValue(c10::intrusive_ptr v); IValue(const std::string& v); bool isString() const { return Tag::String == tag; } c10::intrusive_ptr toString() && { JIT_ASSERT(isString()); return moveToIntrusivePtr(); } c10::intrusive_ptr toString() const & { JIT_ASSERT(isString()); return toIntrusivePtr(); } // DoubleList IValue(c10::intrusive_ptr v); IValue(std::vector v); bool isDoubleList() const { return Tag::DoubleList == tag; } c10::intrusive_ptr toDoubleList() && { JIT_ASSERT(isDoubleList()); return moveToIntrusivePtr(); } c10::intrusive_ptr toDoubleList() const & { JIT_ASSERT(isDoubleList()); return toIntrusivePtr(); } //TensorList IValue(c10::intrusive_ptr v); IValue(std::vector v); bool isTensorList() const { return Tag::TensorList == tag; } c10::intrusive_ptr toTensorList() && { JIT_ASSERT(isTensorList()); return moveToIntrusivePtr(); } c10::intrusive_ptr toTensorList() const & { JIT_ASSERT(isTensorList()); return toIntrusivePtr(); } // None bool isNone() { return Tag::None == tag; } std::string toNone() const { return "None"; } // Scalar, which gets encoded as either an Int or a Double IValue(at::Scalar s) : IValue() { if(s.isFloatingPoint()) { *this = s.toDouble(); } else { *this = s.toLong(); } } bool isScalar() { return isDouble() || isInt(); } at::Scalar toScalar() const { if(isDouble()) return toDouble(); else if(isInt()) return toInt(); else throw std::runtime_error("IValue is not a Scalar"); } // for debugging std::string tagKind() const { switch(tag) { #define DEFINE_CASE(x) case Tag::x: return #x; TORCH_FORALL_TAGS(DEFINE_CASE) #undef DEFINE_CASE } return "Invalid Tag"; } // generic v.to() implementations // that can be used in special functions like pop/push // that use template meta-programming. // prefer the directly named methods when you can, // since they are simpler to understand // Note: if you get linker errors saying one of these is missing, // change it to ... && = delete; and you will see better error messages for why // However, we cannot commit this because some compiler versions barf on it. template T to() &&; template T to() const &; TORCH_API friend std::ostream& operator<<(std::ostream & out, const IValue & v); private: // NOTE: IValue tags are intentionally private. In the future we may encode // this value different (e.g. using NaN boxing), and this would make it more // costly to determine the tag for all types vs just determining if something // is a particular type. Instead we want clients to use the `isX` methods when // possible. If for perf. reasons you really, absolutely, must have a jump // table, then we can revisit this. enum class Tag : uint32_t { #define DEFINE_TAG(x) x, TORCH_FORALL_TAGS(DEFINE_TAG) #undef DEFINE_TAG }; template c10::intrusive_ptr moveToIntrusivePtr() { auto t = c10::intrusive_ptr::reclaim(static_cast(as_intrusive_ptr)); clearToNone(); return t; } template c10::intrusive_ptr toIntrusivePtr() const { auto r = c10::intrusive_ptr::reclaim(static_cast(as_intrusive_ptr)); auto p = r; r.release(); return p; } void clearToNone() { payload = 0; tag = Tag::None; is_intrusive_ptr = false; } union { at::TensorImpl* as_tensor_impl; c10::intrusive_ptr_target* as_intrusive_ptr; double as_double; int64_t as_int; // this type should be as big as all the other types because it will // be used to copy the union's value in certain cases int64_t payload; }; Tag tag; bool is_intrusive_ptr; }; #undef TORCH_FORALL_TAGS #define DEFINE_TO(type, method_name) \ template<> \ inline type IValue::to() && { \ return std::move(*this).method_name(); \ } \ template<> \ inline type IValue::to() const & { \ return this->method_name(); \ } DEFINE_TO(at::Tensor, toTensor) DEFINE_TO(c10::intrusive_ptr, toTuple) DEFINE_TO(double, toDouble) DEFINE_TO(int64_t, toInt) DEFINE_TO(c10::intrusive_ptr, toDoubleList) DEFINE_TO(c10::intrusive_ptr, toIntList) DEFINE_TO(c10::intrusive_ptr, toTensorList) DEFINE_TO(c10::intrusive_ptr, toString) DEFINE_TO(at::Scalar, toScalar) DEFINE_TO(bool, toInt) DEFINE_TO(std::vector, toIntListRef) DEFINE_TO(std::vector, toDoubleListRef) DEFINE_TO(std::vector, toTensorListRef) #undef DEFINE_TO inline IValue::IValue(c10::intrusive_ptr v) : tag(Tag::Tuple), is_intrusive_ptr(true) { as_intrusive_ptr = v.release(); } inline IValue::IValue(c10::intrusive_ptr v) : tag(Tag::IntList), is_intrusive_ptr(true) { as_intrusive_ptr = v.release(); } inline IValue::IValue(std::vector v) : IValue(IntList::create(std::move(v))) {} inline IValue::IValue(c10::intrusive_ptr v) : tag(Tag::String), is_intrusive_ptr(true) { as_intrusive_ptr = v.release(); } inline IValue::IValue(const std::string& v) : IValue(ConstantString::create(v)) {} inline IValue::IValue(c10::intrusive_ptr v) : tag(Tag::DoubleList), is_intrusive_ptr(true) { as_intrusive_ptr = v.release(); } inline IValue::IValue(std::vector v) : IValue(DoubleList::create(std::move(v))) {} inline IValue::IValue(c10::intrusive_ptr v) : tag(Tag::TensorList), is_intrusive_ptr(true) { as_intrusive_ptr = v.release(); } inline IValue::IValue(std::vector v) : IValue(TensorList::create(std::move(v))) {} inline const std::vector& IValue::toIntListRef() const { return toIntList()->elements(); } inline const std::vector& IValue::toDoubleListRef() const { return toDoubleList()->elements(); } inline const std::vector& IValue::toTensorListRef() const { return toTensorList()->elements(); } }}