diff options
author | Elias Ellison <eellison@fb.com> | 2019-04-17 16:01:41 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-04-17 16:06:48 -0700 |
commit | 4371cb5e0193d2eaa8d23673eb153874113eab4e (patch) | |
tree | 47709446d44d0bbfcd03b2459bca357fb594d279 /torch | |
parent | d6b91075dc79af5022206dac730732fd1edcb488 (diff) | |
download | pytorch-4371cb5e0193d2eaa8d23673eb153874113eab4e.tar.gz pytorch-4371cb5e0193d2eaa8d23673eb153874113eab4e.tar.bz2 pytorch-4371cb5e0193d2eaa8d23673eb153874113eab4e.zip |
Cast not expressions to bool (#19361)
Summary:
As part of implicitly casting condition statements, we should be casting not expressions as well.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19361
Differential Revision: D14984275
Pulled By: eellison
fbshipit-source-id: f8dae64f74777154c25f7a6bcdac03cf44cbb60b
Diffstat (limited to 'torch')
-rw-r--r-- | torch/csrc/jit/script/compiler.cpp | 30 |
1 files changed, 17 insertions, 13 deletions
diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index 33c7168422..d17ce638e2 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -574,10 +574,7 @@ struct to_ir { ConstantPooling(to_clean); } - FunctionSchema emitDef( - const Def& def, - const Self& self, - Block* block) { + FunctionSchema emitDef(const Def& def, const Self& self, Block* block) { auto schema = extractSchemaFromDef(def, self); // TODO need guards on init returning none if (schema.returns().size() == 1) { @@ -630,9 +627,7 @@ struct to_ir { return stack.at(0).toTuple()->elements(); } - std::vector<Argument> parseArgsFromDecl( - const Decl& decl, - const Self& self) { + std::vector<Argument> parseArgsFromDecl(const Decl& decl, const Self& self) { auto params_begin = decl.params().begin(); auto params_end = decl.params().end(); if (self) { @@ -704,9 +699,7 @@ struct to_ir { /*default_value =*/c10::nullopt, /*kwarg_only =*/false)}; } - FunctionSchema extractSchemaFromDef( - const Def& def, - const Self& self) { + FunctionSchema extractSchemaFromDef(const Def& def, const Self& self) { const auto name = def.name().name(); std::vector<Argument> args = parseArgsFromDecl(def.decl(), self); std::vector<Argument> returns = parseReturnFromDecl(def.decl()); @@ -719,7 +712,6 @@ struct to_ir { const Self& self, const FunctionSchema& schema, Block* block) { - std::vector<Argument> arguments; // for schema // inputs auto it = def.decl().params().begin(); @@ -740,7 +732,8 @@ struct to_ir { AT_ASSERT(it != end); const auto& name = (*it).ident().name(); Value* new_input = block->addInput()->setUniqueName(name); - environment_stack->setSugaredVar((*it).ident().range(), name, self(new_input)); + environment_stack->setSugaredVar( + (*it).ident().range(), name, self(new_input)); arguments.emplace_back(name, new_input->type()); ++it; } @@ -2273,7 +2266,6 @@ struct to_ir { case TK_POW: case TK_IS: case TK_ISNOT: - case TK_NOT: case TK_NE: case TK_EQ: case '<': @@ -2301,6 +2293,18 @@ struct to_ir { {}, /*required=*/true); } + case TK_NOT: { + Value* input = emitCond(Expr(tree->trees()[0])); + return emitBuiltinCall( + tree->range(), + *method.graph(), + aten::__not__, + c10::nullopt, + {input}, + {}, + /*required=*/true); + } + case TK_UNARY_MINUS: { return emitNegate(tree); } |