summaryrefslogtreecommitdiff
path: root/torch
diff options
context:
space:
mode:
authorElias Ellison <eellison@fb.com>2019-04-17 16:01:41 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-04-17 16:06:48 -0700
commit4371cb5e0193d2eaa8d23673eb153874113eab4e (patch)
tree47709446d44d0bbfcd03b2459bca357fb594d279 /torch
parentd6b91075dc79af5022206dac730732fd1edcb488 (diff)
downloadpytorch-4371cb5e0193d2eaa8d23673eb153874113eab4e.tar.gz
pytorch-4371cb5e0193d2eaa8d23673eb153874113eab4e.tar.bz2
pytorch-4371cb5e0193d2eaa8d23673eb153874113eab4e.zip
Cast not expressions to bool (#19361)
Summary: As part of implicitly casting condition statements, we should be casting not expressions as well. Pull Request resolved: https://github.com/pytorch/pytorch/pull/19361 Differential Revision: D14984275 Pulled By: eellison fbshipit-source-id: f8dae64f74777154c25f7a6bcdac03cf44cbb60b
Diffstat (limited to 'torch')
-rw-r--r--torch/csrc/jit/script/compiler.cpp30
1 files changed, 17 insertions, 13 deletions
diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp
index 33c7168422..d17ce638e2 100644
--- a/torch/csrc/jit/script/compiler.cpp
+++ b/torch/csrc/jit/script/compiler.cpp
@@ -574,10 +574,7 @@ struct to_ir {
ConstantPooling(to_clean);
}
- FunctionSchema emitDef(
- const Def& def,
- const Self& self,
- Block* block) {
+ FunctionSchema emitDef(const Def& def, const Self& self, Block* block) {
auto schema = extractSchemaFromDef(def, self);
// TODO need guards on init returning none
if (schema.returns().size() == 1) {
@@ -630,9 +627,7 @@ struct to_ir {
return stack.at(0).toTuple()->elements();
}
- std::vector<Argument> parseArgsFromDecl(
- const Decl& decl,
- const Self& self) {
+ std::vector<Argument> parseArgsFromDecl(const Decl& decl, const Self& self) {
auto params_begin = decl.params().begin();
auto params_end = decl.params().end();
if (self) {
@@ -704,9 +699,7 @@ struct to_ir {
/*default_value =*/c10::nullopt,
/*kwarg_only =*/false)};
}
- FunctionSchema extractSchemaFromDef(
- const Def& def,
- const Self& self) {
+ FunctionSchema extractSchemaFromDef(const Def& def, const Self& self) {
const auto name = def.name().name();
std::vector<Argument> args = parseArgsFromDecl(def.decl(), self);
std::vector<Argument> returns = parseReturnFromDecl(def.decl());
@@ -719,7 +712,6 @@ struct to_ir {
const Self& self,
const FunctionSchema& schema,
Block* block) {
-
std::vector<Argument> arguments; // for schema
// inputs
auto it = def.decl().params().begin();
@@ -740,7 +732,8 @@ struct to_ir {
AT_ASSERT(it != end);
const auto& name = (*it).ident().name();
Value* new_input = block->addInput()->setUniqueName(name);
- environment_stack->setSugaredVar((*it).ident().range(), name, self(new_input));
+ environment_stack->setSugaredVar(
+ (*it).ident().range(), name, self(new_input));
arguments.emplace_back(name, new_input->type());
++it;
}
@@ -2273,7 +2266,6 @@ struct to_ir {
case TK_POW:
case TK_IS:
case TK_ISNOT:
- case TK_NOT:
case TK_NE:
case TK_EQ:
case '<':
@@ -2301,6 +2293,18 @@ struct to_ir {
{},
/*required=*/true);
}
+ case TK_NOT: {
+ Value* input = emitCond(Expr(tree->trees()[0]));
+ return emitBuiltinCall(
+ tree->range(),
+ *method.graph(),
+ aten::__not__,
+ c10::nullopt,
+ {input},
+ {},
+ /*required=*/true);
+ }
+
case TK_UNARY_MINUS: {
return emitNegate(tree);
}