summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDmytro Dzhulgakov <dzhulgakov@fb.com>2018-10-02 00:31:42 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-10-02 00:43:40 -0700
commit1d3f650ce4f1b781f03e8b4f250a25d5a8f819cc (patch)
tree0cd6434d59568e469549a97524fffe0ae73129a9
parentff608a9ff3edded33764c8631427e92c7288bafb (diff)
downloadpytorch-1d3f650ce4f1b781f03e8b4f250a25d5a8f819cc.tar.gz
pytorch-1d3f650ce4f1b781f03e8b4f250a25d5a8f819cc.tar.bz2
pytorch-1d3f650ce4f1b781f03e8b4f250a25d5a8f819cc.zip
Revert D10098106: [pytorch][PR] [WIP] New version of PT1 model format
Differential Revision: D10098106 Original commit changeset: 94ec7fc57c84 fbshipit-source-id: 38f729b0970618f38359797b806cbbcd865f4715
-rw-r--r--caffe2/core/blob_serialization.cc15
-rw-r--r--caffe2/proto/caffe2.proto86
-rw-r--r--caffe2/proto/torch.proto564
-rw-r--r--caffe2/python/convert.py56
-rw-r--r--caffe2/python/convert_test.py234
-rw-r--r--caffe2/python/pybind_state.cc43
-rw-r--r--caffe2/python/workspace.py4
-rw-r--r--caffe2/utils/proto_convert.cc181
-rw-r--r--caffe2/utils/proto_convert.h14
9 files changed, 1058 insertions, 139 deletions
diff --git a/caffe2/core/blob_serialization.cc b/caffe2/core/blob_serialization.cc
index f27d16adf3..8126b3d594 100644
--- a/caffe2/core/blob_serialization.cc
+++ b/caffe2/core/blob_serialization.cc
@@ -309,12 +309,6 @@ void TensorSerializer::Serialize(
proto.add_string_data(SerializeBlob(temp_blob, ""));
}
} break;
- case TensorProto_DataType_SPECIAL: {
- CAFFE_THROW("SPECIAL Tensor is not handled yet.");
- } break;
- case TensorProto_DataType_NO_CONTENT: {
- CAFFE_THROW("NO_CONTENT Tensor should not be serialized.");
- } break;
// Note: we intentially do not provide "default:" so if any new data types
// are added, the compiler should warn the user to add the case here.
}
@@ -526,14 +520,7 @@ void TensorDeserializer::Deserialize(const TensorProto& proto, Tensor* tensor) {
(i + chunkBegin) * temp_blob.meta().itemsize(),
1);
}
- } break;
- case TensorProto_DataType_SPECIAL: {
- CAFFE_THROW("SPECIAL Tensor is not handled yet.");
- } break;
- case TensorProto_DataType_NO_CONTENT: {
- CAFFE_THROW("NO_CONTENT Tensor should not be deserialized.");
- } break;
- // Note: we intentially do not provide "default:" so if any new data types
+ }
}
context->FinishDeviceComputation();
}
diff --git a/caffe2/proto/caffe2.proto b/caffe2/proto/caffe2.proto
index 9dc745edbd..7187001029 100644
--- a/caffe2/proto/caffe2.proto
+++ b/caffe2/proto/caffe2.proto
@@ -15,46 +15,23 @@ package caffe2;
message TensorProto {
// The dimensions in the tensor.
repeated int64 dims = 1;
- // The strides of the tensor.
- repeated int64 strides = 12;
-
- // Data type
enum DataType {
UNDEFINED = 0;
-
- // Basic types
- FLOAT = 1; // float
- INT32 = 2; // int
- BYTE = 3; // byte, when deserialized, is going to be restored as uint8
- STRING = 4; // string
-
- // Less-commonly used data types
- BOOL = 5; // bool
- UINT8 = 6; // uint8_t
- INT8 = 7; // int8_t
- UINT16 = 8; // uint16_t
- INT16 = 9; // int16_t
- INT64 = 10; // int64_t
+ FLOAT = 1; // float
+ INT32 = 2; // int
+ BYTE = 3; // BYTE, when deserialized, is going to be restored as uint8.
+ STRING = 4; // string
+ // Less-commonly used data types.
+ BOOL = 5; // bool
+ UINT8 = 6; // uint8_t
+ INT8 = 7; // int8_t
+ UINT16 = 8; // uint16_t
+ INT16 = 9; // int16_t
+ INT64 = 10; // int64_t
FLOAT16 = 12; // at::Half
- DOUBLE = 13; // double
-
- // Special data type, type information is stored in the special type field
- SPECIAL = 51;
- // Use TensorProto to specify the shape and type
- NO_CONTENT = 52;
+ DOUBLE = 13; // double
}
optional DataType data_type = 2 [default = FLOAT];
- // if data_type is SPECIAL, use this field to express the type info
- optional SpecialType special_type = 13;
-
- // Data storage
- enum StorageType {
- TYPED = 1;
- RAW = 2;
- EXTERNAL = 3;
- ALIAS = 4;
- }
- optional StorageType storage_type = 14 [default = TYPED];
// For float
repeated float float_data = 3 [packed = true];
// For int32, uint8, int8, uint16, int16, bool, and float16
@@ -69,13 +46,6 @@ message TensorProto {
repeated double double_data = 9 [packed = true];
// For int64
repeated int64 int64_data = 10 [packed = true];
- // For raw data
- optional bytes raw_data = 15;
- // External data by file name
- optional string external_data = 16;
- // For argument, to share the content
- optional string alias = 17;
-
// Optionally, a name for the tensor.
optional string name = 7;
@@ -83,23 +53,13 @@ message TensorProto {
// it was serialized from. This is useful in cases like snapshotting a whole
// workspace in a multi-GPU environment.
optional DeviceOption device_detail = 8;
-
// When loading from chunks this is going to indicate where to put data in the
// full array. When not used full data have to be present
message Segment {
required int64 begin = 1;
required int64 end = 2;
- optional int64 chunk_num = 51;
- optional int64 chunk_id = 52;
}
optional Segment segment = 11;
- optional string debug_info = 18;
-
- // For PyTorch serialized tensor.
- optional bool require_gradient = 19;
- optional bool is_buffer = 20;
-
- repeated Argument annotations = 21;
}
message QTensorProto {
@@ -126,11 +86,7 @@ message TensorShape {
repeated int32 unknown_dims = 3;
optional bool unknown_shape = 4 [default = false];
optional string name = 5;
-}
-// This is prepared for non-tensor types.
-message SpecialType {
- optional string name = 1;
}
message TensorShapes {
@@ -141,17 +97,13 @@ message TensorShapes {
// values, or repeated float, int and string arrays.
message Argument {
optional string name = 1;
-
optional float f = 2;
optional int64 i = 3;
optional bytes s = 4;
- optional TensorProto t = 10;
optional NetDef n = 8;
-
repeated float floats = 5;
repeated int64 ints = 6;
repeated bytes strings = 7;
- repeated TensorProto tensors = 11;
repeated NetDef nets = 9;
}
@@ -200,11 +152,7 @@ message DeviceOption {
// Operator Definition.
message OperatorDef {
repeated string input = 1; // the name of the input blobs
- // the input name in the schema, for named inputs
- repeated string mapped_inputs = 11;
repeated string output = 2; // the name of output top blobs
- // the outputname in the schema, for named outputs
- repeated string mapped_outputs = 12;
optional string name = 3; // the operator name. This is optional.
// the operator type. This is needed to create the object from the operator
// registry.
@@ -238,16 +186,6 @@ message OperatorDef {
// This is an optional string with no assumed characteristics as
// operators can be constructed in any language.
optional string debug_info = 10;
-
- // additional annotations
- repeated Argument annotations = 13;
-
- // for jit ir exporting
- optional string aten_function = 14;
-
- // for operator versioning
- optional string domain = 15;
- optional string op_version = 16;
}
// Network definition.
diff --git a/caffe2/proto/torch.proto b/caffe2/proto/torch.proto
index f31c3b65ec..43dfd02b14 100644
--- a/caffe2/proto/torch.proto
+++ b/caffe2/proto/torch.proto
@@ -4,77 +4,547 @@ import "caffe2/proto/caffe2.proto";
package torch;
-enum ProtoVersion {
+// Overview
+//
+// ONNX is an open specification that is comprised of the following components:
+//
+// 1) A definition of an extensible computation graph model.
+// 2) Definitions of standard data types.
+// 3) Definitions of built-in operators.
+//
+// This document describes the syntax of models and their computation graphs,
+// as well as the standard data types. Together, they are referred to as the ONNX
+// Intermediate Representation, or 'IR' for short.
+//
+// The normative semantic specification of the ONNX IR is found in docs/IR.md.
+// Definitions of the built-in neural network operators may be found in docs/Operators.md.
+
+// Notes
+//
+// Release
+//
+// We are still in the very early stage of defining ONNX. The current
+// version of ONNX is a starting point. While we are actively working
+// towards a complete spec, we would like to get the community involved
+// by sharing our working version of ONNX.
+//
+// Protobuf compatibility
+//
+// To simplify framework compatibility, ONNX is defined using the subset of
+// protobuf that is compatible with both protobuf v2 and v3. This means that we
+// do not use any protobuf features that are only available in one of the two
+// versions.
+//
+// Here are the most notable contortions we have to carry out to work around
+// these limitations:
+//
+// - No 'map' (added protobuf 3.0). We instead represent mappings as lists
+// of key-value pairs, where order does not matter and duplicates
+// are not allowed.
+
+// Versioning
+//
+// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md
+//
+// To be compatible with both proto2 and proto3, we will use a version number
+// that is not defined by the default value but an explicit enum number.
+enum Version {
+ // proto3 requires the first enum value to be zero.
+ // We add this just to appease the compiler.
_START_VERSION = 0;
- IR_VERSION_NEWEST = 0x0000000000000101;
+ // The version field is always serialized and we will use it to store the
+ // version that the graph is generated from. This helps us set up version
+ // control.
+ // For the IR, we are using simple numbers starting with with 0x00000001,
+ // which was the version we published on Oct 10, 2017.
+ IR_VERSION_2017_10_10 = 0x0000000000000001;
+
+ // IR_VERSION 2 published on Oct 30, 2017
+ // - Added type discriminator to AttributeProto to support proto3 users
+ IR_VERSION_2017_10_30 = 0x0000000000000002;
+
+ // IR VERSION 3 published on Nov 3, 2017
+ // - For operator versioning:
+ // - Added new message OperatorSetIdProto
+ // - Added opset_import in ModelProto
+ // - For vendor extensions, added domain in NodeProto
+ IR_VERSION_NEWEST_ONNX = 0x0000000000000003;
+
+ // PYTORCH IR VERSION
+ IR_VERSION_NEWEST = 0x0000000000000103;
}
-message MethodDef {
- // method name
- optional string name = 1; // method name
+// Attributes
+//
+// A named attribute containing either singular float, integer, string, graph,
+// and tensor values, or repeated float, integer, string, graph, and tensor values.
+// An AttributeProto MUST contain the name field, and *only one* of the
+// following content fields, effectively enforcing a C/C++ union equivalent.
+message AttributeProto {
+
+ // Note: this enum is structurally identical to the OpSchema::AttrType
+ // enum defined in schema.h. If you rev one, you likely need to rev the other.
+ enum AttributeType {
+ UNDEFINED = 0;
+ FLOAT = 1;
+ INT = 2;
+ STRING = 3;
+ TENSOR = 4;
+ GRAPH = 5;
+
+ FLOATS = 6;
+ INTS = 7;
+ STRINGS = 8;
+ TENSORS = 9;
+ GRAPHS = 10;
+ }
+
+ // The name field MUST be present for this version of the IR.
+ optional string name = 1; // namespace Attribute
- // static graph
- optional caffe2.NetDef graph = 2;
- // method is represented as torch script
- optional string torch_script = 3;
+ // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function.
+ // In this case, this AttributeProto does not contain data, and it's a reference of attribute
+ // in parent scope.
+ // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph.
+ optional string ref_attr_name = 21;
- // the names of inputs and outputs
- repeated string inputs = 4;
- repeated string outputs = 5;
+ // A human-readable documentation for this attribute. Markdown is allowed.
+ optional string doc_string = 13;
- // whether this method is main or not.
- // by default, `forward` should the main method.
- optional bool is_main = 6;
+ // The type field MUST be present for this version of the IR.
+ // For 0.0.1 versions of the IR, this field was not defined, and
+ // implementations needed to use has_field hueristics to determine
+ // which value field was in use. For IR_VERSION 0.0.2 or later, this
+ // field MUST be set and match the f|i|s|t|... field in use. This
+ // change was made to accomodate proto3 implementations.
+ optional AttributeType type = 20; // discriminator that indicates which field below is in use
- optional string debug_info = 7;
+ // Exactly ONE of the following fields must be present for this version of the IR
+ optional float f = 2; // float
+ optional int64 i = 3; // int
+ optional bytes s = 4; // UTF-8 string
+ optional TensorProto t = 5; // tensor value
+ optional GraphProto g = 6; // graph
+ // Do not use field below, it's deprecated.
+ // optional ValueProto v = 12; // value - subsumes everything but graph
- repeated caffe2.Argument annotations = 8;
+ repeated float floats = 7; // list of floats
+ repeated int64 ints = 8; // list of ints
+ repeated bytes strings = 9; // list of UTF-8 strings
+ repeated TensorProto tensors = 10; // list of tensors
+ repeated GraphProto graphs = 11; // list of graph
}
+// Defines information on value, including the name, the type, and
+// the shape of the value.
+message ValueInfoProto {
+ // This field MUST be present in this version of the IR.
+ optional string name = 1; // namespace Value
+ // This field MUST be present in this version of the IR.
+ optional TypeProto type = 2;
+ // A human-readable documentation for this value. Markdown is allowed.
+ optional string doc_string = 3;
+}
+
+// Nodes
+//
+// Computation graphs are made up of a DAG of nodes, which represent what is
+// commonly called a "layer" or "pipeline stage" in machine learning frameworks.
+//
+// For example, it can be a node of type "Conv" that takes in an image, a filter
+// tensor and a bias tensor, and produces the convolved output.
+message NodeProto {
+ repeated string input = 1; // namespace Value
+ repeated string output = 2; // namespace Value
-message ModuleDef {
- repeated ModuleDef submodules = 1;
+ // An optional identifier for this node in a graph.
+ // This field MAY be absent in ths version of the IR.
+ optional string name = 3; // namespace Node
- // We suppose to store the modules in one of the following format:
- // - methods (static graph or torch script)
- // - pickle
- // - cpp_arena
- repeated MethodDef methods = 2;
- // because the old pickle modules may not be supported by torch_script,
- // have to stored as pickle_arena at this moment.
- optional bytes pickle_arena = 3;
- // should be exposed by the Class Archive, so user can save
- // module specific data which cannot be store in the graph or torch_script
- optional bytes cpp_arena = 4;
+ // The symbolic identifier of the Operator to execute.
+ optional string op_type = 4; // namespace Operator
+ // The domain of the OperatorSet that specifies the operator named by op_type.
+ optional string domain = 7; // namespace Domain
- // the names of inputs and outputs of the module are inferred
- // from the main method.
+ // Additional named attributes.
+ repeated AttributeProto attribute = 5;
- optional string debug_info = 5;
+ // A human-readable documentation for this node. Markdown is allowed.
+ // Equivalent to string debug_info
+ optional string doc_string = 6;
- repeated caffe2.Argument annotations = 6;
+ // Additional annotations, attributes are defined in Schema
+ // To be added as annotations:
+ // string engine
+ // string list control_input
+ // int64 is_gradient_op
+ repeated AttributeProto annotations = 8;
+
+ // Besides the node type, PyTorhc also serialize ATen function signature
+ optional caffe2.DeviceOption device_option = 51;
+ optional string aten_function = 52;
}
-message ModelDef {
+// Models
+//
+// ModelProto is a top-level file/container format for bundling a ML model and
+// associating its computation graph with metadata.
+//
+// The semantics of the model are described by the associated GraphProto.
+//
+// Model ==> Caffe2 MetaNetDef
+// ==> PyTorch Module
+message ModelProto {
+ // The version of the IR this model targets. See Version enum above.
+ // This field MUST be present.
optional int64 ir_version = 1;
- // main module of the model
- optional ModuleDef main_module = 2;
+ // The OperatorSets this model relies on.
+ // All ModelProtos MUST have at least one entry that
+ // specifies which version of the ONNX OperatorSet is
+ // being imported.
+ //
+ // All nodes in the ModelProto's graph will bind against the operator
+ // with the same-domain/same-op_type operator with the HIGHEST version
+ // in the referenced operator sets.
+ repeated OperatorSetIdProto opset_import = 8;
+
+ // The name of the framework or tool used to generate this model.
+ // This field SHOULD be present to indicate which implementation/tool/framework
+ // emitted the model.
+ optional string producer_name = 2;
+
+ // The version of the framework or tool used to generate this model.
+ // This field SHOULD be present to indicate which implementation/tool/framework
+ // emitted the model.
+ optional string producer_version = 3;
+
+ // Domain name of the model.
+ // We use reverse domain names as name space indicators. For example:
+ // `com.facebook.fair` or `com.microsoft.cognitiveservices`
+ //
+ // Together with `model_version` and GraphProto.name, this forms the unique identity of
+ // the graph.
+ optional string domain = 4;
+
+ // The version of the graph encoded. See Version enum below.
+ optional int64 model_version = 5;
+
+ // A human-readable documentation for this model. Markdown is allowed.
+ optional string doc_string = 6;
- repeated caffe2.TensorProto parameters = 3;
- repeated caffe2.TensorProto value_infos = 4;
+ // The parameterized graph that is evaluated to execute the model.
+ // The main graph, in single graph case, it is ONNX compatible.
+ optional GraphProto graph = 7;
- // to distinguish whether exported from c2 or torch
- optional string producer_name = 5;
+ // The remaining nets in MetaNetDef.
+ // Submodules and methods in PyTorch.
+ repeated GraphProto methods = 15;
+
+ // Named metadata values; keys should be distinct.
+ // Many meta data in MetaNetDef and preditor are piggy backed here.
+ // 1) project
+ // 2) model_class
+ // 3) internal_version
+ // 4) predictor_type
+ // 5) predictor_id
+ // 6) execute_plan
+ // 7) applicationSpecificInfo (another string map, need to verify it has no duplicate.)
+ // 8) engine
+ // 9) publish time
+ repeated StringStringEntryProto metadata_props = 14;
+
+ // Model name
+ optional string name = 16;
+
+ // Model name
+ repeated AttributeProto annotations = 17;
+
+ // Mapping from list name to blob name list, must be string list type.
+ // Equivalent to blobs in MetaNetDef.
+ repeated AttributeProto blob_lists = 51;
+
+ // Mapping from plan name to serialized plan, must be string list type.
+ // Equivalent to plans in MetaNetDef.
+ repeated AttributeProto plans = 52;
+};
+
+// StringStringEntryProto follows the pattern for cross-proto-version maps.
+// See https://developers.google.com/protocol-buffers/docs/proto3#maps
+message StringStringEntryProto {
+ optional string key = 1;
+ optional string value= 2;
+};
+
+// Graphs
+//
+// A graph defines the computational logic of a model and is comprised of a parameterized
+// list of nodes that form a directed acyclic graph based on their inputs and outputs.
+// This is the equivalent of the "network" or "graph" in many deep learning
+// frameworks.
+// Graph ==> NetDef in Caffe2
+// ==> Submodule/Method in PyTorch
+message GraphProto {
+ // The nodes in the graph, sorted topologically.
+ repeated NodeProto node = 1;
+
+ // The name of the graph.
+ optional string name = 2; // namespace Graph
+
+ // A list of named tensor values, used to specify constant inputs of the graph.
+ // Each TensorProto entry must have a distinct name (within the list) that
+ // also appears in the input list.
+ repeated TensorProto initializer = 5;
+
+ // A human-readable documentation for this graph. Markdown is allowed.
+ optional string doc_string = 10;
+
+ // The inputs and outputs of the graph.
+ repeated ValueInfoProto input = 11;
+ repeated ValueInfoProto output = 12;
+
+ // Information for the values in the graph. The ValueInfoProto.name's
+ // must be distinct. It is optional for a value to appear in value_info list.
+ repeated ValueInfoProto value_info = 13;
+
+ // Additional annotations.
+ repeated AttributeProto annotations = 14;
+
+ // DO NOT USE the following fields, they were deprecated from earlier versions.
+ // repeated string input = 3;
+ // repeated string output = 4;
+ // optional int64 ir_version = 6;
+ // optional int64 producer_version = 7;
+ // optional string producer_tag = 8;
+ // optional string domain = 9;
+}
- // put build version here
- optional string producer_version = 6;
+// Tensors
+//
+// A serialized tensor value.
+message TensorProto {
+ enum DataType {
+ UNDEFINED = 0;
+ // Basic types.
+ FLOAT = 1; // float
+ UINT8 = 2; // uint8_t
+ INT8 = 3; // int8_t
+ UINT16 = 4; // uint16_t
+ INT16 = 5; // int16_t
+ INT32 = 6; // int32_t
+ INT64 = 7; // int64_t
+ STRING = 8; // string
+ BOOL = 9; // bool
- optional string name = 7;
+ // Advanced types
+ FLOAT16 = 10;
+ DOUBLE = 11;
+ UINT32 = 12;
+ UINT64 = 13;
+ COMPLEX64 = 14; // complex with float32 real and imaginary components
+ COMPLEX128 = 15; // complex with float64 real and imaginary components
+ // Future extensions go here.
- optional string debug_info = 8;
+ // Special data type, real type information is stored in ValueInfoProto.
+ // If data_type is SPECIAL, raw_data should be used.
+ SPECIAL = 51;
+ }
+
+ // The shape of the tensor.
+ repeated int64 dims = 1;
+ repeated int64 strides = 14;
+
+ // The data type of the tensor.
+ optional DataType data_type = 2;
+
+ // For very large tensors, we may want to store them in chunks, in which
+ // case the following fields will specify the segment that is stored in
+ // the current TensorProto.
+ message Segment {
+ optional int64 begin = 1;
+ optional int64 end = 2;
+ optional int64 chuck_num = 51;
+ optional int64 chuck_id = 52;
+ }
+ // Used as offset in the external shared data.
+ optional Segment segment = 3;
+
+ // Tensor content must be organized in row-major order.
+ //
+ // Depending on the data_type field, exactly one of the fields below with
+ // name ending in _data is used to store the elements of the tensor.
+
+ // For float and complex64 values
+ // Complex64 tensors are encoded as a single array of floats,
+ // with the real components appearing in odd numbered positions,
+ // and the corresponding imaginary component apparing in the
+ // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
+ // is encoded as [1.0, 2.0 ,3.0 ,4.0]
+ // When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
+ repeated float float_data = 4 [packed = true];
+
+ // For int32, uint8, int8, uint16, int16, bool, and Half values
+ // float16 values must be bit-wise converted to an uint16_t prior
+ // to writing to the buffer.
+ // When this field is present, the data_type field MUST be
+ // INT32, INT16, INT8, UINT16, INT8, BOOL, or FLOAT16
+ repeated int32 int32_data = 5 [packed = true];
+
+ // For strings.
+ // Each element of string_data is a UTF-8 encoded Unicode
+ // string. No trailing null, no leading BOM. The protobuf "string"
+ // scalar type is not used to match ML community conventions.
+ // When this field is present, the data_type field MUST be STRING
+ repeated bytes string_data = 6;
+
+ // For int64.
+ // When this field is present, the data_type field MUST be INT64
+ repeated int64 int64_data = 7 [packed = true];
+
+ // Optionally, a name for the tensor.
+ optional string name = 8; // namespace Value
+
+ // A human-readable documentation for this tensor. Markdown is allowed.
+ optional string doc_string = 12;
+
+ // Serializations can either use one of the fields above, or use this
+ // raw bytes field. The only exception is the string case, where one is
+ // required to store the content in the repeated bytes string_data field.
+ //
+ // When this raw_data field is used to store tensor value, elements MUST
+ // be stored in as fixed-width, little-endian order.
+ // Floating-point data types MUST be stored in IEEE 754 format.
+ // Complex64 elements must be written as two consecutive FLOAT values, real component first.
+ // Complex128 elements must be written as two consecutive DOUBLE values, real component first.
+ // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
+ //
+ // Note: the advantage of specific field rather than the raw_data field is
+ // that in some cases (e.g. int data), protobuf does a better packing via
+ // variable length storage, and may lead to smaller binary footprint.
+ // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED
+ optional bytes raw_data = 9;
+
+ // For double
+ // Complex64 tensors are encoded as a single array of doubles,
+ // with the real components appearing in odd numbered positions,
+ // and the corresponding imaginary component apparing in the
+ // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
+ // is encoded as [1.0, 2.0 ,3.0 ,4.0]
+ // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128
+ repeated double double_data = 10 [packed = true];
+
+ // For uint64 and uint32 values
+ // When this field is present, the data_type field MUST be
+ // UINT32 or UINT64
+ repeated uint64 uint64_data = 11 [packed = true];
+
+ // External data by file name
+ optional string external_data = 13;
+
+ // If two tensors represent the same weights/content, use alias.
+ // Must exist a TensorProto named alias in the initializer list.
+ // To avoid the duplicate tensor in attribute, such as value in Constant node.
+ // This is useful, if everything is stored just in the proto.
+ optional string alias = 16;
+
+ // Additional annotations.
+ repeated AttributeProto annotations = 17;
+
+ // Device info
+ optional caffe2.DeviceOption device_option = 51;
+
+ // For PyTorch serialized tensor.
+ optional int64 require_gradient = 52;
+ optional int64 is_buffer = 53;
+}
+
+// Defines a tensor shape. A dimension can be either an integer value
+// or a symbolic variable. A symbolic variable represents an unknown
+// dimension.
+message TensorShapeProto {
+ message Dimension {
+ oneof value {
+ int64 dim_value = 1;
+ string dim_param = 2; // namespace Shape
+ };
+ // Standard denotation can optionally be used to denote tensor
+ // dimensions with standard semantic descriptions to ensure
+ // that operations are applied to the correct axis of a tensor.
+ // Refer to https://github.com/onnx/onnx/blob/master/docs/DimensionDenotation.md#denotation-definition
+ // for pre-defined dimension denotations.
+ optional string denotation = 3;
+ };
+ // To represent a scalar, using no dim to represent 0-d tensor.
+ repeated Dimension dim = 1;
+
+ repeated Dimension stride = 51;
+}
+
+// Types
+//
+// The standard ONNX data types.
+message TypeProto {
+
+ message Tensor {
+ // This field MUST NOT have the value of UNDEFINED
+ // This field MUST be present for this version of the IR.
+ optional TensorProto.DataType elem_type = 1;
+ optional TensorShapeProto shape = 2;
+ }
+
+ // Sequence type: List, Tuple
+ message Sequence {
+ // elem_type and elem_type_list cannot appear together.
+ // If all the element types are the same, we use elem_type,
+ // otherwise, we specify the type of each element in elem_type_list.
+ optional TypeProto elem_type = 1;
+ repeated TypeProto elem_type_list = 51;
+ enum SequenceType {
+ UNDEFINED = 0;
+ LIST = 1;
+ TUPLE = 2;
+ }
+ optional SequenceType sequence_type = 52;
+ }
+
+ // Map<K, V>, (not necessary at this moment)
+ message Map {
+ optional TensorProto.DataType key_type = 1;
+ optional TypeProto value_type = 2;
+ }
+
+ // Special type of blobs, based on the type_name, we can choose the right
+ // serializer and deserialzier.
+ message SpecialBlob {
+ optional string type_name = 1;
+ }
+
+ oneof value {
+ // The type of a tensor.
+ Tensor tensor_type = 1;
+ Sequence sequence_type = 4;
+ Map map_type = 5;
+ SpecialBlob special_type = 51;
+ }
+
+ // An optional denotation can be used to denote the whole
+ // type with a standard semantic description as to what is
+ // stored inside. Refer to https://github.com/onnx/onnx/blob/master/docs/TypeDenotation.md#type-denotation-definition
+ // for pre-defined type denotations.
+ optional string denotation = 6;
+}
- // annotations, it is used for MetaNetDef's metadata
- repeated caffe2.Argument annotations = 9;
+// Operator Sets
+//
+// OperatorSets are uniquely identified by a (domain, opset_version) pair.
+message OperatorSetIdProto {
+ // The domain of the operator set being identified.
+ // The empty string ("") or absence of this field implies the operator
+ // set that is defined as part of the ONNX specification.
+ // This field MUST be present in this version of the IR when referring to any other operator set.
+ optional string domain = 1;
+ // The version of the operator set being identified.
+ // This field MUST be present in this version of the IR.
+ optional int64 version = 2;
}
diff --git a/caffe2/python/convert.py b/caffe2/python/convert.py
index 44f81d6e2d..50eaf220c7 100644
--- a/caffe2/python/convert.py
+++ b/caffe2/python/convert.py
@@ -8,3 +8,59 @@ from __future__ import unicode_literals
from caffe2.proto import caffe2_pb2, torch_pb2
import caffe2.python._import_c_extension as C
+
+
+def ArgumentToAttributeProto(arg):
+ serialized_arg = None
+ if hasattr(arg, 'SerializeToString') and callable(arg.SerializeToString):
+ serialized_arg = arg.SerializeToString()
+ elif isinstance(arg, bytes):
+ serialized_arg = arg
+ else:
+ raise ValueError('No SerializeToString method is detected. '
+ 'neither arg is bytes.\ntype is {}'.format(type(arg)))
+ attr = torch_pb2.AttributeProto()
+ attr.ParseFromString(C.argument_to_attribute_proto(serialized_arg))
+ return attr
+
+
+def AttributeProtoToArgument(attr):
+ serialized_attr = None
+ if hasattr(attr, 'SerializeToString') and callable(attr.SerializeToString):
+ serialized_attr = attr.SerializeToString()
+ elif isinstance(attr, bytes):
+ serialized_attr = attr
+ else:
+ raise ValueError('No SerializeToString method is detected. '
+ 'neither attr is bytes.\ntype is {}'.format(type(attr)))
+ arg = caffe2_pb2.Argument()
+ arg.ParseFromString(C.attribute_proto_to_argument(serialized_attr))
+ return arg
+
+
+def OperatorDefToNodeProto(op_def):
+ serialized_op_def = None
+ if hasattr(op_def, 'SerializeToString') and callable(op_def.SerializeToString):
+ serialized_op_def = op_def.SerializeToString()
+ elif isinstance(op_def, bytes):
+ serialized_op_def = op_def
+ else:
+ raise ValueError('No SerializeToString method is detected. '
+ 'neither op_def is bytes.\ntype is {}'.format(type(op_def)))
+ node = torch_pb2.NodeProto()
+ node.ParseFromString(C.operator_def_to_node_proto(serialized_op_def))
+ return node
+
+
+def NodeProtoToOperatorDef(node_proto):
+ serialized_node_proto = None
+ if hasattr(node_proto, 'SerializeToString') and callable(node_proto.SerializeToString):
+ serialized_node_proto = node_proto.SerializeToString()
+ elif isinstance(node_proto, bytes):
+ serialized_node_proto = node_proto
+ else:
+ raise ValueError('No SerializeToString method is detected. '
+ 'neither node_proto is bytes.\ntype is {}'.format(type(node_proto)))
+ op_def = caffe2_pb2.OperatorDef()
+ op_def.ParseFromString(C.node_proto_to_operator_def(serialized_node_proto))
+ return op_def
diff --git a/caffe2/python/convert_test.py b/caffe2/python/convert_test.py
index 82c969c901..c8de7e9750 100644
--- a/caffe2/python/convert_test.py
+++ b/caffe2/python/convert_test.py
@@ -12,5 +12,239 @@ class TestOperator(unittest.TestCase):
def setUp(self):
workspace.ResetWorkspace()
+ def testArgument2AttributeProto(self):
+ arg_f = caffe2_pb2.Argument()
+ arg_f.name = "TestArgF"
+ arg_f.f = 10.0
+ attr_f = convert.ArgumentToAttributeProto(arg_f)
+ self.assertEqual(attr_f.name, arg_f.name)
+ self.assertEqual(attr_f.f, arg_f.f)
+
+ arg_i = caffe2_pb2.Argument()
+ arg_i.name = "TestArgI"
+ arg_i.i = 100
+ attr_i = convert.ArgumentToAttributeProto(arg_i)
+ self.assertEqual(attr_i.name, arg_i.name)
+ self.assertEqual(attr_i.i, arg_i.i)
+
+ arg_s = caffe2_pb2.Argument()
+ arg_s.name = "TestArgS"
+ arg_s.s = "TestS".encode("utf-8")
+ attr_s = convert.ArgumentToAttributeProto(arg_s)
+ self.assertEqual(attr_s.name, arg_s.name)
+ self.assertEqual(attr_s.s, arg_s.s)
+
+ # TODO: test net arg
+
+ arg_floats = caffe2_pb2.Argument()
+ arg_floats.name = "TestArgFloats"
+ arg_floats.floats.extend([10.0, 11.0, 12.0])
+ attr_floats = convert.ArgumentToAttributeProto(arg_floats)
+ self.assertEqual(attr_floats.name, arg_floats.name)
+ self.assertEqual(attr_floats.floats, arg_floats.floats)
+
+ arg_ints = caffe2_pb2.Argument()
+ arg_ints.name = "TestArgInts"
+ arg_ints.ints.extend([100, 101, 102])
+ attr_ints = convert.ArgumentToAttributeProto(arg_ints)
+ self.assertEqual(attr_ints.name, arg_ints.name)
+ self.assertEqual(attr_ints.ints, arg_ints.ints)
+
+ arg_strings = caffe2_pb2.Argument()
+ arg_strings.name = "TestArgStrings"
+ arg_strings.strings.extend([
+ "TestStrings1".encode("utf-8"),
+ "TestStrings2".encode("utf-8"),
+ ])
+ attr_strings = convert.ArgumentToAttributeProto(arg_strings)
+ self.assertEqual(attr_strings.name, arg_strings.name)
+ self.assertEqual(attr_strings.strings, arg_strings.strings)
+
+ # TODO: test nets arg
+
+ def testAttributeProto2Argument(self):
+ attr_f = torch_pb2.AttributeProto()
+ attr_f.type = torch_pb2.AttributeProto.FLOAT
+ attr_f.name = "TestAttrF"
+ attr_f.f = 10.0
+ arg_f = convert.AttributeProtoToArgument(attr_f)
+ self.assertEqual(arg_f.name, attr_f.name)
+ self.assertEqual(arg_f.f, attr_f.f)
+
+ attr_i = torch_pb2.AttributeProto()
+ attr_i.type = torch_pb2.AttributeProto.INT
+ attr_i.name = "TestArgI"
+ attr_i.i = 100
+ arg_i = convert.AttributeProtoToArgument(attr_i)
+ self.assertEqual(arg_i.name, attr_i.name)
+ self.assertEqual(arg_i.i, attr_i.i)
+
+ attr_s = torch_pb2.AttributeProto()
+ attr_s.type = torch_pb2.AttributeProto.STRING
+ attr_s.name = "TestArgS"
+ attr_s.s = "TestS".encode("utf-8")
+ arg_s = convert.AttributeProtoToArgument(attr_s)
+ self.assertEqual(arg_s.name, attr_s.name)
+ self.assertEqual(arg_s.s, attr_s.s)
+
+ # TODO: test graph attribute
+
+ attr_floats = torch_pb2.AttributeProto()
+ attr_floats.type = torch_pb2.AttributeProto.FLOATS
+ attr_floats.name = "TestAttrFloats"
+ attr_floats.floats.extend([10.0, 11.0, 12.0])
+ arg_floats = convert.AttributeProtoToArgument(attr_floats)
+ self.assertEqual(arg_floats.name, attr_floats.name)
+ self.assertEqual(arg_floats.floats, attr_floats.floats)
+
+ attr_ints = torch_pb2.AttributeProto()
+ attr_ints.type = torch_pb2.AttributeProto.INTS
+ attr_ints.name = "TestArgInts"
+ attr_ints.ints.extend([100, 101, 102])
+ arg_ints = convert.AttributeProtoToArgument(attr_ints)
+ self.assertEqual(arg_ints.name, attr_ints.name)
+ self.assertEqual(arg_ints.ints, attr_ints.ints)
+
+ attr_strings = torch_pb2.AttributeProto()
+ attr_strings.type = torch_pb2.AttributeProto.STRINGS
+ attr_strings.name = "TestArgStrings"
+ attr_strings.strings.extend([
+ "TestStrings1".encode("utf-8"),
+ "TestStrings2".encode("utf-8"),
+ ])
+ arg_strings = convert.AttributeProtoToArgument(attr_strings)
+ self.assertEqual(arg_strings.name, attr_strings.name)
+ self.assertEqual(arg_strings.strings, attr_strings.strings)
+
+ # TODO: test graphs attribute
+
+
+ def testOperatorDef2NodeProto(self):
+ op_def = caffe2_pb2.OperatorDef()
+ op_def.input.extend(["A", "B", "C"])
+ op_def.output.extend(["X", "Y"])
+ op_def.name = "TestOpName"
+ op_def.type = "TestOp"
+ arg1 = caffe2_pb2.Argument()
+ arg1.name = "TestArg1"
+ arg1.i = 1
+ arg2 = caffe2_pb2.Argument()
+ arg2.name = "TestArg2"
+ arg1.s = "TestInfo".encode("utf-8")
+ op_def.arg.extend([arg1, arg2])
+ op_def.device_option.CopyFrom(caffe2_pb2.DeviceOption())
+ op_def.engine = "TestEngine".encode("utf-8")
+ op_def.control_input.extend(["input1", "input2"])
+ op_def.is_gradient_op = True
+ op_def.debug_info = "TestDebugInfo"
+
+ node = convert.OperatorDefToNodeProto(op_def)
+
+ self.assertEqual(node.input, op_def.input)
+ self.assertEqual(node.output, op_def.output)
+ self.assertEqual(node.name, op_def.name)
+ self.assertEqual(node.op_type, op_def.type)
+ self.assertEqual(node.attribute[0].name, op_def.arg[0].name)
+ self.assertEqual(node.attribute[1].name, op_def.arg[1].name)
+ self.assertEqual(node.device_option, op_def.device_option)
+ node_engine = [a.s.decode("utf-8") for a in node.annotations if a.name == "engine"][0]
+ self.assertEqual(node_engine, op_def.engine)
+ node_control_input = [a.strings for a in node.annotations if a.name == "control_input"][0]
+ self.assertEqual(len(node_control_input), len(op_def.control_input))
+ for x, y in zip(node_control_input, op_def.control_input):
+ self.assertEqual(x.decode("utf-8"), y)
+ self.assertEqual(node.doc_string, op_def.debug_info)
+ node_is_gradient_op = [a.i for a in node.annotations if a.name == "is_gradient_op"][0]
+ self.assertEqual(node_is_gradient_op, int(op_def.is_gradient_op))
+
+ def testNodeProto2OperatorDef(self):
+ node = torch_pb2.NodeProto()
+ node.input.extend(["A", "B", "C"])
+ node.output.extend(["X", "Y"])
+ node.name = "TestOpName"
+ node.op_type = "TestOp"
+ attr1 = torch_pb2.AttributeProto()
+ attr1.name = "TestAttr1"
+ attr1.type = torch_pb2.AttributeProto.STRING
+ attr1.s = "TestInfo".encode("utf-8")
+ attr2 = torch_pb2.AttributeProto()
+ attr2.name = "TestAttr2"
+ attr2.type = torch_pb2.AttributeProto.INT
+ attr2.i = 10
+ node.attribute.extend([attr1, attr2])
+ node.device_option.CopyFrom(caffe2_pb2.DeviceOption())
+ anno1 = torch_pb2.AttributeProto()
+ anno1.name = "engine"
+ anno1.type = torch_pb2.AttributeProto.STRING
+ anno1.s = "TestEngine".encode("utf-8")
+ anno2 = torch_pb2.AttributeProto()
+ anno2.name = "control_input"
+ anno2.type = torch_pb2.AttributeProto.STRINGS
+ anno2.strings.extend(["input1".encode("utf-8"), "input2".encode("utf-8")])
+ anno3 = torch_pb2.AttributeProto()
+ anno3.name = "is_gradient_op"
+ anno3.type = torch_pb2.AttributeProto.INT
+ anno3.i = 1
+ node.annotations.extend([anno1, anno2, anno3])
+ node.doc_string = "TestDocString".encode("utf-8")
+
+ op_def = convert.NodeProtoToOperatorDef(node)
+
+ self.assertEqual(op_def.input, node.input)
+ self.assertEqual(op_def.output, node.output)
+ self.assertEqual(op_def.name, node.name)
+ self.assertEqual(op_def.type, node.op_type)
+ self.assertEqual(op_def.arg[0].name, node.attribute[0].name)
+ self.assertEqual(op_def.arg[1].name, node.attribute[1].name)
+ self.assertEqual(op_def.device_option, node.device_option)
+ node_engine = [a.s for a in node.annotations if a.name == "engine"][0]
+ self.assertEqual(op_def.engine, node_engine.decode("utf-8"))
+ node_control_input = [a.strings for a in node.annotations if a.name == "control_input"][0]
+ for x, y in zip(op_def.control_input, node_control_input):
+ self.assertEqual(x, y.decode("utf-8"))
+ self.assertEqual(op_def.debug_info, node.doc_string)
+ node_is_gradient_op = [a.i for a in node.annotations if a.name == "is_gradient_op"][0]
+ self.assertEqual(int(op_def.is_gradient_op), node_is_gradient_op)
+
+ def testEnd2End(self):
+ op_def = caffe2_pb2.OperatorDef()
+ op_def.type = "Add"
+ op_def.input.extend(["input1"])
+ op_def.input.extend(["input2"])
+ op_def.output.extend(["output1"])
+ node = convert.OperatorDefToNodeProto(op_def)
+
+ input1 = np.random.randn(1, 3, 1, 5).astype(np.float32)
+ input2 = np.random.randn(2, 1, 4, 1).astype(np.float32)
+ ref_output1 = input1 + input2
+ workspace.FeedBlob("input1", input1)
+ workspace.FeedBlob("input2", input2)
+ self.assertEqual(workspace.RunOperatorOnce(node.SerializeToString(), legacy_proto=False), True)
+
+ self.assertEqual(workspace.HasBlob("output1"), True)
+ fetched_back = workspace.FetchBlob("output1")
+ np.testing.assert_array_equal(fetched_back, ref_output1)
+
+ def testRoundTrip(self):
+ op_def = caffe2_pb2.OperatorDef()
+ op_def.type = "Add"
+ op_def.input.extend(["input1"])
+ op_def.input.extend(["input2"])
+ op_def.output.extend(["output1"])
+ node = convert.OperatorDefToNodeProto(op_def)
+ new_op_def = convert.NodeProtoToOperatorDef(node)
+
+ input1 = np.random.randn(1, 3, 1, 5).astype(np.float32)
+ input2 = np.random.randn(2, 1, 4, 1).astype(np.float32)
+ ref_output1 = input1 + input2
+ workspace.FeedBlob("input1", input1)
+ workspace.FeedBlob("input2", input2)
+ self.assertEqual(workspace.RunOperatorOnce(new_op_def.SerializeToString()), True)
+
+ self.assertEqual(workspace.HasBlob("output1"), True)
+ fetched_back = workspace.FetchBlob("output1")
+ np.testing.assert_array_equal(fetched_back, ref_output1)
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/caffe2/python/pybind_state.cc b/caffe2/python/pybind_state.cc
index 7ebee57d49..7062ead045 100644
--- a/caffe2/python/pybind_state.cc
+++ b/caffe2/python/pybind_state.cc
@@ -1187,10 +1187,17 @@ void addGlobalMethods(py::module& m) {
return true;
});
m.def("nets", []() { return gWorkspace->Nets(); });
- m.def("run_operator_once", [](const py::bytes& op_def) {
+ m.def("run_operator_once", [](const py::bytes& op_def, bool legacy_proto=true) {
CAFFE_ENFORCE(gWorkspace);
OperatorDef def;
- CAFFE_ENFORCE(ParseProtoFromLargeString(op_def.cast<std::string>(), &def));
+ if (legacy_proto) {
+ CAFFE_ENFORCE(ParseProtoFromLargeString(op_def.cast<std::string>(), &def));
+ } else {
+ ::torch::NodeProto node;
+ CAFFE_ENFORCE(
+ ParseProtoFromLargeString(op_def.cast<std::string>(), &node));
+ NodeProtoToOperatorDef(node, &def);
+ }
py::gil_scoped_release g;
CAFFE_ENFORCE(gWorkspace->RunOperatorOnce(def));
return true;
@@ -1527,6 +1534,38 @@ void addGlobalMethods(py::module& m) {
CAFFE_ENFORCE(blob);
return BlobStat::sizeBytes(*blob);
});
+ m.def("argument_to_attribute_proto", [](py::bytes arg_str) -> py::bytes {
+ Argument arg;
+ CAFFE_ENFORCE(
+ ParseProtoFromLargeString(arg_str.cast<std::string>(), &arg));
+ ::torch::AttributeProto attr;
+ ArgumentToAttributeProto(arg, &attr);
+ return attr.SerializeAsString();
+ });
+ m.def("attribute_proto_to_argument", [](py::bytes attr_str) -> py::bytes {
+ ::torch::AttributeProto attr;
+ CAFFE_ENFORCE(
+ ParseProtoFromLargeString(attr_str.cast<std::string>(), &attr));
+ Argument arg;
+ AttributeProtoToArgument(attr, &arg);
+ return arg.SerializeAsString();
+ });
+ m.def("operator_def_to_node_proto", [](py::bytes op_str) -> py::bytes {
+ OperatorDef op_def;
+ CAFFE_ENFORCE(
+ ParseProtoFromLargeString(op_str.cast<std::string>(), &op_def));
+ ::torch::NodeProto node;
+ OperatorDefToNodeProto(op_def, &node);
+ return node.SerializeAsString();
+ });
+ m.def("node_proto_to_operator_def", [](py::bytes node_str) -> py::bytes {
+ ::torch::NodeProto node_proto;
+ CAFFE_ENFORCE(
+ ParseProtoFromLargeString(node_str.cast<std::string>(), &node_proto));
+ OperatorDef op_def;
+ NodeProtoToOperatorDef(node_proto, &op_def);
+ return op_def.SerializeAsString();
+ });
m.def("support_onnx_export", [](const std::string& op) -> bool {
const OpSchema* schema = caffe2::OpSchemaRegistry::Schema(op);
if (!schema) {
diff --git a/caffe2/python/workspace.py b/caffe2/python/workspace.py
index ef02f64dc9..a41cc15317 100644
--- a/caffe2/python/workspace.py
+++ b/caffe2/python/workspace.py
@@ -163,8 +163,8 @@ def GetOperatorCost(operator, blobs):
return C.get_operator_cost(StringifyProto(operator), blobs)
-def RunOperatorOnce(operator):
- return C.run_operator_once(StringifyProto(operator))
+def RunOperatorOnce(operator, legacy_proto=True):
+ return C.run_operator_once(StringifyProto(operator), legacy_proto)
def RunOperatorsOnce(operators):
diff --git a/caffe2/utils/proto_convert.cc b/caffe2/utils/proto_convert.cc
index 1d69c8c80c..790bd27429 100644
--- a/caffe2/utils/proto_convert.cc
+++ b/caffe2/utils/proto_convert.cc
@@ -2,4 +2,185 @@
#include "caffe2/core/logging.h"
namespace caffe2 {
+
+C10_EXPORT void ArgumentToAttributeProto(
+ const Argument& arg,
+ ::torch::AttributeProto* attr) {
+ CAFFE_ENFORCE(arg.has_name());
+ attr->set_name(arg.name());
+ if (arg.has_f()) {
+ attr->set_f(arg.f());
+ } else if (arg.has_i()) {
+ attr->set_i(arg.i());
+ } else if (arg.has_s()) {
+ attr->set_s(arg.s());
+ } else if (arg.has_n()) {
+ // TODO
+ CAFFE_THROW("NetDef conversion is not implemented yet.");
+ } else if (arg.floats_size() > 0) {
+ attr->mutable_floats()->CopyFrom(arg.floats());
+ } else if (arg.ints_size() > 0) {
+ attr->mutable_ints()->CopyFrom(arg.ints());
+ } else if (arg.strings_size() > 0) {
+ attr->mutable_strings()->CopyFrom(arg.strings());
+ } else if (arg.nets_size() > 0) {
+ // TODO
+ CAFFE_THROW("NetDefs conversion is not implemented yet.");
+ }
+}
+
+C10_EXPORT void AttributeProtoToArgument(
+ const ::torch::AttributeProto& attr,
+ Argument* arg) {
+ CAFFE_ENFORCE(attr.has_name());
+ arg->set_name(attr.name());
+ CAFFE_ENFORCE(attr.has_type());
+ const auto type = attr.type();
+ if (type ==
+ ::torch::AttributeProto_AttributeType::
+ AttributeProto_AttributeType_FLOAT) {
+ CAFFE_ENFORCE(attr.has_f());
+ arg->set_f(attr.f());
+ } else if (
+ type ==
+ ::torch::AttributeProto_AttributeType::AttributeProto_AttributeType_INT) {
+ CAFFE_ENFORCE(attr.has_i());
+ arg->set_i(attr.i());
+ } else if (
+ type ==
+ ::torch::AttributeProto_AttributeType::
+ AttributeProto_AttributeType_STRING) {
+ CAFFE_ENFORCE(attr.has_s());
+ arg->set_s(attr.s());
+ } else if (
+ type ==
+ ::torch::AttributeProto_AttributeType::
+ AttributeProto_AttributeType_TENSOR) {
+ CAFFE_THROW("Caffe2's Argument does not support tensor as attribute.");
+ } else if (
+ type ==
+ ::torch::AttributeProto_AttributeType::
+ AttributeProto_AttributeType_GRAPH) {
+ // TODO
+ CAFFE_THROW("GraphProto conversion is not implemented yet.");
+ } else if (
+ type ==
+ ::torch::AttributeProto_AttributeType::
+ AttributeProto_AttributeType_FLOATS) {
+ arg->mutable_floats()->CopyFrom(attr.floats());
+ } else if (
+ type ==
+ ::torch::AttributeProto_AttributeType::
+ AttributeProto_AttributeType_INTS) {
+ arg->mutable_ints()->CopyFrom(attr.ints());
+ } else if (
+ type ==
+ ::torch::AttributeProto_AttributeType::
+ AttributeProto_AttributeType_STRINGS) {
+ arg->mutable_strings()->CopyFrom(attr.strings());
+ } else if (
+ type ==
+ ::torch::AttributeProto_AttributeType::
+ AttributeProto_AttributeType_TENSORS) {
+ CAFFE_THROW("Caffe2's Argument does not support tensors as attribute.");
+ } else if (
+ type ==
+ ::torch::AttributeProto_AttributeType::
+ AttributeProto_AttributeType_GRAPHS) {
+ // TODO
+ CAFFE_THROW("GraphProtos conversion is not implemented yet.");
+ } else {
+ CAFFE_THROW("Unknow Attribute type.");
+ }
+}
+
+C10_EXPORT void OperatorDefToNodeProto(
+ const OperatorDef& def,
+ ::torch::NodeProto* node) {
+ node->mutable_input()->CopyFrom(def.input());
+ node->mutable_output()->CopyFrom(def.output());
+ if (def.has_name()) {
+ node->set_name(def.name());
+ }
+ CAFFE_ENFORCE(def.has_type());
+ node->set_op_type(def.type());
+ for (int i = 0; i < def.arg_size(); ++i) {
+ auto attr = node->add_attribute();
+ ArgumentToAttributeProto(def.arg(i), attr);
+ }
+ if (def.has_device_option()) {
+ node->mutable_device_option()->CopyFrom(def.device_option());
+ }
+ if (def.has_engine()) {
+ auto attr = node->add_annotations();
+ attr->set_name("engine");
+ attr->set_type(::torch::AttributeProto_AttributeType::
+ AttributeProto_AttributeType_STRING);
+ attr->set_s(def.engine());
+ }
+ if (def.control_input_size() > 0) {
+ auto attr = node->add_annotations();
+ attr->set_name("control_input");
+ attr->set_type(::torch::AttributeProto_AttributeType::
+ AttributeProto_AttributeType_STRINGS);
+ attr->mutable_strings()->CopyFrom(def.control_input());
+ }
+ if (def.has_is_gradient_op()) {
+ auto attr = node->add_annotations();
+ attr->set_name("is_gradient_op");
+ attr->set_type(::torch::AttributeProto_AttributeType::
+ AttributeProto_AttributeType_INT);
+ if (def.is_gradient_op()) {
+ attr->set_i(1);
+ } else {
+ attr->set_i(0);
+ }
+ }
+ if (def.has_debug_info()) {
+ node->set_doc_string(def.debug_info());
+ }
+}
+
+C10_EXPORT void NodeProtoToOperatorDef(
+ const ::torch::NodeProto& node,
+ OperatorDef* def) {
+ def->mutable_input()->CopyFrom(node.input());
+ def->mutable_output()->CopyFrom(node.output());
+ if (node.has_name()) {
+ def->set_name(node.name());
+ }
+
+ CAFFE_ENFORCE(node.has_op_type());
+ def->set_type(node.op_type());
+ for (int i = 0; i < node.attribute_size(); ++i) {
+ auto arg = def->add_arg();
+ AttributeProtoToArgument(node.attribute(i), arg);
+ }
+ if (node.has_doc_string()) {
+ def->set_debug_info(node.doc_string());
+ }
+ for (int i = 0; i < node.annotations_size(); ++i) {
+ const auto& attr = node.annotations(i);
+ CAFFE_ENFORCE(attr.has_name());
+ if (attr.name() == "engine") {
+ CAFFE_ENFORCE(attr.has_s());
+ def->set_engine(attr.s());
+ } else if (attr.name() == "control_input") {
+ def->mutable_control_input()->CopyFrom(attr.strings());
+ } else if (attr.name() == "is_gradient_op") {
+ CAFFE_ENFORCE(attr.has_i());
+ if (i == 0) {
+ def->set_is_gradient_op(false);
+ } else {
+ def->set_is_gradient_op(true);
+ }
+ }
+ auto arg = def->add_arg();
+ AttributeProtoToArgument(node.annotations(i), arg);
+ }
+ if (node.has_device_option()) {
+ def->mutable_device_option()->CopyFrom(node.device_option());
+ }
+}
+
} // namespace caffe2
diff --git a/caffe2/utils/proto_convert.h b/caffe2/utils/proto_convert.h
index 91bcf1bafa..a9ca9c3ad4 100644
--- a/caffe2/utils/proto_convert.h
+++ b/caffe2/utils/proto_convert.h
@@ -6,6 +6,20 @@
#include "caffe2/proto/torch_pb.h"
namespace caffe2 {
+
+CAFFE2_API void ArgumentToAttributeProto(
+ const Argument& arg,
+ ::torch::AttributeProto* attr);
+CAFFE2_API void AttributeProtoToArgument(
+ const ::torch::AttributeProto& attr,
+ Argument* arg);
+CAFFE2_API void OperatorDefToNodeProto(
+ const OperatorDef& def,
+ ::torch::NodeProto* node);
+CAFFE2_API void NodeProtoToOperatorDef(
+ const ::torch::NodeProto& node,
+ OperatorDef* def);
+
} // namespace caffe2
#endif // CAFFE2_UTILS_PROTO_CONVERT_H_