summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--aten/src/ATen/core/ivalue.cpp2
-rw-r--r--aten/src/ATen/core/ivalue.h52
-rw-r--r--test/expect/TestBatched.test_for.expect4
-rw-r--r--test/expect/TestBatched.test_if_else.expect8
-rw-r--r--test/expect/TestBatched.test_if_else_with_scalar.expect8
-rw-r--r--test/expect/TestBatched.test_if_noelse.expect8
-rw-r--r--test/expect/TestBatched.test_if_noelse_with_scalar.expect8
-rw-r--r--test/expect/TestBatched.test_while.expect14
-rw-r--r--test/expect/TestJit.test_batchnorm.expect4
-rw-r--r--test/expect/TestJit.test_constant_prop_if_constant.expect4
-rw-r--r--test/expect/TestJit.test_constant_prop_loop_constant.expect8
-rw-r--r--test/expect/TestJit.test_constant_prop_nested.expect2
-rw-r--r--test/expect/TestJit.test_conv.expect8
-rw-r--r--test/expect/TestJit.test_dropout.expect2
-rw-r--r--test/expect/TestJit.test_inplace_copy.expect17
-rw-r--r--test/expect/TestJit.test_pretty_printer-if_one.expect2
-rw-r--r--test/expect/TestJit.test_pretty_printer-if_test.expect2
-rw-r--r--test/expect/TestJit.test_pretty_printer-loop_use_test.expect4
-rw-r--r--test/expect/TestJit.test_pretty_printer-while_if_test.expect6
-rw-r--r--test/expect/TestJit.test_pretty_printer-while_test.expect4
-rw-r--r--test/expect/TestJit.test_recursive_cse.expect2
-rw-r--r--test/expect/TestScript.test_if_list.expect2
-rw-r--r--test/expect/TestScript.test_if_supertype.expect2
-rw-r--r--test/expect/TestScript.test_index_put_trace_with_view.expect2
-rw-r--r--test/expect/TestScript.test_index_put_trace_without_view.expect2
-rw-r--r--test/expect/TestScript.test_logical_short_circuit.expect14
-rw-r--r--test/expect/TestScript.test_loop_unroll_unused_counter.expect6
-rw-r--r--test/expect/TestScript.test_loop_unrolling.expect6
-rw-r--r--test/expect/TestScript.test_loop_unrolling_nested.expect10
-rw-r--r--test/expect/TestScript.test_sum-1.expect2
-rw-r--r--test/expect/TestScript.test_sum-2.expect2
-rw-r--r--test/expect/TestScript.test_type_cast-test_float_to_int.expect (renamed from test/expect/TestScript.test_type_cast-float_to_int.expect)2
-rw-r--r--test/expect/TestScript.test_type_cast-test_int_to_float.expect (renamed from test/expect/TestScript.test_type_cast-int_to_float.expect)0
-rw-r--r--test/test_jit.py54
-rw-r--r--tools/jit/gen_jit_dispatch.py2
-rw-r--r--torch/csrc/jit/attributes.h1
-rw-r--r--torch/csrc/jit/autodiff.cpp2
-rw-r--r--torch/csrc/jit/constants.cpp19
-rw-r--r--torch/csrc/jit/export.cpp2
-rw-r--r--torch/csrc/jit/import.cpp2
-rw-r--r--torch/csrc/jit/interned_strings.h2
-rw-r--r--torch/csrc/jit/interpreter.cpp27
-rw-r--r--torch/csrc/jit/ir.cpp10
-rw-r--r--torch/csrc/jit/ir.h14
-rw-r--r--torch/csrc/jit/operator.cpp7
-rw-r--r--torch/csrc/jit/passes/erase_number_types.cpp7
-rw-r--r--torch/csrc/jit/passes/peephole.cpp2
-rw-r--r--torch/csrc/jit/passes/remove_expands.cpp3
-rw-r--r--torch/csrc/jit/passes/shape_analysis.cpp102
-rw-r--r--torch/csrc/jit/passes/to_batch.cpp24
-rw-r--r--torch/csrc/jit/pybind_utils.h6
-rw-r--r--torch/csrc/jit/python_ir.cpp4
-rw-r--r--torch/csrc/jit/register_prim_ops.cpp82
-rw-r--r--torch/csrc/jit/script/compiler.cpp9
-rw-r--r--torch/csrc/jit/script/init.cpp6
-rw-r--r--torch/csrc/jit/type.cpp14
-rw-r--r--torch/csrc/jit/type.h37
-rw-r--r--torch/jit/batchop.py10
58 files changed, 440 insertions, 227 deletions
diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp
index 5df0a5b49c..58e90516c9 100644
--- a/aten/src/ATen/core/ivalue.cpp
+++ b/aten/src/ATen/core/ivalue.cpp
@@ -6,9 +6,11 @@
_(Tensor) \
_(Double) \
_(Int) \
+ _(Bool) \
_(Tuple) \
_(IntList) \
_(DoubleList) \
+ _(BoolList) \
_(String) \
_(TensorList) \
_(Blob) \
diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h
index 5e210d638d..0bdd85a587 100644
--- a/aten/src/ATen/core/ivalue.h
+++ b/aten/src/ATen/core/ivalue.h
@@ -74,6 +74,7 @@ struct C10_EXPORT Tuple : public List<IValue> {
using IntList = List<int64_t>;
using TensorList = List<at::Tensor>;
using DoubleList = List<double>;
+using BoolList = List<bool>;
using GenericList = List<IValue>;
// IValue is the generic tagged union used by the interpreter to hold
@@ -88,9 +89,11 @@ using GenericList = List<IValue>;
_(Tensor) \
_(Double) \
_(Int) \
+ _(Bool) \
_(Tuple) \
_(IntList) \
_(DoubleList) \
+ _(BoolList) \
_(String) \
_(TensorList) \
_(Blob) \
@@ -224,8 +227,6 @@ struct CAFFE2_API IValue final {
// allow you to pass literals (3, 4) without ambiguity
IValue(int32_t i)
: IValue(static_cast<int64_t>(i)) {}
- IValue(bool b)
- : IValue(static_cast<int64_t>(b)) {}
bool isInt() const { return Tag::Int == tag; }
@@ -234,6 +235,17 @@ struct CAFFE2_API IValue final {
return payload.as_int;
}
+ // Bool
+ IValue(bool b)
+ : tag(Tag::Bool), is_intrusive_ptr(false) {
+ payload.as_bool = b;
+ }
+ bool isBool() const { return Tag::Bool == tag; }
+ bool toBool() const {
+ AT_ASSERT(isBool());
+ return payload.as_bool;
+ }
+
// IntList
IValue(c10::intrusive_ptr<IntList> v);
IValue(std::vector<int64_t> v);
@@ -251,6 +263,7 @@ struct CAFFE2_API IValue final {
const std::vector<int64_t>& toIntListRef() const;
const std::vector<double>& toDoubleListRef() const;
+ const std::vector<bool>& toBoolListRef() const;
const std::vector<at::Tensor>& toTensorListRef() const;
const std::vector<IValue>& toGenericListRef() const;
@@ -280,6 +293,19 @@ struct CAFFE2_API IValue final {
return toIntrusivePtr<DoubleList>();
}
+ // BoolList
+ IValue(c10::intrusive_ptr<BoolList> v);
+ IValue(std::vector<bool> v);
+ bool isBoolList() const { return Tag::BoolList == tag; }
+ c10::intrusive_ptr<BoolList> toBoolList() && {
+ AT_ASSERT(isBoolList());
+ return moveToIntrusivePtr<BoolList>();
+ }
+ c10::intrusive_ptr<BoolList> toBoolList() const & {
+ AT_ASSERT(isBoolList());
+ return toIntrusivePtr<BoolList>();
+ }
+
//TensorList
IValue(c10::intrusive_ptr<TensorList> v);
IValue(std::vector<at::Tensor> v);
@@ -323,15 +349,16 @@ struct CAFFE2_API IValue final {
}
}
bool isScalar() {
- return isDouble() || isInt();
+ return isDouble() || isInt() || isBool();
}
at::Scalar toScalar() const {
if(isDouble())
return toDouble();
else if(isInt())
return toInt();
- else
- throw std::runtime_error("IValue is not a Scalar");
+ else if (isBool())
+ return int(toBool());
+ throw std::runtime_error("IValue is not a Scalar");
}
// for debugging
@@ -396,6 +423,7 @@ struct CAFFE2_API IValue final {
union {
int64_t as_int;
double as_double;
+ bool as_bool;
c10::intrusive_ptr_target* as_intrusive_ptr;
World as_world;
} payload;
@@ -419,15 +447,16 @@ DEFINE_TO(at::Tensor, toTensor)
DEFINE_TO(c10::intrusive_ptr<Tuple>, toTuple)
DEFINE_TO(double, toDouble)
DEFINE_TO(int64_t, toInt)
+DEFINE_TO(bool, toBool)
DEFINE_TO(c10::intrusive_ptr<DoubleList>, toDoubleList)
DEFINE_TO(c10::intrusive_ptr<IntList>, toIntList)
DEFINE_TO(c10::intrusive_ptr<TensorList>, toTensorList)
DEFINE_TO(c10::intrusive_ptr<GenericList>, toGenericList)
DEFINE_TO(c10::intrusive_ptr<ConstantString>, toString)
DEFINE_TO(at::Scalar, toScalar)
-DEFINE_TO(bool, toInt)
DEFINE_TO(std::vector<int64_t>, toIntListRef)
DEFINE_TO(std::vector<double>, toDoubleListRef)
+DEFINE_TO(std::vector<bool>, toBoolListRef)
DEFINE_TO(std::vector<at::Tensor>, toTensorListRef)
DEFINE_TO(std::vector<IValue>, toGenericListRef)
DEFINE_TO(World, toWorld)
@@ -490,6 +519,13 @@ inline IValue::IValue(c10::intrusive_ptr<DoubleList> v)
inline IValue::IValue(std::vector<double> v)
: IValue(DoubleList::create(std::move(v))) {}
+inline IValue::IValue(c10::intrusive_ptr<BoolList> v)
+: tag(Tag::BoolList), is_intrusive_ptr(true) {
+ payload.as_intrusive_ptr = v.release();
+}
+inline IValue::IValue(std::vector<bool> v)
+: IValue(BoolList::create(std::move(v))) {}
+
inline IValue::IValue(c10::intrusive_ptr<TensorList> v)
: tag(Tag::TensorList), is_intrusive_ptr(true) {
payload.as_intrusive_ptr = v.release();
@@ -517,6 +553,10 @@ inline const std::vector<at::Tensor>& IValue::toTensorListRef() const {
return toTensorList()->elements();
}
+inline const std::vector<bool>& IValue::toBoolListRef() const {
+ return toBoolList()->elements();
+}
+
inline const std::vector<IValue>& IValue::toGenericListRef() const {
return toGenericList()->elements();
}
diff --git a/test/expect/TestBatched.test_for.expect b/test/expect/TestBatched.test_for.expect
index 8932957402..6b15b3e799 100644
--- a/test/expect/TestBatched.test_for.expect
+++ b/test/expect/TestBatched.test_for.expect
@@ -5,7 +5,7 @@ graph(%x.1_data : Dynamic
%y_mask : Dynamic
%y_dims : Dynamic) {
%6 : int = prim::Constant[value=10]()
- %7 : int = prim::Constant[value=1]()
+ %7 : bool = prim::Constant[value=1]()
%x : Dynamic, %9 : Dynamic, %10 : Dynamic = prim::Loop(%6, %7, %x.1_data, %x.1_mask, %x.1_dims)
block0(%loop_num : int, %5_data : Dynamic, %5_mask : Dynamic, %5_dims : Dynamic) {
%15 : int = prim::Constant[value=1]()
@@ -14,7 +14,7 @@ graph(%x.1_data : Dynamic
%data.1 : Dynamic = aten::add(%5_data, %y_data, %alpha)
%mask : Dynamic = aten::mul(%5_mask, %y_mask)
%dims : Dynamic = aten::__or__(%5_dims, %y_dims)
- %21 : int = prim::Constant[value=1]()
+ %21 : bool = prim::Constant[value=1]()
%data : Dynamic = aten::where(%mask, %data.1, %5_data)
-> (%21, %data, %mask, %dims)
}
diff --git a/test/expect/TestBatched.test_if_else.expect b/test/expect/TestBatched.test_if_else.expect
index ddb6763469..86475a61e3 100644
--- a/test/expect/TestBatched.test_if_else.expect
+++ b/test/expect/TestBatched.test_if_else.expect
@@ -7,7 +7,7 @@ graph(%a.1_data : Dynamic
%6 : Dynamic = aten::gt(%a.1_data, %b_data)
%7 : Dynamic = aten::mul(%a.1_mask, %b_mask)
%8 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
- %9 : int = prim::TensorToNum(%6)
+ %9 : bool = prim::TensorToBool(%6)
%10 : int = prim::Constant[value=1]()
%11 : Long() = prim::NumToTensor(%10)
%alpha.1 : float = prim::TensorToNum(%11)
@@ -24,17 +24,17 @@ graph(%a.1_data : Dynamic
%23 : Dynamic = aten::type_as(%7, %6)
%cond_mask.1 : Dynamic = aten::mul(%6, %23)
%25 : int = aten::dim(%cond_mask.1)
- %26 : int = aten::eq(%25, %22)
+ %26 : bool = aten::eq(%25, %22)
%cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%26)
block0() {
%30 : int = aten::dim(%data.1)
%31 : int = aten::sub(%30, %22)
- %32 : int = prim::Constant[value=1]()
+ %32 : bool = prim::Constant[value=1]()
%data.3 : Dynamic = prim::Loop(%31, %32, %cond_mask.1)
block0(%_ : int, %35 : Dynamic) {
%36 : int = aten::dim(%35)
%data.2 : Dynamic = aten::unsqueeze(%35, %36)
- %38 : int = prim::Constant[value=1]()
+ %38 : bool = prim::Constant[value=1]()
-> (%38, %data.2)
}
%cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)
diff --git a/test/expect/TestBatched.test_if_else_with_scalar.expect b/test/expect/TestBatched.test_if_else_with_scalar.expect
index 08057e4d10..cbe4a9f05b 100644
--- a/test/expect/TestBatched.test_if_else_with_scalar.expect
+++ b/test/expect/TestBatched.test_if_else_with_scalar.expect
@@ -8,7 +8,7 @@ graph(%a.1_data : Dynamic
%7 : Float() = prim::NumToTensor(%6)
%other : float = prim::TensorToNum(%7)
%9 : Dynamic = aten::gt(%a.1_data, %other)
- %10 : int = prim::TensorToNum(%9)
+ %10 : bool = prim::TensorToBool(%9)
%11 : int = prim::Constant[value=1]()
%12 : Long() = prim::NumToTensor(%11)
%alpha.1 : float = prim::TensorToNum(%12)
@@ -25,17 +25,17 @@ graph(%a.1_data : Dynamic
%24 : Dynamic = aten::type_as(%a.1_mask, %9)
%cond_mask.1 : Dynamic = aten::mul(%9, %24)
%26 : int = aten::dim(%cond_mask.1)
- %27 : int = aten::eq(%26, %23)
+ %27 : bool = aten::eq(%26, %23)
%cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%27)
block0() {
%31 : int = aten::dim(%data.1)
%32 : int = aten::sub(%31, %23)
- %33 : int = prim::Constant[value=1]()
+ %33 : bool = prim::Constant[value=1]()
%data.3 : Dynamic = prim::Loop(%32, %33, %cond_mask.1)
block0(%_ : int, %36 : Dynamic) {
%37 : int = aten::dim(%36)
%data.2 : Dynamic = aten::unsqueeze(%36, %37)
- %39 : int = prim::Constant[value=1]()
+ %39 : bool = prim::Constant[value=1]()
-> (%39, %data.2)
}
%cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)
diff --git a/test/expect/TestBatched.test_if_noelse.expect b/test/expect/TestBatched.test_if_noelse.expect
index a408563916..8dbed77571 100644
--- a/test/expect/TestBatched.test_if_noelse.expect
+++ b/test/expect/TestBatched.test_if_noelse.expect
@@ -7,7 +7,7 @@ graph(%a.1_data : Dynamic
%6 : Dynamic = aten::gt(%a.1_data, %b_data)
%7 : Dynamic = aten::mul(%a.1_mask, %b_mask)
%8 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
- %9 : int = prim::TensorToNum(%6)
+ %9 : bool = prim::TensorToBool(%6)
%10 : int = prim::Constant[value=1]()
%11 : Long() = prim::NumToTensor(%10)
%alpha : float = prim::TensorToNum(%11)
@@ -18,17 +18,17 @@ graph(%a.1_data : Dynamic
%17 : Dynamic = aten::type_as(%7, %6)
%cond_mask.1 : Dynamic = aten::mul(%6, %17)
%19 : int = aten::dim(%cond_mask.1)
- %20 : int = aten::eq(%19, %16)
+ %20 : bool = aten::eq(%19, %16)
%cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%20)
block0() {
%24 : int = aten::dim(%data.1)
%25 : int = aten::sub(%24, %16)
- %26 : int = prim::Constant[value=1]()
+ %26 : bool = prim::Constant[value=1]()
%data.3 : Dynamic = prim::Loop(%25, %26, %cond_mask.1)
block0(%_ : int, %29 : Dynamic) {
%30 : int = aten::dim(%29)
%data.2 : Dynamic = aten::unsqueeze(%29, %30)
- %32 : int = prim::Constant[value=1]()
+ %32 : bool = prim::Constant[value=1]()
-> (%32, %data.2)
}
%cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)
diff --git a/test/expect/TestBatched.test_if_noelse_with_scalar.expect b/test/expect/TestBatched.test_if_noelse_with_scalar.expect
index 087868ea16..d8f453c965 100644
--- a/test/expect/TestBatched.test_if_noelse_with_scalar.expect
+++ b/test/expect/TestBatched.test_if_noelse_with_scalar.expect
@@ -8,7 +8,7 @@ graph(%a.1_data : Dynamic
%7 : Float() = prim::NumToTensor(%6)
%other : float = prim::TensorToNum(%7)
%9 : Dynamic = aten::gt(%a.1_data, %other)
- %10 : int = prim::TensorToNum(%9)
+ %10 : bool = prim::TensorToBool(%9)
%11 : int = prim::Constant[value=1]()
%12 : Long() = prim::NumToTensor(%11)
%alpha : float = prim::TensorToNum(%12)
@@ -19,17 +19,17 @@ graph(%a.1_data : Dynamic
%18 : Dynamic = aten::type_as(%a.1_mask, %9)
%cond_mask.1 : Dynamic = aten::mul(%9, %18)
%20 : int = aten::dim(%cond_mask.1)
- %21 : int = aten::eq(%20, %17)
+ %21 : bool = aten::eq(%20, %17)
%cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%21)
block0() {
%25 : int = aten::dim(%data.1)
%26 : int = aten::sub(%25, %17)
- %27 : int = prim::Constant[value=1]()
+ %27 : bool = prim::Constant[value=1]()
%data.3 : Dynamic = prim::Loop(%26, %27, %cond_mask.1)
block0(%_ : int, %30 : Dynamic) {
%31 : int = aten::dim(%30)
%data.2 : Dynamic = aten::unsqueeze(%30, %31)
- %33 : int = prim::Constant[value=1]()
+ %33 : bool = prim::Constant[value=1]()
-> (%33, %data.2)
}
%cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)
diff --git a/test/expect/TestBatched.test_while.expect b/test/expect/TestBatched.test_while.expect
index 7aba7a89ac..5cd196b56f 100644
--- a/test/expect/TestBatched.test_while.expect
+++ b/test/expect/TestBatched.test_while.expect
@@ -8,12 +8,12 @@ graph(%a.1_data : Dynamic
%7 : Dynamic = aten::gt(%a.1_data, %b_data)
%8 : Dynamic = aten::mul(%a.1_mask, %b_mask)
%9 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
- %10 : int = prim::TensorToNum(%7)
+ %10 : bool = prim::TensorToBool(%7)
%11 : int = prim::Constant[value=0]()
%12 : Dynamic = aten::mul(%7, %8)
%13 : Dynamic = aten::sum(%12)
%14 : Dynamic = aten::gt(%13, %11)
- %15 : int = prim::TensorToNum(%14)
+ %15 : bool = prim::TensorToBool(%14)
%16 : Dynamic, %17 : Dynamic, %18 : Dynamic, %a : Dynamic, %20 : Dynamic, %21 : Dynamic = prim::Loop(%6, %15, %7, %8, %9, %a.1_data, %a.1_mask, %a.1_dims)
block0(%loop_num : int, %cond_data.2 : Dynamic, %cond_mask.3 : Dynamic, %cond_dims : Dynamic, %6_data : Dynamic, %6_mask : Dynamic, %6_dims : Dynamic) {
%29 : int = prim::Constant[value=1]()
@@ -25,22 +25,22 @@ graph(%a.1_data : Dynamic
%35 : Dynamic = aten::gt(%data.1, %b_data)
%36 : Dynamic = aten::mul(%mask, %b_mask)
%37 : Dynamic = aten::__or__(%dims, %b_dims)
- %38 : int = prim::TensorToNum(%35)
+ %38 : bool = prim::TensorToBool(%35)
%39 : int = prim::Constant[value=1]()
%40 : Dynamic = aten::type_as(%cond_mask.3, %cond_data.2)
%cond_mask.1 : Dynamic = aten::mul(%cond_data.2, %40)
%42 : int = aten::dim(%cond_mask.1)
- %43 : int = aten::eq(%42, %39)
+ %43 : bool = aten::eq(%42, %39)
%cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%43)
block0() {
%47 : int = aten::dim(%data.1)
%48 : int = aten::sub(%47, %39)
- %49 : int = prim::Constant[value=1]()
+ %49 : bool = prim::Constant[value=1]()
%data.3 : Dynamic = prim::Loop(%48, %49, %cond_mask.1)
block0(%_ : int, %52 : Dynamic) {
%53 : int = aten::dim(%52)
%data.2 : Dynamic = aten::unsqueeze(%52, %53)
- %55 : int = prim::Constant[value=1]()
+ %55 : bool = prim::Constant[value=1]()
-> (%55, %data.2)
}
%cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)
@@ -57,7 +57,7 @@ graph(%a.1_data : Dynamic
%62 : Dynamic = aten::mul(%35, %36)
%63 : Dynamic = aten::sum(%62)
%64 : Dynamic = aten::gt(%63, %61)
- %65 : int = prim::TensorToNum(%64)
+ %65 : bool = prim::TensorToBool(%64)
-> (%65, %35, %36, %37, %res_data, %res_mask, %res_dims)
}
return (%a, %20, %21);
diff --git a/test/expect/TestJit.test_batchnorm.expect b/test/expect/TestJit.test_batchnorm.expect
index c61390578d..e1fc75d5a4 100644
--- a/test/expect/TestJit.test_batchnorm.expect
+++ b/test/expect/TestJit.test_batchnorm.expect
@@ -4,10 +4,10 @@ graph(%0 : Double(2, 2, 2, 2)
%3 : Double(2)
%4 : Double(2)
%5 : Long()) {
- %6 : int = prim::Constant[value=1](), scope: BatchNorm2d
+ %6 : bool = prim::Constant[value=1](), scope: BatchNorm2d
%7 : float = prim::Constant[value=0.1](), scope: BatchNorm2d
%8 : float = prim::Constant[value=1e-05](), scope: BatchNorm2d
- %9 : int = prim::Constant[value=1](), scope: BatchNorm2d
+ %9 : bool = prim::Constant[value=1](), scope: BatchNorm2d
%10 : Double(2, 2, 2, 2) = aten::batch_norm(%0, %1, %2, %3, %4, %6, %7, %8, %9), scope: BatchNorm2d
return (%10);
}
diff --git a/test/expect/TestJit.test_constant_prop_if_constant.expect b/test/expect/TestJit.test_constant_prop_if_constant.expect
index fa1cb8053a..d373c1ee12 100644
--- a/test/expect/TestJit.test_constant_prop_if_constant.expect
+++ b/test/expect/TestJit.test_constant_prop_if_constant.expect
@@ -1,10 +1,10 @@
graph(%a : Dynamic
%b : Dynamic) {
%c2.1 : int = prim::Constant[value=1]()
- %3 : int = prim::TensorToNum(%a)
+ %3 : bool = prim::TensorToBool(%a)
%c0.4 : int, %c1 : int = prim::If(%3)
block0() {
- %6 : int = prim::TensorToNum(%b)
+ %6 : bool = prim::TensorToBool(%b)
%c0.3 : int = prim::If(%6)
block0() {
%8 : int = prim::Constant[value=2]()
diff --git a/test/expect/TestJit.test_constant_prop_loop_constant.expect b/test/expect/TestJit.test_constant_prop_loop_constant.expect
index d0d4c8ed7d..f5bc3f5720 100644
--- a/test/expect/TestJit.test_constant_prop_loop_constant.expect
+++ b/test/expect/TestJit.test_constant_prop_loop_constant.expect
@@ -3,16 +3,16 @@ graph() {
%b.2 : int = prim::Constant[value=1]()
%2 : int = prim::Constant[value=2147483647]()
%b.1 : int = prim::Constant[value=0]()
- %4 : int = prim::Constant[value=1]()
+ %4 : bool = prim::Constant[value=1]()
%b.3 : int = prim::Loop(%2, %4, %b.1)
block0(%6 : int, %7 : int) {
- %8 : int = prim::Constant[value=1]()
+ %8 : bool = prim::Constant[value=1]()
-> (%8, %b.2)
}
- %9 : int = prim::Constant[value=0]()
+ %9 : bool = prim::Constant[value=0]()
%b : int = prim::Loop(%2, %9, %b.3)
block0(%11 : int, %12 : int) {
- %13 : int = prim::Constant[value=0]()
+ %13 : bool = prim::Constant[value=0]()
-> (%13, %b.4)
}
return (%b);
diff --git a/test/expect/TestJit.test_constant_prop_nested.expect b/test/expect/TestJit.test_constant_prop_nested.expect
index 09ef82076e..4a644a0d00 100644
--- a/test/expect/TestJit.test_constant_prop_nested.expect
+++ b/test/expect/TestJit.test_constant_prop_nested.expect
@@ -1,7 +1,7 @@
graph(%a : Dynamic) {
%1 : int = prim::Constant[value=2]()
%2 : Dynamic = aten::lt(%a, %1)
- %3 : int = prim::TensorToNum(%2)
+ %3 : bool = prim::TensorToBool(%2)
%c : int = prim::If(%3)
block0() {
%5 : int = prim::Constant[value=5]()
diff --git a/test/expect/TestJit.test_conv.expect b/test/expect/TestJit.test_conv.expect
index fcb53bad14..20bb072015 100644
--- a/test/expect/TestJit.test_conv.expect
+++ b/test/expect/TestJit.test_conv.expect
@@ -10,14 +10,14 @@ graph(%0 : Double(20, 16, 50, 40)
%9 : int = prim::Constant[value=1](), scope: Conv2d
%10 : int = prim::Constant[value=1](), scope: Conv2d
%11 : int[] = prim::ListConstruct(%9, %10), scope: Conv2d
- %12 : int = prim::Constant[value=0](), scope: Conv2d
+ %12 : bool = prim::Constant[value=0](), scope: Conv2d
%13 : int = prim::Constant[value=0](), scope: Conv2d
%14 : int = prim::Constant[value=0](), scope: Conv2d
%15 : int[] = prim::ListConstruct(%13, %14), scope: Conv2d
%16 : int = prim::Constant[value=1](), scope: Conv2d
- %17 : int = prim::Constant[value=0](), scope: Conv2d
- %18 : int = prim::Constant[value=0](), scope: Conv2d
- %19 : int = prim::Constant[value=1](), scope: Conv2d
+ %17 : bool = prim::Constant[value=0](), scope: Conv2d
+ %18 : bool = prim::Constant[value=0](), scope: Conv2d
+ %19 : bool = prim::Constant[value=1](), scope: Conv2d
%20 : Double(20, 13, 48, 38) = aten::_convolution(%0, %1, %2, %5, %8, %11, %12, %15, %16, %17, %18, %19), scope: Conv2d
return (%20);
}
diff --git a/test/expect/TestJit.test_dropout.expect b/test/expect/TestJit.test_dropout.expect
index 3daa3484c6..3d0d7d312d 100644
--- a/test/expect/TestJit.test_dropout.expect
+++ b/test/expect/TestJit.test_dropout.expect
@@ -1,6 +1,6 @@
graph(%0 : Double(2, 2)) {
%1 : float = prim::Constant[value=0.6](), scope: Dropout
- %2 : int = prim::Constant[value=1](), scope: Dropout
+ %2 : bool = prim::Constant[value=1](), scope: Dropout
%3 : Double(2, 2) = aten::dropout(%0, %1, %2), scope: Dropout
return (%3);
}
diff --git a/test/expect/TestJit.test_inplace_copy.expect b/test/expect/TestJit.test_inplace_copy.expect
new file mode 100644
index 0000000000..f046063c97
--- /dev/null
+++ b/test/expect/TestJit.test_inplace_copy.expect
@@ -0,0 +1,17 @@
+graph(%0 : Double(4, 4)) {
+ %1 : int = prim::Constant[value=0]()
+ %2 : int = aten::size(%0, %1)
+ %3 : Long() = prim::NumToTensor(%2)
+ %4 : int = prim::TensorToNum(%3)
+ %5 : int = prim::Constant[value=1]()
+ %6 : int = aten::size(%0, %5)
+ %7 : Long() = prim::NumToTensor(%6)
+ %8 : int = prim::TensorToNum(%7)
+ %9 : int[] = prim::ListConstruct(%4, %8)
+ %10 : int = prim::Constant[value=7]()
+ %11 : int = prim::Constant[value=0]()
+ %12 : int[] = prim::Constant[value=[0, -1]]()
+ %13 : Double(4, 4) = aten::zeros(%9, %10, %11, %12)
+ %14 : Double(4, 4) = aten::expand_as(%0, %13)
+ return (%14);
+}
diff --git a/test/expect/TestJit.test_pretty_printer-if_one.expect b/test/expect/TestJit.test_pretty_printer-if_one.expect
index 3a9254be45..cbede6c812 100644
--- a/test/expect/TestJit.test_pretty_printer-if_one.expect
+++ b/test/expect/TestJit.test_pretty_printer-if_one.expect
@@ -1,6 +1,6 @@
def script(c2, c1):
t2 = aten::lt(c2, c1)
- t3 = prim::TensorToNum(t2)
+ t3 = prim::TensorToBool(t2)
if t3:
c = c2
else:
diff --git a/test/expect/TestJit.test_pretty_printer-if_test.expect b/test/expect/TestJit.test_pretty_printer-if_test.expect
index 130c9b0c5a..c7e2249078 100644
--- a/test/expect/TestJit.test_pretty_printer-if_test.expect
+++ b/test/expect/TestJit.test_pretty_printer-if_test.expect
@@ -1,6 +1,6 @@
def script(c2, c1):
t2 = aten::lt(c2, c1)
- t3 = prim::TensorToNum(t2)
+ t3 = prim::TensorToBool(t2)
if t3:
c = c1
else:
diff --git a/test/expect/TestJit.test_pretty_printer-loop_use_test.expect b/test/expect/TestJit.test_pretty_printer-loop_use_test.expect
index 4e35ad2150..01934775d8 100644
--- a/test/expect/TestJit.test_pretty_printer-loop_use_test.expect
+++ b/test/expect/TestJit.test_pretty_printer-loop_use_test.expect
@@ -2,14 +2,14 @@ def script(y1):
x = aten::add(y1, 1, 1)
z1 = aten::add(x, 5, 1)
t9 = aten::lt(y1, 8)
- t10 = prim::TensorToNum(t9)
+ t10 = prim::TensorToBool(t9)
y = y1
z = z1
t11 = t10
while t11:
y2 = aten::add(y, 1, 1)
t17 = aten::lt(y2, 8)
- t18 = prim::TensorToNum(t17)
+ t18 = prim::TensorToBool(t17)
t11 = t18
y = y2
z = x
diff --git a/test/expect/TestJit.test_pretty_printer-while_if_test.expect b/test/expect/TestJit.test_pretty_printer-while_if_test.expect
index c830784510..70d4af6150 100644
--- a/test/expect/TestJit.test_pretty_printer-while_if_test.expect
+++ b/test/expect/TestJit.test_pretty_printer-while_if_test.expect
@@ -1,6 +1,6 @@
def script(a1, b1):
t5 = aten::lt(a1, 10)
- t6 = prim::TensorToNum(t5)
+ t6 = prim::TensorToBool(t5)
a = a1
b = b1
c = 0
@@ -9,13 +9,13 @@ def script(a1, b1):
a2 = aten::add(a, 1, 1)
b2 = aten::add(b, 1, 1)
t15 = aten::gt(a2, b2)
- t16 = prim::TensorToNum(t15)
+ t16 = prim::TensorToBool(t15)
if t16:
c4 = 2
else:
c4 = 3
t21 = aten::lt(a2, 10)
- t22 = prim::TensorToNum(t21)
+ t22 = prim::TensorToBool(t21)
t7 = t22
a = a2
b = b2
diff --git a/test/expect/TestJit.test_pretty_printer-while_test.expect b/test/expect/TestJit.test_pretty_printer-while_test.expect
index 487087ad56..f99e4721f0 100644
--- a/test/expect/TestJit.test_pretty_printer-while_test.expect
+++ b/test/expect/TestJit.test_pretty_printer-while_test.expect
@@ -1,6 +1,6 @@
def script(a1, i1):
t4 = aten::lt(i1, 3)
- t5 = prim::TensorToNum(t4)
+ t5 = prim::TensorToBool(t4)
a = a1
i = i1
t6 = t5
@@ -8,7 +8,7 @@ def script(a1, i1):
a2 = aten::mul(a, a)
i2 = aten::add(i, 1, 1)
t13 = aten::lt(i2, 3)
- t14 = prim::TensorToNum(t13)
+ t14 = prim::TensorToBool(t13)
t6 = t14
a = a2
i = i2
diff --git a/test/expect/TestJit.test_recursive_cse.expect b/test/expect/TestJit.test_recursive_cse.expect
index a2aa84b5d6..dd56a51193 100644
--- a/test/expect/TestJit.test_recursive_cse.expect
+++ b/test/expect/TestJit.test_recursive_cse.expect
@@ -3,7 +3,7 @@ graph(%z.1 : Dynamic
%2 : int = prim::Constant[value=1]()
%3 : Dynamic = aten::add(%z.1, %y, %2)
%4 : Dynamic = aten::gt(%3, %z.1)
- %5 : int = prim::TensorToNum(%4)
+ %5 : bool = prim::TensorToBool(%4)
%z : Dynamic = prim::If(%5)
block0() {
-> (%3)
diff --git a/test/expect/TestScript.test_if_list.expect b/test/expect/TestScript.test_if_list.expect
index fddb6a173a..93c942d4dc 100644
--- a/test/expect/TestScript.test_if_list.expect
+++ b/test/expect/TestScript.test_if_list.expect
@@ -1,5 +1,5 @@
graph(%x : Double(*, *)) {
- %1 : int = prim::Constant[value=1]()
+ %1 : bool = prim::Constant[value=1]()
%c : Dynamic[] = prim::If(%1)
block0() {
%c.1 : Dynamic[] = prim::ListConstruct(%x, %x)
diff --git a/test/expect/TestScript.test_if_supertype.expect b/test/expect/TestScript.test_if_supertype.expect
index 3b58ca8928..e57f49a27c 100644
--- a/test/expect/TestScript.test_if_supertype.expect
+++ b/test/expect/TestScript.test_if_supertype.expect
@@ -1,7 +1,7 @@
graph(%y.1 : Float(*, *)
%z.2 : Long(*, *)
%z.1 : Float(*, *)) {
- %3 : int = prim::Constant[value=1]()
+ %3 : bool = prim::Constant[value=1]()
%x : Float(*, *), %y : Dynamic, %z : Dynamic = prim::If(%3)
block0() {
-> (%y.1, %z.2, %z.1)
diff --git a/test/expect/TestScript.test_index_put_trace_with_view.expect b/test/expect/TestScript.test_index_put_trace_with_view.expect
index cc03d3d529..12bae3f8b4 100644
--- a/test/expect/TestScript.test_index_put_trace_with_view.expect
+++ b/test/expect/TestScript.test_index_put_trace_with_view.expect
@@ -4,7 +4,7 @@ graph(%0 : Double(100)
%3 : int = prim::Constant[value=4]()
%4 : int[] = prim::ListConstruct(%3)
%5 : Double(4) = aten::view(%2, %4)
- %6 : int = prim::Constant[value=0]()
+ %6 : bool = prim::Constant[value=0]()
%7 : Long(4) = aten::_cast_Long(%1, %6)
%8 : Dynamic[] = prim::ListConstruct(%7)
%9 : Double(100) = aten::index_put(%0, %8, %5)
diff --git a/test/expect/TestScript.test_index_put_trace_without_view.expect b/test/expect/TestScript.test_index_put_trace_without_view.expect
index c725067960..8e5da7efd4 100644
--- a/test/expect/TestScript.test_index_put_trace_without_view.expect
+++ b/test/expect/TestScript.test_index_put_trace_without_view.expect
@@ -1,7 +1,7 @@
graph(%0 : Double(100)
%1 : Long(4)
%2 : Double(4)) {
- %3 : int = prim::Constant[value=0]()
+ %3 : bool = prim::Constant[value=0]()
%4 : Long(4) = aten::_cast_Long(%1, %3)
%5 : Dynamic[] = prim::ListConstruct(%4)
%6 : Double(100) = aten::index_put(%0, %5, %2)
diff --git a/test/expect/TestScript.test_logical_short_circuit.expect b/test/expect/TestScript.test_logical_short_circuit.expect
index bcd34a7cb0..0a4569c2b4 100644
--- a/test/expect/TestScript.test_logical_short_circuit.expect
+++ b/test/expect/TestScript.test_logical_short_circuit.expect
@@ -1,31 +1,31 @@
graph(%t : Dynamic) {
%c1.2 : int = prim::Constant[value=0]()
%c1.1 : int = prim::Constant[value=1]()
- %3 : int = prim::Constant[value=0]()
- %4 : int = prim::If(%3)
+ %3 : bool = prim::Constant[value=0]()
+ %4 : bool = prim::If(%3)
block0() {
%5 : int = prim::Constant[value=0]()
%6 : Dynamic = aten::select(%t, %5, %c1.1)
- %7 : int = prim::TensorToNum(%6)
+ %7 : bool = prim::TensorToBool(%6)
-> (%7)
}
block1() {
-> (%3)
}
- %8 : int = prim::If(%4)
+ %8 : bool = prim::If(%4)
block0() {
-> (%4)
}
block1() {
- %9 : int = prim::Constant[value=1]()
- %10 : int = prim::If(%9)
+ %9 : bool = prim::Constant[value=1]()
+ %10 : bool = prim::If(%9)
block0() {
-> (%9)
}
block1() {
%11 : int = prim::Constant[value=0]()
%12 : Dynamic = aten::select(%t, %11, %c1.1)
- %13 : int = prim::TensorToNum(%12)
+ %13 : bool = prim::TensorToBool(%12)
-> (%13)
}
-> (%10)
diff --git a/test/expect/TestScript.test_loop_unroll_unused_counter.expect b/test/expect/TestScript.test_loop_unroll_unused_counter.expect
index 292b251c75..feb204a925 100644
--- a/test/expect/TestScript.test_loop_unroll_unused_counter.expect
+++ b/test/expect/TestScript.test_loop_unroll_unused_counter.expect
@@ -2,7 +2,7 @@ graph(%x : Dynamic) {
%1 : int = prim::Constant[value=1]()
%y.1 : int = prim::Constant[value=0]()
%3 : int = prim::TensorToNum(%x)
- %4 : int = prim::Constant[value=1]()
+ %4 : bool = prim::Constant[value=1]()
%5 : int = prim::Constant[value=8]()
%6 : int = aten::floordiv(%3, %5)
%7 : int = prim::Constant[value=8]()
@@ -18,13 +18,13 @@ graph(%x : Dynamic) {
%y.9 : int = aten::add(%y.8, %1)
%y.10 : int = aten::add(%y.9, %1)
%y.11 : int = aten::add(%y.10, %1)
- %21 : int = prim::Constant[value=1]()
+ %21 : bool = prim::Constant[value=1]()
-> (%21, %y.11)
}
%y : int = prim::Loop(%9, %4, %y.3)
block0(%i : int, %24 : int) {
%y.4 : int = aten::add(%24, %1)
- %26 : int = prim::Constant[value=1]()
+ %26 : bool = prim::Constant[value=1]()
-> (%26, %y.4)
}
return (%y);
diff --git a/test/expect/TestScript.test_loop_unrolling.expect b/test/expect/TestScript.test_loop_unrolling.expect
index 29546b5c1e..5f3e61a4ce 100644
--- a/test/expect/TestScript.test_loop_unrolling.expect
+++ b/test/expect/TestScript.test_loop_unrolling.expect
@@ -1,7 +1,7 @@
graph(%x : Dynamic) {
%y.1 : int = prim::Constant[value=0]()
%2 : int = prim::TensorToNum(%x)
- %3 : int = prim::Constant[value=1]()
+ %3 : bool = prim::Constant[value=1]()
%4 : int = prim::Constant[value=0]()
%5 : int = prim::Constant[value=8]()
%6 : int = aten::floordiv(%2, %5)
@@ -32,7 +32,7 @@ graph(%x : Dynamic) {
%34 : int = prim::Constant[value=1]()
%35 : int = aten::add(%32, %34)
%y.11 : int = aten::add(%y.10, %35)
- %37 : int = prim::Constant[value=1]()
+ %37 : bool = prim::Constant[value=1]()
%38 : int = prim::Constant[value=1]()
%39 : int = aten::add(%35, %38)
-> (%37, %39, %y.11)
@@ -40,7 +40,7 @@ graph(%x : Dynamic) {
%40 : Dynamic, %y : int = prim::Loop(%9, %3, %10, %y.3)
block0(%i : int, %43 : int, %44 : int) {
%y.4 : int = aten::add(%44, %43)
- %46 : int = prim::Constant[value=1]()
+ %46 : bool = prim::Constant[value=1]()
%47 : int = prim::Constant[value=1]()
%48 : int = aten::add(%43, %47)
-> (%46, %48, %y.4)
diff --git a/test/expect/TestScript.test_loop_unrolling_nested.expect b/test/expect/TestScript.test_loop_unrolling_nested.expect
index a107eb81f7..2668b111ab 100644
--- a/test/expect/TestScript.test_loop_unrolling_nested.expect
+++ b/test/expect/TestScript.test_loop_unrolling_nested.expect
@@ -1,11 +1,11 @@
graph(%x : Dynamic) {
%1 : int = prim::Constant[value=10]()
%y.1 : int = prim::Constant[value=0]()
- %3 : int = prim::Constant[value=1]()
+ %3 : bool = prim::Constant[value=1]()
%y : int = prim::Loop(%1, %3, %y.1)
block0(%i : int, %6 : int) {
%7 : int = prim::TensorToNum(%x)
- %8 : int = prim::Constant[value=1]()
+ %8 : bool = prim::Constant[value=1]()
%9 : int = prim::Constant[value=0]()
%10 : int = prim::Constant[value=8]()
%11 : int = aten::floordiv(%7, %10)
@@ -36,7 +36,7 @@ graph(%x : Dynamic) {
%39 : int = prim::Constant[value=1]()
%40 : int = aten::add(%37, %39)
%y.12 : int = aten::add(%y.11, %40)
- %42 : int = prim::Constant[value=1]()
+ %42 : bool = prim::Constant[value=1]()
%43 : int = prim::Constant[value=1]()
%44 : int = aten::add(%40, %43)
-> (%42, %44, %y.12)
@@ -44,12 +44,12 @@ graph(%x : Dynamic) {
%45 : Dynamic, %y.3 : int = prim::Loop(%14, %8, %15, %y.4)
block0(%j : int, %48 : int, %49 : int) {
%y.5 : int = aten::add(%49, %48)
- %51 : int = prim::Constant[value=1]()
+ %51 : bool = prim::Constant[value=1]()
%52 : int = prim::Constant[value=1]()
%53 : int = aten::add(%48, %52)
-> (%51, %53, %y.5)
}
- %54 : int = prim::Constant[value=1]()
+ %54 : bool = prim::Constant[value=1]()
-> (%54, %y.3)
}
return (%y);
diff --git a/test/expect/TestScript.test_sum-1.expect b/test/expect/TestScript.test_sum-1.expect
index a2bb9d4417..ce0d1fe0b5 100644
--- a/test/expect/TestScript.test_sum-1.expect
+++ b/test/expect/TestScript.test_sum-1.expect
@@ -1,7 +1,7 @@
graph(%x : Dynamic) {
%1 : int = prim::Constant[value=4]()
%2 : int[] = prim::ListConstruct(%1)
- %3 : int = prim::Constant[value=0]()
+ %3 : bool = prim::Constant[value=0]()
%4 : Dynamic = aten::sum(%x, %2, %3)
return (%4);
}
diff --git a/test/expect/TestScript.test_sum-2.expect b/test/expect/TestScript.test_sum-2.expect
index 4b2352de81..a952901438 100644
--- a/test/expect/TestScript.test_sum-2.expect
+++ b/test/expect/TestScript.test_sum-2.expect
@@ -1,7 +1,7 @@
graph(%x : Double(*, *, *, *, *)) {
%1 : int = prim::Constant[value=4]()
%2 : int[] = prim::ListConstruct(%1)
- %3 : int = prim::Constant[value=0]()
+ %3 : bool = prim::Constant[value=0]()
%4 : Dynamic = aten::sum(%x, %2, %3)
return (%4);
}
diff --git a/test/expect/TestScript.test_type_cast-float_to_int.expect b/test/expect/TestScript.test_type_cast-test_float_to_int.expect
index 138b6e4913..ffc45ac565 100644
--- a/test/expect/TestScript.test_type_cast-float_to_int.expect
+++ b/test/expect/TestScript.test_type_cast-test_float_to_int.expect
@@ -1,6 +1,6 @@
graph() {
%0 : int = prim::Constant[value=1]()
- %1 : float = prim::Constant[value=2]()
+ %1 : float = prim::Constant[value=5]()
%b : int = prim::FloatToInt(%1)
%3 : int = aten::add(%b, %0)
return (%3);
diff --git a/test/expect/TestScript.test_type_cast-int_to_float.expect b/test/expect/TestScript.test_type_cast-test_int_to_float.expect
index 719cc1e881..719cc1e881 100644
--- a/test/expect/TestScript.test_type_cast-int_to_float.expect
+++ b/test/expect/TestScript.test_type_cast-test_int_to_float.expect
diff --git a/test/test_jit.py b/test/test_jit.py
index 24d8076d31..08ebca8af1 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -1365,7 +1365,6 @@ class TestJit(JitTestCase):
# as all backwards functions of views are implemented
# as a zero filled tensor with a gradient fill on the
# viewed portion.
- @unittest.expectedFailure
def test_inplace_copy(self):
x = torch.randn(4, 4, requires_grad=True)
@@ -3763,7 +3762,7 @@ a")
self.checkScript(func, inputs, optimize=True)
def test_explicit_bool_cast(self):
- with self.assertRaisesRegex(RuntimeError, "expected an integer"):
+ with self.assertRaisesRegex(RuntimeError, "expected a boolean"):
@torch.jit.script
def test_bool_cast(a):
if a:
@@ -3987,18 +3986,38 @@ a")
throwsAnd(t)
def test_type_cast(self):
+ @torch.jit.script
def test_int_to_float():
b = float(2)
return b + 1.0
+ with self.assertRaisesRegex(RuntimeError, "Cannot cast type"):
+ @torch.jit.script
+ def test_int_to_bool():
+ return bool(5)
+
+ @torch.jit.script
def test_float_to_int():
- b = int(2.0)
+ b = int(5.0)
return b + 1
- graph1 = torch.jit.script(test_int_to_float).graph
- self.assertExpectedGraph(graph1, subname="int_to_float")
- graph2 = torch.jit.script(test_float_to_int).graph
- self.assertExpectedGraph(graph2, subname="float_to_int")
+ with self.assertRaisesRegex(RuntimeError, "Cannot cast type"):
+ @torch.jit.script
+ def test_float_to_bool():
+ return bool(5.0)
+
+ with self.assertRaisesRegex(RuntimeError, "Cannot cast type"):
+ @torch.jit.script
+ def test_bool_to_float():
+ return float(True)
+
+ with self.assertRaisesRegex(RuntimeError, "Cannot cast type"):
+ @torch.jit.script
+ def test_bool_to_int():
+ return int(True)
+
+ self.assertExpectedGraph(test_int_to_float.graph, "test_int_to_float")
+ self.assertExpectedGraph(test_float_to_int.graph, "test_float_to_int")
def test_multiple_assignment(self):
def outer_func(x):
@@ -7757,27 +7776,6 @@ EXCLUDE_TYPE_CHECK = {
# known to be failing in script
EXCLUDE_SCRIPT = {
- # TODO: Fix var/std
- # there are two schemas for var (and std):
- # (1) var(Tensor, int, *, bool, bool, Tensor)
- # (2) var(Tensor, *, bool)
- #
- # Right now, the following is happening:
- # - Shorter schemas come before longer schemas
- # - bool, int are treated as IntType rather than DynamicType like before
- # So the schemas look like the following in operator:
- # (2) var(DynamicType, IntType)
- # (1) var(DynamicType, IntType, IntType, DynamicType)
- # Now, when one calls torch.var(tensor, dim=1), the compiler mistakingly
- # matches it with (2) instead of (1), which is a problem.
- 'test_std_dim',
- 'test_std_dim_1d',
- 'test_std_dim_1d_neg0',
- 'test_std_dim_neg0',
- 'test_var_dim',
- 'test_var_dim_1d',
- 'test_var_dim_1d_neg0',
- 'test_var_dim_neg0',
'test_norm_fro',
'test_norm_fro_default',
'test_norm_nuc',
diff --git a/tools/jit/gen_jit_dispatch.py b/tools/jit/gen_jit_dispatch.py
index f6fdc7505d..4fc923589d 100644
--- a/tools/jit/gen_jit_dispatch.py
+++ b/tools/jit/gen_jit_dispatch.py
@@ -64,7 +64,7 @@ FROM_IVALUE = {
'ScalarType': '{}.to<at::ScalarType>()',
'Tensor': '{}.toTensor()',
'TensorList': '{}.toTensorList()->elements()',
- 'bool': 'bool({}.toInt())',
+ 'bool': '{}.toBool()',
'double': '{}.toDouble()',
'int64_t': '{}.toInt()',
'std::string': '{}.toString()->string()',
diff --git a/torch/csrc/jit/attributes.h b/torch/csrc/jit/attributes.h
index 2610643918..af3a16393f 100644
--- a/torch/csrc/jit/attributes.h
+++ b/torch/csrc/jit/attributes.h
@@ -167,6 +167,7 @@ struct Attributes {
const Kind##Attr::ValueType& method(Symbol name) const { \
return get<Kind##Attr>(name); \
}
+
CREATE_ACCESSOR(Float,f)
CREATE_ACCESSOR(Floats,fs)
CREATE_ACCESSOR(String,s)
diff --git a/torch/csrc/jit/autodiff.cpp b/torch/csrc/jit/autodiff.cpp
index a59c856eab..46840fc99a 100644
--- a/torch/csrc/jit/autodiff.cpp
+++ b/torch/csrc/jit/autodiff.cpp
@@ -280,7 +280,7 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val
} else if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
return {grads.at(0).mm(inputs.at(1).t()), inputs.at(0).t().mm(grads.at(0))};
- } else if (node->matches("aten::expand(Tensor self, int[] size, *, int implicit) -> Tensor")) {
+ } else if (node->matches("aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor")) {
const auto& input_sizes = inputs.at(0).sizes();
if (input_sizes.size() == 0)
return {grads.at(0).sum(), nullptr, nullptr};
diff --git a/torch/csrc/jit/constants.cpp b/torch/csrc/jit/constants.cpp
index 1633ac0de4..c62cdbd1fa 100644
--- a/torch/csrc/jit/constants.cpp
+++ b/torch/csrc/jit/constants.cpp
@@ -27,6 +27,13 @@ Value* insertConstant(
} else if(val.isDouble()) {
n->f_(attr::value, val.toDouble());
n->output()->setType(FloatType::get());
+ } else if (val.isBool()) {
+ n->i_(attr::value, val.toBool());
+ n->output()->setType(BoolType::get());
+ } else if (val.isBoolList()) {
+ auto bool_list = val.toBoolList()->elements();
+ n->is_(attr::value, std::vector<int64_t>(bool_list.begin(), bool_list.end()));
+ n->output()->setType(ListType::ofBools());
} else if(val.isIntList()) {
n->is_(attr::value, val.toIntList()->elements());
n->output()->setType(ListType::ofInts());
@@ -64,6 +71,12 @@ RegisterOperators reg({
stack.push_back(t);
return 0;
};
+ } else if (type->isSubtypeOf(BoolType::get())) {
+ bool b = node->i(attr::value);
+ return [b](Stack& stack) {
+ push(stack, b);
+ return 0;
+ };
} else if (
type->isSubtypeOf(NumberType::get()) &&
node->kindOf(attr::value) == AttributeKind::i) {
@@ -86,6 +99,12 @@ RegisterOperators reg({
push(stack, is);
return 0;
};
+ } else if(type->isSubtypeOf(ListType::ofBools())) {
+ auto bs = node->is(attr::value);
+ return [bs](Stack& stack) {
+ push(stack, bs);
+ return 0;
+ };
} else if(type->isSubtypeOf(ListType::ofTensors())) {
auto ts = fmap(node->ts(attr::value), [](const at::Tensor & t) -> at::Tensor {
return autograd::make_variable(t);
diff --git a/torch/csrc/jit/export.cpp b/torch/csrc/jit/export.cpp
index 973780b7d7..574dae168a 100644
--- a/torch/csrc/jit/export.cpp
+++ b/torch/csrc/jit/export.cpp
@@ -559,6 +559,8 @@ void ModuleEncoder::EncodeTypeInfo(
type_proto->set_denotation("FloatType");
} else if (kind == TypeKind::IntType) {
type_proto->set_denotation("IntType");
+ } else if (kind == TypeKind::BoolType) {
+ type_proto->set_denotation("BoolType");
} else if (kind == TypeKind::NoneType) {
type_proto->set_denotation("NoneType");
} else if (kind == TypeKind::GeneratorType) {
diff --git a/torch/csrc/jit/import.cpp b/torch/csrc/jit/import.cpp
index 4574addb3a..45053a81db 100644
--- a/torch/csrc/jit/import.cpp
+++ b/torch/csrc/jit/import.cpp
@@ -257,6 +257,8 @@ TypePtr ModuleDecoder::buildType(const onnx::TypeProto& type_proto) {
return FloatType::get();
} else if (kind == "IntType") {
return IntType::get();
+ } else if (kind == "BoolType") {
+ return BoolType::get();
} else if (kind == "NoneType") {
return NoneType::get();
} else if (kind == "GeneratorType") {
diff --git a/torch/csrc/jit/interned_strings.h b/torch/csrc/jit/interned_strings.h
index b4e6b7c139..32544aa682 100644
--- a/torch/csrc/jit/interned_strings.h
+++ b/torch/csrc/jit/interned_strings.h
@@ -46,9 +46,11 @@ namespace torch { namespace jit {
_(prim, TupleUnpack) \
_(prim, ListConstruct) \
_(prim, ListUnpack) \
+ _(prim, BoolToTensor) \
_(prim, NumToTensor) \
_(prim, TensorToNum) \
_(prim, ImplicitTensorToNum) \
+ _(prim, TensorToBool) \
_(prim, IntToFloat) \
_(prim, FloatToInt) \
_(prim, StringToFloat) \
diff --git a/torch/csrc/jit/interpreter.cpp b/torch/csrc/jit/interpreter.cpp
index 14e7fab54d..065054a0fc 100644
--- a/torch/csrc/jit/interpreter.cpp
+++ b/torch/csrc/jit/interpreter.cpp
@@ -61,14 +61,14 @@ Value* createTripCountConjunctiveCondition(
// Emit initial comparison -- initial_trip_count < max_trip_count
Value* initial_comparison_value =
g->insertNode(g->create(aten::lt, {cur_trip_count, max_trip_count}, 1))
- ->output()->setType(IntType::get());
+ ->output()->setType(BoolType::get());
// Replace initial condition with logical `and` of trip count and
// initial condition
Value* new_cond =
g->insertNode(
g->create(aten::__and__, {initial_comparison_value, cond}, 1))
- ->output()->setType(IntType::get());
+ ->output()->setType(BoolType::get());
return new_cond;
}
@@ -388,30 +388,29 @@ struct CodeImpl {
CodeImpl(std::shared_ptr<Graph>& graph_)
: preprocess(*graph_) {
graph = preprocess.graph;
- // std::cout << "into code graph:\n" << *graph << "\n";
insertNodesFromBlock(graph->block());
}
- // jump when input is 0
- void createJumpZ(int from_inst, int to_inst) {
+ // jump when input is false
+ void createJumpFalse(int from_inst, int to_inst) {
auto & inst = instructions[from_inst];
JIT_ASSERT(inst.debug_name == prim::Placeholder);
auto offset = relativeJump(from_inst, to_inst);
inst.callback = [offset](Stack & stack) {
- auto t = pop(stack).toInt();
- return (t == 0) ? offset : 0;
+ auto t = pop(stack).toBool();
+ return t ? 0 : offset;
};
inst.debug_name = prim::JumpZ;
}
- // jump when input is not 0
- void createJumpNZ(int from_inst, int to_inst) {
+ // jump when input is true
+ void createJumpTrue(int from_inst, int to_inst) {
auto & inst = instructions[from_inst];
JIT_ASSERT(inst.debug_name == prim::Placeholder);
auto offset = relativeJump(from_inst, to_inst);
inst.callback = [offset](Stack & stack) {
- auto t = pop(stack).toInt();
- return (t != 0) ? offset : 0;
+ auto t = pop(stack).toBool();
+ return t ? offset : 0;
};
inst.debug_name = prim::JumpNZ;
}
@@ -460,7 +459,7 @@ struct CodeImpl {
insertNodesFromBlock(then_block);
insertAssign(source_location, then_block->outputs(), moveFlags(then_block), node->outputs());
createJump(jump, instructions.size());
- createJumpNZ(cond_branch, then_block_start);
+ createJumpTrue(cond_branch, then_block_start);
} break;
case prim::Loop: {
// o0 = while c i0
@@ -495,8 +494,8 @@ struct CodeImpl {
// after branch: stack: ...
aliasRegistersTo(node->outputs(), body_block->inputs());
- createJumpZ(cond_branch, instructions.size());
- createJumpNZ(cond_branch_end, entry);
+ createJumpFalse(cond_branch, instructions.size());
+ createJumpTrue(cond_branch_end, entry);
} break;
default: {
insertInstruction(node);
diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp
index 90451494ba..4e4bb0b68d 100644
--- a/torch/csrc/jit/ir.cpp
+++ b/torch/csrc/jit/ir.cpp
@@ -562,18 +562,18 @@ namespace {
const OperatorSet& nondeterminstic_aten_ops() {
static OperatorSet nondeterministic_ops = {
- "aten::dropout(Tensor input, float p, int train) -> Tensor",
+ "aten::dropout(Tensor input, float p, bool train) -> Tensor",
"aten::_fused_dropout(Tensor self, float p, Generator generator) -> (Tensor, Tensor)",
"aten::_standard_gamma(Tensor self, Generator generator) -> Tensor",
"aten::bernoulli(Tensor self, *, Generator generator) -> Tensor",
"aten::bernoulli(Tensor self, float p, *, Generator generator) -> Tensor",
- "aten::multinomial(Tensor self, int num_samples, int replacement, *, Generator generator) -> Tensor",
+ "aten::multinomial(Tensor self, int num_samples, bool replacement, *, Generator generator) -> Tensor",
"aten::normal(Tensor mean, Tensor std, *, Generator generator) -> Tensor",
"aten::normal(float mean, Tensor std, *, Generator generator) -> Tensor",
"aten::normal(Tensor mean, float std, *, Generator generator) -> Tensor",
"aten::poisson(Tensor self, Generator generator) -> Tensor",
- "aten::rrelu(Tensor self, Scalar lower, Scalar upper, int training, Generator generator) -> Tensor",
- "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, int training, Generator generator) -> Tensor",
+ "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator generator) -> Tensor",
+ "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator generator) -> Tensor",
"aten::rand(int[] size, *, int dtype, int layout, int[] device) -> Tensor",
"aten::rand_like(Tensor self) -> Tensor",
"aten::rand_like(Tensor self, *, int dtype, int layout, int[] device) -> Tensor",
@@ -598,7 +598,7 @@ bool Node::isNondeterministic() const {
return false;
}
// Dropout with train = False is deterministic
- if (matches("aten::dropout(Tensor input, float p, int train) -> Tensor") && is_constant(attr::train) && !get<bool>(attr::train).value()) {
+ if (matches("aten::dropout(Tensor input, float p, bool train) -> Tensor") && is_constant(attr::train) && !get<bool>(attr::train).value()) {
return false;
}
return true;
diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h
index 0bb5c899c7..6fdb49e7d9 100644
--- a/torch/csrc/jit/ir.h
+++ b/torch/csrc/jit/ir.h
@@ -1050,6 +1050,15 @@ public:
result->output()->setType(CompleteTensorType::fromNumberType(typ));
return result;
}
+ Node* createBoolToTensor(Value* value) {
+ auto typ = value->type();
+ Node * result = create(prim::BoolToTensor, {value});
+ if (!typ->isSubtypeOf(BoolType::get())) {
+ AT_ERROR("Cannot create bool type from ", typ->str());
+ }
+ result->output()->setType(CompleteTensorType::fromBoolType());
+ return result;
+ }
Node* createTensorToNum(const TypePtr& type, Value* value) {
auto* result = create(prim::TensorToNum, {value});
result->output()->setType(type);
@@ -1060,6 +1069,11 @@ public:
result->output()->setType(type);
return result;
}
+ Node* createTensorToBool(Value* value) {
+ auto* result = create(prim::TensorToBool, {value});
+ result->output()->setType(BoolType::get());
+ return result;
+ }
Node* createIntToFloat(Value* value) {
JIT_ASSERT(*value->type() == *IntType::get());
auto* result = create(prim::IntToFloat, {value});
diff --git a/torch/csrc/jit/operator.cpp b/torch/csrc/jit/operator.cpp
index d701b536c4..818a58c965 100644
--- a/torch/csrc/jit/operator.cpp
+++ b/torch/csrc/jit/operator.cpp
@@ -57,7 +57,7 @@ struct SchemaParser {
{"str", StringType::get() },
{"float", FloatType::get() },
{"int", IntType::get() },
- {"bool", IntType::get() }, // TODO: add separate bool type
+ {"bool", BoolType::get() },
{"World", WorldType::get() },
};
auto tok = L.expect(TK_IDENT);
@@ -162,6 +162,10 @@ struct SchemaParser {
return fmap(vs, [](IValue v) {
return v.toInt();
});
+ case TypeKind::BoolType:
+ return fmap(vs, [](IValue v) {
+ return v.toBool();
+ });
default:
throw ErrorReport(range) << "lists are only supported for float or int types.";
}
@@ -191,6 +195,7 @@ struct SchemaParser {
} break;
case TypeKind::NumberType:
case TypeKind::IntType:
+ case TypeKind::BoolType:
case TypeKind::FloatType:
arg.default_value = parseSingleConstant(arg.type->kind());
break;
diff --git a/torch/csrc/jit/passes/erase_number_types.cpp b/torch/csrc/jit/passes/erase_number_types.cpp
index 7d2ba5bfeb..115d4b6a68 100644
--- a/torch/csrc/jit/passes/erase_number_types.cpp
+++ b/torch/csrc/jit/passes/erase_number_types.cpp
@@ -13,13 +13,16 @@ static void EraseNumberTypesOnBlock(Block* block) {
case prim::Constant: {
// remove primitive constants, replacing with tensor equivalent
// ONNX does not support non-tensor constants
- if(it->output()->type()->isSubtypeOf(NumberType::get())) {
+ if (it->output()->type()->isSubtypeOf(NumberType::get()) ||
+ it->output()->type()->isSubtypeOf(BoolType::get())) {
auto s = *constant_as<at::Scalar>(it->output());
WithInsertPoint guard(*it);
Value* r = block->owningGraph()->insertConstant(scalar_to_tensor(s));
it->output()->replaceAllUsesWith(r);
}
} break;
+ case prim::TensorToBool:
+ case prim::BoolToTensor:
case prim::TensorToNum:
case prim::ImplicitTensorToNum:
case prim::NumToTensor: {
@@ -30,6 +33,8 @@ static void EraseNumberTypesOnBlock(Block* block) {
for(auto o : it->outputs()) {
if (o->type()->isSubtypeOf(NumberType::get())) {
o->setType(CompleteTensorType::fromNumberType(o->type()));
+ } else if (o->type()->isSubtypeOf(BoolType::get())) {
+ o->setType(CompleteTensorType::fromBoolType());
}
}
} break;
diff --git a/torch/csrc/jit/passes/peephole.cpp b/torch/csrc/jit/passes/peephole.cpp
index 176166218f..7f2312696e 100644
--- a/torch/csrc/jit/passes/peephole.cpp
+++ b/torch/csrc/jit/passes/peephole.cpp
@@ -24,7 +24,7 @@ void PeepholeOptimize(Block * block) {
// XXX: remember that if you want to simplify an expression by combining multiple nodes
// into a different one, then you need to check that they all belong to the given block
- if (node->matches("aten::expand(Tensor self, int[] size, *, int implicit) -> Tensor",
+ if (node->matches("aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor",
/*with_const=*/attr::size)) {
// x.expand(x.size()) == x
if (auto input_type = node->namedInput(attr::self)->type()->cast<CompleteTensorType>()) {
diff --git a/torch/csrc/jit/passes/remove_expands.cpp b/torch/csrc/jit/passes/remove_expands.cpp
index 93d53e5481..1c67e68507 100644
--- a/torch/csrc/jit/passes/remove_expands.cpp
+++ b/torch/csrc/jit/passes/remove_expands.cpp
@@ -8,7 +8,8 @@ static void RemoveExpands(Block* block) {
++it) {
for (auto sub : it->blocks())
RemoveExpands(sub);
- if (it->kind() == aten::expand && it->get<int64_t>(attr::implicit) != static_cast<int64_t>(0)) {
+
+ if (it->kind() == aten::expand && it->get<bool>(attr::implicit) == true) {
it->output()->replaceAllUsesWith(it->namedInput(attr::self));
it.destroyCurrent();
}
diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp
index eedc7fd0a8..af5e8cd96f 100644
--- a/torch/csrc/jit/passes/shape_analysis.cpp
+++ b/torch/csrc/jit/passes/shape_analysis.cpp
@@ -120,7 +120,7 @@ void broadcastBinary(Node *node, std::vector<CompleteTensorTypePtr>& types, size
Node *expand = graph->create(aten::expand,
{node->inputs().at(input_idx),
graph->insertConstant(expected_size),
- graph->insertConstant(0)})
+ graph->insertConstant(false)})
->insertBefore(node);
PropagateShapeOnNode(expand);
node->replaceInput(input_idx, expand->output());
@@ -441,12 +441,12 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) {
"aten::clamp(Tensor self, Scalar min, Scalar max) -> Tensor",
"aten::clamp_max(Tensor self, Scalar max) -> Tensor",
"aten::clamp_min(Tensor self, Scalar min) -> Tensor",
- "aten::alpha_dropout(Tensor input, float p, int train) -> Tensor",
+ "aten::alpha_dropout(Tensor input, float p, bool train) -> Tensor",
"aten::bernoulli(Tensor self, float p, *, Generator generator) -> Tensor",
"aten::cos(Tensor self) -> Tensor",
"aten::cosh(Tensor self) -> Tensor",
"aten::digamma(Tensor self) -> Tensor",
- "aten::dropout(Tensor input, float p, int train) -> Tensor",
+ "aten::dropout(Tensor input, float p, bool train) -> Tensor",
"aten::elu(Tensor self, Scalar alpha, Scalar scale, Scalar input_scale) -> Tensor",
"aten::erf(Tensor self) -> Tensor",
"aten::erfc(Tensor self) -> Tensor",
@@ -462,8 +462,8 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) {
"aten::floor(Tensor self) -> Tensor",
"aten::frac(Tensor self) -> Tensor",
"aten::flip(Tensor self, int[] dims) -> Tensor",
- "aten::feature_alpha_dropout(Tensor input, float p, int train) -> Tensor",
- "aten::feature_dropout(Tensor input, float p, int train) -> Tensor",
+ "aten::feature_alpha_dropout(Tensor input, float p, bool train) -> Tensor",
+ "aten::feature_dropout(Tensor input, float p, bool train) -> Tensor",
"aten::hardshrink(Tensor self, Scalar lambd) -> Tensor",
"aten::hardtanh(Tensor self, Scalar min_val, Scalar max_val) -> Tensor",
"aten::glu(Tensor self, int dim) -> Tensor",
@@ -479,7 +479,7 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) {
"aten::reciprocal(Tensor self) -> Tensor",
"aten::relu(Tensor self) -> Tensor",
"aten::round(Tensor self) -> Tensor",
- "aten::rrelu(Tensor self, Scalar lower, Scalar upper, int training, Generator generator) -> Tensor",
+ "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator generator) -> Tensor",
"aten::rsqrt(Tensor self) -> Tensor",
"aten::selu(Tensor self) -> Tensor",
"aten::sigmoid(Tensor self) -> Tensor",
@@ -645,7 +645,7 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) {
// Knowing the type and device of weights or biases is usually enough to
// infer the output type.
static const register_formula_for nn_ops_first_input_preserving {{
- "aten::batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, int training, float momentum, float eps, int cudnn_enabled) -> Tensor",
+ "aten::batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
"aten::conv1d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor",
"aten::conv2d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor",
"aten::conv3d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor",
@@ -653,16 +653,16 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) {
"aten::conv_transpose1d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor",
"aten::conv_transpose2d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor",
"aten::conv_transpose3d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor",
- "aten::convolution(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, int transposed, int[] output_padding, int groups) -> Tensor",
+ "aten::convolution(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor",
"aten::adaptive_avg_pool1d(Tensor self, int[] output_size) -> Tensor",
"aten::adaptive_avg_pool2d(Tensor self, int[] output_size) -> Tensor",
"aten::adaptive_avg_pool3d(Tensor self, int[] output_size) -> Tensor",
- "aten::avg_pool1d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int ceil_mode, int count_include_pad) -> Tensor",
- "aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int ceil_mode, int count_include_pad) -> Tensor",
- "aten::avg_pool3d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int ceil_mode, int count_include_pad) -> Tensor",
- "aten::max_pool1d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, int ceil_mode) -> Tensor",
- "aten::max_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, int ceil_mode) -> Tensor",
- "aten::max_pool3d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, int ceil_mode) -> Tensor",
+ "aten::avg_pool1d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor",
+ "aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor",
+ "aten::avg_pool3d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor",
+ "aten::max_pool1d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor",
+ "aten::max_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor",
+ "aten::max_pool3d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor",
"aten::max_unpool2d(Tensor self, Tensor indices, int[] output_size) -> Tensor",
"aten::max_unpool3d(Tensor self, Tensor indices, int[] output_size, int[] stride, int[] padding) -> Tensor",
"aten::reflection_pad1d(Tensor self, int[] padding) -> Tensor",
@@ -670,12 +670,12 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) {
"aten::replication_pad1d(Tensor self, int[] padding) -> Tensor",
"aten::replication_pad2d(Tensor self, int[] padding) -> Tensor",
"aten::replication_pad3d(Tensor self, int[] padding) -> Tensor",
- "aten::upsample_bilinear2d(Tensor self, int[] output_size, int align_corners) -> Tensor",
- "aten::upsample_linear1d(Tensor self, int[] output_size, int align_corners) -> Tensor",
+ "aten::upsample_bilinear2d(Tensor self, int[] output_size, bool align_corners) -> Tensor",
+ "aten::upsample_linear1d(Tensor self, int[] output_size, bool align_corners) -> Tensor",
"aten::upsample_nearest1d(Tensor self, int[] output_size) -> Tensor",
"aten::upsample_nearest2d(Tensor self, int[] output_size) -> Tensor",
"aten::upsample_nearest3d(Tensor self, int[] output_size) -> Tensor",
- "aten::upsample_trilinear3d(Tensor self, int[] output_size, int align_corners) -> Tensor",
+ "aten::upsample_trilinear3d(Tensor self, int[] output_size, bool align_corners) -> Tensor",
"aten::prelu(Tensor self, Tensor weight) -> Tensor",
}, [](Node * node) -> type_vec_t {
if (auto type = node->input(0)->type()->cast<TensorType>()) {
@@ -702,10 +702,10 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) {
"aten::mean(Tensor self) -> Tensor",
"aten::median(Tensor self) -> Tensor",
"aten::norm(Tensor self, Scalar p) -> Tensor",
- "aten::std(Tensor self, int unbiased) -> Tensor",
+ "aten::std(Tensor self, bool unbiased) -> Tensor",
"aten::sum(Tensor self) -> Tensor",
"aten::trace(Tensor self) -> Tensor",
- "aten::var(Tensor self, int unbiased) -> Tensor",
+ "aten::var(Tensor self, bool unbiased) -> Tensor",
"aten::all(Tensor self) -> Tensor",
"aten::any(Tensor self) -> Tensor",
}, [](Node * node) -> type_vec_t {
@@ -735,7 +735,7 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) {
static const auto multidim_reduce_with_postprocess =
[](Node * node, size_t num_reduced_dim, bool upcast_integer) -> type_vec_t {
- auto maybe_keepdim = node->get<int64_t>(attr::keepdim);
+ auto maybe_keepdim = node->get<bool>(attr::keepdim);
if (!maybe_keepdim) return {};
if (auto type = node->input(0)->type()->cast<TensorType>()) {
if (upcast_integer && !at::isFloatingType(type->scalarType())) {
@@ -760,24 +760,24 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) {
// - First input should be the only tensor input
// - Has a bool keepdim argument
static const register_formula_for dim_reduce_ops {{
- "aten::argmax(Tensor self, int dim, int keepdim) -> Tensor",
- "aten::argmin(Tensor self, int dim, int keepdim) -> Tensor",
- "aten::max_values(Tensor self, int dim, int keepdim) -> Tensor",
- "aten::min_values(Tensor self, int dim, int keepdim) -> Tensor",
- "aten::mean(Tensor self, int dim, int keepdim) -> Tensor",
- "aten::norm(Tensor self, Scalar p, int dim, int keepdim) -> Tensor",
- "aten::std(Tensor self, int dim, int unbiased, int keepdim) -> Tensor",
- "aten::var(Tensor self, int dim, int unbiased, int keepdim) -> Tensor",
- "aten::logsumexp(Tensor self, int dim, int keepdim) -> Tensor",
- "aten::all(Tensor self, int dim, int keepdim) -> Tensor",
- "aten::any(Tensor self, int dim, int keepdim) -> Tensor",
+ "aten::argmax(Tensor self, int dim, bool keepdim) -> Tensor",
+ "aten::argmin(Tensor self, int dim, bool keepdim) -> Tensor",
+ "aten::max_values(Tensor self, int dim, bool keepdim) -> Tensor",
+ "aten::min_values(Tensor self, int dim, bool keepdim) -> Tensor",
+ "aten::mean(Tensor self, int dim, bool keepdim) -> Tensor",
+ "aten::norm(Tensor self, Scalar p, int dim, bool keepdim) -> Tensor",
+ "aten::std(Tensor self, int dim, bool unbiased, bool keepdim) -> Tensor",
+ "aten::var(Tensor self, int dim, bool unbiased, bool keepdim) -> Tensor",
+ "aten::logsumexp(Tensor self, int dim, bool keepdim) -> Tensor",
+ "aten::all(Tensor self, int dim, bool keepdim) -> Tensor",
+ "aten::any(Tensor self, int dim, bool keepdim) -> Tensor",
// Ops returning indices as second output
- "aten::kthvalue(Tensor self, int k, int dim, int keepdim) -> (Tensor, Tensor)",
- "aten::max(Tensor self, int dim, int keepdim) -> (Tensor, Tensor)",
- "aten::min(Tensor self, int dim, int keepdim) -> (Tensor, Tensor)",
- "aten::median(Tensor self, int dim, int keepdim) -> (Tensor, Tensor)",
- "aten::mode(Tensor self, int dim, int keepdim) -> (Tensor, Tensor)",
+ "aten::kthvalue(Tensor self, int k, int dim, bool keepdim) -> (Tensor, Tensor)",
+ "aten::max(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
+ "aten::min(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
+ "aten::median(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
+ "aten::mode(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
}, [](Node * node) -> type_vec_t {
// NB: Note that while this function is generally meant to be used with ops that
// have a single output, we will fix up its return right below.
@@ -798,7 +798,7 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) {
// - First input should be the only tensor input
// - has a bool keepdim argument
static const register_formula_for dim_reduce_ops_with_integer_upcast {{
- "aten::prod(Tensor self, int dim, int keepdim) -> Tensor",
+ "aten::prod(Tensor self, int dim, bool keepdim) -> Tensor",
}, [](Node * node) -> type_vec_t {
return multidim_reduce_with_postprocess(node, /*num_reduce_dim=*/1, /*integer_upcast=*/true);
}};
@@ -812,7 +812,7 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) {
// Additionally:
// - has bool keepdim and int[] dim arguments
static const register_formula_for multidim_reduce_ops_with_integer_upcast {{
- "aten::sum(Tensor self, int[] dim, int keepdim) -> Tensor",
+ "aten::sum(Tensor self, int[] dim, bool keepdim) -> Tensor",
}, [](Node * node) -> type_vec_t {
if (auto dim = node->get<std::vector<int64_t>>(attr::dim)) {
// TODO: can dim contain duplicates?
@@ -900,14 +900,14 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) {
}
};
static const register_formula_for cast_ops {{
- "aten::_cast_Byte(Tensor self, int non_blocking) -> Tensor",
- "aten::_cast_Char(Tensor self, int non_blocking) -> Tensor",
- "aten::_cast_Double(Tensor self, int non_blocking) -> Tensor",
- "aten::_cast_Float(Tensor self, int non_blocking) -> Tensor",
- "aten::_cast_Half(Tensor self, int non_blocking) -> Tensor",
- "aten::_cast_Int(Tensor self, int non_blocking) -> Tensor",
- "aten::_cast_Long(Tensor self, int non_blocking) -> Tensor",
- "aten::_cast_Short(Tensor self, int non_blocking) -> Tensor",
+ "aten::_cast_Byte(Tensor self, bool non_blocking) -> Tensor",
+ "aten::_cast_Char(Tensor self, bool non_blocking) -> Tensor",
+ "aten::_cast_Double(Tensor self, bool non_blocking) -> Tensor",
+ "aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor",
+ "aten::_cast_Half(Tensor self, bool non_blocking) -> Tensor",
+ "aten::_cast_Int(Tensor self, bool non_blocking) -> Tensor",
+ "aten::_cast_Long(Tensor self, bool non_blocking) -> Tensor",
+ "aten::_cast_Short(Tensor self, bool non_blocking) -> Tensor",
}, [](Node * node) -> type_vec_t {
if (auto type = node->namedInput(attr::self)->type()->cast<TensorType>()) {
return {type->toScalarType(get_cast_scalar_type(node))};
@@ -1003,7 +1003,7 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) {
}
return true;
}
- } else if (node->matches("aten::embedding(Tensor weight, Tensor indices, int padding_idx, int scale_grad_by_freq, int sparse) -> Tensor")) {
+ } else if (node->matches("aten::embedding(Tensor weight, Tensor indices, int padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor")) {
auto weight_type = input_type(0);
auto indices_type = input_type(1);
if (weight_type && indices_type) {
@@ -1044,7 +1044,7 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) {
node->matches("aten::reshape_as(Tensor self, Tensor other) -> Tensor")) {
return tensor_types.at(0)->withDim(tensor_types.at(1)->dim());
} else if (node->matches("aten::view(Tensor self, int[] size) -> Tensor") ||
- node->matches("aten::expand(Tensor self, int[] size, *, int implicit) -> Tensor") ||
+ node->matches("aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor") ||
node->matches("aten::as_strided(Tensor self, int[] size, int[] stride) -> Tensor") ||
node->matches("aten::as_strided(Tensor self, int[] size, int[] stride, int storage_offset) -> Tensor")) {
return reshape_prop(node, attr::size, tensor_types);
@@ -1189,12 +1189,12 @@ bool PropagateCompleteShapeOnNode(Node * node, bool insert_expands,
} else if (node->matches("aten::sum(Tensor self) -> Tensor")) {
node->output()->setType(tensor_types.at(0)->withSizes({}));
return true;
- } else if (node->matches("aten::sum(Tensor self, int[] dim, int keepdim) -> Tensor",
+ } else if (node->matches("aten::sum(Tensor self, int[] dim, bool keepdim) -> Tensor",
/*with_const=*/{attr::dim, attr::keepdim})) {
auto & tp = tensor_types.at(0);
auto sizes = tp->sizes();
auto dims = node->get<std::vector<int64_t>>(attr::dim).value();
- bool keepdim = node->get<int64_t>(attr::keepdim).value();
+ bool keepdim = node->get<bool>(attr::keepdim).value();
std::reverse(dims.begin(), dims.end());
for (int64_t dim : dims) {
SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < sizes.size());
@@ -1262,7 +1262,7 @@ bool PropagateCompleteShapeOnNode(Node * node, bool insert_expands,
node->output()->setType(tensor_types.at(1)->withSizes(tensor_types.at(0)->sizes()));
}
return true;
- } else if (node->matches("aten::expand(Tensor self, int[] size, *, int implicit) -> Tensor",
+ } else if (node->matches("aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor",
/*with_const=*/attr::size)) {
auto tp = tensor_types.at(0);
std::vector<int64_t> sizes, strides;
diff --git a/torch/csrc/jit/passes/to_batch.cpp b/torch/csrc/jit/passes/to_batch.cpp
index 0d56ca2255..f1edf80d41 100644
--- a/torch/csrc/jit/passes/to_batch.cpp
+++ b/torch/csrc/jit/passes/to_batch.cpp
@@ -41,6 +41,10 @@ void ToBatch::visitAten(Node* n, Block* block, Block* res_block){
auto to_tensor_node = res_graph->createNumToTensor(input);
res_graph->insertNode(to_tensor_node);
new_inputs[i] = to_tensor_node->output();
+ } else if(input->type() == BoolType::get()) {
+ auto to_tensor_node = res_graph->createBoolToTensor(input);
+ res_graph->insertNode(to_tensor_node);
+ new_inputs[i] = to_tensor_node->output();
}
}
@@ -58,8 +62,11 @@ void ToBatch::visitAten(Node* n, Block* block, Block* res_block){
else if(n->outputs()[0]->type() == FloatType::get()){
to_scalar_node = res_graph->createTensorToNum(FloatType::get(), outputs[0]);
}
+ else if(n->outputs()[0]->type() == BoolType::get()){
+ to_scalar_node = res_graph->createTensorToBool(outputs[0]);
+ }
else{
- throw std::runtime_error("NYI: scalar type other than int, float is not supported yet");
+ throw std::runtime_error("NYI: scalar types other than int, float, and bool are not supported yet");
}
res_graph->insertNode(to_scalar_node);
rn_env[n->outputs()[0]] = to_scalar_node->output();
@@ -348,9 +355,10 @@ void ToBatch::visitLoop(Node* n, Block* block, Block* res_block){
if(cond_is_tensor){
auto cond = batch_map.at(n->inputs()[1]);
auto cond_any = script::inlineCallTo(*res_block->owningGraph(), *getBatchOperator("any"), cond);
- auto to_int_node = res_graph->createTensorToNum(IntType::get(), cond_any[0]);
- res_graph->insertNode(to_int_node);
- rn_env[n->inputs()[1]] = to_int_node->output();
+ auto to_bool_node =
+ res_graph->createTensorToBool(cond_any[0]);
+ res_graph->insertNode(to_bool_node);
+ rn_env[n->inputs()[1]] = to_bool_node->output();
}
for(size_t i = 2; i < n->inputs().size(); i++){
auto input = n->inputs()[i];
@@ -432,9 +440,10 @@ void ToBatch::visitLoop(Node* n, Block* block, Block* res_block){
if(cond_is_tensor){
auto cond = batch_map.at(n->blocks()[0]->outputs()[0]);
auto cond_any = script::inlineCallTo(*res_block->owningGraph(), *getBatchOperator("any"), cond);
- auto to_int_node = res_graph->createTensorToNum(IntType::get(), cond_any[0]);
- res_graph->insertNode(to_int_node);
- loop_block->insertOutput(0, to_int_node->output());
+ auto to_bool_node =
+ res_graph->createTensorToBool(cond_any[0]);
+ res_graph->insertNode(to_bool_node);
+ loop_block->insertOutput(0, to_bool_node->output());
for(size_t i = 0; i < EXP_BTENSOR_SIZE; i++){
loop_block->insertOutput(i + 1, cond[i]);
}
@@ -491,6 +500,7 @@ void ToBatch::toBatch(Block* block, Block* res_block) {
case prim::NumToTensor:
visitNumToTensor(n, block, res_block);
break;
+ case prim::TensorToBool:
case prim::TensorToNum:
visitTensorToNum(n, block, res_block);
break;
diff --git a/torch/csrc/jit/pybind_utils.h b/torch/csrc/jit/pybind_utils.h
index 192cabca3f..bd81a8cc05 100644
--- a/torch/csrc/jit/pybind_utils.h
+++ b/torch/csrc/jit/pybind_utils.h
@@ -107,6 +107,8 @@ inline IValue toIValue(py::handle obj, const TypePtr& type) {
return py::cast<int64_t>(obj);
case TypeKind::NoneType:
return {};
+ case TypeKind::BoolType:
+ return py::cast<bool>(obj);
case TypeKind::TupleType: {
if(!PyTuple_Check(obj.ptr()))
throw py::cast_error(); // note: the py::cast does not throw cast_error
@@ -196,10 +198,14 @@ inline py::object toPyObject(IValue&& ivalue) {
return py::cast(ivalue.toDouble());
} else if (ivalue.isInt()) {
return py::cast(ivalue.toInt());
+ }else if (ivalue.isBool()) {
+ return py::cast(ivalue.toBool());
} else if (ivalue.isIntList()) {
return py::cast(ivalue.toIntListRef());
} else if (ivalue.isDoubleList()) {
return py::cast(ivalue.toDoubleListRef());
+ } else if (ivalue.isBoolList()) {
+ return py::cast(ivalue.toBoolListRef());
} else if (ivalue.isTensorList()) {
return py::cast(ivalue.toTensorListRef());
} else if (ivalue.isGenericList()) {
diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp
index ad03ac556c..979b07d369 100644
--- a/torch/csrc/jit/python_ir.cpp
+++ b/torch/csrc/jit/python_ir.cpp
@@ -455,6 +455,8 @@ void initPythonIRBindings(PyObject * module_) {
return "StringType";
case TypeKind::GeneratorType:
return "GeneratorType";
+ case TypeKind::BoolType:
+ return "BoolType";
case TypeKind::VarType:
return "VarType";
case TypeKind::WorldType:
@@ -491,6 +493,8 @@ void initPythonIRBindings(PyObject * module_) {
.def_static("get", &FloatType::get);
py::class_<DynamicType, Type, std::shared_ptr<DynamicType>>(m, "DynamicType")
.def_static("get", &DynamicType::get);
+ py::class_<BoolType, Type, std::shared_ptr<BoolType>>(m, "BoolType")
+ .def_static("get", &BoolType::get);
py::class_<TupleType, Type, std::shared_ptr<TupleType>>(m, "TupleType")
.def(py::init([](std::vector<TypePtr> a){ return TupleType::create(a); }))
diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp
index cdea4ab894..c7bf050dc2 100644
--- a/torch/csrc/jit/register_prim_ops.cpp
+++ b/torch/csrc/jit/register_prim_ops.cpp
@@ -70,6 +70,17 @@ RegisterOperators reg({
};
}),
Operator(
+ prim::TensorToBool,
+ [](Node* node) -> Operation {
+ return [](Stack& stack) {
+ at::Tensor a;
+ pop(stack, a);
+ at::DeviceGuard guard(a);
+ push(stack, a.item<int64_t>() != 0);
+ return 0;
+ };
+ }),
+ Operator(
prim::TensorToNum,
[](Node* node) -> Operation {
if(node->output()->type() == IntType::get()) {
@@ -124,6 +135,18 @@ RegisterOperators reg({
};
}),
Operator(
+ prim::BoolToTensor,
+ [](Node* node) -> Operation {
+ return [](Stack& stack) {
+ bool b;
+ pop(stack, b);
+ push(
+ stack,
+ autograd::make_variable(at::scalar_to_tensor(b)));
+ return 0;
+ };
+ }),
+ Operator(
prim::IntToFloat,
[](Node* node) -> Operation {
return [](Stack& stack) {
@@ -446,25 +469,25 @@ RegisterOperators reg({
});
// define implementations for primitive number ops
-#define DEFINE_GENERIC_OP(aten_op, int_op, float_op, float_result) \
- Operator( \
- #aten_op "(int a, int b) -> int", \
- [](Node* node) { \
- return [=](Stack& stack) { \
- int64_t a, b; \
- pop(stack, a, b); \
- push(stack, int_op); \
- return 0; \
- }; \
- }), \
- Operator( \
- #aten_op "(float a, float b) -> " #float_result, [](Node* node) { \
- return [=](Stack& stack) { \
- double a, b; \
- pop(stack, a, b); \
- push(stack, float_op); \
- return 0; \
- }; \
+#define DEFINE_GENERIC_OP(aten_op, int_op, float_op, int_result, float_result) \
+ Operator( \
+ #aten_op "(int a, int b) -> " #int_result, \
+ [](Node* node) { \
+ return [=](Stack& stack) { \
+ int64_t a, b; \
+ pop(stack, a, b); \
+ push(stack, int_op); \
+ return 0; \
+ }; \
+ }), \
+ Operator( \
+ #aten_op "(float a, float b) -> " #float_result, [](Node* node) { \
+ return [=](Stack& stack) { \
+ double a, b; \
+ pop(stack, a, b); \
+ push(stack, float_op); \
+ return 0; \
+ }; \
}),
#define DEFINE_INT_OP(aten_op, op) \
@@ -477,8 +500,19 @@ RegisterOperators reg({
}; \
}),
-#define DEFINE_BINARY_OP(aten_op, op) DEFINE_GENERIC_OP(aten_op, op, op, float)
-#define DEFINE_COMPARISON_OP(aten_op, op) DEFINE_GENERIC_OP(aten_op, op, op, int)
+#define DEFINE_BINARY_OP(aten_op, op) \
+ DEFINE_GENERIC_OP(aten_op, op, op, int, float)
+#define DEFINE_COMPARISON_OP(aten_op, op) \
+ DEFINE_GENERIC_OP(aten_op, op, op, bool, bool)
+#define DEFINE_BOOL_OP(aten_op, op) \
+ Operator(#aten_op "(bool a, bool b) -> bool", [](Node* node) { \
+ return [=](Stack& stack) { \
+ bool a, b; \
+ pop(stack, a, b); \
+ push(stack, op); \
+ return 0; \
+ }; \
+ }),
// Convert an python index (which may be negative) into an index usable for a
// C++ container
@@ -663,7 +697,7 @@ RegisterOperators reg2({
// Pass in two ops for handling int and float separately as % in C++ only works for int
// The modulus calculation is different between C++ and Python (on negative), we preserve
// the python behavior as it's more common and match python syntax, hence the conversion.
- DEFINE_GENERIC_OP(aten::remainder, (b + (a % b)) % b, fmod((b + fmod(a, b)), b), float)
+ DEFINE_GENERIC_OP(aten::remainder, (b + (a % b)) % b, fmod((b + fmod(a, b)), b), int, float)
// TODO: Support python floordiv (//)
// Right now aten::floordiv is only used by loop unrolling
@@ -696,8 +730,8 @@ RegisterOperators reg2({
DEFINE_COMPARISON_OP(aten::le, a <= b)
DEFINE_COMPARISON_OP(aten::ge, a >= b)
- DEFINE_INT_OP(aten::__and__, a&& b)
- DEFINE_INT_OP(aten::__or__, a || b)
+ DEFINE_BOOL_OP(aten::__and__, a && b)
+ DEFINE_BOOL_OP(aten::__or__, a || b)
Operator("aten::_construct_empty_int_list() -> int[]",
[](Node* node) -> Operation {
diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp
index 28aa735fc3..474722f536 100644
--- a/torch/csrc/jit/script/compiler.cpp
+++ b/torch/csrc/jit/script/compiler.cpp
@@ -74,6 +74,8 @@ static Value* typeCast(const SourceRange& loc, Value* value, TypePtr dst) {
n = graph.createNumToTensor(value);
} else if (dst->isSubtypeOf(NumberType::get()) && orig->isSubtypeOf(DynamicType::get())) {
n = graph.createTensorToNum(dst, value);
+ } else if (dst->isSubtypeOf(BoolType::get()) && orig->isSubtypeOf(DynamicType::get())) {
+ n = graph.createTensorToBool(value);
} else if(dst->isSubtypeOf(IntType::get()) && orig->isSubtypeOf(FloatType::get())) {
n = graph.createFloatToInt(value);
} else if(dst->isSubtypeOf(FloatType::get()) && orig->isSubtypeOf(IntType::get())) {
@@ -324,7 +326,7 @@ struct Environment {
{"print", std::make_shared<PrintValue>()},
{"float", std::make_shared<CastValue>(FloatType::get())},
{"int", std::make_shared<CastValue>(IntType::get())},
- {"bool", std::make_shared<CastValue>(IntType::get())},
+ {"bool", std::make_shared<CastValue>(BoolType::get())},
// todo(zach): remove when we can correctly export torch.full via ONNX
// or we have implicit conversion that can convert numbers to tensors
{"_to_tensor", std::make_shared<CastValue>(DynamicType::get()) },
@@ -1048,9 +1050,9 @@ private:
Value* emitCond(Expr cond) {
Value* v = emitExpr(cond);
- if (!v->type()->isSubtypeOf(IntType::get())) {
+ if (!v->type()->isSubtypeOf(BoolType::get())) {
ErrorReport error(cond);
- error << "expected an integer expression for condition but found "
+ error << "expected a boolean expression for condition but found "
<< v->type()->str();
if (v->type()->isSubtypeOf(DynamicType::get())) {
error << ", to use a tensor in a boolean"
@@ -1928,6 +1930,7 @@ const std::unordered_map<std::string, TypePtr> &ident_to_type_lut() {
{"Tensor", DynamicType::get()},
{"int", IntType::get()},
{"float", FloatType::get()},
+ {"bool", BoolType::get()},
};
return map;
}
diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp
index 4c7df820b1..fec69ac317 100644
--- a/torch/csrc/jit/script/init.cpp
+++ b/torch/csrc/jit/script/init.cpp
@@ -278,12 +278,12 @@ std::shared_ptr<SugaredValue> toSugaredValue(
// f = f + 1
auto& g = *m.graph();
if (is_constant) {
- if (py::isinstance<py::int_>(obj)) {
+ if (py::isinstance<py::bool_>(obj)) {
+ return toSimple(g.insertConstant(py::cast<bool>(obj), loc));
+ } else if (py::isinstance<py::int_>(obj)) {
return toSimple(g.insertConstant(py::cast<int64_t>(obj), loc));
} else if (py::isinstance<py::float_>(obj)) {
return toSimple(g.insertConstant(py::cast<float>(obj), loc));
- } else if (py::isinstance<py::bool_>(obj)) {
- return toSimple(g.insertConstant(py::cast<bool>(obj), loc));
} else if (THPDevice_Check(obj.ptr())) {
auto device = (THPDevice*)obj.ptr();
std::vector<int64_t> v = {static_cast<int64_t>(device->device.type()),
diff --git a/torch/csrc/jit/type.cpp b/torch/csrc/jit/type.cpp
index 855adad429..ce3cd8c1f2 100644
--- a/torch/csrc/jit/type.cpp
+++ b/torch/csrc/jit/type.cpp
@@ -46,6 +46,8 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
out << "float";
} else if(t.kind() == TypeKind::IntType) {
out << "int";
+ } else if(t.kind() == TypeKind::BoolType) {
+ out << "bool";
} else if(t.kind() == TypeKind::ListType) {
auto prim = t.cast<ListType>()->getElementType();
out << *prim << "[]";
@@ -85,6 +87,10 @@ FloatTypePtr FloatType::get() {
static auto value = FloatType::create();
return value;
}
+BoolTypePtr BoolType::get() {
+ static auto value = BoolType::create();
+ return value;
+}
NoneTypePtr NoneType::get() {
static auto value = NoneType::create();
return value;
@@ -113,6 +119,10 @@ ListTypePtr ListType::ofFloats() {
static auto value = ListType::create(FloatType::get());
return value;
}
+ListTypePtr ListType::ofBools() {
+ static auto value = ListType::create(BoolType::get());
+ return value;
+}
TypePtr inferTypeFrom(const IValue& value) {
if (value.isTensor()) {
@@ -121,12 +131,16 @@ TypePtr inferTypeFrom(const IValue& value) {
return FloatType::get();
} else if (value.isInt()) {
return IntType::get();
+ } else if (value.isBool()) {
+ return BoolType::get();
} else if (value.isString()) {
return StringType::get();
} else if (value.isIntList()) {
return ListType::ofInts();
} else if (value.isTensorList()) {
return ListType::ofTensors();
+ } else if (value.isBoolList()) {
+ return ListType::ofBools();
} else if (value.isDoubleList()) {
return ListType::ofFloats();
} else if (value.isTuple()) {
diff --git a/torch/csrc/jit/type.h b/torch/csrc/jit/type.h
index 49748de239..69133f6be9 100644
--- a/torch/csrc/jit/type.h
+++ b/torch/csrc/jit/type.h
@@ -27,6 +27,7 @@ _(IntType) \
_(NoneType) \
_(StringType) \
_(GeneratorType) \
+_(BoolType) \
_(VarType) \
_(WorldType) \
@@ -343,6 +344,7 @@ struct TORCH_API CompleteTensorType : public TensorType {
return prod;
}
static TypePtr fromNumberType(TypePtr typ);
+ static TypePtr fromBoolType();
private:
CompleteTensorType(const at::Tensor& tensor)
@@ -438,6 +440,7 @@ struct TORCH_API ListType : public Type {
static ListTypePtr ofTensors();
static ListTypePtr ofInts();
static ListTypePtr ofFloats();
+ static ListTypePtr ofBools();
static const TypeKind Kind = TypeKind::ListType;
private:
@@ -605,6 +608,31 @@ private:
: Type(TypeKind::IntType) {}
};
+struct BoolType;
+using BoolTypePtr = std::shared_ptr<BoolType>;
+// This node represents a Python bool value
+struct TORCH_API BoolType : public Type {
+ template<typename ... T>
+ static BoolTypePtr create( T&& ... all ) {
+ return BoolTypePtr(new BoolType(std::forward<T>(all)... ));
+ }
+ bool operator==(const Type& rhs) const override {
+ return rhs.kind() == kind();
+ }
+ std::string str() const override {
+ return "bool";
+ }
+ bool isSubtypeOf(const TypePtr rhs) const override {
+ return *this == *rhs || rhs->kind() == TypeKind::BoolType;
+ }
+ static const TypeKind Kind = TypeKind::BoolType;
+ // global singleton
+ static BoolTypePtr get();
+private:
+ BoolType()
+ : Type(TypeKind::BoolType) {}
+};
+
struct StringType;
using StringTypePtr = std::shared_ptr<StringType>;
// This node represents a Python string value
@@ -728,10 +756,17 @@ inline TypePtr CompleteTensorType::fromNumberType(TypePtr typ) {
return CompleteTensorType::create(at::kLong, -1, {});
} else if (typ->isSubtypeOf(FloatType::get())) {
return CompleteTensorType::create(at::kFloat, -1, {});
+ } else if (typ->isSubtypeOf(BoolType::get())) {
+ return CompleteTensorType::create(at::kLong, -1, {});
}
AT_ERROR("unknown number type", typ->str());
}
+inline TypePtr CompleteTensorType::fromBoolType() {
+ return CompleteTensorType::create(at::kLong, -1, {});
+}
+
+
// Attempt to find the correct supertype of t1 and t2. If none is found then
// nullopt will be returned. If t1 == t2, or t1 is a type refinement of t2,
// then t2 will be returned (and vice versa).
@@ -755,7 +790,7 @@ TypePtr getTypePtr() {
template<> inline TypePtr getTypePtr<at::Tensor>() { return DynamicType::get(); }
template<> inline TypePtr getTypePtr<double>() { return FloatType::get(); }
template<> inline TypePtr getTypePtr<int64_t>() { return IntType::get(); }
-template<> inline TypePtr getTypePtr<bool>() { return IntType::get(); }
+template<> inline TypePtr getTypePtr<bool>() { return BoolType::get(); }
template<> inline TypePtr getTypePtr<at::Scalar>() { return NumberType::get(); }
template<> inline TypePtr getTypePtr<std::vector<at::Tensor>>() { return ListType::ofTensors(); }
template<> inline TypePtr getTypePtr<std::vector<double>>() { return ListType::ofFloats(); }
diff --git a/torch/jit/batchop.py b/torch/jit/batchop.py
index 229cafbb94..fed8d2a7c6 100644
--- a/torch/jit/batchop.py
+++ b/torch/jit/batchop.py
@@ -214,8 +214,8 @@ def batch_where(data, mask, dims, data1, mask1, dims1, data2, mask2, dims2):
@torch.jit.script
-def batch_where_scalar(cond_, data1, mask1, dims1, data2, mask2, dims2):
- cond = torch.zeros([1], dtype=torch.uint8) * cond_
+def batch_where_scalar(cond, data1, mask1, dims1, data2, mask2, dims2):
+ cond = torch.zeros([1], dtype=torch.uint8)
res_data = torch.where(cond, data1, data2)
res_mask = torch.where(cond, mask1, mask2)
res_dims = torch.where(cond, dims1, dims2)
@@ -304,7 +304,7 @@ def batch_unsqueeze(data, mask, dims, dim_):
@torch.jit.script
def batch_argmax(data, mask, dims, dim_, keepdim_):
dim = int(dim_)
- keepdim = int(keepdim_)
+ keepdim = bool(keepdim_)
# if dim == 0:
# raise ValueError("cannot do argmax along batch_dim")
batch_size = data.size(0)
@@ -338,8 +338,8 @@ def batch_argmax(data, mask, dims, dim_, keepdim_):
def batch_topk(data, mask, dims, k_, dim_, largest_, sorted_):
k = int(k_)
dim = int(dim_)
- largest = int(largest_)
- sorted = int(sorted_)
+ largest = bool(largest_)
+ sorted = bool(sorted_)
# if dim == 0:
# raise ValueError("cannot do topk along batch_dim")
batch_size = data.size(0)