From 35becd1879fae56fc417905c5154c402b9780a3f Mon Sep 17 00:00:00 2001 From: Lu Fang Date: Mon, 1 Oct 2018 15:42:45 -0700 Subject: New version of PT1 model format (#12149) Summary: Considered four different existing formats: 1) static graph, 2) torch script, 3) pickle files, 4) PyTorch C++ serialize APIs Pull Request resolved: https://github.com/pytorch/pytorch/pull/12149 Reviewed By: BIT-silence Differential Revision: D10098106 Pulled By: houseroad fbshipit-source-id: 94ec7fc57c842e50fae5286ddeda657a4967a07a --- caffe2/core/blob_serialization.cc | 15 +- caffe2/proto/caffe2.proto | 86 +++++- caffe2/proto/torch.proto | 564 ++++---------------------------------- caffe2/python/convert.py | 56 ---- caffe2/python/convert_test.py | 234 ---------------- caffe2/python/pybind_state.cc | 43 +-- caffe2/python/workspace.py | 4 +- caffe2/utils/proto_convert.cc | 181 ------------ caffe2/utils/proto_convert.h | 14 - 9 files changed, 139 insertions(+), 1058 deletions(-) diff --git a/caffe2/core/blob_serialization.cc b/caffe2/core/blob_serialization.cc index 8126b3d594..f27d16adf3 100644 --- a/caffe2/core/blob_serialization.cc +++ b/caffe2/core/blob_serialization.cc @@ -308,6 +308,12 @@ void TensorSerializer::Serialize( const_cast(raw_data + i * input.itemsize()), input.meta()); proto.add_string_data(SerializeBlob(temp_blob, "")); } + } break; + case TensorProto_DataType_SPECIAL: { + CAFFE_THROW("SPECIAL Tensor is not handled yet."); + } break; + case TensorProto_DataType_NO_CONTENT: { + CAFFE_THROW("NO_CONTENT Tensor should not be serialized."); } break; // Note: we intentially do not provide "default:" so if any new data types // are added, the compiler should warn the user to add the case here. @@ -520,7 +526,14 @@ void TensorDeserializer::Deserialize(const TensorProto& proto, Tensor* tensor) { (i + chunkBegin) * temp_blob.meta().itemsize(), 1); } - } + } break; + case TensorProto_DataType_SPECIAL: { + CAFFE_THROW("SPECIAL Tensor is not handled yet."); + } break; + case TensorProto_DataType_NO_CONTENT: { + CAFFE_THROW("NO_CONTENT Tensor should not be deserialized."); + } break; + // Note: we intentially do not provide "default:" so if any new data types } context->FinishDeviceComputation(); } diff --git a/caffe2/proto/caffe2.proto b/caffe2/proto/caffe2.proto index 21bdec2c68..63a2a256de 100644 --- a/caffe2/proto/caffe2.proto +++ b/caffe2/proto/caffe2.proto @@ -15,23 +15,46 @@ package caffe2; message TensorProto { // The dimensions in the tensor. repeated int64 dims = 1; + // The strides of the tensor. + repeated int64 strides = 12; + + // Data type enum DataType { UNDEFINED = 0; - FLOAT = 1; // float - INT32 = 2; // int - BYTE = 3; // BYTE, when deserialized, is going to be restored as uint8. - STRING = 4; // string - // Less-commonly used data types. - BOOL = 5; // bool - UINT8 = 6; // uint8_t - INT8 = 7; // int8_t - UINT16 = 8; // uint16_t - INT16 = 9; // int16_t - INT64 = 10; // int64_t + + // Basic types + FLOAT = 1; // float + INT32 = 2; // int + BYTE = 3; // byte, when deserialized, is going to be restored as uint8 + STRING = 4; // string + + // Less-commonly used data types + BOOL = 5; // bool + UINT8 = 6; // uint8_t + INT8 = 7; // int8_t + UINT16 = 8; // uint16_t + INT16 = 9; // int16_t + INT64 = 10; // int64_t FLOAT16 = 12; // at::Half - DOUBLE = 13; // double + DOUBLE = 13; // double + + // Special data type, type information is stored in the special type field + SPECIAL = 51; + // Use TensorProto to specify the shape and type + NO_CONTENT = 52; } optional DataType data_type = 2 [default = FLOAT]; + // if data_type is SPECIAL, use this field to express the type info + optional SpecialType special_type = 13; + + // Data storage + enum StorageType { + TYPED = 1; + RAW = 2; + EXTERNAL = 3; + ALIAS = 4; + } + optional StorageType storage_type = 14 [default = TYPED]; // For float repeated float float_data = 3 [packed = true]; // For int32, uint8, int8, uint16, int16, bool, and float16 @@ -46,6 +69,13 @@ message TensorProto { repeated double double_data = 9 [packed = true]; // For int64 repeated int64 int64_data = 10 [packed = true]; + // For raw data + optional bytes raw_data = 15; + // External data by file name + optional string external_data = 16; + // For argument, to share the content + optional string alias = 17; + // Optionally, a name for the tensor. optional string name = 7; @@ -53,13 +83,23 @@ message TensorProto { // it was serialized from. This is useful in cases like snapshotting a whole // workspace in a multi-GPU environment. optional DeviceOption device_detail = 8; + // When loading from chunks this is going to indicate where to put data in the // full array. When not used full data have to be present message Segment { required int64 begin = 1; required int64 end = 2; + optional int64 chunk_num = 51; + optional int64 chunk_id = 52; } optional Segment segment = 11; + optional string debug_info = 18; + + // For PyTorch serialized tensor. + optional bool require_gradient = 19; + optional bool is_buffer = 20; + + repeated Argument annotations = 21; } message QTensorProto { @@ -86,7 +126,11 @@ message TensorShape { repeated int32 unknown_dims = 3; optional bool unknown_shape = 4 [default = false]; optional string name = 5; +} +// This is prepared for non-tensor types. +message SpecialType { + optional string name = 1; } message TensorShapes { @@ -97,13 +141,17 @@ message TensorShapes { // values, or repeated float, int and string arrays. message Argument { optional string name = 1; + optional float f = 2; optional int64 i = 3; optional bytes s = 4; + optional TensorProto t = 10; optional NetDef n = 8; + repeated float floats = 5; repeated int64 ints = 6; repeated bytes strings = 7; + repeated TensorProto tensors = 11; repeated NetDef nets = 9; } @@ -152,7 +200,11 @@ message DeviceOption { // Operator Definition. message OperatorDef { repeated string input = 1; // the name of the input blobs + // the input name in the schema, for named inputs + repeated string mapped_inputs = 11; repeated string output = 2; // the name of output top blobs + // the outputname in the schema, for named outputs + repeated string mapped_outputs = 12; optional string name = 3; // the operator name. This is optional. // the operator type. This is needed to create the object from the operator // registry. @@ -186,6 +238,16 @@ message OperatorDef { // This is an optional string with no assumed characteristics as // operators can be constructed in any language. optional string debug_info = 10; + + // additional annotations + repeated Argument annotations = 13; + + // for jit ir exporting + optional string aten_function = 14; + + // for operator versioning + optional string domain = 15; + optional string op_version = 16; } // Network definition. diff --git a/caffe2/proto/torch.proto b/caffe2/proto/torch.proto index 43dfd02b14..f31c3b65ec 100644 --- a/caffe2/proto/torch.proto +++ b/caffe2/proto/torch.proto @@ -4,547 +4,77 @@ import "caffe2/proto/caffe2.proto"; package torch; -// Overview -// -// ONNX is an open specification that is comprised of the following components: -// -// 1) A definition of an extensible computation graph model. -// 2) Definitions of standard data types. -// 3) Definitions of built-in operators. -// -// This document describes the syntax of models and their computation graphs, -// as well as the standard data types. Together, they are referred to as the ONNX -// Intermediate Representation, or 'IR' for short. -// -// The normative semantic specification of the ONNX IR is found in docs/IR.md. -// Definitions of the built-in neural network operators may be found in docs/Operators.md. - -// Notes -// -// Release -// -// We are still in the very early stage of defining ONNX. The current -// version of ONNX is a starting point. While we are actively working -// towards a complete spec, we would like to get the community involved -// by sharing our working version of ONNX. -// -// Protobuf compatibility -// -// To simplify framework compatibility, ONNX is defined using the subset of -// protobuf that is compatible with both protobuf v2 and v3. This means that we -// do not use any protobuf features that are only available in one of the two -// versions. -// -// Here are the most notable contortions we have to carry out to work around -// these limitations: -// -// - No 'map' (added protobuf 3.0). We instead represent mappings as lists -// of key-value pairs, where order does not matter and duplicates -// are not allowed. - -// Versioning -// -// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md -// -// To be compatible with both proto2 and proto3, we will use a version number -// that is not defined by the default value but an explicit enum number. -enum Version { - // proto3 requires the first enum value to be zero. - // We add this just to appease the compiler. +enum ProtoVersion { _START_VERSION = 0; - // The version field is always serialized and we will use it to store the - // version that the graph is generated from. This helps us set up version - // control. - // For the IR, we are using simple numbers starting with with 0x00000001, - // which was the version we published on Oct 10, 2017. - IR_VERSION_2017_10_10 = 0x0000000000000001; - - // IR_VERSION 2 published on Oct 30, 2017 - // - Added type discriminator to AttributeProto to support proto3 users - IR_VERSION_2017_10_30 = 0x0000000000000002; - - // IR VERSION 3 published on Nov 3, 2017 - // - For operator versioning: - // - Added new message OperatorSetIdProto - // - Added opset_import in ModelProto - // - For vendor extensions, added domain in NodeProto - IR_VERSION_NEWEST_ONNX = 0x0000000000000003; - - // PYTORCH IR VERSION - IR_VERSION_NEWEST = 0x0000000000000103; + IR_VERSION_NEWEST = 0x0000000000000101; } -// Attributes -// -// A named attribute containing either singular float, integer, string, graph, -// and tensor values, or repeated float, integer, string, graph, and tensor values. -// An AttributeProto MUST contain the name field, and *only one* of the -// following content fields, effectively enforcing a C/C++ union equivalent. -message AttributeProto { - - // Note: this enum is structurally identical to the OpSchema::AttrType - // enum defined in schema.h. If you rev one, you likely need to rev the other. - enum AttributeType { - UNDEFINED = 0; - FLOAT = 1; - INT = 2; - STRING = 3; - TENSOR = 4; - GRAPH = 5; - - FLOATS = 6; - INTS = 7; - STRINGS = 8; - TENSORS = 9; - GRAPHS = 10; - } - - // The name field MUST be present for this version of the IR. - optional string name = 1; // namespace Attribute +message MethodDef { + // method name + optional string name = 1; // method name - // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function. - // In this case, this AttributeProto does not contain data, and it's a reference of attribute - // in parent scope. - // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph. - optional string ref_attr_name = 21; + // static graph + optional caffe2.NetDef graph = 2; + // method is represented as torch script + optional string torch_script = 3; - // A human-readable documentation for this attribute. Markdown is allowed. - optional string doc_string = 13; + // the names of inputs and outputs + repeated string inputs = 4; + repeated string outputs = 5; - // The type field MUST be present for this version of the IR. - // For 0.0.1 versions of the IR, this field was not defined, and - // implementations needed to use has_field hueristics to determine - // which value field was in use. For IR_VERSION 0.0.2 or later, this - // field MUST be set and match the f|i|s|t|... field in use. This - // change was made to accomodate proto3 implementations. - optional AttributeType type = 20; // discriminator that indicates which field below is in use + // whether this method is main or not. + // by default, `forward` should the main method. + optional bool is_main = 6; - // Exactly ONE of the following fields must be present for this version of the IR - optional float f = 2; // float - optional int64 i = 3; // int - optional bytes s = 4; // UTF-8 string - optional TensorProto t = 5; // tensor value - optional GraphProto g = 6; // graph - // Do not use field below, it's deprecated. - // optional ValueProto v = 12; // value - subsumes everything but graph + optional string debug_info = 7; - repeated float floats = 7; // list of floats - repeated int64 ints = 8; // list of ints - repeated bytes strings = 9; // list of UTF-8 strings - repeated TensorProto tensors = 10; // list of tensors - repeated GraphProto graphs = 11; // list of graph + repeated caffe2.Argument annotations = 8; } -// Defines information on value, including the name, the type, and -// the shape of the value. -message ValueInfoProto { - // This field MUST be present in this version of the IR. - optional string name = 1; // namespace Value - // This field MUST be present in this version of the IR. - optional TypeProto type = 2; - // A human-readable documentation for this value. Markdown is allowed. - optional string doc_string = 3; -} - -// Nodes -// -// Computation graphs are made up of a DAG of nodes, which represent what is -// commonly called a "layer" or "pipeline stage" in machine learning frameworks. -// -// For example, it can be a node of type "Conv" that takes in an image, a filter -// tensor and a bias tensor, and produces the convolved output. -message NodeProto { - repeated string input = 1; // namespace Value - repeated string output = 2; // namespace Value - // An optional identifier for this node in a graph. - // This field MAY be absent in ths version of the IR. - optional string name = 3; // namespace Node +message ModuleDef { + repeated ModuleDef submodules = 1; - // The symbolic identifier of the Operator to execute. - optional string op_type = 4; // namespace Operator - // The domain of the OperatorSet that specifies the operator named by op_type. - optional string domain = 7; // namespace Domain + // We suppose to store the modules in one of the following format: + // - methods (static graph or torch script) + // - pickle + // - cpp_arena + repeated MethodDef methods = 2; + // because the old pickle modules may not be supported by torch_script, + // have to stored as pickle_arena at this moment. + optional bytes pickle_arena = 3; + // should be exposed by the Class Archive, so user can save + // module specific data which cannot be store in the graph or torch_script + optional bytes cpp_arena = 4; - // Additional named attributes. - repeated AttributeProto attribute = 5; + // the names of inputs and outputs of the module are inferred + // from the main method. - // A human-readable documentation for this node. Markdown is allowed. - // Equivalent to string debug_info - optional string doc_string = 6; + optional string debug_info = 5; - // Additional annotations, attributes are defined in Schema - // To be added as annotations: - // string engine - // string list control_input - // int64 is_gradient_op - repeated AttributeProto annotations = 8; - - // Besides the node type, PyTorhc also serialize ATen function signature - optional caffe2.DeviceOption device_option = 51; - optional string aten_function = 52; + repeated caffe2.Argument annotations = 6; } -// Models -// -// ModelProto is a top-level file/container format for bundling a ML model and -// associating its computation graph with metadata. -// -// The semantics of the model are described by the associated GraphProto. -// -// Model ==> Caffe2 MetaNetDef -// ==> PyTorch Module -message ModelProto { - // The version of the IR this model targets. See Version enum above. - // This field MUST be present. +message ModelDef { optional int64 ir_version = 1; - // The OperatorSets this model relies on. - // All ModelProtos MUST have at least one entry that - // specifies which version of the ONNX OperatorSet is - // being imported. - // - // All nodes in the ModelProto's graph will bind against the operator - // with the same-domain/same-op_type operator with the HIGHEST version - // in the referenced operator sets. - repeated OperatorSetIdProto opset_import = 8; - - // The name of the framework or tool used to generate this model. - // This field SHOULD be present to indicate which implementation/tool/framework - // emitted the model. - optional string producer_name = 2; - - // The version of the framework or tool used to generate this model. - // This field SHOULD be present to indicate which implementation/tool/framework - // emitted the model. - optional string producer_version = 3; - - // Domain name of the model. - // We use reverse domain names as name space indicators. For example: - // `com.facebook.fair` or `com.microsoft.cognitiveservices` - // - // Together with `model_version` and GraphProto.name, this forms the unique identity of - // the graph. - optional string domain = 4; - - // The version of the graph encoded. See Version enum below. - optional int64 model_version = 5; - - // A human-readable documentation for this model. Markdown is allowed. - optional string doc_string = 6; + // main module of the model + optional ModuleDef main_module = 2; - // The parameterized graph that is evaluated to execute the model. - // The main graph, in single graph case, it is ONNX compatible. - optional GraphProto graph = 7; + repeated caffe2.TensorProto parameters = 3; + repeated caffe2.TensorProto value_infos = 4; - // The remaining nets in MetaNetDef. - // Submodules and methods in PyTorch. - repeated GraphProto methods = 15; - - // Named metadata values; keys should be distinct. - // Many meta data in MetaNetDef and preditor are piggy backed here. - // 1) project - // 2) model_class - // 3) internal_version - // 4) predictor_type - // 5) predictor_id - // 6) execute_plan - // 7) applicationSpecificInfo (another string map, need to verify it has no duplicate.) - // 8) engine - // 9) publish time - repeated StringStringEntryProto metadata_props = 14; - - // Model name - optional string name = 16; - - // Model name - repeated AttributeProto annotations = 17; - - // Mapping from list name to blob name list, must be string list type. - // Equivalent to blobs in MetaNetDef. - repeated AttributeProto blob_lists = 51; - - // Mapping from plan name to serialized plan, must be string list type. - // Equivalent to plans in MetaNetDef. - repeated AttributeProto plans = 52; -}; - -// StringStringEntryProto follows the pattern for cross-proto-version maps. -// See https://developers.google.com/protocol-buffers/docs/proto3#maps -message StringStringEntryProto { - optional string key = 1; - optional string value= 2; -}; - -// Graphs -// -// A graph defines the computational logic of a model and is comprised of a parameterized -// list of nodes that form a directed acyclic graph based on their inputs and outputs. -// This is the equivalent of the "network" or "graph" in many deep learning -// frameworks. -// Graph ==> NetDef in Caffe2 -// ==> Submodule/Method in PyTorch -message GraphProto { - // The nodes in the graph, sorted topologically. - repeated NodeProto node = 1; - - // The name of the graph. - optional string name = 2; // namespace Graph - - // A list of named tensor values, used to specify constant inputs of the graph. - // Each TensorProto entry must have a distinct name (within the list) that - // also appears in the input list. - repeated TensorProto initializer = 5; - - // A human-readable documentation for this graph. Markdown is allowed. - optional string doc_string = 10; - - // The inputs and outputs of the graph. - repeated ValueInfoProto input = 11; - repeated ValueInfoProto output = 12; - - // Information for the values in the graph. The ValueInfoProto.name's - // must be distinct. It is optional for a value to appear in value_info list. - repeated ValueInfoProto value_info = 13; - - // Additional annotations. - repeated AttributeProto annotations = 14; - - // DO NOT USE the following fields, they were deprecated from earlier versions. - // repeated string input = 3; - // repeated string output = 4; - // optional int64 ir_version = 6; - // optional int64 producer_version = 7; - // optional string producer_tag = 8; - // optional string domain = 9; -} + // to distinguish whether exported from c2 or torch + optional string producer_name = 5; -// Tensors -// -// A serialized tensor value. -message TensorProto { - enum DataType { - UNDEFINED = 0; - // Basic types. - FLOAT = 1; // float - UINT8 = 2; // uint8_t - INT8 = 3; // int8_t - UINT16 = 4; // uint16_t - INT16 = 5; // int16_t - INT32 = 6; // int32_t - INT64 = 7; // int64_t - STRING = 8; // string - BOOL = 9; // bool + // put build version here + optional string producer_version = 6; - // Advanced types - FLOAT16 = 10; - DOUBLE = 11; - UINT32 = 12; - UINT64 = 13; - COMPLEX64 = 14; // complex with float32 real and imaginary components - COMPLEX128 = 15; // complex with float64 real and imaginary components - // Future extensions go here. + optional string name = 7; - // Special data type, real type information is stored in ValueInfoProto. - // If data_type is SPECIAL, raw_data should be used. - SPECIAL = 51; - } - - // The shape of the tensor. - repeated int64 dims = 1; - repeated int64 strides = 14; - - // The data type of the tensor. - optional DataType data_type = 2; - - // For very large tensors, we may want to store them in chunks, in which - // case the following fields will specify the segment that is stored in - // the current TensorProto. - message Segment { - optional int64 begin = 1; - optional int64 end = 2; - optional int64 chuck_num = 51; - optional int64 chuck_id = 52; - } - // Used as offset in the external shared data. - optional Segment segment = 3; - - // Tensor content must be organized in row-major order. - // - // Depending on the data_type field, exactly one of the fields below with - // name ending in _data is used to store the elements of the tensor. - - // For float and complex64 values - // Complex64 tensors are encoded as a single array of floats, - // with the real components appearing in odd numbered positions, - // and the corresponding imaginary component apparing in the - // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] - // is encoded as [1.0, 2.0 ,3.0 ,4.0] - // When this field is present, the data_type field MUST be FLOAT or COMPLEX64. - repeated float float_data = 4 [packed = true]; - - // For int32, uint8, int8, uint16, int16, bool, and Half values - // float16 values must be bit-wise converted to an uint16_t prior - // to writing to the buffer. - // When this field is present, the data_type field MUST be - // INT32, INT16, INT8, UINT16, INT8, BOOL, or FLOAT16 - repeated int32 int32_data = 5 [packed = true]; - - // For strings. - // Each element of string_data is a UTF-8 encoded Unicode - // string. No trailing null, no leading BOM. The protobuf "string" - // scalar type is not used to match ML community conventions. - // When this field is present, the data_type field MUST be STRING - repeated bytes string_data = 6; - - // For int64. - // When this field is present, the data_type field MUST be INT64 - repeated int64 int64_data = 7 [packed = true]; - - // Optionally, a name for the tensor. - optional string name = 8; // namespace Value - - // A human-readable documentation for this tensor. Markdown is allowed. - optional string doc_string = 12; - - // Serializations can either use one of the fields above, or use this - // raw bytes field. The only exception is the string case, where one is - // required to store the content in the repeated bytes string_data field. - // - // When this raw_data field is used to store tensor value, elements MUST - // be stored in as fixed-width, little-endian order. - // Floating-point data types MUST be stored in IEEE 754 format. - // Complex64 elements must be written as two consecutive FLOAT values, real component first. - // Complex128 elements must be written as two consecutive DOUBLE values, real component first. - // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). - // - // Note: the advantage of specific field rather than the raw_data field is - // that in some cases (e.g. int data), protobuf does a better packing via - // variable length storage, and may lead to smaller binary footprint. - // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED - optional bytes raw_data = 9; - - // For double - // Complex64 tensors are encoded as a single array of doubles, - // with the real components appearing in odd numbered positions, - // and the corresponding imaginary component apparing in the - // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] - // is encoded as [1.0, 2.0 ,3.0 ,4.0] - // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 - repeated double double_data = 10 [packed = true]; - - // For uint64 and uint32 values - // When this field is present, the data_type field MUST be - // UINT32 or UINT64 - repeated uint64 uint64_data = 11 [packed = true]; - - // External data by file name - optional string external_data = 13; - - // If two tensors represent the same weights/content, use alias. - // Must exist a TensorProto named alias in the initializer list. - // To avoid the duplicate tensor in attribute, such as value in Constant node. - // This is useful, if everything is stored just in the proto. - optional string alias = 16; - - // Additional annotations. - repeated AttributeProto annotations = 17; - - // Device info - optional caffe2.DeviceOption device_option = 51; - - // For PyTorch serialized tensor. - optional int64 require_gradient = 52; - optional int64 is_buffer = 53; -} - -// Defines a tensor shape. A dimension can be either an integer value -// or a symbolic variable. A symbolic variable represents an unknown -// dimension. -message TensorShapeProto { - message Dimension { - oneof value { - int64 dim_value = 1; - string dim_param = 2; // namespace Shape - }; - // Standard denotation can optionally be used to denote tensor - // dimensions with standard semantic descriptions to ensure - // that operations are applied to the correct axis of a tensor. - // Refer to https://github.com/onnx/onnx/blob/master/docs/DimensionDenotation.md#denotation-definition - // for pre-defined dimension denotations. - optional string denotation = 3; - }; - // To represent a scalar, using no dim to represent 0-d tensor. - repeated Dimension dim = 1; - - repeated Dimension stride = 51; -} - -// Types -// -// The standard ONNX data types. -message TypeProto { - - message Tensor { - // This field MUST NOT have the value of UNDEFINED - // This field MUST be present for this version of the IR. - optional TensorProto.DataType elem_type = 1; - optional TensorShapeProto shape = 2; - } - - // Sequence type: List, Tuple - message Sequence { - // elem_type and elem_type_list cannot appear together. - // If all the element types are the same, we use elem_type, - // otherwise, we specify the type of each element in elem_type_list. - optional TypeProto elem_type = 1; - repeated TypeProto elem_type_list = 51; - enum SequenceType { - UNDEFINED = 0; - LIST = 1; - TUPLE = 2; - } - optional SequenceType sequence_type = 52; - } - - // Map, (not necessary at this moment) - message Map { - optional TensorProto.DataType key_type = 1; - optional TypeProto value_type = 2; - } - - // Special type of blobs, based on the type_name, we can choose the right - // serializer and deserialzier. - message SpecialBlob { - optional string type_name = 1; - } - - oneof value { - // The type of a tensor. - Tensor tensor_type = 1; - Sequence sequence_type = 4; - Map map_type = 5; - SpecialBlob special_type = 51; - } - - // An optional denotation can be used to denote the whole - // type with a standard semantic description as to what is - // stored inside. Refer to https://github.com/onnx/onnx/blob/master/docs/TypeDenotation.md#type-denotation-definition - // for pre-defined type denotations. - optional string denotation = 6; -} + optional string debug_info = 8; -// Operator Sets -// -// OperatorSets are uniquely identified by a (domain, opset_version) pair. -message OperatorSetIdProto { - // The domain of the operator set being identified. - // The empty string ("") or absence of this field implies the operator - // set that is defined as part of the ONNX specification. - // This field MUST be present in this version of the IR when referring to any other operator set. - optional string domain = 1; + // annotations, it is used for MetaNetDef's metadata + repeated caffe2.Argument annotations = 9; - // The version of the operator set being identified. - // This field MUST be present in this version of the IR. - optional int64 version = 2; } diff --git a/caffe2/python/convert.py b/caffe2/python/convert.py index 50eaf220c7..44f81d6e2d 100644 --- a/caffe2/python/convert.py +++ b/caffe2/python/convert.py @@ -8,59 +8,3 @@ from __future__ import unicode_literals from caffe2.proto import caffe2_pb2, torch_pb2 import caffe2.python._import_c_extension as C - - -def ArgumentToAttributeProto(arg): - serialized_arg = None - if hasattr(arg, 'SerializeToString') and callable(arg.SerializeToString): - serialized_arg = arg.SerializeToString() - elif isinstance(arg, bytes): - serialized_arg = arg - else: - raise ValueError('No SerializeToString method is detected. ' - 'neither arg is bytes.\ntype is {}'.format(type(arg))) - attr = torch_pb2.AttributeProto() - attr.ParseFromString(C.argument_to_attribute_proto(serialized_arg)) - return attr - - -def AttributeProtoToArgument(attr): - serialized_attr = None - if hasattr(attr, 'SerializeToString') and callable(attr.SerializeToString): - serialized_attr = attr.SerializeToString() - elif isinstance(attr, bytes): - serialized_attr = attr - else: - raise ValueError('No SerializeToString method is detected. ' - 'neither attr is bytes.\ntype is {}'.format(type(attr))) - arg = caffe2_pb2.Argument() - arg.ParseFromString(C.attribute_proto_to_argument(serialized_attr)) - return arg - - -def OperatorDefToNodeProto(op_def): - serialized_op_def = None - if hasattr(op_def, 'SerializeToString') and callable(op_def.SerializeToString): - serialized_op_def = op_def.SerializeToString() - elif isinstance(op_def, bytes): - serialized_op_def = op_def - else: - raise ValueError('No SerializeToString method is detected. ' - 'neither op_def is bytes.\ntype is {}'.format(type(op_def))) - node = torch_pb2.NodeProto() - node.ParseFromString(C.operator_def_to_node_proto(serialized_op_def)) - return node - - -def NodeProtoToOperatorDef(node_proto): - serialized_node_proto = None - if hasattr(node_proto, 'SerializeToString') and callable(node_proto.SerializeToString): - serialized_node_proto = node_proto.SerializeToString() - elif isinstance(node_proto, bytes): - serialized_node_proto = node_proto - else: - raise ValueError('No SerializeToString method is detected. ' - 'neither node_proto is bytes.\ntype is {}'.format(type(node_proto))) - op_def = caffe2_pb2.OperatorDef() - op_def.ParseFromString(C.node_proto_to_operator_def(serialized_node_proto)) - return op_def diff --git a/caffe2/python/convert_test.py b/caffe2/python/convert_test.py index c8de7e9750..82c969c901 100644 --- a/caffe2/python/convert_test.py +++ b/caffe2/python/convert_test.py @@ -12,239 +12,5 @@ class TestOperator(unittest.TestCase): def setUp(self): workspace.ResetWorkspace() - def testArgument2AttributeProto(self): - arg_f = caffe2_pb2.Argument() - arg_f.name = "TestArgF" - arg_f.f = 10.0 - attr_f = convert.ArgumentToAttributeProto(arg_f) - self.assertEqual(attr_f.name, arg_f.name) - self.assertEqual(attr_f.f, arg_f.f) - - arg_i = caffe2_pb2.Argument() - arg_i.name = "TestArgI" - arg_i.i = 100 - attr_i = convert.ArgumentToAttributeProto(arg_i) - self.assertEqual(attr_i.name, arg_i.name) - self.assertEqual(attr_i.i, arg_i.i) - - arg_s = caffe2_pb2.Argument() - arg_s.name = "TestArgS" - arg_s.s = "TestS".encode("utf-8") - attr_s = convert.ArgumentToAttributeProto(arg_s) - self.assertEqual(attr_s.name, arg_s.name) - self.assertEqual(attr_s.s, arg_s.s) - - # TODO: test net arg - - arg_floats = caffe2_pb2.Argument() - arg_floats.name = "TestArgFloats" - arg_floats.floats.extend([10.0, 11.0, 12.0]) - attr_floats = convert.ArgumentToAttributeProto(arg_floats) - self.assertEqual(attr_floats.name, arg_floats.name) - self.assertEqual(attr_floats.floats, arg_floats.floats) - - arg_ints = caffe2_pb2.Argument() - arg_ints.name = "TestArgInts" - arg_ints.ints.extend([100, 101, 102]) - attr_ints = convert.ArgumentToAttributeProto(arg_ints) - self.assertEqual(attr_ints.name, arg_ints.name) - self.assertEqual(attr_ints.ints, arg_ints.ints) - - arg_strings = caffe2_pb2.Argument() - arg_strings.name = "TestArgStrings" - arg_strings.strings.extend([ - "TestStrings1".encode("utf-8"), - "TestStrings2".encode("utf-8"), - ]) - attr_strings = convert.ArgumentToAttributeProto(arg_strings) - self.assertEqual(attr_strings.name, arg_strings.name) - self.assertEqual(attr_strings.strings, arg_strings.strings) - - # TODO: test nets arg - - def testAttributeProto2Argument(self): - attr_f = torch_pb2.AttributeProto() - attr_f.type = torch_pb2.AttributeProto.FLOAT - attr_f.name = "TestAttrF" - attr_f.f = 10.0 - arg_f = convert.AttributeProtoToArgument(attr_f) - self.assertEqual(arg_f.name, attr_f.name) - self.assertEqual(arg_f.f, attr_f.f) - - attr_i = torch_pb2.AttributeProto() - attr_i.type = torch_pb2.AttributeProto.INT - attr_i.name = "TestArgI" - attr_i.i = 100 - arg_i = convert.AttributeProtoToArgument(attr_i) - self.assertEqual(arg_i.name, attr_i.name) - self.assertEqual(arg_i.i, attr_i.i) - - attr_s = torch_pb2.AttributeProto() - attr_s.type = torch_pb2.AttributeProto.STRING - attr_s.name = "TestArgS" - attr_s.s = "TestS".encode("utf-8") - arg_s = convert.AttributeProtoToArgument(attr_s) - self.assertEqual(arg_s.name, attr_s.name) - self.assertEqual(arg_s.s, attr_s.s) - - # TODO: test graph attribute - - attr_floats = torch_pb2.AttributeProto() - attr_floats.type = torch_pb2.AttributeProto.FLOATS - attr_floats.name = "TestAttrFloats" - attr_floats.floats.extend([10.0, 11.0, 12.0]) - arg_floats = convert.AttributeProtoToArgument(attr_floats) - self.assertEqual(arg_floats.name, attr_floats.name) - self.assertEqual(arg_floats.floats, attr_floats.floats) - - attr_ints = torch_pb2.AttributeProto() - attr_ints.type = torch_pb2.AttributeProto.INTS - attr_ints.name = "TestArgInts" - attr_ints.ints.extend([100, 101, 102]) - arg_ints = convert.AttributeProtoToArgument(attr_ints) - self.assertEqual(arg_ints.name, attr_ints.name) - self.assertEqual(arg_ints.ints, attr_ints.ints) - - attr_strings = torch_pb2.AttributeProto() - attr_strings.type = torch_pb2.AttributeProto.STRINGS - attr_strings.name = "TestArgStrings" - attr_strings.strings.extend([ - "TestStrings1".encode("utf-8"), - "TestStrings2".encode("utf-8"), - ]) - arg_strings = convert.AttributeProtoToArgument(attr_strings) - self.assertEqual(arg_strings.name, attr_strings.name) - self.assertEqual(arg_strings.strings, attr_strings.strings) - - # TODO: test graphs attribute - - - def testOperatorDef2NodeProto(self): - op_def = caffe2_pb2.OperatorDef() - op_def.input.extend(["A", "B", "C"]) - op_def.output.extend(["X", "Y"]) - op_def.name = "TestOpName" - op_def.type = "TestOp" - arg1 = caffe2_pb2.Argument() - arg1.name = "TestArg1" - arg1.i = 1 - arg2 = caffe2_pb2.Argument() - arg2.name = "TestArg2" - arg1.s = "TestInfo".encode("utf-8") - op_def.arg.extend([arg1, arg2]) - op_def.device_option.CopyFrom(caffe2_pb2.DeviceOption()) - op_def.engine = "TestEngine".encode("utf-8") - op_def.control_input.extend(["input1", "input2"]) - op_def.is_gradient_op = True - op_def.debug_info = "TestDebugInfo" - - node = convert.OperatorDefToNodeProto(op_def) - - self.assertEqual(node.input, op_def.input) - self.assertEqual(node.output, op_def.output) - self.assertEqual(node.name, op_def.name) - self.assertEqual(node.op_type, op_def.type) - self.assertEqual(node.attribute[0].name, op_def.arg[0].name) - self.assertEqual(node.attribute[1].name, op_def.arg[1].name) - self.assertEqual(node.device_option, op_def.device_option) - node_engine = [a.s.decode("utf-8") for a in node.annotations if a.name == "engine"][0] - self.assertEqual(node_engine, op_def.engine) - node_control_input = [a.strings for a in node.annotations if a.name == "control_input"][0] - self.assertEqual(len(node_control_input), len(op_def.control_input)) - for x, y in zip(node_control_input, op_def.control_input): - self.assertEqual(x.decode("utf-8"), y) - self.assertEqual(node.doc_string, op_def.debug_info) - node_is_gradient_op = [a.i for a in node.annotations if a.name == "is_gradient_op"][0] - self.assertEqual(node_is_gradient_op, int(op_def.is_gradient_op)) - - def testNodeProto2OperatorDef(self): - node = torch_pb2.NodeProto() - node.input.extend(["A", "B", "C"]) - node.output.extend(["X", "Y"]) - node.name = "TestOpName" - node.op_type = "TestOp" - attr1 = torch_pb2.AttributeProto() - attr1.name = "TestAttr1" - attr1.type = torch_pb2.AttributeProto.STRING - attr1.s = "TestInfo".encode("utf-8") - attr2 = torch_pb2.AttributeProto() - attr2.name = "TestAttr2" - attr2.type = torch_pb2.AttributeProto.INT - attr2.i = 10 - node.attribute.extend([attr1, attr2]) - node.device_option.CopyFrom(caffe2_pb2.DeviceOption()) - anno1 = torch_pb2.AttributeProto() - anno1.name = "engine" - anno1.type = torch_pb2.AttributeProto.STRING - anno1.s = "TestEngine".encode("utf-8") - anno2 = torch_pb2.AttributeProto() - anno2.name = "control_input" - anno2.type = torch_pb2.AttributeProto.STRINGS - anno2.strings.extend(["input1".encode("utf-8"), "input2".encode("utf-8")]) - anno3 = torch_pb2.AttributeProto() - anno3.name = "is_gradient_op" - anno3.type = torch_pb2.AttributeProto.INT - anno3.i = 1 - node.annotations.extend([anno1, anno2, anno3]) - node.doc_string = "TestDocString".encode("utf-8") - - op_def = convert.NodeProtoToOperatorDef(node) - - self.assertEqual(op_def.input, node.input) - self.assertEqual(op_def.output, node.output) - self.assertEqual(op_def.name, node.name) - self.assertEqual(op_def.type, node.op_type) - self.assertEqual(op_def.arg[0].name, node.attribute[0].name) - self.assertEqual(op_def.arg[1].name, node.attribute[1].name) - self.assertEqual(op_def.device_option, node.device_option) - node_engine = [a.s for a in node.annotations if a.name == "engine"][0] - self.assertEqual(op_def.engine, node_engine.decode("utf-8")) - node_control_input = [a.strings for a in node.annotations if a.name == "control_input"][0] - for x, y in zip(op_def.control_input, node_control_input): - self.assertEqual(x, y.decode("utf-8")) - self.assertEqual(op_def.debug_info, node.doc_string) - node_is_gradient_op = [a.i for a in node.annotations if a.name == "is_gradient_op"][0] - self.assertEqual(int(op_def.is_gradient_op), node_is_gradient_op) - - def testEnd2End(self): - op_def = caffe2_pb2.OperatorDef() - op_def.type = "Add" - op_def.input.extend(["input1"]) - op_def.input.extend(["input2"]) - op_def.output.extend(["output1"]) - node = convert.OperatorDefToNodeProto(op_def) - - input1 = np.random.randn(1, 3, 1, 5).astype(np.float32) - input2 = np.random.randn(2, 1, 4, 1).astype(np.float32) - ref_output1 = input1 + input2 - workspace.FeedBlob("input1", input1) - workspace.FeedBlob("input2", input2) - self.assertEqual(workspace.RunOperatorOnce(node.SerializeToString(), legacy_proto=False), True) - - self.assertEqual(workspace.HasBlob("output1"), True) - fetched_back = workspace.FetchBlob("output1") - np.testing.assert_array_equal(fetched_back, ref_output1) - - def testRoundTrip(self): - op_def = caffe2_pb2.OperatorDef() - op_def.type = "Add" - op_def.input.extend(["input1"]) - op_def.input.extend(["input2"]) - op_def.output.extend(["output1"]) - node = convert.OperatorDefToNodeProto(op_def) - new_op_def = convert.NodeProtoToOperatorDef(node) - - input1 = np.random.randn(1, 3, 1, 5).astype(np.float32) - input2 = np.random.randn(2, 1, 4, 1).astype(np.float32) - ref_output1 = input1 + input2 - workspace.FeedBlob("input1", input1) - workspace.FeedBlob("input2", input2) - self.assertEqual(workspace.RunOperatorOnce(new_op_def.SerializeToString()), True) - - self.assertEqual(workspace.HasBlob("output1"), True) - fetched_back = workspace.FetchBlob("output1") - np.testing.assert_array_equal(fetched_back, ref_output1) - - if __name__ == '__main__': unittest.main() diff --git a/caffe2/python/pybind_state.cc b/caffe2/python/pybind_state.cc index 7062ead045..7ebee57d49 100644 --- a/caffe2/python/pybind_state.cc +++ b/caffe2/python/pybind_state.cc @@ -1187,17 +1187,10 @@ void addGlobalMethods(py::module& m) { return true; }); m.def("nets", []() { return gWorkspace->Nets(); }); - m.def("run_operator_once", [](const py::bytes& op_def, bool legacy_proto=true) { + m.def("run_operator_once", [](const py::bytes& op_def) { CAFFE_ENFORCE(gWorkspace); OperatorDef def; - if (legacy_proto) { - CAFFE_ENFORCE(ParseProtoFromLargeString(op_def.cast(), &def)); - } else { - ::torch::NodeProto node; - CAFFE_ENFORCE( - ParseProtoFromLargeString(op_def.cast(), &node)); - NodeProtoToOperatorDef(node, &def); - } + CAFFE_ENFORCE(ParseProtoFromLargeString(op_def.cast(), &def)); py::gil_scoped_release g; CAFFE_ENFORCE(gWorkspace->RunOperatorOnce(def)); return true; @@ -1534,38 +1527,6 @@ void addGlobalMethods(py::module& m) { CAFFE_ENFORCE(blob); return BlobStat::sizeBytes(*blob); }); - m.def("argument_to_attribute_proto", [](py::bytes arg_str) -> py::bytes { - Argument arg; - CAFFE_ENFORCE( - ParseProtoFromLargeString(arg_str.cast(), &arg)); - ::torch::AttributeProto attr; - ArgumentToAttributeProto(arg, &attr); - return attr.SerializeAsString(); - }); - m.def("attribute_proto_to_argument", [](py::bytes attr_str) -> py::bytes { - ::torch::AttributeProto attr; - CAFFE_ENFORCE( - ParseProtoFromLargeString(attr_str.cast(), &attr)); - Argument arg; - AttributeProtoToArgument(attr, &arg); - return arg.SerializeAsString(); - }); - m.def("operator_def_to_node_proto", [](py::bytes op_str) -> py::bytes { - OperatorDef op_def; - CAFFE_ENFORCE( - ParseProtoFromLargeString(op_str.cast(), &op_def)); - ::torch::NodeProto node; - OperatorDefToNodeProto(op_def, &node); - return node.SerializeAsString(); - }); - m.def("node_proto_to_operator_def", [](py::bytes node_str) -> py::bytes { - ::torch::NodeProto node_proto; - CAFFE_ENFORCE( - ParseProtoFromLargeString(node_str.cast(), &node_proto)); - OperatorDef op_def; - NodeProtoToOperatorDef(node_proto, &op_def); - return op_def.SerializeAsString(); - }); m.def("support_onnx_export", [](const std::string& op) -> bool { const OpSchema* schema = caffe2::OpSchemaRegistry::Schema(op); if (!schema) { diff --git a/caffe2/python/workspace.py b/caffe2/python/workspace.py index a41cc15317..ef02f64dc9 100644 --- a/caffe2/python/workspace.py +++ b/caffe2/python/workspace.py @@ -163,8 +163,8 @@ def GetOperatorCost(operator, blobs): return C.get_operator_cost(StringifyProto(operator), blobs) -def RunOperatorOnce(operator, legacy_proto=True): - return C.run_operator_once(StringifyProto(operator), legacy_proto) +def RunOperatorOnce(operator): + return C.run_operator_once(StringifyProto(operator)) def RunOperatorsOnce(operators): diff --git a/caffe2/utils/proto_convert.cc b/caffe2/utils/proto_convert.cc index 790bd27429..1d69c8c80c 100644 --- a/caffe2/utils/proto_convert.cc +++ b/caffe2/utils/proto_convert.cc @@ -2,185 +2,4 @@ #include "caffe2/core/logging.h" namespace caffe2 { - -C10_EXPORT void ArgumentToAttributeProto( - const Argument& arg, - ::torch::AttributeProto* attr) { - CAFFE_ENFORCE(arg.has_name()); - attr->set_name(arg.name()); - if (arg.has_f()) { - attr->set_f(arg.f()); - } else if (arg.has_i()) { - attr->set_i(arg.i()); - } else if (arg.has_s()) { - attr->set_s(arg.s()); - } else if (arg.has_n()) { - // TODO - CAFFE_THROW("NetDef conversion is not implemented yet."); - } else if (arg.floats_size() > 0) { - attr->mutable_floats()->CopyFrom(arg.floats()); - } else if (arg.ints_size() > 0) { - attr->mutable_ints()->CopyFrom(arg.ints()); - } else if (arg.strings_size() > 0) { - attr->mutable_strings()->CopyFrom(arg.strings()); - } else if (arg.nets_size() > 0) { - // TODO - CAFFE_THROW("NetDefs conversion is not implemented yet."); - } -} - -C10_EXPORT void AttributeProtoToArgument( - const ::torch::AttributeProto& attr, - Argument* arg) { - CAFFE_ENFORCE(attr.has_name()); - arg->set_name(attr.name()); - CAFFE_ENFORCE(attr.has_type()); - const auto type = attr.type(); - if (type == - ::torch::AttributeProto_AttributeType:: - AttributeProto_AttributeType_FLOAT) { - CAFFE_ENFORCE(attr.has_f()); - arg->set_f(attr.f()); - } else if ( - type == - ::torch::AttributeProto_AttributeType::AttributeProto_AttributeType_INT) { - CAFFE_ENFORCE(attr.has_i()); - arg->set_i(attr.i()); - } else if ( - type == - ::torch::AttributeProto_AttributeType:: - AttributeProto_AttributeType_STRING) { - CAFFE_ENFORCE(attr.has_s()); - arg->set_s(attr.s()); - } else if ( - type == - ::torch::AttributeProto_AttributeType:: - AttributeProto_AttributeType_TENSOR) { - CAFFE_THROW("Caffe2's Argument does not support tensor as attribute."); - } else if ( - type == - ::torch::AttributeProto_AttributeType:: - AttributeProto_AttributeType_GRAPH) { - // TODO - CAFFE_THROW("GraphProto conversion is not implemented yet."); - } else if ( - type == - ::torch::AttributeProto_AttributeType:: - AttributeProto_AttributeType_FLOATS) { - arg->mutable_floats()->CopyFrom(attr.floats()); - } else if ( - type == - ::torch::AttributeProto_AttributeType:: - AttributeProto_AttributeType_INTS) { - arg->mutable_ints()->CopyFrom(attr.ints()); - } else if ( - type == - ::torch::AttributeProto_AttributeType:: - AttributeProto_AttributeType_STRINGS) { - arg->mutable_strings()->CopyFrom(attr.strings()); - } else if ( - type == - ::torch::AttributeProto_AttributeType:: - AttributeProto_AttributeType_TENSORS) { - CAFFE_THROW("Caffe2's Argument does not support tensors as attribute."); - } else if ( - type == - ::torch::AttributeProto_AttributeType:: - AttributeProto_AttributeType_GRAPHS) { - // TODO - CAFFE_THROW("GraphProtos conversion is not implemented yet."); - } else { - CAFFE_THROW("Unknow Attribute type."); - } -} - -C10_EXPORT void OperatorDefToNodeProto( - const OperatorDef& def, - ::torch::NodeProto* node) { - node->mutable_input()->CopyFrom(def.input()); - node->mutable_output()->CopyFrom(def.output()); - if (def.has_name()) { - node->set_name(def.name()); - } - CAFFE_ENFORCE(def.has_type()); - node->set_op_type(def.type()); - for (int i = 0; i < def.arg_size(); ++i) { - auto attr = node->add_attribute(); - ArgumentToAttributeProto(def.arg(i), attr); - } - if (def.has_device_option()) { - node->mutable_device_option()->CopyFrom(def.device_option()); - } - if (def.has_engine()) { - auto attr = node->add_annotations(); - attr->set_name("engine"); - attr->set_type(::torch::AttributeProto_AttributeType:: - AttributeProto_AttributeType_STRING); - attr->set_s(def.engine()); - } - if (def.control_input_size() > 0) { - auto attr = node->add_annotations(); - attr->set_name("control_input"); - attr->set_type(::torch::AttributeProto_AttributeType:: - AttributeProto_AttributeType_STRINGS); - attr->mutable_strings()->CopyFrom(def.control_input()); - } - if (def.has_is_gradient_op()) { - auto attr = node->add_annotations(); - attr->set_name("is_gradient_op"); - attr->set_type(::torch::AttributeProto_AttributeType:: - AttributeProto_AttributeType_INT); - if (def.is_gradient_op()) { - attr->set_i(1); - } else { - attr->set_i(0); - } - } - if (def.has_debug_info()) { - node->set_doc_string(def.debug_info()); - } -} - -C10_EXPORT void NodeProtoToOperatorDef( - const ::torch::NodeProto& node, - OperatorDef* def) { - def->mutable_input()->CopyFrom(node.input()); - def->mutable_output()->CopyFrom(node.output()); - if (node.has_name()) { - def->set_name(node.name()); - } - - CAFFE_ENFORCE(node.has_op_type()); - def->set_type(node.op_type()); - for (int i = 0; i < node.attribute_size(); ++i) { - auto arg = def->add_arg(); - AttributeProtoToArgument(node.attribute(i), arg); - } - if (node.has_doc_string()) { - def->set_debug_info(node.doc_string()); - } - for (int i = 0; i < node.annotations_size(); ++i) { - const auto& attr = node.annotations(i); - CAFFE_ENFORCE(attr.has_name()); - if (attr.name() == "engine") { - CAFFE_ENFORCE(attr.has_s()); - def->set_engine(attr.s()); - } else if (attr.name() == "control_input") { - def->mutable_control_input()->CopyFrom(attr.strings()); - } else if (attr.name() == "is_gradient_op") { - CAFFE_ENFORCE(attr.has_i()); - if (i == 0) { - def->set_is_gradient_op(false); - } else { - def->set_is_gradient_op(true); - } - } - auto arg = def->add_arg(); - AttributeProtoToArgument(node.annotations(i), arg); - } - if (node.has_device_option()) { - def->mutable_device_option()->CopyFrom(node.device_option()); - } -} - } // namespace caffe2 diff --git a/caffe2/utils/proto_convert.h b/caffe2/utils/proto_convert.h index a9ca9c3ad4..91bcf1bafa 100644 --- a/caffe2/utils/proto_convert.h +++ b/caffe2/utils/proto_convert.h @@ -6,20 +6,6 @@ #include "caffe2/proto/torch_pb.h" namespace caffe2 { - -CAFFE2_API void ArgumentToAttributeProto( - const Argument& arg, - ::torch::AttributeProto* attr); -CAFFE2_API void AttributeProtoToArgument( - const ::torch::AttributeProto& attr, - Argument* arg); -CAFFE2_API void OperatorDefToNodeProto( - const OperatorDef& def, - ::torch::NodeProto* node); -CAFFE2_API void NodeProtoToOperatorDef( - const ::torch::NodeProto& node, - OperatorDef* def); - } // namespace caffe2 #endif // CAFFE2_UTILS_PROTO_CONVERT_H_ -- cgit v1.2.3