diff options
58 files changed, 440 insertions, 227 deletions
diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp index 5df0a5b49c..58e90516c9 100644 --- a/aten/src/ATen/core/ivalue.cpp +++ b/aten/src/ATen/core/ivalue.cpp @@ -6,9 +6,11 @@ _(Tensor) \ _(Double) \ _(Int) \ + _(Bool) \ _(Tuple) \ _(IntList) \ _(DoubleList) \ + _(BoolList) \ _(String) \ _(TensorList) \ _(Blob) \ diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index 5e210d638d..0bdd85a587 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -74,6 +74,7 @@ struct C10_EXPORT Tuple : public List<IValue> { using IntList = List<int64_t>; using TensorList = List<at::Tensor>; using DoubleList = List<double>; +using BoolList = List<bool>; using GenericList = List<IValue>; // IValue is the generic tagged union used by the interpreter to hold @@ -88,9 +89,11 @@ using GenericList = List<IValue>; _(Tensor) \ _(Double) \ _(Int) \ + _(Bool) \ _(Tuple) \ _(IntList) \ _(DoubleList) \ + _(BoolList) \ _(String) \ _(TensorList) \ _(Blob) \ @@ -224,8 +227,6 @@ struct CAFFE2_API IValue final { // allow you to pass literals (3, 4) without ambiguity IValue(int32_t i) : IValue(static_cast<int64_t>(i)) {} - IValue(bool b) - : IValue(static_cast<int64_t>(b)) {} bool isInt() const { return Tag::Int == tag; } @@ -234,6 +235,17 @@ struct CAFFE2_API IValue final { return payload.as_int; } + // Bool + IValue(bool b) + : tag(Tag::Bool), is_intrusive_ptr(false) { + payload.as_bool = b; + } + bool isBool() const { return Tag::Bool == tag; } + bool toBool() const { + AT_ASSERT(isBool()); + return payload.as_bool; + } + // IntList IValue(c10::intrusive_ptr<IntList> v); IValue(std::vector<int64_t> v); @@ -251,6 +263,7 @@ struct CAFFE2_API IValue final { const std::vector<int64_t>& toIntListRef() const; const std::vector<double>& toDoubleListRef() const; + const std::vector<bool>& toBoolListRef() const; const std::vector<at::Tensor>& toTensorListRef() const; const std::vector<IValue>& toGenericListRef() const; @@ -280,6 +293,19 @@ struct CAFFE2_API IValue final { return toIntrusivePtr<DoubleList>(); } + // BoolList + IValue(c10::intrusive_ptr<BoolList> v); + IValue(std::vector<bool> v); + bool isBoolList() const { return Tag::BoolList == tag; } + c10::intrusive_ptr<BoolList> toBoolList() && { + AT_ASSERT(isBoolList()); + return moveToIntrusivePtr<BoolList>(); + } + c10::intrusive_ptr<BoolList> toBoolList() const & { + AT_ASSERT(isBoolList()); + return toIntrusivePtr<BoolList>(); + } + //TensorList IValue(c10::intrusive_ptr<TensorList> v); IValue(std::vector<at::Tensor> v); @@ -323,15 +349,16 @@ struct CAFFE2_API IValue final { } } bool isScalar() { - return isDouble() || isInt(); + return isDouble() || isInt() || isBool(); } at::Scalar toScalar() const { if(isDouble()) return toDouble(); else if(isInt()) return toInt(); - else - throw std::runtime_error("IValue is not a Scalar"); + else if (isBool()) + return int(toBool()); + throw std::runtime_error("IValue is not a Scalar"); } // for debugging @@ -396,6 +423,7 @@ struct CAFFE2_API IValue final { union { int64_t as_int; double as_double; + bool as_bool; c10::intrusive_ptr_target* as_intrusive_ptr; World as_world; } payload; @@ -419,15 +447,16 @@ DEFINE_TO(at::Tensor, toTensor) DEFINE_TO(c10::intrusive_ptr<Tuple>, toTuple) DEFINE_TO(double, toDouble) DEFINE_TO(int64_t, toInt) +DEFINE_TO(bool, toBool) DEFINE_TO(c10::intrusive_ptr<DoubleList>, toDoubleList) DEFINE_TO(c10::intrusive_ptr<IntList>, toIntList) DEFINE_TO(c10::intrusive_ptr<TensorList>, toTensorList) DEFINE_TO(c10::intrusive_ptr<GenericList>, toGenericList) DEFINE_TO(c10::intrusive_ptr<ConstantString>, toString) DEFINE_TO(at::Scalar, toScalar) -DEFINE_TO(bool, toInt) DEFINE_TO(std::vector<int64_t>, toIntListRef) DEFINE_TO(std::vector<double>, toDoubleListRef) +DEFINE_TO(std::vector<bool>, toBoolListRef) DEFINE_TO(std::vector<at::Tensor>, toTensorListRef) DEFINE_TO(std::vector<IValue>, toGenericListRef) DEFINE_TO(World, toWorld) @@ -490,6 +519,13 @@ inline IValue::IValue(c10::intrusive_ptr<DoubleList> v) inline IValue::IValue(std::vector<double> v) : IValue(DoubleList::create(std::move(v))) {} +inline IValue::IValue(c10::intrusive_ptr<BoolList> v) +: tag(Tag::BoolList), is_intrusive_ptr(true) { + payload.as_intrusive_ptr = v.release(); +} +inline IValue::IValue(std::vector<bool> v) +: IValue(BoolList::create(std::move(v))) {} + inline IValue::IValue(c10::intrusive_ptr<TensorList> v) : tag(Tag::TensorList), is_intrusive_ptr(true) { payload.as_intrusive_ptr = v.release(); @@ -517,6 +553,10 @@ inline const std::vector<at::Tensor>& IValue::toTensorListRef() const { return toTensorList()->elements(); } +inline const std::vector<bool>& IValue::toBoolListRef() const { + return toBoolList()->elements(); +} + inline const std::vector<IValue>& IValue::toGenericListRef() const { return toGenericList()->elements(); } diff --git a/test/expect/TestBatched.test_for.expect b/test/expect/TestBatched.test_for.expect index 8932957402..6b15b3e799 100644 --- a/test/expect/TestBatched.test_for.expect +++ b/test/expect/TestBatched.test_for.expect @@ -5,7 +5,7 @@ graph(%x.1_data : Dynamic %y_mask : Dynamic %y_dims : Dynamic) { %6 : int = prim::Constant[value=10]() - %7 : int = prim::Constant[value=1]() + %7 : bool = prim::Constant[value=1]() %x : Dynamic, %9 : Dynamic, %10 : Dynamic = prim::Loop(%6, %7, %x.1_data, %x.1_mask, %x.1_dims) block0(%loop_num : int, %5_data : Dynamic, %5_mask : Dynamic, %5_dims : Dynamic) { %15 : int = prim::Constant[value=1]() @@ -14,7 +14,7 @@ graph(%x.1_data : Dynamic %data.1 : Dynamic = aten::add(%5_data, %y_data, %alpha) %mask : Dynamic = aten::mul(%5_mask, %y_mask) %dims : Dynamic = aten::__or__(%5_dims, %y_dims) - %21 : int = prim::Constant[value=1]() + %21 : bool = prim::Constant[value=1]() %data : Dynamic = aten::where(%mask, %data.1, %5_data) -> (%21, %data, %mask, %dims) } diff --git a/test/expect/TestBatched.test_if_else.expect b/test/expect/TestBatched.test_if_else.expect index ddb6763469..86475a61e3 100644 --- a/test/expect/TestBatched.test_if_else.expect +++ b/test/expect/TestBatched.test_if_else.expect @@ -7,7 +7,7 @@ graph(%a.1_data : Dynamic %6 : Dynamic = aten::gt(%a.1_data, %b_data) %7 : Dynamic = aten::mul(%a.1_mask, %b_mask) %8 : Dynamic = aten::__or__(%a.1_dims, %b_dims) - %9 : int = prim::TensorToNum(%6) + %9 : bool = prim::TensorToBool(%6) %10 : int = prim::Constant[value=1]() %11 : Long() = prim::NumToTensor(%10) %alpha.1 : float = prim::TensorToNum(%11) @@ -24,17 +24,17 @@ graph(%a.1_data : Dynamic %23 : Dynamic = aten::type_as(%7, %6) %cond_mask.1 : Dynamic = aten::mul(%6, %23) %25 : int = aten::dim(%cond_mask.1) - %26 : int = aten::eq(%25, %22) + %26 : bool = aten::eq(%25, %22) %cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%26) block0() { %30 : int = aten::dim(%data.1) %31 : int = aten::sub(%30, %22) - %32 : int = prim::Constant[value=1]() + %32 : bool = prim::Constant[value=1]() %data.3 : Dynamic = prim::Loop(%31, %32, %cond_mask.1) block0(%_ : int, %35 : Dynamic) { %36 : int = aten::dim(%35) %data.2 : Dynamic = aten::unsqueeze(%35, %36) - %38 : int = prim::Constant[value=1]() + %38 : bool = prim::Constant[value=1]() -> (%38, %data.2) } %cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1) diff --git a/test/expect/TestBatched.test_if_else_with_scalar.expect b/test/expect/TestBatched.test_if_else_with_scalar.expect index 08057e4d10..cbe4a9f05b 100644 --- a/test/expect/TestBatched.test_if_else_with_scalar.expect +++ b/test/expect/TestBatched.test_if_else_with_scalar.expect @@ -8,7 +8,7 @@ graph(%a.1_data : Dynamic %7 : Float() = prim::NumToTensor(%6) %other : float = prim::TensorToNum(%7) %9 : Dynamic = aten::gt(%a.1_data, %other) - %10 : int = prim::TensorToNum(%9) + %10 : bool = prim::TensorToBool(%9) %11 : int = prim::Constant[value=1]() %12 : Long() = prim::NumToTensor(%11) %alpha.1 : float = prim::TensorToNum(%12) @@ -25,17 +25,17 @@ graph(%a.1_data : Dynamic %24 : Dynamic = aten::type_as(%a.1_mask, %9) %cond_mask.1 : Dynamic = aten::mul(%9, %24) %26 : int = aten::dim(%cond_mask.1) - %27 : int = aten::eq(%26, %23) + %27 : bool = aten::eq(%26, %23) %cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%27) block0() { %31 : int = aten::dim(%data.1) %32 : int = aten::sub(%31, %23) - %33 : int = prim::Constant[value=1]() + %33 : bool = prim::Constant[value=1]() %data.3 : Dynamic = prim::Loop(%32, %33, %cond_mask.1) block0(%_ : int, %36 : Dynamic) { %37 : int = aten::dim(%36) %data.2 : Dynamic = aten::unsqueeze(%36, %37) - %39 : int = prim::Constant[value=1]() + %39 : bool = prim::Constant[value=1]() -> (%39, %data.2) } %cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1) diff --git a/test/expect/TestBatched.test_if_noelse.expect b/test/expect/TestBatched.test_if_noelse.expect index a408563916..8dbed77571 100644 --- a/test/expect/TestBatched.test_if_noelse.expect +++ b/test/expect/TestBatched.test_if_noelse.expect @@ -7,7 +7,7 @@ graph(%a.1_data : Dynamic %6 : Dynamic = aten::gt(%a.1_data, %b_data) %7 : Dynamic = aten::mul(%a.1_mask, %b_mask) %8 : Dynamic = aten::__or__(%a.1_dims, %b_dims) - %9 : int = prim::TensorToNum(%6) + %9 : bool = prim::TensorToBool(%6) %10 : int = prim::Constant[value=1]() %11 : Long() = prim::NumToTensor(%10) %alpha : float = prim::TensorToNum(%11) @@ -18,17 +18,17 @@ graph(%a.1_data : Dynamic %17 : Dynamic = aten::type_as(%7, %6) %cond_mask.1 : Dynamic = aten::mul(%6, %17) %19 : int = aten::dim(%cond_mask.1) - %20 : int = aten::eq(%19, %16) + %20 : bool = aten::eq(%19, %16) %cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%20) block0() { %24 : int = aten::dim(%data.1) %25 : int = aten::sub(%24, %16) - %26 : int = prim::Constant[value=1]() + %26 : bool = prim::Constant[value=1]() %data.3 : Dynamic = prim::Loop(%25, %26, %cond_mask.1) block0(%_ : int, %29 : Dynamic) { %30 : int = aten::dim(%29) %data.2 : Dynamic = aten::unsqueeze(%29, %30) - %32 : int = prim::Constant[value=1]() + %32 : bool = prim::Constant[value=1]() -> (%32, %data.2) } %cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1) diff --git a/test/expect/TestBatched.test_if_noelse_with_scalar.expect b/test/expect/TestBatched.test_if_noelse_with_scalar.expect index 087868ea16..d8f453c965 100644 --- a/test/expect/TestBatched.test_if_noelse_with_scalar.expect +++ b/test/expect/TestBatched.test_if_noelse_with_scalar.expect @@ -8,7 +8,7 @@ graph(%a.1_data : Dynamic %7 : Float() = prim::NumToTensor(%6) %other : float = prim::TensorToNum(%7) %9 : Dynamic = aten::gt(%a.1_data, %other) - %10 : int = prim::TensorToNum(%9) + %10 : bool = prim::TensorToBool(%9) %11 : int = prim::Constant[value=1]() %12 : Long() = prim::NumToTensor(%11) %alpha : float = prim::TensorToNum(%12) @@ -19,17 +19,17 @@ graph(%a.1_data : Dynamic %18 : Dynamic = aten::type_as(%a.1_mask, %9) %cond_mask.1 : Dynamic = aten::mul(%9, %18) %20 : int = aten::dim(%cond_mask.1) - %21 : int = aten::eq(%20, %17) + %21 : bool = aten::eq(%20, %17) %cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%21) block0() { %25 : int = aten::dim(%data.1) %26 : int = aten::sub(%25, %17) - %27 : int = prim::Constant[value=1]() + %27 : bool = prim::Constant[value=1]() %data.3 : Dynamic = prim::Loop(%26, %27, %cond_mask.1) block0(%_ : int, %30 : Dynamic) { %31 : int = aten::dim(%30) %data.2 : Dynamic = aten::unsqueeze(%30, %31) - %33 : int = prim::Constant[value=1]() + %33 : bool = prim::Constant[value=1]() -> (%33, %data.2) } %cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1) diff --git a/test/expect/TestBatched.test_while.expect b/test/expect/TestBatched.test_while.expect index 7aba7a89ac..5cd196b56f 100644 --- a/test/expect/TestBatched.test_while.expect +++ b/test/expect/TestBatched.test_while.expect @@ -8,12 +8,12 @@ graph(%a.1_data : Dynamic %7 : Dynamic = aten::gt(%a.1_data, %b_data) %8 : Dynamic = aten::mul(%a.1_mask, %b_mask) %9 : Dynamic = aten::__or__(%a.1_dims, %b_dims) - %10 : int = prim::TensorToNum(%7) + %10 : bool = prim::TensorToBool(%7) %11 : int = prim::Constant[value=0]() %12 : Dynamic = aten::mul(%7, %8) %13 : Dynamic = aten::sum(%12) %14 : Dynamic = aten::gt(%13, %11) - %15 : int = prim::TensorToNum(%14) + %15 : bool = prim::TensorToBool(%14) %16 : Dynamic, %17 : Dynamic, %18 : Dynamic, %a : Dynamic, %20 : Dynamic, %21 : Dynamic = prim::Loop(%6, %15, %7, %8, %9, %a.1_data, %a.1_mask, %a.1_dims) block0(%loop_num : int, %cond_data.2 : Dynamic, %cond_mask.3 : Dynamic, %cond_dims : Dynamic, %6_data : Dynamic, %6_mask : Dynamic, %6_dims : Dynamic) { %29 : int = prim::Constant[value=1]() @@ -25,22 +25,22 @@ graph(%a.1_data : Dynamic %35 : Dynamic = aten::gt(%data.1, %b_data) %36 : Dynamic = aten::mul(%mask, %b_mask) %37 : Dynamic = aten::__or__(%dims, %b_dims) - %38 : int = prim::TensorToNum(%35) + %38 : bool = prim::TensorToBool(%35) %39 : int = prim::Constant[value=1]() %40 : Dynamic = aten::type_as(%cond_mask.3, %cond_data.2) %cond_mask.1 : Dynamic = aten::mul(%cond_data.2, %40) %42 : int = aten::dim(%cond_mask.1) - %43 : int = aten::eq(%42, %39) + %43 : bool = aten::eq(%42, %39) %cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%43) block0() { %47 : int = aten::dim(%data.1) %48 : int = aten::sub(%47, %39) - %49 : int = prim::Constant[value=1]() + %49 : bool = prim::Constant[value=1]() %data.3 : Dynamic = prim::Loop(%48, %49, %cond_mask.1) block0(%_ : int, %52 : Dynamic) { %53 : int = aten::dim(%52) %data.2 : Dynamic = aten::unsqueeze(%52, %53) - %55 : int = prim::Constant[value=1]() + %55 : bool = prim::Constant[value=1]() -> (%55, %data.2) } %cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1) @@ -57,7 +57,7 @@ graph(%a.1_data : Dynamic %62 : Dynamic = aten::mul(%35, %36) %63 : Dynamic = aten::sum(%62) %64 : Dynamic = aten::gt(%63, %61) - %65 : int = prim::TensorToNum(%64) + %65 : bool = prim::TensorToBool(%64) -> (%65, %35, %36, %37, %res_data, %res_mask, %res_dims) } return (%a, %20, %21); diff --git a/test/expect/TestJit.test_batchnorm.expect b/test/expect/TestJit.test_batchnorm.expect index c61390578d..e1fc75d5a4 100644 --- a/test/expect/TestJit.test_batchnorm.expect +++ b/test/expect/TestJit.test_batchnorm.expect @@ -4,10 +4,10 @@ graph(%0 : Double(2, 2, 2, 2) %3 : Double(2) %4 : Double(2) %5 : Long()) { - %6 : int = prim::Constant[value=1](), scope: BatchNorm2d + %6 : bool = prim::Constant[value=1](), scope: BatchNorm2d %7 : float = prim::Constant[value=0.1](), scope: BatchNorm2d %8 : float = prim::Constant[value=1e-05](), scope: BatchNorm2d - %9 : int = prim::Constant[value=1](), scope: BatchNorm2d + %9 : bool = prim::Constant[value=1](), scope: BatchNorm2d %10 : Double(2, 2, 2, 2) = aten::batch_norm(%0, %1, %2, %3, %4, %6, %7, %8, %9), scope: BatchNorm2d return (%10); } diff --git a/test/expect/TestJit.test_constant_prop_if_constant.expect b/test/expect/TestJit.test_constant_prop_if_constant.expect index fa1cb8053a..d373c1ee12 100644 --- a/test/expect/TestJit.test_constant_prop_if_constant.expect +++ b/test/expect/TestJit.test_constant_prop_if_constant.expect @@ -1,10 +1,10 @@ graph(%a : Dynamic %b : Dynamic) { %c2.1 : int = prim::Constant[value=1]() - %3 : int = prim::TensorToNum(%a) + %3 : bool = prim::TensorToBool(%a) %c0.4 : int, %c1 : int = prim::If(%3) block0() { - %6 : int = prim::TensorToNum(%b) + %6 : bool = prim::TensorToBool(%b) %c0.3 : int = prim::If(%6) block0() { %8 : int = prim::Constant[value=2]() diff --git a/test/expect/TestJit.test_constant_prop_loop_constant.expect b/test/expect/TestJit.test_constant_prop_loop_constant.expect index d0d4c8ed7d..f5bc3f5720 100644 --- a/test/expect/TestJit.test_constant_prop_loop_constant.expect +++ b/test/expect/TestJit.test_constant_prop_loop_constant.expect @@ -3,16 +3,16 @@ graph() { %b.2 : int = prim::Constant[value=1]() %2 : int = prim::Constant[value=2147483647]() %b.1 : int = prim::Constant[value=0]() - %4 : int = prim::Constant[value=1]() + %4 : bool = prim::Constant[value=1]() %b.3 : int = prim::Loop(%2, %4, %b.1) block0(%6 : int, %7 : int) { - %8 : int = prim::Constant[value=1]() + %8 : bool = prim::Constant[value=1]() -> (%8, %b.2) } - %9 : int = prim::Constant[value=0]() + %9 : bool = prim::Constant[value=0]() %b : int = prim::Loop(%2, %9, %b.3) block0(%11 : int, %12 : int) { - %13 : int = prim::Constant[value=0]() + %13 : bool = prim::Constant[value=0]() -> (%13, %b.4) } return (%b); diff --git a/test/expect/TestJit.test_constant_prop_nested.expect b/test/expect/TestJit.test_constant_prop_nested.expect index 09ef82076e..4a644a0d00 100644 --- a/test/expect/TestJit.test_constant_prop_nested.expect +++ b/test/expect/TestJit.test_constant_prop_nested.expect @@ -1,7 +1,7 @@ graph(%a : Dynamic) { %1 : int = prim::Constant[value=2]() %2 : Dynamic = aten::lt(%a, %1) - %3 : int = prim::TensorToNum(%2) + %3 : bool = prim::TensorToBool(%2) %c : int = prim::If(%3) block0() { %5 : int = prim::Constant[value=5]() diff --git a/test/expect/TestJit.test_conv.expect b/test/expect/TestJit.test_conv.expect index fcb53bad14..20bb072015 100644 --- a/test/expect/TestJit.test_conv.expect +++ b/test/expect/TestJit.test_conv.expect @@ -10,14 +10,14 @@ graph(%0 : Double(20, 16, 50, 40) %9 : int = prim::Constant[value=1](), scope: Conv2d %10 : int = prim::Constant[value=1](), scope: Conv2d %11 : int[] = prim::ListConstruct(%9, %10), scope: Conv2d - %12 : int = prim::Constant[value=0](), scope: Conv2d + %12 : bool = prim::Constant[value=0](), scope: Conv2d %13 : int = prim::Constant[value=0](), scope: Conv2d %14 : int = prim::Constant[value=0](), scope: Conv2d %15 : int[] = prim::ListConstruct(%13, %14), scope: Conv2d %16 : int = prim::Constant[value=1](), scope: Conv2d - %17 : int = prim::Constant[value=0](), scope: Conv2d - %18 : int = prim::Constant[value=0](), scope: Conv2d - %19 : int = prim::Constant[value=1](), scope: Conv2d + %17 : bool = prim::Constant[value=0](), scope: Conv2d + %18 : bool = prim::Constant[value=0](), scope: Conv2d + %19 : bool = prim::Constant[value=1](), scope: Conv2d %20 : Double(20, 13, 48, 38) = aten::_convolution(%0, %1, %2, %5, %8, %11, %12, %15, %16, %17, %18, %19), scope: Conv2d return (%20); } diff --git a/test/expect/TestJit.test_dropout.expect b/test/expect/TestJit.test_dropout.expect index 3daa3484c6..3d0d7d312d 100644 --- a/test/expect/TestJit.test_dropout.expect +++ b/test/expect/TestJit.test_dropout.expect @@ -1,6 +1,6 @@ graph(%0 : Double(2, 2)) { %1 : float = prim::Constant[value=0.6](), scope: Dropout - %2 : int = prim::Constant[value=1](), scope: Dropout + %2 : bool = prim::Constant[value=1](), scope: Dropout %3 : Double(2, 2) = aten::dropout(%0, %1, %2), scope: Dropout return (%3); } diff --git a/test/expect/TestJit.test_inplace_copy.expect b/test/expect/TestJit.test_inplace_copy.expect new file mode 100644 index 0000000000..f046063c97 --- /dev/null +++ b/test/expect/TestJit.test_inplace_copy.expect @@ -0,0 +1,17 @@ +graph(%0 : Double(4, 4)) { + %1 : int = prim::Constant[value=0]() + %2 : int = aten::size(%0, %1) + %3 : Long() = prim::NumToTensor(%2) + %4 : int = prim::TensorToNum(%3) + %5 : int = prim::Constant[value=1]() + %6 : int = aten::size(%0, %5) + %7 : Long() = prim::NumToTensor(%6) + %8 : int = prim::TensorToNum(%7) + %9 : int[] = prim::ListConstruct(%4, %8) + %10 : int = prim::Constant[value=7]() + %11 : int = prim::Constant[value=0]() + %12 : int[] = prim::Constant[value=[0, -1]]() + %13 : Double(4, 4) = aten::zeros(%9, %10, %11, %12) + %14 : Double(4, 4) = aten::expand_as(%0, %13) + return (%14); +} diff --git a/test/expect/TestJit.test_pretty_printer-if_one.expect b/test/expect/TestJit.test_pretty_printer-if_one.expect index 3a9254be45..cbede6c812 100644 --- a/test/expect/TestJit.test_pretty_printer-if_one.expect +++ b/test/expect/TestJit.test_pretty_printer-if_one.expect @@ -1,6 +1,6 @@ def script(c2, c1): t2 = aten::lt(c2, c1) - t3 = prim::TensorToNum(t2) + t3 = prim::TensorToBool(t2) if t3: c = c2 else: diff --git a/test/expect/TestJit.test_pretty_printer-if_test.expect b/test/expect/TestJit.test_pretty_printer-if_test.expect index 130c9b0c5a..c7e2249078 100644 --- a/test/expect/TestJit.test_pretty_printer-if_test.expect +++ b/test/expect/TestJit.test_pretty_printer-if_test.expect @@ -1,6 +1,6 @@ def script(c2, c1): t2 = aten::lt(c2, c1) - t3 = prim::TensorToNum(t2) + t3 = prim::TensorToBool(t2) if t3: c = c1 else: diff --git a/test/expect/TestJit.test_pretty_printer-loop_use_test.expect b/test/expect/TestJit.test_pretty_printer-loop_use_test.expect index 4e35ad2150..01934775d8 100644 --- a/test/expect/TestJit.test_pretty_printer-loop_use_test.expect +++ b/test/expect/TestJit.test_pretty_printer-loop_use_test.expect @@ -2,14 +2,14 @@ def script(y1): x = aten::add(y1, 1, 1) z1 = aten::add(x, 5, 1) t9 = aten::lt(y1, 8) - t10 = prim::TensorToNum(t9) + t10 = prim::TensorToBool(t9) y = y1 z = z1 t11 = t10 while t11: y2 = aten::add(y, 1, 1) t17 = aten::lt(y2, 8) - t18 = prim::TensorToNum(t17) + t18 = prim::TensorToBool(t17) t11 = t18 y = y2 z = x diff --git a/test/expect/TestJit.test_pretty_printer-while_if_test.expect b/test/expect/TestJit.test_pretty_printer-while_if_test.expect index c830784510..70d4af6150 100644 --- a/test/expect/TestJit.test_pretty_printer-while_if_test.expect +++ b/test/expect/TestJit.test_pretty_printer-while_if_test.expect @@ -1,6 +1,6 @@ def script(a1, b1): t5 = aten::lt(a1, 10) - t6 = prim::TensorToNum(t5) + t6 = prim::TensorToBool(t5) a = a1 b = b1 c = 0 @@ -9,13 +9,13 @@ def script(a1, b1): a2 = aten::add(a, 1, 1) b2 = aten::add(b, 1, 1) t15 = aten::gt(a2, b2) - t16 = prim::TensorToNum(t15) + t16 = prim::TensorToBool(t15) if t16: c4 = 2 else: c4 = 3 t21 = aten::lt(a2, 10) - t22 = prim::TensorToNum(t21) + t22 = prim::TensorToBool(t21) t7 = t22 a = a2 b = b2 diff --git a/test/expect/TestJit.test_pretty_printer-while_test.expect b/test/expect/TestJit.test_pretty_printer-while_test.expect index 487087ad56..f99e4721f0 100644 --- a/test/expect/TestJit.test_pretty_printer-while_test.expect +++ b/test/expect/TestJit.test_pretty_printer-while_test.expect @@ -1,6 +1,6 @@ def script(a1, i1): t4 = aten::lt(i1, 3) - t5 = prim::TensorToNum(t4) + t5 = prim::TensorToBool(t4) a = a1 i = i1 t6 = t5 @@ -8,7 +8,7 @@ def script(a1, i1): a2 = aten::mul(a, a) i2 = aten::add(i, 1, 1) t13 = aten::lt(i2, 3) - t14 = prim::TensorToNum(t13) + t14 = prim::TensorToBool(t13) t6 = t14 a = a2 i = i2 diff --git a/test/expect/TestJit.test_recursive_cse.expect b/test/expect/TestJit.test_recursive_cse.expect index a2aa84b5d6..dd56a51193 100644 --- a/test/expect/TestJit.test_recursive_cse.expect +++ b/test/expect/TestJit.test_recursive_cse.expect @@ -3,7 +3,7 @@ graph(%z.1 : Dynamic %2 : int = prim::Constant[value=1]() %3 : Dynamic = aten::add(%z.1, %y, %2) %4 : Dynamic = aten::gt(%3, %z.1) - %5 : int = prim::TensorToNum(%4) + %5 : bool = prim::TensorToBool(%4) %z : Dynamic = prim::If(%5) block0() { -> (%3) diff --git a/test/expect/TestScript.test_if_list.expect b/test/expect/TestScript.test_if_list.expect index fddb6a173a..93c942d4dc 100644 --- a/test/expect/TestScript.test_if_list.expect +++ b/test/expect/TestScript.test_if_list.expect @@ -1,5 +1,5 @@ graph(%x : Double(*, *)) { - %1 : int = prim::Constant[value=1]() + %1 : bool = prim::Constant[value=1]() %c : Dynamic[] = prim::If(%1) block0() { %c.1 : Dynamic[] = prim::ListConstruct(%x, %x) diff --git a/test/expect/TestScript.test_if_supertype.expect b/test/expect/TestScript.test_if_supertype.expect index 3b58ca8928..e57f49a27c 100644 --- a/test/expect/TestScript.test_if_supertype.expect +++ b/test/expect/TestScript.test_if_supertype.expect @@ -1,7 +1,7 @@ graph(%y.1 : Float(*, *) %z.2 : Long(*, *) %z.1 : Float(*, *)) { - %3 : int = prim::Constant[value=1]() + %3 : bool = prim::Constant[value=1]() %x : Float(*, *), %y : Dynamic, %z : Dynamic = prim::If(%3) block0() { -> (%y.1, %z.2, %z.1) diff --git a/test/expect/TestScript.test_index_put_trace_with_view.expect b/test/expect/TestScript.test_index_put_trace_with_view.expect index cc03d3d529..12bae3f8b4 100644 --- a/test/expect/TestScript.test_index_put_trace_with_view.expect +++ b/test/expect/TestScript.test_index_put_trace_with_view.expect @@ -4,7 +4,7 @@ graph(%0 : Double(100) %3 : int = prim::Constant[value=4]() %4 : int[] = prim::ListConstruct(%3) %5 : Double(4) = aten::view(%2, %4) - %6 : int = prim::Constant[value=0]() + %6 : bool = prim::Constant[value=0]() %7 : Long(4) = aten::_cast_Long(%1, %6) %8 : Dynamic[] = prim::ListConstruct(%7) %9 : Double(100) = aten::index_put(%0, %8, %5) diff --git a/test/expect/TestScript.test_index_put_trace_without_view.expect b/test/expect/TestScript.test_index_put_trace_without_view.expect index c725067960..8e5da7efd4 100644 --- a/test/expect/TestScript.test_index_put_trace_without_view.expect +++ b/test/expect/TestScript.test_index_put_trace_without_view.expect @@ -1,7 +1,7 @@ graph(%0 : Double(100) %1 : Long(4) %2 : Double(4)) { - %3 : int = prim::Constant[value=0]() + %3 : bool = prim::Constant[value=0]() %4 : Long(4) = aten::_cast_Long(%1, %3) %5 : Dynamic[] = prim::ListConstruct(%4) %6 : Double(100) = aten::index_put(%0, %5, %2) diff --git a/test/expect/TestScript.test_logical_short_circuit.expect b/test/expect/TestScript.test_logical_short_circuit.expect index bcd34a7cb0..0a4569c2b4 100644 --- a/test/expect/TestScript.test_logical_short_circuit.expect +++ b/test/expect/TestScript.test_logical_short_circuit.expect @@ -1,31 +1,31 @@ graph(%t : Dynamic) { %c1.2 : int = prim::Constant[value=0]() %c1.1 : int = prim::Constant[value=1]() - %3 : int = prim::Constant[value=0]() - %4 : int = prim::If(%3) + %3 : bool = prim::Constant[value=0]() + %4 : bool = prim::If(%3) block0() { %5 : int = prim::Constant[value=0]() %6 : Dynamic = aten::select(%t, %5, %c1.1) - %7 : int = prim::TensorToNum(%6) + %7 : bool = prim::TensorToBool(%6) -> (%7) } block1() { -> (%3) } - %8 : int = prim::If(%4) + %8 : bool = prim::If(%4) block0() { -> (%4) } block1() { - %9 : int = prim::Constant[value=1]() - %10 : int = prim::If(%9) + %9 : bool = prim::Constant[value=1]() + %10 : bool = prim::If(%9) block0() { -> (%9) } block1() { %11 : int = prim::Constant[value=0]() %12 : Dynamic = aten::select(%t, %11, %c1.1) - %13 : int = prim::TensorToNum(%12) + %13 : bool = prim::TensorToBool(%12) -> (%13) } -> (%10) diff --git a/test/expect/TestScript.test_loop_unroll_unused_counter.expect b/test/expect/TestScript.test_loop_unroll_unused_counter.expect index 292b251c75..feb204a925 100644 --- a/test/expect/TestScript.test_loop_unroll_unused_counter.expect +++ b/test/expect/TestScript.test_loop_unroll_unused_counter.expect @@ -2,7 +2,7 @@ graph(%x : Dynamic) { %1 : int = prim::Constant[value=1]() %y.1 : int = prim::Constant[value=0]() %3 : int = prim::TensorToNum(%x) - %4 : int = prim::Constant[value=1]() + %4 : bool = prim::Constant[value=1]() %5 : int = prim::Constant[value=8]() %6 : int = aten::floordiv(%3, %5) %7 : int = prim::Constant[value=8]() @@ -18,13 +18,13 @@ graph(%x : Dynamic) { %y.9 : int = aten::add(%y.8, %1) %y.10 : int = aten::add(%y.9, %1) %y.11 : int = aten::add(%y.10, %1) - %21 : int = prim::Constant[value=1]() + %21 : bool = prim::Constant[value=1]() -> (%21, %y.11) } %y : int = prim::Loop(%9, %4, %y.3) block0(%i : int, %24 : int) { %y.4 : int = aten::add(%24, %1) - %26 : int = prim::Constant[value=1]() + %26 : bool = prim::Constant[value=1]() -> (%26, %y.4) } return (%y); diff --git a/test/expect/TestScript.test_loop_unrolling.expect b/test/expect/TestScript.test_loop_unrolling.expect index 29546b5c1e..5f3e61a4ce 100644 --- a/test/expect/TestScript.test_loop_unrolling.expect +++ b/test/expect/TestScript.test_loop_unrolling.expect @@ -1,7 +1,7 @@ graph(%x : Dynamic) { %y.1 : int = prim::Constant[value=0]() %2 : int = prim::TensorToNum(%x) - %3 : int = prim::Constant[value=1]() + %3 : bool = prim::Constant[value=1]() %4 : int = prim::Constant[value=0]() %5 : int = prim::Constant[value=8]() %6 : int = aten::floordiv(%2, %5) @@ -32,7 +32,7 @@ graph(%x : Dynamic) { %34 : int = prim::Constant[value=1]() %35 : int = aten::add(%32, %34) %y.11 : int = aten::add(%y.10, %35) - %37 : int = prim::Constant[value=1]() + %37 : bool = prim::Constant[value=1]() %38 : int = prim::Constant[value=1]() %39 : int = aten::add(%35, %38) -> (%37, %39, %y.11) @@ -40,7 +40,7 @@ graph(%x : Dynamic) { %40 : Dynamic, %y : int = prim::Loop(%9, %3, %10, %y.3) block0(%i : int, %43 : int, %44 : int) { %y.4 : int = aten::add(%44, %43) - %46 : int = prim::Constant[value=1]() + %46 : bool = prim::Constant[value=1]() %47 : int = prim::Constant[value=1]() %48 : int = aten::add(%43, %47) -> (%46, %48, %y.4) diff --git a/test/expect/TestScript.test_loop_unrolling_nested.expect b/test/expect/TestScript.test_loop_unrolling_nested.expect index a107eb81f7..2668b111ab 100644 --- a/test/expect/TestScript.test_loop_unrolling_nested.expect +++ b/test/expect/TestScript.test_loop_unrolling_nested.expect @@ -1,11 +1,11 @@ graph(%x : Dynamic) { %1 : int = prim::Constant[value=10]() %y.1 : int = prim::Constant[value=0]() - %3 : int = prim::Constant[value=1]() + %3 : bool = prim::Constant[value=1]() %y : int = prim::Loop(%1, %3, %y.1) block0(%i : int, %6 : int) { %7 : int = prim::TensorToNum(%x) - %8 : int = prim::Constant[value=1]() + %8 : bool = prim::Constant[value=1]() %9 : int = prim::Constant[value=0]() %10 : int = prim::Constant[value=8]() %11 : int = aten::floordiv(%7, %10) @@ -36,7 +36,7 @@ graph(%x : Dynamic) { %39 : int = prim::Constant[value=1]() %40 : int = aten::add(%37, %39) %y.12 : int = aten::add(%y.11, %40) - %42 : int = prim::Constant[value=1]() + %42 : bool = prim::Constant[value=1]() %43 : int = prim::Constant[value=1]() %44 : int = aten::add(%40, %43) -> (%42, %44, %y.12) @@ -44,12 +44,12 @@ graph(%x : Dynamic) { %45 : Dynamic, %y.3 : int = prim::Loop(%14, %8, %15, %y.4) block0(%j : int, %48 : int, %49 : int) { %y.5 : int = aten::add(%49, %48) - %51 : int = prim::Constant[value=1]() + %51 : bool = prim::Constant[value=1]() %52 : int = prim::Constant[value=1]() %53 : int = aten::add(%48, %52) -> (%51, %53, %y.5) } - %54 : int = prim::Constant[value=1]() + %54 : bool = prim::Constant[value=1]() -> (%54, %y.3) } return (%y); diff --git a/test/expect/TestScript.test_sum-1.expect b/test/expect/TestScript.test_sum-1.expect index a2bb9d4417..ce0d1fe0b5 100644 --- a/test/expect/TestScript.test_sum-1.expect +++ b/test/expect/TestScript.test_sum-1.expect @@ -1,7 +1,7 @@ graph(%x : Dynamic) { %1 : int = prim::Constant[value=4]() %2 : int[] = prim::ListConstruct(%1) - %3 : int = prim::Constant[value=0]() + %3 : bool = prim::Constant[value=0]() %4 : Dynamic = aten::sum(%x, %2, %3) return (%4); } diff --git a/test/expect/TestScript.test_sum-2.expect b/test/expect/TestScript.test_sum-2.expect index 4b2352de81..a952901438 100644 --- a/test/expect/TestScript.test_sum-2.expect +++ b/test/expect/TestScript.test_sum-2.expect @@ -1,7 +1,7 @@ graph(%x : Double(*, *, *, *, *)) { %1 : int = prim::Constant[value=4]() %2 : int[] = prim::ListConstruct(%1) - %3 : int = prim::Constant[value=0]() + %3 : bool = prim::Constant[value=0]() %4 : Dynamic = aten::sum(%x, %2, %3) return (%4); } diff --git a/test/expect/TestScript.test_type_cast-float_to_int.expect b/test/expect/TestScript.test_type_cast-test_float_to_int.expect index 138b6e4913..ffc45ac565 100644 --- a/test/expect/TestScript.test_type_cast-float_to_int.expect +++ b/test/expect/TestScript.test_type_cast-test_float_to_int.expect @@ -1,6 +1,6 @@ graph() { %0 : int = prim::Constant[value=1]() - %1 : float = prim::Constant[value=2]() + %1 : float = prim::Constant[value=5]() %b : int = prim::FloatToInt(%1) %3 : int = aten::add(%b, %0) return (%3); diff --git a/test/expect/TestScript.test_type_cast-int_to_float.expect b/test/expect/TestScript.test_type_cast-test_int_to_float.expect index 719cc1e881..719cc1e881 100644 --- a/test/expect/TestScript.test_type_cast-int_to_float.expect +++ b/test/expect/TestScript.test_type_cast-test_int_to_float.expect diff --git a/test/test_jit.py b/test/test_jit.py index 24d8076d31..08ebca8af1 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -1365,7 +1365,6 @@ class TestJit(JitTestCase): # as all backwards functions of views are implemented # as a zero filled tensor with a gradient fill on the # viewed portion. - @unittest.expectedFailure def test_inplace_copy(self): x = torch.randn(4, 4, requires_grad=True) @@ -3763,7 +3762,7 @@ a") self.checkScript(func, inputs, optimize=True) def test_explicit_bool_cast(self): - with self.assertRaisesRegex(RuntimeError, "expected an integer"): + with self.assertRaisesRegex(RuntimeError, "expected a boolean"): @torch.jit.script def test_bool_cast(a): if a: @@ -3987,18 +3986,38 @@ a") throwsAnd(t) def test_type_cast(self): + @torch.jit.script def test_int_to_float(): b = float(2) return b + 1.0 + with self.assertRaisesRegex(RuntimeError, "Cannot cast type"): + @torch.jit.script + def test_int_to_bool(): + return bool(5) + + @torch.jit.script def test_float_to_int(): - b = int(2.0) + b = int(5.0) return b + 1 - graph1 = torch.jit.script(test_int_to_float).graph - self.assertExpectedGraph(graph1, subname="int_to_float") - graph2 = torch.jit.script(test_float_to_int).graph - self.assertExpectedGraph(graph2, subname="float_to_int") + with self.assertRaisesRegex(RuntimeError, "Cannot cast type"): + @torch.jit.script + def test_float_to_bool(): + return bool(5.0) + + with self.assertRaisesRegex(RuntimeError, "Cannot cast type"): + @torch.jit.script + def test_bool_to_float(): + return float(True) + + with self.assertRaisesRegex(RuntimeError, "Cannot cast type"): + @torch.jit.script + def test_bool_to_int(): + return int(True) + + self.assertExpectedGraph(test_int_to_float.graph, "test_int_to_float") + self.assertExpectedGraph(test_float_to_int.graph, "test_float_to_int") def test_multiple_assignment(self): def outer_func(x): @@ -7757,27 +7776,6 @@ EXCLUDE_TYPE_CHECK = { # known to be failing in script EXCLUDE_SCRIPT = { - # TODO: Fix var/std - # there are two schemas for var (and std): - # (1) var(Tensor, int, *, bool, bool, Tensor) - # (2) var(Tensor, *, bool) - # - # Right now, the following is happening: - # - Shorter schemas come before longer schemas - # - bool, int are treated as IntType rather than DynamicType like before - # So the schemas look like the following in operator: - # (2) var(DynamicType, IntType) - # (1) var(DynamicType, IntType, IntType, DynamicType) - # Now, when one calls torch.var(tensor, dim=1), the compiler mistakingly - # matches it with (2) instead of (1), which is a problem. - 'test_std_dim', - 'test_std_dim_1d', - 'test_std_dim_1d_neg0', - 'test_std_dim_neg0', - 'test_var_dim', - 'test_var_dim_1d', - 'test_var_dim_1d_neg0', - 'test_var_dim_neg0', 'test_norm_fro', 'test_norm_fro_default', 'test_norm_nuc', diff --git a/tools/jit/gen_jit_dispatch.py b/tools/jit/gen_jit_dispatch.py index f6fdc7505d..4fc923589d 100644 --- a/tools/jit/gen_jit_dispatch.py +++ b/tools/jit/gen_jit_dispatch.py @@ -64,7 +64,7 @@ FROM_IVALUE = { 'ScalarType': '{}.to<at::ScalarType>()', 'Tensor': '{}.toTensor()', 'TensorList': '{}.toTensorList()->elements()', - 'bool': 'bool({}.toInt())', + 'bool': '{}.toBool()', 'double': '{}.toDouble()', 'int64_t': '{}.toInt()', 'std::string': '{}.toString()->string()', diff --git a/torch/csrc/jit/attributes.h b/torch/csrc/jit/attributes.h index 2610643918..af3a16393f 100644 --- a/torch/csrc/jit/attributes.h +++ b/torch/csrc/jit/attributes.h @@ -167,6 +167,7 @@ struct Attributes { const Kind##Attr::ValueType& method(Symbol name) const { \ return get<Kind##Attr>(name); \ } + CREATE_ACCESSOR(Float,f) CREATE_ACCESSOR(Floats,fs) CREATE_ACCESSOR(String,s) diff --git a/torch/csrc/jit/autodiff.cpp b/torch/csrc/jit/autodiff.cpp index a59c856eab..46840fc99a 100644 --- a/torch/csrc/jit/autodiff.cpp +++ b/torch/csrc/jit/autodiff.cpp @@ -280,7 +280,7 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val } else if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) { return {grads.at(0).mm(inputs.at(1).t()), inputs.at(0).t().mm(grads.at(0))}; - } else if (node->matches("aten::expand(Tensor self, int[] size, *, int implicit) -> Tensor")) { + } else if (node->matches("aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor")) { const auto& input_sizes = inputs.at(0).sizes(); if (input_sizes.size() == 0) return {grads.at(0).sum(), nullptr, nullptr}; diff --git a/torch/csrc/jit/constants.cpp b/torch/csrc/jit/constants.cpp index 1633ac0de4..c62cdbd1fa 100644 --- a/torch/csrc/jit/constants.cpp +++ b/torch/csrc/jit/constants.cpp @@ -27,6 +27,13 @@ Value* insertConstant( } else if(val.isDouble()) { n->f_(attr::value, val.toDouble()); n->output()->setType(FloatType::get()); + } else if (val.isBool()) { + n->i_(attr::value, val.toBool()); + n->output()->setType(BoolType::get()); + } else if (val.isBoolList()) { + auto bool_list = val.toBoolList()->elements(); + n->is_(attr::value, std::vector<int64_t>(bool_list.begin(), bool_list.end())); + n->output()->setType(ListType::ofBools()); } else if(val.isIntList()) { n->is_(attr::value, val.toIntList()->elements()); n->output()->setType(ListType::ofInts()); @@ -64,6 +71,12 @@ RegisterOperators reg({ stack.push_back(t); return 0; }; + } else if (type->isSubtypeOf(BoolType::get())) { + bool b = node->i(attr::value); + return [b](Stack& stack) { + push(stack, b); + return 0; + }; } else if ( type->isSubtypeOf(NumberType::get()) && node->kindOf(attr::value) == AttributeKind::i) { @@ -86,6 +99,12 @@ RegisterOperators reg({ push(stack, is); return 0; }; + } else if(type->isSubtypeOf(ListType::ofBools())) { + auto bs = node->is(attr::value); + return [bs](Stack& stack) { + push(stack, bs); + return 0; + }; } else if(type->isSubtypeOf(ListType::ofTensors())) { auto ts = fmap(node->ts(attr::value), [](const at::Tensor & t) -> at::Tensor { return autograd::make_variable(t); diff --git a/torch/csrc/jit/export.cpp b/torch/csrc/jit/export.cpp index 973780b7d7..574dae168a 100644 --- a/torch/csrc/jit/export.cpp +++ b/torch/csrc/jit/export.cpp @@ -559,6 +559,8 @@ void ModuleEncoder::EncodeTypeInfo( type_proto->set_denotation("FloatType"); } else if (kind == TypeKind::IntType) { type_proto->set_denotation("IntType"); + } else if (kind == TypeKind::BoolType) { + type_proto->set_denotation("BoolType"); } else if (kind == TypeKind::NoneType) { type_proto->set_denotation("NoneType"); } else if (kind == TypeKind::GeneratorType) { diff --git a/torch/csrc/jit/import.cpp b/torch/csrc/jit/import.cpp index 4574addb3a..45053a81db 100644 --- a/torch/csrc/jit/import.cpp +++ b/torch/csrc/jit/import.cpp @@ -257,6 +257,8 @@ TypePtr ModuleDecoder::buildType(const onnx::TypeProto& type_proto) { return FloatType::get(); } else if (kind == "IntType") { return IntType::get(); + } else if (kind == "BoolType") { + return BoolType::get(); } else if (kind == "NoneType") { return NoneType::get(); } else if (kind == "GeneratorType") { diff --git a/torch/csrc/jit/interned_strings.h b/torch/csrc/jit/interned_strings.h index b4e6b7c139..32544aa682 100644 --- a/torch/csrc/jit/interned_strings.h +++ b/torch/csrc/jit/interned_strings.h @@ -46,9 +46,11 @@ namespace torch { namespace jit { _(prim, TupleUnpack) \ _(prim, ListConstruct) \ _(prim, ListUnpack) \ + _(prim, BoolToTensor) \ _(prim, NumToTensor) \ _(prim, TensorToNum) \ _(prim, ImplicitTensorToNum) \ + _(prim, TensorToBool) \ _(prim, IntToFloat) \ _(prim, FloatToInt) \ _(prim, StringToFloat) \ diff --git a/torch/csrc/jit/interpreter.cpp b/torch/csrc/jit/interpreter.cpp index 14e7fab54d..065054a0fc 100644 --- a/torch/csrc/jit/interpreter.cpp +++ b/torch/csrc/jit/interpreter.cpp @@ -61,14 +61,14 @@ Value* createTripCountConjunctiveCondition( // Emit initial comparison -- initial_trip_count < max_trip_count Value* initial_comparison_value = g->insertNode(g->create(aten::lt, {cur_trip_count, max_trip_count}, 1)) - ->output()->setType(IntType::get()); + ->output()->setType(BoolType::get()); // Replace initial condition with logical `and` of trip count and // initial condition Value* new_cond = g->insertNode( g->create(aten::__and__, {initial_comparison_value, cond}, 1)) - ->output()->setType(IntType::get()); + ->output()->setType(BoolType::get()); return new_cond; } @@ -388,30 +388,29 @@ struct CodeImpl { CodeImpl(std::shared_ptr<Graph>& graph_) : preprocess(*graph_) { graph = preprocess.graph; - // std::cout << "into code graph:\n" << *graph << "\n"; insertNodesFromBlock(graph->block()); } - // jump when input is 0 - void createJumpZ(int from_inst, int to_inst) { + // jump when input is false + void createJumpFalse(int from_inst, int to_inst) { auto & inst = instructions[from_inst]; JIT_ASSERT(inst.debug_name == prim::Placeholder); auto offset = relativeJump(from_inst, to_inst); inst.callback = [offset](Stack & stack) { - auto t = pop(stack).toInt(); - return (t == 0) ? offset : 0; + auto t = pop(stack).toBool(); + return t ? 0 : offset; }; inst.debug_name = prim::JumpZ; } - // jump when input is not 0 - void createJumpNZ(int from_inst, int to_inst) { + // jump when input is true + void createJumpTrue(int from_inst, int to_inst) { auto & inst = instructions[from_inst]; JIT_ASSERT(inst.debug_name == prim::Placeholder); auto offset = relativeJump(from_inst, to_inst); inst.callback = [offset](Stack & stack) { - auto t = pop(stack).toInt(); - return (t != 0) ? offset : 0; + auto t = pop(stack).toBool(); + return t ? offset : 0; }; inst.debug_name = prim::JumpNZ; } @@ -460,7 +459,7 @@ struct CodeImpl { insertNodesFromBlock(then_block); insertAssign(source_location, then_block->outputs(), moveFlags(then_block), node->outputs()); createJump(jump, instructions.size()); - createJumpNZ(cond_branch, then_block_start); + createJumpTrue(cond_branch, then_block_start); } break; case prim::Loop: { // o0 = while c i0 @@ -495,8 +494,8 @@ struct CodeImpl { // after branch: stack: ... aliasRegistersTo(node->outputs(), body_block->inputs()); - createJumpZ(cond_branch, instructions.size()); - createJumpNZ(cond_branch_end, entry); + createJumpFalse(cond_branch, instructions.size()); + createJumpTrue(cond_branch_end, entry); } break; default: { insertInstruction(node); diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp index 90451494ba..4e4bb0b68d 100644 --- a/torch/csrc/jit/ir.cpp +++ b/torch/csrc/jit/ir.cpp @@ -562,18 +562,18 @@ namespace { const OperatorSet& nondeterminstic_aten_ops() { static OperatorSet nondeterministic_ops = { - "aten::dropout(Tensor input, float p, int train) -> Tensor", + "aten::dropout(Tensor input, float p, bool train) -> Tensor", "aten::_fused_dropout(Tensor self, float p, Generator generator) -> (Tensor, Tensor)", "aten::_standard_gamma(Tensor self, Generator generator) -> Tensor", "aten::bernoulli(Tensor self, *, Generator generator) -> Tensor", "aten::bernoulli(Tensor self, float p, *, Generator generator) -> Tensor", - "aten::multinomial(Tensor self, int num_samples, int replacement, *, Generator generator) -> Tensor", + "aten::multinomial(Tensor self, int num_samples, bool replacement, *, Generator generator) -> Tensor", "aten::normal(Tensor mean, Tensor std, *, Generator generator) -> Tensor", "aten::normal(float mean, Tensor std, *, Generator generator) -> Tensor", "aten::normal(Tensor mean, float std, *, Generator generator) -> Tensor", "aten::poisson(Tensor self, Generator generator) -> Tensor", - "aten::rrelu(Tensor self, Scalar lower, Scalar upper, int training, Generator generator) -> Tensor", - "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, int training, Generator generator) -> Tensor", + "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator generator) -> Tensor", + "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator generator) -> Tensor", "aten::rand(int[] size, *, int dtype, int layout, int[] device) -> Tensor", "aten::rand_like(Tensor self) -> Tensor", "aten::rand_like(Tensor self, *, int dtype, int layout, int[] device) -> Tensor", @@ -598,7 +598,7 @@ bool Node::isNondeterministic() const { return false; } // Dropout with train = False is deterministic - if (matches("aten::dropout(Tensor input, float p, int train) -> Tensor") && is_constant(attr::train) && !get<bool>(attr::train).value()) { + if (matches("aten::dropout(Tensor input, float p, bool train) -> Tensor") && is_constant(attr::train) && !get<bool>(attr::train).value()) { return false; } return true; diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h index 0bb5c899c7..6fdb49e7d9 100644 --- a/torch/csrc/jit/ir.h +++ b/torch/csrc/jit/ir.h @@ -1050,6 +1050,15 @@ public: result->output()->setType(CompleteTensorType::fromNumberType(typ)); return result; } + Node* createBoolToTensor(Value* value) { + auto typ = value->type(); + Node * result = create(prim::BoolToTensor, {value}); + if (!typ->isSubtypeOf(BoolType::get())) { + AT_ERROR("Cannot create bool type from ", typ->str()); + } + result->output()->setType(CompleteTensorType::fromBoolType()); + return result; + } Node* createTensorToNum(const TypePtr& type, Value* value) { auto* result = create(prim::TensorToNum, {value}); result->output()->setType(type); @@ -1060,6 +1069,11 @@ public: result->output()->setType(type); return result; } + Node* createTensorToBool(Value* value) { + auto* result = create(prim::TensorToBool, {value}); + result->output()->setType(BoolType::get()); + return result; + } Node* createIntToFloat(Value* value) { JIT_ASSERT(*value->type() == *IntType::get()); auto* result = create(prim::IntToFloat, {value}); diff --git a/torch/csrc/jit/operator.cpp b/torch/csrc/jit/operator.cpp index d701b536c4..818a58c965 100644 --- a/torch/csrc/jit/operator.cpp +++ b/torch/csrc/jit/operator.cpp @@ -57,7 +57,7 @@ struct SchemaParser { {"str", StringType::get() }, {"float", FloatType::get() }, {"int", IntType::get() }, - {"bool", IntType::get() }, // TODO: add separate bool type + {"bool", BoolType::get() }, {"World", WorldType::get() }, }; auto tok = L.expect(TK_IDENT); @@ -162,6 +162,10 @@ struct SchemaParser { return fmap(vs, [](IValue v) { return v.toInt(); }); + case TypeKind::BoolType: + return fmap(vs, [](IValue v) { + return v.toBool(); + }); default: throw ErrorReport(range) << "lists are only supported for float or int types."; } @@ -191,6 +195,7 @@ struct SchemaParser { } break; case TypeKind::NumberType: case TypeKind::IntType: + case TypeKind::BoolType: case TypeKind::FloatType: arg.default_value = parseSingleConstant(arg.type->kind()); break; diff --git a/torch/csrc/jit/passes/erase_number_types.cpp b/torch/csrc/jit/passes/erase_number_types.cpp index 7d2ba5bfeb..115d4b6a68 100644 --- a/torch/csrc/jit/passes/erase_number_types.cpp +++ b/torch/csrc/jit/passes/erase_number_types.cpp @@ -13,13 +13,16 @@ static void EraseNumberTypesOnBlock(Block* block) { case prim::Constant: { // remove primitive constants, replacing with tensor equivalent // ONNX does not support non-tensor constants - if(it->output()->type()->isSubtypeOf(NumberType::get())) { + if (it->output()->type()->isSubtypeOf(NumberType::get()) || + it->output()->type()->isSubtypeOf(BoolType::get())) { auto s = *constant_as<at::Scalar>(it->output()); WithInsertPoint guard(*it); Value* r = block->owningGraph()->insertConstant(scalar_to_tensor(s)); it->output()->replaceAllUsesWith(r); } } break; + case prim::TensorToBool: + case prim::BoolToTensor: case prim::TensorToNum: case prim::ImplicitTensorToNum: case prim::NumToTensor: { @@ -30,6 +33,8 @@ static void EraseNumberTypesOnBlock(Block* block) { for(auto o : it->outputs()) { if (o->type()->isSubtypeOf(NumberType::get())) { o->setType(CompleteTensorType::fromNumberType(o->type())); + } else if (o->type()->isSubtypeOf(BoolType::get())) { + o->setType(CompleteTensorType::fromBoolType()); } } } break; diff --git a/torch/csrc/jit/passes/peephole.cpp b/torch/csrc/jit/passes/peephole.cpp index 176166218f..7f2312696e 100644 --- a/torch/csrc/jit/passes/peephole.cpp +++ b/torch/csrc/jit/passes/peephole.cpp @@ -24,7 +24,7 @@ void PeepholeOptimize(Block * block) { // XXX: remember that if you want to simplify an expression by combining multiple nodes // into a different one, then you need to check that they all belong to the given block - if (node->matches("aten::expand(Tensor self, int[] size, *, int implicit) -> Tensor", + if (node->matches("aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor", /*with_const=*/attr::size)) { // x.expand(x.size()) == x if (auto input_type = node->namedInput(attr::self)->type()->cast<CompleteTensorType>()) { diff --git a/torch/csrc/jit/passes/remove_expands.cpp b/torch/csrc/jit/passes/remove_expands.cpp index 93d53e5481..1c67e68507 100644 --- a/torch/csrc/jit/passes/remove_expands.cpp +++ b/torch/csrc/jit/passes/remove_expands.cpp @@ -8,7 +8,8 @@ static void RemoveExpands(Block* block) { ++it) { for (auto sub : it->blocks()) RemoveExpands(sub); - if (it->kind() == aten::expand && it->get<int64_t>(attr::implicit) != static_cast<int64_t>(0)) { + + if (it->kind() == aten::expand && it->get<bool>(attr::implicit) == true) { it->output()->replaceAllUsesWith(it->namedInput(attr::self)); it.destroyCurrent(); } diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index eedc7fd0a8..af5e8cd96f 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -120,7 +120,7 @@ void broadcastBinary(Node *node, std::vector<CompleteTensorTypePtr>& types, size Node *expand = graph->create(aten::expand, {node->inputs().at(input_idx), graph->insertConstant(expected_size), - graph->insertConstant(0)}) + graph->insertConstant(false)}) ->insertBefore(node); PropagateShapeOnNode(expand); node->replaceInput(input_idx, expand->output()); @@ -441,12 +441,12 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) { "aten::clamp(Tensor self, Scalar min, Scalar max) -> Tensor", "aten::clamp_max(Tensor self, Scalar max) -> Tensor", "aten::clamp_min(Tensor self, Scalar min) -> Tensor", - "aten::alpha_dropout(Tensor input, float p, int train) -> Tensor", + "aten::alpha_dropout(Tensor input, float p, bool train) -> Tensor", "aten::bernoulli(Tensor self, float p, *, Generator generator) -> Tensor", "aten::cos(Tensor self) -> Tensor", "aten::cosh(Tensor self) -> Tensor", "aten::digamma(Tensor self) -> Tensor", - "aten::dropout(Tensor input, float p, int train) -> Tensor", + "aten::dropout(Tensor input, float p, bool train) -> Tensor", "aten::elu(Tensor self, Scalar alpha, Scalar scale, Scalar input_scale) -> Tensor", "aten::erf(Tensor self) -> Tensor", "aten::erfc(Tensor self) -> Tensor", @@ -462,8 +462,8 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) { "aten::floor(Tensor self) -> Tensor", "aten::frac(Tensor self) -> Tensor", "aten::flip(Tensor self, int[] dims) -> Tensor", - "aten::feature_alpha_dropout(Tensor input, float p, int train) -> Tensor", - "aten::feature_dropout(Tensor input, float p, int train) -> Tensor", + "aten::feature_alpha_dropout(Tensor input, float p, bool train) -> Tensor", + "aten::feature_dropout(Tensor input, float p, bool train) -> Tensor", "aten::hardshrink(Tensor self, Scalar lambd) -> Tensor", "aten::hardtanh(Tensor self, Scalar min_val, Scalar max_val) -> Tensor", "aten::glu(Tensor self, int dim) -> Tensor", @@ -479,7 +479,7 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) { "aten::reciprocal(Tensor self) -> Tensor", "aten::relu(Tensor self) -> Tensor", "aten::round(Tensor self) -> Tensor", - "aten::rrelu(Tensor self, Scalar lower, Scalar upper, int training, Generator generator) -> Tensor", + "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator generator) -> Tensor", "aten::rsqrt(Tensor self) -> Tensor", "aten::selu(Tensor self) -> Tensor", "aten::sigmoid(Tensor self) -> Tensor", @@ -645,7 +645,7 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) { // Knowing the type and device of weights or biases is usually enough to // infer the output type. static const register_formula_for nn_ops_first_input_preserving {{ - "aten::batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, int training, float momentum, float eps, int cudnn_enabled) -> Tensor", + "aten::batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor", "aten::conv1d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor", "aten::conv2d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor", "aten::conv3d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor", @@ -653,16 +653,16 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) { "aten::conv_transpose1d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor", "aten::conv_transpose2d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor", "aten::conv_transpose3d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor", - "aten::convolution(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, int transposed, int[] output_padding, int groups) -> Tensor", + "aten::convolution(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor", "aten::adaptive_avg_pool1d(Tensor self, int[] output_size) -> Tensor", "aten::adaptive_avg_pool2d(Tensor self, int[] output_size) -> Tensor", "aten::adaptive_avg_pool3d(Tensor self, int[] output_size) -> Tensor", - "aten::avg_pool1d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int ceil_mode, int count_include_pad) -> Tensor", - "aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int ceil_mode, int count_include_pad) -> Tensor", - "aten::avg_pool3d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int ceil_mode, int count_include_pad) -> Tensor", - "aten::max_pool1d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, int ceil_mode) -> Tensor", - "aten::max_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, int ceil_mode) -> Tensor", - "aten::max_pool3d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, int ceil_mode) -> Tensor", + "aten::avg_pool1d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor", + "aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor", + "aten::avg_pool3d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor", + "aten::max_pool1d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor", + "aten::max_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor", + "aten::max_pool3d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor", "aten::max_unpool2d(Tensor self, Tensor indices, int[] output_size) -> Tensor", "aten::max_unpool3d(Tensor self, Tensor indices, int[] output_size, int[] stride, int[] padding) -> Tensor", "aten::reflection_pad1d(Tensor self, int[] padding) -> Tensor", @@ -670,12 +670,12 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) { "aten::replication_pad1d(Tensor self, int[] padding) -> Tensor", "aten::replication_pad2d(Tensor self, int[] padding) -> Tensor", "aten::replication_pad3d(Tensor self, int[] padding) -> Tensor", - "aten::upsample_bilinear2d(Tensor self, int[] output_size, int align_corners) -> Tensor", - "aten::upsample_linear1d(Tensor self, int[] output_size, int align_corners) -> Tensor", + "aten::upsample_bilinear2d(Tensor self, int[] output_size, bool align_corners) -> Tensor", + "aten::upsample_linear1d(Tensor self, int[] output_size, bool align_corners) -> Tensor", "aten::upsample_nearest1d(Tensor self, int[] output_size) -> Tensor", "aten::upsample_nearest2d(Tensor self, int[] output_size) -> Tensor", "aten::upsample_nearest3d(Tensor self, int[] output_size) -> Tensor", - "aten::upsample_trilinear3d(Tensor self, int[] output_size, int align_corners) -> Tensor", + "aten::upsample_trilinear3d(Tensor self, int[] output_size, bool align_corners) -> Tensor", "aten::prelu(Tensor self, Tensor weight) -> Tensor", }, [](Node * node) -> type_vec_t { if (auto type = node->input(0)->type()->cast<TensorType>()) { @@ -702,10 +702,10 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) { "aten::mean(Tensor self) -> Tensor", "aten::median(Tensor self) -> Tensor", "aten::norm(Tensor self, Scalar p) -> Tensor", - "aten::std(Tensor self, int unbiased) -> Tensor", + "aten::std(Tensor self, bool unbiased) -> Tensor", "aten::sum(Tensor self) -> Tensor", "aten::trace(Tensor self) -> Tensor", - "aten::var(Tensor self, int unbiased) -> Tensor", + "aten::var(Tensor self, bool unbiased) -> Tensor", "aten::all(Tensor self) -> Tensor", "aten::any(Tensor self) -> Tensor", }, [](Node * node) -> type_vec_t { @@ -735,7 +735,7 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) { static const auto multidim_reduce_with_postprocess = [](Node * node, size_t num_reduced_dim, bool upcast_integer) -> type_vec_t { - auto maybe_keepdim = node->get<int64_t>(attr::keepdim); + auto maybe_keepdim = node->get<bool>(attr::keepdim); if (!maybe_keepdim) return {}; if (auto type = node->input(0)->type()->cast<TensorType>()) { if (upcast_integer && !at::isFloatingType(type->scalarType())) { @@ -760,24 +760,24 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) { // - First input should be the only tensor input // - Has a bool keepdim argument static const register_formula_for dim_reduce_ops {{ - "aten::argmax(Tensor self, int dim, int keepdim) -> Tensor", - "aten::argmin(Tensor self, int dim, int keepdim) -> Tensor", - "aten::max_values(Tensor self, int dim, int keepdim) -> Tensor", - "aten::min_values(Tensor self, int dim, int keepdim) -> Tensor", - "aten::mean(Tensor self, int dim, int keepdim) -> Tensor", - "aten::norm(Tensor self, Scalar p, int dim, int keepdim) -> Tensor", - "aten::std(Tensor self, int dim, int unbiased, int keepdim) -> Tensor", - "aten::var(Tensor self, int dim, int unbiased, int keepdim) -> Tensor", - "aten::logsumexp(Tensor self, int dim, int keepdim) -> Tensor", - "aten::all(Tensor self, int dim, int keepdim) -> Tensor", - "aten::any(Tensor self, int dim, int keepdim) -> Tensor", + "aten::argmax(Tensor self, int dim, bool keepdim) -> Tensor", + "aten::argmin(Tensor self, int dim, bool keepdim) -> Tensor", + "aten::max_values(Tensor self, int dim, bool keepdim) -> Tensor", + "aten::min_values(Tensor self, int dim, bool keepdim) -> Tensor", + "aten::mean(Tensor self, int dim, bool keepdim) -> Tensor", + "aten::norm(Tensor self, Scalar p, int dim, bool keepdim) -> Tensor", + "aten::std(Tensor self, int dim, bool unbiased, bool keepdim) -> Tensor", + "aten::var(Tensor self, int dim, bool unbiased, bool keepdim) -> Tensor", + "aten::logsumexp(Tensor self, int dim, bool keepdim) -> Tensor", + "aten::all(Tensor self, int dim, bool keepdim) -> Tensor", + "aten::any(Tensor self, int dim, bool keepdim) -> Tensor", // Ops returning indices as second output - "aten::kthvalue(Tensor self, int k, int dim, int keepdim) -> (Tensor, Tensor)", - "aten::max(Tensor self, int dim, int keepdim) -> (Tensor, Tensor)", - "aten::min(Tensor self, int dim, int keepdim) -> (Tensor, Tensor)", - "aten::median(Tensor self, int dim, int keepdim) -> (Tensor, Tensor)", - "aten::mode(Tensor self, int dim, int keepdim) -> (Tensor, Tensor)", + "aten::kthvalue(Tensor self, int k, int dim, bool keepdim) -> (Tensor, Tensor)", + "aten::max(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)", + "aten::min(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)", + "aten::median(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)", + "aten::mode(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)", }, [](Node * node) -> type_vec_t { // NB: Note that while this function is generally meant to be used with ops that // have a single output, we will fix up its return right below. @@ -798,7 +798,7 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) { // - First input should be the only tensor input // - has a bool keepdim argument static const register_formula_for dim_reduce_ops_with_integer_upcast {{ - "aten::prod(Tensor self, int dim, int keepdim) -> Tensor", + "aten::prod(Tensor self, int dim, bool keepdim) -> Tensor", }, [](Node * node) -> type_vec_t { return multidim_reduce_with_postprocess(node, /*num_reduce_dim=*/1, /*integer_upcast=*/true); }}; @@ -812,7 +812,7 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) { // Additionally: // - has bool keepdim and int[] dim arguments static const register_formula_for multidim_reduce_ops_with_integer_upcast {{ - "aten::sum(Tensor self, int[] dim, int keepdim) -> Tensor", + "aten::sum(Tensor self, int[] dim, bool keepdim) -> Tensor", }, [](Node * node) -> type_vec_t { if (auto dim = node->get<std::vector<int64_t>>(attr::dim)) { // TODO: can dim contain duplicates? @@ -900,14 +900,14 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) { } }; static const register_formula_for cast_ops {{ - "aten::_cast_Byte(Tensor self, int non_blocking) -> Tensor", - "aten::_cast_Char(Tensor self, int non_blocking) -> Tensor", - "aten::_cast_Double(Tensor self, int non_blocking) -> Tensor", - "aten::_cast_Float(Tensor self, int non_blocking) -> Tensor", - "aten::_cast_Half(Tensor self, int non_blocking) -> Tensor", - "aten::_cast_Int(Tensor self, int non_blocking) -> Tensor", - "aten::_cast_Long(Tensor self, int non_blocking) -> Tensor", - "aten::_cast_Short(Tensor self, int non_blocking) -> Tensor", + "aten::_cast_Byte(Tensor self, bool non_blocking) -> Tensor", + "aten::_cast_Char(Tensor self, bool non_blocking) -> Tensor", + "aten::_cast_Double(Tensor self, bool non_blocking) -> Tensor", + "aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor", + "aten::_cast_Half(Tensor self, bool non_blocking) -> Tensor", + "aten::_cast_Int(Tensor self, bool non_blocking) -> Tensor", + "aten::_cast_Long(Tensor self, bool non_blocking) -> Tensor", + "aten::_cast_Short(Tensor self, bool non_blocking) -> Tensor", }, [](Node * node) -> type_vec_t { if (auto type = node->namedInput(attr::self)->type()->cast<TensorType>()) { return {type->toScalarType(get_cast_scalar_type(node))}; @@ -1003,7 +1003,7 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) { } return true; } - } else if (node->matches("aten::embedding(Tensor weight, Tensor indices, int padding_idx, int scale_grad_by_freq, int sparse) -> Tensor")) { + } else if (node->matches("aten::embedding(Tensor weight, Tensor indices, int padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor")) { auto weight_type = input_type(0); auto indices_type = input_type(1); if (weight_type && indices_type) { @@ -1044,7 +1044,7 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) { node->matches("aten::reshape_as(Tensor self, Tensor other) -> Tensor")) { return tensor_types.at(0)->withDim(tensor_types.at(1)->dim()); } else if (node->matches("aten::view(Tensor self, int[] size) -> Tensor") || - node->matches("aten::expand(Tensor self, int[] size, *, int implicit) -> Tensor") || + node->matches("aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor") || node->matches("aten::as_strided(Tensor self, int[] size, int[] stride) -> Tensor") || node->matches("aten::as_strided(Tensor self, int[] size, int[] stride, int storage_offset) -> Tensor")) { return reshape_prop(node, attr::size, tensor_types); @@ -1189,12 +1189,12 @@ bool PropagateCompleteShapeOnNode(Node * node, bool insert_expands, } else if (node->matches("aten::sum(Tensor self) -> Tensor")) { node->output()->setType(tensor_types.at(0)->withSizes({})); return true; - } else if (node->matches("aten::sum(Tensor self, int[] dim, int keepdim) -> Tensor", + } else if (node->matches("aten::sum(Tensor self, int[] dim, bool keepdim) -> Tensor", /*with_const=*/{attr::dim, attr::keepdim})) { auto & tp = tensor_types.at(0); auto sizes = tp->sizes(); auto dims = node->get<std::vector<int64_t>>(attr::dim).value(); - bool keepdim = node->get<int64_t>(attr::keepdim).value(); + bool keepdim = node->get<bool>(attr::keepdim).value(); std::reverse(dims.begin(), dims.end()); for (int64_t dim : dims) { SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < sizes.size()); @@ -1262,7 +1262,7 @@ bool PropagateCompleteShapeOnNode(Node * node, bool insert_expands, node->output()->setType(tensor_types.at(1)->withSizes(tensor_types.at(0)->sizes())); } return true; - } else if (node->matches("aten::expand(Tensor self, int[] size, *, int implicit) -> Tensor", + } else if (node->matches("aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor", /*with_const=*/attr::size)) { auto tp = tensor_types.at(0); std::vector<int64_t> sizes, strides; diff --git a/torch/csrc/jit/passes/to_batch.cpp b/torch/csrc/jit/passes/to_batch.cpp index 0d56ca2255..f1edf80d41 100644 --- a/torch/csrc/jit/passes/to_batch.cpp +++ b/torch/csrc/jit/passes/to_batch.cpp @@ -41,6 +41,10 @@ void ToBatch::visitAten(Node* n, Block* block, Block* res_block){ auto to_tensor_node = res_graph->createNumToTensor(input); res_graph->insertNode(to_tensor_node); new_inputs[i] = to_tensor_node->output(); + } else if(input->type() == BoolType::get()) { + auto to_tensor_node = res_graph->createBoolToTensor(input); + res_graph->insertNode(to_tensor_node); + new_inputs[i] = to_tensor_node->output(); } } @@ -58,8 +62,11 @@ void ToBatch::visitAten(Node* n, Block* block, Block* res_block){ else if(n->outputs()[0]->type() == FloatType::get()){ to_scalar_node = res_graph->createTensorToNum(FloatType::get(), outputs[0]); } + else if(n->outputs()[0]->type() == BoolType::get()){ + to_scalar_node = res_graph->createTensorToBool(outputs[0]); + } else{ - throw std::runtime_error("NYI: scalar type other than int, float is not supported yet"); + throw std::runtime_error("NYI: scalar types other than int, float, and bool are not supported yet"); } res_graph->insertNode(to_scalar_node); rn_env[n->outputs()[0]] = to_scalar_node->output(); @@ -348,9 +355,10 @@ void ToBatch::visitLoop(Node* n, Block* block, Block* res_block){ if(cond_is_tensor){ auto cond = batch_map.at(n->inputs()[1]); auto cond_any = script::inlineCallTo(*res_block->owningGraph(), *getBatchOperator("any"), cond); - auto to_int_node = res_graph->createTensorToNum(IntType::get(), cond_any[0]); - res_graph->insertNode(to_int_node); - rn_env[n->inputs()[1]] = to_int_node->output(); + auto to_bool_node = + res_graph->createTensorToBool(cond_any[0]); + res_graph->insertNode(to_bool_node); + rn_env[n->inputs()[1]] = to_bool_node->output(); } for(size_t i = 2; i < n->inputs().size(); i++){ auto input = n->inputs()[i]; @@ -432,9 +440,10 @@ void ToBatch::visitLoop(Node* n, Block* block, Block* res_block){ if(cond_is_tensor){ auto cond = batch_map.at(n->blocks()[0]->outputs()[0]); auto cond_any = script::inlineCallTo(*res_block->owningGraph(), *getBatchOperator("any"), cond); - auto to_int_node = res_graph->createTensorToNum(IntType::get(), cond_any[0]); - res_graph->insertNode(to_int_node); - loop_block->insertOutput(0, to_int_node->output()); + auto to_bool_node = + res_graph->createTensorToBool(cond_any[0]); + res_graph->insertNode(to_bool_node); + loop_block->insertOutput(0, to_bool_node->output()); for(size_t i = 0; i < EXP_BTENSOR_SIZE; i++){ loop_block->insertOutput(i + 1, cond[i]); } @@ -491,6 +500,7 @@ void ToBatch::toBatch(Block* block, Block* res_block) { case prim::NumToTensor: visitNumToTensor(n, block, res_block); break; + case prim::TensorToBool: case prim::TensorToNum: visitTensorToNum(n, block, res_block); break; diff --git a/torch/csrc/jit/pybind_utils.h b/torch/csrc/jit/pybind_utils.h index 192cabca3f..bd81a8cc05 100644 --- a/torch/csrc/jit/pybind_utils.h +++ b/torch/csrc/jit/pybind_utils.h @@ -107,6 +107,8 @@ inline IValue toIValue(py::handle obj, const TypePtr& type) { return py::cast<int64_t>(obj); case TypeKind::NoneType: return {}; + case TypeKind::BoolType: + return py::cast<bool>(obj); case TypeKind::TupleType: { if(!PyTuple_Check(obj.ptr())) throw py::cast_error(); // note: the py::cast does not throw cast_error @@ -196,10 +198,14 @@ inline py::object toPyObject(IValue&& ivalue) { return py::cast(ivalue.toDouble()); } else if (ivalue.isInt()) { return py::cast(ivalue.toInt()); + }else if (ivalue.isBool()) { + return py::cast(ivalue.toBool()); } else if (ivalue.isIntList()) { return py::cast(ivalue.toIntListRef()); } else if (ivalue.isDoubleList()) { return py::cast(ivalue.toDoubleListRef()); + } else if (ivalue.isBoolList()) { + return py::cast(ivalue.toBoolListRef()); } else if (ivalue.isTensorList()) { return py::cast(ivalue.toTensorListRef()); } else if (ivalue.isGenericList()) { diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp index ad03ac556c..979b07d369 100644 --- a/torch/csrc/jit/python_ir.cpp +++ b/torch/csrc/jit/python_ir.cpp @@ -455,6 +455,8 @@ void initPythonIRBindings(PyObject * module_) { return "StringType"; case TypeKind::GeneratorType: return "GeneratorType"; + case TypeKind::BoolType: + return "BoolType"; case TypeKind::VarType: return "VarType"; case TypeKind::WorldType: @@ -491,6 +493,8 @@ void initPythonIRBindings(PyObject * module_) { .def_static("get", &FloatType::get); py::class_<DynamicType, Type, std::shared_ptr<DynamicType>>(m, "DynamicType") .def_static("get", &DynamicType::get); + py::class_<BoolType, Type, std::shared_ptr<BoolType>>(m, "BoolType") + .def_static("get", &BoolType::get); py::class_<TupleType, Type, std::shared_ptr<TupleType>>(m, "TupleType") .def(py::init([](std::vector<TypePtr> a){ return TupleType::create(a); })) diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index cdea4ab894..c7bf050dc2 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -70,6 +70,17 @@ RegisterOperators reg({ }; }), Operator( + prim::TensorToBool, + [](Node* node) -> Operation { + return [](Stack& stack) { + at::Tensor a; + pop(stack, a); + at::DeviceGuard guard(a); + push(stack, a.item<int64_t>() != 0); + return 0; + }; + }), + Operator( prim::TensorToNum, [](Node* node) -> Operation { if(node->output()->type() == IntType::get()) { @@ -124,6 +135,18 @@ RegisterOperators reg({ }; }), Operator( + prim::BoolToTensor, + [](Node* node) -> Operation { + return [](Stack& stack) { + bool b; + pop(stack, b); + push( + stack, + autograd::make_variable(at::scalar_to_tensor(b))); + return 0; + }; + }), + Operator( prim::IntToFloat, [](Node* node) -> Operation { return [](Stack& stack) { @@ -446,25 +469,25 @@ RegisterOperators reg({ }); // define implementations for primitive number ops -#define DEFINE_GENERIC_OP(aten_op, int_op, float_op, float_result) \ - Operator( \ - #aten_op "(int a, int b) -> int", \ - [](Node* node) { \ - return [=](Stack& stack) { \ - int64_t a, b; \ - pop(stack, a, b); \ - push(stack, int_op); \ - return 0; \ - }; \ - }), \ - Operator( \ - #aten_op "(float a, float b) -> " #float_result, [](Node* node) { \ - return [=](Stack& stack) { \ - double a, b; \ - pop(stack, a, b); \ - push(stack, float_op); \ - return 0; \ - }; \ +#define DEFINE_GENERIC_OP(aten_op, int_op, float_op, int_result, float_result) \ + Operator( \ + #aten_op "(int a, int b) -> " #int_result, \ + [](Node* node) { \ + return [=](Stack& stack) { \ + int64_t a, b; \ + pop(stack, a, b); \ + push(stack, int_op); \ + return 0; \ + }; \ + }), \ + Operator( \ + #aten_op "(float a, float b) -> " #float_result, [](Node* node) { \ + return [=](Stack& stack) { \ + double a, b; \ + pop(stack, a, b); \ + push(stack, float_op); \ + return 0; \ + }; \ }), #define DEFINE_INT_OP(aten_op, op) \ @@ -477,8 +500,19 @@ RegisterOperators reg({ }; \ }), -#define DEFINE_BINARY_OP(aten_op, op) DEFINE_GENERIC_OP(aten_op, op, op, float) -#define DEFINE_COMPARISON_OP(aten_op, op) DEFINE_GENERIC_OP(aten_op, op, op, int) +#define DEFINE_BINARY_OP(aten_op, op) \ + DEFINE_GENERIC_OP(aten_op, op, op, int, float) +#define DEFINE_COMPARISON_OP(aten_op, op) \ + DEFINE_GENERIC_OP(aten_op, op, op, bool, bool) +#define DEFINE_BOOL_OP(aten_op, op) \ + Operator(#aten_op "(bool a, bool b) -> bool", [](Node* node) { \ + return [=](Stack& stack) { \ + bool a, b; \ + pop(stack, a, b); \ + push(stack, op); \ + return 0; \ + }; \ + }), // Convert an python index (which may be negative) into an index usable for a // C++ container @@ -663,7 +697,7 @@ RegisterOperators reg2({ // Pass in two ops for handling int and float separately as % in C++ only works for int // The modulus calculation is different between C++ and Python (on negative), we preserve // the python behavior as it's more common and match python syntax, hence the conversion. - DEFINE_GENERIC_OP(aten::remainder, (b + (a % b)) % b, fmod((b + fmod(a, b)), b), float) + DEFINE_GENERIC_OP(aten::remainder, (b + (a % b)) % b, fmod((b + fmod(a, b)), b), int, float) // TODO: Support python floordiv (//) // Right now aten::floordiv is only used by loop unrolling @@ -696,8 +730,8 @@ RegisterOperators reg2({ DEFINE_COMPARISON_OP(aten::le, a <= b) DEFINE_COMPARISON_OP(aten::ge, a >= b) - DEFINE_INT_OP(aten::__and__, a&& b) - DEFINE_INT_OP(aten::__or__, a || b) + DEFINE_BOOL_OP(aten::__and__, a && b) + DEFINE_BOOL_OP(aten::__or__, a || b) Operator("aten::_construct_empty_int_list() -> int[]", [](Node* node) -> Operation { diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index 28aa735fc3..474722f536 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -74,6 +74,8 @@ static Value* typeCast(const SourceRange& loc, Value* value, TypePtr dst) { n = graph.createNumToTensor(value); } else if (dst->isSubtypeOf(NumberType::get()) && orig->isSubtypeOf(DynamicType::get())) { n = graph.createTensorToNum(dst, value); + } else if (dst->isSubtypeOf(BoolType::get()) && orig->isSubtypeOf(DynamicType::get())) { + n = graph.createTensorToBool(value); } else if(dst->isSubtypeOf(IntType::get()) && orig->isSubtypeOf(FloatType::get())) { n = graph.createFloatToInt(value); } else if(dst->isSubtypeOf(FloatType::get()) && orig->isSubtypeOf(IntType::get())) { @@ -324,7 +326,7 @@ struct Environment { {"print", std::make_shared<PrintValue>()}, {"float", std::make_shared<CastValue>(FloatType::get())}, {"int", std::make_shared<CastValue>(IntType::get())}, - {"bool", std::make_shared<CastValue>(IntType::get())}, + {"bool", std::make_shared<CastValue>(BoolType::get())}, // todo(zach): remove when we can correctly export torch.full via ONNX // or we have implicit conversion that can convert numbers to tensors {"_to_tensor", std::make_shared<CastValue>(DynamicType::get()) }, @@ -1048,9 +1050,9 @@ private: Value* emitCond(Expr cond) { Value* v = emitExpr(cond); - if (!v->type()->isSubtypeOf(IntType::get())) { + if (!v->type()->isSubtypeOf(BoolType::get())) { ErrorReport error(cond); - error << "expected an integer expression for condition but found " + error << "expected a boolean expression for condition but found " << v->type()->str(); if (v->type()->isSubtypeOf(DynamicType::get())) { error << ", to use a tensor in a boolean" @@ -1928,6 +1930,7 @@ const std::unordered_map<std::string, TypePtr> &ident_to_type_lut() { {"Tensor", DynamicType::get()}, {"int", IntType::get()}, {"float", FloatType::get()}, + {"bool", BoolType::get()}, }; return map; } diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index 4c7df820b1..fec69ac317 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -278,12 +278,12 @@ std::shared_ptr<SugaredValue> toSugaredValue( // f = f + 1 auto& g = *m.graph(); if (is_constant) { - if (py::isinstance<py::int_>(obj)) { + if (py::isinstance<py::bool_>(obj)) { + return toSimple(g.insertConstant(py::cast<bool>(obj), loc)); + } else if (py::isinstance<py::int_>(obj)) { return toSimple(g.insertConstant(py::cast<int64_t>(obj), loc)); } else if (py::isinstance<py::float_>(obj)) { return toSimple(g.insertConstant(py::cast<float>(obj), loc)); - } else if (py::isinstance<py::bool_>(obj)) { - return toSimple(g.insertConstant(py::cast<bool>(obj), loc)); } else if (THPDevice_Check(obj.ptr())) { auto device = (THPDevice*)obj.ptr(); std::vector<int64_t> v = {static_cast<int64_t>(device->device.type()), diff --git a/torch/csrc/jit/type.cpp b/torch/csrc/jit/type.cpp index 855adad429..ce3cd8c1f2 100644 --- a/torch/csrc/jit/type.cpp +++ b/torch/csrc/jit/type.cpp @@ -46,6 +46,8 @@ std::ostream& operator<<(std::ostream & out, const Type & t) { out << "float"; } else if(t.kind() == TypeKind::IntType) { out << "int"; + } else if(t.kind() == TypeKind::BoolType) { + out << "bool"; } else if(t.kind() == TypeKind::ListType) { auto prim = t.cast<ListType>()->getElementType(); out << *prim << "[]"; @@ -85,6 +87,10 @@ FloatTypePtr FloatType::get() { static auto value = FloatType::create(); return value; } +BoolTypePtr BoolType::get() { + static auto value = BoolType::create(); + return value; +} NoneTypePtr NoneType::get() { static auto value = NoneType::create(); return value; @@ -113,6 +119,10 @@ ListTypePtr ListType::ofFloats() { static auto value = ListType::create(FloatType::get()); return value; } +ListTypePtr ListType::ofBools() { + static auto value = ListType::create(BoolType::get()); + return value; +} TypePtr inferTypeFrom(const IValue& value) { if (value.isTensor()) { @@ -121,12 +131,16 @@ TypePtr inferTypeFrom(const IValue& value) { return FloatType::get(); } else if (value.isInt()) { return IntType::get(); + } else if (value.isBool()) { + return BoolType::get(); } else if (value.isString()) { return StringType::get(); } else if (value.isIntList()) { return ListType::ofInts(); } else if (value.isTensorList()) { return ListType::ofTensors(); + } else if (value.isBoolList()) { + return ListType::ofBools(); } else if (value.isDoubleList()) { return ListType::ofFloats(); } else if (value.isTuple()) { diff --git a/torch/csrc/jit/type.h b/torch/csrc/jit/type.h index 49748de239..69133f6be9 100644 --- a/torch/csrc/jit/type.h +++ b/torch/csrc/jit/type.h @@ -27,6 +27,7 @@ _(IntType) \ _(NoneType) \ _(StringType) \ _(GeneratorType) \ +_(BoolType) \ _(VarType) \ _(WorldType) \ @@ -343,6 +344,7 @@ struct TORCH_API CompleteTensorType : public TensorType { return prod; } static TypePtr fromNumberType(TypePtr typ); + static TypePtr fromBoolType(); private: CompleteTensorType(const at::Tensor& tensor) @@ -438,6 +440,7 @@ struct TORCH_API ListType : public Type { static ListTypePtr ofTensors(); static ListTypePtr ofInts(); static ListTypePtr ofFloats(); + static ListTypePtr ofBools(); static const TypeKind Kind = TypeKind::ListType; private: @@ -605,6 +608,31 @@ private: : Type(TypeKind::IntType) {} }; +struct BoolType; +using BoolTypePtr = std::shared_ptr<BoolType>; +// This node represents a Python bool value +struct TORCH_API BoolType : public Type { + template<typename ... T> + static BoolTypePtr create( T&& ... all ) { + return BoolTypePtr(new BoolType(std::forward<T>(all)... )); + } + bool operator==(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "bool"; + } + bool isSubtypeOf(const TypePtr rhs) const override { + return *this == *rhs || rhs->kind() == TypeKind::BoolType; + } + static const TypeKind Kind = TypeKind::BoolType; + // global singleton + static BoolTypePtr get(); +private: + BoolType() + : Type(TypeKind::BoolType) {} +}; + struct StringType; using StringTypePtr = std::shared_ptr<StringType>; // This node represents a Python string value @@ -728,10 +756,17 @@ inline TypePtr CompleteTensorType::fromNumberType(TypePtr typ) { return CompleteTensorType::create(at::kLong, -1, {}); } else if (typ->isSubtypeOf(FloatType::get())) { return CompleteTensorType::create(at::kFloat, -1, {}); + } else if (typ->isSubtypeOf(BoolType::get())) { + return CompleteTensorType::create(at::kLong, -1, {}); } AT_ERROR("unknown number type", typ->str()); } +inline TypePtr CompleteTensorType::fromBoolType() { + return CompleteTensorType::create(at::kLong, -1, {}); +} + + // Attempt to find the correct supertype of t1 and t2. If none is found then // nullopt will be returned. If t1 == t2, or t1 is a type refinement of t2, // then t2 will be returned (and vice versa). @@ -755,7 +790,7 @@ TypePtr getTypePtr() { template<> inline TypePtr getTypePtr<at::Tensor>() { return DynamicType::get(); } template<> inline TypePtr getTypePtr<double>() { return FloatType::get(); } template<> inline TypePtr getTypePtr<int64_t>() { return IntType::get(); } -template<> inline TypePtr getTypePtr<bool>() { return IntType::get(); } +template<> inline TypePtr getTypePtr<bool>() { return BoolType::get(); } template<> inline TypePtr getTypePtr<at::Scalar>() { return NumberType::get(); } template<> inline TypePtr getTypePtr<std::vector<at::Tensor>>() { return ListType::ofTensors(); } template<> inline TypePtr getTypePtr<std::vector<double>>() { return ListType::ofFloats(); } diff --git a/torch/jit/batchop.py b/torch/jit/batchop.py index 229cafbb94..fed8d2a7c6 100644 --- a/torch/jit/batchop.py +++ b/torch/jit/batchop.py @@ -214,8 +214,8 @@ def batch_where(data, mask, dims, data1, mask1, dims1, data2, mask2, dims2): @torch.jit.script -def batch_where_scalar(cond_, data1, mask1, dims1, data2, mask2, dims2): - cond = torch.zeros([1], dtype=torch.uint8) * cond_ +def batch_where_scalar(cond, data1, mask1, dims1, data2, mask2, dims2): + cond = torch.zeros([1], dtype=torch.uint8) res_data = torch.where(cond, data1, data2) res_mask = torch.where(cond, mask1, mask2) res_dims = torch.where(cond, dims1, dims2) @@ -304,7 +304,7 @@ def batch_unsqueeze(data, mask, dims, dim_): @torch.jit.script def batch_argmax(data, mask, dims, dim_, keepdim_): dim = int(dim_) - keepdim = int(keepdim_) + keepdim = bool(keepdim_) # if dim == 0: # raise ValueError("cannot do argmax along batch_dim") batch_size = data.size(0) @@ -338,8 +338,8 @@ def batch_argmax(data, mask, dims, dim_, keepdim_): def batch_topk(data, mask, dims, k_, dim_, largest_, sorted_): k = int(k_) dim = int(dim_) - largest = int(largest_) - sorted = int(sorted_) + largest = bool(largest_) + sorted = bool(sorted_) # if dim == 0: # raise ValueError("cannot do topk along batch_dim") batch_size = data.size(0) |