#ifndef CAFFE2_OPERATORS_DO_OP_H_ #define CAFFE2_OPERATORS_DO_OP_H_ #include #include #include #include #include "caffe2/core/context.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" #include "caffe2/operators/create_scope_op.h" #include "caffe2/proto/caffe2_pb.h" namespace caffe2 { template class DoOp final : public Operator { public: DoOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), parent_ws_(ws) { CAFFE_ENFORCE( this->template HasSingleArgumentOfType("net"), "net must be specified in Do operator"); net_def_ = this->template GetSingleArgument("net", NetDef()); is_gradient_op_ = operator_def.is_gradient_op(); copy_external_blobs_ = this->template GetSingleArgument("copy_external_blobs", false); reuse_workspace_ = this->template GetSingleArgument("reuse_workspace", false); CAFFE_ENFORCE( !(is_gradient_op_ && reuse_workspace_), "Gradient Do op requires use of stacked workspaces"); CAFFE_ENFORCE( !(copy_external_blobs_ && reuse_workspace_), "Reuse workspace and copy external blobs simultaneously in Do op"); const auto& inner_blobs = this->template GetRepeatedArgument("inner_blobs"); const auto& outer_blobs_idx = this->template GetRepeatedArgument("outer_blobs_idx"); CAFFE_ENFORCE_EQ( inner_blobs.size(), outer_blobs_idx.size(), "Invalid blob bindings: different inner/outer blobs lengths"); const auto& outer_blob_names = checkAndGetOuterNames(operator_def); std::unordered_set used_outer_names; for (size_t blob_idx = 0; blob_idx < inner_blobs.size(); ++blob_idx) { CAFFE_ENFORCE( !blob_bindings_.count(inner_blobs[blob_idx]), "Invalid blob bindings: redefinition of inner blob " + inner_blobs[blob_idx]); CAFFE_ENFORCE( outer_blobs_idx[blob_idx] >= 0 && outer_blobs_idx[blob_idx] < outer_blob_names.size(), "Invalid blob bindings: outer blob index (" + c10::to_string(outer_blobs_idx[blob_idx]) + ", inner name: " + inner_blobs[blob_idx] + ") is out of bounds [0, " + c10::to_string(outer_blob_names.size() - 1) + "]"); const auto& outer_name = outer_blob_names[outer_blobs_idx[blob_idx]]; CAFFE_ENFORCE( !used_outer_names.count(outer_name), "Reusage of outer name: " + outer_name); used_outer_names.insert(outer_name); blob_bindings_[inner_blobs[blob_idx]] = outer_name; forwarded_inner_blobs_.insert(inner_blobs[blob_idx]); } std::unordered_set all_outer_names( outer_blob_names.begin(), outer_blob_names.end()); CAFFE_ENFORCE_EQ( used_outer_names.size(), all_outer_names.size(), "Not all outer names are used in blob bindings"); } USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override { auto* ws_stack = this->template Output(OutputSize() - 1); std::shared_ptr net_workspace; if (is_gradient_op_) { net_workspace = ws_stack->popGradientWorkspace(parent_ws_, blob_bindings_); } else { if (reuse_workspace_ && !ws_stack->empty()) { net_workspace = ws_stack->reuseLastForwardWorkspace(parent_ws_, blob_bindings_); } else { net_workspace = ws_stack->pushForwardWorkspace(parent_ws_, blob_bindings_); } } CAFFE_ENFORCE(net_workspace, "Failed to initialize Do op workspace"); // TODO(iliacher): figure how to reuse existing net with a new workspace auto* net = net_workspace->GetNet(net_def_.name()); if (!net) { net = net_workspace->CreateNet(net_def_, true); } CAFFE_ENFORCE(net, "Failed to initialize subnet"); auto success = net->Run(); if (!is_gradient_op_ && copy_external_blobs_) { net_workspace->template CopyForwardedTensors( forwarded_inner_blobs_); } return success; } private: // returns vector of input blob names followed by output blob names in // operator definition order; ensures that input (output) names are unique, // checks number of input (output) blobs std::vector checkAndGetOuterNames( const OperatorDef& operator_def) const { auto input_names = getInputBlobNames(operator_def); CAFFE_ENFORCE(!input_names.empty(), "Expected at least one input blob"); std::string input_ws_blob = input_names.back(); // copy // removing blob that holds pointer op workspace input_names.pop_back(); std::unordered_set all_input_names( input_names.begin(), input_names.end()); CAFFE_ENFORCE_EQ( input_names.size(), all_input_names.size(), "Duplicate input blobs"); auto output_names = getOutputBlobNames(operator_def); CAFFE_ENFORCE(!output_names.empty(), "Expected at least one output blob"); const auto& output_ws_blob = output_names.back(); CAFFE_ENFORCE_EQ( input_ws_blob, output_ws_blob, "Expected same input/output workspace blob"); // remove blob that holds pointer to op workspace output_names.pop_back(); std::unordered_set all_output_names( output_names.begin(), output_names.end()); CAFFE_ENFORCE_EQ( output_names.size(), all_output_names.size(), "Duplicate output blobs"); std::vector outer_blob_names; outer_blob_names.reserve(input_names.size() + output_names.size()); outer_blob_names.insert( outer_blob_names.end(), input_names.begin(), input_names.end()); outer_blob_names.insert( outer_blob_names.end(), output_names.begin(), output_names.end()); return outer_blob_names; } std::vector getInputBlobNames( const OperatorDef& operator_def) const { std::vector names; names.reserve(operator_def.input_size()); for (auto idx = 0; idx < operator_def.input_size(); ++idx) { names.push_back(operator_def.input(idx)); } return names; } std::vector getOutputBlobNames( const OperatorDef& operator_def) const { std::vector names; names.reserve(operator_def.output_size()); for (auto idx = 0; idx < operator_def.output_size(); ++idx) { names.push_back(operator_def.output(idx)); } return names; } std::unordered_map blob_bindings_; std::unordered_set forwarded_inner_blobs_; bool is_gradient_op_; bool copy_external_blobs_; bool reuse_workspace_; NetDef net_def_; Workspace* parent_ws_; }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_DO_OP_H_