diff options
author | Artem Volkhin <volkhin@fb.com> | 2017-08-25 23:56:05 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2017-08-26 00:07:58 -0700 |
commit | d3c8e68004c118b4bd00cd1d33a3fb549ce9a605 (patch) | |
tree | 495c0550b1c5ae9f98e0b1499c0fa2ddf5e5ceff /caffe2 | |
parent | 9f693b39aa747188c500d753de20fbf0db39d0a1 (diff) | |
download | pytorch-d3c8e68004c118b4bd00cd1d33a3fb549ce9a605.tar.gz pytorch-d3c8e68004c118b4bd00cd1d33a3fb549ce9a605.tar.bz2 pytorch-d3c8e68004c118b4bd00cd1d33a3fb549ce9a605.zip |
Revert D5641588: [caffe2] Control flow operators
Summary:
This reverts commit f9e04429961c3da7da4ebca3e8163bfcc2a09ec9
bypass-lint
Differential Revision: D5641588
fbshipit-source-id: bb23b213d08e9c3ea509216fce9367625943d007
Diffstat (limited to 'caffe2')
-rw-r--r-- | caffe2/operators/do_op.cc | 32 | ||||
-rw-r--r-- | caffe2/operators/do_op.h | 99 | ||||
-rw-r--r-- | caffe2/operators/if_op.cc | 48 | ||||
-rw-r--r-- | caffe2/operators/if_op.h | 44 | ||||
-rw-r--r-- | caffe2/operators/while_op.cc | 58 | ||||
-rw-r--r-- | caffe2/operators/while_op.h | 47 | ||||
-rw-r--r-- | caffe2/python/control_ops_util.py | 223 | ||||
-rw-r--r-- | caffe2/python/core.py | 14 | ||||
-rw-r--r-- | caffe2/python/net_builder.py | 265 | ||||
-rw-r--r-- | caffe2/python/net_builder_test.py | 85 |
10 files changed, 3 insertions, 912 deletions
diff --git a/caffe2/operators/do_op.cc b/caffe2/operators/do_op.cc deleted file mode 100644 index d74246ff2c..0000000000 --- a/caffe2/operators/do_op.cc +++ /dev/null @@ -1,32 +0,0 @@ -#include "caffe2/operators/do_op.h" - -namespace caffe2 { - -template <> -bool DoOp<CPUContext>::RunOnDevice() { - return net_->Run(); -} - -REGISTER_CPU_OPERATOR(Do, DoOp<CPUContext>); - -OPERATOR_SCHEMA(Do) - .NumInputs(0, INT_MAX) - .NumOutputs(0, INT_MAX) - .SetDoc(R"DOC( -'Do' control operator, creates a new workspace and executes a subnet in it. -Accepts 'net' argument for a subnet, arguments 'inner_blobs' and 'outer_blobs_idx' -provide a mapping between selected inner blob names and corresponding outer blobs -indices: [0..NumInputs-1] indices correspond to input blobs and [NumInputs..NumOutputs+NumInputs-1] - -output blobs, in the order specified in 'Do' operator definition. - )DOC") - .Arg("net", "Subnet with blob bindings") - .Arg( - "inner_blobs", - "List of inner net blob names to bind to outer workspace") - .Arg( - "outer_blobs_idx", - "Indices of corresponding outer workspace blobs, " - "in order: operator inputs, operator outputs") - .AllowInplace([](int in, int out) -> bool { return true; }); - -} // namespace caffe2 diff --git a/caffe2/operators/do_op.h b/caffe2/operators/do_op.h deleted file mode 100644 index 386cff033b..0000000000 --- a/caffe2/operators/do_op.h +++ /dev/null @@ -1,99 +0,0 @@ -#ifndef CAFFE2_OPERATORS_DO_OP_H_ -#define CAFFE2_OPERATORS_DO_OP_H_ - -#include <string> -#include <unordered_map> -#include <vector> - -#include "caffe2/core/context.h" -#include "caffe2/core/logging.h" -#include "caffe2/core/operator.h" -#include "caffe2/proto/caffe2.pb.h" - -namespace caffe2 { - -template <class Context> -class DoOp final : public Operator<Context> { - public: - DoOp(const OperatorDef& operator_def, Workspace* ws) - : Operator<Context>(operator_def, ws) { - CAFFE_ENFORCE( - this->template HasSingleArgumentOfType<NetDef>("net"), - "net must be specified in Do operator"); - net_def_ = this->template GetSingleArgument<NetDef>("net", NetDef()); - - const auto& input_names = getInputBlobNames(operator_def); - const auto& output_names = getOutputBlobNames(operator_def); - std::vector<std::string> outer_blob_names; - outer_blob_names.reserve(input_names.size() + output_names.size()); - outer_blob_names.insert( - outer_blob_names.end(), input_names.begin(), input_names.end()); - outer_blob_names.insert( - outer_blob_names.end(), output_names.begin(), output_names.end()); - - const auto& inner_blobs = - this->template GetRepeatedArgument<std::string>("inner_blobs"); - // [0..input_names.size()-1] indices encode input blobs; - // [input_names.size()..output_names.size()+input_names.size()-1] - - // encode output blobs - const auto& outer_blobs = - this->template GetRepeatedArgument<int>("outer_blobs_idx"); - CAFFE_ENFORCE_EQ( - inner_blobs.size(), - outer_blobs.size(), - "Invalid blob bindings: different inner/outer blobs lengths"); - std::unordered_map<std::string, std::string> blob_bindings; - for (size_t blob_idx = 0; blob_idx < inner_blobs.size(); ++blob_idx) { - CAFFE_ENFORCE( - !blob_bindings.count(inner_blobs[blob_idx]), - "Invalid blob bindings: redefinition of inner blob " + - inner_blobs[blob_idx]); - CAFFE_ENFORCE( - outer_blobs[blob_idx] >= 0 && - outer_blobs[blob_idx] < outer_blob_names.size(), - "Invalid blob bindings: outer blob index (" + - caffe2::to_string(outer_blobs[blob_idx]) + ", inner name: " + - inner_blobs[blob_idx] + ") is out of bounds [0, " + - caffe2::to_string(outer_blob_names.size() - 1) + "]"); - blob_bindings[inner_blobs[blob_idx]] = - outer_blob_names[outer_blobs[blob_idx]]; - } - - net_workspace_.reset(new Workspace(ws, blob_bindings)); - CAFFE_ENFORCE(net_workspace_, "Failed to initialize subnet workspace"); - net_ = net_workspace_->CreateNet(net_def_, true); - CAFFE_ENFORCE(net_, "Failed to initialize subnet"); - } - - USE_OPERATOR_CONTEXT_FUNCTIONS; - bool RunOnDevice() override; - - private: - static std::vector<std::string> getInputBlobNames( - const OperatorDef& operator_def) { - std::vector<std::string> names; - names.reserve(operator_def.input_size()); - for (auto idx = 0; idx < operator_def.input_size(); ++idx) { - names.push_back(operator_def.input(idx)); - } - return names; - } - - static std::vector<std::string> getOutputBlobNames( - const OperatorDef& operator_def) { - std::vector<std::string> names; - names.reserve(operator_def.output_size()); - for (auto idx = 0; idx < operator_def.output_size(); ++idx) { - names.push_back(operator_def.output(idx)); - } - return names; - } - - NetDef net_def_; - NetBase* net_; - std::unique_ptr<Workspace> net_workspace_; -}; - -} // namespace caffe2 - -#endif // CAFFE2_OPERATORS_DO_OP_H_ diff --git a/caffe2/operators/if_op.cc b/caffe2/operators/if_op.cc deleted file mode 100644 index 803d3a944b..0000000000 --- a/caffe2/operators/if_op.cc +++ /dev/null @@ -1,48 +0,0 @@ -#include "caffe2/operators/if_op.h" - -namespace caffe2 { - -template <> -bool IfOp<CPUContext>::RunOnDevice() { - CAFFE_ENFORCE_GT( - InputSize(), 0, "Condition must be specified in If operator"); - CAFFE_ENFORCE( - InputIsType<Tensor<CPUContext>>(0), - "Invalid condition in If operator: tensor expected"); - - const auto& condition = Input(0); - CAFFE_ENFORCE( - condition.IsType<bool>(), - "Invalid condition tensor in If operator: boolean expected"); - CAFFE_ENFORCE_EQ( - condition.size(), - 1, - "Invalid condition tensor in If operator: single value expected"); - CAFFE_ENFORCE_EQ( - condition.ndim(), - 0, - "Invalid condition tensor in If operator: scalar expected"); - - auto conditionValue = *condition.data<bool>(); - auto* netToExecute = - conditionValue ? then_net_ : (else_net_ ? else_net_ : nullptr); - - return netToExecute ? netToExecute->Run() : true; -} - -REGISTER_CPU_OPERATOR(If, IfOp<CPUContext>); - -OPERATOR_SCHEMA(If) - .NumInputs(1, INT_MAX) - .NumOutputs(0, INT_MAX) - .SetDoc(R"DOC( -'If' control operator, first input is a scalar boolean blob that stores condition -value. Accepts 'then_net' (required) and 'else_net' (optional) arguments for 'then' and -'else' subnets respectively. Subnets are executed in the same workspace as 'If'. - )DOC") - .Arg("then_net", "Net executed when condition is true") - .Arg("else_net", "Net executed when condition is false (optional)") - .Input(0, "condition", "Scalar boolean condition") - .AllowInplace([](int in, int out) -> bool { return true; }); - -} // namespace caffe2 diff --git a/caffe2/operators/if_op.h b/caffe2/operators/if_op.h deleted file mode 100644 index 5b3eb8efbb..0000000000 --- a/caffe2/operators/if_op.h +++ /dev/null @@ -1,44 +0,0 @@ -#ifndef CAFFE2_OPERATORS_IF_OP_H_ -#define CAFFE2_OPERATORS_IF_OP_H_ - -#include "caffe2/core/context.h" -#include "caffe2/core/logging.h" -#include "caffe2/core/operator.h" - -namespace caffe2 { - -template <class Context> -class IfOp final : public Operator<Context> { - public: - IfOp(const OperatorDef& operator_def, Workspace* ws) - : Operator<Context>(operator_def, ws) { - CAFFE_ENFORCE( - this->template HasSingleArgumentOfType<NetDef>("then_net"), - "then_net must be specified in If operator"); - then_net_def_ = - this->template GetSingleArgument<NetDef>("then_net", NetDef()); - then_net_ = ws->CreateNet(then_net_def_, true); - CAFFE_ENFORCE(then_net_, "Failed to initialize then subnet"); - - if (this->template HasSingleArgumentOfType<NetDef>("else_net")) { - else_net_def_ = - this->template GetSingleArgument<NetDef>("else_net", NetDef()); - else_net_ = ws->CreateNet(else_net_def_, true); - CAFFE_ENFORCE(else_net_, "Failed to initialize else subnet"); - } - } - - USE_OPERATOR_CONTEXT_FUNCTIONS; - bool RunOnDevice() override; - - private: - NetDef then_net_def_; - NetBase* then_net_ = nullptr; - - NetDef else_net_def_; - NetBase* else_net_ = nullptr; -}; - -} // namespace caffe2 - -#endif // CAFFE2_OPERATORS_IF_OP_H_ diff --git a/caffe2/operators/while_op.cc b/caffe2/operators/while_op.cc deleted file mode 100644 index 53f992f8ac..0000000000 --- a/caffe2/operators/while_op.cc +++ /dev/null @@ -1,58 +0,0 @@ -#include "caffe2/operators/while_op.h" - -namespace caffe2 { - -template <> -bool WhileOp<CPUContext>::RunOnDevice() { - CAFFE_ENFORCE_GT( - InputSize(), 0, "Condition must be specified in While operator"); - CAFFE_ENFORCE( - InputIsType<Tensor<CPUContext>>(0), - "Invalid condition in While operator: tensor expected"); - - const auto& condition = Input(0); - CAFFE_ENFORCE( - condition.IsType<bool>(), - "Invalid condition tensor in While operator: boolean expected"); - CAFFE_ENFORCE_EQ( - condition.size(), - 1, - "Invalid condition tensor in While operator: single value expected"); - CAFFE_ENFORCE_EQ( - condition.ndim(), - 0, - "Invalid condition tensor in While operator: scalar expected"); - - while (true) { - if (cond_net_ && !cond_net_->Run()) { - return false; - } - if (!*condition.data<bool>()) { - return true; - } - if (!loop_net_->Run()) { - return false; - } - } - - return true; -} - -REGISTER_CPU_OPERATOR(While, WhileOp<CPUContext>); - -OPERATOR_SCHEMA(While) - .NumInputs(1, INT_MAX) - .NumOutputs(0, INT_MAX) - .SetDoc(R"DOC( -'While' control operator, first input is a scalar boolean blob that stores loop's -condition value. Accepts 'loop_net' (required) and 'cond_net' (optional) arguments for -loop's body and condition subnets respectively. If condition subnet is specified, -it is executed before the first and after each iteration. Subnets are executed in -the same workspace as 'While'. - )DOC") - .Arg("loop_net", "Net executed on each iteration") - .Arg("cond_net", "Net to (re)compute condition value") - .Input(0, "condition", "Scalar boolean condition") - .AllowInplace([](int in, int out) -> bool { return true; }); - -} // namespace caffe2 diff --git a/caffe2/operators/while_op.h b/caffe2/operators/while_op.h deleted file mode 100644 index 1f7352d77e..0000000000 --- a/caffe2/operators/while_op.h +++ /dev/null @@ -1,47 +0,0 @@ -#ifndef CAFFE2_OPERATORS_WHILE_OP_H_ -#define CAFFE2_OPERATORS_WHILE_OP_H_ - -#include "caffe2/core/context.h" -#include "caffe2/core/logging.h" -#include "caffe2/core/operator.h" - -namespace caffe2 { - -template <class Context> -class WhileOp final : public Operator<Context> { - public: - WhileOp(const OperatorDef& operator_def, Workspace* ws) - : Operator<Context>(operator_def, ws) { - CAFFE_ENFORCE( - this->template HasSingleArgumentOfType<NetDef>("loop_net"), - "loop_net must be specified in While operator"); - loop_net_def_ = - this->template GetSingleArgument<NetDef>("loop_net", NetDef()); - loop_net_ = ws->CreateNet(loop_net_def_, true); - CAFFE_ENFORCE(loop_net_, "Failed to initialize loop subnet"); - - cond_net_ = nullptr; - bool has_cond_net = - this->template HasSingleArgumentOfType<NetDef>("cond_net"); - if (has_cond_net) { - cond_net_def_ = - this->template GetSingleArgument<NetDef>("cond_net", NetDef()); - cond_net_ = ws->CreateNet(cond_net_def_, true); - CAFFE_ENFORCE(cond_net_, "Failed to initialize condition subnet"); - } - } - - USE_OPERATOR_CONTEXT_FUNCTIONS; - bool RunOnDevice() override; - - private: - NetDef loop_net_def_; - NetBase* loop_net_; - - NetDef cond_net_def_; - NetBase* cond_net_; -}; - -} // namespace caffe2 - -#endif // CAFFE2_OPERATORS_WHILE_OP_H_ diff --git a/caffe2/python/control_ops_util.py b/caffe2/python/control_ops_util.py deleted file mode 100644 index 8cd1441025..0000000000 --- a/caffe2/python/control_ops_util.py +++ /dev/null @@ -1,223 +0,0 @@ -## @package control_ops_util -# Module caffe2.python.control_ops_util -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals - -from caffe2.python import core - - -def get_external_blob_names(net, lexical_scope): - """ - Returns a set of blobs a given net depends on and a set of - output blobs that are written by the net - Inputs: - net - net to return input/output blobs for; - lexical_scope - all external blob names visible to the net - """ - net_proto = net.Proto() - net_ssa, _ = core.get_ssa(net_proto) - input_names = core.get_undefined_blobs(net_ssa) - if net_proto.external_input: - input_names |= set(net_proto.external_input) - - output_names = set() - if net_proto.external_output: - output_names = set(net_proto.external_output) - for op in net_proto.op: - for output in op.output: - if output in lexical_scope: - output_names.add(output) - - return input_names, output_names - - -def add_if_op(if_net, cond_blob, lexical_scope, then_net, else_net=None): - """ - A helper function to add an If op to the net. - Automatically determines whether blobs in the then/else subnets are external - (from the outer workspace) or local (visible only inside subnet's workspace) - based on lexical scope - set of all outer blob names visible to the 'If' - operator. All the blobs in then/else subnets with names matching a name in lexical - scope and all the blobs that are first used as the operators' inputs are - considered outer blobs - these blobs must exist in the outer workspace, - then/else subnets can read their values and new values written into these blobs - will be visible outside of the 'If' operator. All other blobs are local - exist - only within inner workspaces for then/else. - Inputs: - if_net - net to add an If op to; - cond_blob - scalar bool blob reference, used as If condition; - lexical_scope - a set of outer blob names visible to then/else branches; - then_net/else_net - nets (core.Net) for then/else branches - """ - then_input_blob_names, then_output_blob_names = get_external_blob_names( - then_net, lexical_scope) - - else_input_blob_names = set() - else_output_blob_names = set() - if else_net: - else_input_blob_names, else_output_blob_names = get_external_blob_names( - else_net, lexical_scope) - - input_blob_names = then_input_blob_names | else_input_blob_names - - # find outputs that are not produced by both then and else branches and - # add them into inputs - outputs_to_inputs = then_output_blob_names ^ else_output_blob_names - input_blob_names |= outputs_to_inputs - - output_blob_names = then_output_blob_names | else_output_blob_names - - ext_then_input_blob_names = then_input_blob_names | ( - then_output_blob_names - else_output_blob_names) - ext_else_input_blob_names = else_input_blob_names | ( - else_output_blob_names - then_output_blob_names) - - if_inputs = [cond_blob] - if_inputs += [core.BlobReference(name=b, net=None) for b in input_blob_names] - if_outputs = [core.BlobReference(name=b, net=None) for b in output_blob_names] - - do_then_net = core.Net('do_then_net') - - ext_then_input_blobs = \ - [core.BlobReference(name=b, net=None) for b in ext_then_input_blob_names] - then_output_blobs = \ - [core.BlobReference(name=b, net=None) for b in then_output_blob_names] - then_input_output_names_ordered = [ - str(b) for b in (ext_then_input_blobs + then_output_blobs)] - - then_outer_blob_names = list(ext_then_input_blob_names | then_output_blob_names) - then_outer_blob_names_idx = [ - then_input_output_names_ordered.index(b) for b in then_outer_blob_names] - - do_then_net.Do( - ext_then_input_blobs, - then_output_blobs, - net=then_net.Proto(), - inner_blobs=then_outer_blob_names, - outer_blobs_idx=then_outer_blob_names_idx) - do_then_net.AddExternalOutput(*then_output_blobs) - - if_args = {} - if_args['then_net'] = do_then_net.Proto() - - if else_net: - do_else_net = core.Net('do_else_net') - - ext_else_input_blobs = \ - [core.BlobReference(name=b, net=None) for b in ext_else_input_blob_names] - else_output_blobs = \ - [core.BlobReference(name=b, net=None) for b in else_output_blob_names] - else_input_output_names_ordered = [ - str(b) for b in (ext_else_input_blobs + else_output_blobs)] - - else_outer_blob_names = list(ext_else_input_blob_names | else_output_blob_names) - else_outer_blob_names_idx = [ - else_input_output_names_ordered.index(b) for b in else_outer_blob_names] - - do_else_net.Do( - ext_else_input_blobs, - else_output_blobs, - net=else_net.Proto(), - inner_blobs=else_outer_blob_names, - outer_blobs_idx=else_outer_blob_names_idx) - do_else_net.AddExternalOutput(*else_output_blobs) - if_args['else_net'] = do_else_net.Proto() - - if_net.If(if_inputs, if_outputs, **if_args) - if_net.AddExternalOutput(*if_outputs) - - -def add_while_op( - while_net, cond_blob, lexical_scope, loop_body_net, condition_body_net=None): - """ - A helper function to add a While op to the net. Same rules for determining - outer and inner blobs as for the 'If' operator apply for the 'While' operator - loop and condition subnets. If specified, condition net is executed in a separate - workspace before the first and after each iteration, the last operator must have - a single scalar boolean output that is written into the condition blob. - Inputs: - while_net - net to add a While op to; - cond_blob - scalar bool blob reference, used as a stop condition; - lexical_scope - a set of outer blob names visible to the loop's body; - loop_body_net - net to execute on each iteration; - condition_body_net - net to compute condition value - """ - input_blob_names, output_blob_names = get_external_blob_names( - loop_body_net, lexical_scope) - - # Since it's possible that loop is not going to run even once - # we have to add loop's external outputs into inputs - input_blob_names |= output_blob_names - - while_inputs = [cond_blob] - while_inputs += [core.BlobReference(name=b, net=None) for b in input_blob_names] - while_outputs = [core.BlobReference(name=b, net=None) for b in output_blob_names] - - do_loop_body_net = core.Net('do_loop_body_net') - - loop_input_output_names_ordered = [ - str(b) for b in (while_inputs + while_outputs)] - loop_body_outer_blob_names = list(input_blob_names | output_blob_names) - loop_body_outer_blob_names_idx = [ - loop_input_output_names_ordered.index(b) for b in loop_body_outer_blob_names] - do_loop_body_net.Do( - while_inputs, - while_outputs, - net=loop_body_net.Proto(), - inner_blobs=loop_body_outer_blob_names, - outer_blobs_idx=loop_body_outer_blob_names_idx) - do_loop_body_net.AddExternalOutput(*while_outputs) - - while_args = {} - while_args['loop_net'] = do_loop_body_net.Proto() - - condition_net = None - if condition_body_net: - # make sure condition blob is visible outside of condition net - if str(cond_blob) not in condition_body_net.Proto().external_output: - condition_body_net.AddExternalOutput(cond_blob) - - cond_input_blob_names, cond_output_blob_names = get_external_blob_names( - condition_body_net, lexical_scope) - - cond_inputs = [core.BlobReference(name=b, net=None) - for b in cond_input_blob_names] - assert str(cond_blob) in cond_output_blob_names, \ - 'Condition blob expected in condition net output' - cond_outputs = [core.BlobReference(name=b, net=None) - for b in cond_output_blob_names] - - cond_input_output_names_ordered = [ - str(b) for b in (cond_inputs + cond_outputs)] - cond_body_outer_blob_names = \ - list(cond_input_blob_names | cond_output_blob_names) - cond_body_outer_blob_names_idx = [ - cond_input_output_names_ordered.index(b) - for b in cond_body_outer_blob_names] - condition_net = core.Net('do_loop_condition_net') - condition_net.Do( - cond_inputs, - cond_outputs, - net=condition_body_net.Proto(), - inner_blobs=cond_body_outer_blob_names, - outer_blobs_idx=cond_body_outer_blob_names_idx) - condition_net.AddExternalOutput(*cond_outputs) - - while_args['cond_net'] = condition_net.Proto() - - while_inputs += [b for b in cond_inputs - if str(b) not in input_blob_names] - while_outputs += [b for b in cond_outputs - if str(b) not in output_blob_names] - - if str(cond_blob) not in lexical_scope: - while_net.ConstantFill( - [], - cond_blob, - dtype=core.DataType.BOOL, - value=False) - - while_net.While(while_inputs, while_outputs, **while_args) - while_net.AddExternalOutput(*while_outputs) diff --git a/caffe2/python/core.py b/caffe2/python/core.py index 0a33b6da1a..42bc14e840 100644 --- a/caffe2/python/core.py +++ b/caffe2/python/core.py @@ -1492,20 +1492,6 @@ class Net(object): return True return blob_name in self._external_input_map - def UsedBlobNames(self): - """ - Returns a set of blob names used in the net - """ - blob_names = set() - for op in self._net.op: - blob_names |= set(op.input) - blob_names |= set(op.output) - if self._net.external_input: - blob_names |= set(self._net.external_input) - if self._net.external_output: - blob_names |= set(self._net.external_output) - return blob_names - def GetBlobRef(self, blob_name): """ Given the name of a blob produced by this net, return a BlobReference diff --git a/caffe2/python/net_builder.py b/caffe2/python/net_builder.py index 7e628480f4..fbec58ad8b 100644 --- a/caffe2/python/net_builder.py +++ b/caffe2/python/net_builder.py @@ -7,7 +7,6 @@ from __future__ import unicode_literals from caffe2.python import core, context from caffe2.python.task import Task, TaskGroup -from caffe2.python.control_ops_util import add_if_op, add_while_op @context.define_context() @@ -29,22 +28,17 @@ class NetBuilder(object): step = core.to_execution_step(nb) """ def __init__(self, name=None, _stop_blob_required=False, - _stop_blob=None, _fullname=None, _use_control_ops=False): - self._parent = NetBuilder.current(required=False) + _stop_blob=None, _fullname=None): + nb = NetBuilder.current(required=False) assert not _fullname or not name, 'Cannot set both _fullname and name' - assert not _use_control_ops or \ - (not _stop_blob_required and not _stop_blob), \ - 'Stop blobs are not used with control operators' self.name = _fullname or '/'.join( - n for n in (self._parent.name if self._parent else None, name) if n + n for n in (nb.name if nb else None, name) if n ) self._frozen = False self._current_net = None self._children = [] - self._lexical_scope = set() self._stop_blob = _stop_blob self._stop_blob_required = _stop_blob_required - self._use_control_ops = _use_control_ops def stop_blob(self): """ @@ -54,8 +48,6 @@ class NetBuilder(object): in the current net, so it doesn't initialize it if the current net is the first of the builder. """ - assert not self._use_control_ops, \ - 'Stop blobs are not used with control operators' if self._stop_blob is None: net = self.current_net() self._stop_blob = core.BlobReference( @@ -66,8 +58,6 @@ class NetBuilder(object): return self._stop_blob def stop_if(self, blob): - assert not self._use_control_ops, \ - 'Stop blobs are not used with control operators' ops.Copy(blob, self.stop_blob()) self._current_net = None @@ -75,49 +65,13 @@ class NetBuilder(object): assert not self._frozen, ( 'This NetBuilder (%s) has been built already.' % self.name) - def _update_lexical_scope(self): - """ - Updates lexical scope based on the current list of children. - Lexical scope contains names of blobs that are currently available - and were introduced in the net builder - """ - self._lexical_scope = set() - for child in self._children: - if isinstance(child, core.Net): - self._lexical_scope |= child.UsedBlobNames() - elif isinstance(child, NetBuilder) and child._use_control_ops: - self._lexical_scope |= child._lexical_scope - - def _collect_lexical_scopes(self): - """ - Collects the names of all blobs currently visible in the net builder - """ - scope = set(self._lexical_scope) - parent = self._parent - while parent: - scope |= parent._lexical_scope - parent = parent._parent - return scope - - def _reset_children(self): - self._current_net = None - self._children = [] - self._lexical_scope = set() - def add(self, child): self._assert_mutable() - - if self._use_control_ops: - assert isinstance(child, core.Net) or ( - isinstance(child, NetBuilder) and child._use_control_ops), \ - "Expected Net or NetBuilder with control ops" - self._current_net = None self._children.append(child) # to-do : check it's not a dag net if isinstance(child, core.Net): self._current_net = child - self._update_lexical_scope() return child def current_net(self, name=None): @@ -138,7 +92,6 @@ class NetBuilder(object): return self._children def __exit__(self, etype, *args): - self._update_lexical_scope() self.freeze() if etype is not None: return @@ -146,51 +99,6 @@ class NetBuilder(object): 'This NetBuilder (%s) requires a stop condition ' % self.name + 'to be set with `stop` or `stop_if`') - @staticmethod - def merge_nets(nets_or_builders, outer_blob_names): - # Only nets or builders with control ops are allowed. - # Need to pay attention to external outputs, e.g. - # ... - # IfNet1 (cond_blob): - # (Net1) - # X = 1 - # IfNet2 (...): - # X = X + 1 - # ... - # In this example there're two children in then branch of IfNet1: - # a subnet Net1 that creates blob X and sets its value to one, and - # a net builder IfNet2 that (conditionally) increments X. - # From IfNet2's point of view X is an external input - # and output blob, it will be put into IfNet2 net's external_output. - # At the same time, from the point of view of IfNet1 X is purely local. - # Net.AppendNet just merges external outputs of the networks, so - # without checking this the result of Net1.AppendNet(IfNet2's net) - # would have blob X in external_output - - net = None - for n in nets_or_builders: - cur = None - if isinstance(n, NetBuilder): - assert n._use_control_ops, \ - "Merging of NetBuilder supported only for control ops" - nets = n.get() - assert len(nets) == 1 and isinstance(nets[0], core.Net), \ - "Invalid control op net builder" - cur = nets[0] - else: - assert isinstance(n, core.Net) - cur = n - if net: - net.AppendNet(cur) - else: - net = cur - if net: - # correct external output - external_outputs = [o for o in net.Proto().external_output - if o in outer_blob_names] - net.Proto().external_output[:] = external_outputs - return net - def __str__(self): return self.name or 'Un-named NetBuilder' @@ -325,37 +233,6 @@ class Operations(object): """ return NetBuilder.current().add(_RunIf(cond, name=name)) - def IfNet(self, cond, name=None): - """ - Same as If, but uses 'If' operator instead of execution step logic - """ - return NetBuilder.current().add(_RunIfNet(cond, name=name)) - - def Else(self, name=None): - """ - Else branch of IfNet, has to be specified immediately after IfNet. - Example: - with ops.IfNet(ops.LT([x, y])): - ... - with ops.Else(): - ... - """ - return _RunElseNet(name=name) - - def WhileNet(self, name=None): - """ - NetBuilder for 'While' control operator - """ - return NetBuilder.current().add(_RunWhileNet(name=name)) - - def Condition(self, name=None): - """ - Loop's condition, executed within WhileNet context - """ - assert isinstance(NetBuilder.current(), _RunWhileNet), \ - "Use of Condition outside of WhileNet" - return _RunWhileCondition(name=name) - def task_init(self): """ Defines operations that will be executed once at task startup. @@ -599,139 +476,3 @@ class _RunIf(_RunOnce): assert not self._is_else, 'Elif not allowed for an Else.' return NetBuilder.current().add( _RunIf(name=name or self.name, _already_ran=self._already_ran)) - - -class _RunIfNet(NetBuilder): - """ - Generates a single net that uses If operator - """ - def __init__(self, cond_blob, name=None): - NetBuilder.__init__(self, name=name, _use_control_ops=True) - assert cond_blob, 'Conditional blob is not specified for an If net' - self._cond_blob = cond_blob - self._then_net = None - self._else_net = None - - def add(self, child): - return NetBuilder.add(self, child) - - def __exit__(self, type, *args): - if type is None: - _then_nets = self._children - self._reset_children() - - self._then_net = NetBuilder.merge_nets( - _then_nets, self._collect_lexical_scopes()) - if not self._then_net: - self._then_net = core.Net('empty_then_net') - - if_net = core.Net(self.name + '/if_net') - add_if_op(if_net, self._cond_blob, self._collect_lexical_scopes(), - self._then_net, self._else_net) - - self._current_net = if_net - self._children = [if_net] - NetBuilder.__exit__(self, type, *args) - - -class _RunElseNet(NetBuilder): - """ - Else branch for _RunIfNet builder - """ - def __init__(self, name=None): - NetBuilder.__init__(self, name=name, _use_control_ops=True) - assert self._parent and len(self._parent._children) > 0 and \ - isinstance(self._parent._children[-1], _RunIfNet), \ - 'Invalid use of Else builder' - self._if_builder = self._parent._children[-1] - - def __exit__(self, type, *args): - if type is None: - _else_nets = self._children - self._reset_children() - - self._if_builder._else_net = NetBuilder.merge_nets( - _else_nets, self._collect_lexical_scopes()) - if self._if_builder._else_net: - if_else_net = core.Net(self.name + '/if_else_net') - add_if_op( - if_else_net, - self._if_builder._cond_blob, - self._collect_lexical_scopes(), - self._if_builder._then_net, - self._if_builder._else_net) - self._if_builder._current_net = if_else_net - self._if_builder._children = [if_else_net] - NetBuilder.__exit__(self, type, *args) - - -class _RunWhileNet(NetBuilder): - """ - Generates a single net that uses While operator - """ - def __init__(self, name=None): - NetBuilder.__init__(self, name=name, _use_control_ops=True) - self._cond_builder = None - - def __exit__(self, type, *args): - if type is None: - assert self._cond_builder, \ - 'Condition builder must be specified in While op' - - _cond_blob = self._cond_builder._cond_blob - _cond_net = self._cond_builder._cond_net - - loop_body = self._children - self._reset_children() - loop_body_net = NetBuilder.merge_nets( - loop_body, self._collect_lexical_scopes()) - if not loop_body_net: - loop_body_net = core.Net('empty_loop_body_net') - - while_net = core.Net(self.name + '/while_net') - add_while_op(while_net, _cond_blob, self._collect_lexical_scopes(), - loop_body_net, _cond_net) - - self._current_net = while_net - self._children = [while_net] - NetBuilder.__exit__(self, type, *args) - - -class _RunWhileCondition(NetBuilder): - """ - Computes loop's condition, used in the context of WhileNet. - Last operator must have a single scalar boolean output that will be used - as a condition value, no other blobs created in the condition net are - visible outside of it - """ - def __init__(self, name=None): - NetBuilder.__init__(self, name=name, _use_control_ops=True) - assert self._parent and isinstance(self._parent, _RunWhileNet), \ - 'Invalid use of loop condition builder' - self._cond_blob = None - self._cond_net = None - - def __enter__(self): - builder = NetBuilder.__enter__(self) - assert not self._parent._cond_builder, \ - 'Multiple loop condition builders specified' - assert len(self._parent._children) == 0, \ - 'Condition definition must be specified before the loop\'s body' - self._parent._cond_builder = self - return builder - - def __exit__(self, type, *args): - if type is None: - condition_body = self._children - self._reset_children() - self._cond_net = NetBuilder.merge_nets( - condition_body, self._collect_lexical_scopes()) - assert self._cond_net, 'Invalid loop condition specified' - assert len(self._cond_net.Proto().op) > 0, 'Invalid condition net' - last_op = self._cond_net.Proto().op[-1] - assert len(last_op.output) == 1, 'Invalid condition net' - self._cond_blob = core.BlobReference(name=last_op.output[0], net=None) - - self._current_net = self._cond_net - self._children = [self._cond_net] - NetBuilder.__exit__(self, type, *args) diff --git a/caffe2/python/net_builder_test.py b/caffe2/python/net_builder_test.py index 169419c5c1..1d29711ec9 100644 --- a/caffe2/python/net_builder_test.py +++ b/caffe2/python/net_builder_test.py @@ -245,88 +245,3 @@ class TestNetBuilder(unittest.TestCase): self.assertEquals(total1.fetch(), NUM_INSTANCES * NUM_ITERS) self.assertEquals(total2.fetch(), NUM_INSTANCES * (NUM_ITERS ** 2)) self.assertEquals(total3.fetch(), NUM_INSTANCES * (NUM_ITERS ** 2)) - - def test_if_net(self): - with NetBuilder() as nb: - x0 = ops.Const(0) - x1 = ops.Const(1) - x2 = ops.Const(2) - y0 = ops.Const(0) - y1 = ops.Const(1) - y2 = ops.Const(2) - - # basic logic - first_res = ops.Const(0) - with ops.IfNet(ops.Const(True)): - ops.Const(1, blob_out=first_res) - with ops.Else(): - ops.Const(2, blob_out=first_res) - - second_res = ops.Const(0) - with ops.IfNet(ops.Const(False)): - ops.Const(1, blob_out=second_res) - with ops.Else(): - ops.Const(2, blob_out=second_res) - - # nested and sequential ifs, - # empty then/else, - # passing outer blobs into branches, - # writing into outer blobs, incl. into input blob - # using local blobs - with ops.IfNet(ops.LT([x0, x1])): - local_blob = ops.Const(900) - ops.Add([ops.Const(100), local_blob], [y0]) - - gt = ops.GT([x1, x2]) - with ops.IfNet(gt): - # empty then - pass - with ops.Else(): - ops.Add([y1, local_blob], [local_blob]) - ops.Add([ops.Const(100), y1], [y1]) - - with ops.IfNet(ops.EQ([local_blob, ops.Const(901)])): - ops.Const(7, blob_out=y2) - ops.Add([y1, y2], [y2]) - with ops.Else(): - # empty else - pass - - plan = Plan('if_net_test') - plan.AddStep(to_execution_step(nb)) - ws = workspace.C.Workspace() - ws.run(plan) - - first_res_value = ws.blobs[str(first_res)].fetch() - second_res_value = ws.blobs[str(second_res)].fetch() - y0_value = ws.blobs[str(y0)].fetch() - y1_value = ws.blobs[str(y1)].fetch() - y2_value = ws.blobs[str(y2)].fetch() - - self.assertEquals(first_res_value, 1) - self.assertEquals(second_res_value, 2) - self.assertEquals(y0_value, 1000) - self.assertEquals(y1_value, 101) - self.assertEquals(y2_value, 108) - self.assertTrue(str(local_blob) not in ws.blobs) - - def test_while_net(self): - with NetBuilder() as nb: - x = ops.Const(0) - y = ops.Const(0) - with ops.WhileNet(): - with ops.Condition(): - ops.Add([x, ops.Const(1)], [x]) - ops.LT([x, ops.Const(7)]) - ops.Add([x, y], [y]) - - plan = Plan('while_net_test') - plan.AddStep(to_execution_step(nb)) - ws = workspace.C.Workspace() - ws.run(plan) - - x_value = ws.blobs[str(x)].fetch() - y_value = ws.blobs[str(y)].fetch() - - self.assertEqual(x_value, 7) - self.assertEqual(y_value, 21) |