summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLu Fang <lufang@fb.com>2018-10-01 15:42:45 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-10-01 15:57:02 -0700
commit35becd1879fae56fc417905c5154c402b9780a3f (patch)
tree83a7a28e04290a5044c1f8d88d87d7ad734b6a4e
parent8fa7de35f2700b8b23fcd8f98d2f5cdbabae9c95 (diff)
downloadpytorch-35becd1879fae56fc417905c5154c402b9780a3f.tar.gz
pytorch-35becd1879fae56fc417905c5154c402b9780a3f.tar.bz2
pytorch-35becd1879fae56fc417905c5154c402b9780a3f.zip
New version of PT1 model format (#12149)
Summary: Considered four different existing formats: 1) static graph, 2) torch script, 3) pickle files, 4) PyTorch C++ serialize APIs Pull Request resolved: https://github.com/pytorch/pytorch/pull/12149 Reviewed By: BIT-silence Differential Revision: D10098106 Pulled By: houseroad fbshipit-source-id: 94ec7fc57c842e50fae5286ddeda657a4967a07a
-rw-r--r--caffe2/core/blob_serialization.cc15
-rw-r--r--caffe2/proto/caffe2.proto86
-rw-r--r--caffe2/proto/torch.proto564
-rw-r--r--caffe2/python/convert.py56
-rw-r--r--caffe2/python/convert_test.py234
-rw-r--r--caffe2/python/pybind_state.cc43
-rw-r--r--caffe2/python/workspace.py4
-rw-r--r--caffe2/utils/proto_convert.cc181
-rw-r--r--caffe2/utils/proto_convert.h14
9 files changed, 139 insertions, 1058 deletions
diff --git a/caffe2/core/blob_serialization.cc b/caffe2/core/blob_serialization.cc
index 8126b3d594..f27d16adf3 100644
--- a/caffe2/core/blob_serialization.cc
+++ b/caffe2/core/blob_serialization.cc
@@ -309,6 +309,12 @@ void TensorSerializer::Serialize(
proto.add_string_data(SerializeBlob(temp_blob, ""));
}
} break;
+ case TensorProto_DataType_SPECIAL: {
+ CAFFE_THROW("SPECIAL Tensor is not handled yet.");
+ } break;
+ case TensorProto_DataType_NO_CONTENT: {
+ CAFFE_THROW("NO_CONTENT Tensor should not be serialized.");
+ } break;
// Note: we intentially do not provide "default:" so if any new data types
// are added, the compiler should warn the user to add the case here.
}
@@ -520,7 +526,14 @@ void TensorDeserializer::Deserialize(const TensorProto& proto, Tensor* tensor) {
(i + chunkBegin) * temp_blob.meta().itemsize(),
1);
}
- }
+ } break;
+ case TensorProto_DataType_SPECIAL: {
+ CAFFE_THROW("SPECIAL Tensor is not handled yet.");
+ } break;
+ case TensorProto_DataType_NO_CONTENT: {
+ CAFFE_THROW("NO_CONTENT Tensor should not be deserialized.");
+ } break;
+ // Note: we intentially do not provide "default:" so if any new data types
}
context->FinishDeviceComputation();
}
diff --git a/caffe2/proto/caffe2.proto b/caffe2/proto/caffe2.proto
index 21bdec2c68..63a2a256de 100644
--- a/caffe2/proto/caffe2.proto
+++ b/caffe2/proto/caffe2.proto
@@ -15,23 +15,46 @@ package caffe2;
message TensorProto {
// The dimensions in the tensor.
repeated int64 dims = 1;
+ // The strides of the tensor.
+ repeated int64 strides = 12;
+
+ // Data type
enum DataType {
UNDEFINED = 0;
- FLOAT = 1; // float
- INT32 = 2; // int
- BYTE = 3; // BYTE, when deserialized, is going to be restored as uint8.
- STRING = 4; // string
- // Less-commonly used data types.
- BOOL = 5; // bool
- UINT8 = 6; // uint8_t
- INT8 = 7; // int8_t
- UINT16 = 8; // uint16_t
- INT16 = 9; // int16_t
- INT64 = 10; // int64_t
+
+ // Basic types
+ FLOAT = 1; // float
+ INT32 = 2; // int
+ BYTE = 3; // byte, when deserialized, is going to be restored as uint8
+ STRING = 4; // string
+
+ // Less-commonly used data types
+ BOOL = 5; // bool
+ UINT8 = 6; // uint8_t
+ INT8 = 7; // int8_t
+ UINT16 = 8; // uint16_t
+ INT16 = 9; // int16_t
+ INT64 = 10; // int64_t
FLOAT16 = 12; // at::Half
- DOUBLE = 13; // double
+ DOUBLE = 13; // double
+
+ // Special data type, type information is stored in the special type field
+ SPECIAL = 51;
+ // Use TensorProto to specify the shape and type
+ NO_CONTENT = 52;
}
optional DataType data_type = 2 [default = FLOAT];
+ // if data_type is SPECIAL, use this field to express the type info
+ optional SpecialType special_type = 13;
+
+ // Data storage
+ enum StorageType {
+ TYPED = 1;
+ RAW = 2;
+ EXTERNAL = 3;
+ ALIAS = 4;
+ }
+ optional StorageType storage_type = 14 [default = TYPED];
// For float
repeated float float_data = 3 [packed = true];
// For int32, uint8, int8, uint16, int16, bool, and float16
@@ -46,6 +69,13 @@ message TensorProto {
repeated double double_data = 9 [packed = true];
// For int64
repeated int64 int64_data = 10 [packed = true];
+ // For raw data
+ optional bytes raw_data = 15;
+ // External data by file name
+ optional string external_data = 16;
+ // For argument, to share the content
+ optional string alias = 17;
+
// Optionally, a name for the tensor.
optional string name = 7;
@@ -53,13 +83,23 @@ message TensorProto {
// it was serialized from. This is useful in cases like snapshotting a whole
// workspace in a multi-GPU environment.
optional DeviceOption device_detail = 8;
+
// When loading from chunks this is going to indicate where to put data in the
// full array. When not used full data have to be present
message Segment {
required int64 begin = 1;
required int64 end = 2;
+ optional int64 chunk_num = 51;
+ optional int64 chunk_id = 52;
}
optional Segment segment = 11;
+ optional string debug_info = 18;
+
+ // For PyTorch serialized tensor.
+ optional bool require_gradient = 19;
+ optional bool is_buffer = 20;
+
+ repeated Argument annotations = 21;
}
message QTensorProto {
@@ -86,7 +126,11 @@ message TensorShape {
repeated int32 unknown_dims = 3;
optional bool unknown_shape = 4 [default = false];
optional string name = 5;
+}
+// This is prepared for non-tensor types.
+message SpecialType {
+ optional string name = 1;
}
message TensorShapes {
@@ -97,13 +141,17 @@ message TensorShapes {
// values, or repeated float, int and string arrays.
message Argument {
optional string name = 1;
+
optional float f = 2;
optional int64 i = 3;
optional bytes s = 4;
+ optional TensorProto t = 10;
optional NetDef n = 8;
+
repeated float floats = 5;
repeated int64 ints = 6;
repeated bytes strings = 7;
+ repeated TensorProto tensors = 11;
repeated NetDef nets = 9;
}
@@ -152,7 +200,11 @@ message DeviceOption {
// Operator Definition.
message OperatorDef {
repeated string input = 1; // the name of the input blobs
+ // the input name in the schema, for named inputs
+ repeated string mapped_inputs = 11;
repeated string output = 2; // the name of output top blobs
+ // the outputname in the schema, for named outputs
+ repeated string mapped_outputs = 12;
optional string name = 3; // the operator name. This is optional.
// the operator type. This is needed to create the object from the operator
// registry.
@@ -186,6 +238,16 @@ message OperatorDef {
// This is an optional string with no assumed characteristics as
// operators can be constructed in any language.
optional string debug_info = 10;
+
+ // additional annotations
+ repeated Argument annotations = 13;
+
+ // for jit ir exporting
+ optional string aten_function = 14;
+
+ // for operator versioning
+ optional string domain = 15;
+ optional string op_version = 16;
}
// Network definition.
diff --git a/caffe2/proto/torch.proto b/caffe2/proto/torch.proto
index 43dfd02b14..f31c3b65ec 100644
--- a/caffe2/proto/torch.proto
+++ b/caffe2/proto/torch.proto
@@ -4,547 +4,77 @@ import "caffe2/proto/caffe2.proto";
package torch;
-// Overview
-//
-// ONNX is an open specification that is comprised of the following components:
-//
-// 1) A definition of an extensible computation graph model.
-// 2) Definitions of standard data types.
-// 3) Definitions of built-in operators.
-//
-// This document describes the syntax of models and their computation graphs,
-// as well as the standard data types. Together, they are referred to as the ONNX
-// Intermediate Representation, or 'IR' for short.
-//
-// The normative semantic specification of the ONNX IR is found in docs/IR.md.
-// Definitions of the built-in neural network operators may be found in docs/Operators.md.
-
-// Notes
-//
-// Release
-//
-// We are still in the very early stage of defining ONNX. The current
-// version of ONNX is a starting point. While we are actively working
-// towards a complete spec, we would like to get the community involved
-// by sharing our working version of ONNX.
-//
-// Protobuf compatibility
-//
-// To simplify framework compatibility, ONNX is defined using the subset of
-// protobuf that is compatible with both protobuf v2 and v3. This means that we
-// do not use any protobuf features that are only available in one of the two
-// versions.
-//
-// Here are the most notable contortions we have to carry out to work around
-// these limitations:
-//
-// - No 'map' (added protobuf 3.0). We instead represent mappings as lists
-// of key-value pairs, where order does not matter and duplicates
-// are not allowed.
-
-// Versioning
-//
-// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md
-//
-// To be compatible with both proto2 and proto3, we will use a version number
-// that is not defined by the default value but an explicit enum number.
-enum Version {
- // proto3 requires the first enum value to be zero.
- // We add this just to appease the compiler.
+enum ProtoVersion {
_START_VERSION = 0;
- // The version field is always serialized and we will use it to store the
- // version that the graph is generated from. This helps us set up version
- // control.
- // For the IR, we are using simple numbers starting with with 0x00000001,
- // which was the version we published on Oct 10, 2017.
- IR_VERSION_2017_10_10 = 0x0000000000000001;
-
- // IR_VERSION 2 published on Oct 30, 2017
- // - Added type discriminator to AttributeProto to support proto3 users
- IR_VERSION_2017_10_30 = 0x0000000000000002;
-
- // IR VERSION 3 published on Nov 3, 2017
- // - For operator versioning:
- // - Added new message OperatorSetIdProto
- // - Added opset_import in ModelProto
- // - For vendor extensions, added domain in NodeProto
- IR_VERSION_NEWEST_ONNX = 0x0000000000000003;
-
- // PYTORCH IR VERSION
- IR_VERSION_NEWEST = 0x0000000000000103;
+ IR_VERSION_NEWEST = 0x0000000000000101;
}
-// Attributes
-//
-// A named attribute containing either singular float, integer, string, graph,
-// and tensor values, or repeated float, integer, string, graph, and tensor values.
-// An AttributeProto MUST contain the name field, and *only one* of the
-// following content fields, effectively enforcing a C/C++ union equivalent.
-message AttributeProto {
-
- // Note: this enum is structurally identical to the OpSchema::AttrType
- // enum defined in schema.h. If you rev one, you likely need to rev the other.
- enum AttributeType {
- UNDEFINED = 0;
- FLOAT = 1;
- INT = 2;
- STRING = 3;
- TENSOR = 4;
- GRAPH = 5;
-
- FLOATS = 6;
- INTS = 7;
- STRINGS = 8;
- TENSORS = 9;
- GRAPHS = 10;
- }
-
- // The name field MUST be present for this version of the IR.
- optional string name = 1; // namespace Attribute
+message MethodDef {
+ // method name
+ optional string name = 1; // method name
- // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function.
- // In this case, this AttributeProto does not contain data, and it's a reference of attribute
- // in parent scope.
- // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph.
- optional string ref_attr_name = 21;
+ // static graph
+ optional caffe2.NetDef graph = 2;
+ // method is represented as torch script
+ optional string torch_script = 3;
- // A human-readable documentation for this attribute. Markdown is allowed.
- optional string doc_string = 13;
+ // the names of inputs and outputs
+ repeated string inputs = 4;
+ repeated string outputs = 5;
- // The type field MUST be present for this version of the IR.
- // For 0.0.1 versions of the IR, this field was not defined, and
- // implementations needed to use has_field hueristics to determine
- // which value field was in use. For IR_VERSION 0.0.2 or later, this
- // field MUST be set and match the f|i|s|t|... field in use. This
- // change was made to accomodate proto3 implementations.
- optional AttributeType type = 20; // discriminator that indicates which field below is in use
+ // whether this method is main or not.
+ // by default, `forward` should the main method.
+ optional bool is_main = 6;
- // Exactly ONE of the following fields must be present for this version of the IR
- optional float f = 2; // float
- optional int64 i = 3; // int
- optional bytes s = 4; // UTF-8 string
- optional TensorProto t = 5; // tensor value
- optional GraphProto g = 6; // graph
- // Do not use field below, it's deprecated.
- // optional ValueProto v = 12; // value - subsumes everything but graph
+ optional string debug_info = 7;
- repeated float floats = 7; // list of floats
- repeated int64 ints = 8; // list of ints
- repeated bytes strings = 9; // list of UTF-8 strings
- repeated TensorProto tensors = 10; // list of tensors
- repeated GraphProto graphs = 11; // list of graph
+ repeated caffe2.Argument annotations = 8;
}
-// Defines information on value, including the name, the type, and
-// the shape of the value.
-message ValueInfoProto {
- // This field MUST be present in this version of the IR.
- optional string name = 1; // namespace Value
- // This field MUST be present in this version of the IR.
- optional TypeProto type = 2;
- // A human-readable documentation for this value. Markdown is allowed.
- optional string doc_string = 3;
-}
-
-// Nodes
-//
-// Computation graphs are made up of a DAG of nodes, which represent what is
-// commonly called a "layer" or "pipeline stage" in machine learning frameworks.
-//
-// For example, it can be a node of type "Conv" that takes in an image, a filter
-// tensor and a bias tensor, and produces the convolved output.
-message NodeProto {
- repeated string input = 1; // namespace Value
- repeated string output = 2; // namespace Value
- // An optional identifier for this node in a graph.
- // This field MAY be absent in ths version of the IR.
- optional string name = 3; // namespace Node
+message ModuleDef {
+ repeated ModuleDef submodules = 1;
- // The symbolic identifier of the Operator to execute.
- optional string op_type = 4; // namespace Operator
- // The domain of the OperatorSet that specifies the operator named by op_type.
- optional string domain = 7; // namespace Domain
+ // We suppose to store the modules in one of the following format:
+ // - methods (static graph or torch script)
+ // - pickle
+ // - cpp_arena
+ repeated MethodDef methods = 2;
+ // because the old pickle modules may not be supported by torch_script,
+ // have to stored as pickle_arena at this moment.
+ optional bytes pickle_arena = 3;
+ // should be exposed by the Class Archive, so user can save
+ // module specific data which cannot be store in the graph or torch_script
+ optional bytes cpp_arena = 4;
- // Additional named attributes.
- repeated AttributeProto attribute = 5;
+ // the names of inputs and outputs of the module are inferred
+ // from the main method.
- // A human-readable documentation for this node. Markdown is allowed.
- // Equivalent to string debug_info
- optional string doc_string = 6;
+ optional string debug_info = 5;
- // Additional annotations, attributes are defined in Schema
- // To be added as annotations:
- // string engine
- // string list control_input
- // int64 is_gradient_op
- repeated AttributeProto annotations = 8;
-
- // Besides the node type, PyTorhc also serialize ATen function signature
- optional caffe2.DeviceOption device_option = 51;
- optional string aten_function = 52;
+ repeated caffe2.Argument annotations = 6;
}
-// Models
-//
-// ModelProto is a top-level file/container format for bundling a ML model and
-// associating its computation graph with metadata.
-//
-// The semantics of the model are described by the associated GraphProto.
-//
-// Model ==> Caffe2 MetaNetDef
-// ==> PyTorch Module
-message ModelProto {
- // The version of the IR this model targets. See Version enum above.
- // This field MUST be present.
+message ModelDef {
optional int64 ir_version = 1;
- // The OperatorSets this model relies on.
- // All ModelProtos MUST have at least one entry that
- // specifies which version of the ONNX OperatorSet is
- // being imported.
- //
- // All nodes in the ModelProto's graph will bind against the operator
- // with the same-domain/same-op_type operator with the HIGHEST version
- // in the referenced operator sets.
- repeated OperatorSetIdProto opset_import = 8;
-
- // The name of the framework or tool used to generate this model.
- // This field SHOULD be present to indicate which implementation/tool/framework
- // emitted the model.
- optional string producer_name = 2;
-
- // The version of the framework or tool used to generate this model.
- // This field SHOULD be present to indicate which implementation/tool/framework
- // emitted the model.
- optional string producer_version = 3;
-
- // Domain name of the model.
- // We use reverse domain names as name space indicators. For example:
- // `com.facebook.fair` or `com.microsoft.cognitiveservices`
- //
- // Together with `model_version` and GraphProto.name, this forms the unique identity of
- // the graph.
- optional string domain = 4;
-
- // The version of the graph encoded. See Version enum below.
- optional int64 model_version = 5;
-
- // A human-readable documentation for this model. Markdown is allowed.
- optional string doc_string = 6;
+ // main module of the model
+ optional ModuleDef main_module = 2;
- // The parameterized graph that is evaluated to execute the model.
- // The main graph, in single graph case, it is ONNX compatible.
- optional GraphProto graph = 7;
+ repeated caffe2.TensorProto parameters = 3;
+ repeated caffe2.TensorProto value_infos = 4;
- // The remaining nets in MetaNetDef.
- // Submodules and methods in PyTorch.
- repeated GraphProto methods = 15;
-
- // Named metadata values; keys should be distinct.
- // Many meta data in MetaNetDef and preditor are piggy backed here.
- // 1) project
- // 2) model_class
- // 3) internal_version
- // 4) predictor_type
- // 5) predictor_id
- // 6) execute_plan
- // 7) applicationSpecificInfo (another string map, need to verify it has no duplicate.)
- // 8) engine
- // 9) publish time
- repeated StringStringEntryProto metadata_props = 14;
-
- // Model name
- optional string name = 16;
-
- // Model name
- repeated AttributeProto annotations = 17;
-
- // Mapping from list name to blob name list, must be string list type.
- // Equivalent to blobs in MetaNetDef.
- repeated AttributeProto blob_lists = 51;
-
- // Mapping from plan name to serialized plan, must be string list type.
- // Equivalent to plans in MetaNetDef.
- repeated AttributeProto plans = 52;
-};
-
-// StringStringEntryProto follows the pattern for cross-proto-version maps.
-// See https://developers.google.com/protocol-buffers/docs/proto3#maps
-message StringStringEntryProto {
- optional string key = 1;
- optional string value= 2;
-};
-
-// Graphs
-//
-// A graph defines the computational logic of a model and is comprised of a parameterized
-// list of nodes that form a directed acyclic graph based on their inputs and outputs.
-// This is the equivalent of the "network" or "graph" in many deep learning
-// frameworks.
-// Graph ==> NetDef in Caffe2
-// ==> Submodule/Method in PyTorch
-message GraphProto {
- // The nodes in the graph, sorted topologically.
- repeated NodeProto node = 1;
-
- // The name of the graph.
- optional string name = 2; // namespace Graph
-
- // A list of named tensor values, used to specify constant inputs of the graph.
- // Each TensorProto entry must have a distinct name (within the list) that
- // also appears in the input list.
- repeated TensorProto initializer = 5;
-
- // A human-readable documentation for this graph. Markdown is allowed.
- optional string doc_string = 10;
-
- // The inputs and outputs of the graph.
- repeated ValueInfoProto input = 11;
- repeated ValueInfoProto output = 12;
-
- // Information for the values in the graph. The ValueInfoProto.name's
- // must be distinct. It is optional for a value to appear in value_info list.
- repeated ValueInfoProto value_info = 13;
-
- // Additional annotations.
- repeated AttributeProto annotations = 14;
-
- // DO NOT USE the following fields, they were deprecated from earlier versions.
- // repeated string input = 3;
- // repeated string output = 4;
- // optional int64 ir_version = 6;
- // optional int64 producer_version = 7;
- // optional string producer_tag = 8;
- // optional string domain = 9;
-}
+ // to distinguish whether exported from c2 or torch
+ optional string producer_name = 5;
-// Tensors
-//
-// A serialized tensor value.
-message TensorProto {
- enum DataType {
- UNDEFINED = 0;
- // Basic types.
- FLOAT = 1; // float
- UINT8 = 2; // uint8_t
- INT8 = 3; // int8_t
- UINT16 = 4; // uint16_t
- INT16 = 5; // int16_t
- INT32 = 6; // int32_t
- INT64 = 7; // int64_t
- STRING = 8; // string
- BOOL = 9; // bool
+ // put build version here
+ optional string producer_version = 6;
- // Advanced types
- FLOAT16 = 10;
- DOUBLE = 11;
- UINT32 = 12;
- UINT64 = 13;
- COMPLEX64 = 14; // complex with float32 real and imaginary components
- COMPLEX128 = 15; // complex with float64 real and imaginary components
- // Future extensions go here.
+ optional string name = 7;
- // Special data type, real type information is stored in ValueInfoProto.
- // If data_type is SPECIAL, raw_data should be used.
- SPECIAL = 51;
- }
-
- // The shape of the tensor.
- repeated int64 dims = 1;
- repeated int64 strides = 14;
-
- // The data type of the tensor.
- optional DataType data_type = 2;
-
- // For very large tensors, we may want to store them in chunks, in which
- // case the following fields will specify the segment that is stored in
- // the current TensorProto.
- message Segment {
- optional int64 begin = 1;
- optional int64 end = 2;
- optional int64 chuck_num = 51;
- optional int64 chuck_id = 52;
- }
- // Used as offset in the external shared data.
- optional Segment segment = 3;
-
- // Tensor content must be organized in row-major order.
- //
- // Depending on the data_type field, exactly one of the fields below with
- // name ending in _data is used to store the elements of the tensor.
-
- // For float and complex64 values
- // Complex64 tensors are encoded as a single array of floats,
- // with the real components appearing in odd numbered positions,
- // and the corresponding imaginary component apparing in the
- // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
- // is encoded as [1.0, 2.0 ,3.0 ,4.0]
- // When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
- repeated float float_data = 4 [packed = true];
-
- // For int32, uint8, int8, uint16, int16, bool, and Half values
- // float16 values must be bit-wise converted to an uint16_t prior
- // to writing to the buffer.
- // When this field is present, the data_type field MUST be
- // INT32, INT16, INT8, UINT16, INT8, BOOL, or FLOAT16
- repeated int32 int32_data = 5 [packed = true];
-
- // For strings.
- // Each element of string_data is a UTF-8 encoded Unicode
- // string. No trailing null, no leading BOM. The protobuf "string"
- // scalar type is not used to match ML community conventions.
- // When this field is present, the data_type field MUST be STRING
- repeated bytes string_data = 6;
-
- // For int64.
- // When this field is present, the data_type field MUST be INT64
- repeated int64 int64_data = 7 [packed = true];
-
- // Optionally, a name for the tensor.
- optional string name = 8; // namespace Value
-
- // A human-readable documentation for this tensor. Markdown is allowed.
- optional string doc_string = 12;
-
- // Serializations can either use one of the fields above, or use this
- // raw bytes field. The only exception is the string case, where one is
- // required to store the content in the repeated bytes string_data field.
- //
- // When this raw_data field is used to store tensor value, elements MUST
- // be stored in as fixed-width, little-endian order.
- // Floating-point data types MUST be stored in IEEE 754 format.
- // Complex64 elements must be written as two consecutive FLOAT values, real component first.
- // Complex128 elements must be written as two consecutive DOUBLE values, real component first.
- // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
- //
- // Note: the advantage of specific field rather than the raw_data field is
- // that in some cases (e.g. int data), protobuf does a better packing via
- // variable length storage, and may lead to smaller binary footprint.
- // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED
- optional bytes raw_data = 9;
-
- // For double
- // Complex64 tensors are encoded as a single array of doubles,
- // with the real components appearing in odd numbered positions,
- // and the corresponding imaginary component apparing in the
- // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
- // is encoded as [1.0, 2.0 ,3.0 ,4.0]
- // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128
- repeated double double_data = 10 [packed = true];
-
- // For uint64 and uint32 values
- // When this field is present, the data_type field MUST be
- // UINT32 or UINT64
- repeated uint64 uint64_data = 11 [packed = true];
-
- // External data by file name
- optional string external_data = 13;
-
- // If two tensors represent the same weights/content, use alias.
- // Must exist a TensorProto named alias in the initializer list.
- // To avoid the duplicate tensor in attribute, such as value in Constant node.
- // This is useful, if everything is stored just in the proto.
- optional string alias = 16;
-
- // Additional annotations.
- repeated AttributeProto annotations = 17;
-
- // Device info
- optional caffe2.DeviceOption device_option = 51;
-
- // For PyTorch serialized tensor.
- optional int64 require_gradient = 52;
- optional int64 is_buffer = 53;
-}
-
-// Defines a tensor shape. A dimension can be either an integer value
-// or a symbolic variable. A symbolic variable represents an unknown
-// dimension.
-message TensorShapeProto {
- message Dimension {
- oneof value {
- int64 dim_value = 1;
- string dim_param = 2; // namespace Shape
- };
- // Standard denotation can optionally be used to denote tensor
- // dimensions with standard semantic descriptions to ensure
- // that operations are applied to the correct axis of a tensor.
- // Refer to https://github.com/onnx/onnx/blob/master/docs/DimensionDenotation.md#denotation-definition
- // for pre-defined dimension denotations.
- optional string denotation = 3;
- };
- // To represent a scalar, using no dim to represent 0-d tensor.
- repeated Dimension dim = 1;
-
- repeated Dimension stride = 51;
-}
-
-// Types
-//
-// The standard ONNX data types.
-message TypeProto {
-
- message Tensor {
- // This field MUST NOT have the value of UNDEFINED
- // This field MUST be present for this version of the IR.
- optional TensorProto.DataType elem_type = 1;
- optional TensorShapeProto shape = 2;
- }
-
- // Sequence type: List, Tuple
- message Sequence {
- // elem_type and elem_type_list cannot appear together.
- // If all the element types are the same, we use elem_type,
- // otherwise, we specify the type of each element in elem_type_list.
- optional TypeProto elem_type = 1;
- repeated TypeProto elem_type_list = 51;
- enum SequenceType {
- UNDEFINED = 0;
- LIST = 1;
- TUPLE = 2;
- }
- optional SequenceType sequence_type = 52;
- }
-
- // Map<K, V>, (not necessary at this moment)
- message Map {
- optional TensorProto.DataType key_type = 1;
- optional TypeProto value_type = 2;
- }
-
- // Special type of blobs, based on the type_name, we can choose the right
- // serializer and deserialzier.
- message SpecialBlob {
- optional string type_name = 1;
- }
-
- oneof value {
- // The type of a tensor.
- Tensor tensor_type = 1;
- Sequence sequence_type = 4;
- Map map_type = 5;
- SpecialBlob special_type = 51;
- }
-
- // An optional denotation can be used to denote the whole
- // type with a standard semantic description as to what is
- // stored inside. Refer to https://github.com/onnx/onnx/blob/master/docs/TypeDenotation.md#type-denotation-definition
- // for pre-defined type denotations.
- optional string denotation = 6;
-}
+ optional string debug_info = 8;
-// Operator Sets
-//
-// OperatorSets are uniquely identified by a (domain, opset_version) pair.
-message OperatorSetIdProto {
- // The domain of the operator set being identified.
- // The empty string ("") or absence of this field implies the operator
- // set that is defined as part of the ONNX specification.
- // This field MUST be present in this version of the IR when referring to any other operator set.
- optional string domain = 1;
+ // annotations, it is used for MetaNetDef's metadata
+ repeated caffe2.Argument annotations = 9;
- // The version of the operator set being identified.
- // This field MUST be present in this version of the IR.
- optional int64 version = 2;
}
diff --git a/caffe2/python/convert.py b/caffe2/python/convert.py
index 50eaf220c7..44f81d6e2d 100644
--- a/caffe2/python/convert.py
+++ b/caffe2/python/convert.py
@@ -8,59 +8,3 @@ from __future__ import unicode_literals
from caffe2.proto import caffe2_pb2, torch_pb2
import caffe2.python._import_c_extension as C
-
-
-def ArgumentToAttributeProto(arg):
- serialized_arg = None
- if hasattr(arg, 'SerializeToString') and callable(arg.SerializeToString):
- serialized_arg = arg.SerializeToString()
- elif isinstance(arg, bytes):
- serialized_arg = arg
- else:
- raise ValueError('No SerializeToString method is detected. '
- 'neither arg is bytes.\ntype is {}'.format(type(arg)))
- attr = torch_pb2.AttributeProto()
- attr.ParseFromString(C.argument_to_attribute_proto(serialized_arg))
- return attr
-
-
-def AttributeProtoToArgument(attr):
- serialized_attr = None
- if hasattr(attr, 'SerializeToString') and callable(attr.SerializeToString):
- serialized_attr = attr.SerializeToString()
- elif isinstance(attr, bytes):
- serialized_attr = attr
- else:
- raise ValueError('No SerializeToString method is detected. '
- 'neither attr is bytes.\ntype is {}'.format(type(attr)))
- arg = caffe2_pb2.Argument()
- arg.ParseFromString(C.attribute_proto_to_argument(serialized_attr))
- return arg
-
-
-def OperatorDefToNodeProto(op_def):
- serialized_op_def = None
- if hasattr(op_def, 'SerializeToString') and callable(op_def.SerializeToString):
- serialized_op_def = op_def.SerializeToString()
- elif isinstance(op_def, bytes):
- serialized_op_def = op_def
- else:
- raise ValueError('No SerializeToString method is detected. '
- 'neither op_def is bytes.\ntype is {}'.format(type(op_def)))
- node = torch_pb2.NodeProto()
- node.ParseFromString(C.operator_def_to_node_proto(serialized_op_def))
- return node
-
-
-def NodeProtoToOperatorDef(node_proto):
- serialized_node_proto = None
- if hasattr(node_proto, 'SerializeToString') and callable(node_proto.SerializeToString):
- serialized_node_proto = node_proto.SerializeToString()
- elif isinstance(node_proto, bytes):
- serialized_node_proto = node_proto
- else:
- raise ValueError('No SerializeToString method is detected. '
- 'neither node_proto is bytes.\ntype is {}'.format(type(node_proto)))
- op_def = caffe2_pb2.OperatorDef()
- op_def.ParseFromString(C.node_proto_to_operator_def(serialized_node_proto))
- return op_def
diff --git a/caffe2/python/convert_test.py b/caffe2/python/convert_test.py
index c8de7e9750..82c969c901 100644
--- a/caffe2/python/convert_test.py
+++ b/caffe2/python/convert_test.py
@@ -12,239 +12,5 @@ class TestOperator(unittest.TestCase):
def setUp(self):
workspace.ResetWorkspace()
- def testArgument2AttributeProto(self):
- arg_f = caffe2_pb2.Argument()
- arg_f.name = "TestArgF"
- arg_f.f = 10.0
- attr_f = convert.ArgumentToAttributeProto(arg_f)
- self.assertEqual(attr_f.name, arg_f.name)
- self.assertEqual(attr_f.f, arg_f.f)
-
- arg_i = caffe2_pb2.Argument()
- arg_i.name = "TestArgI"
- arg_i.i = 100
- attr_i = convert.ArgumentToAttributeProto(arg_i)
- self.assertEqual(attr_i.name, arg_i.name)
- self.assertEqual(attr_i.i, arg_i.i)
-
- arg_s = caffe2_pb2.Argument()
- arg_s.name = "TestArgS"
- arg_s.s = "TestS".encode("utf-8")
- attr_s = convert.ArgumentToAttributeProto(arg_s)
- self.assertEqual(attr_s.name, arg_s.name)
- self.assertEqual(attr_s.s, arg_s.s)
-
- # TODO: test net arg
-
- arg_floats = caffe2_pb2.Argument()
- arg_floats.name = "TestArgFloats"
- arg_floats.floats.extend([10.0, 11.0, 12.0])
- attr_floats = convert.ArgumentToAttributeProto(arg_floats)
- self.assertEqual(attr_floats.name, arg_floats.name)
- self.assertEqual(attr_floats.floats, arg_floats.floats)
-
- arg_ints = caffe2_pb2.Argument()
- arg_ints.name = "TestArgInts"
- arg_ints.ints.extend([100, 101, 102])
- attr_ints = convert.ArgumentToAttributeProto(arg_ints)
- self.assertEqual(attr_ints.name, arg_ints.name)
- self.assertEqual(attr_ints.ints, arg_ints.ints)
-
- arg_strings = caffe2_pb2.Argument()
- arg_strings.name = "TestArgStrings"
- arg_strings.strings.extend([
- "TestStrings1".encode("utf-8"),
- "TestStrings2".encode("utf-8"),
- ])
- attr_strings = convert.ArgumentToAttributeProto(arg_strings)
- self.assertEqual(attr_strings.name, arg_strings.name)
- self.assertEqual(attr_strings.strings, arg_strings.strings)
-
- # TODO: test nets arg
-
- def testAttributeProto2Argument(self):
- attr_f = torch_pb2.AttributeProto()
- attr_f.type = torch_pb2.AttributeProto.FLOAT
- attr_f.name = "TestAttrF"
- attr_f.f = 10.0
- arg_f = convert.AttributeProtoToArgument(attr_f)
- self.assertEqual(arg_f.name, attr_f.name)
- self.assertEqual(arg_f.f, attr_f.f)
-
- attr_i = torch_pb2.AttributeProto()
- attr_i.type = torch_pb2.AttributeProto.INT
- attr_i.name = "TestArgI"
- attr_i.i = 100
- arg_i = convert.AttributeProtoToArgument(attr_i)
- self.assertEqual(arg_i.name, attr_i.name)
- self.assertEqual(arg_i.i, attr_i.i)
-
- attr_s = torch_pb2.AttributeProto()
- attr_s.type = torch_pb2.AttributeProto.STRING
- attr_s.name = "TestArgS"
- attr_s.s = "TestS".encode("utf-8")
- arg_s = convert.AttributeProtoToArgument(attr_s)
- self.assertEqual(arg_s.name, attr_s.name)
- self.assertEqual(arg_s.s, attr_s.s)
-
- # TODO: test graph attribute
-
- attr_floats = torch_pb2.AttributeProto()
- attr_floats.type = torch_pb2.AttributeProto.FLOATS
- attr_floats.name = "TestAttrFloats"
- attr_floats.floats.extend([10.0, 11.0, 12.0])
- arg_floats = convert.AttributeProtoToArgument(attr_floats)
- self.assertEqual(arg_floats.name, attr_floats.name)
- self.assertEqual(arg_floats.floats, attr_floats.floats)
-
- attr_ints = torch_pb2.AttributeProto()
- attr_ints.type = torch_pb2.AttributeProto.INTS
- attr_ints.name = "TestArgInts"
- attr_ints.ints.extend([100, 101, 102])
- arg_ints = convert.AttributeProtoToArgument(attr_ints)
- self.assertEqual(arg_ints.name, attr_ints.name)
- self.assertEqual(arg_ints.ints, attr_ints.ints)
-
- attr_strings = torch_pb2.AttributeProto()
- attr_strings.type = torch_pb2.AttributeProto.STRINGS
- attr_strings.name = "TestArgStrings"
- attr_strings.strings.extend([
- "TestStrings1".encode("utf-8"),
- "TestStrings2".encode("utf-8"),
- ])
- arg_strings = convert.AttributeProtoToArgument(attr_strings)
- self.assertEqual(arg_strings.name, attr_strings.name)
- self.assertEqual(arg_strings.strings, attr_strings.strings)
-
- # TODO: test graphs attribute
-
-
- def testOperatorDef2NodeProto(self):
- op_def = caffe2_pb2.OperatorDef()
- op_def.input.extend(["A", "B", "C"])
- op_def.output.extend(["X", "Y"])
- op_def.name = "TestOpName"
- op_def.type = "TestOp"
- arg1 = caffe2_pb2.Argument()
- arg1.name = "TestArg1"
- arg1.i = 1
- arg2 = caffe2_pb2.Argument()
- arg2.name = "TestArg2"
- arg1.s = "TestInfo".encode("utf-8")
- op_def.arg.extend([arg1, arg2])
- op_def.device_option.CopyFrom(caffe2_pb2.DeviceOption())
- op_def.engine = "TestEngine".encode("utf-8")
- op_def.control_input.extend(["input1", "input2"])
- op_def.is_gradient_op = True
- op_def.debug_info = "TestDebugInfo"
-
- node = convert.OperatorDefToNodeProto(op_def)
-
- self.assertEqual(node.input, op_def.input)
- self.assertEqual(node.output, op_def.output)
- self.assertEqual(node.name, op_def.name)
- self.assertEqual(node.op_type, op_def.type)
- self.assertEqual(node.attribute[0].name, op_def.arg[0].name)
- self.assertEqual(node.attribute[1].name, op_def.arg[1].name)
- self.assertEqual(node.device_option, op_def.device_option)
- node_engine = [a.s.decode("utf-8") for a in node.annotations if a.name == "engine"][0]
- self.assertEqual(node_engine, op_def.engine)
- node_control_input = [a.strings for a in node.annotations if a.name == "control_input"][0]
- self.assertEqual(len(node_control_input), len(op_def.control_input))
- for x, y in zip(node_control_input, op_def.control_input):
- self.assertEqual(x.decode("utf-8"), y)
- self.assertEqual(node.doc_string, op_def.debug_info)
- node_is_gradient_op = [a.i for a in node.annotations if a.name == "is_gradient_op"][0]
- self.assertEqual(node_is_gradient_op, int(op_def.is_gradient_op))
-
- def testNodeProto2OperatorDef(self):
- node = torch_pb2.NodeProto()
- node.input.extend(["A", "B", "C"])
- node.output.extend(["X", "Y"])
- node.name = "TestOpName"
- node.op_type = "TestOp"
- attr1 = torch_pb2.AttributeProto()
- attr1.name = "TestAttr1"
- attr1.type = torch_pb2.AttributeProto.STRING
- attr1.s = "TestInfo".encode("utf-8")
- attr2 = torch_pb2.AttributeProto()
- attr2.name = "TestAttr2"
- attr2.type = torch_pb2.AttributeProto.INT
- attr2.i = 10
- node.attribute.extend([attr1, attr2])
- node.device_option.CopyFrom(caffe2_pb2.DeviceOption())
- anno1 = torch_pb2.AttributeProto()
- anno1.name = "engine"
- anno1.type = torch_pb2.AttributeProto.STRING
- anno1.s = "TestEngine".encode("utf-8")
- anno2 = torch_pb2.AttributeProto()
- anno2.name = "control_input"
- anno2.type = torch_pb2.AttributeProto.STRINGS
- anno2.strings.extend(["input1".encode("utf-8"), "input2".encode("utf-8")])
- anno3 = torch_pb2.AttributeProto()
- anno3.name = "is_gradient_op"
- anno3.type = torch_pb2.AttributeProto.INT
- anno3.i = 1
- node.annotations.extend([anno1, anno2, anno3])
- node.doc_string = "TestDocString".encode("utf-8")
-
- op_def = convert.NodeProtoToOperatorDef(node)
-
- self.assertEqual(op_def.input, node.input)
- self.assertEqual(op_def.output, node.output)
- self.assertEqual(op_def.name, node.name)
- self.assertEqual(op_def.type, node.op_type)
- self.assertEqual(op_def.arg[0].name, node.attribute[0].name)
- self.assertEqual(op_def.arg[1].name, node.attribute[1].name)
- self.assertEqual(op_def.device_option, node.device_option)
- node_engine = [a.s for a in node.annotations if a.name == "engine"][0]
- self.assertEqual(op_def.engine, node_engine.decode("utf-8"))
- node_control_input = [a.strings for a in node.annotations if a.name == "control_input"][0]
- for x, y in zip(op_def.control_input, node_control_input):
- self.assertEqual(x, y.decode("utf-8"))
- self.assertEqual(op_def.debug_info, node.doc_string)
- node_is_gradient_op = [a.i for a in node.annotations if a.name == "is_gradient_op"][0]
- self.assertEqual(int(op_def.is_gradient_op), node_is_gradient_op)
-
- def testEnd2End(self):
- op_def = caffe2_pb2.OperatorDef()
- op_def.type = "Add"
- op_def.input.extend(["input1"])
- op_def.input.extend(["input2"])
- op_def.output.extend(["output1"])
- node = convert.OperatorDefToNodeProto(op_def)
-
- input1 = np.random.randn(1, 3, 1, 5).astype(np.float32)
- input2 = np.random.randn(2, 1, 4, 1).astype(np.float32)
- ref_output1 = input1 + input2
- workspace.FeedBlob("input1", input1)
- workspace.FeedBlob("input2", input2)
- self.assertEqual(workspace.RunOperatorOnce(node.SerializeToString(), legacy_proto=False), True)
-
- self.assertEqual(workspace.HasBlob("output1"), True)
- fetched_back = workspace.FetchBlob("output1")
- np.testing.assert_array_equal(fetched_back, ref_output1)
-
- def testRoundTrip(self):
- op_def = caffe2_pb2.OperatorDef()
- op_def.type = "Add"
- op_def.input.extend(["input1"])
- op_def.input.extend(["input2"])
- op_def.output.extend(["output1"])
- node = convert.OperatorDefToNodeProto(op_def)
- new_op_def = convert.NodeProtoToOperatorDef(node)
-
- input1 = np.random.randn(1, 3, 1, 5).astype(np.float32)
- input2 = np.random.randn(2, 1, 4, 1).astype(np.float32)
- ref_output1 = input1 + input2
- workspace.FeedBlob("input1", input1)
- workspace.FeedBlob("input2", input2)
- self.assertEqual(workspace.RunOperatorOnce(new_op_def.SerializeToString()), True)
-
- self.assertEqual(workspace.HasBlob("output1"), True)
- fetched_back = workspace.FetchBlob("output1")
- np.testing.assert_array_equal(fetched_back, ref_output1)
-
-
if __name__ == '__main__':
unittest.main()
diff --git a/caffe2/python/pybind_state.cc b/caffe2/python/pybind_state.cc
index 7062ead045..7ebee57d49 100644
--- a/caffe2/python/pybind_state.cc
+++ b/caffe2/python/pybind_state.cc
@@ -1187,17 +1187,10 @@ void addGlobalMethods(py::module& m) {
return true;
});
m.def("nets", []() { return gWorkspace->Nets(); });
- m.def("run_operator_once", [](const py::bytes& op_def, bool legacy_proto=true) {
+ m.def("run_operator_once", [](const py::bytes& op_def) {
CAFFE_ENFORCE(gWorkspace);
OperatorDef def;
- if (legacy_proto) {
- CAFFE_ENFORCE(ParseProtoFromLargeString(op_def.cast<std::string>(), &def));
- } else {
- ::torch::NodeProto node;
- CAFFE_ENFORCE(
- ParseProtoFromLargeString(op_def.cast<std::string>(), &node));
- NodeProtoToOperatorDef(node, &def);
- }
+ CAFFE_ENFORCE(ParseProtoFromLargeString(op_def.cast<std::string>(), &def));
py::gil_scoped_release g;
CAFFE_ENFORCE(gWorkspace->RunOperatorOnce(def));
return true;
@@ -1534,38 +1527,6 @@ void addGlobalMethods(py::module& m) {
CAFFE_ENFORCE(blob);
return BlobStat::sizeBytes(*blob);
});
- m.def("argument_to_attribute_proto", [](py::bytes arg_str) -> py::bytes {
- Argument arg;
- CAFFE_ENFORCE(
- ParseProtoFromLargeString(arg_str.cast<std::string>(), &arg));
- ::torch::AttributeProto attr;
- ArgumentToAttributeProto(arg, &attr);
- return attr.SerializeAsString();
- });
- m.def("attribute_proto_to_argument", [](py::bytes attr_str) -> py::bytes {
- ::torch::AttributeProto attr;
- CAFFE_ENFORCE(
- ParseProtoFromLargeString(attr_str.cast<std::string>(), &attr));
- Argument arg;
- AttributeProtoToArgument(attr, &arg);
- return arg.SerializeAsString();
- });
- m.def("operator_def_to_node_proto", [](py::bytes op_str) -> py::bytes {
- OperatorDef op_def;
- CAFFE_ENFORCE(
- ParseProtoFromLargeString(op_str.cast<std::string>(), &op_def));
- ::torch::NodeProto node;
- OperatorDefToNodeProto(op_def, &node);
- return node.SerializeAsString();
- });
- m.def("node_proto_to_operator_def", [](py::bytes node_str) -> py::bytes {
- ::torch::NodeProto node_proto;
- CAFFE_ENFORCE(
- ParseProtoFromLargeString(node_str.cast<std::string>(), &node_proto));
- OperatorDef op_def;
- NodeProtoToOperatorDef(node_proto, &op_def);
- return op_def.SerializeAsString();
- });
m.def("support_onnx_export", [](const std::string& op) -> bool {
const OpSchema* schema = caffe2::OpSchemaRegistry::Schema(op);
if (!schema) {
diff --git a/caffe2/python/workspace.py b/caffe2/python/workspace.py
index a41cc15317..ef02f64dc9 100644
--- a/caffe2/python/workspace.py
+++ b/caffe2/python/workspace.py
@@ -163,8 +163,8 @@ def GetOperatorCost(operator, blobs):
return C.get_operator_cost(StringifyProto(operator), blobs)
-def RunOperatorOnce(operator, legacy_proto=True):
- return C.run_operator_once(StringifyProto(operator), legacy_proto)
+def RunOperatorOnce(operator):
+ return C.run_operator_once(StringifyProto(operator))
def RunOperatorsOnce(operators):
diff --git a/caffe2/utils/proto_convert.cc b/caffe2/utils/proto_convert.cc
index 790bd27429..1d69c8c80c 100644
--- a/caffe2/utils/proto_convert.cc
+++ b/caffe2/utils/proto_convert.cc
@@ -2,185 +2,4 @@
#include "caffe2/core/logging.h"
namespace caffe2 {
-
-C10_EXPORT void ArgumentToAttributeProto(
- const Argument& arg,
- ::torch::AttributeProto* attr) {
- CAFFE_ENFORCE(arg.has_name());
- attr->set_name(arg.name());
- if (arg.has_f()) {
- attr->set_f(arg.f());
- } else if (arg.has_i()) {
- attr->set_i(arg.i());
- } else if (arg.has_s()) {
- attr->set_s(arg.s());
- } else if (arg.has_n()) {
- // TODO
- CAFFE_THROW("NetDef conversion is not implemented yet.");
- } else if (arg.floats_size() > 0) {
- attr->mutable_floats()->CopyFrom(arg.floats());
- } else if (arg.ints_size() > 0) {
- attr->mutable_ints()->CopyFrom(arg.ints());
- } else if (arg.strings_size() > 0) {
- attr->mutable_strings()->CopyFrom(arg.strings());
- } else if (arg.nets_size() > 0) {
- // TODO
- CAFFE_THROW("NetDefs conversion is not implemented yet.");
- }
-}
-
-C10_EXPORT void AttributeProtoToArgument(
- const ::torch::AttributeProto& attr,
- Argument* arg) {
- CAFFE_ENFORCE(attr.has_name());
- arg->set_name(attr.name());
- CAFFE_ENFORCE(attr.has_type());
- const auto type = attr.type();
- if (type ==
- ::torch::AttributeProto_AttributeType::
- AttributeProto_AttributeType_FLOAT) {
- CAFFE_ENFORCE(attr.has_f());
- arg->set_f(attr.f());
- } else if (
- type ==
- ::torch::AttributeProto_AttributeType::AttributeProto_AttributeType_INT) {
- CAFFE_ENFORCE(attr.has_i());
- arg->set_i(attr.i());
- } else if (
- type ==
- ::torch::AttributeProto_AttributeType::
- AttributeProto_AttributeType_STRING) {
- CAFFE_ENFORCE(attr.has_s());
- arg->set_s(attr.s());
- } else if (
- type ==
- ::torch::AttributeProto_AttributeType::
- AttributeProto_AttributeType_TENSOR) {
- CAFFE_THROW("Caffe2's Argument does not support tensor as attribute.");
- } else if (
- type ==
- ::torch::AttributeProto_AttributeType::
- AttributeProto_AttributeType_GRAPH) {
- // TODO
- CAFFE_THROW("GraphProto conversion is not implemented yet.");
- } else if (
- type ==
- ::torch::AttributeProto_AttributeType::
- AttributeProto_AttributeType_FLOATS) {
- arg->mutable_floats()->CopyFrom(attr.floats());
- } else if (
- type ==
- ::torch::AttributeProto_AttributeType::
- AttributeProto_AttributeType_INTS) {
- arg->mutable_ints()->CopyFrom(attr.ints());
- } else if (
- type ==
- ::torch::AttributeProto_AttributeType::
- AttributeProto_AttributeType_STRINGS) {
- arg->mutable_strings()->CopyFrom(attr.strings());
- } else if (
- type ==
- ::torch::AttributeProto_AttributeType::
- AttributeProto_AttributeType_TENSORS) {
- CAFFE_THROW("Caffe2's Argument does not support tensors as attribute.");
- } else if (
- type ==
- ::torch::AttributeProto_AttributeType::
- AttributeProto_AttributeType_GRAPHS) {
- // TODO
- CAFFE_THROW("GraphProtos conversion is not implemented yet.");
- } else {
- CAFFE_THROW("Unknow Attribute type.");
- }
-}
-
-C10_EXPORT void OperatorDefToNodeProto(
- const OperatorDef& def,
- ::torch::NodeProto* node) {
- node->mutable_input()->CopyFrom(def.input());
- node->mutable_output()->CopyFrom(def.output());
- if (def.has_name()) {
- node->set_name(def.name());
- }
- CAFFE_ENFORCE(def.has_type());
- node->set_op_type(def.type());
- for (int i = 0; i < def.arg_size(); ++i) {
- auto attr = node->add_attribute();
- ArgumentToAttributeProto(def.arg(i), attr);
- }
- if (def.has_device_option()) {
- node->mutable_device_option()->CopyFrom(def.device_option());
- }
- if (def.has_engine()) {
- auto attr = node->add_annotations();
- attr->set_name("engine");
- attr->set_type(::torch::AttributeProto_AttributeType::
- AttributeProto_AttributeType_STRING);
- attr->set_s(def.engine());
- }
- if (def.control_input_size() > 0) {
- auto attr = node->add_annotations();
- attr->set_name("control_input");
- attr->set_type(::torch::AttributeProto_AttributeType::
- AttributeProto_AttributeType_STRINGS);
- attr->mutable_strings()->CopyFrom(def.control_input());
- }
- if (def.has_is_gradient_op()) {
- auto attr = node->add_annotations();
- attr->set_name("is_gradient_op");
- attr->set_type(::torch::AttributeProto_AttributeType::
- AttributeProto_AttributeType_INT);
- if (def.is_gradient_op()) {
- attr->set_i(1);
- } else {
- attr->set_i(0);
- }
- }
- if (def.has_debug_info()) {
- node->set_doc_string(def.debug_info());
- }
-}
-
-C10_EXPORT void NodeProtoToOperatorDef(
- const ::torch::NodeProto& node,
- OperatorDef* def) {
- def->mutable_input()->CopyFrom(node.input());
- def->mutable_output()->CopyFrom(node.output());
- if (node.has_name()) {
- def->set_name(node.name());
- }
-
- CAFFE_ENFORCE(node.has_op_type());
- def->set_type(node.op_type());
- for (int i = 0; i < node.attribute_size(); ++i) {
- auto arg = def->add_arg();
- AttributeProtoToArgument(node.attribute(i), arg);
- }
- if (node.has_doc_string()) {
- def->set_debug_info(node.doc_string());
- }
- for (int i = 0; i < node.annotations_size(); ++i) {
- const auto& attr = node.annotations(i);
- CAFFE_ENFORCE(attr.has_name());
- if (attr.name() == "engine") {
- CAFFE_ENFORCE(attr.has_s());
- def->set_engine(attr.s());
- } else if (attr.name() == "control_input") {
- def->mutable_control_input()->CopyFrom(attr.strings());
- } else if (attr.name() == "is_gradient_op") {
- CAFFE_ENFORCE(attr.has_i());
- if (i == 0) {
- def->set_is_gradient_op(false);
- } else {
- def->set_is_gradient_op(true);
- }
- }
- auto arg = def->add_arg();
- AttributeProtoToArgument(node.annotations(i), arg);
- }
- if (node.has_device_option()) {
- def->mutable_device_option()->CopyFrom(node.device_option());
- }
-}
-
} // namespace caffe2
diff --git a/caffe2/utils/proto_convert.h b/caffe2/utils/proto_convert.h
index a9ca9c3ad4..91bcf1bafa 100644
--- a/caffe2/utils/proto_convert.h
+++ b/caffe2/utils/proto_convert.h
@@ -6,20 +6,6 @@
#include "caffe2/proto/torch_pb.h"
namespace caffe2 {
-
-CAFFE2_API void ArgumentToAttributeProto(
- const Argument& arg,
- ::torch::AttributeProto* attr);
-CAFFE2_API void AttributeProtoToArgument(
- const ::torch::AttributeProto& attr,
- Argument* arg);
-CAFFE2_API void OperatorDefToNodeProto(
- const OperatorDef& def,
- ::torch::NodeProto* node);
-CAFFE2_API void NodeProtoToOperatorDef(
- const ::torch::NodeProto& node,
- OperatorDef* def);
-
} // namespace caffe2
#endif // CAFFE2_UTILS_PROTO_CONVERT_H_