diff options
author | James Reed <jamesreed@fb.com> | 2018-02-18 01:53:13 -0800 |
---|---|---|
committer | Adam Paszke <adam.paszke@gmail.com> | 2018-02-18 09:53:13 +0000 |
commit | 5eefe87d4eab296c0eb28394f9daf4659d03d890 (patch) | |
tree | d0cb2981645c4beb1af1900cc3364394c412dbbb /torch | |
parent | 9193dfd185ac526d0d71c6fd100a52b0b2e10e58 (diff) | |
download | pytorch-5eefe87d4eab296c0eb28394f9daf4659d03d890.tar.gz pytorch-5eefe87d4eab296c0eb28394f9daf4659d03d890.tar.bz2 pytorch-5eefe87d4eab296c0eb28394f9daf4659d03d890.zip |
Emit ternary if in script compiler (#5291)
Diffstat (limited to 'torch')
-rw-r--r-- | torch/csrc/jit/script/compiler.cpp | 32 |
1 files changed, 30 insertions, 2 deletions
diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index 7ee58e8692..fa10124c1d 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -222,6 +222,32 @@ struct to_ir { return save_env; } + std::vector<Value*> emitTernaryIf(const TernaryIf& expr) { + Value* cond_value = emitExpr(expr.cond(), 1)[0]; + + Node* n = def.graph->insertNode(def.graph->create(kIf, 0)); + n->addInput(cond_value); + auto* true_block = n->addBlock(); + auto* false_block = n->addBlock(); + + auto emit_if_expr = [this](Block* b, const Expr& expr) { + environment_stack = std::make_shared<Environment>(b, environment_stack); + WithInsertPoint guard(*def.graph, b); + Value* out_val = emitExpr(expr, 1)[0]; + b->registerOutput(out_val); + + environment_stack = environment_stack->next; + }; + + emit_if_expr(true_block, expr.true_expr()); + emit_if_expr(false_block, expr.false_expr()); + + // Add op outputs + auto expr_value = n->addOutput(); // Resulting value + + return {expr_value}; + } + void emitIf(const If& stmt) { Value* cond_value = emitExpr(stmt.cond(), 1)[0]; @@ -510,8 +536,10 @@ struct to_ir { } break; case '.': // TODO: add support for "." - case TK_IF_EXPR: - // TODO: add support for conditional + case TK_IF_EXPR: { + expectOutputs(tree, output_size, 1); + return emitTernaryIf(TernaryIf(tree)); + } break; default: throw ErrorReport(tree) << "NYI: " << tree; break; |