summaryrefslogtreecommitdiff
path: root/torch
diff options
context:
space:
mode:
authorJames Reed <jamesreed@fb.com>2018-02-18 01:53:13 -0800
committerAdam Paszke <adam.paszke@gmail.com>2018-02-18 09:53:13 +0000
commit5eefe87d4eab296c0eb28394f9daf4659d03d890 (patch)
treed0cb2981645c4beb1af1900cc3364394c412dbbb /torch
parent9193dfd185ac526d0d71c6fd100a52b0b2e10e58 (diff)
downloadpytorch-5eefe87d4eab296c0eb28394f9daf4659d03d890.tar.gz
pytorch-5eefe87d4eab296c0eb28394f9daf4659d03d890.tar.bz2
pytorch-5eefe87d4eab296c0eb28394f9daf4659d03d890.zip
Emit ternary if in script compiler (#5291)
Diffstat (limited to 'torch')
-rw-r--r--torch/csrc/jit/script/compiler.cpp32
1 files changed, 30 insertions, 2 deletions
diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp
index 7ee58e8692..fa10124c1d 100644
--- a/torch/csrc/jit/script/compiler.cpp
+++ b/torch/csrc/jit/script/compiler.cpp
@@ -222,6 +222,32 @@ struct to_ir {
return save_env;
}
+ std::vector<Value*> emitTernaryIf(const TernaryIf& expr) {
+ Value* cond_value = emitExpr(expr.cond(), 1)[0];
+
+ Node* n = def.graph->insertNode(def.graph->create(kIf, 0));
+ n->addInput(cond_value);
+ auto* true_block = n->addBlock();
+ auto* false_block = n->addBlock();
+
+ auto emit_if_expr = [this](Block* b, const Expr& expr) {
+ environment_stack = std::make_shared<Environment>(b, environment_stack);
+ WithInsertPoint guard(*def.graph, b);
+ Value* out_val = emitExpr(expr, 1)[0];
+ b->registerOutput(out_val);
+
+ environment_stack = environment_stack->next;
+ };
+
+ emit_if_expr(true_block, expr.true_expr());
+ emit_if_expr(false_block, expr.false_expr());
+
+ // Add op outputs
+ auto expr_value = n->addOutput(); // Resulting value
+
+ return {expr_value};
+ }
+
void emitIf(const If& stmt) {
Value* cond_value = emitExpr(stmt.cond(), 1)[0];
@@ -510,8 +536,10 @@ struct to_ir {
} break;
case '.':
// TODO: add support for "."
- case TK_IF_EXPR:
- // TODO: add support for conditional
+ case TK_IF_EXPR: {
+ expectOutputs(tree, output_size, 1);
+ return emitTernaryIf(TernaryIf(tree));
+ } break;
default:
throw ErrorReport(tree) << "NYI: " << tree;
break;