summaryrefslogtreecommitdiff
path: root/torch
diff options
context:
space:
mode:
authorDavid Riazati <davidriazati@fb.com>2018-10-03 12:25:39 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-10-03 12:40:03 -0700
commitd1ac1eba3b53f67d8d12eb20002b06893a2a4d2e (patch)
tree50f41b71264f05355a8a255fda94b582406c45aa /torch
parentc029c839a1900cb6caa0cd32c3db76fbe56460a0 (diff)
downloadpytorch-d1ac1eba3b53f67d8d12eb20002b06893a2a4d2e.tar.gz
pytorch-d1ac1eba3b53f67d8d12eb20002b06893a2a4d2e.tar.bz2
pytorch-d1ac1eba3b53f67d8d12eb20002b06893a2a4d2e.zip
Add `bool` type to IR (#11834)
Summary: This PR adds a bool type to `IValue` and puts it into place. * changes conds for `prim::If` and `prim::Loop` to use `bool` type * changes operators that take `bool`s to match their native ops * fixes ambiguous `aten` ops `aten::std` and `aten::var` * fixes tests in `test_jit.py TestJitGenerated` ``` 'test_std_dim', 'test_std_dim_1d', 'test_std_dim_1d_neg0', 'test_std_dim_neg0', 'test_var_dim', 'test_var_dim_1d', 'test_var_dim_1d_neg0', 'test_var_dim_neg0' ``` * adds `prim::BoolToTensor` and `prim::TensorToBool` apaszke zdevito Pull Request resolved: https://github.com/pytorch/pytorch/pull/11834 Differential Revision: D9928570 Pulled By: driazati fbshipit-source-id: 373c53df2f1a8ffa9e33d9a517002fbeef25f3eb
Diffstat (limited to 'torch')
-rw-r--r--torch/csrc/jit/attributes.h1
-rw-r--r--torch/csrc/jit/autodiff.cpp2
-rw-r--r--torch/csrc/jit/constants.cpp19
-rw-r--r--torch/csrc/jit/export.cpp2
-rw-r--r--torch/csrc/jit/import.cpp2
-rw-r--r--torch/csrc/jit/interned_strings.h2
-rw-r--r--torch/csrc/jit/interpreter.cpp27
-rw-r--r--torch/csrc/jit/ir.cpp10
-rw-r--r--torch/csrc/jit/ir.h14
-rw-r--r--torch/csrc/jit/operator.cpp7
-rw-r--r--torch/csrc/jit/passes/erase_number_types.cpp7
-rw-r--r--torch/csrc/jit/passes/peephole.cpp2
-rw-r--r--torch/csrc/jit/passes/remove_expands.cpp3
-rw-r--r--torch/csrc/jit/passes/shape_analysis.cpp102
-rw-r--r--torch/csrc/jit/passes/to_batch.cpp24
-rw-r--r--torch/csrc/jit/pybind_utils.h6
-rw-r--r--torch/csrc/jit/python_ir.cpp4
-rw-r--r--torch/csrc/jit/register_prim_ops.cpp82
-rw-r--r--torch/csrc/jit/script/compiler.cpp9
-rw-r--r--torch/csrc/jit/script/init.cpp6
-rw-r--r--torch/csrc/jit/type.cpp14
-rw-r--r--torch/csrc/jit/type.h37
-rw-r--r--torch/jit/batchop.py10
23 files changed, 274 insertions, 118 deletions
diff --git a/torch/csrc/jit/attributes.h b/torch/csrc/jit/attributes.h
index 2610643918..af3a16393f 100644
--- a/torch/csrc/jit/attributes.h
+++ b/torch/csrc/jit/attributes.h
@@ -167,6 +167,7 @@ struct Attributes {
const Kind##Attr::ValueType& method(Symbol name) const { \
return get<Kind##Attr>(name); \
}
+
CREATE_ACCESSOR(Float,f)
CREATE_ACCESSOR(Floats,fs)
CREATE_ACCESSOR(String,s)
diff --git a/torch/csrc/jit/autodiff.cpp b/torch/csrc/jit/autodiff.cpp
index a59c856eab..46840fc99a 100644
--- a/torch/csrc/jit/autodiff.cpp
+++ b/torch/csrc/jit/autodiff.cpp
@@ -280,7 +280,7 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val
} else if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
return {grads.at(0).mm(inputs.at(1).t()), inputs.at(0).t().mm(grads.at(0))};
- } else if (node->matches("aten::expand(Tensor self, int[] size, *, int implicit) -> Tensor")) {
+ } else if (node->matches("aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor")) {
const auto& input_sizes = inputs.at(0).sizes();
if (input_sizes.size() == 0)
return {grads.at(0).sum(), nullptr, nullptr};
diff --git a/torch/csrc/jit/constants.cpp b/torch/csrc/jit/constants.cpp
index 1633ac0de4..c62cdbd1fa 100644
--- a/torch/csrc/jit/constants.cpp
+++ b/torch/csrc/jit/constants.cpp
@@ -27,6 +27,13 @@ Value* insertConstant(
} else if(val.isDouble()) {
n->f_(attr::value, val.toDouble());
n->output()->setType(FloatType::get());
+ } else if (val.isBool()) {
+ n->i_(attr::value, val.toBool());
+ n->output()->setType(BoolType::get());
+ } else if (val.isBoolList()) {
+ auto bool_list = val.toBoolList()->elements();
+ n->is_(attr::value, std::vector<int64_t>(bool_list.begin(), bool_list.end()));
+ n->output()->setType(ListType::ofBools());
} else if(val.isIntList()) {
n->is_(attr::value, val.toIntList()->elements());
n->output()->setType(ListType::ofInts());
@@ -64,6 +71,12 @@ RegisterOperators reg({
stack.push_back(t);
return 0;
};
+ } else if (type->isSubtypeOf(BoolType::get())) {
+ bool b = node->i(attr::value);
+ return [b](Stack& stack) {
+ push(stack, b);
+ return 0;
+ };
} else if (
type->isSubtypeOf(NumberType::get()) &&
node->kindOf(attr::value) == AttributeKind::i) {
@@ -86,6 +99,12 @@ RegisterOperators reg({
push(stack, is);
return 0;
};
+ } else if(type->isSubtypeOf(ListType::ofBools())) {
+ auto bs = node->is(attr::value);
+ return [bs](Stack& stack) {
+ push(stack, bs);
+ return 0;
+ };
} else if(type->isSubtypeOf(ListType::ofTensors())) {
auto ts = fmap(node->ts(attr::value), [](const at::Tensor & t) -> at::Tensor {
return autograd::make_variable(t);
diff --git a/torch/csrc/jit/export.cpp b/torch/csrc/jit/export.cpp
index 973780b7d7..574dae168a 100644
--- a/torch/csrc/jit/export.cpp
+++ b/torch/csrc/jit/export.cpp
@@ -559,6 +559,8 @@ void ModuleEncoder::EncodeTypeInfo(
type_proto->set_denotation("FloatType");
} else if (kind == TypeKind::IntType) {
type_proto->set_denotation("IntType");
+ } else if (kind == TypeKind::BoolType) {
+ type_proto->set_denotation("BoolType");
} else if (kind == TypeKind::NoneType) {
type_proto->set_denotation("NoneType");
} else if (kind == TypeKind::GeneratorType) {
diff --git a/torch/csrc/jit/import.cpp b/torch/csrc/jit/import.cpp
index 4574addb3a..45053a81db 100644
--- a/torch/csrc/jit/import.cpp
+++ b/torch/csrc/jit/import.cpp
@@ -257,6 +257,8 @@ TypePtr ModuleDecoder::buildType(const onnx::TypeProto& type_proto) {
return FloatType::get();
} else if (kind == "IntType") {
return IntType::get();
+ } else if (kind == "BoolType") {
+ return BoolType::get();
} else if (kind == "NoneType") {
return NoneType::get();
} else if (kind == "GeneratorType") {
diff --git a/torch/csrc/jit/interned_strings.h b/torch/csrc/jit/interned_strings.h
index b4e6b7c139..32544aa682 100644
--- a/torch/csrc/jit/interned_strings.h
+++ b/torch/csrc/jit/interned_strings.h
@@ -46,9 +46,11 @@ namespace torch { namespace jit {
_(prim, TupleUnpack) \
_(prim, ListConstruct) \
_(prim, ListUnpack) \
+ _(prim, BoolToTensor) \
_(prim, NumToTensor) \
_(prim, TensorToNum) \
_(prim, ImplicitTensorToNum) \
+ _(prim, TensorToBool) \
_(prim, IntToFloat) \
_(prim, FloatToInt) \
_(prim, StringToFloat) \
diff --git a/torch/csrc/jit/interpreter.cpp b/torch/csrc/jit/interpreter.cpp
index 14e7fab54d..065054a0fc 100644
--- a/torch/csrc/jit/interpreter.cpp
+++ b/torch/csrc/jit/interpreter.cpp
@@ -61,14 +61,14 @@ Value* createTripCountConjunctiveCondition(
// Emit initial comparison -- initial_trip_count < max_trip_count
Value* initial_comparison_value =
g->insertNode(g->create(aten::lt, {cur_trip_count, max_trip_count}, 1))
- ->output()->setType(IntType::get());
+ ->output()->setType(BoolType::get());
// Replace initial condition with logical `and` of trip count and
// initial condition
Value* new_cond =
g->insertNode(
g->create(aten::__and__, {initial_comparison_value, cond}, 1))
- ->output()->setType(IntType::get());
+ ->output()->setType(BoolType::get());
return new_cond;
}
@@ -388,30 +388,29 @@ struct CodeImpl {
CodeImpl(std::shared_ptr<Graph>& graph_)
: preprocess(*graph_) {
graph = preprocess.graph;
- // std::cout << "into code graph:\n" << *graph << "\n";
insertNodesFromBlock(graph->block());
}
- // jump when input is 0
- void createJumpZ(int from_inst, int to_inst) {
+ // jump when input is false
+ void createJumpFalse(int from_inst, int to_inst) {
auto & inst = instructions[from_inst];
JIT_ASSERT(inst.debug_name == prim::Placeholder);
auto offset = relativeJump(from_inst, to_inst);
inst.callback = [offset](Stack & stack) {
- auto t = pop(stack).toInt();
- return (t == 0) ? offset : 0;
+ auto t = pop(stack).toBool();
+ return t ? 0 : offset;
};
inst.debug_name = prim::JumpZ;
}
- // jump when input is not 0
- void createJumpNZ(int from_inst, int to_inst) {
+ // jump when input is true
+ void createJumpTrue(int from_inst, int to_inst) {
auto & inst = instructions[from_inst];
JIT_ASSERT(inst.debug_name == prim::Placeholder);
auto offset = relativeJump(from_inst, to_inst);
inst.callback = [offset](Stack & stack) {
- auto t = pop(stack).toInt();
- return (t != 0) ? offset : 0;
+ auto t = pop(stack).toBool();
+ return t ? offset : 0;
};
inst.debug_name = prim::JumpNZ;
}
@@ -460,7 +459,7 @@ struct CodeImpl {
insertNodesFromBlock(then_block);
insertAssign(source_location, then_block->outputs(), moveFlags(then_block), node->outputs());
createJump(jump, instructions.size());
- createJumpNZ(cond_branch, then_block_start);
+ createJumpTrue(cond_branch, then_block_start);
} break;
case prim::Loop: {
// o0 = while c i0
@@ -495,8 +494,8 @@ struct CodeImpl {
// after branch: stack: ...
aliasRegistersTo(node->outputs(), body_block->inputs());
- createJumpZ(cond_branch, instructions.size());
- createJumpNZ(cond_branch_end, entry);
+ createJumpFalse(cond_branch, instructions.size());
+ createJumpTrue(cond_branch_end, entry);
} break;
default: {
insertInstruction(node);
diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp
index 90451494ba..4e4bb0b68d 100644
--- a/torch/csrc/jit/ir.cpp
+++ b/torch/csrc/jit/ir.cpp
@@ -562,18 +562,18 @@ namespace {
const OperatorSet& nondeterminstic_aten_ops() {
static OperatorSet nondeterministic_ops = {
- "aten::dropout(Tensor input, float p, int train) -> Tensor",
+ "aten::dropout(Tensor input, float p, bool train) -> Tensor",
"aten::_fused_dropout(Tensor self, float p, Generator generator) -> (Tensor, Tensor)",
"aten::_standard_gamma(Tensor self, Generator generator) -> Tensor",
"aten::bernoulli(Tensor self, *, Generator generator) -> Tensor",
"aten::bernoulli(Tensor self, float p, *, Generator generator) -> Tensor",
- "aten::multinomial(Tensor self, int num_samples, int replacement, *, Generator generator) -> Tensor",
+ "aten::multinomial(Tensor self, int num_samples, bool replacement, *, Generator generator) -> Tensor",
"aten::normal(Tensor mean, Tensor std, *, Generator generator) -> Tensor",
"aten::normal(float mean, Tensor std, *, Generator generator) -> Tensor",
"aten::normal(Tensor mean, float std, *, Generator generator) -> Tensor",
"aten::poisson(Tensor self, Generator generator) -> Tensor",
- "aten::rrelu(Tensor self, Scalar lower, Scalar upper, int training, Generator generator) -> Tensor",
- "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, int training, Generator generator) -> Tensor",
+ "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator generator) -> Tensor",
+ "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator generator) -> Tensor",
"aten::rand(int[] size, *, int dtype, int layout, int[] device) -> Tensor",
"aten::rand_like(Tensor self) -> Tensor",
"aten::rand_like(Tensor self, *, int dtype, int layout, int[] device) -> Tensor",
@@ -598,7 +598,7 @@ bool Node::isNondeterministic() const {
return false;
}
// Dropout with train = False is deterministic
- if (matches("aten::dropout(Tensor input, float p, int train) -> Tensor") && is_constant(attr::train) && !get<bool>(attr::train).value()) {
+ if (matches("aten::dropout(Tensor input, float p, bool train) -> Tensor") && is_constant(attr::train) && !get<bool>(attr::train).value()) {
return false;
}
return true;
diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h
index 0bb5c899c7..6fdb49e7d9 100644
--- a/torch/csrc/jit/ir.h
+++ b/torch/csrc/jit/ir.h
@@ -1050,6 +1050,15 @@ public:
result->output()->setType(CompleteTensorType::fromNumberType(typ));
return result;
}
+ Node* createBoolToTensor(Value* value) {
+ auto typ = value->type();
+ Node * result = create(prim::BoolToTensor, {value});
+ if (!typ->isSubtypeOf(BoolType::get())) {
+ AT_ERROR("Cannot create bool type from ", typ->str());
+ }
+ result->output()->setType(CompleteTensorType::fromBoolType());
+ return result;
+ }
Node* createTensorToNum(const TypePtr& type, Value* value) {
auto* result = create(prim::TensorToNum, {value});
result->output()->setType(type);
@@ -1060,6 +1069,11 @@ public:
result->output()->setType(type);
return result;
}
+ Node* createTensorToBool(Value* value) {
+ auto* result = create(prim::TensorToBool, {value});
+ result->output()->setType(BoolType::get());
+ return result;
+ }
Node* createIntToFloat(Value* value) {
JIT_ASSERT(*value->type() == *IntType::get());
auto* result = create(prim::IntToFloat, {value});
diff --git a/torch/csrc/jit/operator.cpp b/torch/csrc/jit/operator.cpp
index d701b536c4..818a58c965 100644
--- a/torch/csrc/jit/operator.cpp
+++ b/torch/csrc/jit/operator.cpp
@@ -57,7 +57,7 @@ struct SchemaParser {
{"str", StringType::get() },
{"float", FloatType::get() },
{"int", IntType::get() },
- {"bool", IntType::get() }, // TODO: add separate bool type
+ {"bool", BoolType::get() },
{"World", WorldType::get() },
};
auto tok = L.expect(TK_IDENT);
@@ -162,6 +162,10 @@ struct SchemaParser {
return fmap(vs, [](IValue v) {
return v.toInt();
});
+ case TypeKind::BoolType:
+ return fmap(vs, [](IValue v) {
+ return v.toBool();
+ });
default:
throw ErrorReport(range) << "lists are only supported for float or int types.";
}
@@ -191,6 +195,7 @@ struct SchemaParser {
} break;
case TypeKind::NumberType:
case TypeKind::IntType:
+ case TypeKind::BoolType:
case TypeKind::FloatType:
arg.default_value = parseSingleConstant(arg.type->kind());
break;
diff --git a/torch/csrc/jit/passes/erase_number_types.cpp b/torch/csrc/jit/passes/erase_number_types.cpp
index 7d2ba5bfeb..115d4b6a68 100644
--- a/torch/csrc/jit/passes/erase_number_types.cpp
+++ b/torch/csrc/jit/passes/erase_number_types.cpp
@@ -13,13 +13,16 @@ static void EraseNumberTypesOnBlock(Block* block) {
case prim::Constant: {
// remove primitive constants, replacing with tensor equivalent
// ONNX does not support non-tensor constants
- if(it->output()->type()->isSubtypeOf(NumberType::get())) {
+ if (it->output()->type()->isSubtypeOf(NumberType::get()) ||
+ it->output()->type()->isSubtypeOf(BoolType::get())) {
auto s = *constant_as<at::Scalar>(it->output());
WithInsertPoint guard(*it);
Value* r = block->owningGraph()->insertConstant(scalar_to_tensor(s));
it->output()->replaceAllUsesWith(r);
}
} break;
+ case prim::TensorToBool:
+ case prim::BoolToTensor:
case prim::TensorToNum:
case prim::ImplicitTensorToNum:
case prim::NumToTensor: {
@@ -30,6 +33,8 @@ static void EraseNumberTypesOnBlock(Block* block) {
for(auto o : it->outputs()) {
if (o->type()->isSubtypeOf(NumberType::get())) {
o->setType(CompleteTensorType::fromNumberType(o->type()));
+ } else if (o->type()->isSubtypeOf(BoolType::get())) {
+ o->setType(CompleteTensorType::fromBoolType());
}
}
} break;
diff --git a/torch/csrc/jit/passes/peephole.cpp b/torch/csrc/jit/passes/peephole.cpp
index 176166218f..7f2312696e 100644
--- a/torch/csrc/jit/passes/peephole.cpp
+++ b/torch/csrc/jit/passes/peephole.cpp
@@ -24,7 +24,7 @@ void PeepholeOptimize(Block * block) {
// XXX: remember that if you want to simplify an expression by combining multiple nodes
// into a different one, then you need to check that they all belong to the given block
- if (node->matches("aten::expand(Tensor self, int[] size, *, int implicit) -> Tensor",
+ if (node->matches("aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor",
/*with_const=*/attr::size)) {
// x.expand(x.size()) == x
if (auto input_type = node->namedInput(attr::self)->type()->cast<CompleteTensorType>()) {
diff --git a/torch/csrc/jit/passes/remove_expands.cpp b/torch/csrc/jit/passes/remove_expands.cpp
index 93d53e5481..1c67e68507 100644
--- a/torch/csrc/jit/passes/remove_expands.cpp
+++ b/torch/csrc/jit/passes/remove_expands.cpp
@@ -8,7 +8,8 @@ static void RemoveExpands(Block* block) {
++it) {
for (auto sub : it->blocks())
RemoveExpands(sub);
- if (it->kind() == aten::expand && it->get<int64_t>(attr::implicit) != static_cast<int64_t>(0)) {
+
+ if (it->kind() == aten::expand && it->get<bool>(attr::implicit) == true) {
it->output()->replaceAllUsesWith(it->namedInput(attr::self));
it.destroyCurrent();
}
diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp
index eedc7fd0a8..af5e8cd96f 100644
--- a/torch/csrc/jit/passes/shape_analysis.cpp
+++ b/torch/csrc/jit/passes/shape_analysis.cpp
@@ -120,7 +120,7 @@ void broadcastBinary(Node *node, std::vector<CompleteTensorTypePtr>& types, size
Node *expand = graph->create(aten::expand,
{node->inputs().at(input_idx),
graph->insertConstant(expected_size),
- graph->insertConstant(0)})
+ graph->insertConstant(false)})
->insertBefore(node);
PropagateShapeOnNode(expand);
node->replaceInput(input_idx, expand->output());
@@ -441,12 +441,12 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) {
"aten::clamp(Tensor self, Scalar min, Scalar max) -> Tensor",
"aten::clamp_max(Tensor self, Scalar max) -> Tensor",
"aten::clamp_min(Tensor self, Scalar min) -> Tensor",
- "aten::alpha_dropout(Tensor input, float p, int train) -> Tensor",
+ "aten::alpha_dropout(Tensor input, float p, bool train) -> Tensor",
"aten::bernoulli(Tensor self, float p, *, Generator generator) -> Tensor",
"aten::cos(Tensor self) -> Tensor",
"aten::cosh(Tensor self) -> Tensor",
"aten::digamma(Tensor self) -> Tensor",
- "aten::dropout(Tensor input, float p, int train) -> Tensor",
+ "aten::dropout(Tensor input, float p, bool train) -> Tensor",
"aten::elu(Tensor self, Scalar alpha, Scalar scale, Scalar input_scale) -> Tensor",
"aten::erf(Tensor self) -> Tensor",
"aten::erfc(Tensor self) -> Tensor",
@@ -462,8 +462,8 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) {
"aten::floor(Tensor self) -> Tensor",
"aten::frac(Tensor self) -> Tensor",
"aten::flip(Tensor self, int[] dims) -> Tensor",
- "aten::feature_alpha_dropout(Tensor input, float p, int train) -> Tensor",
- "aten::feature_dropout(Tensor input, float p, int train) -> Tensor",
+ "aten::feature_alpha_dropout(Tensor input, float p, bool train) -> Tensor",
+ "aten::feature_dropout(Tensor input, float p, bool train) -> Tensor",
"aten::hardshrink(Tensor self, Scalar lambd) -> Tensor",
"aten::hardtanh(Tensor self, Scalar min_val, Scalar max_val) -> Tensor",
"aten::glu(Tensor self, int dim) -> Tensor",
@@ -479,7 +479,7 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) {
"aten::reciprocal(Tensor self) -> Tensor",
"aten::relu(Tensor self) -> Tensor",
"aten::round(Tensor self) -> Tensor",
- "aten::rrelu(Tensor self, Scalar lower, Scalar upper, int training, Generator generator) -> Tensor",
+ "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator generator) -> Tensor",
"aten::rsqrt(Tensor self) -> Tensor",
"aten::selu(Tensor self) -> Tensor",
"aten::sigmoid(Tensor self) -> Tensor",
@@ -645,7 +645,7 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) {
// Knowing the type and device of weights or biases is usually enough to
// infer the output type.
static const register_formula_for nn_ops_first_input_preserving {{
- "aten::batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, int training, float momentum, float eps, int cudnn_enabled) -> Tensor",
+ "aten::batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
"aten::conv1d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor",
"aten::conv2d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor",
"aten::conv3d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor",
@@ -653,16 +653,16 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) {
"aten::conv_transpose1d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor",
"aten::conv_transpose2d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor",
"aten::conv_transpose3d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor",
- "aten::convolution(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, int transposed, int[] output_padding, int groups) -> Tensor",
+ "aten::convolution(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor",
"aten::adaptive_avg_pool1d(Tensor self, int[] output_size) -> Tensor",
"aten::adaptive_avg_pool2d(Tensor self, int[] output_size) -> Tensor",
"aten::adaptive_avg_pool3d(Tensor self, int[] output_size) -> Tensor",
- "aten::avg_pool1d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int ceil_mode, int count_include_pad) -> Tensor",
- "aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int ceil_mode, int count_include_pad) -> Tensor",
- "aten::avg_pool3d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int ceil_mode, int count_include_pad) -> Tensor",
- "aten::max_pool1d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, int ceil_mode) -> Tensor",
- "aten::max_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, int ceil_mode) -> Tensor",
- "aten::max_pool3d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, int ceil_mode) -> Tensor",
+ "aten::avg_pool1d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor",
+ "aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor",
+ "aten::avg_pool3d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor",
+ "aten::max_pool1d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor",
+ "aten::max_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor",
+ "aten::max_pool3d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor",
"aten::max_unpool2d(Tensor self, Tensor indices, int[] output_size) -> Tensor",
"aten::max_unpool3d(Tensor self, Tensor indices, int[] output_size, int[] stride, int[] padding) -> Tensor",
"aten::reflection_pad1d(Tensor self, int[] padding) -> Tensor",
@@ -670,12 +670,12 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) {
"aten::replication_pad1d(Tensor self, int[] padding) -> Tensor",
"aten::replication_pad2d(Tensor self, int[] padding) -> Tensor",
"aten::replication_pad3d(Tensor self, int[] padding) -> Tensor",
- "aten::upsample_bilinear2d(Tensor self, int[] output_size, int align_corners) -> Tensor",
- "aten::upsample_linear1d(Tensor self, int[] output_size, int align_corners) -> Tensor",
+ "aten::upsample_bilinear2d(Tensor self, int[] output_size, bool align_corners) -> Tensor",
+ "aten::upsample_linear1d(Tensor self, int[] output_size, bool align_corners) -> Tensor",
"aten::upsample_nearest1d(Tensor self, int[] output_size) -> Tensor",
"aten::upsample_nearest2d(Tensor self, int[] output_size) -> Tensor",
"aten::upsample_nearest3d(Tensor self, int[] output_size) -> Tensor",
- "aten::upsample_trilinear3d(Tensor self, int[] output_size, int align_corners) -> Tensor",
+ "aten::upsample_trilinear3d(Tensor self, int[] output_size, bool align_corners) -> Tensor",
"aten::prelu(Tensor self, Tensor weight) -> Tensor",
}, [](Node * node) -> type_vec_t {
if (auto type = node->input(0)->type()->cast<TensorType>()) {
@@ -702,10 +702,10 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) {
"aten::mean(Tensor self) -> Tensor",
"aten::median(Tensor self) -> Tensor",
"aten::norm(Tensor self, Scalar p) -> Tensor",
- "aten::std(Tensor self, int unbiased) -> Tensor",
+ "aten::std(Tensor self, bool unbiased) -> Tensor",
"aten::sum(Tensor self) -> Tensor",
"aten::trace(Tensor self) -> Tensor",
- "aten::var(Tensor self, int unbiased) -> Tensor",
+ "aten::var(Tensor self, bool unbiased) -> Tensor",
"aten::all(Tensor self) -> Tensor",
"aten::any(Tensor self) -> Tensor",
}, [](Node * node) -> type_vec_t {
@@ -735,7 +735,7 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) {
static const auto multidim_reduce_with_postprocess =
[](Node * node, size_t num_reduced_dim, bool upcast_integer) -> type_vec_t {
- auto maybe_keepdim = node->get<int64_t>(attr::keepdim);
+ auto maybe_keepdim = node->get<bool>(attr::keepdim);
if (!maybe_keepdim) return {};
if (auto type = node->input(0)->type()->cast<TensorType>()) {
if (upcast_integer && !at::isFloatingType(type->scalarType())) {
@@ -760,24 +760,24 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) {
// - First input should be the only tensor input
// - Has a bool keepdim argument
static const register_formula_for dim_reduce_ops {{
- "aten::argmax(Tensor self, int dim, int keepdim) -> Tensor",
- "aten::argmin(Tensor self, int dim, int keepdim) -> Tensor",
- "aten::max_values(Tensor self, int dim, int keepdim) -> Tensor",
- "aten::min_values(Tensor self, int dim, int keepdim) -> Tensor",
- "aten::mean(Tensor self, int dim, int keepdim) -> Tensor",
- "aten::norm(Tensor self, Scalar p, int dim, int keepdim) -> Tensor",
- "aten::std(Tensor self, int dim, int unbiased, int keepdim) -> Tensor",
- "aten::var(Tensor self, int dim, int unbiased, int keepdim) -> Tensor",
- "aten::logsumexp(Tensor self, int dim, int keepdim) -> Tensor",
- "aten::all(Tensor self, int dim, int keepdim) -> Tensor",
- "aten::any(Tensor self, int dim, int keepdim) -> Tensor",
+ "aten::argmax(Tensor self, int dim, bool keepdim) -> Tensor",
+ "aten::argmin(Tensor self, int dim, bool keepdim) -> Tensor",
+ "aten::max_values(Tensor self, int dim, bool keepdim) -> Tensor",
+ "aten::min_values(Tensor self, int dim, bool keepdim) -> Tensor",
+ "aten::mean(Tensor self, int dim, bool keepdim) -> Tensor",
+ "aten::norm(Tensor self, Scalar p, int dim, bool keepdim) -> Tensor",
+ "aten::std(Tensor self, int dim, bool unbiased, bool keepdim) -> Tensor",
+ "aten::var(Tensor self, int dim, bool unbiased, bool keepdim) -> Tensor",
+ "aten::logsumexp(Tensor self, int dim, bool keepdim) -> Tensor",
+ "aten::all(Tensor self, int dim, bool keepdim) -> Tensor",
+ "aten::any(Tensor self, int dim, bool keepdim) -> Tensor",
// Ops returning indices as second output
- "aten::kthvalue(Tensor self, int k, int dim, int keepdim) -> (Tensor, Tensor)",
- "aten::max(Tensor self, int dim, int keepdim) -> (Tensor, Tensor)",
- "aten::min(Tensor self, int dim, int keepdim) -> (Tensor, Tensor)",
- "aten::median(Tensor self, int dim, int keepdim) -> (Tensor, Tensor)",
- "aten::mode(Tensor self, int dim, int keepdim) -> (Tensor, Tensor)",
+ "aten::kthvalue(Tensor self, int k, int dim, bool keepdim) -> (Tensor, Tensor)",
+ "aten::max(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
+ "aten::min(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
+ "aten::median(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
+ "aten::mode(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
}, [](Node * node) -> type_vec_t {
// NB: Note that while this function is generally meant to be used with ops that
// have a single output, we will fix up its return right below.
@@ -798,7 +798,7 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) {
// - First input should be the only tensor input
// - has a bool keepdim argument
static const register_formula_for dim_reduce_ops_with_integer_upcast {{
- "aten::prod(Tensor self, int dim, int keepdim) -> Tensor",
+ "aten::prod(Tensor self, int dim, bool keepdim) -> Tensor",
}, [](Node * node) -> type_vec_t {
return multidim_reduce_with_postprocess(node, /*num_reduce_dim=*/1, /*integer_upcast=*/true);
}};
@@ -812,7 +812,7 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) {
// Additionally:
// - has bool keepdim and int[] dim arguments
static const register_formula_for multidim_reduce_ops_with_integer_upcast {{
- "aten::sum(Tensor self, int[] dim, int keepdim) -> Tensor",
+ "aten::sum(Tensor self, int[] dim, bool keepdim) -> Tensor",
}, [](Node * node) -> type_vec_t {
if (auto dim = node->get<std::vector<int64_t>>(attr::dim)) {
// TODO: can dim contain duplicates?
@@ -900,14 +900,14 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) {
}
};
static const register_formula_for cast_ops {{
- "aten::_cast_Byte(Tensor self, int non_blocking) -> Tensor",
- "aten::_cast_Char(Tensor self, int non_blocking) -> Tensor",
- "aten::_cast_Double(Tensor self, int non_blocking) -> Tensor",
- "aten::_cast_Float(Tensor self, int non_blocking) -> Tensor",
- "aten::_cast_Half(Tensor self, int non_blocking) -> Tensor",
- "aten::_cast_Int(Tensor self, int non_blocking) -> Tensor",
- "aten::_cast_Long(Tensor self, int non_blocking) -> Tensor",
- "aten::_cast_Short(Tensor self, int non_blocking) -> Tensor",
+ "aten::_cast_Byte(Tensor self, bool non_blocking) -> Tensor",
+ "aten::_cast_Char(Tensor self, bool non_blocking) -> Tensor",
+ "aten::_cast_Double(Tensor self, bool non_blocking) -> Tensor",
+ "aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor",
+ "aten::_cast_Half(Tensor self, bool non_blocking) -> Tensor",
+ "aten::_cast_Int(Tensor self, bool non_blocking) -> Tensor",
+ "aten::_cast_Long(Tensor self, bool non_blocking) -> Tensor",
+ "aten::_cast_Short(Tensor self, bool non_blocking) -> Tensor",
}, [](Node * node) -> type_vec_t {
if (auto type = node->namedInput(attr::self)->type()->cast<TensorType>()) {
return {type->toScalarType(get_cast_scalar_type(node))};
@@ -1003,7 +1003,7 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) {
}
return true;
}
- } else if (node->matches("aten::embedding(Tensor weight, Tensor indices, int padding_idx, int scale_grad_by_freq, int sparse) -> Tensor")) {
+ } else if (node->matches("aten::embedding(Tensor weight, Tensor indices, int padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor")) {
auto weight_type = input_type(0);
auto indices_type = input_type(1);
if (weight_type && indices_type) {
@@ -1044,7 +1044,7 @@ bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) {
node->matches("aten::reshape_as(Tensor self, Tensor other) -> Tensor")) {
return tensor_types.at(0)->withDim(tensor_types.at(1)->dim());
} else if (node->matches("aten::view(Tensor self, int[] size) -> Tensor") ||
- node->matches("aten::expand(Tensor self, int[] size, *, int implicit) -> Tensor") ||
+ node->matches("aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor") ||
node->matches("aten::as_strided(Tensor self, int[] size, int[] stride) -> Tensor") ||
node->matches("aten::as_strided(Tensor self, int[] size, int[] stride, int storage_offset) -> Tensor")) {
return reshape_prop(node, attr::size, tensor_types);
@@ -1189,12 +1189,12 @@ bool PropagateCompleteShapeOnNode(Node * node, bool insert_expands,
} else if (node->matches("aten::sum(Tensor self) -> Tensor")) {
node->output()->setType(tensor_types.at(0)->withSizes({}));
return true;
- } else if (node->matches("aten::sum(Tensor self, int[] dim, int keepdim) -> Tensor",
+ } else if (node->matches("aten::sum(Tensor self, int[] dim, bool keepdim) -> Tensor",
/*with_const=*/{attr::dim, attr::keepdim})) {
auto & tp = tensor_types.at(0);
auto sizes = tp->sizes();
auto dims = node->get<std::vector<int64_t>>(attr::dim).value();
- bool keepdim = node->get<int64_t>(attr::keepdim).value();
+ bool keepdim = node->get<bool>(attr::keepdim).value();
std::reverse(dims.begin(), dims.end());
for (int64_t dim : dims) {
SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < sizes.size());
@@ -1262,7 +1262,7 @@ bool PropagateCompleteShapeOnNode(Node * node, bool insert_expands,
node->output()->setType(tensor_types.at(1)->withSizes(tensor_types.at(0)->sizes()));
}
return true;
- } else if (node->matches("aten::expand(Tensor self, int[] size, *, int implicit) -> Tensor",
+ } else if (node->matches("aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor",
/*with_const=*/attr::size)) {
auto tp = tensor_types.at(0);
std::vector<int64_t> sizes, strides;
diff --git a/torch/csrc/jit/passes/to_batch.cpp b/torch/csrc/jit/passes/to_batch.cpp
index 0d56ca2255..f1edf80d41 100644
--- a/torch/csrc/jit/passes/to_batch.cpp
+++ b/torch/csrc/jit/passes/to_batch.cpp
@@ -41,6 +41,10 @@ void ToBatch::visitAten(Node* n, Block* block, Block* res_block){
auto to_tensor_node = res_graph->createNumToTensor(input);
res_graph->insertNode(to_tensor_node);
new_inputs[i] = to_tensor_node->output();
+ } else if(input->type() == BoolType::get()) {
+ auto to_tensor_node = res_graph->createBoolToTensor(input);
+ res_graph->insertNode(to_tensor_node);
+ new_inputs[i] = to_tensor_node->output();
}
}
@@ -58,8 +62,11 @@ void ToBatch::visitAten(Node* n, Block* block, Block* res_block){
else if(n->outputs()[0]->type() == FloatType::get()){
to_scalar_node = res_graph->createTensorToNum(FloatType::get(), outputs[0]);
}
+ else if(n->outputs()[0]->type() == BoolType::get()){
+ to_scalar_node = res_graph->createTensorToBool(outputs[0]);
+ }
else{
- throw std::runtime_error("NYI: scalar type other than int, float is not supported yet");
+ throw std::runtime_error("NYI: scalar types other than int, float, and bool are not supported yet");
}
res_graph->insertNode(to_scalar_node);
rn_env[n->outputs()[0]] = to_scalar_node->output();
@@ -348,9 +355,10 @@ void ToBatch::visitLoop(Node* n, Block* block, Block* res_block){
if(cond_is_tensor){
auto cond = batch_map.at(n->inputs()[1]);
auto cond_any = script::inlineCallTo(*res_block->owningGraph(), *getBatchOperator("any"), cond);
- auto to_int_node = res_graph->createTensorToNum(IntType::get(), cond_any[0]);
- res_graph->insertNode(to_int_node);
- rn_env[n->inputs()[1]] = to_int_node->output();
+ auto to_bool_node =
+ res_graph->createTensorToBool(cond_any[0]);
+ res_graph->insertNode(to_bool_node);
+ rn_env[n->inputs()[1]] = to_bool_node->output();
}
for(size_t i = 2; i < n->inputs().size(); i++){
auto input = n->inputs()[i];
@@ -432,9 +440,10 @@ void ToBatch::visitLoop(Node* n, Block* block, Block* res_block){
if(cond_is_tensor){
auto cond = batch_map.at(n->blocks()[0]->outputs()[0]);
auto cond_any = script::inlineCallTo(*res_block->owningGraph(), *getBatchOperator("any"), cond);
- auto to_int_node = res_graph->createTensorToNum(IntType::get(), cond_any[0]);
- res_graph->insertNode(to_int_node);
- loop_block->insertOutput(0, to_int_node->output());
+ auto to_bool_node =
+ res_graph->createTensorToBool(cond_any[0]);
+ res_graph->insertNode(to_bool_node);
+ loop_block->insertOutput(0, to_bool_node->output());
for(size_t i = 0; i < EXP_BTENSOR_SIZE; i++){
loop_block->insertOutput(i + 1, cond[i]);
}
@@ -491,6 +500,7 @@ void ToBatch::toBatch(Block* block, Block* res_block) {
case prim::NumToTensor:
visitNumToTensor(n, block, res_block);
break;
+ case prim::TensorToBool:
case prim::TensorToNum:
visitTensorToNum(n, block, res_block);
break;
diff --git a/torch/csrc/jit/pybind_utils.h b/torch/csrc/jit/pybind_utils.h
index 192cabca3f..bd81a8cc05 100644
--- a/torch/csrc/jit/pybind_utils.h
+++ b/torch/csrc/jit/pybind_utils.h
@@ -107,6 +107,8 @@ inline IValue toIValue(py::handle obj, const TypePtr& type) {
return py::cast<int64_t>(obj);
case TypeKind::NoneType:
return {};
+ case TypeKind::BoolType:
+ return py::cast<bool>(obj);
case TypeKind::TupleType: {
if(!PyTuple_Check(obj.ptr()))
throw py::cast_error(); // note: the py::cast does not throw cast_error
@@ -196,10 +198,14 @@ inline py::object toPyObject(IValue&& ivalue) {
return py::cast(ivalue.toDouble());
} else if (ivalue.isInt()) {
return py::cast(ivalue.toInt());
+ }else if (ivalue.isBool()) {
+ return py::cast(ivalue.toBool());
} else if (ivalue.isIntList()) {
return py::cast(ivalue.toIntListRef());
} else if (ivalue.isDoubleList()) {
return py::cast(ivalue.toDoubleListRef());
+ } else if (ivalue.isBoolList()) {
+ return py::cast(ivalue.toBoolListRef());
} else if (ivalue.isTensorList()) {
return py::cast(ivalue.toTensorListRef());
} else if (ivalue.isGenericList()) {
diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp
index ad03ac556c..979b07d369 100644
--- a/torch/csrc/jit/python_ir.cpp
+++ b/torch/csrc/jit/python_ir.cpp
@@ -455,6 +455,8 @@ void initPythonIRBindings(PyObject * module_) {
return "StringType";
case TypeKind::GeneratorType:
return "GeneratorType";
+ case TypeKind::BoolType:
+ return "BoolType";
case TypeKind::VarType:
return "VarType";
case TypeKind::WorldType:
@@ -491,6 +493,8 @@ void initPythonIRBindings(PyObject * module_) {
.def_static("get", &FloatType::get);
py::class_<DynamicType, Type, std::shared_ptr<DynamicType>>(m, "DynamicType")
.def_static("get", &DynamicType::get);
+ py::class_<BoolType, Type, std::shared_ptr<BoolType>>(m, "BoolType")
+ .def_static("get", &BoolType::get);
py::class_<TupleType, Type, std::shared_ptr<TupleType>>(m, "TupleType")
.def(py::init([](std::vector<TypePtr> a){ return TupleType::create(a); }))
diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp
index cdea4ab894..c7bf050dc2 100644
--- a/torch/csrc/jit/register_prim_ops.cpp
+++ b/torch/csrc/jit/register_prim_ops.cpp
@@ -70,6 +70,17 @@ RegisterOperators reg({
};
}),
Operator(
+ prim::TensorToBool,
+ [](Node* node) -> Operation {
+ return [](Stack& stack) {
+ at::Tensor a;
+ pop(stack, a);
+ at::DeviceGuard guard(a);
+ push(stack, a.item<int64_t>() != 0);
+ return 0;
+ };
+ }),
+ Operator(
prim::TensorToNum,
[](Node* node) -> Operation {
if(node->output()->type() == IntType::get()) {
@@ -124,6 +135,18 @@ RegisterOperators reg({
};
}),
Operator(
+ prim::BoolToTensor,
+ [](Node* node) -> Operation {
+ return [](Stack& stack) {
+ bool b;
+ pop(stack, b);
+ push(
+ stack,
+ autograd::make_variable(at::scalar_to_tensor(b)));
+ return 0;
+ };
+ }),
+ Operator(
prim::IntToFloat,
[](Node* node) -> Operation {
return [](Stack& stack) {
@@ -446,25 +469,25 @@ RegisterOperators reg({
});
// define implementations for primitive number ops
-#define DEFINE_GENERIC_OP(aten_op, int_op, float_op, float_result) \
- Operator( \
- #aten_op "(int a, int b) -> int", \
- [](Node* node) { \
- return [=](Stack& stack) { \
- int64_t a, b; \
- pop(stack, a, b); \
- push(stack, int_op); \
- return 0; \
- }; \
- }), \
- Operator( \
- #aten_op "(float a, float b) -> " #float_result, [](Node* node) { \
- return [=](Stack& stack) { \
- double a, b; \
- pop(stack, a, b); \
- push(stack, float_op); \
- return 0; \
- }; \
+#define DEFINE_GENERIC_OP(aten_op, int_op, float_op, int_result, float_result) \
+ Operator( \
+ #aten_op "(int a, int b) -> " #int_result, \
+ [](Node* node) { \
+ return [=](Stack& stack) { \
+ int64_t a, b; \
+ pop(stack, a, b); \
+ push(stack, int_op); \
+ return 0; \
+ }; \
+ }), \
+ Operator( \
+ #aten_op "(float a, float b) -> " #float_result, [](Node* node) { \
+ return [=](Stack& stack) { \
+ double a, b; \
+ pop(stack, a, b); \
+ push(stack, float_op); \
+ return 0; \
+ }; \
}),
#define DEFINE_INT_OP(aten_op, op) \
@@ -477,8 +500,19 @@ RegisterOperators reg({
}; \
}),
-#define DEFINE_BINARY_OP(aten_op, op) DEFINE_GENERIC_OP(aten_op, op, op, float)
-#define DEFINE_COMPARISON_OP(aten_op, op) DEFINE_GENERIC_OP(aten_op, op, op, int)
+#define DEFINE_BINARY_OP(aten_op, op) \
+ DEFINE_GENERIC_OP(aten_op, op, op, int, float)
+#define DEFINE_COMPARISON_OP(aten_op, op) \
+ DEFINE_GENERIC_OP(aten_op, op, op, bool, bool)
+#define DEFINE_BOOL_OP(aten_op, op) \
+ Operator(#aten_op "(bool a, bool b) -> bool", [](Node* node) { \
+ return [=](Stack& stack) { \
+ bool a, b; \
+ pop(stack, a, b); \
+ push(stack, op); \
+ return 0; \
+ }; \
+ }),
// Convert an python index (which may be negative) into an index usable for a
// C++ container
@@ -663,7 +697,7 @@ RegisterOperators reg2({
// Pass in two ops for handling int and float separately as % in C++ only works for int
// The modulus calculation is different between C++ and Python (on negative), we preserve
// the python behavior as it's more common and match python syntax, hence the conversion.
- DEFINE_GENERIC_OP(aten::remainder, (b + (a % b)) % b, fmod((b + fmod(a, b)), b), float)
+ DEFINE_GENERIC_OP(aten::remainder, (b + (a % b)) % b, fmod((b + fmod(a, b)), b), int, float)
// TODO: Support python floordiv (//)
// Right now aten::floordiv is only used by loop unrolling
@@ -696,8 +730,8 @@ RegisterOperators reg2({
DEFINE_COMPARISON_OP(aten::le, a <= b)
DEFINE_COMPARISON_OP(aten::ge, a >= b)
- DEFINE_INT_OP(aten::__and__, a&& b)
- DEFINE_INT_OP(aten::__or__, a || b)
+ DEFINE_BOOL_OP(aten::__and__, a && b)
+ DEFINE_BOOL_OP(aten::__or__, a || b)
Operator("aten::_construct_empty_int_list() -> int[]",
[](Node* node) -> Operation {
diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp
index 28aa735fc3..474722f536 100644
--- a/torch/csrc/jit/script/compiler.cpp
+++ b/torch/csrc/jit/script/compiler.cpp
@@ -74,6 +74,8 @@ static Value* typeCast(const SourceRange& loc, Value* value, TypePtr dst) {
n = graph.createNumToTensor(value);
} else if (dst->isSubtypeOf(NumberType::get()) && orig->isSubtypeOf(DynamicType::get())) {
n = graph.createTensorToNum(dst, value);
+ } else if (dst->isSubtypeOf(BoolType::get()) && orig->isSubtypeOf(DynamicType::get())) {
+ n = graph.createTensorToBool(value);
} else if(dst->isSubtypeOf(IntType::get()) && orig->isSubtypeOf(FloatType::get())) {
n = graph.createFloatToInt(value);
} else if(dst->isSubtypeOf(FloatType::get()) && orig->isSubtypeOf(IntType::get())) {
@@ -324,7 +326,7 @@ struct Environment {
{"print", std::make_shared<PrintValue>()},
{"float", std::make_shared<CastValue>(FloatType::get())},
{"int", std::make_shared<CastValue>(IntType::get())},
- {"bool", std::make_shared<CastValue>(IntType::get())},
+ {"bool", std::make_shared<CastValue>(BoolType::get())},
// todo(zach): remove when we can correctly export torch.full via ONNX
// or we have implicit conversion that can convert numbers to tensors
{"_to_tensor", std::make_shared<CastValue>(DynamicType::get()) },
@@ -1048,9 +1050,9 @@ private:
Value* emitCond(Expr cond) {
Value* v = emitExpr(cond);
- if (!v->type()->isSubtypeOf(IntType::get())) {
+ if (!v->type()->isSubtypeOf(BoolType::get())) {
ErrorReport error(cond);
- error << "expected an integer expression for condition but found "
+ error << "expected a boolean expression for condition but found "
<< v->type()->str();
if (v->type()->isSubtypeOf(DynamicType::get())) {
error << ", to use a tensor in a boolean"
@@ -1928,6 +1930,7 @@ const std::unordered_map<std::string, TypePtr> &ident_to_type_lut() {
{"Tensor", DynamicType::get()},
{"int", IntType::get()},
{"float", FloatType::get()},
+ {"bool", BoolType::get()},
};
return map;
}
diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp
index 4c7df820b1..fec69ac317 100644
--- a/torch/csrc/jit/script/init.cpp
+++ b/torch/csrc/jit/script/init.cpp
@@ -278,12 +278,12 @@ std::shared_ptr<SugaredValue> toSugaredValue(
// f = f + 1
auto& g = *m.graph();
if (is_constant) {
- if (py::isinstance<py::int_>(obj)) {
+ if (py::isinstance<py::bool_>(obj)) {
+ return toSimple(g.insertConstant(py::cast<bool>(obj), loc));
+ } else if (py::isinstance<py::int_>(obj)) {
return toSimple(g.insertConstant(py::cast<int64_t>(obj), loc));
} else if (py::isinstance<py::float_>(obj)) {
return toSimple(g.insertConstant(py::cast<float>(obj), loc));
- } else if (py::isinstance<py::bool_>(obj)) {
- return toSimple(g.insertConstant(py::cast<bool>(obj), loc));
} else if (THPDevice_Check(obj.ptr())) {
auto device = (THPDevice*)obj.ptr();
std::vector<int64_t> v = {static_cast<int64_t>(device->device.type()),
diff --git a/torch/csrc/jit/type.cpp b/torch/csrc/jit/type.cpp
index 855adad429..ce3cd8c1f2 100644
--- a/torch/csrc/jit/type.cpp
+++ b/torch/csrc/jit/type.cpp
@@ -46,6 +46,8 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
out << "float";
} else if(t.kind() == TypeKind::IntType) {
out << "int";
+ } else if(t.kind() == TypeKind::BoolType) {
+ out << "bool";
} else if(t.kind() == TypeKind::ListType) {
auto prim = t.cast<ListType>()->getElementType();
out << *prim << "[]";
@@ -85,6 +87,10 @@ FloatTypePtr FloatType::get() {
static auto value = FloatType::create();
return value;
}
+BoolTypePtr BoolType::get() {
+ static auto value = BoolType::create();
+ return value;
+}
NoneTypePtr NoneType::get() {
static auto value = NoneType::create();
return value;
@@ -113,6 +119,10 @@ ListTypePtr ListType::ofFloats() {
static auto value = ListType::create(FloatType::get());
return value;
}
+ListTypePtr ListType::ofBools() {
+ static auto value = ListType::create(BoolType::get());
+ return value;
+}
TypePtr inferTypeFrom(const IValue& value) {
if (value.isTensor()) {
@@ -121,12 +131,16 @@ TypePtr inferTypeFrom(const IValue& value) {
return FloatType::get();
} else if (value.isInt()) {
return IntType::get();
+ } else if (value.isBool()) {
+ return BoolType::get();
} else if (value.isString()) {
return StringType::get();
} else if (value.isIntList()) {
return ListType::ofInts();
} else if (value.isTensorList()) {
return ListType::ofTensors();
+ } else if (value.isBoolList()) {
+ return ListType::ofBools();
} else if (value.isDoubleList()) {
return ListType::ofFloats();
} else if (value.isTuple()) {
diff --git a/torch/csrc/jit/type.h b/torch/csrc/jit/type.h
index 49748de239..69133f6be9 100644
--- a/torch/csrc/jit/type.h
+++ b/torch/csrc/jit/type.h
@@ -27,6 +27,7 @@ _(IntType) \
_(NoneType) \
_(StringType) \
_(GeneratorType) \
+_(BoolType) \
_(VarType) \
_(WorldType) \
@@ -343,6 +344,7 @@ struct TORCH_API CompleteTensorType : public TensorType {
return prod;
}
static TypePtr fromNumberType(TypePtr typ);
+ static TypePtr fromBoolType();
private:
CompleteTensorType(const at::Tensor& tensor)
@@ -438,6 +440,7 @@ struct TORCH_API ListType : public Type {
static ListTypePtr ofTensors();
static ListTypePtr ofInts();
static ListTypePtr ofFloats();
+ static ListTypePtr ofBools();
static const TypeKind Kind = TypeKind::ListType;
private:
@@ -605,6 +608,31 @@ private:
: Type(TypeKind::IntType) {}
};
+struct BoolType;
+using BoolTypePtr = std::shared_ptr<BoolType>;
+// This node represents a Python bool value
+struct TORCH_API BoolType : public Type {
+ template<typename ... T>
+ static BoolTypePtr create( T&& ... all ) {
+ return BoolTypePtr(new BoolType(std::forward<T>(all)... ));
+ }
+ bool operator==(const Type& rhs) const override {
+ return rhs.kind() == kind();
+ }
+ std::string str() const override {
+ return "bool";
+ }
+ bool isSubtypeOf(const TypePtr rhs) const override {
+ return *this == *rhs || rhs->kind() == TypeKind::BoolType;
+ }
+ static const TypeKind Kind = TypeKind::BoolType;
+ // global singleton
+ static BoolTypePtr get();
+private:
+ BoolType()
+ : Type(TypeKind::BoolType) {}
+};
+
struct StringType;
using StringTypePtr = std::shared_ptr<StringType>;
// This node represents a Python string value
@@ -728,10 +756,17 @@ inline TypePtr CompleteTensorType::fromNumberType(TypePtr typ) {
return CompleteTensorType::create(at::kLong, -1, {});
} else if (typ->isSubtypeOf(FloatType::get())) {
return CompleteTensorType::create(at::kFloat, -1, {});
+ } else if (typ->isSubtypeOf(BoolType::get())) {
+ return CompleteTensorType::create(at::kLong, -1, {});
}
AT_ERROR("unknown number type", typ->str());
}
+inline TypePtr CompleteTensorType::fromBoolType() {
+ return CompleteTensorType::create(at::kLong, -1, {});
+}
+
+
// Attempt to find the correct supertype of t1 and t2. If none is found then
// nullopt will be returned. If t1 == t2, or t1 is a type refinement of t2,
// then t2 will be returned (and vice versa).
@@ -755,7 +790,7 @@ TypePtr getTypePtr() {
template<> inline TypePtr getTypePtr<at::Tensor>() { return DynamicType::get(); }
template<> inline TypePtr getTypePtr<double>() { return FloatType::get(); }
template<> inline TypePtr getTypePtr<int64_t>() { return IntType::get(); }
-template<> inline TypePtr getTypePtr<bool>() { return IntType::get(); }
+template<> inline TypePtr getTypePtr<bool>() { return BoolType::get(); }
template<> inline TypePtr getTypePtr<at::Scalar>() { return NumberType::get(); }
template<> inline TypePtr getTypePtr<std::vector<at::Tensor>>() { return ListType::ofTensors(); }
template<> inline TypePtr getTypePtr<std::vector<double>>() { return ListType::ofFloats(); }
diff --git a/torch/jit/batchop.py b/torch/jit/batchop.py
index 229cafbb94..fed8d2a7c6 100644
--- a/torch/jit/batchop.py
+++ b/torch/jit/batchop.py
@@ -214,8 +214,8 @@ def batch_where(data, mask, dims, data1, mask1, dims1, data2, mask2, dims2):
@torch.jit.script
-def batch_where_scalar(cond_, data1, mask1, dims1, data2, mask2, dims2):
- cond = torch.zeros([1], dtype=torch.uint8) * cond_
+def batch_where_scalar(cond, data1, mask1, dims1, data2, mask2, dims2):
+ cond = torch.zeros([1], dtype=torch.uint8)
res_data = torch.where(cond, data1, data2)
res_mask = torch.where(cond, mask1, mask2)
res_dims = torch.where(cond, dims1, dims2)
@@ -304,7 +304,7 @@ def batch_unsqueeze(data, mask, dims, dim_):
@torch.jit.script
def batch_argmax(data, mask, dims, dim_, keepdim_):
dim = int(dim_)
- keepdim = int(keepdim_)
+ keepdim = bool(keepdim_)
# if dim == 0:
# raise ValueError("cannot do argmax along batch_dim")
batch_size = data.size(0)
@@ -338,8 +338,8 @@ def batch_argmax(data, mask, dims, dim_, keepdim_):
def batch_topk(data, mask, dims, k_, dim_, largest_, sorted_):
k = int(k_)
dim = int(dim_)
- largest = int(largest_)
- sorted = int(sorted_)
+ largest = bool(largest_)
+ sorted = bool(sorted_)
# if dim == 0:
# raise ValueError("cannot do topk along batch_dim")
batch_size = data.size(0)