diff options
author | David Riazati <davidriazati@fb.com> | 2018-10-03 12:25:39 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-10-03 12:40:03 -0700 |
commit | d1ac1eba3b53f67d8d12eb20002b06893a2a4d2e (patch) | |
tree | 50f41b71264f05355a8a255fda94b582406c45aa /torch | |
parent | c029c839a1900cb6caa0cd32c3db76fbe56460a0 (diff) | |
download | pytorch-d1ac1eba3b53f67d8d12eb20002b06893a2a4d2e.tar.gz pytorch-d1ac1eba3b53f67d8d12eb20002b06893a2a4d2e.tar.bz2 pytorch-d1ac1eba3b53f67d8d12eb20002b06893a2a4d2e.zip |
Add `bool` type to IR (#11834)
Summary:
This PR adds a bool type to `IValue` and puts it into place.
* changes conds for `prim::If` and `prim::Loop` to use `bool` type
* changes operators that take `bool`s to match their native ops
* fixes ambiguous `aten` ops `aten::std` and `aten::var`
* fixes tests in `test_jit.py TestJitGenerated`
```
'test_std_dim',
'test_std_dim_1d',
'test_std_dim_1d_neg0',
'test_std_dim_neg0',
'test_var_dim',
'test_var_dim_1d',
'test_var_dim_1d_neg0',
'test_var_dim_neg0'
```
* adds `prim::BoolToTensor` and `prim::TensorToBool`
apaszke zdevito
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11834
Differential Revision: D9928570
Pulled By: driazati
fbshipit-source-id: 373c53df2f1a8ffa9e33d9a517002fbeef25f3eb
Diffstat (limited to 'torch')
-rw-r--r-- | torch/csrc/jit/attributes.h | 1 | ||||
-rw-r--r-- | torch/csrc/jit/autodiff.cpp | 2 | ||||
-rw-r--r-- | torch/csrc/jit/constants.cpp | 19 | ||||
-rw-r--r-- | torch/csrc/jit/export.cpp | 2 | ||||
-rw-r--r-- | torch/csrc/jit/import.cpp | 2 | ||||
-rw-r--r-- | torch/csrc/jit/interned_strings.h | 2 | ||||
-rw-r--r-- | torch/csrc/jit/interpreter.cpp | 27 | ||||
-rw-r--r-- | torch/csrc/jit/ir.cpp | 10 | ||||
-rw-r--r-- | torch/csrc/jit/ir.h | 14 | ||||
-rw-r--r-- | torch/csrc/jit/operator.cpp | 7 | ||||
-rw-r--r-- | torch/csrc/jit/passes/erase_number_types.cpp | 7 | ||||
-rw-r--r-- | torch/csrc/jit/passes/peephole.cpp | 2 | ||||
-rw-r--r-- | torch/csrc/jit/passes/remove_expands.cpp | 3 | ||||
-rw-r--r-- | torch/csrc/jit/passes/shape_analysis.cpp | 102 | ||||
-rw-r--r-- | torch/csrc/jit/passes/to_batch.cpp | 24 | ||||
-rw-r--r-- | torch/csrc/jit/pybind_utils.h | 6 | ||||
-rw-r--r-- | torch/csrc/jit/python_ir.cpp | 4 | ||||
-rw-r--r-- | torch/csrc/jit/register_prim_ops.cpp | 82 | ||||
-rw-r--r-- | torch/csrc/jit/script/compiler.cpp | 9 | ||||
-rw-r--r-- | torch/csrc/jit/script/init.cpp | 6 | ||||
-rw-r--r-- | torch/csrc/jit/type.cpp | 14 | ||||
-rw-r--r-- | torch/csrc/jit/type.h | 37 | ||||
-rw-r--r-- | torch/jit/batchop.py | 10 |
23 files changed, 274 insertions, 118 deletions
diff --git a/torch/csrc/jit/attributes.h b/torch/csrc/jit/attributes.h index 2610643918..af3a16393f 100644 --- a/torch/csrc/jit/attributes.h +++ b/torch/csrc/jit/attributes.h @@ -167,6 +167,7 @@ struct Attributes { const Kind##Attr::ValueType& method(Symbol name) const { \ return get<Kind##Attr>(name); \ } + CREATE_ACCESSOR(Float,f) CREATE_ACCESSOR(Floats,fs) CREATE_ACCESSOR(String,s) diff --git a/torch/csrc/jit/autodiff.cpp b/torch/csrc/jit/autodiff.cpp index a59c856eab..46840fc99a 100644 --- a/torch/csrc/jit/autodiff.cpp +++ b/torch/csrc/jit/autodiff.cpp @@ -280,7 +280,7 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val } else if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) { return {grads.at(0).mm(inputs.at(1).t()), inputs.at(0).t().mm(grads.at(0))}; - } else if (node->matches("aten::expand(Tensor self, int[] size, *, int implicit) -> Tensor")) { + } else if (node->matches("aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor")) { const auto& input_sizes = inputs.at(0).sizes(); if (input_sizes.size() == 0) return {grads.at(0).sum(), nullptr, nullptr}; diff --git a/torch/csrc/jit/constants.cpp b/torch/csrc/jit/constants.cpp index 1633ac0de4..c62cdbd1fa 100644 --- a/torch/csrc/jit/constants.cpp +++ b/torch/csrc/jit/constants.cpp @@ -27,6 +27,13 @@ Value* insertConstant( } else if(val.isDouble()) { n->f_(attr::value, val.toDouble()); n->output()->setType(FloatType::get()); + } else if (val.isBool()) { + n->i_(attr::value, val.toBool()); + n->output()->setType(BoolType::get()); + } else if (val.isBoolList()) { + auto bool_list = val.toBoolList()->elements(); + n->is_(attr::value, std::vector<int64_t>(bool_list.begin(), bool_list.end())); + n->output()->setType(ListType::ofBools()); } else if(val.isIntList()) { n->is_(attr::value, val.toIntList()->elements()); n->output()->setType(ListType::ofInts()); @@ -64,6 +71,12 @@ RegisterOperators reg({ stack.push_back(t); return 0; }; + } else if (type->isSubtypeOf(BoolType::get())) { + bool b = node->i(attr::value); + return [b](Stack& stack) { + push(stack, b); + return 0; + }; } else if ( type->isSubtypeOf(NumberType::get()) && node->kindOf(attr::value) == AttributeKind::i) { @@ -86,6 +99,12 @@ RegisterOperators reg({ push(stack, is); return 0; }; + } else if(type->isSubtypeOf(ListType::ofBools())) { + auto bs = node->is(attr::value); + return [bs](Stack& stack) { + push(stack, bs); + return 0; + }; } else if(type->isSubtypeOf(ListType::ofTensors())) { auto ts = fmap(node->ts(attr::value), [](const at::Tensor & t) -> at::Tensor { return autograd::make_variable(t); diff --git a/torch/csrc/jit/export.cpp b/torch/csrc/jit/export.cpp index 973780b7d7..574dae168a 100644 --- a/torch/csrc/jit/export.cpp +++ b/torch/csrc/jit/export.cpp @@ -559,6 +559,8 @@ void ModuleEncoder::EncodeTypeInfo( type_proto->set_denotation("FloatType"); } else if (kind == TypeKind::IntType) { type_proto->set_denotation("IntType"); + } else if (kind == TypeKind::BoolType) { + type_proto->set_denotation("BoolType"); } else if (kind == TypeKind::NoneType) { type_proto->set_denotation("NoneType"); } else if (kind == TypeKind::GeneratorType) { diff --git a/torch/csrc/jit/import.cpp b/torch/csrc/jit/import.cpp index 4574addb3a..45053a81db 100644 --- a/torch/csrc/jit/import.cpp +++ b/torch/csrc/jit/import.cpp @@ -257,6 +257,8 @@ TypePtr ModuleDecoder::buildType(const onnx::TypeProto& type_proto) { return FloatType::get(); } else if (kind == "IntType") { return IntType::get(); + } else if (kind == "BoolType") { + return BoolType::get(); } else if (kind == "NoneType") { return NoneType::get(); } else if (kind == "GeneratorType") { diff --git a/torch/csrc/jit/interned_strings.h b/torch/csrc/jit/interned_strings.h index b4e6b7c139..32544aa682 100644 --- a/torch/csrc/jit/interned_strings.h +++ b/torch/csrc/jit/interned_strings.h @@ -46,9 +46,11 @@ namespace torch { namespace jit { _(prim, TupleUnpack) \ _(prim, ListConstruct) \ _(prim, ListUnpack) \ + _(prim, BoolToTensor) \ _(prim, NumToTensor) \ _(prim, TensorToNum) \ _(prim, ImplicitTensorToNum) \ + _(prim, TensorToBool) \ _(prim, IntToFloat) \ _(prim, FloatToInt) \ _(prim, StringToFloat) \ diff --git a/torch/csrc/jit/interpreter.cpp b/torch/csrc/jit/interpreter.cpp index 14e7fab54d..065054a0fc 100644 --- a/torch/csrc/jit/interpreter.cpp +++ b/torch/csrc/jit/interpreter.cpp @@ -61,14 +61,14 @@ Value* createTripCountConjunctiveCondition( // Emit initial comparison -- initial_trip_count < max_trip_count Value* initial_comparison_value = g->insertNode(g->create(aten::lt, {cur_trip_count, max_trip_count}, 1)) - ->output()->setType(IntType::get()); + ->output()->setType(BoolType::get()); // Replace initial condition with logical `and` of trip count and // initial condition Value* new_cond = g->insertNode( g->create(aten::__and__, {initial_comparison_value, cond}, 1)) - ->output()->setType(IntType::get()); + ->output()->setType(BoolType::get()); return new_cond; } @@ -388,30 +388,29 @@ struct CodeImpl { CodeImpl(std::shared_ptr<Graph>& graph_) : preprocess(*graph_) { graph = preprocess.graph; - // std::cout << "into code graph:\n" << *graph << "\n"; insertNodesFromBlock(graph->block()); } - // jump when input is 0 - void createJumpZ(int from_inst, int to_inst) { + // jump when input is false + void createJumpFalse(int from_inst, int to_inst) { auto & inst = instructions[from_inst]; JIT_ASSERT(inst.debug_name == prim::Placeholder); auto offset = relativeJump(from_inst, to_inst); inst.callback = [offset](Stack & stack) { - auto t = pop(stack).toInt(); - return (t == 0) ? offset : 0; + auto t = pop(stack).toBool(); + return t ? 0 : offset; }; inst.debug_name = prim::JumpZ; } - // jump when input is not 0 - void createJumpNZ(int from_inst, int to_inst) { + // jump when input is true + void createJumpTrue(int from_inst, int to_inst) { auto & inst = instructions[from_inst]; JIT_ASSERT(inst.debug_name == prim::Placeholder); auto offset = relativeJump(from_inst, to_inst); inst.callback = [offset](Stack & stack) { - auto t = pop(stack).toInt(); - return (t != 0) ? offset : 0; + auto t = pop(stack).toBool(); + return t ? offset : 0; }; inst.debug_name = prim::JumpNZ; } @@ -460,7 +459,7 @@ struct CodeImpl { insertNodesFromBlock(then_block); insertAssign(source_location, then_block->outputs(), moveFlags(then_block), node->outputs()); createJump(jump, instructions.size()); - createJumpNZ(cond_branch, then_block_start); + createJumpTrue(cond_branch, then_block_start); } break; case prim::Loop: { // o0 = while c i0 @@ -495,8 +494,8 @@ struct CodeImpl { // after branch: stack: ... aliasRegistersTo(node->outputs(), body_block->inputs()); - createJumpZ(cond_branch, instructions.size()); - createJumpNZ(cond_branch_end, entry); + createJumpFalse(cond_branch, instructions.size()); + createJumpTrue(cond_branch_end, entry); } break; default: { insertInstruction(node); diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp index 90451494ba..4e4bb0b68d 100644 --- a/torch/csrc/jit/ir.cpp +++ b/torch/csrc/jit/ir.cpp @@ -562,18 +562,18 @@ namespace { const OperatorSet& nondeterminstic_aten_ops() { static OperatorSet nondeterministic_ops = { - "aten::dropout(Tensor input, float p, int train) -> Tensor", + "aten::dropout(Tensor input, float p, bool train) -> Tensor", "aten::_fused_dropout(Tensor self, float p, Generator generator) -> (Tensor, Tensor)", "aten::_standard_gamma(Tensor self, Generator generator) -> Tensor", "aten::bernoulli(Tensor self, *, Generator generator) -> Tensor", "aten::bernoulli(Tensor self, float p, *, Generator generator) -> Tensor", - "aten::multinomial(Tensor self, int num_samples, int replacement, *, Generator generator) -> Tensor", + "aten::multinomial(Tensor self, int num_samples, bool replacement, *, Generator generator) -> Tensor", "aten::normal(Tensor mean, Tensor std, *, Generator generator) -> Tensor", "aten::normal(float mean, Tensor std, *, Generator generator) -> Tensor", "aten::normal(Tensor mean, float std, *, Generator generator) -> Tensor", "aten::poisson(Tensor self, Generator generator) -> Tensor", - "aten::rrelu(Tensor self, Scalar lower, Scalar upper, int training, Generator generator) -> Tensor", - "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, int training, Generator generator) -> Tensor", + "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator generator) -> Tensor", + "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator generator) -> Tensor", "aten::rand(int[] size, *, int dtype, int layout, int[] device) -> Tensor", "aten::rand_like(Tensor self) -> Tensor", "aten::rand_like(Tensor self, *, int dtype, int layout, int[] device) -> Tensor", @@ -598,7 +598,7 @@ bool Node::isNondeterministic() const { return false; } // Dropout with train = False is deterministic - if (matches("aten::dropout(Tensor input, float p, int train) -> Tensor") && is_constant(attr::train) && !get<bool>(attr::train).value()) { + if (matches("aten::dropout(Tensor input, float p, bool train) -> Tensor") && is_constant(attr::train) && !get<bool>(attr::train).value()) { return false; } return true; diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h index 0bb5c899c7..6fdb49e7d9 100644 --- a/torch/csrc/jit/ir.h +++ b/torch/csrc/jit/ir.h @@ -1050,6 +1050,15 @@ public: result->output()->setType(CompleteTensorType::fromNumberType(typ)); return result; } + Node* createBoolToTensor(Value* value) { + auto typ = value->type(); + Node * result = create(prim::BoolToTensor, {value}); + if (!typ->isSubtypeOf(BoolType::get())) { + AT_ERROR("Cannot create bool type from ", typ->str()); + } + result->output()->setType(CompleteTensorType::fromBoolType()); + return result; + } Node* createTensorToNum(const TypePtr& type, Value* value) { auto* result = create(prim::TensorToNum, {value}); result->output()->setType(type); @@ -1060,6 +1069,11 @@ public: result->output()->setType(type); return result; } + Node* createTensorToBool(Value* value) { + auto* result = create(prim::TensorToBool, {value}); + result->output()->setType(BoolType::get()); + return result; + } Node* createIntToFloat(Value* value) { JIT_ASSERT(*value->type() == *IntType::get()); auto* result = create(prim::IntToFloat, {value}); diff --git a/torch/csrc/jit/operator.cpp b/torch/csrc/jit/operator.cpp index d701b536c4..818a58c965 100644 --- a/torch/csrc/jit/operator.cpp +++ b/torch/csrc/jit/operator.cpp @@ -57,7 +57,7 @@ struct SchemaParser { {"str", StringType::get() }, {"float", FloatType::get() }, {"int", IntType::get() }, - {"bool", IntType::get() }, // TODO: add separate bool type + {"bool", BoolType::get() }, {"World", WorldType::get() }, }; auto tok = L.expect(TK_IDENT); @@ -162,6 +162,10 @@ struct SchemaParser { return fmap(vs, [](IValue v) { return v.toInt(); }); + case TypeKind::BoolType: + return fmap(vs, [](IValue v) { + return v.toBool(); + }); default: throw ErrorReport(range) << "lists are only supported for float or int types."; } @@ -191,6 +195,7 @@ struct SchemaParser { } break; case TypeKind::NumberType: case TypeKind::IntType: + case TypeKind::BoolType: case TypeKind::FloatType: arg.default_value = parseSingleConstant(arg.type->kind()); break; diff --git a/torch/csrc/jit/passes/erase_number_types.cpp b/torch/csrc/jit/passes/erase_number_types.cpp index 7d2ba5bfeb..115d4b6a68 100644 --- a/torch/csrc/jit/passes/erase_number_types.cpp +++ b/torch/csrc/jit/passes/erase_number_types.cpp @@ -13,13 +13,16 @@ static void EraseNumberTypesOnBlock(Block* block) { case prim::Constant: { // remove primitive constants, replacing with tensor equivalent // ONNX does not support non-tensor constants - if(it->output()->type()->isSubtypeOf(NumberType::get())) { + if (it->output()->type()->isSubtypeOf(NumberType::get()) || + it->output()->type()->isSubtypeOf(BoolType::get())) { auto s = *constant_as<at::Scalar>(it->output()); WithInsertPoint guard(*it); Value* r = block->owningGraph()->insertConstant(scalar_to_tensor(s)); it->output()->replaceAllUsesWith(r); } } break; + case prim::TensorToBool: + case prim::BoolToTensor: case prim::TensorToNum: case prim::ImplicitTensorToNum: case prim::NumToTensor: { @@ -30,6 +33,8 @@ static void EraseNumberTypesOnBlock(Block* block) { for(auto o : it->outputs()) { if (o->type()->isSubtypeOf(NumberType::get())) { o->setType(CompleteTensorType::fromNumberType(o->type())); + } else if (o->type()->isSubtypeOf(BoolType::get())) { + o->setType(CompleteTensorType::fromBoolType()); } } } break; diff --git a/torch/csrc/jit/passes/peephole.cpp b/torch/csrc/jit/passes/peephole.cpp index 176166218f..7f2312696e 100644 --- a/torch/csrc/jit/passes/peephole.cpp +++ b/torch/csrc/jit/passes/peephole.cpp @@ -24,7 +24,7 @@ void PeepholeOptimize(Block * block) { // XXX: remember that if you want to simplify an expression by combining multiple nodes // into a different one, then you need to check that they all belong to the given block - if (node->matches("aten::expand(Tensor self, int[] size, *, int implicit) -> Tensor", + if (node->matches("aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor", /*with_const=*/attr::size)) { // x.expand(x.size()) == x if (auto input_type = node->namedInput(attr::self)->type()->cast<CompleteTensorType>()) { diff --git a/torch/csrc/jit/passes/remove_expands.cpp b/torch/csrc/jit/passes/remove_expands.cpp index 93d53e5481..1c67e68507 100644 --- a/torch/csrc/jit/passes/remove_expands.cpp +++ b/torch/csrc/jit/passes/remove_expands.cpp @@ -8,7 +8,8 @@ static void RemoveExpands(Block* block) { ++it) { for (auto sub : it->blocks()) RemoveExpands(sub); - if (it->kind() == aten::expand && it->get<int64_t>(attr::implicit) != static_cast<int64_t>(0)) { + + if (it->kind() == aten::expand && it->get<bool>(attr::implicit) == true) { it->output()->replaceAllUsesWith(it->namedInput(attr::self)); it.destroyCurrent(); } diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index eedc7fd0a8..af5e8cd96f 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -120,7 +120,7 @@ void broadcastBinary(Node *node, std::vector<CompleteTensorTypePtr>& types, size Node *expand = graph->create(aten::expand, {node->inputs().at(input_idx), graph->insertConstant(expected_size), - graph->insertConstant(0)}) + graph->insertConstant(false)}) ->insertBefore(node); PropagateShapeOnNode(expand); node->replaceInput(input_idx, expand->output()); @@ -441,12 +441,12 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) { "aten::clamp(Tensor self, Scalar min, Scalar max) -> Tensor", "aten::clamp_max(Tensor self, Scalar max) -> Tensor", "aten::clamp_min(Tensor self, Scalar min) -> Tensor", - "aten::alpha_dropout(Tensor input, float p, int train) -> Tensor", + "aten::alpha_dropout(Tensor input, float p, bool train) -> Tensor", "aten::bernoulli(Tensor self, float p, *, Generator generator) -> Tensor", "aten::cos(Tensor self) -> Tensor", "aten::cosh(Tensor self) -> Tensor", "aten::digamma(Tensor self) -> Tensor", - "aten::dropout(Tensor input, float p, int train) -> Tensor", + "aten::dropout(Tensor input, float p, bool train) -> Tensor", "aten::elu(Tensor self, Scalar alpha, Scalar scale, Scalar input_scale) -> Tensor", "aten::erf(Tensor self) -> Tensor", "aten::erfc(Tensor self) -> Tensor", @@ -462,8 +462,8 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) { "aten::floor(Tensor self) -> Tensor", "aten::frac(Tensor self) -> Tensor", "aten::flip(Tensor self, int[] dims) -> Tensor", - "aten::feature_alpha_dropout(Tensor input, float p, int train) -> Tensor", - "aten::feature_dropout(Tensor input, float p, int train) -> Tensor", + "aten::feature_alpha_dropout(Tensor input, float p, bool train) -> Tensor", + "aten::feature_dropout(Tensor input, float p, bool train) -> Tensor", "aten::hardshrink(Tensor self, Scalar lambd) -> Tensor", "aten::hardtanh(Tensor self, Scalar min_val, Scalar max_val) -> Tensor", "aten::glu(Tensor self, int dim) -> Tensor", @@ -479,7 +479,7 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) { "aten::reciprocal(Tensor self) -> Tensor", "aten::relu(Tensor self) -> Tensor", "aten::round(Tensor self) -> Tensor", - "aten::rrelu(Tensor self, Scalar lower, Scalar upper, int training, Generator generator) -> Tensor", + "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator generator) -> Tensor", "aten::rsqrt(Tensor self) -> Tensor", "aten::selu(Tensor self) -> Tensor", "aten::sigmoid(Tensor self) -> Tensor", @@ -645,7 +645,7 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) { // Knowing the type and device of weights or biases is usually enough to // infer the output type. static const register_formula_for nn_ops_first_input_preserving {{ - "aten::batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, int training, float momentum, float eps, int cudnn_enabled) -> Tensor", + "aten::batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor", "aten::conv1d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor", "aten::conv2d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor", "aten::conv3d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor", @@ -653,16 +653,16 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) { "aten::conv_transpose1d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor", "aten::conv_transpose2d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor", "aten::conv_transpose3d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor", - "aten::convolution(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, int transposed, int[] output_padding, int groups) -> Tensor", + "aten::convolution(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor", "aten::adaptive_avg_pool1d(Tensor self, int[] output_size) -> Tensor", "aten::adaptive_avg_pool2d(Tensor self, int[] output_size) -> Tensor", "aten::adaptive_avg_pool3d(Tensor self, int[] output_size) -> Tensor", - "aten::avg_pool1d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int ceil_mode, int count_include_pad) -> Tensor", - "aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int ceil_mode, int count_include_pad) -> Tensor", - "aten::avg_pool3d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int ceil_mode, int count_include_pad) -> Tensor", - "aten::max_pool1d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, int ceil_mode) -> Tensor", - "aten::max_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, int ceil_mode) -> Tensor", - "aten::max_pool3d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, int ceil_mode) -> Tensor", + "aten::avg_pool1d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor", + "aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor", + "aten::avg_pool3d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor", + "aten::max_pool1d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor", + "aten::max_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor", + "aten::max_pool3d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor", "aten::max_unpool2d(Tensor self, Tensor indices, int[] output_size) -> Tensor", "aten::max_unpool3d(Tensor self, Tensor indices, int[] output_size, int[] stride, int[] padding) -> Tensor", "aten::reflection_pad1d(Tensor self, int[] padding) -> Tensor", @@ -670,12 +670,12 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) { "aten::replication_pad1d(Tensor self, int[] padding) -> Tensor", "aten::replication_pad2d(Tensor self, int[] padding) -> Tensor", "aten::replication_pad3d(Tensor self, int[] padding) -> Tensor", - "aten::upsample_bilinear2d(Tensor self, int[] output_size, int align_corners) -> Tensor", - "aten::upsample_linear1d(Tensor self, int[] output_size, int align_corners) -> Tensor", + "aten::upsample_bilinear2d(Tensor self, int[] output_size, bool align_corners) -> Tensor", + "aten::upsample_linear1d(Tensor self, int[] output_size, bool align_corners) -> Tensor", "aten::upsample_nearest1d(Tensor self, int[] output_size) -> Tensor", "aten::upsample_nearest2d(Tensor self, int[] output_size) -> Tensor", "aten::upsample_nearest3d(Tensor self, int[] output_size) -> Tensor", - "aten::upsample_trilinear3d(Tensor self, int[] output_size, int align_corners) -> Tensor", + "aten::upsample_trilinear3d(Tensor self, int[] output_size, bool align_corners) -> Tensor", "aten::prelu(Tensor self, Tensor weight) -> Tensor", }, [](Node * node) -> type_vec_t { if (auto type = node->input(0)->type()->cast<TensorType>()) { @@ -702,10 +702,10 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) { "aten::mean(Tensor self) -> Tensor", "aten::median(Tensor self) -> Tensor", "aten::norm(Tensor self, Scalar p) -> Tensor", - "aten::std(Tensor self, int unbiased) -> Tensor", + "aten::std(Tensor self, bool unbiased) -> Tensor", "aten::sum(Tensor self) -> Tensor", "aten::trace(Tensor self) -> Tensor", - "aten::var(Tensor self, int unbiased) -> Tensor", + "aten::var(Tensor self, bool unbiased) -> Tensor", "aten::all(Tensor self) -> Tensor", "aten::any(Tensor self) -> Tensor", }, [](Node * node) -> type_vec_t { @@ -735,7 +735,7 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) { static const auto multidim_reduce_with_postprocess = [](Node * node, size_t num_reduced_dim, bool upcast_integer) -> type_vec_t { - auto maybe_keepdim = node->get<int64_t>(attr::keepdim); + auto maybe_keepdim = node->get<bool>(attr::keepdim); if (!maybe_keepdim) return {}; if (auto type = node->input(0)->type()->cast<TensorType>()) { if (upcast_integer && !at::isFloatingType(type->scalarType())) { @@ -760,24 +760,24 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) { // - First input should be the only tensor input // - Has a bool keepdim argument static const register_formula_for dim_reduce_ops {{ - "aten::argmax(Tensor self, int dim, int keepdim) -> Tensor", - "aten::argmin(Tensor self, int dim, int keepdim) -> Tensor", - "aten::max_values(Tensor self, int dim, int keepdim) -> Tensor", - "aten::min_values(Tensor self, int dim, int keepdim) -> Tensor", - "aten::mean(Tensor self, int dim, int keepdim) -> Tensor", - "aten::norm(Tensor self, Scalar p, int dim, int keepdim) -> Tensor", - "aten::std(Tensor self, int dim, int unbiased, int keepdim) -> Tensor", - "aten::var(Tensor self, int dim, int unbiased, int keepdim) -> Tensor", - "aten::logsumexp(Tensor self, int dim, int keepdim) -> Tensor", - "aten::all(Tensor self, int dim, int keepdim) -> Tensor", - "aten::any(Tensor self, int dim, int keepdim) -> Tensor", + "aten::argmax(Tensor self, int dim, bool keepdim) -> Tensor", + "aten::argmin(Tensor self, int dim, bool keepdim) -> Tensor", + "aten::max_values(Tensor self, int dim, bool keepdim) -> Tensor", + "aten::min_values(Tensor self, int dim, bool keepdim) -> Tensor", + "aten::mean(Tensor self, int dim, bool keepdim) -> Tensor", + "aten::norm(Tensor self, Scalar p, int dim, bool keepdim) -> Tensor", + "aten::std(Tensor self, int dim, bool unbiased, bool keepdim) -> Tensor", + "aten::var(Tensor self, int dim, bool unbiased, bool keepdim) -> Tensor", + "aten::logsumexp(Tensor self, int dim, bool keepdim) -> Tensor", + "aten::all(Tensor self, int dim, bool keepdim) -> Tensor", + "aten::any(Tensor self, int dim, bool keepdim) -> Tensor", // Ops returning indices as second output - "aten::kthvalue(Tensor self, int k, int dim, int keepdim) -> (Tensor, Tensor)", - "aten::max(Tensor self, int dim, int keepdim) -> (Tensor, Tensor)", - "aten::min(Tensor self, int dim, int keepdim) -> (Tensor, Tensor)", - "aten::median(Tensor self, int dim, int keepdim) -> (Tensor, Tensor)", - "aten::mode(Tensor self, int dim, int keepdim) -> (Tensor, Tensor)", + "aten::kthvalue(Tensor self, int k, int dim, bool keepdim) -> (Tensor, Tensor)", + "aten::max(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)", + "aten::min(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)", + "aten::median(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)", + "aten::mode(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)", }, [](Node * node) -> type_vec_t { // NB: Note that while this function is generally meant to be used with ops that // have a single output, we will fix up its return right below. @@ -798,7 +798,7 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) { // - First input should be the only tensor input // - has a bool keepdim argument static const register_formula_for dim_reduce_ops_with_integer_upcast {{ - "aten::prod(Tensor self, int dim, int keepdim) -> Tensor", + "aten::prod(Tensor self, int dim, bool keepdim) -> Tensor", }, [](Node * node) -> type_vec_t { return multidim_reduce_with_postprocess(node, /*num_reduce_dim=*/1, /*integer_upcast=*/true); }}; @@ -812,7 +812,7 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) { // Additionally: // - has bool keepdim and int[] dim arguments static const register_formula_for multidim_reduce_ops_with_integer_upcast {{ - "aten::sum(Tensor self, int[] dim, int keepdim) -> Tensor", + "aten::sum(Tensor self, int[] dim, bool keepdim) -> Tensor", }, [](Node * node) -> type_vec_t { if (auto dim = node->get<std::vector<int64_t>>(attr::dim)) { // TODO: can dim contain duplicates? @@ -900,14 +900,14 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) { } }; static const register_formula_for cast_ops {{ - "aten::_cast_Byte(Tensor self, int non_blocking) -> Tensor", - "aten::_cast_Char(Tensor self, int non_blocking) -> Tensor", - "aten::_cast_Double(Tensor self, int non_blocking) -> Tensor", - "aten::_cast_Float(Tensor self, int non_blocking) -> Tensor", - "aten::_cast_Half(Tensor self, int non_blocking) -> Tensor", - "aten::_cast_Int(Tensor self, int non_blocking) -> Tensor", - "aten::_cast_Long(Tensor self, int non_blocking) -> Tensor", - "aten::_cast_Short(Tensor self, int non_blocking) -> Tensor", + "aten::_cast_Byte(Tensor self, bool non_blocking) -> Tensor", + "aten::_cast_Char(Tensor self, bool non_blocking) -> Tensor", + "aten::_cast_Double(Tensor self, bool non_blocking) -> Tensor", + "aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor", + "aten::_cast_Half(Tensor self, bool non_blocking) -> Tensor", + "aten::_cast_Int(Tensor self, bool non_blocking) -> Tensor", + "aten::_cast_Long(Tensor self, bool non_blocking) -> Tensor", + "aten::_cast_Short(Tensor self, bool non_blocking) -> Tensor", }, [](Node * node) -> type_vec_t { if (auto type = node->namedInput(attr::self)->type()->cast<TensorType>()) { return {type->toScalarType(get_cast_scalar_type(node))}; @@ -1003,7 +1003,7 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) { } return true; } - } else if (node->matches("aten::embedding(Tensor weight, Tensor indices, int padding_idx, int scale_grad_by_freq, int sparse) -> Tensor")) { + } else if (node->matches("aten::embedding(Tensor weight, Tensor indices, int padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor")) { auto weight_type = input_type(0); auto indices_type = input_type(1); if (weight_type && indices_type) { @@ -1044,7 +1044,7 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) { node->matches("aten::reshape_as(Tensor self, Tensor other) -> Tensor")) { return tensor_types.at(0)->withDim(tensor_types.at(1)->dim()); } else if (node->matches("aten::view(Tensor self, int[] size) -> Tensor") || - node->matches("aten::expand(Tensor self, int[] size, *, int implicit) -> Tensor") || + node->matches("aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor") || node->matches("aten::as_strided(Tensor self, int[] size, int[] stride) -> Tensor") || node->matches("aten::as_strided(Tensor self, int[] size, int[] stride, int storage_offset) -> Tensor")) { return reshape_prop(node, attr::size, tensor_types); @@ -1189,12 +1189,12 @@ bool PropagateCompleteShapeOnNode(Node * node, bool insert_expands, } else if (node->matches("aten::sum(Tensor self) -> Tensor")) { node->output()->setType(tensor_types.at(0)->withSizes({})); return true; - } else if (node->matches("aten::sum(Tensor self, int[] dim, int keepdim) -> Tensor", + } else if (node->matches("aten::sum(Tensor self, int[] dim, bool keepdim) -> Tensor", /*with_const=*/{attr::dim, attr::keepdim})) { auto & tp = tensor_types.at(0); auto sizes = tp->sizes(); auto dims = node->get<std::vector<int64_t>>(attr::dim).value(); - bool keepdim = node->get<int64_t>(attr::keepdim).value(); + bool keepdim = node->get<bool>(attr::keepdim).value(); std::reverse(dims.begin(), dims.end()); for (int64_t dim : dims) { SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < sizes.size()); @@ -1262,7 +1262,7 @@ bool PropagateCompleteShapeOnNode(Node * node, bool insert_expands, node->output()->setType(tensor_types.at(1)->withSizes(tensor_types.at(0)->sizes())); } return true; - } else if (node->matches("aten::expand(Tensor self, int[] size, *, int implicit) -> Tensor", + } else if (node->matches("aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor", /*with_const=*/attr::size)) { auto tp = tensor_types.at(0); std::vector<int64_t> sizes, strides; diff --git a/torch/csrc/jit/passes/to_batch.cpp b/torch/csrc/jit/passes/to_batch.cpp index 0d56ca2255..f1edf80d41 100644 --- a/torch/csrc/jit/passes/to_batch.cpp +++ b/torch/csrc/jit/passes/to_batch.cpp @@ -41,6 +41,10 @@ void ToBatch::visitAten(Node* n, Block* block, Block* res_block){ auto to_tensor_node = res_graph->createNumToTensor(input); res_graph->insertNode(to_tensor_node); new_inputs[i] = to_tensor_node->output(); + } else if(input->type() == BoolType::get()) { + auto to_tensor_node = res_graph->createBoolToTensor(input); + res_graph->insertNode(to_tensor_node); + new_inputs[i] = to_tensor_node->output(); } } @@ -58,8 +62,11 @@ void ToBatch::visitAten(Node* n, Block* block, Block* res_block){ else if(n->outputs()[0]->type() == FloatType::get()){ to_scalar_node = res_graph->createTensorToNum(FloatType::get(), outputs[0]); } + else if(n->outputs()[0]->type() == BoolType::get()){ + to_scalar_node = res_graph->createTensorToBool(outputs[0]); + } else{ - throw std::runtime_error("NYI: scalar type other than int, float is not supported yet"); + throw std::runtime_error("NYI: scalar types other than int, float, and bool are not supported yet"); } res_graph->insertNode(to_scalar_node); rn_env[n->outputs()[0]] = to_scalar_node->output(); @@ -348,9 +355,10 @@ void ToBatch::visitLoop(Node* n, Block* block, Block* res_block){ if(cond_is_tensor){ auto cond = batch_map.at(n->inputs()[1]); auto cond_any = script::inlineCallTo(*res_block->owningGraph(), *getBatchOperator("any"), cond); - auto to_int_node = res_graph->createTensorToNum(IntType::get(), cond_any[0]); - res_graph->insertNode(to_int_node); - rn_env[n->inputs()[1]] = to_int_node->output(); + auto to_bool_node = + res_graph->createTensorToBool(cond_any[0]); + res_graph->insertNode(to_bool_node); + rn_env[n->inputs()[1]] = to_bool_node->output(); } for(size_t i = 2; i < n->inputs().size(); i++){ auto input = n->inputs()[i]; @@ -432,9 +440,10 @@ void ToBatch::visitLoop(Node* n, Block* block, Block* res_block){ if(cond_is_tensor){ auto cond = batch_map.at(n->blocks()[0]->outputs()[0]); auto cond_any = script::inlineCallTo(*res_block->owningGraph(), *getBatchOperator("any"), cond); - auto to_int_node = res_graph->createTensorToNum(IntType::get(), cond_any[0]); - res_graph->insertNode(to_int_node); - loop_block->insertOutput(0, to_int_node->output()); + auto to_bool_node = + res_graph->createTensorToBool(cond_any[0]); + res_graph->insertNode(to_bool_node); + loop_block->insertOutput(0, to_bool_node->output()); for(size_t i = 0; i < EXP_BTENSOR_SIZE; i++){ loop_block->insertOutput(i + 1, cond[i]); } @@ -491,6 +500,7 @@ void ToBatch::toBatch(Block* block, Block* res_block) { case prim::NumToTensor: visitNumToTensor(n, block, res_block); break; + case prim::TensorToBool: case prim::TensorToNum: visitTensorToNum(n, block, res_block); break; diff --git a/torch/csrc/jit/pybind_utils.h b/torch/csrc/jit/pybind_utils.h index 192cabca3f..bd81a8cc05 100644 --- a/torch/csrc/jit/pybind_utils.h +++ b/torch/csrc/jit/pybind_utils.h @@ -107,6 +107,8 @@ inline IValue toIValue(py::handle obj, const TypePtr& type) { return py::cast<int64_t>(obj); case TypeKind::NoneType: return {}; + case TypeKind::BoolType: + return py::cast<bool>(obj); case TypeKind::TupleType: { if(!PyTuple_Check(obj.ptr())) throw py::cast_error(); // note: the py::cast does not throw cast_error @@ -196,10 +198,14 @@ inline py::object toPyObject(IValue&& ivalue) { return py::cast(ivalue.toDouble()); } else if (ivalue.isInt()) { return py::cast(ivalue.toInt()); + }else if (ivalue.isBool()) { + return py::cast(ivalue.toBool()); } else if (ivalue.isIntList()) { return py::cast(ivalue.toIntListRef()); } else if (ivalue.isDoubleList()) { return py::cast(ivalue.toDoubleListRef()); + } else if (ivalue.isBoolList()) { + return py::cast(ivalue.toBoolListRef()); } else if (ivalue.isTensorList()) { return py::cast(ivalue.toTensorListRef()); } else if (ivalue.isGenericList()) { diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp index ad03ac556c..979b07d369 100644 --- a/torch/csrc/jit/python_ir.cpp +++ b/torch/csrc/jit/python_ir.cpp @@ -455,6 +455,8 @@ void initPythonIRBindings(PyObject * module_) { return "StringType"; case TypeKind::GeneratorType: return "GeneratorType"; + case TypeKind::BoolType: + return "BoolType"; case TypeKind::VarType: return "VarType"; case TypeKind::WorldType: @@ -491,6 +493,8 @@ void initPythonIRBindings(PyObject * module_) { .def_static("get", &FloatType::get); py::class_<DynamicType, Type, std::shared_ptr<DynamicType>>(m, "DynamicType") .def_static("get", &DynamicType::get); + py::class_<BoolType, Type, std::shared_ptr<BoolType>>(m, "BoolType") + .def_static("get", &BoolType::get); py::class_<TupleType, Type, std::shared_ptr<TupleType>>(m, "TupleType") .def(py::init([](std::vector<TypePtr> a){ return TupleType::create(a); })) diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index cdea4ab894..c7bf050dc2 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -70,6 +70,17 @@ RegisterOperators reg({ }; }), Operator( + prim::TensorToBool, + [](Node* node) -> Operation { + return [](Stack& stack) { + at::Tensor a; + pop(stack, a); + at::DeviceGuard guard(a); + push(stack, a.item<int64_t>() != 0); + return 0; + }; + }), + Operator( prim::TensorToNum, [](Node* node) -> Operation { if(node->output()->type() == IntType::get()) { @@ -124,6 +135,18 @@ RegisterOperators reg({ }; }), Operator( + prim::BoolToTensor, + [](Node* node) -> Operation { + return [](Stack& stack) { + bool b; + pop(stack, b); + push( + stack, + autograd::make_variable(at::scalar_to_tensor(b))); + return 0; + }; + }), + Operator( prim::IntToFloat, [](Node* node) -> Operation { return [](Stack& stack) { @@ -446,25 +469,25 @@ RegisterOperators reg({ }); // define implementations for primitive number ops -#define DEFINE_GENERIC_OP(aten_op, int_op, float_op, float_result) \ - Operator( \ - #aten_op "(int a, int b) -> int", \ - [](Node* node) { \ - return [=](Stack& stack) { \ - int64_t a, b; \ - pop(stack, a, b); \ - push(stack, int_op); \ - return 0; \ - }; \ - }), \ - Operator( \ - #aten_op "(float a, float b) -> " #float_result, [](Node* node) { \ - return [=](Stack& stack) { \ - double a, b; \ - pop(stack, a, b); \ - push(stack, float_op); \ - return 0; \ - }; \ +#define DEFINE_GENERIC_OP(aten_op, int_op, float_op, int_result, float_result) \ + Operator( \ + #aten_op "(int a, int b) -> " #int_result, \ + [](Node* node) { \ + return [=](Stack& stack) { \ + int64_t a, b; \ + pop(stack, a, b); \ + push(stack, int_op); \ + return 0; \ + }; \ + }), \ + Operator( \ + #aten_op "(float a, float b) -> " #float_result, [](Node* node) { \ + return [=](Stack& stack) { \ + double a, b; \ + pop(stack, a, b); \ + push(stack, float_op); \ + return 0; \ + }; \ }), #define DEFINE_INT_OP(aten_op, op) \ @@ -477,8 +500,19 @@ RegisterOperators reg({ }; \ }), -#define DEFINE_BINARY_OP(aten_op, op) DEFINE_GENERIC_OP(aten_op, op, op, float) -#define DEFINE_COMPARISON_OP(aten_op, op) DEFINE_GENERIC_OP(aten_op, op, op, int) +#define DEFINE_BINARY_OP(aten_op, op) \ + DEFINE_GENERIC_OP(aten_op, op, op, int, float) +#define DEFINE_COMPARISON_OP(aten_op, op) \ + DEFINE_GENERIC_OP(aten_op, op, op, bool, bool) +#define DEFINE_BOOL_OP(aten_op, op) \ + Operator(#aten_op "(bool a, bool b) -> bool", [](Node* node) { \ + return [=](Stack& stack) { \ + bool a, b; \ + pop(stack, a, b); \ + push(stack, op); \ + return 0; \ + }; \ + }), // Convert an python index (which may be negative) into an index usable for a // C++ container @@ -663,7 +697,7 @@ RegisterOperators reg2({ // Pass in two ops for handling int and float separately as % in C++ only works for int // The modulus calculation is different between C++ and Python (on negative), we preserve // the python behavior as it's more common and match python syntax, hence the conversion. - DEFINE_GENERIC_OP(aten::remainder, (b + (a % b)) % b, fmod((b + fmod(a, b)), b), float) + DEFINE_GENERIC_OP(aten::remainder, (b + (a % b)) % b, fmod((b + fmod(a, b)), b), int, float) // TODO: Support python floordiv (//) // Right now aten::floordiv is only used by loop unrolling @@ -696,8 +730,8 @@ RegisterOperators reg2({ DEFINE_COMPARISON_OP(aten::le, a <= b) DEFINE_COMPARISON_OP(aten::ge, a >= b) - DEFINE_INT_OP(aten::__and__, a&& b) - DEFINE_INT_OP(aten::__or__, a || b) + DEFINE_BOOL_OP(aten::__and__, a && b) + DEFINE_BOOL_OP(aten::__or__, a || b) Operator("aten::_construct_empty_int_list() -> int[]", [](Node* node) -> Operation { diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index 28aa735fc3..474722f536 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -74,6 +74,8 @@ static Value* typeCast(const SourceRange& loc, Value* value, TypePtr dst) { n = graph.createNumToTensor(value); } else if (dst->isSubtypeOf(NumberType::get()) && orig->isSubtypeOf(DynamicType::get())) { n = graph.createTensorToNum(dst, value); + } else if (dst->isSubtypeOf(BoolType::get()) && orig->isSubtypeOf(DynamicType::get())) { + n = graph.createTensorToBool(value); } else if(dst->isSubtypeOf(IntType::get()) && orig->isSubtypeOf(FloatType::get())) { n = graph.createFloatToInt(value); } else if(dst->isSubtypeOf(FloatType::get()) && orig->isSubtypeOf(IntType::get())) { @@ -324,7 +326,7 @@ struct Environment { {"print", std::make_shared<PrintValue>()}, {"float", std::make_shared<CastValue>(FloatType::get())}, {"int", std::make_shared<CastValue>(IntType::get())}, - {"bool", std::make_shared<CastValue>(IntType::get())}, + {"bool", std::make_shared<CastValue>(BoolType::get())}, // todo(zach): remove when we can correctly export torch.full via ONNX // or we have implicit conversion that can convert numbers to tensors {"_to_tensor", std::make_shared<CastValue>(DynamicType::get()) }, @@ -1048,9 +1050,9 @@ private: Value* emitCond(Expr cond) { Value* v = emitExpr(cond); - if (!v->type()->isSubtypeOf(IntType::get())) { + if (!v->type()->isSubtypeOf(BoolType::get())) { ErrorReport error(cond); - error << "expected an integer expression for condition but found " + error << "expected a boolean expression for condition but found " << v->type()->str(); if (v->type()->isSubtypeOf(DynamicType::get())) { error << ", to use a tensor in a boolean" @@ -1928,6 +1930,7 @@ const std::unordered_map<std::string, TypePtr> &ident_to_type_lut() { {"Tensor", DynamicType::get()}, {"int", IntType::get()}, {"float", FloatType::get()}, + {"bool", BoolType::get()}, }; return map; } diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index 4c7df820b1..fec69ac317 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -278,12 +278,12 @@ std::shared_ptr<SugaredValue> toSugaredValue( // f = f + 1 auto& g = *m.graph(); if (is_constant) { - if (py::isinstance<py::int_>(obj)) { + if (py::isinstance<py::bool_>(obj)) { + return toSimple(g.insertConstant(py::cast<bool>(obj), loc)); + } else if (py::isinstance<py::int_>(obj)) { return toSimple(g.insertConstant(py::cast<int64_t>(obj), loc)); } else if (py::isinstance<py::float_>(obj)) { return toSimple(g.insertConstant(py::cast<float>(obj), loc)); - } else if (py::isinstance<py::bool_>(obj)) { - return toSimple(g.insertConstant(py::cast<bool>(obj), loc)); } else if (THPDevice_Check(obj.ptr())) { auto device = (THPDevice*)obj.ptr(); std::vector<int64_t> v = {static_cast<int64_t>(device->device.type()), diff --git a/torch/csrc/jit/type.cpp b/torch/csrc/jit/type.cpp index 855adad429..ce3cd8c1f2 100644 --- a/torch/csrc/jit/type.cpp +++ b/torch/csrc/jit/type.cpp @@ -46,6 +46,8 @@ std::ostream& operator<<(std::ostream & out, const Type & t) { out << "float"; } else if(t.kind() == TypeKind::IntType) { out << "int"; + } else if(t.kind() == TypeKind::BoolType) { + out << "bool"; } else if(t.kind() == TypeKind::ListType) { auto prim = t.cast<ListType>()->getElementType(); out << *prim << "[]"; @@ -85,6 +87,10 @@ FloatTypePtr FloatType::get() { static auto value = FloatType::create(); return value; } +BoolTypePtr BoolType::get() { + static auto value = BoolType::create(); + return value; +} NoneTypePtr NoneType::get() { static auto value = NoneType::create(); return value; @@ -113,6 +119,10 @@ ListTypePtr ListType::ofFloats() { static auto value = ListType::create(FloatType::get()); return value; } +ListTypePtr ListType::ofBools() { + static auto value = ListType::create(BoolType::get()); + return value; +} TypePtr inferTypeFrom(const IValue& value) { if (value.isTensor()) { @@ -121,12 +131,16 @@ TypePtr inferTypeFrom(const IValue& value) { return FloatType::get(); } else if (value.isInt()) { return IntType::get(); + } else if (value.isBool()) { + return BoolType::get(); } else if (value.isString()) { return StringType::get(); } else if (value.isIntList()) { return ListType::ofInts(); } else if (value.isTensorList()) { return ListType::ofTensors(); + } else if (value.isBoolList()) { + return ListType::ofBools(); } else if (value.isDoubleList()) { return ListType::ofFloats(); } else if (value.isTuple()) { diff --git a/torch/csrc/jit/type.h b/torch/csrc/jit/type.h index 49748de239..69133f6be9 100644 --- a/torch/csrc/jit/type.h +++ b/torch/csrc/jit/type.h @@ -27,6 +27,7 @@ _(IntType) \ _(NoneType) \ _(StringType) \ _(GeneratorType) \ +_(BoolType) \ _(VarType) \ _(WorldType) \ @@ -343,6 +344,7 @@ struct TORCH_API CompleteTensorType : public TensorType { return prod; } static TypePtr fromNumberType(TypePtr typ); + static TypePtr fromBoolType(); private: CompleteTensorType(const at::Tensor& tensor) @@ -438,6 +440,7 @@ struct TORCH_API ListType : public Type { static ListTypePtr ofTensors(); static ListTypePtr ofInts(); static ListTypePtr ofFloats(); + static ListTypePtr ofBools(); static const TypeKind Kind = TypeKind::ListType; private: @@ -605,6 +608,31 @@ private: : Type(TypeKind::IntType) {} }; +struct BoolType; +using BoolTypePtr = std::shared_ptr<BoolType>; +// This node represents a Python bool value +struct TORCH_API BoolType : public Type { + template<typename ... T> + static BoolTypePtr create( T&& ... all ) { + return BoolTypePtr(new BoolType(std::forward<T>(all)... )); + } + bool operator==(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "bool"; + } + bool isSubtypeOf(const TypePtr rhs) const override { + return *this == *rhs || rhs->kind() == TypeKind::BoolType; + } + static const TypeKind Kind = TypeKind::BoolType; + // global singleton + static BoolTypePtr get(); +private: + BoolType() + : Type(TypeKind::BoolType) {} +}; + struct StringType; using StringTypePtr = std::shared_ptr<StringType>; // This node represents a Python string value @@ -728,10 +756,17 @@ inline TypePtr CompleteTensorType::fromNumberType(TypePtr typ) { return CompleteTensorType::create(at::kLong, -1, {}); } else if (typ->isSubtypeOf(FloatType::get())) { return CompleteTensorType::create(at::kFloat, -1, {}); + } else if (typ->isSubtypeOf(BoolType::get())) { + return CompleteTensorType::create(at::kLong, -1, {}); } AT_ERROR("unknown number type", typ->str()); } +inline TypePtr CompleteTensorType::fromBoolType() { + return CompleteTensorType::create(at::kLong, -1, {}); +} + + // Attempt to find the correct supertype of t1 and t2. If none is found then // nullopt will be returned. If t1 == t2, or t1 is a type refinement of t2, // then t2 will be returned (and vice versa). @@ -755,7 +790,7 @@ TypePtr getTypePtr() { template<> inline TypePtr getTypePtr<at::Tensor>() { return DynamicType::get(); } template<> inline TypePtr getTypePtr<double>() { return FloatType::get(); } template<> inline TypePtr getTypePtr<int64_t>() { return IntType::get(); } -template<> inline TypePtr getTypePtr<bool>() { return IntType::get(); } +template<> inline TypePtr getTypePtr<bool>() { return BoolType::get(); } template<> inline TypePtr getTypePtr<at::Scalar>() { return NumberType::get(); } template<> inline TypePtr getTypePtr<std::vector<at::Tensor>>() { return ListType::ofTensors(); } template<> inline TypePtr getTypePtr<std::vector<double>>() { return ListType::ofFloats(); } diff --git a/torch/jit/batchop.py b/torch/jit/batchop.py index 229cafbb94..fed8d2a7c6 100644 --- a/torch/jit/batchop.py +++ b/torch/jit/batchop.py @@ -214,8 +214,8 @@ def batch_where(data, mask, dims, data1, mask1, dims1, data2, mask2, dims2): @torch.jit.script -def batch_where_scalar(cond_, data1, mask1, dims1, data2, mask2, dims2): - cond = torch.zeros([1], dtype=torch.uint8) * cond_ +def batch_where_scalar(cond, data1, mask1, dims1, data2, mask2, dims2): + cond = torch.zeros([1], dtype=torch.uint8) res_data = torch.where(cond, data1, data2) res_mask = torch.where(cond, mask1, mask2) res_dims = torch.where(cond, dims1, dims2) @@ -304,7 +304,7 @@ def batch_unsqueeze(data, mask, dims, dim_): @torch.jit.script def batch_argmax(data, mask, dims, dim_, keepdim_): dim = int(dim_) - keepdim = int(keepdim_) + keepdim = bool(keepdim_) # if dim == 0: # raise ValueError("cannot do argmax along batch_dim") batch_size = data.size(0) @@ -338,8 +338,8 @@ def batch_argmax(data, mask, dims, dim_, keepdim_): def batch_topk(data, mask, dims, k_, dim_, largest_, sorted_): k = int(k_) dim = int(dim_) - largest = int(largest_) - sorted = int(sorted_) + largest = bool(largest_) + sorted = bool(sorted_) # if dim == 0: # raise ValueError("cannot do topk along batch_dim") batch_size = data.size(0) |