summaryrefslogtreecommitdiff
path: root/caffe2
diff options
context:
space:
mode:
authorArtem Volkhin <volkhin@fb.com>2017-08-25 23:56:05 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2017-08-26 00:07:58 -0700
commitd3c8e68004c118b4bd00cd1d33a3fb549ce9a605 (patch)
tree495c0550b1c5ae9f98e0b1499c0fa2ddf5e5ceff /caffe2
parent9f693b39aa747188c500d753de20fbf0db39d0a1 (diff)
downloadpytorch-d3c8e68004c118b4bd00cd1d33a3fb549ce9a605.tar.gz
pytorch-d3c8e68004c118b4bd00cd1d33a3fb549ce9a605.tar.bz2
pytorch-d3c8e68004c118b4bd00cd1d33a3fb549ce9a605.zip
Revert D5641588: [caffe2] Control flow operators
Summary: This reverts commit f9e04429961c3da7da4ebca3e8163bfcc2a09ec9 bypass-lint Differential Revision: D5641588 fbshipit-source-id: bb23b213d08e9c3ea509216fce9367625943d007
Diffstat (limited to 'caffe2')
-rw-r--r--caffe2/operators/do_op.cc32
-rw-r--r--caffe2/operators/do_op.h99
-rw-r--r--caffe2/operators/if_op.cc48
-rw-r--r--caffe2/operators/if_op.h44
-rw-r--r--caffe2/operators/while_op.cc58
-rw-r--r--caffe2/operators/while_op.h47
-rw-r--r--caffe2/python/control_ops_util.py223
-rw-r--r--caffe2/python/core.py14
-rw-r--r--caffe2/python/net_builder.py265
-rw-r--r--caffe2/python/net_builder_test.py85
10 files changed, 3 insertions, 912 deletions
diff --git a/caffe2/operators/do_op.cc b/caffe2/operators/do_op.cc
deleted file mode 100644
index d74246ff2c..0000000000
--- a/caffe2/operators/do_op.cc
+++ /dev/null
@@ -1,32 +0,0 @@
-#include "caffe2/operators/do_op.h"
-
-namespace caffe2 {
-
-template <>
-bool DoOp<CPUContext>::RunOnDevice() {
- return net_->Run();
-}
-
-REGISTER_CPU_OPERATOR(Do, DoOp<CPUContext>);
-
-OPERATOR_SCHEMA(Do)
- .NumInputs(0, INT_MAX)
- .NumOutputs(0, INT_MAX)
- .SetDoc(R"DOC(
-'Do' control operator, creates a new workspace and executes a subnet in it.
-Accepts 'net' argument for a subnet, arguments 'inner_blobs' and 'outer_blobs_idx'
-provide a mapping between selected inner blob names and corresponding outer blobs
-indices: [0..NumInputs-1] indices correspond to input blobs and [NumInputs..NumOutputs+NumInputs-1] -
-output blobs, in the order specified in 'Do' operator definition.
- )DOC")
- .Arg("net", "Subnet with blob bindings")
- .Arg(
- "inner_blobs",
- "List of inner net blob names to bind to outer workspace")
- .Arg(
- "outer_blobs_idx",
- "Indices of corresponding outer workspace blobs, "
- "in order: operator inputs, operator outputs")
- .AllowInplace([](int in, int out) -> bool { return true; });
-
-} // namespace caffe2
diff --git a/caffe2/operators/do_op.h b/caffe2/operators/do_op.h
deleted file mode 100644
index 386cff033b..0000000000
--- a/caffe2/operators/do_op.h
+++ /dev/null
@@ -1,99 +0,0 @@
-#ifndef CAFFE2_OPERATORS_DO_OP_H_
-#define CAFFE2_OPERATORS_DO_OP_H_
-
-#include <string>
-#include <unordered_map>
-#include <vector>
-
-#include "caffe2/core/context.h"
-#include "caffe2/core/logging.h"
-#include "caffe2/core/operator.h"
-#include "caffe2/proto/caffe2.pb.h"
-
-namespace caffe2 {
-
-template <class Context>
-class DoOp final : public Operator<Context> {
- public:
- DoOp(const OperatorDef& operator_def, Workspace* ws)
- : Operator<Context>(operator_def, ws) {
- CAFFE_ENFORCE(
- this->template HasSingleArgumentOfType<NetDef>("net"),
- "net must be specified in Do operator");
- net_def_ = this->template GetSingleArgument<NetDef>("net", NetDef());
-
- const auto& input_names = getInputBlobNames(operator_def);
- const auto& output_names = getOutputBlobNames(operator_def);
- std::vector<std::string> outer_blob_names;
- outer_blob_names.reserve(input_names.size() + output_names.size());
- outer_blob_names.insert(
- outer_blob_names.end(), input_names.begin(), input_names.end());
- outer_blob_names.insert(
- outer_blob_names.end(), output_names.begin(), output_names.end());
-
- const auto& inner_blobs =
- this->template GetRepeatedArgument<std::string>("inner_blobs");
- // [0..input_names.size()-1] indices encode input blobs;
- // [input_names.size()..output_names.size()+input_names.size()-1] -
- // encode output blobs
- const auto& outer_blobs =
- this->template GetRepeatedArgument<int>("outer_blobs_idx");
- CAFFE_ENFORCE_EQ(
- inner_blobs.size(),
- outer_blobs.size(),
- "Invalid blob bindings: different inner/outer blobs lengths");
- std::unordered_map<std::string, std::string> blob_bindings;
- for (size_t blob_idx = 0; blob_idx < inner_blobs.size(); ++blob_idx) {
- CAFFE_ENFORCE(
- !blob_bindings.count(inner_blobs[blob_idx]),
- "Invalid blob bindings: redefinition of inner blob " +
- inner_blobs[blob_idx]);
- CAFFE_ENFORCE(
- outer_blobs[blob_idx] >= 0 &&
- outer_blobs[blob_idx] < outer_blob_names.size(),
- "Invalid blob bindings: outer blob index (" +
- caffe2::to_string(outer_blobs[blob_idx]) + ", inner name: " +
- inner_blobs[blob_idx] + ") is out of bounds [0, " +
- caffe2::to_string(outer_blob_names.size() - 1) + "]");
- blob_bindings[inner_blobs[blob_idx]] =
- outer_blob_names[outer_blobs[blob_idx]];
- }
-
- net_workspace_.reset(new Workspace(ws, blob_bindings));
- CAFFE_ENFORCE(net_workspace_, "Failed to initialize subnet workspace");
- net_ = net_workspace_->CreateNet(net_def_, true);
- CAFFE_ENFORCE(net_, "Failed to initialize subnet");
- }
-
- USE_OPERATOR_CONTEXT_FUNCTIONS;
- bool RunOnDevice() override;
-
- private:
- static std::vector<std::string> getInputBlobNames(
- const OperatorDef& operator_def) {
- std::vector<std::string> names;
- names.reserve(operator_def.input_size());
- for (auto idx = 0; idx < operator_def.input_size(); ++idx) {
- names.push_back(operator_def.input(idx));
- }
- return names;
- }
-
- static std::vector<std::string> getOutputBlobNames(
- const OperatorDef& operator_def) {
- std::vector<std::string> names;
- names.reserve(operator_def.output_size());
- for (auto idx = 0; idx < operator_def.output_size(); ++idx) {
- names.push_back(operator_def.output(idx));
- }
- return names;
- }
-
- NetDef net_def_;
- NetBase* net_;
- std::unique_ptr<Workspace> net_workspace_;
-};
-
-} // namespace caffe2
-
-#endif // CAFFE2_OPERATORS_DO_OP_H_
diff --git a/caffe2/operators/if_op.cc b/caffe2/operators/if_op.cc
deleted file mode 100644
index 803d3a944b..0000000000
--- a/caffe2/operators/if_op.cc
+++ /dev/null
@@ -1,48 +0,0 @@
-#include "caffe2/operators/if_op.h"
-
-namespace caffe2 {
-
-template <>
-bool IfOp<CPUContext>::RunOnDevice() {
- CAFFE_ENFORCE_GT(
- InputSize(), 0, "Condition must be specified in If operator");
- CAFFE_ENFORCE(
- InputIsType<Tensor<CPUContext>>(0),
- "Invalid condition in If operator: tensor expected");
-
- const auto& condition = Input(0);
- CAFFE_ENFORCE(
- condition.IsType<bool>(),
- "Invalid condition tensor in If operator: boolean expected");
- CAFFE_ENFORCE_EQ(
- condition.size(),
- 1,
- "Invalid condition tensor in If operator: single value expected");
- CAFFE_ENFORCE_EQ(
- condition.ndim(),
- 0,
- "Invalid condition tensor in If operator: scalar expected");
-
- auto conditionValue = *condition.data<bool>();
- auto* netToExecute =
- conditionValue ? then_net_ : (else_net_ ? else_net_ : nullptr);
-
- return netToExecute ? netToExecute->Run() : true;
-}
-
-REGISTER_CPU_OPERATOR(If, IfOp<CPUContext>);
-
-OPERATOR_SCHEMA(If)
- .NumInputs(1, INT_MAX)
- .NumOutputs(0, INT_MAX)
- .SetDoc(R"DOC(
-'If' control operator, first input is a scalar boolean blob that stores condition
-value. Accepts 'then_net' (required) and 'else_net' (optional) arguments for 'then' and
-'else' subnets respectively. Subnets are executed in the same workspace as 'If'.
- )DOC")
- .Arg("then_net", "Net executed when condition is true")
- .Arg("else_net", "Net executed when condition is false (optional)")
- .Input(0, "condition", "Scalar boolean condition")
- .AllowInplace([](int in, int out) -> bool { return true; });
-
-} // namespace caffe2
diff --git a/caffe2/operators/if_op.h b/caffe2/operators/if_op.h
deleted file mode 100644
index 5b3eb8efbb..0000000000
--- a/caffe2/operators/if_op.h
+++ /dev/null
@@ -1,44 +0,0 @@
-#ifndef CAFFE2_OPERATORS_IF_OP_H_
-#define CAFFE2_OPERATORS_IF_OP_H_
-
-#include "caffe2/core/context.h"
-#include "caffe2/core/logging.h"
-#include "caffe2/core/operator.h"
-
-namespace caffe2 {
-
-template <class Context>
-class IfOp final : public Operator<Context> {
- public:
- IfOp(const OperatorDef& operator_def, Workspace* ws)
- : Operator<Context>(operator_def, ws) {
- CAFFE_ENFORCE(
- this->template HasSingleArgumentOfType<NetDef>("then_net"),
- "then_net must be specified in If operator");
- then_net_def_ =
- this->template GetSingleArgument<NetDef>("then_net", NetDef());
- then_net_ = ws->CreateNet(then_net_def_, true);
- CAFFE_ENFORCE(then_net_, "Failed to initialize then subnet");
-
- if (this->template HasSingleArgumentOfType<NetDef>("else_net")) {
- else_net_def_ =
- this->template GetSingleArgument<NetDef>("else_net", NetDef());
- else_net_ = ws->CreateNet(else_net_def_, true);
- CAFFE_ENFORCE(else_net_, "Failed to initialize else subnet");
- }
- }
-
- USE_OPERATOR_CONTEXT_FUNCTIONS;
- bool RunOnDevice() override;
-
- private:
- NetDef then_net_def_;
- NetBase* then_net_ = nullptr;
-
- NetDef else_net_def_;
- NetBase* else_net_ = nullptr;
-};
-
-} // namespace caffe2
-
-#endif // CAFFE2_OPERATORS_IF_OP_H_
diff --git a/caffe2/operators/while_op.cc b/caffe2/operators/while_op.cc
deleted file mode 100644
index 53f992f8ac..0000000000
--- a/caffe2/operators/while_op.cc
+++ /dev/null
@@ -1,58 +0,0 @@
-#include "caffe2/operators/while_op.h"
-
-namespace caffe2 {
-
-template <>
-bool WhileOp<CPUContext>::RunOnDevice() {
- CAFFE_ENFORCE_GT(
- InputSize(), 0, "Condition must be specified in While operator");
- CAFFE_ENFORCE(
- InputIsType<Tensor<CPUContext>>(0),
- "Invalid condition in While operator: tensor expected");
-
- const auto& condition = Input(0);
- CAFFE_ENFORCE(
- condition.IsType<bool>(),
- "Invalid condition tensor in While operator: boolean expected");
- CAFFE_ENFORCE_EQ(
- condition.size(),
- 1,
- "Invalid condition tensor in While operator: single value expected");
- CAFFE_ENFORCE_EQ(
- condition.ndim(),
- 0,
- "Invalid condition tensor in While operator: scalar expected");
-
- while (true) {
- if (cond_net_ && !cond_net_->Run()) {
- return false;
- }
- if (!*condition.data<bool>()) {
- return true;
- }
- if (!loop_net_->Run()) {
- return false;
- }
- }
-
- return true;
-}
-
-REGISTER_CPU_OPERATOR(While, WhileOp<CPUContext>);
-
-OPERATOR_SCHEMA(While)
- .NumInputs(1, INT_MAX)
- .NumOutputs(0, INT_MAX)
- .SetDoc(R"DOC(
-'While' control operator, first input is a scalar boolean blob that stores loop's
-condition value. Accepts 'loop_net' (required) and 'cond_net' (optional) arguments for
-loop's body and condition subnets respectively. If condition subnet is specified,
-it is executed before the first and after each iteration. Subnets are executed in
-the same workspace as 'While'.
- )DOC")
- .Arg("loop_net", "Net executed on each iteration")
- .Arg("cond_net", "Net to (re)compute condition value")
- .Input(0, "condition", "Scalar boolean condition")
- .AllowInplace([](int in, int out) -> bool { return true; });
-
-} // namespace caffe2
diff --git a/caffe2/operators/while_op.h b/caffe2/operators/while_op.h
deleted file mode 100644
index 1f7352d77e..0000000000
--- a/caffe2/operators/while_op.h
+++ /dev/null
@@ -1,47 +0,0 @@
-#ifndef CAFFE2_OPERATORS_WHILE_OP_H_
-#define CAFFE2_OPERATORS_WHILE_OP_H_
-
-#include "caffe2/core/context.h"
-#include "caffe2/core/logging.h"
-#include "caffe2/core/operator.h"
-
-namespace caffe2 {
-
-template <class Context>
-class WhileOp final : public Operator<Context> {
- public:
- WhileOp(const OperatorDef& operator_def, Workspace* ws)
- : Operator<Context>(operator_def, ws) {
- CAFFE_ENFORCE(
- this->template HasSingleArgumentOfType<NetDef>("loop_net"),
- "loop_net must be specified in While operator");
- loop_net_def_ =
- this->template GetSingleArgument<NetDef>("loop_net", NetDef());
- loop_net_ = ws->CreateNet(loop_net_def_, true);
- CAFFE_ENFORCE(loop_net_, "Failed to initialize loop subnet");
-
- cond_net_ = nullptr;
- bool has_cond_net =
- this->template HasSingleArgumentOfType<NetDef>("cond_net");
- if (has_cond_net) {
- cond_net_def_ =
- this->template GetSingleArgument<NetDef>("cond_net", NetDef());
- cond_net_ = ws->CreateNet(cond_net_def_, true);
- CAFFE_ENFORCE(cond_net_, "Failed to initialize condition subnet");
- }
- }
-
- USE_OPERATOR_CONTEXT_FUNCTIONS;
- bool RunOnDevice() override;
-
- private:
- NetDef loop_net_def_;
- NetBase* loop_net_;
-
- NetDef cond_net_def_;
- NetBase* cond_net_;
-};
-
-} // namespace caffe2
-
-#endif // CAFFE2_OPERATORS_WHILE_OP_H_
diff --git a/caffe2/python/control_ops_util.py b/caffe2/python/control_ops_util.py
deleted file mode 100644
index 8cd1441025..0000000000
--- a/caffe2/python/control_ops_util.py
+++ /dev/null
@@ -1,223 +0,0 @@
-## @package control_ops_util
-# Module caffe2.python.control_ops_util
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-from __future__ import unicode_literals
-
-from caffe2.python import core
-
-
-def get_external_blob_names(net, lexical_scope):
- """
- Returns a set of blobs a given net depends on and a set of
- output blobs that are written by the net
- Inputs:
- net - net to return input/output blobs for;
- lexical_scope - all external blob names visible to the net
- """
- net_proto = net.Proto()
- net_ssa, _ = core.get_ssa(net_proto)
- input_names = core.get_undefined_blobs(net_ssa)
- if net_proto.external_input:
- input_names |= set(net_proto.external_input)
-
- output_names = set()
- if net_proto.external_output:
- output_names = set(net_proto.external_output)
- for op in net_proto.op:
- for output in op.output:
- if output in lexical_scope:
- output_names.add(output)
-
- return input_names, output_names
-
-
-def add_if_op(if_net, cond_blob, lexical_scope, then_net, else_net=None):
- """
- A helper function to add an If op to the net.
- Automatically determines whether blobs in the then/else subnets are external
- (from the outer workspace) or local (visible only inside subnet's workspace)
- based on lexical scope - set of all outer blob names visible to the 'If'
- operator. All the blobs in then/else subnets with names matching a name in lexical
- scope and all the blobs that are first used as the operators' inputs are
- considered outer blobs - these blobs must exist in the outer workspace,
- then/else subnets can read their values and new values written into these blobs
- will be visible outside of the 'If' operator. All other blobs are local - exist
- only within inner workspaces for then/else.
- Inputs:
- if_net - net to add an If op to;
- cond_blob - scalar bool blob reference, used as If condition;
- lexical_scope - a set of outer blob names visible to then/else branches;
- then_net/else_net - nets (core.Net) for then/else branches
- """
- then_input_blob_names, then_output_blob_names = get_external_blob_names(
- then_net, lexical_scope)
-
- else_input_blob_names = set()
- else_output_blob_names = set()
- if else_net:
- else_input_blob_names, else_output_blob_names = get_external_blob_names(
- else_net, lexical_scope)
-
- input_blob_names = then_input_blob_names | else_input_blob_names
-
- # find outputs that are not produced by both then and else branches and
- # add them into inputs
- outputs_to_inputs = then_output_blob_names ^ else_output_blob_names
- input_blob_names |= outputs_to_inputs
-
- output_blob_names = then_output_blob_names | else_output_blob_names
-
- ext_then_input_blob_names = then_input_blob_names | (
- then_output_blob_names - else_output_blob_names)
- ext_else_input_blob_names = else_input_blob_names | (
- else_output_blob_names - then_output_blob_names)
-
- if_inputs = [cond_blob]
- if_inputs += [core.BlobReference(name=b, net=None) for b in input_blob_names]
- if_outputs = [core.BlobReference(name=b, net=None) for b in output_blob_names]
-
- do_then_net = core.Net('do_then_net')
-
- ext_then_input_blobs = \
- [core.BlobReference(name=b, net=None) for b in ext_then_input_blob_names]
- then_output_blobs = \
- [core.BlobReference(name=b, net=None) for b in then_output_blob_names]
- then_input_output_names_ordered = [
- str(b) for b in (ext_then_input_blobs + then_output_blobs)]
-
- then_outer_blob_names = list(ext_then_input_blob_names | then_output_blob_names)
- then_outer_blob_names_idx = [
- then_input_output_names_ordered.index(b) for b in then_outer_blob_names]
-
- do_then_net.Do(
- ext_then_input_blobs,
- then_output_blobs,
- net=then_net.Proto(),
- inner_blobs=then_outer_blob_names,
- outer_blobs_idx=then_outer_blob_names_idx)
- do_then_net.AddExternalOutput(*then_output_blobs)
-
- if_args = {}
- if_args['then_net'] = do_then_net.Proto()
-
- if else_net:
- do_else_net = core.Net('do_else_net')
-
- ext_else_input_blobs = \
- [core.BlobReference(name=b, net=None) for b in ext_else_input_blob_names]
- else_output_blobs = \
- [core.BlobReference(name=b, net=None) for b in else_output_blob_names]
- else_input_output_names_ordered = [
- str(b) for b in (ext_else_input_blobs + else_output_blobs)]
-
- else_outer_blob_names = list(ext_else_input_blob_names | else_output_blob_names)
- else_outer_blob_names_idx = [
- else_input_output_names_ordered.index(b) for b in else_outer_blob_names]
-
- do_else_net.Do(
- ext_else_input_blobs,
- else_output_blobs,
- net=else_net.Proto(),
- inner_blobs=else_outer_blob_names,
- outer_blobs_idx=else_outer_blob_names_idx)
- do_else_net.AddExternalOutput(*else_output_blobs)
- if_args['else_net'] = do_else_net.Proto()
-
- if_net.If(if_inputs, if_outputs, **if_args)
- if_net.AddExternalOutput(*if_outputs)
-
-
-def add_while_op(
- while_net, cond_blob, lexical_scope, loop_body_net, condition_body_net=None):
- """
- A helper function to add a While op to the net. Same rules for determining
- outer and inner blobs as for the 'If' operator apply for the 'While' operator
- loop and condition subnets. If specified, condition net is executed in a separate
- workspace before the first and after each iteration, the last operator must have
- a single scalar boolean output that is written into the condition blob.
- Inputs:
- while_net - net to add a While op to;
- cond_blob - scalar bool blob reference, used as a stop condition;
- lexical_scope - a set of outer blob names visible to the loop's body;
- loop_body_net - net to execute on each iteration;
- condition_body_net - net to compute condition value
- """
- input_blob_names, output_blob_names = get_external_blob_names(
- loop_body_net, lexical_scope)
-
- # Since it's possible that loop is not going to run even once
- # we have to add loop's external outputs into inputs
- input_blob_names |= output_blob_names
-
- while_inputs = [cond_blob]
- while_inputs += [core.BlobReference(name=b, net=None) for b in input_blob_names]
- while_outputs = [core.BlobReference(name=b, net=None) for b in output_blob_names]
-
- do_loop_body_net = core.Net('do_loop_body_net')
-
- loop_input_output_names_ordered = [
- str(b) for b in (while_inputs + while_outputs)]
- loop_body_outer_blob_names = list(input_blob_names | output_blob_names)
- loop_body_outer_blob_names_idx = [
- loop_input_output_names_ordered.index(b) for b in loop_body_outer_blob_names]
- do_loop_body_net.Do(
- while_inputs,
- while_outputs,
- net=loop_body_net.Proto(),
- inner_blobs=loop_body_outer_blob_names,
- outer_blobs_idx=loop_body_outer_blob_names_idx)
- do_loop_body_net.AddExternalOutput(*while_outputs)
-
- while_args = {}
- while_args['loop_net'] = do_loop_body_net.Proto()
-
- condition_net = None
- if condition_body_net:
- # make sure condition blob is visible outside of condition net
- if str(cond_blob) not in condition_body_net.Proto().external_output:
- condition_body_net.AddExternalOutput(cond_blob)
-
- cond_input_blob_names, cond_output_blob_names = get_external_blob_names(
- condition_body_net, lexical_scope)
-
- cond_inputs = [core.BlobReference(name=b, net=None)
- for b in cond_input_blob_names]
- assert str(cond_blob) in cond_output_blob_names, \
- 'Condition blob expected in condition net output'
- cond_outputs = [core.BlobReference(name=b, net=None)
- for b in cond_output_blob_names]
-
- cond_input_output_names_ordered = [
- str(b) for b in (cond_inputs + cond_outputs)]
- cond_body_outer_blob_names = \
- list(cond_input_blob_names | cond_output_blob_names)
- cond_body_outer_blob_names_idx = [
- cond_input_output_names_ordered.index(b)
- for b in cond_body_outer_blob_names]
- condition_net = core.Net('do_loop_condition_net')
- condition_net.Do(
- cond_inputs,
- cond_outputs,
- net=condition_body_net.Proto(),
- inner_blobs=cond_body_outer_blob_names,
- outer_blobs_idx=cond_body_outer_blob_names_idx)
- condition_net.AddExternalOutput(*cond_outputs)
-
- while_args['cond_net'] = condition_net.Proto()
-
- while_inputs += [b for b in cond_inputs
- if str(b) not in input_blob_names]
- while_outputs += [b for b in cond_outputs
- if str(b) not in output_blob_names]
-
- if str(cond_blob) not in lexical_scope:
- while_net.ConstantFill(
- [],
- cond_blob,
- dtype=core.DataType.BOOL,
- value=False)
-
- while_net.While(while_inputs, while_outputs, **while_args)
- while_net.AddExternalOutput(*while_outputs)
diff --git a/caffe2/python/core.py b/caffe2/python/core.py
index 0a33b6da1a..42bc14e840 100644
--- a/caffe2/python/core.py
+++ b/caffe2/python/core.py
@@ -1492,20 +1492,6 @@ class Net(object):
return True
return blob_name in self._external_input_map
- def UsedBlobNames(self):
- """
- Returns a set of blob names used in the net
- """
- blob_names = set()
- for op in self._net.op:
- blob_names |= set(op.input)
- blob_names |= set(op.output)
- if self._net.external_input:
- blob_names |= set(self._net.external_input)
- if self._net.external_output:
- blob_names |= set(self._net.external_output)
- return blob_names
-
def GetBlobRef(self, blob_name):
"""
Given the name of a blob produced by this net, return a BlobReference
diff --git a/caffe2/python/net_builder.py b/caffe2/python/net_builder.py
index 7e628480f4..fbec58ad8b 100644
--- a/caffe2/python/net_builder.py
+++ b/caffe2/python/net_builder.py
@@ -7,7 +7,6 @@ from __future__ import unicode_literals
from caffe2.python import core, context
from caffe2.python.task import Task, TaskGroup
-from caffe2.python.control_ops_util import add_if_op, add_while_op
@context.define_context()
@@ -29,22 +28,17 @@ class NetBuilder(object):
step = core.to_execution_step(nb)
"""
def __init__(self, name=None, _stop_blob_required=False,
- _stop_blob=None, _fullname=None, _use_control_ops=False):
- self._parent = NetBuilder.current(required=False)
+ _stop_blob=None, _fullname=None):
+ nb = NetBuilder.current(required=False)
assert not _fullname or not name, 'Cannot set both _fullname and name'
- assert not _use_control_ops or \
- (not _stop_blob_required and not _stop_blob), \
- 'Stop blobs are not used with control operators'
self.name = _fullname or '/'.join(
- n for n in (self._parent.name if self._parent else None, name) if n
+ n for n in (nb.name if nb else None, name) if n
)
self._frozen = False
self._current_net = None
self._children = []
- self._lexical_scope = set()
self._stop_blob = _stop_blob
self._stop_blob_required = _stop_blob_required
- self._use_control_ops = _use_control_ops
def stop_blob(self):
"""
@@ -54,8 +48,6 @@ class NetBuilder(object):
in the current net, so it doesn't initialize it if the current net is
the first of the builder.
"""
- assert not self._use_control_ops, \
- 'Stop blobs are not used with control operators'
if self._stop_blob is None:
net = self.current_net()
self._stop_blob = core.BlobReference(
@@ -66,8 +58,6 @@ class NetBuilder(object):
return self._stop_blob
def stop_if(self, blob):
- assert not self._use_control_ops, \
- 'Stop blobs are not used with control operators'
ops.Copy(blob, self.stop_blob())
self._current_net = None
@@ -75,49 +65,13 @@ class NetBuilder(object):
assert not self._frozen, (
'This NetBuilder (%s) has been built already.' % self.name)
- def _update_lexical_scope(self):
- """
- Updates lexical scope based on the current list of children.
- Lexical scope contains names of blobs that are currently available
- and were introduced in the net builder
- """
- self._lexical_scope = set()
- for child in self._children:
- if isinstance(child, core.Net):
- self._lexical_scope |= child.UsedBlobNames()
- elif isinstance(child, NetBuilder) and child._use_control_ops:
- self._lexical_scope |= child._lexical_scope
-
- def _collect_lexical_scopes(self):
- """
- Collects the names of all blobs currently visible in the net builder
- """
- scope = set(self._lexical_scope)
- parent = self._parent
- while parent:
- scope |= parent._lexical_scope
- parent = parent._parent
- return scope
-
- def _reset_children(self):
- self._current_net = None
- self._children = []
- self._lexical_scope = set()
-
def add(self, child):
self._assert_mutable()
-
- if self._use_control_ops:
- assert isinstance(child, core.Net) or (
- isinstance(child, NetBuilder) and child._use_control_ops), \
- "Expected Net or NetBuilder with control ops"
-
self._current_net = None
self._children.append(child)
# to-do : check it's not a dag net
if isinstance(child, core.Net):
self._current_net = child
- self._update_lexical_scope()
return child
def current_net(self, name=None):
@@ -138,7 +92,6 @@ class NetBuilder(object):
return self._children
def __exit__(self, etype, *args):
- self._update_lexical_scope()
self.freeze()
if etype is not None:
return
@@ -146,51 +99,6 @@ class NetBuilder(object):
'This NetBuilder (%s) requires a stop condition ' % self.name +
'to be set with `stop` or `stop_if`')
- @staticmethod
- def merge_nets(nets_or_builders, outer_blob_names):
- # Only nets or builders with control ops are allowed.
- # Need to pay attention to external outputs, e.g.
- # ...
- # IfNet1 (cond_blob):
- # (Net1)
- # X = 1
- # IfNet2 (...):
- # X = X + 1
- # ...
- # In this example there're two children in then branch of IfNet1:
- # a subnet Net1 that creates blob X and sets its value to one, and
- # a net builder IfNet2 that (conditionally) increments X.
- # From IfNet2's point of view X is an external input
- # and output blob, it will be put into IfNet2 net's external_output.
- # At the same time, from the point of view of IfNet1 X is purely local.
- # Net.AppendNet just merges external outputs of the networks, so
- # without checking this the result of Net1.AppendNet(IfNet2's net)
- # would have blob X in external_output
-
- net = None
- for n in nets_or_builders:
- cur = None
- if isinstance(n, NetBuilder):
- assert n._use_control_ops, \
- "Merging of NetBuilder supported only for control ops"
- nets = n.get()
- assert len(nets) == 1 and isinstance(nets[0], core.Net), \
- "Invalid control op net builder"
- cur = nets[0]
- else:
- assert isinstance(n, core.Net)
- cur = n
- if net:
- net.AppendNet(cur)
- else:
- net = cur
- if net:
- # correct external output
- external_outputs = [o for o in net.Proto().external_output
- if o in outer_blob_names]
- net.Proto().external_output[:] = external_outputs
- return net
-
def __str__(self):
return self.name or 'Un-named NetBuilder'
@@ -325,37 +233,6 @@ class Operations(object):
"""
return NetBuilder.current().add(_RunIf(cond, name=name))
- def IfNet(self, cond, name=None):
- """
- Same as If, but uses 'If' operator instead of execution step logic
- """
- return NetBuilder.current().add(_RunIfNet(cond, name=name))
-
- def Else(self, name=None):
- """
- Else branch of IfNet, has to be specified immediately after IfNet.
- Example:
- with ops.IfNet(ops.LT([x, y])):
- ...
- with ops.Else():
- ...
- """
- return _RunElseNet(name=name)
-
- def WhileNet(self, name=None):
- """
- NetBuilder for 'While' control operator
- """
- return NetBuilder.current().add(_RunWhileNet(name=name))
-
- def Condition(self, name=None):
- """
- Loop's condition, executed within WhileNet context
- """
- assert isinstance(NetBuilder.current(), _RunWhileNet), \
- "Use of Condition outside of WhileNet"
- return _RunWhileCondition(name=name)
-
def task_init(self):
"""
Defines operations that will be executed once at task startup.
@@ -599,139 +476,3 @@ class _RunIf(_RunOnce):
assert not self._is_else, 'Elif not allowed for an Else.'
return NetBuilder.current().add(
_RunIf(name=name or self.name, _already_ran=self._already_ran))
-
-
-class _RunIfNet(NetBuilder):
- """
- Generates a single net that uses If operator
- """
- def __init__(self, cond_blob, name=None):
- NetBuilder.__init__(self, name=name, _use_control_ops=True)
- assert cond_blob, 'Conditional blob is not specified for an If net'
- self._cond_blob = cond_blob
- self._then_net = None
- self._else_net = None
-
- def add(self, child):
- return NetBuilder.add(self, child)
-
- def __exit__(self, type, *args):
- if type is None:
- _then_nets = self._children
- self._reset_children()
-
- self._then_net = NetBuilder.merge_nets(
- _then_nets, self._collect_lexical_scopes())
- if not self._then_net:
- self._then_net = core.Net('empty_then_net')
-
- if_net = core.Net(self.name + '/if_net')
- add_if_op(if_net, self._cond_blob, self._collect_lexical_scopes(),
- self._then_net, self._else_net)
-
- self._current_net = if_net
- self._children = [if_net]
- NetBuilder.__exit__(self, type, *args)
-
-
-class _RunElseNet(NetBuilder):
- """
- Else branch for _RunIfNet builder
- """
- def __init__(self, name=None):
- NetBuilder.__init__(self, name=name, _use_control_ops=True)
- assert self._parent and len(self._parent._children) > 0 and \
- isinstance(self._parent._children[-1], _RunIfNet), \
- 'Invalid use of Else builder'
- self._if_builder = self._parent._children[-1]
-
- def __exit__(self, type, *args):
- if type is None:
- _else_nets = self._children
- self._reset_children()
-
- self._if_builder._else_net = NetBuilder.merge_nets(
- _else_nets, self._collect_lexical_scopes())
- if self._if_builder._else_net:
- if_else_net = core.Net(self.name + '/if_else_net')
- add_if_op(
- if_else_net,
- self._if_builder._cond_blob,
- self._collect_lexical_scopes(),
- self._if_builder._then_net,
- self._if_builder._else_net)
- self._if_builder._current_net = if_else_net
- self._if_builder._children = [if_else_net]
- NetBuilder.__exit__(self, type, *args)
-
-
-class _RunWhileNet(NetBuilder):
- """
- Generates a single net that uses While operator
- """
- def __init__(self, name=None):
- NetBuilder.__init__(self, name=name, _use_control_ops=True)
- self._cond_builder = None
-
- def __exit__(self, type, *args):
- if type is None:
- assert self._cond_builder, \
- 'Condition builder must be specified in While op'
-
- _cond_blob = self._cond_builder._cond_blob
- _cond_net = self._cond_builder._cond_net
-
- loop_body = self._children
- self._reset_children()
- loop_body_net = NetBuilder.merge_nets(
- loop_body, self._collect_lexical_scopes())
- if not loop_body_net:
- loop_body_net = core.Net('empty_loop_body_net')
-
- while_net = core.Net(self.name + '/while_net')
- add_while_op(while_net, _cond_blob, self._collect_lexical_scopes(),
- loop_body_net, _cond_net)
-
- self._current_net = while_net
- self._children = [while_net]
- NetBuilder.__exit__(self, type, *args)
-
-
-class _RunWhileCondition(NetBuilder):
- """
- Computes loop's condition, used in the context of WhileNet.
- Last operator must have a single scalar boolean output that will be used
- as a condition value, no other blobs created in the condition net are
- visible outside of it
- """
- def __init__(self, name=None):
- NetBuilder.__init__(self, name=name, _use_control_ops=True)
- assert self._parent and isinstance(self._parent, _RunWhileNet), \
- 'Invalid use of loop condition builder'
- self._cond_blob = None
- self._cond_net = None
-
- def __enter__(self):
- builder = NetBuilder.__enter__(self)
- assert not self._parent._cond_builder, \
- 'Multiple loop condition builders specified'
- assert len(self._parent._children) == 0, \
- 'Condition definition must be specified before the loop\'s body'
- self._parent._cond_builder = self
- return builder
-
- def __exit__(self, type, *args):
- if type is None:
- condition_body = self._children
- self._reset_children()
- self._cond_net = NetBuilder.merge_nets(
- condition_body, self._collect_lexical_scopes())
- assert self._cond_net, 'Invalid loop condition specified'
- assert len(self._cond_net.Proto().op) > 0, 'Invalid condition net'
- last_op = self._cond_net.Proto().op[-1]
- assert len(last_op.output) == 1, 'Invalid condition net'
- self._cond_blob = core.BlobReference(name=last_op.output[0], net=None)
-
- self._current_net = self._cond_net
- self._children = [self._cond_net]
- NetBuilder.__exit__(self, type, *args)
diff --git a/caffe2/python/net_builder_test.py b/caffe2/python/net_builder_test.py
index 169419c5c1..1d29711ec9 100644
--- a/caffe2/python/net_builder_test.py
+++ b/caffe2/python/net_builder_test.py
@@ -245,88 +245,3 @@ class TestNetBuilder(unittest.TestCase):
self.assertEquals(total1.fetch(), NUM_INSTANCES * NUM_ITERS)
self.assertEquals(total2.fetch(), NUM_INSTANCES * (NUM_ITERS ** 2))
self.assertEquals(total3.fetch(), NUM_INSTANCES * (NUM_ITERS ** 2))
-
- def test_if_net(self):
- with NetBuilder() as nb:
- x0 = ops.Const(0)
- x1 = ops.Const(1)
- x2 = ops.Const(2)
- y0 = ops.Const(0)
- y1 = ops.Const(1)
- y2 = ops.Const(2)
-
- # basic logic
- first_res = ops.Const(0)
- with ops.IfNet(ops.Const(True)):
- ops.Const(1, blob_out=first_res)
- with ops.Else():
- ops.Const(2, blob_out=first_res)
-
- second_res = ops.Const(0)
- with ops.IfNet(ops.Const(False)):
- ops.Const(1, blob_out=second_res)
- with ops.Else():
- ops.Const(2, blob_out=second_res)
-
- # nested and sequential ifs,
- # empty then/else,
- # passing outer blobs into branches,
- # writing into outer blobs, incl. into input blob
- # using local blobs
- with ops.IfNet(ops.LT([x0, x1])):
- local_blob = ops.Const(900)
- ops.Add([ops.Const(100), local_blob], [y0])
-
- gt = ops.GT([x1, x2])
- with ops.IfNet(gt):
- # empty then
- pass
- with ops.Else():
- ops.Add([y1, local_blob], [local_blob])
- ops.Add([ops.Const(100), y1], [y1])
-
- with ops.IfNet(ops.EQ([local_blob, ops.Const(901)])):
- ops.Const(7, blob_out=y2)
- ops.Add([y1, y2], [y2])
- with ops.Else():
- # empty else
- pass
-
- plan = Plan('if_net_test')
- plan.AddStep(to_execution_step(nb))
- ws = workspace.C.Workspace()
- ws.run(plan)
-
- first_res_value = ws.blobs[str(first_res)].fetch()
- second_res_value = ws.blobs[str(second_res)].fetch()
- y0_value = ws.blobs[str(y0)].fetch()
- y1_value = ws.blobs[str(y1)].fetch()
- y2_value = ws.blobs[str(y2)].fetch()
-
- self.assertEquals(first_res_value, 1)
- self.assertEquals(second_res_value, 2)
- self.assertEquals(y0_value, 1000)
- self.assertEquals(y1_value, 101)
- self.assertEquals(y2_value, 108)
- self.assertTrue(str(local_blob) not in ws.blobs)
-
- def test_while_net(self):
- with NetBuilder() as nb:
- x = ops.Const(0)
- y = ops.Const(0)
- with ops.WhileNet():
- with ops.Condition():
- ops.Add([x, ops.Const(1)], [x])
- ops.LT([x, ops.Const(7)])
- ops.Add([x, y], [y])
-
- plan = Plan('while_net_test')
- plan.AddStep(to_execution_step(nb))
- ws = workspace.C.Workspace()
- ws.run(plan)
-
- x_value = ws.blobs[str(x)].fetch()
- y_value = ws.blobs[str(y)].fetch()
-
- self.assertEqual(x_value, 7)
- self.assertEqual(y_value, 21)