diff options
author | Dmytro Dzhulgakov <dzhulgakov@fb.com> | 2018-10-02 00:31:42 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-10-02 00:43:40 -0700 |
commit | 1d3f650ce4f1b781f03e8b4f250a25d5a8f819cc (patch) | |
tree | 0cd6434d59568e469549a97524fffe0ae73129a9 | |
parent | ff608a9ff3edded33764c8631427e92c7288bafb (diff) | |
download | pytorch-1d3f650ce4f1b781f03e8b4f250a25d5a8f819cc.tar.gz pytorch-1d3f650ce4f1b781f03e8b4f250a25d5a8f819cc.tar.bz2 pytorch-1d3f650ce4f1b781f03e8b4f250a25d5a8f819cc.zip |
Revert D10098106: [pytorch][PR] [WIP] New version of PT1 model format
Differential Revision:
D10098106
Original commit changeset: 94ec7fc57c84
fbshipit-source-id: 38f729b0970618f38359797b806cbbcd865f4715
-rw-r--r-- | caffe2/core/blob_serialization.cc | 15 | ||||
-rw-r--r-- | caffe2/proto/caffe2.proto | 86 | ||||
-rw-r--r-- | caffe2/proto/torch.proto | 564 | ||||
-rw-r--r-- | caffe2/python/convert.py | 56 | ||||
-rw-r--r-- | caffe2/python/convert_test.py | 234 | ||||
-rw-r--r-- | caffe2/python/pybind_state.cc | 43 | ||||
-rw-r--r-- | caffe2/python/workspace.py | 4 | ||||
-rw-r--r-- | caffe2/utils/proto_convert.cc | 181 | ||||
-rw-r--r-- | caffe2/utils/proto_convert.h | 14 |
9 files changed, 1058 insertions, 139 deletions
diff --git a/caffe2/core/blob_serialization.cc b/caffe2/core/blob_serialization.cc index f27d16adf3..8126b3d594 100644 --- a/caffe2/core/blob_serialization.cc +++ b/caffe2/core/blob_serialization.cc @@ -309,12 +309,6 @@ void TensorSerializer::Serialize( proto.add_string_data(SerializeBlob(temp_blob, "")); } } break; - case TensorProto_DataType_SPECIAL: { - CAFFE_THROW("SPECIAL Tensor is not handled yet."); - } break; - case TensorProto_DataType_NO_CONTENT: { - CAFFE_THROW("NO_CONTENT Tensor should not be serialized."); - } break; // Note: we intentially do not provide "default:" so if any new data types // are added, the compiler should warn the user to add the case here. } @@ -526,14 +520,7 @@ void TensorDeserializer::Deserialize(const TensorProto& proto, Tensor* tensor) { (i + chunkBegin) * temp_blob.meta().itemsize(), 1); } - } break; - case TensorProto_DataType_SPECIAL: { - CAFFE_THROW("SPECIAL Tensor is not handled yet."); - } break; - case TensorProto_DataType_NO_CONTENT: { - CAFFE_THROW("NO_CONTENT Tensor should not be deserialized."); - } break; - // Note: we intentially do not provide "default:" so if any new data types + } } context->FinishDeviceComputation(); } diff --git a/caffe2/proto/caffe2.proto b/caffe2/proto/caffe2.proto index 9dc745edbd..7187001029 100644 --- a/caffe2/proto/caffe2.proto +++ b/caffe2/proto/caffe2.proto @@ -15,46 +15,23 @@ package caffe2; message TensorProto { // The dimensions in the tensor. repeated int64 dims = 1; - // The strides of the tensor. - repeated int64 strides = 12; - - // Data type enum DataType { UNDEFINED = 0; - - // Basic types - FLOAT = 1; // float - INT32 = 2; // int - BYTE = 3; // byte, when deserialized, is going to be restored as uint8 - STRING = 4; // string - - // Less-commonly used data types - BOOL = 5; // bool - UINT8 = 6; // uint8_t - INT8 = 7; // int8_t - UINT16 = 8; // uint16_t - INT16 = 9; // int16_t - INT64 = 10; // int64_t + FLOAT = 1; // float + INT32 = 2; // int + BYTE = 3; // BYTE, when deserialized, is going to be restored as uint8. + STRING = 4; // string + // Less-commonly used data types. + BOOL = 5; // bool + UINT8 = 6; // uint8_t + INT8 = 7; // int8_t + UINT16 = 8; // uint16_t + INT16 = 9; // int16_t + INT64 = 10; // int64_t FLOAT16 = 12; // at::Half - DOUBLE = 13; // double - - // Special data type, type information is stored in the special type field - SPECIAL = 51; - // Use TensorProto to specify the shape and type - NO_CONTENT = 52; + DOUBLE = 13; // double } optional DataType data_type = 2 [default = FLOAT]; - // if data_type is SPECIAL, use this field to express the type info - optional SpecialType special_type = 13; - - // Data storage - enum StorageType { - TYPED = 1; - RAW = 2; - EXTERNAL = 3; - ALIAS = 4; - } - optional StorageType storage_type = 14 [default = TYPED]; // For float repeated float float_data = 3 [packed = true]; // For int32, uint8, int8, uint16, int16, bool, and float16 @@ -69,13 +46,6 @@ message TensorProto { repeated double double_data = 9 [packed = true]; // For int64 repeated int64 int64_data = 10 [packed = true]; - // For raw data - optional bytes raw_data = 15; - // External data by file name - optional string external_data = 16; - // For argument, to share the content - optional string alias = 17; - // Optionally, a name for the tensor. optional string name = 7; @@ -83,23 +53,13 @@ message TensorProto { // it was serialized from. This is useful in cases like snapshotting a whole // workspace in a multi-GPU environment. optional DeviceOption device_detail = 8; - // When loading from chunks this is going to indicate where to put data in the // full array. When not used full data have to be present message Segment { required int64 begin = 1; required int64 end = 2; - optional int64 chunk_num = 51; - optional int64 chunk_id = 52; } optional Segment segment = 11; - optional string debug_info = 18; - - // For PyTorch serialized tensor. - optional bool require_gradient = 19; - optional bool is_buffer = 20; - - repeated Argument annotations = 21; } message QTensorProto { @@ -126,11 +86,7 @@ message TensorShape { repeated int32 unknown_dims = 3; optional bool unknown_shape = 4 [default = false]; optional string name = 5; -} -// This is prepared for non-tensor types. -message SpecialType { - optional string name = 1; } message TensorShapes { @@ -141,17 +97,13 @@ message TensorShapes { // values, or repeated float, int and string arrays. message Argument { optional string name = 1; - optional float f = 2; optional int64 i = 3; optional bytes s = 4; - optional TensorProto t = 10; optional NetDef n = 8; - repeated float floats = 5; repeated int64 ints = 6; repeated bytes strings = 7; - repeated TensorProto tensors = 11; repeated NetDef nets = 9; } @@ -200,11 +152,7 @@ message DeviceOption { // Operator Definition. message OperatorDef { repeated string input = 1; // the name of the input blobs - // the input name in the schema, for named inputs - repeated string mapped_inputs = 11; repeated string output = 2; // the name of output top blobs - // the outputname in the schema, for named outputs - repeated string mapped_outputs = 12; optional string name = 3; // the operator name. This is optional. // the operator type. This is needed to create the object from the operator // registry. @@ -238,16 +186,6 @@ message OperatorDef { // This is an optional string with no assumed characteristics as // operators can be constructed in any language. optional string debug_info = 10; - - // additional annotations - repeated Argument annotations = 13; - - // for jit ir exporting - optional string aten_function = 14; - - // for operator versioning - optional string domain = 15; - optional string op_version = 16; } // Network definition. diff --git a/caffe2/proto/torch.proto b/caffe2/proto/torch.proto index f31c3b65ec..43dfd02b14 100644 --- a/caffe2/proto/torch.proto +++ b/caffe2/proto/torch.proto @@ -4,77 +4,547 @@ import "caffe2/proto/caffe2.proto"; package torch; -enum ProtoVersion { +// Overview +// +// ONNX is an open specification that is comprised of the following components: +// +// 1) A definition of an extensible computation graph model. +// 2) Definitions of standard data types. +// 3) Definitions of built-in operators. +// +// This document describes the syntax of models and their computation graphs, +// as well as the standard data types. Together, they are referred to as the ONNX +// Intermediate Representation, or 'IR' for short. +// +// The normative semantic specification of the ONNX IR is found in docs/IR.md. +// Definitions of the built-in neural network operators may be found in docs/Operators.md. + +// Notes +// +// Release +// +// We are still in the very early stage of defining ONNX. The current +// version of ONNX is a starting point. While we are actively working +// towards a complete spec, we would like to get the community involved +// by sharing our working version of ONNX. +// +// Protobuf compatibility +// +// To simplify framework compatibility, ONNX is defined using the subset of +// protobuf that is compatible with both protobuf v2 and v3. This means that we +// do not use any protobuf features that are only available in one of the two +// versions. +// +// Here are the most notable contortions we have to carry out to work around +// these limitations: +// +// - No 'map' (added protobuf 3.0). We instead represent mappings as lists +// of key-value pairs, where order does not matter and duplicates +// are not allowed. + +// Versioning +// +// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md +// +// To be compatible with both proto2 and proto3, we will use a version number +// that is not defined by the default value but an explicit enum number. +enum Version { + // proto3 requires the first enum value to be zero. + // We add this just to appease the compiler. _START_VERSION = 0; - IR_VERSION_NEWEST = 0x0000000000000101; + // The version field is always serialized and we will use it to store the + // version that the graph is generated from. This helps us set up version + // control. + // For the IR, we are using simple numbers starting with with 0x00000001, + // which was the version we published on Oct 10, 2017. + IR_VERSION_2017_10_10 = 0x0000000000000001; + + // IR_VERSION 2 published on Oct 30, 2017 + // - Added type discriminator to AttributeProto to support proto3 users + IR_VERSION_2017_10_30 = 0x0000000000000002; + + // IR VERSION 3 published on Nov 3, 2017 + // - For operator versioning: + // - Added new message OperatorSetIdProto + // - Added opset_import in ModelProto + // - For vendor extensions, added domain in NodeProto + IR_VERSION_NEWEST_ONNX = 0x0000000000000003; + + // PYTORCH IR VERSION + IR_VERSION_NEWEST = 0x0000000000000103; } -message MethodDef { - // method name - optional string name = 1; // method name +// Attributes +// +// A named attribute containing either singular float, integer, string, graph, +// and tensor values, or repeated float, integer, string, graph, and tensor values. +// An AttributeProto MUST contain the name field, and *only one* of the +// following content fields, effectively enforcing a C/C++ union equivalent. +message AttributeProto { + + // Note: this enum is structurally identical to the OpSchema::AttrType + // enum defined in schema.h. If you rev one, you likely need to rev the other. + enum AttributeType { + UNDEFINED = 0; + FLOAT = 1; + INT = 2; + STRING = 3; + TENSOR = 4; + GRAPH = 5; + + FLOATS = 6; + INTS = 7; + STRINGS = 8; + TENSORS = 9; + GRAPHS = 10; + } + + // The name field MUST be present for this version of the IR. + optional string name = 1; // namespace Attribute - // static graph - optional caffe2.NetDef graph = 2; - // method is represented as torch script - optional string torch_script = 3; + // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function. + // In this case, this AttributeProto does not contain data, and it's a reference of attribute + // in parent scope. + // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph. + optional string ref_attr_name = 21; - // the names of inputs and outputs - repeated string inputs = 4; - repeated string outputs = 5; + // A human-readable documentation for this attribute. Markdown is allowed. + optional string doc_string = 13; - // whether this method is main or not. - // by default, `forward` should the main method. - optional bool is_main = 6; + // The type field MUST be present for this version of the IR. + // For 0.0.1 versions of the IR, this field was not defined, and + // implementations needed to use has_field hueristics to determine + // which value field was in use. For IR_VERSION 0.0.2 or later, this + // field MUST be set and match the f|i|s|t|... field in use. This + // change was made to accomodate proto3 implementations. + optional AttributeType type = 20; // discriminator that indicates which field below is in use - optional string debug_info = 7; + // Exactly ONE of the following fields must be present for this version of the IR + optional float f = 2; // float + optional int64 i = 3; // int + optional bytes s = 4; // UTF-8 string + optional TensorProto t = 5; // tensor value + optional GraphProto g = 6; // graph + // Do not use field below, it's deprecated. + // optional ValueProto v = 12; // value - subsumes everything but graph - repeated caffe2.Argument annotations = 8; + repeated float floats = 7; // list of floats + repeated int64 ints = 8; // list of ints + repeated bytes strings = 9; // list of UTF-8 strings + repeated TensorProto tensors = 10; // list of tensors + repeated GraphProto graphs = 11; // list of graph } +// Defines information on value, including the name, the type, and +// the shape of the value. +message ValueInfoProto { + // This field MUST be present in this version of the IR. + optional string name = 1; // namespace Value + // This field MUST be present in this version of the IR. + optional TypeProto type = 2; + // A human-readable documentation for this value. Markdown is allowed. + optional string doc_string = 3; +} + +// Nodes +// +// Computation graphs are made up of a DAG of nodes, which represent what is +// commonly called a "layer" or "pipeline stage" in machine learning frameworks. +// +// For example, it can be a node of type "Conv" that takes in an image, a filter +// tensor and a bias tensor, and produces the convolved output. +message NodeProto { + repeated string input = 1; // namespace Value + repeated string output = 2; // namespace Value -message ModuleDef { - repeated ModuleDef submodules = 1; + // An optional identifier for this node in a graph. + // This field MAY be absent in ths version of the IR. + optional string name = 3; // namespace Node - // We suppose to store the modules in one of the following format: - // - methods (static graph or torch script) - // - pickle - // - cpp_arena - repeated MethodDef methods = 2; - // because the old pickle modules may not be supported by torch_script, - // have to stored as pickle_arena at this moment. - optional bytes pickle_arena = 3; - // should be exposed by the Class Archive, so user can save - // module specific data which cannot be store in the graph or torch_script - optional bytes cpp_arena = 4; + // The symbolic identifier of the Operator to execute. + optional string op_type = 4; // namespace Operator + // The domain of the OperatorSet that specifies the operator named by op_type. + optional string domain = 7; // namespace Domain - // the names of inputs and outputs of the module are inferred - // from the main method. + // Additional named attributes. + repeated AttributeProto attribute = 5; - optional string debug_info = 5; + // A human-readable documentation for this node. Markdown is allowed. + // Equivalent to string debug_info + optional string doc_string = 6; - repeated caffe2.Argument annotations = 6; + // Additional annotations, attributes are defined in Schema + // To be added as annotations: + // string engine + // string list control_input + // int64 is_gradient_op + repeated AttributeProto annotations = 8; + + // Besides the node type, PyTorhc also serialize ATen function signature + optional caffe2.DeviceOption device_option = 51; + optional string aten_function = 52; } -message ModelDef { +// Models +// +// ModelProto is a top-level file/container format for bundling a ML model and +// associating its computation graph with metadata. +// +// The semantics of the model are described by the associated GraphProto. +// +// Model ==> Caffe2 MetaNetDef +// ==> PyTorch Module +message ModelProto { + // The version of the IR this model targets. See Version enum above. + // This field MUST be present. optional int64 ir_version = 1; - // main module of the model - optional ModuleDef main_module = 2; + // The OperatorSets this model relies on. + // All ModelProtos MUST have at least one entry that + // specifies which version of the ONNX OperatorSet is + // being imported. + // + // All nodes in the ModelProto's graph will bind against the operator + // with the same-domain/same-op_type operator with the HIGHEST version + // in the referenced operator sets. + repeated OperatorSetIdProto opset_import = 8; + + // The name of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + optional string producer_name = 2; + + // The version of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + optional string producer_version = 3; + + // Domain name of the model. + // We use reverse domain names as name space indicators. For example: + // `com.facebook.fair` or `com.microsoft.cognitiveservices` + // + // Together with `model_version` and GraphProto.name, this forms the unique identity of + // the graph. + optional string domain = 4; + + // The version of the graph encoded. See Version enum below. + optional int64 model_version = 5; + + // A human-readable documentation for this model. Markdown is allowed. + optional string doc_string = 6; - repeated caffe2.TensorProto parameters = 3; - repeated caffe2.TensorProto value_infos = 4; + // The parameterized graph that is evaluated to execute the model. + // The main graph, in single graph case, it is ONNX compatible. + optional GraphProto graph = 7; - // to distinguish whether exported from c2 or torch - optional string producer_name = 5; + // The remaining nets in MetaNetDef. + // Submodules and methods in PyTorch. + repeated GraphProto methods = 15; + + // Named metadata values; keys should be distinct. + // Many meta data in MetaNetDef and preditor are piggy backed here. + // 1) project + // 2) model_class + // 3) internal_version + // 4) predictor_type + // 5) predictor_id + // 6) execute_plan + // 7) applicationSpecificInfo (another string map, need to verify it has no duplicate.) + // 8) engine + // 9) publish time + repeated StringStringEntryProto metadata_props = 14; + + // Model name + optional string name = 16; + + // Model name + repeated AttributeProto annotations = 17; + + // Mapping from list name to blob name list, must be string list type. + // Equivalent to blobs in MetaNetDef. + repeated AttributeProto blob_lists = 51; + + // Mapping from plan name to serialized plan, must be string list type. + // Equivalent to plans in MetaNetDef. + repeated AttributeProto plans = 52; +}; + +// StringStringEntryProto follows the pattern for cross-proto-version maps. +// See https://developers.google.com/protocol-buffers/docs/proto3#maps +message StringStringEntryProto { + optional string key = 1; + optional string value= 2; +}; + +// Graphs +// +// A graph defines the computational logic of a model and is comprised of a parameterized +// list of nodes that form a directed acyclic graph based on their inputs and outputs. +// This is the equivalent of the "network" or "graph" in many deep learning +// frameworks. +// Graph ==> NetDef in Caffe2 +// ==> Submodule/Method in PyTorch +message GraphProto { + // The nodes in the graph, sorted topologically. + repeated NodeProto node = 1; + + // The name of the graph. + optional string name = 2; // namespace Graph + + // A list of named tensor values, used to specify constant inputs of the graph. + // Each TensorProto entry must have a distinct name (within the list) that + // also appears in the input list. + repeated TensorProto initializer = 5; + + // A human-readable documentation for this graph. Markdown is allowed. + optional string doc_string = 10; + + // The inputs and outputs of the graph. + repeated ValueInfoProto input = 11; + repeated ValueInfoProto output = 12; + + // Information for the values in the graph. The ValueInfoProto.name's + // must be distinct. It is optional for a value to appear in value_info list. + repeated ValueInfoProto value_info = 13; + + // Additional annotations. + repeated AttributeProto annotations = 14; + + // DO NOT USE the following fields, they were deprecated from earlier versions. + // repeated string input = 3; + // repeated string output = 4; + // optional int64 ir_version = 6; + // optional int64 producer_version = 7; + // optional string producer_tag = 8; + // optional string domain = 9; +} - // put build version here - optional string producer_version = 6; +// Tensors +// +// A serialized tensor value. +message TensorProto { + enum DataType { + UNDEFINED = 0; + // Basic types. + FLOAT = 1; // float + UINT8 = 2; // uint8_t + INT8 = 3; // int8_t + UINT16 = 4; // uint16_t + INT16 = 5; // int16_t + INT32 = 6; // int32_t + INT64 = 7; // int64_t + STRING = 8; // string + BOOL = 9; // bool - optional string name = 7; + // Advanced types + FLOAT16 = 10; + DOUBLE = 11; + UINT32 = 12; + UINT64 = 13; + COMPLEX64 = 14; // complex with float32 real and imaginary components + COMPLEX128 = 15; // complex with float64 real and imaginary components + // Future extensions go here. - optional string debug_info = 8; + // Special data type, real type information is stored in ValueInfoProto. + // If data_type is SPECIAL, raw_data should be used. + SPECIAL = 51; + } + + // The shape of the tensor. + repeated int64 dims = 1; + repeated int64 strides = 14; + + // The data type of the tensor. + optional DataType data_type = 2; + + // For very large tensors, we may want to store them in chunks, in which + // case the following fields will specify the segment that is stored in + // the current TensorProto. + message Segment { + optional int64 begin = 1; + optional int64 end = 2; + optional int64 chuck_num = 51; + optional int64 chuck_id = 52; + } + // Used as offset in the external shared data. + optional Segment segment = 3; + + // Tensor content must be organized in row-major order. + // + // Depending on the data_type field, exactly one of the fields below with + // name ending in _data is used to store the elements of the tensor. + + // For float and complex64 values + // Complex64 tensors are encoded as a single array of floats, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component apparing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be FLOAT or COMPLEX64. + repeated float float_data = 4 [packed = true]; + + // For int32, uint8, int8, uint16, int16, bool, and Half values + // float16 values must be bit-wise converted to an uint16_t prior + // to writing to the buffer. + // When this field is present, the data_type field MUST be + // INT32, INT16, INT8, UINT16, INT8, BOOL, or FLOAT16 + repeated int32 int32_data = 5 [packed = true]; + + // For strings. + // Each element of string_data is a UTF-8 encoded Unicode + // string. No trailing null, no leading BOM. The protobuf "string" + // scalar type is not used to match ML community conventions. + // When this field is present, the data_type field MUST be STRING + repeated bytes string_data = 6; + + // For int64. + // When this field is present, the data_type field MUST be INT64 + repeated int64 int64_data = 7 [packed = true]; + + // Optionally, a name for the tensor. + optional string name = 8; // namespace Value + + // A human-readable documentation for this tensor. Markdown is allowed. + optional string doc_string = 12; + + // Serializations can either use one of the fields above, or use this + // raw bytes field. The only exception is the string case, where one is + // required to store the content in the repeated bytes string_data field. + // + // When this raw_data field is used to store tensor value, elements MUST + // be stored in as fixed-width, little-endian order. + // Floating-point data types MUST be stored in IEEE 754 format. + // Complex64 elements must be written as two consecutive FLOAT values, real component first. + // Complex128 elements must be written as two consecutive DOUBLE values, real component first. + // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). + // + // Note: the advantage of specific field rather than the raw_data field is + // that in some cases (e.g. int data), protobuf does a better packing via + // variable length storage, and may lead to smaller binary footprint. + // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED + optional bytes raw_data = 9; + + // For double + // Complex64 tensors are encoded as a single array of doubles, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component apparing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 + repeated double double_data = 10 [packed = true]; + + // For uint64 and uint32 values + // When this field is present, the data_type field MUST be + // UINT32 or UINT64 + repeated uint64 uint64_data = 11 [packed = true]; + + // External data by file name + optional string external_data = 13; + + // If two tensors represent the same weights/content, use alias. + // Must exist a TensorProto named alias in the initializer list. + // To avoid the duplicate tensor in attribute, such as value in Constant node. + // This is useful, if everything is stored just in the proto. + optional string alias = 16; + + // Additional annotations. + repeated AttributeProto annotations = 17; + + // Device info + optional caffe2.DeviceOption device_option = 51; + + // For PyTorch serialized tensor. + optional int64 require_gradient = 52; + optional int64 is_buffer = 53; +} + +// Defines a tensor shape. A dimension can be either an integer value +// or a symbolic variable. A symbolic variable represents an unknown +// dimension. +message TensorShapeProto { + message Dimension { + oneof value { + int64 dim_value = 1; + string dim_param = 2; // namespace Shape + }; + // Standard denotation can optionally be used to denote tensor + // dimensions with standard semantic descriptions to ensure + // that operations are applied to the correct axis of a tensor. + // Refer to https://github.com/onnx/onnx/blob/master/docs/DimensionDenotation.md#denotation-definition + // for pre-defined dimension denotations. + optional string denotation = 3; + }; + // To represent a scalar, using no dim to represent 0-d tensor. + repeated Dimension dim = 1; + + repeated Dimension stride = 51; +} + +// Types +// +// The standard ONNX data types. +message TypeProto { + + message Tensor { + // This field MUST NOT have the value of UNDEFINED + // This field MUST be present for this version of the IR. + optional TensorProto.DataType elem_type = 1; + optional TensorShapeProto shape = 2; + } + + // Sequence type: List, Tuple + message Sequence { + // elem_type and elem_type_list cannot appear together. + // If all the element types are the same, we use elem_type, + // otherwise, we specify the type of each element in elem_type_list. + optional TypeProto elem_type = 1; + repeated TypeProto elem_type_list = 51; + enum SequenceType { + UNDEFINED = 0; + LIST = 1; + TUPLE = 2; + } + optional SequenceType sequence_type = 52; + } + + // Map<K, V>, (not necessary at this moment) + message Map { + optional TensorProto.DataType key_type = 1; + optional TypeProto value_type = 2; + } + + // Special type of blobs, based on the type_name, we can choose the right + // serializer and deserialzier. + message SpecialBlob { + optional string type_name = 1; + } + + oneof value { + // The type of a tensor. + Tensor tensor_type = 1; + Sequence sequence_type = 4; + Map map_type = 5; + SpecialBlob special_type = 51; + } + + // An optional denotation can be used to denote the whole + // type with a standard semantic description as to what is + // stored inside. Refer to https://github.com/onnx/onnx/blob/master/docs/TypeDenotation.md#type-denotation-definition + // for pre-defined type denotations. + optional string denotation = 6; +} - // annotations, it is used for MetaNetDef's metadata - repeated caffe2.Argument annotations = 9; +// Operator Sets +// +// OperatorSets are uniquely identified by a (domain, opset_version) pair. +message OperatorSetIdProto { + // The domain of the operator set being identified. + // The empty string ("") or absence of this field implies the operator + // set that is defined as part of the ONNX specification. + // This field MUST be present in this version of the IR when referring to any other operator set. + optional string domain = 1; + // The version of the operator set being identified. + // This field MUST be present in this version of the IR. + optional int64 version = 2; } diff --git a/caffe2/python/convert.py b/caffe2/python/convert.py index 44f81d6e2d..50eaf220c7 100644 --- a/caffe2/python/convert.py +++ b/caffe2/python/convert.py @@ -8,3 +8,59 @@ from __future__ import unicode_literals from caffe2.proto import caffe2_pb2, torch_pb2 import caffe2.python._import_c_extension as C + + +def ArgumentToAttributeProto(arg): + serialized_arg = None + if hasattr(arg, 'SerializeToString') and callable(arg.SerializeToString): + serialized_arg = arg.SerializeToString() + elif isinstance(arg, bytes): + serialized_arg = arg + else: + raise ValueError('No SerializeToString method is detected. ' + 'neither arg is bytes.\ntype is {}'.format(type(arg))) + attr = torch_pb2.AttributeProto() + attr.ParseFromString(C.argument_to_attribute_proto(serialized_arg)) + return attr + + +def AttributeProtoToArgument(attr): + serialized_attr = None + if hasattr(attr, 'SerializeToString') and callable(attr.SerializeToString): + serialized_attr = attr.SerializeToString() + elif isinstance(attr, bytes): + serialized_attr = attr + else: + raise ValueError('No SerializeToString method is detected. ' + 'neither attr is bytes.\ntype is {}'.format(type(attr))) + arg = caffe2_pb2.Argument() + arg.ParseFromString(C.attribute_proto_to_argument(serialized_attr)) + return arg + + +def OperatorDefToNodeProto(op_def): + serialized_op_def = None + if hasattr(op_def, 'SerializeToString') and callable(op_def.SerializeToString): + serialized_op_def = op_def.SerializeToString() + elif isinstance(op_def, bytes): + serialized_op_def = op_def + else: + raise ValueError('No SerializeToString method is detected. ' + 'neither op_def is bytes.\ntype is {}'.format(type(op_def))) + node = torch_pb2.NodeProto() + node.ParseFromString(C.operator_def_to_node_proto(serialized_op_def)) + return node + + +def NodeProtoToOperatorDef(node_proto): + serialized_node_proto = None + if hasattr(node_proto, 'SerializeToString') and callable(node_proto.SerializeToString): + serialized_node_proto = node_proto.SerializeToString() + elif isinstance(node_proto, bytes): + serialized_node_proto = node_proto + else: + raise ValueError('No SerializeToString method is detected. ' + 'neither node_proto is bytes.\ntype is {}'.format(type(node_proto))) + op_def = caffe2_pb2.OperatorDef() + op_def.ParseFromString(C.node_proto_to_operator_def(serialized_node_proto)) + return op_def diff --git a/caffe2/python/convert_test.py b/caffe2/python/convert_test.py index 82c969c901..c8de7e9750 100644 --- a/caffe2/python/convert_test.py +++ b/caffe2/python/convert_test.py @@ -12,5 +12,239 @@ class TestOperator(unittest.TestCase): def setUp(self): workspace.ResetWorkspace() + def testArgument2AttributeProto(self): + arg_f = caffe2_pb2.Argument() + arg_f.name = "TestArgF" + arg_f.f = 10.0 + attr_f = convert.ArgumentToAttributeProto(arg_f) + self.assertEqual(attr_f.name, arg_f.name) + self.assertEqual(attr_f.f, arg_f.f) + + arg_i = caffe2_pb2.Argument() + arg_i.name = "TestArgI" + arg_i.i = 100 + attr_i = convert.ArgumentToAttributeProto(arg_i) + self.assertEqual(attr_i.name, arg_i.name) + self.assertEqual(attr_i.i, arg_i.i) + + arg_s = caffe2_pb2.Argument() + arg_s.name = "TestArgS" + arg_s.s = "TestS".encode("utf-8") + attr_s = convert.ArgumentToAttributeProto(arg_s) + self.assertEqual(attr_s.name, arg_s.name) + self.assertEqual(attr_s.s, arg_s.s) + + # TODO: test net arg + + arg_floats = caffe2_pb2.Argument() + arg_floats.name = "TestArgFloats" + arg_floats.floats.extend([10.0, 11.0, 12.0]) + attr_floats = convert.ArgumentToAttributeProto(arg_floats) + self.assertEqual(attr_floats.name, arg_floats.name) + self.assertEqual(attr_floats.floats, arg_floats.floats) + + arg_ints = caffe2_pb2.Argument() + arg_ints.name = "TestArgInts" + arg_ints.ints.extend([100, 101, 102]) + attr_ints = convert.ArgumentToAttributeProto(arg_ints) + self.assertEqual(attr_ints.name, arg_ints.name) + self.assertEqual(attr_ints.ints, arg_ints.ints) + + arg_strings = caffe2_pb2.Argument() + arg_strings.name = "TestArgStrings" + arg_strings.strings.extend([ + "TestStrings1".encode("utf-8"), + "TestStrings2".encode("utf-8"), + ]) + attr_strings = convert.ArgumentToAttributeProto(arg_strings) + self.assertEqual(attr_strings.name, arg_strings.name) + self.assertEqual(attr_strings.strings, arg_strings.strings) + + # TODO: test nets arg + + def testAttributeProto2Argument(self): + attr_f = torch_pb2.AttributeProto() + attr_f.type = torch_pb2.AttributeProto.FLOAT + attr_f.name = "TestAttrF" + attr_f.f = 10.0 + arg_f = convert.AttributeProtoToArgument(attr_f) + self.assertEqual(arg_f.name, attr_f.name) + self.assertEqual(arg_f.f, attr_f.f) + + attr_i = torch_pb2.AttributeProto() + attr_i.type = torch_pb2.AttributeProto.INT + attr_i.name = "TestArgI" + attr_i.i = 100 + arg_i = convert.AttributeProtoToArgument(attr_i) + self.assertEqual(arg_i.name, attr_i.name) + self.assertEqual(arg_i.i, attr_i.i) + + attr_s = torch_pb2.AttributeProto() + attr_s.type = torch_pb2.AttributeProto.STRING + attr_s.name = "TestArgS" + attr_s.s = "TestS".encode("utf-8") + arg_s = convert.AttributeProtoToArgument(attr_s) + self.assertEqual(arg_s.name, attr_s.name) + self.assertEqual(arg_s.s, attr_s.s) + + # TODO: test graph attribute + + attr_floats = torch_pb2.AttributeProto() + attr_floats.type = torch_pb2.AttributeProto.FLOATS + attr_floats.name = "TestAttrFloats" + attr_floats.floats.extend([10.0, 11.0, 12.0]) + arg_floats = convert.AttributeProtoToArgument(attr_floats) + self.assertEqual(arg_floats.name, attr_floats.name) + self.assertEqual(arg_floats.floats, attr_floats.floats) + + attr_ints = torch_pb2.AttributeProto() + attr_ints.type = torch_pb2.AttributeProto.INTS + attr_ints.name = "TestArgInts" + attr_ints.ints.extend([100, 101, 102]) + arg_ints = convert.AttributeProtoToArgument(attr_ints) + self.assertEqual(arg_ints.name, attr_ints.name) + self.assertEqual(arg_ints.ints, attr_ints.ints) + + attr_strings = torch_pb2.AttributeProto() + attr_strings.type = torch_pb2.AttributeProto.STRINGS + attr_strings.name = "TestArgStrings" + attr_strings.strings.extend([ + "TestStrings1".encode("utf-8"), + "TestStrings2".encode("utf-8"), + ]) + arg_strings = convert.AttributeProtoToArgument(attr_strings) + self.assertEqual(arg_strings.name, attr_strings.name) + self.assertEqual(arg_strings.strings, attr_strings.strings) + + # TODO: test graphs attribute + + + def testOperatorDef2NodeProto(self): + op_def = caffe2_pb2.OperatorDef() + op_def.input.extend(["A", "B", "C"]) + op_def.output.extend(["X", "Y"]) + op_def.name = "TestOpName" + op_def.type = "TestOp" + arg1 = caffe2_pb2.Argument() + arg1.name = "TestArg1" + arg1.i = 1 + arg2 = caffe2_pb2.Argument() + arg2.name = "TestArg2" + arg1.s = "TestInfo".encode("utf-8") + op_def.arg.extend([arg1, arg2]) + op_def.device_option.CopyFrom(caffe2_pb2.DeviceOption()) + op_def.engine = "TestEngine".encode("utf-8") + op_def.control_input.extend(["input1", "input2"]) + op_def.is_gradient_op = True + op_def.debug_info = "TestDebugInfo" + + node = convert.OperatorDefToNodeProto(op_def) + + self.assertEqual(node.input, op_def.input) + self.assertEqual(node.output, op_def.output) + self.assertEqual(node.name, op_def.name) + self.assertEqual(node.op_type, op_def.type) + self.assertEqual(node.attribute[0].name, op_def.arg[0].name) + self.assertEqual(node.attribute[1].name, op_def.arg[1].name) + self.assertEqual(node.device_option, op_def.device_option) + node_engine = [a.s.decode("utf-8") for a in node.annotations if a.name == "engine"][0] + self.assertEqual(node_engine, op_def.engine) + node_control_input = [a.strings for a in node.annotations if a.name == "control_input"][0] + self.assertEqual(len(node_control_input), len(op_def.control_input)) + for x, y in zip(node_control_input, op_def.control_input): + self.assertEqual(x.decode("utf-8"), y) + self.assertEqual(node.doc_string, op_def.debug_info) + node_is_gradient_op = [a.i for a in node.annotations if a.name == "is_gradient_op"][0] + self.assertEqual(node_is_gradient_op, int(op_def.is_gradient_op)) + + def testNodeProto2OperatorDef(self): + node = torch_pb2.NodeProto() + node.input.extend(["A", "B", "C"]) + node.output.extend(["X", "Y"]) + node.name = "TestOpName" + node.op_type = "TestOp" + attr1 = torch_pb2.AttributeProto() + attr1.name = "TestAttr1" + attr1.type = torch_pb2.AttributeProto.STRING + attr1.s = "TestInfo".encode("utf-8") + attr2 = torch_pb2.AttributeProto() + attr2.name = "TestAttr2" + attr2.type = torch_pb2.AttributeProto.INT + attr2.i = 10 + node.attribute.extend([attr1, attr2]) + node.device_option.CopyFrom(caffe2_pb2.DeviceOption()) + anno1 = torch_pb2.AttributeProto() + anno1.name = "engine" + anno1.type = torch_pb2.AttributeProto.STRING + anno1.s = "TestEngine".encode("utf-8") + anno2 = torch_pb2.AttributeProto() + anno2.name = "control_input" + anno2.type = torch_pb2.AttributeProto.STRINGS + anno2.strings.extend(["input1".encode("utf-8"), "input2".encode("utf-8")]) + anno3 = torch_pb2.AttributeProto() + anno3.name = "is_gradient_op" + anno3.type = torch_pb2.AttributeProto.INT + anno3.i = 1 + node.annotations.extend([anno1, anno2, anno3]) + node.doc_string = "TestDocString".encode("utf-8") + + op_def = convert.NodeProtoToOperatorDef(node) + + self.assertEqual(op_def.input, node.input) + self.assertEqual(op_def.output, node.output) + self.assertEqual(op_def.name, node.name) + self.assertEqual(op_def.type, node.op_type) + self.assertEqual(op_def.arg[0].name, node.attribute[0].name) + self.assertEqual(op_def.arg[1].name, node.attribute[1].name) + self.assertEqual(op_def.device_option, node.device_option) + node_engine = [a.s for a in node.annotations if a.name == "engine"][0] + self.assertEqual(op_def.engine, node_engine.decode("utf-8")) + node_control_input = [a.strings for a in node.annotations if a.name == "control_input"][0] + for x, y in zip(op_def.control_input, node_control_input): + self.assertEqual(x, y.decode("utf-8")) + self.assertEqual(op_def.debug_info, node.doc_string) + node_is_gradient_op = [a.i for a in node.annotations if a.name == "is_gradient_op"][0] + self.assertEqual(int(op_def.is_gradient_op), node_is_gradient_op) + + def testEnd2End(self): + op_def = caffe2_pb2.OperatorDef() + op_def.type = "Add" + op_def.input.extend(["input1"]) + op_def.input.extend(["input2"]) + op_def.output.extend(["output1"]) + node = convert.OperatorDefToNodeProto(op_def) + + input1 = np.random.randn(1, 3, 1, 5).astype(np.float32) + input2 = np.random.randn(2, 1, 4, 1).astype(np.float32) + ref_output1 = input1 + input2 + workspace.FeedBlob("input1", input1) + workspace.FeedBlob("input2", input2) + self.assertEqual(workspace.RunOperatorOnce(node.SerializeToString(), legacy_proto=False), True) + + self.assertEqual(workspace.HasBlob("output1"), True) + fetched_back = workspace.FetchBlob("output1") + np.testing.assert_array_equal(fetched_back, ref_output1) + + def testRoundTrip(self): + op_def = caffe2_pb2.OperatorDef() + op_def.type = "Add" + op_def.input.extend(["input1"]) + op_def.input.extend(["input2"]) + op_def.output.extend(["output1"]) + node = convert.OperatorDefToNodeProto(op_def) + new_op_def = convert.NodeProtoToOperatorDef(node) + + input1 = np.random.randn(1, 3, 1, 5).astype(np.float32) + input2 = np.random.randn(2, 1, 4, 1).astype(np.float32) + ref_output1 = input1 + input2 + workspace.FeedBlob("input1", input1) + workspace.FeedBlob("input2", input2) + self.assertEqual(workspace.RunOperatorOnce(new_op_def.SerializeToString()), True) + + self.assertEqual(workspace.HasBlob("output1"), True) + fetched_back = workspace.FetchBlob("output1") + np.testing.assert_array_equal(fetched_back, ref_output1) + + if __name__ == '__main__': unittest.main() diff --git a/caffe2/python/pybind_state.cc b/caffe2/python/pybind_state.cc index 7ebee57d49..7062ead045 100644 --- a/caffe2/python/pybind_state.cc +++ b/caffe2/python/pybind_state.cc @@ -1187,10 +1187,17 @@ void addGlobalMethods(py::module& m) { return true; }); m.def("nets", []() { return gWorkspace->Nets(); }); - m.def("run_operator_once", [](const py::bytes& op_def) { + m.def("run_operator_once", [](const py::bytes& op_def, bool legacy_proto=true) { CAFFE_ENFORCE(gWorkspace); OperatorDef def; - CAFFE_ENFORCE(ParseProtoFromLargeString(op_def.cast<std::string>(), &def)); + if (legacy_proto) { + CAFFE_ENFORCE(ParseProtoFromLargeString(op_def.cast<std::string>(), &def)); + } else { + ::torch::NodeProto node; + CAFFE_ENFORCE( + ParseProtoFromLargeString(op_def.cast<std::string>(), &node)); + NodeProtoToOperatorDef(node, &def); + } py::gil_scoped_release g; CAFFE_ENFORCE(gWorkspace->RunOperatorOnce(def)); return true; @@ -1527,6 +1534,38 @@ void addGlobalMethods(py::module& m) { CAFFE_ENFORCE(blob); return BlobStat::sizeBytes(*blob); }); + m.def("argument_to_attribute_proto", [](py::bytes arg_str) -> py::bytes { + Argument arg; + CAFFE_ENFORCE( + ParseProtoFromLargeString(arg_str.cast<std::string>(), &arg)); + ::torch::AttributeProto attr; + ArgumentToAttributeProto(arg, &attr); + return attr.SerializeAsString(); + }); + m.def("attribute_proto_to_argument", [](py::bytes attr_str) -> py::bytes { + ::torch::AttributeProto attr; + CAFFE_ENFORCE( + ParseProtoFromLargeString(attr_str.cast<std::string>(), &attr)); + Argument arg; + AttributeProtoToArgument(attr, &arg); + return arg.SerializeAsString(); + }); + m.def("operator_def_to_node_proto", [](py::bytes op_str) -> py::bytes { + OperatorDef op_def; + CAFFE_ENFORCE( + ParseProtoFromLargeString(op_str.cast<std::string>(), &op_def)); + ::torch::NodeProto node; + OperatorDefToNodeProto(op_def, &node); + return node.SerializeAsString(); + }); + m.def("node_proto_to_operator_def", [](py::bytes node_str) -> py::bytes { + ::torch::NodeProto node_proto; + CAFFE_ENFORCE( + ParseProtoFromLargeString(node_str.cast<std::string>(), &node_proto)); + OperatorDef op_def; + NodeProtoToOperatorDef(node_proto, &op_def); + return op_def.SerializeAsString(); + }); m.def("support_onnx_export", [](const std::string& op) -> bool { const OpSchema* schema = caffe2::OpSchemaRegistry::Schema(op); if (!schema) { diff --git a/caffe2/python/workspace.py b/caffe2/python/workspace.py index ef02f64dc9..a41cc15317 100644 --- a/caffe2/python/workspace.py +++ b/caffe2/python/workspace.py @@ -163,8 +163,8 @@ def GetOperatorCost(operator, blobs): return C.get_operator_cost(StringifyProto(operator), blobs) -def RunOperatorOnce(operator): - return C.run_operator_once(StringifyProto(operator)) +def RunOperatorOnce(operator, legacy_proto=True): + return C.run_operator_once(StringifyProto(operator), legacy_proto) def RunOperatorsOnce(operators): diff --git a/caffe2/utils/proto_convert.cc b/caffe2/utils/proto_convert.cc index 1d69c8c80c..790bd27429 100644 --- a/caffe2/utils/proto_convert.cc +++ b/caffe2/utils/proto_convert.cc @@ -2,4 +2,185 @@ #include "caffe2/core/logging.h" namespace caffe2 { + +C10_EXPORT void ArgumentToAttributeProto( + const Argument& arg, + ::torch::AttributeProto* attr) { + CAFFE_ENFORCE(arg.has_name()); + attr->set_name(arg.name()); + if (arg.has_f()) { + attr->set_f(arg.f()); + } else if (arg.has_i()) { + attr->set_i(arg.i()); + } else if (arg.has_s()) { + attr->set_s(arg.s()); + } else if (arg.has_n()) { + // TODO + CAFFE_THROW("NetDef conversion is not implemented yet."); + } else if (arg.floats_size() > 0) { + attr->mutable_floats()->CopyFrom(arg.floats()); + } else if (arg.ints_size() > 0) { + attr->mutable_ints()->CopyFrom(arg.ints()); + } else if (arg.strings_size() > 0) { + attr->mutable_strings()->CopyFrom(arg.strings()); + } else if (arg.nets_size() > 0) { + // TODO + CAFFE_THROW("NetDefs conversion is not implemented yet."); + } +} + +C10_EXPORT void AttributeProtoToArgument( + const ::torch::AttributeProto& attr, + Argument* arg) { + CAFFE_ENFORCE(attr.has_name()); + arg->set_name(attr.name()); + CAFFE_ENFORCE(attr.has_type()); + const auto type = attr.type(); + if (type == + ::torch::AttributeProto_AttributeType:: + AttributeProto_AttributeType_FLOAT) { + CAFFE_ENFORCE(attr.has_f()); + arg->set_f(attr.f()); + } else if ( + type == + ::torch::AttributeProto_AttributeType::AttributeProto_AttributeType_INT) { + CAFFE_ENFORCE(attr.has_i()); + arg->set_i(attr.i()); + } else if ( + type == + ::torch::AttributeProto_AttributeType:: + AttributeProto_AttributeType_STRING) { + CAFFE_ENFORCE(attr.has_s()); + arg->set_s(attr.s()); + } else if ( + type == + ::torch::AttributeProto_AttributeType:: + AttributeProto_AttributeType_TENSOR) { + CAFFE_THROW("Caffe2's Argument does not support tensor as attribute."); + } else if ( + type == + ::torch::AttributeProto_AttributeType:: + AttributeProto_AttributeType_GRAPH) { + // TODO + CAFFE_THROW("GraphProto conversion is not implemented yet."); + } else if ( + type == + ::torch::AttributeProto_AttributeType:: + AttributeProto_AttributeType_FLOATS) { + arg->mutable_floats()->CopyFrom(attr.floats()); + } else if ( + type == + ::torch::AttributeProto_AttributeType:: + AttributeProto_AttributeType_INTS) { + arg->mutable_ints()->CopyFrom(attr.ints()); + } else if ( + type == + ::torch::AttributeProto_AttributeType:: + AttributeProto_AttributeType_STRINGS) { + arg->mutable_strings()->CopyFrom(attr.strings()); + } else if ( + type == + ::torch::AttributeProto_AttributeType:: + AttributeProto_AttributeType_TENSORS) { + CAFFE_THROW("Caffe2's Argument does not support tensors as attribute."); + } else if ( + type == + ::torch::AttributeProto_AttributeType:: + AttributeProto_AttributeType_GRAPHS) { + // TODO + CAFFE_THROW("GraphProtos conversion is not implemented yet."); + } else { + CAFFE_THROW("Unknow Attribute type."); + } +} + +C10_EXPORT void OperatorDefToNodeProto( + const OperatorDef& def, + ::torch::NodeProto* node) { + node->mutable_input()->CopyFrom(def.input()); + node->mutable_output()->CopyFrom(def.output()); + if (def.has_name()) { + node->set_name(def.name()); + } + CAFFE_ENFORCE(def.has_type()); + node->set_op_type(def.type()); + for (int i = 0; i < def.arg_size(); ++i) { + auto attr = node->add_attribute(); + ArgumentToAttributeProto(def.arg(i), attr); + } + if (def.has_device_option()) { + node->mutable_device_option()->CopyFrom(def.device_option()); + } + if (def.has_engine()) { + auto attr = node->add_annotations(); + attr->set_name("engine"); + attr->set_type(::torch::AttributeProto_AttributeType:: + AttributeProto_AttributeType_STRING); + attr->set_s(def.engine()); + } + if (def.control_input_size() > 0) { + auto attr = node->add_annotations(); + attr->set_name("control_input"); + attr->set_type(::torch::AttributeProto_AttributeType:: + AttributeProto_AttributeType_STRINGS); + attr->mutable_strings()->CopyFrom(def.control_input()); + } + if (def.has_is_gradient_op()) { + auto attr = node->add_annotations(); + attr->set_name("is_gradient_op"); + attr->set_type(::torch::AttributeProto_AttributeType:: + AttributeProto_AttributeType_INT); + if (def.is_gradient_op()) { + attr->set_i(1); + } else { + attr->set_i(0); + } + } + if (def.has_debug_info()) { + node->set_doc_string(def.debug_info()); + } +} + +C10_EXPORT void NodeProtoToOperatorDef( + const ::torch::NodeProto& node, + OperatorDef* def) { + def->mutable_input()->CopyFrom(node.input()); + def->mutable_output()->CopyFrom(node.output()); + if (node.has_name()) { + def->set_name(node.name()); + } + + CAFFE_ENFORCE(node.has_op_type()); + def->set_type(node.op_type()); + for (int i = 0; i < node.attribute_size(); ++i) { + auto arg = def->add_arg(); + AttributeProtoToArgument(node.attribute(i), arg); + } + if (node.has_doc_string()) { + def->set_debug_info(node.doc_string()); + } + for (int i = 0; i < node.annotations_size(); ++i) { + const auto& attr = node.annotations(i); + CAFFE_ENFORCE(attr.has_name()); + if (attr.name() == "engine") { + CAFFE_ENFORCE(attr.has_s()); + def->set_engine(attr.s()); + } else if (attr.name() == "control_input") { + def->mutable_control_input()->CopyFrom(attr.strings()); + } else if (attr.name() == "is_gradient_op") { + CAFFE_ENFORCE(attr.has_i()); + if (i == 0) { + def->set_is_gradient_op(false); + } else { + def->set_is_gradient_op(true); + } + } + auto arg = def->add_arg(); + AttributeProtoToArgument(node.annotations(i), arg); + } + if (node.has_device_option()) { + def->mutable_device_option()->CopyFrom(node.device_option()); + } +} + } // namespace caffe2 diff --git a/caffe2/utils/proto_convert.h b/caffe2/utils/proto_convert.h index 91bcf1bafa..a9ca9c3ad4 100644 --- a/caffe2/utils/proto_convert.h +++ b/caffe2/utils/proto_convert.h @@ -6,6 +6,20 @@ #include "caffe2/proto/torch_pb.h" namespace caffe2 { + +CAFFE2_API void ArgumentToAttributeProto( + const Argument& arg, + ::torch::AttributeProto* attr); +CAFFE2_API void AttributeProtoToArgument( + const ::torch::AttributeProto& attr, + Argument* arg); +CAFFE2_API void OperatorDefToNodeProto( + const OperatorDef& def, + ::torch::NodeProto* node); +CAFFE2_API void NodeProtoToOperatorDef( + const ::torch::NodeProto& node, + OperatorDef* def); + } // namespace caffe2 #endif // CAFFE2_UTILS_PROTO_CONVERT_H_ |