diff options
author | Yinghai Lu <yinghai@fb.com> | 2018-02-20 13:56:52 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-02-20 13:56:52 -0800 |
commit | cc7e61c88d8250c31bd5cd4897323335bc550d3a (patch) | |
tree | 8afe634a8aba693a92344b0dc98c9cce5fb03651 | |
parent | 7283d5194a71e5fb47baa7025f7d3874b8fab3c6 (diff) | |
download | pytorch-cc7e61c88d8250c31bd5cd4897323335bc550d3a.tar.gz pytorch-cc7e61c88d8250c31bd5cd4897323335bc550d3a.tar.bz2 pytorch-cc7e61c88d8250c31bd5cd4897323335bc550d3a.zip |
Move onnx-caffe2 inside caffe2 (#1921)
* Move onnx-caffe2 inside caffe2
* Update to the lastest onnx-caffe2 and update jenkins env
* Rename onnx_caffe2 to onnx
* Add __init__.py to caffe2/python/onnx
* Change CI check variable to JENKINS_URL
* Cherrypick recent onnx-caffe2 update
-rw-r--r-- | caffe2/python/onnx/__init__.py | 0 | ||||
-rw-r--r-- | caffe2/python/onnx/backend.py | 1159 | ||||
-rw-r--r-- | caffe2/python/onnx/backend_rep.py | 77 | ||||
-rw-r--r-- | caffe2/python/onnx/bin/conversion.py | 104 | ||||
-rw-r--r-- | caffe2/python/onnx/error.py | 23 | ||||
-rw-r--r-- | caffe2/python/onnx/frontend.py | 551 | ||||
-rw-r--r-- | caffe2/python/onnx/helper.py | 157 | ||||
-rw-r--r-- | caffe2/python/onnx/tests/caffe2_ref_test.py | 357 | ||||
-rw-r--r-- | caffe2/python/onnx/tests/conversion_test.py | 244 | ||||
-rw-r--r-- | caffe2/python/onnx/tests/helper_test.py | 45 | ||||
-rw-r--r-- | caffe2/python/onnx/tests/onnx_backend_test.py | 50 | ||||
-rw-r--r-- | caffe2/python/onnx/tests/optimize_onnx_test.py | 118 | ||||
-rw-r--r-- | caffe2/python/onnx/tests/ssa_test.py | 122 | ||||
-rw-r--r-- | caffe2/python/onnx/tests/test_utils.py | 43 | ||||
-rw-r--r-- | caffe2/python/onnx/workspace.py | 80 | ||||
-rwxr-xr-x | docker/jenkins/common/install_python.sh | 1 |
16 files changed, 3131 insertions, 0 deletions
diff --git a/caffe2/python/onnx/__init__.py b/caffe2/python/onnx/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/caffe2/python/onnx/__init__.py diff --git a/caffe2/python/onnx/backend.py b/caffe2/python/onnx/backend.py new file mode 100644 index 0000000000..95f8f59f0f --- /dev/null +++ b/caffe2/python/onnx/backend.py @@ -0,0 +1,1159 @@ +# Copyright (c) 2016-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## + +## @package onnx +# Module caffe2.python.onnx.backend + +"""Backend for running ONNX on Caffe2 + +To run this, you will need to have Caffe2 installed as well. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import os +import collections +from subprocess import Popen, PIPE + +import caffe2 +from caffe2.python import core, workspace, rnn_cell, gru_cell +from caffe2.python.model_helper import ModelHelper +from caffe2.proto import caffe2_pb2 +import caffe2.python.utils +import numpy as np +import onnx +from onnx import checker, GraphProto, TensorProto, AttributeProto, ModelProto +import onnx.numpy_helper +import onnx.defs +import onnx.optimizer +from onnx.backend.base import Backend, Device, DeviceType, namedtupledict + +from caffe2.python.onnx.workspace import Workspace +from caffe2.python.onnx.backend_rep import Caffe2Rep +from caffe2.python.onnx.helper import dummy_name + +import warnings + +def force_unicode(s): + try: + return s.decode('utf-8') + except AttributeError: + return s + +def get_device_option(device): + m = {DeviceType.CPU: caffe2_pb2.CPU, + DeviceType.CUDA: caffe2_pb2.CUDA} + return core.DeviceOption(m[device.type], device.device_id) + + +class OnnxAttributes(dict): + """ + This is a more convenient way to work with ONNX/Caffe2 attributes + that is not the protobuf representation. + """ + @staticmethod + def from_onnx(args): + d = OnnxAttributes() + for arg in args: + d[arg.name] = convertAttributeProto(arg) + return d + + def caffe2(self, kmap=lambda k: k): + for k, v in self.items(): + if kmap(k) != '': + yield caffe2.python.utils.MakeArgument(kmap(k), v) + + +# TODO: Move this into ONNX main library +def convertAttributeProto(onnx_arg): + """ + Convert an ONNX AttributeProto into an appropriate Python object + for the type. + + NB: Tensor attribute gets returned as the straight proto. + """ + if onnx_arg.HasField('f'): + return onnx_arg.f + elif onnx_arg.HasField('i'): + return onnx_arg.i + elif onnx_arg.HasField('s'): + return onnx_arg.s + elif onnx_arg.HasField('t'): + return onnx_arg.t # this is a proto! + elif len(onnx_arg.floats): + return list(onnx_arg.floats) + elif len(onnx_arg.ints): + return list(onnx_arg.ints) + elif len(onnx_arg.strings): + return list(onnx_arg.strings) + else: + raise ValueError("Unsupported ONNX attribute: {}".format(onnx_arg)) + + +# TODO: Move this into ONNX main library +class OnnxNode(object): + """ + Reimplementation of NodeProto from ONNX, but in a form + more convenient to work with from Python. + + We may temporarily edit these nodes to get them into Caffe2 form, + before actually translating into the Caffe2 protobuf, since this + is easier than decomposing everything, and putting it back together + when we're ready. + """ + def __init__(self, node): + self.name = str(node.name) + self.op_type = str(node.op_type) + self.attrs = OnnxAttributes.from_onnx(node.attribute) + self.consumed_inputs = self.attrs.pop("consumed_inputs", None) + self.inputs = list(node.input) + self.outputs = list(node.output) + + +Caffe2Ops = collections.namedtuple('Caffe2Ops', ['ops', 'init_ops', 'interface_blobs']) + + +class Caffe2Backend(Backend): + + # The greatest version of the ONNX operator set which we are aware of. + # Models whose version is larger than this will cause us to emit a warning + # that we are attempting to translate on a "best effort" basis. + # + # If you increase this, make SURE you cross-reference all BC-breaking + # changes from one version to the next, and any that you did not + # implement, mark as broken in _broken_operators + _known_opset_version = 3 + + # This dictionary will record operators which are KNOWN to be + # broken, so we give a good error message rather than do something + # bogus and then fail. + _broken_operators = { + # 'BrokenOp': version_it_was_broken_in + } + + # Operators that are different between Caffe2 and + # ONNX but only in their name. + # In most cases, this should be empty - as the effort of ONNX is + # to unify the operator definitions. + _renamed_operators = { + 'Caffe2ConvTranspose': 'ConvTranspose', + 'GlobalMaxPool': 'MaxPool', + 'GlobalAveragePool': 'AveragePool', + 'Pad': 'PadImage', + 'Neg': 'Negative', + 'BatchNormalization': 'SpatialBN', + 'InstanceNormalization': 'InstanceNorm', + 'MatMul': 'BatchMatMul', + 'Upsample': 'ResizeNearest', + 'Identity': 'Copy', + 'InstanceNormalization': 'InstanceNorm', + 'Equal': 'EQ', + 'Less': 'LT', + 'Greater': 'GT', + 'Unsqueeze': 'ExpandDims', + } + + _global_renamed_attrs = {'kernel_shape': 'kernels'} + _per_op_renamed_attrs = { + 'Squeeze': {'axes': 'dims'}, + 'Unsqueeze': {'axes': 'dims'}, + 'Transpose': {'perm': 'axes'}, + 'Upsample': {'mode': ''}, + 'ConvTranspose': {'output_padding': 'adjs'}, + 'Selu': {'gamma': 'scale'}, + } + + # operators whose behavior is different beyond renaming + # the value is an attribute of this class that is a + # function from ToffeIR node_def to caffe2 op_def + _special_operators = { + 'Constant': '_create_constant', + 'Conv': '_create_conv_pool_op_base', + 'AveragePool': '_create_conv_pool_op_base', + 'GlobalAveragePool': '_create_conv_pool_op_base', + 'GlobalMaxPool': '_create_conv_pool_op_base', + 'MaxPool': '_create_conv_pool_op_base', + 'Reshape': '_create_reshape', + 'Gather': '_create_gather', + 'Gemm': '_create_gemm', + 'Pad': '_create_pad', + 'Concat': '_create_concat', + 'LogSoftmax': '_create_logsoftmax', + 'Slice': '_create_slice', + 'LSTM': '_create_lstm', + 'GRU': '_create_gru', + 'RNN': '_create_rnn', + 'Sqrt': '_create_sqrt', + 'Reciprocal': '_create_reciprocal', + } + + # NB: By default, you will use the LATEST definition of the operator, + # so this interface MAY make BC-breaking changes. Specify an + # opset_version if you don't want this to version. + @classmethod + def run_node(cls, node, inputs, device='CPU', opset_version=_known_opset_version): + super(Caffe2Backend, cls).run_node(node, inputs, device) + + device_option = get_device_option(Device(device)) + with Workspace(), core.DeviceScope(device_option): # temporary! + if isinstance(inputs, dict): + for key, value in inputs.items(): + workspace.FeedBlob(key, value) + else: + assert len(node.input) == len(inputs), "{}: expected {} but got {}".format( + node.op_type, len(node.input), len(inputs)) + for key, value in zip(node.input, inputs): + workspace.FeedBlob(key, value) + + cls._inplace_rewrite([node]) + init_ops, ops, _ = cls._onnx_node_to_caffe2_op( + None, None, node, opset_version or cls._known_opset_version) + ops = init_ops + ops + for op in ops: + op.device_option.CopyFrom(device_option) + workspace.RunOperatorsOnce(ops) + output_values = [workspace.FetchBlob(name) for name in node.output] + return namedtupledict('Outputs', node.output)(*output_values) + + @classmethod + def _create_tensor_filling_op(cls, onnx_tensor, name=None): + """ + Given an Onnx TensorProto, translate it into a Caffe2 operator + which produces the given tensor filling op. + """ + assert name or onnx_tensor.name + name = name or onnx_tensor.name + + c2_op = caffe2_pb2.OperatorDef() + + c2_values = c2_op.arg.add() + c2_values.name = "values" + + def tensor2list(onnx_tensor): + # Use the onnx.numpy_helper because the data may be raw + return onnx.numpy_helper.to_array(onnx_tensor).flatten().tolist() + + if onnx_tensor.data_type in [TensorProto.FLOAT]: + c2_op.type = 'GivenTensorFill' + c2_values.floats.extend(tensor2list(onnx_tensor)) + elif onnx_tensor.data_type in [TensorProto.DOUBLE]: + c2_op.type = 'GivenTensorDoubleFill' + c2_values.floats.extend(tensor2list(onnx_tensor)) + elif onnx_tensor.data_type in [TensorProto.INT64, + TensorProto.UINT32]: + c2_op.type = 'GivenTensorInt64Fill' + c2_values.ints.extend(tensor2list(onnx_tensor)) + elif onnx_tensor.data_type in [TensorProto.UINT8, + TensorProto.INT8, + TensorProto.UINT16, + TensorProto.INT16, + TensorProto.INT32]: + c2_op.type = 'GivenTensorIntFill' + c2_values.ints.extend(tensor2list(onnx_tensor)) + elif onnx_tensor.data_type == TensorProto.BOOL: + c2_op.type = 'GivenTensorBoolFill' + c2_values.ints.extend(tensor2list(onnx_tensor)) + elif onnx_tensor.data_type == TensorProto.STRING: + c2_op.type = 'GivenTensorStringFill' + c2_values.strings.extend(onnx_tensor.string_data) + else: + raise RuntimeError( + "unrecognized tensor type {}".format(onnx_tensor.data_type)) + + c2_shape = c2_op.arg.add() + c2_shape.name = "shape" + c2_shape.ints.extend(onnx_tensor.dims) + + c2_op.output.append(name) + + return c2_op + + @classmethod + def _create_constant(cls, init_model, pred_model, n, opset_version): + assert len(n.outputs) == 1 + return cls._create_tensor_filling_op(n.attrs["value"], n.outputs[0]) + + @classmethod + def _create_gather(cls, init_model, pred_model, n, opset_version): + (A, B) = n.inputs + (Y, ) = n.outputs + axis = n.attrs.get('axis', 0) + + if axis == 0: + return core.CreateOperator("Gather", [A, B], [Y]) + elif axis == 1: + return core.CreateOperator("BatchGather", [A, B], [Y]) + raise ValueError( + 'Caffe2 only supports Gather with axis being 0 or 1,' + + 'whereas axis is ' + str(axis)) + + @classmethod + def _create_logsoftmax(cls, init_model, pred_model, n, opset_version): + # NB: this implementation is not backward stable. + (A,) = n.inputs + (Y,) = n.outputs + axis = n.attrs.get('axis', 1) + ops = [] + softmax_A = dummy_name() + ops.append(core.CreateOperator('Softmax', [A], [softmax_A], axis=axis)) + ops.append(core.CreateOperator('Log', [softmax_A], [Y])) + return ops + + @classmethod + def _create_gemm(cls, init_model, pred_model, n, opset_version): + (A, B, C) = n.inputs + (Y,) = n.outputs + alpha = n.attrs.get('alpha', 1.) + beta = n.attrs.get('beta', 1.) + + ops = [] + if alpha != 1: + scaled_A = dummy_name() + ops.append(core.CreateOperator('Scale', [A], [scaled_A], scale=alpha)) + A = scaled_A + if beta != 1: + scaled_C = dummy_name() + ops.append(core.CreateOperator('Scale', [C], [scaled_C], scale=beta)) + C = scaled_C + + trans_a = n.attrs.get('transA', 0) + trans_b = n.attrs.get('transB', 0) + broadcast = n.attrs.get('broadcast', 0) + if not trans_a and trans_b and broadcast: + ops.append(core.CreateOperator('FC', + [A, B, C], + [Y])) + else: + AB = dummy_name() + ops.append(core.CreateOperator('MatMul', + [A, B], + [AB], + trans_a=trans_a, + trans_b=trans_b)) + ops.append(core.CreateOperator('Add', + [AB, C], + [Y], + broadcast=broadcast)) + + return ops + + @classmethod + def _rnn_shape_inference(cls, init_model, pred_model, n, input_blob, W): + # ad-hoc, informally-specified, bug-ridden, slow + # implementation of shape inference + + # if the weight matrices are directly provided as + # initializers, their dimensions should be available in the + # init net model. + for x in init_model.graph.input: + if x.name == W: + return x.type.tensor_type.shape.dim[1].dim_value + + # otherwise, assume that the input_blob is either a direct + # graph input, or another rnn op of the same type. This + # matches the pattern produced by exporting from pytorch + # (where the weight matrices are unusable for this purpose due + # to reshaping operations that lose shape information). + for x in pred_model.graph.input: + if x.name == input_blob: + return x.type.tensor_type.shape.dim[2].dim_value + + curr = n + while True: + for x in pred_model.graph.input: + if x.name == curr.inputs[0] and curr.op_type == 'Gather': + return x.type.tensor_type.shape.dim[1].dim_value + prev = [x for x in map(OnnxNode, pred_model.graph.node) if x.outputs[0] == curr.inputs[0]] + if len(prev) != 1: + return + prev = prev[0] + if prev.op_type == n.op_type: + return prev.attrs['hidden_size'] + curr = prev + + @classmethod + def _create_rnn(cls, init_model, pred_model, n, opset_version): + assert init_model is not None, "cannot convert RNNs without access to the full model" + assert pred_model is not None, "cannot convert RNNs without access to the full model" + + attrs = dict(n.attrs) # make a copy, which is safe to mutate + hidden_size = attrs.pop('hidden_size') + activation = force_unicode(attrs.pop('activations', ('tanh',))[0]) + direction = force_unicode(attrs.pop('direction', 'forward')) + assert not attrs, "unsupported RNN attributes: " + str(attrs.keys()) + assert direction in ['forward', 'bidirectional'], "unsupported backwards RNN" + + input_blob, W, R, B, sequence_lens, initial_h = n.inputs + + if sequence_lens == "": + sequence_lens = None + + input_size = cls._rnn_shape_inference(init_model, pred_model, n, input_blob, W) + if input_size is None: + raise RuntimeError("best-effort shape inference for RNN input failed") + + init_net = core.Net("init-net") + pred_mh = ModelHelper() + + def make_rnn(direction_offset): + name = dummy_name() + + # input and recurrence biases are squashed together in + # onnx but not in caffe2 + + bias_offset = 2 * direction_offset * hidden_size + init_net.Slice(B, name + "/i2h_b", + starts=[bias_offset + 0 * hidden_size], + ends =[bias_offset + 1 * hidden_size]) + init_net.Slice(B, name + "/gates_t_b", + starts=[bias_offset + 1 * hidden_size], + ends =[bias_offset + 2 * hidden_size]) + + weight_offset = direction_offset * hidden_size + init_net.Slice(W, name + '/i2h_w', + starts=[weight_offset + 0 * hidden_size, 0], + ends =[weight_offset + 1 * hidden_size,-1]) + init_net.Slice(R, name + '/gates_t_w', + starts=[weight_offset + 0 * hidden_size, 0], + ends =[weight_offset + 1 * hidden_size,-1]) + + initial_h_sliced = name + '/initial_h' + init_net.Slice(initial_h, initial_h_sliced, + starts=[direction_offset + 0, 0, 0], + ends =[direction_offset + 1,-1,-1]) + + if direction_offset == 1: + input = pred_mh.net.ReversePackedSegs( + [input_blob, sequence_lens], name + "/input-reversed") + else: + input = input_blob + + hidden_t_all, hidden_t_last = rnn_cell.BasicRNN( + pred_mh, + input, + sequence_lens, + [initial_h_sliced], + input_size, + hidden_size, + name, + drop_states=True, + forward_only=True, + activation=activation + ) + + if direction_offset == 1: + hidden_t_all = pred_mh.net.ReversePackedSegs( + [hidden_t_all, sequence_lens], name + "/output-reversed") + + return hidden_t_all, hidden_t_last + + if direction == 'forward': + hidden_t_all, hidden_t_last = make_rnn(0) + pred_mh.net = pred_mh.net.Clone( + "dummy-clone-net", + blob_remap={ hidden_t_all: n.outputs[0], hidden_t_last: n.outputs[1] } + ) + elif direction == 'bidirectional': + hidden_t_all_f, hidden_t_last_f = make_rnn(0) + hidden_t_all_b, hidden_t_last_b = make_rnn(1) + pred_mh.net.Concat([hidden_t_all_f, hidden_t_all_b], + [n.outputs[0], dummy_name()], axis=2) + pred_mh.net.Concat([hidden_t_last_f, hidden_t_last_b], + [n.outputs[1], dummy_name()], axis=2) + + return Caffe2Ops(list(pred_mh.Proto().op), + list(init_net.Proto().op), + list(pred_mh.Proto().external_input)) + + @classmethod + def _create_lstm(cls, init_model, pred_model, n, opset_version): + assert init_model is not None, "cannot convert LSTMs without access to the full model" + assert pred_model is not None, "cannot convert LSTMs without access to the full model" + + attrs = dict(n.attrs) # make a copy, which is safe to mutate + hidden_size = attrs.pop('hidden_size') + direction = force_unicode(attrs.pop('direction', 'forward')) + assert not attrs, "unsupported LSTM attributes: " + str(attrs.keys()) + assert direction in ['forward', 'bidirectional'], "unsupported backwards LSTM" + + input_blob, W, R, B, sequence_lens, initial_h, initial_c = n.inputs + + if sequence_lens == "": + sequence_lens = None + + input_size = cls._rnn_shape_inference(init_model, pred_model, n, input_blob, W) + if input_size is None: + raise RuntimeError("best-effort shape inference for LSTM input failed") + + init_net = core.Net("init-net") + pred_mh = ModelHelper() + + def make_lstm(direction_offset): + name = dummy_name() + + # input and recurrence biases are squashed together in + # onnx but not in caffe2 + + bias_offset = 8 * direction_offset * hidden_size + Bi = init_net.Slice(B, name + "_bias_i2h", + starts=[bias_offset + 0 * hidden_size], + ends =[bias_offset + 4 * hidden_size]) + Br = init_net.Slice(B, name + "_bias_gates", + starts=[bias_offset + 4 * hidden_size], + ends =[bias_offset + 8 * hidden_size]) + + weight_offset = 4 * direction_offset * hidden_size + W_ = init_net.Slice(W, name + '/i2h_w_pre', + starts=[weight_offset + 0 * hidden_size, 0], + ends =[weight_offset + 4 * hidden_size,-1]) + R_ = init_net.Slice(R, name + '/gates_t_w_pre', + starts=[weight_offset + 0 * hidden_size, 0], + ends =[weight_offset + 4 * hidden_size,-1]) + + # caffe2 has a different order from onnx. We need to rearrange + # i o f c -> i f o c + reforms = ((W_, 'i2h_w', [(0, -1)]), + (R_, 'gates_t_w', [(0, -1)]), + (Bi, 'i2h_b' , []), + (Br, 'gates_t_b', [])) + for name_from, name_to, extra_dims in reforms: + xi, xo, xf, xc = [name_from + suffix for suffix in ("_i", "_o", "_f", "_c")] + for i, x in enumerate([xi, xo, xf, xc]): + dim0 = i * hidden_size, (i+1) * hidden_size + starts, ends = zip(dim0, *extra_dims) + init_net.Slice(name_from, x, starts=starts, ends=ends) + init_net.Concat([xi, xf, xo, xc], ['%s/%s' % (name, name_to), dummy_name()], axis=0) + + initial_h_sliced = name + '/initial_h' + init_net.Slice(initial_h, initial_h_sliced, + starts=[direction_offset + 0, 0, 0], + ends =[direction_offset + 1,-1,-1]) + initial_c_sliced = name + '/initial_c' + init_net.Slice(initial_c, initial_c_sliced, + starts=[direction_offset + 0, 0, 0], + ends =[direction_offset + 1,-1,-1]) + + if direction_offset == 1: + input = pred_mh.net.ReversePackedSegs( + [input_blob, sequence_lens], name + "/input-reversed") + else: + input = input_blob + + hidden_t_all, hidden_t_last, _, _, params = rnn_cell.LSTM( + pred_mh, + input, + sequence_lens, + [initial_h_sliced, initial_c_sliced], + input_size, + hidden_size, + name, + drop_states=True, + forward_only=True, + return_params=True + ) + + if direction_offset == 1: + hidden_t_all = pred_mh.net.ReversePackedSegs( + [hidden_t_all, sequence_lens], name + "/output-reversed") + + return hidden_t_all, hidden_t_last + + if direction == 'forward': + hidden_t_all, hidden_t_last = make_lstm(0) + pred_mh.net = pred_mh.net.Clone( + "dummy-clone-net", + blob_remap={ hidden_t_all: n.outputs[0], hidden_t_last: n.outputs[1] } + ) + elif direction == 'bidirectional': + hidden_t_all_f, hidden_t_last_f = make_lstm(0) + hidden_t_all_b, hidden_t_last_b = make_lstm(1) + pred_mh.net.Concat([hidden_t_all_f, hidden_t_all_b], + [n.outputs[0], dummy_name()], axis=2) + pred_mh.net.Concat([hidden_t_last_f, hidden_t_last_b], + [n.outputs[1], dummy_name()], axis=2) + + return Caffe2Ops(list(pred_mh.Proto().op), + list(init_net.Proto().op), + list(pred_mh.Proto().external_input)) + + @classmethod + def _create_gru(cls, init_model, pred_model, n, opset_version): + assert init_model is not None, "cannot convert GRUs without access to the full model" + assert pred_model is not None, "cannot convert GRUs without access to the full model" + + attrs = dict(n.attrs) # make a copy, which is safe to mutate + hidden_size = attrs.pop('hidden_size') + linear_before_reset = attrs.pop('linear_before_reset', 0) + direction = force_unicode(attrs.pop('direction', 'forward')) + assert not attrs, "unsupported GRU attributes: " + str(attrs.keys()) + assert direction in ['forward', 'bidirectional'], "unsupported backwards GRU" + + input_blob, W, R, B, sequence_lens, initial_h = n.inputs + + if sequence_lens == "": + sequence_lens = None + + input_size = cls._rnn_shape_inference(init_model, pred_model, n, input_blob, W) + if input_size is None: + raise RuntimeError("best-effort shape inference for GRU input failed") + + init_net = core.Net("init-net") + pred_mh = ModelHelper() + + def make_gru(direction_offset): + name = dummy_name() + + # input and recurrence biases are squashed together in + # onnx but not in caffe2 + + bias_offset = 6 * direction_offset * hidden_size + Bi = init_net.Slice(B, name + "_bias_i2h", + starts=[bias_offset + 0 * hidden_size], + ends =[bias_offset + 3 * hidden_size]) + Br = init_net.Slice(B, name + "_bias_gates", + starts=[bias_offset + 3 * hidden_size], + ends =[bias_offset + 6 * hidden_size]) + + weight_offset = 3 * direction_offset * hidden_size + W_ = init_net.Slice(W, name + '/i2h_w_pre', + starts=[weight_offset + 0 * hidden_size, 0], + ends =[weight_offset + 3 * hidden_size,-1]) + R_ = init_net.Slice(R, name + '/gates_t_w_pre', + starts=[weight_offset + 0 * hidden_size, 0], + ends =[weight_offset + 3 * hidden_size,-1]) + + # caffe2 has a different order from onnx. We need to rearrange + # z r h -> r z h + reforms = ((W_, 'i2h_w', True, [(0,-1)]), + (R_, 'gate_t_w', False, [(0,-1)]), + (Bi, 'i2h_b', True, []), + (Br, 'gate_t_b', False, [])) + for name_from, name_to, do_concat, extra_dims in reforms: + xz, xr, xh = ['%s/%s_%s' % (name, prefix, name_to) for prefix in ('update', 'reset', 'output')] + for i, x in enumerate([xz, xr, xh]): + dim0 = i * hidden_size, (i+1) * hidden_size + starts, ends = zip(dim0, *extra_dims) + init_net.Slice(name_from, x, starts=starts, ends=ends) + if do_concat: + init_net.Concat([xr, xz, xh], ['%s/%s' % (name, name_to), dummy_name()], axis=0) + + initial_h_sliced = name + '/initial_h' + init_net.Slice(initial_h, initial_h_sliced, + starts=[direction_offset + 0, 0, 0], + ends =[direction_offset + 1,-1,-1]) + + if direction_offset == 1: + input = pred_mh.net.ReversePackedSegs( + [input_blob, sequence_lens], name + "/input-reversed") + else: + input = input_blob + + hidden_t_all, hidden_t_last = gru_cell.GRU( + pred_mh, + input, + sequence_lens, + [initial_h_sliced], + input_size, + hidden_size, + name, + drop_states=True, + forward_only=True, + linear_before_reset=linear_before_reset + ) + + if direction_offset == 1: + hidden_t_all = pred_mh.net.ReversePackedSegs( + [hidden_t_all, sequence_lens], name + "/output-reversed") + + return hidden_t_all, hidden_t_last + + if direction == 'forward': + hidden_t_all, hidden_t_last = make_gru(0) + pred_mh.net = pred_mh.net.Clone( + "dummy-clone-net", + blob_remap={ hidden_t_all: n.outputs[0], hidden_t_last: n.outputs[1] } + ) + elif direction == 'bidirectional': + hidden_t_all_f, hidden_t_last_f = make_gru(0) + hidden_t_all_b, hidden_t_last_b = make_gru(1) + pred_mh.net.Concat([hidden_t_all_f, hidden_t_all_b], + [n.outputs[0], dummy_name()], axis=2) + pred_mh.net.Concat([hidden_t_last_f, hidden_t_last_b], + [n.outputs[1], dummy_name()], axis=2) + + return Caffe2Ops(list(pred_mh.Proto().op), + list(init_net.Proto().op), + list(pred_mh.Proto().external_input)) + + @classmethod + def _create_pad(cls, init_model, pred_model, n, opset_version): + if opset_version < 2: + pads = n.attrs['paddings'] + else: + pads = n.attrs['pads'] + if not (len(pads) == 8 and + # first two dim is for batch and channel + set(pads[:2] + pads[4:6]) == {0}): + raise ValueError('Caffe2 only supports padding 2D Tensor, whereas padding is ' + str(pads)) + # Guard the invalid (negative) pads attribute. + if min(pads) < 0: + raise ValueError('ONNX does not support negative pads in Pad, but get {}.'.format(pads)) + pads[:] = pads[2:4] + pads[6:8] + return cls._common_onnx_node_to_caffe2_op(init_model, pred_model, n, opset_version) + + @classmethod + def _create_concat(cls, init_model, pred_model, n, opset_version): + # TODO: Caffe2 Concat has an extra output. It should be only + # used when doing training, so we should change Caffe2 to allow + # 1 output. + op = cls._common_onnx_node_to_caffe2_op(init_model, pred_model, n, opset_version) + assert len(op.output) == 1 + op.output.append(dummy_name()) + return op + + @classmethod + def _create_slice(cls, init_model, pred_model, n, opset_version): + op = cls._common_onnx_node_to_caffe2_op(init_model, pred_model, n, opset_version) + args = {arg.name: arg for arg in op.arg} + starts_vals = np.array( + args.pop('starts').ints, dtype=np.int64).tolist() + ends_vals = np.array( + [i - 1 if i < 0 else i for i in args.pop('ends').ints], + dtype=np.int64).tolist() + if 'axes' in args: + axes_vals = np.array( + args.pop('axes').ints, dtype=np.int32).tolist() + else: + ndims = len(starts_vals) + axes_vals = np.array(range(ndims), dtype=np.int32).tolist() + + data, = op.input + ops = [] + + shape_tensor = dummy_name() + ops.append(core.CreateOperator( + 'Shape', + [data], + [shape_tensor] + )) + + axes_tensor = dummy_name() + ops.extend([ + core.CreateOperator( + 'GivenTensorIntFill', + [], + [axes_tensor], + shape=[len(axes_vals)], + values=axes_vals, + ), + ]) + + starts_vals_tensor = dummy_name() + starts_tensor = dummy_name() + casted_starts_tensor = dummy_name() + ops.extend([ + core.CreateOperator( + 'GivenTensorInt64Fill', + [], + [starts_vals_tensor], + shape=[len(starts_vals)], + values=starts_vals, + ), + core.CreateOperator( + 'ConstantFill', + [shape_tensor], + [starts_tensor], + dtype=caffe2_pb2.TensorProto.INT64, + value=0, + ), + core.CreateOperator( + 'ScatterAssign', + [starts_tensor, axes_tensor, starts_vals_tensor], + [starts_tensor], + ), + # Slice only accepts starts as int + core.CreateOperator( + 'Cast', + [starts_tensor], + [casted_starts_tensor], + to=caffe2_pb2.TensorProto.INT32, + ), + ]) + + ends_vals_tensor = dummy_name() + ends_tensor = dummy_name() + casted_ends_tensor = dummy_name() + ops.extend([ + core.CreateOperator( + 'GivenTensorInt64Fill', + [], + [ends_vals_tensor], + shape=[len(ends_vals)], + values=ends_vals, + ), + core.CreateOperator( + 'ConstantFill', + [shape_tensor], + [ends_tensor], + dtype=caffe2_pb2.TensorProto.INT64, + value=-1, + ), + core.CreateOperator( + 'ScatterAssign', + [ends_tensor, axes_tensor, ends_vals_tensor], + [ends_tensor], + ), + # Slice only accepts ends as int + core.CreateOperator( + 'Cast', + [ends_tensor], + [casted_ends_tensor], + to=caffe2_pb2.TensorProto.INT32, + ), + ]) + + op.input[:] = [data, casted_starts_tensor, casted_ends_tensor] + del op.arg[:] + op.arg.extend(args.values()) + ops.append(op) + + return ops + + # Note [Caffe2 ConvPoolOpBase] + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # To understand what is going on here, we have to talk a little bit about + # Caffe2's internals. + # + # First, it's important to know that all of Caffe2's pooling and convolution + # operators inherit from "ConvPoolOpBase", which is an abstract class that + # defines all of the attributes (kernels, dilations, strides, etc) which one + # sees on these operators. Unfortunately, Caffe2's documentation generator + # doesn't know how to handle cases like this, so for example, if you look at + # the docs for MaxPool at <https://caffe2.ai/docs/operators-catalogue.html#maxpool> + # you won't see any of the attributes. You have to go source diving to + # find the information; in particular, you want to look at: + # https://github.com/caffe2/caffe2/blob/master/caffe2/operators/conv_pool_op_base.h + # This class handles *global* pooling as well. + # + # Second, it's important to know what Caffe2 expects for padding, which can + # be somewhat difficult to understand from the code because Caffe2 handles + # both singular/pluralized spellings of padding, and there is also legacy + # padding business. The short version of the story is that, for NON-legacy + # padding (which is what we want to output), padding is expected to be + # *twice* the size of kernels. So if you have a 2D convolution, Caffe2 + # will accept two values in 'kernels', but FOUR values in 'pads'; + # furthermore, this is *mandatory.* + # + # Finally, ConvPoolOpBase is not the only class of it's kind; there is + # also ConvTransposeUnpoolBase, which backs ConvTranspose. So don't + # be tricked by the fact that Conv and ConvTranspose have similar + # parameters; they exercise different codepaths and need to be handled + # differently. + + @classmethod + def _create_conv_pool_op_base(cls, init_model, pred_model, n, opset_version): + if n.op_type.startswith('Global'): + n.attrs['global_pooling'] = 1 + + try: + kernels = n.attrs['kernel_shape'] + pads = n.attrs['pads'] + except KeyError: + pass + else: + if len(kernels) == len(pads): + # Caffe2 requires pads to be twice the size of kernels. + n.attrs['pads'] = pads * 2 + + return cls._common_onnx_node_to_caffe2_op(init_model, pred_model, n, opset_version) + + @classmethod + def _create_reshape(cls, init_model, pred_model, n, opset_version): + c2_op = cls._common_onnx_node_to_caffe2_op(init_model, pred_model, n, opset_version) + # Caffe2 has an extra output + c2_op.output.append(dummy_name()) + return c2_op + + @classmethod + def _create_sqrt(cls, init_model, pred_model, n, opset_version): + (X,) = n.inputs + (Y,) = n.outputs + return core.CreateOperator( + 'Pow', + [X], + [Y], + exponent=0.5, + ) + + @classmethod + def _create_reciprocal(cls, init_model, pred_model, n, opset_version): + (X,) = n.inputs + (Y,) = n.outputs + return core.CreateOperator( + 'Pow', + [X], + [Y], + exponent=-1.0, + ) + + @classmethod + def _direct_initialize_parameters(cls, initializer, ws, device_option): + for tp in initializer: + ws.FeedBlob(tp.name, onnx.numpy_helper.to_array(tp), device_option) + + @classmethod + def _direct_initialize_inputs(cls, inputs, initialized, ws, device_option): + for value_info in inputs: + if value_info.name in initialized: + continue + shape = list(d.dim_value for d in value_info.type.tensor_type.shape.dim) + ws.FeedBlob(value_info.name, np.ones(shape), device_option) + + @staticmethod + def optimize_onnx(input, init=False, predict=False): + passes = ['fuse_consecutive_transposes', + 'eliminate_nop_transpose', + 'fuse_transpose_into_gemm'] + if init: + passes.append('split_init') + if predict: + passes.append('split_predict') + out = onnx.optimizer.optimize(input, passes) + return out + + @classmethod + def prepare(cls, model, device='CPU', **kwargs): + ''' + For Onnx Caffe2Backend, we require that init_graph don't initialize the actual input of the predict_graph, + + for example, if "img" is the input blob for the predict_net, we require that in init_graph and in + initializer of the predict_graph, "img" is not initalized. We don't have a check for this, since + there is no way we can know which blob is the input of the predict_graph. + ''' + super(Caffe2Backend, cls).prepare(model, device, **kwargs) + + + opset_version = None + for imp in model.opset_import: + if not imp.HasField("domain") or imp.domain == "": + opset_version = imp.version + if imp.version > cls._known_opset_version: + warnings.warn("This version of onnx-caffe2 targets ONNX operator set version {}, but the model we are trying to import uses version {}. We will try to import it anyway, but if the model uses operators which had BC-breaking changes in the intervening versions, import will fail.".format(cls._known_opset_version, imp.version)) + else: + warnings.warn("Unrecognized operator set {}".format(imp.domain)) + if opset_version is None: + if model.ir_version >= 0x00000003: + raise RuntimeError("Model with IR version >= 3 did not specify ONNX operator set version (onnx-caffe2 requires it)") + else: + opset_version = 1 + + ws = Workspace() + device_option = get_device_option(Device(device)) + + # Directly load initializer data into blobs in workspace + cls._direct_initialize_parameters( + model.graph.initializer, + ws, + device_option, + ) + + initialized = {init.name for init in model.graph.initializer} + + cls._direct_initialize_inputs( + model.graph.input, + initialized, + ws, + device_option, + ) + + uninitialized = [value_info.name for value_info in model.graph.input if value_info.name not in initialized] + + init_net, predict_net = cls._onnx_model_to_caffe2_net(model, device, opset_version, False) + + retval = Caffe2Rep(init_net, predict_net, ws, uninitialized) + return retval + + @classmethod + # TODO: This method needs a refactor for clarity + def _onnx_node_to_caffe2_op(cls, init_model, pred_model, node_def, opset_version): + if node_def.op_type in cls._special_operators: + translator = getattr(cls, cls._special_operators[node_def.op_type]) + else: + translator = cls._common_onnx_node_to_caffe2_op + ops = translator(init_model, pred_model, OnnxNode(node_def), opset_version) + if isinstance(ops, Caffe2Ops): + return ops + if not isinstance(ops, collections.Iterable): + ops = [ops] + return Caffe2Ops(ops, [], []) + + @classmethod + def _common_onnx_node_to_caffe2_op(cls, init_model, pred_model, onnx_node, opset_version): + """ + This translator performs the basic translation of ONNX nodes into + Caffe2 operators. Besides doing a straightforward marshalling from + one format to another, it also does these extra things: + + - Renames operators based on '_renamed_operators' + - Renames attributes based on '_global_renamed_attrs' and + '_per_op_renamed_attrs' + + If you're writing a custom translator, consider calling this first, + and then fixing things up further. + """ + c2_op = caffe2_pb2.OperatorDef() + + c2_op.input.extend(onnx_node.inputs) + c2_op.output.extend(onnx_node.outputs) + c2_op.name = onnx_node.name + + onnx_op_type = onnx_node.op_type + broken_version = cls._broken_operators.get(onnx_op_type, float('Inf')) + if broken_version <= opset_version: + raise ValueError( + "Don't know how to translate op {} in ONNX operator set v{} (I only support prior to v{})".format(onnx_op_type, opset_version, broken_version)) + c2_op.type = cls._renamed_operators.get(onnx_op_type, onnx_op_type) + if not core.IsOperator(c2_op.type): + raise ValueError( + "Don't know how to translate op {}".format(onnx_op_type)) + + def kmap(k): + if (onnx_op_type in cls._per_op_renamed_attrs and + k in cls._per_op_renamed_attrs[onnx_op_type]): + return cls._per_op_renamed_attrs[onnx_op_type][k] + if k in cls._global_renamed_attrs: + return cls._global_renamed_attrs[k] + return k + c2_op.arg.extend(onnx_node.attrs.caffe2(kmap=kmap)) + + return c2_op + + + @classmethod + def _inplace_rewrite(cls, graph_or_nodes): + ''' + currently we use this to translate ONNX-style + consumed_input annotations to Caffe2-style in place + updates (use same input and output names). + ''' + is_graph = isinstance(graph_or_nodes, GraphProto) + if is_graph: + nodes = graph_or_nodes.node + else: + nodes = graph_or_nodes + + renamed = {} + + for node in nodes: + node.input[:] = [renamed.get(input_name, input_name) + for input_name in node.input] + consumed_inputs = OnnxNode(node).consumed_inputs or [] + output_idxes = set(range(len(node.output))) + schema = onnx.defs.get_schema(node.op_type) + for i, consumed in enumerate(consumed_inputs): + if not consumed: + continue + _, output_idx = schema.consumed(i) + # consumed outputs are not always present + # for instance batch norm in test mode + # does not return the consumed inputs + if output_idx < len(node.output): + output_idxes.remove(output_idx) + old_val = node.output[output_idx] + new_val = node.input[i] + node.output[output_idx] = new_val + renamed[old_val] = new_val + for idx in output_idxes: + name = node.output[idx] + node.output[idx] = renamed.get(name, name) + if is_graph: + for output in graph_or_nodes.output: + output.name = renamed.get(output.name, output.name) + + @staticmethod + def _all_names_in_graph(graph): + if graph is None: + return set() + + names = set() + names.update(value_info.name for value_info in graph.input) + names.update(value_info.name for value_info in graph.output) + for node in graph.node: + names.update(node.input) + names.update(node.output) + return names + + @classmethod + def _onnx_model_to_caffe2_net(cls, onnx_model, device, opset_version, include_initializers): + device_option = get_device_option(Device(device)) + + init_model = ModelProto() + init_model.ParseFromString(cls.optimize_onnx(onnx_model.SerializeToString(), init=True)) + cls._inplace_rewrite(init_model.graph) + + pred_model = ModelProto() + pred_model.ParseFromString(cls.optimize_onnx(onnx_model.SerializeToString(), predict=True)) + cls._inplace_rewrite(pred_model.graph) + + init_net = caffe2_pb2.NetDef() + pred_net = caffe2_pb2.NetDef() + + init_net.name = onnx_model.graph.name + '_init' + pred_net.name = onnx_model.graph.name + '_predict' + + if include_initializers: + init_net.op.extend(cls._create_tensor_filling_op(tp) for tp in onnx_model.graph.initializer) + + dummy_name(cls._all_names_in_graph(init_model.graph) | cls._all_names_in_graph(pred_model.graph)) + + for net, model in ( (init_net, init_model), (pred_net, pred_model) ): + net.device_option.CopyFrom(device_option) + for node in model.graph.node: + c2ops = cls._onnx_node_to_caffe2_op( + init_model, pred_model, node, opset_version) + (init_net if include_initializers else net).op.extend(c2ops.init_ops) + net.op.extend(c2ops.ops) + net.external_input.extend(c2ops.interface_blobs) + net.external_output.extend( + value_info.name for value_info in model.graph.output) + net.external_input.extend( + value_info.name for value_info in model.graph.input) + + return init_net, pred_net + + # wrapper for backwards compatability + @classmethod + def onnx_graph_to_caffe2_net(cls, model, device="CPU", opset_version=_known_opset_version): + return cls._onnx_model_to_caffe2_net(model, device=device, opset_version=opset_version, include_initializers=True) + + @classmethod + def supports_device(cls, device_str): + device = Device(device_str) + if device.type == DeviceType.CPU: + return True + elif device.type == DeviceType.CUDA: + return workspace.has_gpu_support + return False + + +prepare = Caffe2Backend.prepare + +run_node = Caffe2Backend.run_node + +run_model = Caffe2Backend.run_model + +supports_device = Caffe2Backend.supports_device # noqa diff --git a/caffe2/python/onnx/backend_rep.py b/caffe2/python/onnx/backend_rep.py new file mode 100644 index 0000000000..f427b6d2d1 --- /dev/null +++ b/caffe2/python/onnx/backend_rep.py @@ -0,0 +1,77 @@ +# Copyright (c) 2016-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## + +## @package onnx +# Module caffe2.python.onnx.backend_rep +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from caffe2.python import core, workspace +from caffe2.proto import caffe2_pb2 +from onnx.backend.base import BackendRep, namedtupledict + +class Caffe2Rep(BackendRep): + def __init__(self, init_net, predict_net, workspace, uninitialized): + super(Caffe2Rep, self).__init__() + self.init_net = init_net + self.predict_net = predict_net + self.workspace = workspace + # The list of uninitialized external_inputs in workspace, we need this to + # pair the name with given sequence inputs. + self.uninitialized = uninitialized + self.nets_created = False + self.ran_init_net = False + + @property + def _name_scope(self): + if self.predict_net.device_option.device_type == caffe2_pb2.CUDA: + return 'gpu_{}'.format(self.predict_net.device_option.cuda_gpu_id) + return '' + + def run(self, inputs, **kwargs): + super(Caffe2Rep, self).run(inputs, **kwargs) + with self.workspace: + with core.DeviceScope(self.predict_net.device_option): + if isinstance(inputs, dict): + with core.NameScope(self._name_scope): + for key, value in inputs.items(): + workspace.FeedBlob(key, value) + elif isinstance(inputs, list) or isinstance(inputs, tuple): + if len(self.uninitialized) != len(inputs): + raise RuntimeError('Expected {} values for uninitialized ' + 'graph inputs ({}), but got {}.'.format( + len(self.uninitialized), + ', '.join(self.uninitialized), + len(inputs))) + for i, value in enumerate(inputs): + # namescope already baked into protobuf + workspace.FeedBlob(self.uninitialized[i], value) + else: + # single input + workspace.FeedBlob(self.uninitialized[0], inputs) + if not self.nets_created: + workspace.CreateNet(self.init_net) + workspace.CreateNet(self.predict_net) + self.nets_created = True + if not self.ran_init_net: + workspace.RunNet(self.init_net.name) + self.ran_init_net = True + workspace.RunNet(self.predict_net.name) + output_values = [workspace.FetchBlob(name) + for name in self.predict_net.external_output] + return namedtupledict('Outputs', + self.predict_net.external_output)(*output_values) diff --git a/caffe2/python/onnx/bin/conversion.py b/caffe2/python/onnx/bin/conversion.py new file mode 100644 index 0000000000..fd5eb70ef1 --- /dev/null +++ b/caffe2/python/onnx/bin/conversion.py @@ -0,0 +1,104 @@ +# Copyright (c) 2016-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## + +## @package onnx +# Module caffe2.python.onnx.bin.conversion + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json + +from caffe2.proto import caffe2_pb2 +import click +import numpy as np +from onnx import checker, ModelProto + +from caffe2.python.onnx.backend import Caffe2Backend as c2 +import caffe2.python.onnx.frontend as c2_onnx + + +@click.command( + help='convert caffe2 net to onnx model', + context_settings={ + 'help_option_names': ['-h', '--help'] + } +) +@click.argument('caffe2_net', type=click.File('rb')) +@click.option('--caffe2-net-name', + type=str, + help="Name of the caffe2 net") +@click.option('--caffe2-init-net', + type=click.File('rb'), + help="Path of the caffe2 init net pb file") +@click.option('--value-info', + type=str, + help='A json string providing the ' + 'type and shape information of the inputs') +@click.option('-o', '--output', required=True, + type=click.File('wb'), + help='Output path for the onnx model pb file') +def caffe2_to_onnx(caffe2_net, + caffe2_net_name, + caffe2_init_net, + value_info, + output): + c2_net_proto = caffe2_pb2.NetDef() + c2_net_proto.ParseFromString(caffe2_net.read()) + if not c2_net_proto.name and not caffe2_net_name: + raise click.BadParameter( + 'The input caffe2 net does not have name, ' + '--caffe2-net-name must be provided') + c2_net_proto.name = caffe2_net_name or c2_net_proto.name + if caffe2_init_net: + c2_init_net_proto = caffe2_pb2.NetDef() + c2_init_net_proto.ParseFromString(caffe2_init_net.read()) + c2_init_net_proto.name = '{}_init'.format(caffe2_net_name) + else: + c2_init_net_proto = None + + if value_info: + value_info = json.loads(value_info) + + onnx_model = c2_onnx.caffe2_net_to_onnx_model( + predict_net=c2_net_proto, + init_net=c2_init_net_proto, + value_info=value_info) + + output.write(onnx_model.SerializeToString()) + + +@click.command( + help='convert onnx model to caffe2 net', + context_settings={ + 'help_option_names': ['-h', '--help'] + } +) +@click.argument('onnx_model', type=click.File('rb')) +@click.option('-o', '--output', required=True, + type=click.File('wb'), + help='Output path for the caffe2 net file') +@click.option('--init-net-output', + required=True, + type=click.File('wb'), + help='Output path for the caffe2 init net file') +def onnx_to_caffe2(onnx_model, output, init_net_output): + onnx_model_proto = ModelProto() + onnx_model_proto.ParseFromString(onnx_model.read()) + + init_net, predict_net = c2.onnx_graph_to_caffe2_net(onnx_model_proto) + init_net_output.write(init_net.SerializeToString()) + output.write(predict_net.SerializeToString()) diff --git a/caffe2/python/onnx/error.py b/caffe2/python/onnx/error.py new file mode 100644 index 0000000000..760d7b048b --- /dev/null +++ b/caffe2/python/onnx/error.py @@ -0,0 +1,23 @@ +# Copyright (c) 2016-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## + +## @package onnx +# Module caffe2.python.onnx.error +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals +class BaseException(Exception): pass +class Unsupported(BaseException): pass diff --git a/caffe2/python/onnx/frontend.py b/caffe2/python/onnx/frontend.py new file mode 100644 index 0000000000..bac6024e09 --- /dev/null +++ b/caffe2/python/onnx/frontend.py @@ -0,0 +1,551 @@ +# Copyright (c) 2016-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## + +## @package onnx +# Module caffe2.python.onnx.frontend + +"""Caffe2 Protobuf to ONNX converter + +To run this, you will need to have Caffe2 installed as well. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import itertools +import collections +import logging +import re + +from caffe2.python import core as caffe2_core +from enum import Enum +from onnx import (defs, checker, helper, numpy_helper, mapping, + ModelProto, GraphProto, NodeProto, AttributeProto, TensorProto, OperatorSetIdProto) +from onnx.helper import make_tensor, make_tensor_value_info +import numpy as np + +from caffe2.python.onnx.helper import make_model, c2_native_run_net, dummy_name +from caffe2.python.onnx.error import Unsupported + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class Caffe2Frontend(object): + # This number controls the semantics of the operators we target. Whenever + # ONNX makes a BC breaking change to semantics of operators, having this set + # to an accurate number will prevent our models form exporting. However, + # we should strive to keep this up-to-date as much as possible. + _target_opset_version = 3 + + _renamed_operators = { + 'SpatialBN': 'BatchNormalization', + 'Conv1D': 'Conv', + 'Conv2D': 'Conv', + 'Conv3D': 'Conv', + 'ConvTranspose1D': 'ConvTranspose', + 'ConvTranspose2D': 'ConvTranspose', + 'ConvTranspose3D': 'ConvTranspose', + 'MaxPool1D': 'MaxPool', + 'MaxPool2D': 'MaxPool', + 'MaxPool3D': 'MaxPool', + 'AveragePool1D': 'AveragePool', + 'AveragePool2D': 'AveragePool', + 'AveragePool3D': 'AveragePool', + } + + # caffe2 arguments that are completely removed in onnx + _blacklist_caffe2_args = { + 'order': {b'NCHW'}, + 'cudnn_exhaustive_search': {0, 1}, + 'use_cudnn': {0, 1}, + } + + _global_renamed_args = { + 'kernels': 'kernel_shape', + } + + _per_op_renamed_args = { + 'Squeeze': {'dims': 'axes'}, + 'Transpose': {'axes': 'perm'}, + } + + _special_operators = { + 'Conv': '_create_conv_pool_op', + 'ConvTranspose': '_create_conv_pool_op', + 'ChannelShuffle': '_create_channel_shuffle', + 'MaxPool': '_create_conv_pool_op', + 'AveragePool': '_create_conv_pool_op', + 'Concat': '_create_concat', + 'FC': '_create_gemm', + 'LRN': '_create_lrn', + 'Slice': '_create_slice', + 'Reshape': '_create_reshape', + } + + @classmethod + def _common_caffe2_arg_to_onnx_attr(cls, op_def, arg): + # name + op_type = op_def.type + if op_type in cls._per_op_renamed_args: + name = cls._per_op_renamed_args[op_type].get( + arg.name, arg.name) + else: + name = cls._global_renamed_args.get(arg.name, arg.name) + + # value + if arg.HasField('f'): + value = arg.f + elif arg.HasField('i'): + value = arg.i + elif arg.HasField('s'): + value = arg.s + elif arg.floats: + value = arg.floats + elif arg.ints: + value = arg.ints + elif arg.strings: + value = arg.strings + else: + raise ValueError('Could not find data field in arg: {}'.format(arg)) + + if name in cls._blacklist_caffe2_args: + assert value in cls._blacklist_caffe2_args[arg.name] + return None + + return helper.make_attribute(name, value) + + @classmethod + def caffe2_arg_to_onnx_attr(cls, op_def, arg): + return cls._common_caffe2_arg_to_onnx_attr(op_def, arg) + + @classmethod + def _common_caffe2_op_to_onnx_node(cls, op_def, shapes): + node_def = NodeProto() + node_def.name = op_def.name + + node_def.op_type = cls._renamed_operators.get(op_def.type, op_def.type) + + node_def.input.extend(op_def.input) + node_def.output.extend(op_def.output) + + attrs = filter(None, [cls.caffe2_arg_to_onnx_attr(op_def, arg) + for arg in op_def.arg]) + node_def.attribute.extend(attrs) + + return node_def + + @classmethod + def _create_concat(cls, op_def, shapes): + node = cls._common_caffe2_op_to_onnx_node(op_def, shapes) + if len(node.output) == 2: + del node.output[1] + explicit_axis = any(arg.name == 'axis' for arg in op_def.arg) + if not explicit_axis: + node.attribute.extend([helper.make_attribute('axis', 1)]) + return node + + @classmethod + def _create_reshape(cls, op_def, shapes): + node = cls._common_caffe2_op_to_onnx_node(op_def, shapes) + if len(node.output) == 2: + del node.output[1] + return node + + @classmethod + def _create_conv_pool_op(cls, op_def, shapes): + node = cls._common_caffe2_op_to_onnx_node(op_def, shapes) + + if node.op_type in ['MaxPool', 'AveragePool']: + for i, attr in enumerate(node.attribute): + if attr.name == 'global_pooling' and attr.i: + node.op_type = 'Global{}'.format(node.op_type) + del node.attribute[i] + break + + attrs = {attr.name: attr for attr in node.attribute} + def apply_trans(k, dim=2, ks=None): + ks = ks or (k + 's') + if dim == 2: + k_h, k_w = k + '_h', k + '_w' + else: + k_t, k_l, k_b, k_r = k + '_t', k + '_l', k + '_b', k + '_r' + + vals = None + if (dim == 2 and k_h in attrs and k_w in attrs): + vals = [attrs[k_h].i, attrs[k_w].i] + del attrs[k_h] + del attrs[k_w] + elif (dim == 4 and + k_t in attrs and k_l in attrs and k_b in attrs and k_r in attrs): + vals = [attrs[k_t].i, + attrs[k_l].i, + attrs[k_b].i, + attrs[k_r].i] + del attrs[k_t] + del attrs[k_l] + del attrs[k_b] + del attrs[k_r] + elif k in attrs: + vals = [attrs[k].i] * dim + del attrs[k] + + if vals and not node.op_type.startswith('Global'): + attrs[ks] = helper.make_attribute(ks, vals) + + apply_trans('kernel', ks='kernel_shape') + apply_trans('stride') + apply_trans('dilation') + apply_trans('adj') + apply_trans('pad', 4) + + del node.attribute[:] + node.attribute.extend(attrs.values()) + return node + + @classmethod + def _create_gemm(cls, op_def, shapes): + x, w, b = op_def.input + args = {arg.name: arg for arg in op_def.arg} + y, = op_def.output + x_shape = list(shapes[x]) + + nodes = [] + if 'axis' in args: + axis = args['axis'].i + outer = np.prod(x_shape[:axis]).astype(int) + inner = np.prod(x_shape[axis:]).astype(int) + reshaped_x = dummy_name() + nodes.append(helper.make_node( + 'Reshape', + inputs=[x], + outputs=[reshaped_x], + shape=[outer, inner], + )) + x = reshaped_x + + if 'axis_w' in args: + axis_w = args['axis_w'].i + w_shape = shapes[w] + outer = np.prod(w_shape[:axis_w]).astype(int).item() + inner = np.prod(w_shape[axis_w:]).astype(int).item() + reshaped_w = dummy_name() + nodes.append(helper.make_node( + 'Reshape', + inputs=[w], + outputs=[reshaped_w], + shape=[outer, inner], + )) + w = reshaped_w + + gemm_y_output = dummy_name() if 'axis' in args else y + nodes.append(helper.make_node( + 'Gemm', + inputs=[x, w, b], + outputs=[gemm_y_output], + name=op_def.name, + transB=1, + broadcast=1, + )) + + if 'axis' in args: + axis = args['axis'].i + nodes.append(helper.make_node( + 'Reshape', + inputs=[gemm_y_output], + outputs=[y], + shape=x_shape[:axis] + [-1], + )) + + return nodes + + @classmethod + def _create_lrn(cls, op_def, shapes): + node = cls._common_caffe2_op_to_onnx_node(op_def, shapes) + if len(node.output) == 2: + del node.output[1] + return node + + @classmethod + def _create_slice(cls, op_def, shapes): + if len(op_def.input) > 1: + raise Unsupported( + 'ONNX Slice operator does not support dynamic slice.') + node = cls._common_caffe2_op_to_onnx_node(op_def, shapes) + attrs = {attr.name: attr for attr in node.attribute} + ndims = len(attrs['starts'].ints) + + node.attribute.extend([helper.make_attribute('axes', range(ndims))]) + + data, = node.input + shape = shapes[data] + + ends = attrs['ends'].ints + for i, end in enumerate(ends): + if end >= 0: + continue + if end == -1: + end = shape[i] + else: + end = end + 1 + ends[i] = end + + return node + + @classmethod + def _create_channel_shuffle(cls, op_def, shapes): + x, = op_def.input + y, = op_def.output + n, c, h, w = shapes[x] + args = {arg.name: arg for arg in op_def.arg} + g = args['group'].i + assert c % g == 0 + + nodes = [] + + tmp1 = dummy_name() + nodes.append(helper.make_node( + 'Reshape', + inputs=[x], + outputs=[tmp1], + shape=[n, g, c // g, h, w], + )) + + tmp2 = dummy_name() + nodes.append(helper.make_node( + 'Transpose', + inputs=[tmp1], + outputs=[tmp2], + perm=[0, 2, 1, 3, 4], + )) + + nodes.append(helper.make_node( + 'Reshape', + inputs=[tmp2], + outputs=[y], + shape=[n, c, h, w], + )) + return nodes + + @classmethod + def caffe2_op_to_onnx_node(cls, op_def, shapes): + if op_def.type in cls._special_operators: + translator = getattr(cls, cls._special_operators[op_def.type]) + else: + translator = cls._common_caffe2_op_to_onnx_node + nodes = translator(op_def, shapes) + if not isinstance(nodes, collections.Iterable): + nodes = [nodes] + return nodes + + @staticmethod + def _all_names_in_net(net): + if net is None: + return set() + + names = set() + names.update(net.external_input) + names.update(net.external_output) + for op in net.op: + names.update(op.input) + names.update(op.output) + return names + + @classmethod + def caffe2_net_to_onnx_graph(cls, + predict_net, + init_net=None, + value_info=None): + if value_info is None: + value_info = {} + if not isinstance(value_info, dict): + raise ValueError('Please pass value_info as a ' + 'name -> (type, shape) dictionary') + + cls._ssa_rewrite(predict_net, init_net, value_info) + + if init_net: + initializer = cls.caffe2_init_net_to_initializer(init_net) + value_info.update({init.name: (init.data_type, init.dims) + for init in initializer}) + else: + initializer = [] + + # Check whether we have got type shape info of all input + missing = (set(list(predict_net.external_input)) - + set(value_info.keys())) + if missing: + raise RuntimeError('Could not find value info of inputs: {}'.format( + ', '.join(missing))) + + inputs = {} + for name in predict_net.external_input: + elem_type, shape = value_info[name] + inputs[name] = np.random.randn(*shape).astype( + mapping.TENSOR_TYPE_TO_NP_TYPE[elem_type]) + + ws, outputs = c2_native_run_net( + init_net, + predict_net, + inputs) + + for name in predict_net.external_output: + output = outputs[name] + elem_type = mapping.NP_TYPE_TO_TENSOR_TYPE[output.dtype] + shape = output.shape + value_info[name] = (elem_type, shape) + + graph_def = GraphProto() + graph_def.name = predict_net.name + graph_def.initializer.extend(initializer) + # This is a mapping from Caffe2 names to ONNX names + graph_def.input.extend( + make_tensor_value_info( + name=name, + elem_type=value_info[name][0], + shape=value_info[name][1]) + for name in predict_net.external_input) + + dummy_name(cls._all_names_in_net(predict_net) | + cls._all_names_in_net(init_net)) + + for op in predict_net.op: + shapes = {} + for name in itertools.chain(op.input, op.output): + blob = ws.FetchBlob(name) + if hasattr(blob, 'shape'): + shapes[name] = blob.shape + graph_def.node.extend( + cls.caffe2_op_to_onnx_node( + op, shapes=shapes)) + + all_output = set(sum((list(node.output) for node in graph_def.node), + [init.name for init in graph_def.initializer])) + redundant_output = set(vi.name for vi in graph_def.output) - all_output + if redundant_output: + logger.warning( + 'There are graph output not produced by any node or initializer: {}' + '! Will drop them.'.format(', '.join(redundant_output))) + graph_def.output.extend( + make_tensor_value_info( + name=name, + elem_type=value_info[name][0], + shape=value_info[name][1]) + for name in predict_net.external_output + if name in all_output) + + cls._annotate_consumed(graph_def) + checker.check_graph(graph_def) + return graph_def + + @classmethod + def caffe2_init_net_to_initializer(cls, init_net): + initializer = [] + for op in init_net.op: + assert not op.input + try: + data_type, field_name = { + 'GivenTensorFill': (TensorProto.FLOAT, 'floats'), + 'GivenTensorInt64Fill': (TensorProto.INT64, 'ints'), + 'GivenTensorIntFill': (TensorProto.INT32, 'ints'), + 'GivenTensorBoolFill': (TensorProto.BOOL, 'ints'), + 'GivenTensorStringFill': (TensorProto.STRING, 'strings'), + }[op.type] + except KeyError: + raise RuntimeError( + "Can not translate init_net with operator '{}' " + "to initializer".format(op.type) + ) + raw = (data_type != TensorProto.STRING) + args = {a.name: a for a in op.arg} + vals = getattr(args['values'], field_name) + if raw: + vals = np.asarray( + vals, + dtype=mapping.TENSOR_TYPE_TO_NP_TYPE[data_type]).tobytes() + initializer.append(make_tensor( + name=op.output[0], + data_type=data_type, + dims=args['shape'].ints, + vals=vals, + raw=raw, + )) + return initializer + + @classmethod + def _annotate_consumed(cls, graph_def): + for node in graph_def.node: + schema = defs.get_schema(node.op_type) + consumes = [] + for i, _input_name in enumerate(node.input): + consume_type, output_idx = schema.consumed(i) + if consume_type == defs.OpSchema.UseType.CONSUME_ENFORCED: + consumes.append(1) + else: + consumes.append(0) + + if any(consumes): + node.attribute.extend([helper.make_attribute( + 'consumed_inputs', + consumes, + )]) + + @classmethod + def _ssa_rewrite(cls, net, init_net, value_info): + def ssa_name(name, version): + return '{}_{}'.format(name, version) + + if init_net: + for op in init_net.op: + assert re.match('GivenTensor.*Fill', op.type) + assert len(op.output) == 1 + op.output[0] = ssa_name(op.output[0], 0) + init_net.external_input[:] = [ssa_name(name, 0) + for name in init_net.external_input] + init_net.external_output[:] = [ssa_name(name, 0) + for name in init_net.external_output] + if value_info: + ssa_value_info = {ssa_name(name, 0): value + for name, value in value_info.items()} + value_info.clear() + value_info.update(ssa_value_info) + net.external_input[:] = [ssa_name(name, 0) + for name in net.external_input] + ssa, blob_versions = caffe2_core.get_ssa(net) + assert len(net.op) == len(ssa) + for op, (versioned_inputs, versioned_outputs) in zip(net.op, ssa): + op.input[:] = [ssa_name(name, version) + for name, version in versioned_inputs] + op.output[:] = [ssa_name(name, version) + for name, version in versioned_outputs] + net.external_output[:] = [ssa_name(name, blob_versions[name]) + for name in net.external_output] + + @classmethod + def caffe2_net_to_onnx_model(cls, *args, **kwargs): + model = make_model(cls.caffe2_net_to_onnx_graph(*args, **kwargs)) + opset_id = OperatorSetIdProto() + opset_id.domain = '' # ONNX + opset_id.version = cls._target_opset_version + model.opset_import.extend([opset_id]) + checker.check_model(model) + return model + + +caffe2_net_to_onnx_graph = Caffe2Frontend.caffe2_net_to_onnx_graph +caffe2_net_to_onnx_model = Caffe2Frontend.caffe2_net_to_onnx_model +caffe2_init_net_to_initializer = Caffe2Frontend.caffe2_init_net_to_initializer diff --git a/caffe2/python/onnx/helper.py b/caffe2/python/onnx/helper.py new file mode 100644 index 0000000000..7d95619e9a --- /dev/null +++ b/caffe2/python/onnx/helper.py @@ -0,0 +1,157 @@ +# Copyright (c) 2016-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## + +## @package onnx +# Module caffe2.python.onnx.helper +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from caffe2.proto import caffe2_pb2 +from onnx import helper +from onnx.backend.base import namedtupledict + +from caffe2.python.onnx.workspace import Workspace + +import io +import logging +import time + + +log = logging.getLogger(__name__) + + +class _DummyNameFactory(object): + used_names = set() + counter = 0 + + @classmethod + def dummy_name(cls, used_names=None): + if used_names is not None: + cls.used_names.clear() + cls.used_names.update(used_names) + cls.counter = 0 + return None + else: + while True: + name = 'OC2_DUMMY_{}'.format(cls.counter) + cls.counter += 1 + if name not in cls.used_names: + cls.used_names.add(name) + return name + +dummy_name = _DummyNameFactory.dummy_name + + +def make_model(graph, **kwargs): + kwargs.setdefault('producer_name', 'onnx-caffe2') + return helper.make_model(graph=graph, **kwargs) + + +def c2_native_run_op(op_def, inputs): + ws = Workspace() + if isinstance(inputs, dict): + for key, value in inputs.items(): + ws.FeedBlob(key, value, op_def.device_option) + else: + assert(len(op_def.input) == len(inputs)) + for key, value in zip(op_def.input, inputs): + ws.FeedBlob(key, value, op_def.device_option) + + ws.RunOperatorOnce(op_def) + + output_names = op_def.output + output_values = [ws.FetchBlob(name) for name in output_names] + return ws, namedtupledict('Outputs', output_names)(*output_values) + + +def c2_native_run_net(init_net, predict_net, inputs): + ws = Workspace() + if init_net: + ws.RunNetOnce(init_net) + + if isinstance(inputs, dict): + for key, value in inputs.items(): + ws.FeedBlob(key, value, predict_net.device_option) + else: + uninitialized = [input_name + for input_name in predict_net.external_input + if not ws.HasBlob(input_name)] + if len(uninitialized) == len(inputs): + for key, value in zip(uninitialized, inputs): + ws.FeedBlob(key, value, predict_net.device_option) + else: + # If everything is initialized, + # we just initialized the first len(inputs) external_input. + assert(len(inputs) <= len(predict_net.external_input)) + for i in range(len(inputs)): + ws.FeedBlob(predict_net.external_input[i], inputs[i], + predict_net.device_option) + + ws.RunNetOnce(predict_net) + + output_names = predict_net.external_output + output_values = [ws.FetchBlob(name) for name in output_names] + return ws, namedtupledict('Outputs', output_names)(*output_values) + + +def load_caffe2_net(file): + net = caffe2_pb2.NetDef() + with open(file, "rb") as f: + net.ParseFromString(f.read()) + return net + + +def save_caffe2_net(net, file, output_txt=False): + with open(file, "wb") as f: + f.write(net.SerializeToString()) + if output_txt: + with open(file + "txt", "w") as f: + f.write(str(net)) + + +def benchmark_caffe2_model(init_net, predict_net, warmup_iters=3, main_iters=10, layer_details=True): + ''' + Run the benchmark net on the target model. + Return the execution time per iteration (millisecond). + ''' + ws = Workspace() + if init_net: + ws.RunNetOnce(init_net) + ws.CreateNet(predict_net) + results = ws.BenchmarkNet(predict_net.name, warmup_iters, main_iters, layer_details) + del ws + return results[0] + + +def benchmark_pytorch_model(model, inputs, training=False, warmup_iters=3, + main_iters=10, verbose=False): + ''' + Run the model several times, and measure the execution time. + Return the execution time per iteration (millisecond). + ''' + for _i in range(warmup_iters): + model(*inputs) + total_pytorch_time = 0.0 + for _i in range(main_iters): + ts = time.time() + model(*inputs) + te = time.time() + total_pytorch_time += te - ts + log.info("The PyTorch model execution time per iter is {} milliseconds, " + "{} iters per second.".format(total_pytorch_time / main_iters * 1000, + main_iters / total_pytorch_time)) + return total_pytorch_time * 1000 / main_iters diff --git a/caffe2/python/onnx/tests/caffe2_ref_test.py b/caffe2/python/onnx/tests/caffe2_ref_test.py new file mode 100644 index 0000000000..d6bc396451 --- /dev/null +++ b/caffe2/python/onnx/tests/caffe2_ref_test.py @@ -0,0 +1,357 @@ +# Copyright (c) 2016-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## + +## @package onnx +# Module caffe2.python.onnx.tests.caffe2_ref_test + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import json +import os +import unittest + +from caffe2.python import core +from caffe2.proto import caffe2_pb2 + +import onnx +from onnx.helper import make_node, make_graph, make_tensor, make_tensor_value_info +from caffe2.python.onnx.helper import make_model, c2_native_run_net, c2_native_run_op + +from onnx import defs, mapping +import caffe2.python.onnx.frontend as c2_onnx +import caffe2.python.onnx.backend as c2 + +import numpy as np +from caffe2.python.models.download import downloadFromURLToFile, getURLFromName, deleteDirectory + +from caffe2.python.onnx.helper import dummy_name +from caffe2.python.onnx.tests.test_utils import TestCase + + +class TestCaffe2Basic(TestCase): + def test_dummy_name(self): + n1 = dummy_name() + n2 = dummy_name() + assert n1 != n2, "Got same names in different calls: {}".format(n1) + + def test_relu_node_inplace(self): + X = np.random.randn(3, 2).astype(np.float32) + Y_ref = np.clip(X, 0, np.inf) + + node_def = make_node( + "Relu", ["X"], ["Y"], consumed_inputs=[1]) + output = c2.run_node( + node_def, {"X": X}) + np.testing.assert_almost_equal(output.X, Y_ref) + + node_def = make_node( + "Relu", ["X"], ["Y"], consumed_inputs=[1]) + graph_def = make_graph( + [node_def], + name="test", + inputs=[make_tensor_value_info("X", onnx.TensorProto.FLOAT, [3, 2])], + outputs=[make_tensor_value_info("X", onnx.TensorProto.FLOAT, [3, 2])]) + c2_rep = c2.prepare(make_model(graph_def)) + output = c2_rep.run({"X": X}) + np.testing.assert_almost_equal(output.X, Y_ref) + + def test_relu_graph(self): + X = np.random.randn(3, 2).astype(np.float32) + Y_ref = np.clip(X, 0, np.inf) + + node_def = make_node( + "Relu", ["X"], ["Y"]) + output = c2.run_node( + node_def, {"X": X}) + np.testing.assert_almost_equal(output.Y, Y_ref) + + graph_def = make_graph( + [node_def], + name="test", + inputs=[make_tensor_value_info("X", onnx.TensorProto.FLOAT, [3, 2])], + outputs=[make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [3, 2])]) + c2_rep = c2.prepare(make_model(graph_def)) + output = c2_rep.run(X) + np.testing.assert_almost_equal(output.Y, Y_ref) + + def test_initializer(self): + X = np.array([[1, 2], [3, 4]]).astype(np.float32) + Y = np.array([[1, 2], [3, 4]]).astype(np.float32) + weight = np.array([[1, 0], [0, 1]]) + graph_def = make_graph( + [make_node("Add", ["X", "Y"], ["Z0"]), + make_node("Cast", ["Z0"], ["Z"], to="float"), + make_node("Mul", ["Z", "weight"], ["W0"]), + make_node("Tanh", ["W0"], ["W1"]), + make_node("Sigmoid", ["W1"], ["W2"]), + make_node("Scale", ["W2"], ["W3"], scale=-1.0)], + name="test_initializer", + inputs=[ + make_tensor_value_info("X", onnx.TensorProto.FLOAT, (2, 2)), + make_tensor_value_info("Y", onnx.TensorProto.FLOAT, (2, 2)), + make_tensor_value_info("weight", onnx.TensorProto.FLOAT, (2, 2)), + ], + outputs=[ + make_tensor_value_info("W3", onnx.TensorProto.FLOAT, (2, 2)) + ], + initializer=[make_tensor("weight", + onnx.TensorProto.FLOAT, + [2, 2], + weight.flatten().astype(float))] + ) + + def sigmoid(x): + return 1 / (1 + np.exp(-x)) + + W_ref = -sigmoid(np.tanh((X + Y) * weight)) + c2_rep = c2.prepare(make_model(graph_def)) + output = c2_rep.run({"X": X, "Y": Y}) + np.testing.assert_almost_equal(output["W3"], W_ref) + + def test_gemm(self): + # simple + A = np.random.randn(3, 2).astype(np.float32) + B = np.random.randn(2, 4).astype(np.float32) + C = np.random.randn(3, 4).astype(np.float32) + node_def = make_node( + 'Gemm', + ['A', 'B', 'C'], + ["Y"]) + output = c2.run_node(node_def, [A, B, C]) + np.testing.assert_almost_equal(output["Y"], np.dot(A, B) + C) + + # transA + A = np.transpose(A) + node_def = make_node( + 'Gemm', + ['A', 'B', 'C'], + ["Y"], + transA=True) + output = c2.run_node(node_def, [A, B, C]) + np.testing.assert_almost_equal( + output["Y"], + np.dot(np.transpose(A), B) + C) + # revert A + A = np.transpose(A) + + # transB + B = np.transpose(B) + node_def = make_node( + 'Gemm', + ['A', 'B', 'C'], + ["Y"], + transB=True) + output = c2.run_node(node_def, [A, B, C]) + np.testing.assert_almost_equal( + output["Y"], + np.dot(A, np.transpose(B)) + C) + # revert A + B = np.transpose(B) + + # scale + alpha = np.random.random() + beta = np.random.random() + node_def = make_node( + 'Gemm', + ['A', 'B', 'C'], + ["Y"], + alpha=alpha, + beta=beta) + output = c2.run_node(node_def, [A, B, C]) + np.testing.assert_almost_equal( + output["Y"], + alpha * np.dot(A, B) + beta * C) + + # broadcast + C = np.random.randn(4).astype(np.float32) + node_def = make_node( + 'Gemm', + ['A', 'B', 'C'], + ["Y"], + alpha=alpha, + beta=beta, + broadcast=1) + output = c2.run_node(node_def, [A, B, C]) + np.testing.assert_almost_equal( + output["Y"], + alpha * np.dot(A, B) + beta * C) + + def test_tensor_filling_ops(self): + for dtype in [ + onnx.TensorProto.FLOAT, + onnx.TensorProto.DOUBLE, + onnx.TensorProto.BOOL, + onnx.TensorProto.INT8, + onnx.TensorProto.INT16, + onnx.TensorProto.INT32, + onnx.TensorProto.INT64, + onnx.TensorProto.UINT8, + onnx.TensorProto.UINT16, + onnx.TensorProto.UINT32, + ]: + shape = (1, 2, 3) + vals = np.random.randn(*shape) + if dtype != onnx.TensorProto.BOOL: + vals *= 5 + vals = vals.astype( + mapping.TENSOR_TYPE_TO_NP_TYPE[dtype]) + tensor = make_tensor( + name='test-tensor-{}'.format(dtype), + data_type=dtype, + dims=[1, 2, 3], + vals=vals.flatten().tolist(), + ) + op = c2.Caffe2Backend._create_tensor_filling_op(tensor) + self.assertEqual(len(op.input), 0) + self.assertEqual(op.output, [tensor.name]) + ws, output = c2_native_run_op(op, inputs=[]) + self.assertEqual(len(output), 1) + np.testing.assert_almost_equal(output[0], vals) + np.testing.assert_almost_equal(ws.FetchBlob(op.output[0]), vals) + + def test_slice(self): + X = np.random.randn(1, 2, 3).astype(np.float32) + starts = np.array([0, 1, 0], dtype=np.int32) + ends = np.array([-1, 2, 3], dtype=np.int32) + + predict_net = caffe2_pb2.NetDef() + predict_net.name = 'test-slice-net' + predict_net.external_input[:] = ['X'] + predict_net.external_output[:] = ['Y'] + predict_net.op.extend([ + core.CreateOperator( + 'Slice', + inputs=['X'], + outputs=['Y'], + starts=starts, + ends=ends, + ), + ]) + ws, (Y,) = c2_native_run_net( + init_net=None, + predict_net=predict_net, + inputs=[X]) + + onnx_model = c2_onnx.caffe2_net_to_onnx_model( + predict_net=predict_net, + value_info={ + 'X': (onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[X.dtype], X.shape) + }) + Y, = c2.run_model(onnx_model, inputs=[X]) + np.testing.assert_almost_equal(Y, X[:, 1:2, :]) + + +class TestCaffe2End2End(TestCase): + def _model_dir(self, model): + caffe2_home = os.path.expanduser(os.getenv('ONNX_HOME', '~/.caffe2')) + models_dir = os.getenv('ONNX_MODELS', os.path.join(caffe2_home, 'models')) + return os.path.join(models_dir, model) + + def _test_net(self, + net_name, + input_blob_dims=(1, 3, 224, 224), + decimal=7): + np.random.seed(seed=0) + model_dir = self._model_dir(net_name) + if not os.path.exists(model_dir): + self._download(net_name) + c2_predict_pb = os.path.join(model_dir, 'predict_net.pb') + c2_predict_net = caffe2_pb2.NetDef() + with open(c2_predict_pb, 'rb') as f: + c2_predict_net.ParseFromString(f.read()) + c2_predict_net.name = net_name + + c2_init_pb = os.path.join(model_dir, 'init_net.pb') + c2_init_net = caffe2_pb2.NetDef() + with open(c2_init_pb, 'rb') as f: + c2_init_net.ParseFromString(f.read()) + c2_init_net.name = net_name + '_init' + + n, c, h, w = input_blob_dims + data = np.random.randn(n, c, h, w).astype(np.float32) + inputs = [data] + _, c2_outputs = c2_native_run_net(c2_init_net, c2_predict_net, inputs) + del _ + + model = c2_onnx.caffe2_net_to_onnx_model( + predict_net=c2_predict_net, + init_net=c2_init_net, + value_info=json.load(open(os.path.join(model_dir, 'value_info.json')))) + c2_ir = c2.prepare(model) + onnx_outputs = c2_ir.run(inputs) + self.assertSameOutputs(c2_outputs, onnx_outputs, decimal=decimal) + + def _download(self, model): + model_dir = self._model_dir(model) + assert not os.path.exists(model_dir) + os.makedirs(model_dir) + for f in ['predict_net.pb', 'init_net.pb', 'value_info.json']: + url = getURLFromName(model, f) + dest = os.path.join(model_dir, f) + try: + try: + downloadFromURLToFile(url, dest, + show_progress=False) + except TypeError: + # show_progress not supported prior to + # Caffe2 78c014e752a374d905ecfb465d44fa16e02a28f1 + # (Sep 17, 2017) + downloadFromURLToFile(url, dest) + except Exception as e: + print("Abort: {reason}".format(reason=e)) + print("Cleaning up...") + deleteDirectory(model_dir) + exit(1) + + def test_alexnet(self): + self._test_net('bvlc_alexnet', decimal=4) + + def test_resnet50(self): + self._test_net('resnet50') + + @unittest.skipIf( + os.environ.get('JENKINS_URL'), + 'Taking too long to download!') + def test_vgg16(self): + self._test_net('vgg16') + + @unittest.skipIf( + os.environ.get('JENKINS_URL'), + 'Running vgg19 on Travis with Python 2 keeps getting OOM!') + def test_vgg19(self): + self._test_net('vgg19') + + def test_inception_v1(self): + self._test_net('inception_v1', decimal=2) + + def test_inception_v2(self): + self._test_net('inception_v2') + + @unittest.skip('Need to add support for ConstantFill operator') + def test_squeezenet(self): + self._test_net('squeezenet') + + def test_shufflenet(self): + self._test_net('shufflenet') + + def test_densenet121(self): + self._test_net('densenet121') + + +if __name__ == '__main__': + unittest.main() diff --git a/caffe2/python/onnx/tests/conversion_test.py b/caffe2/python/onnx/tests/conversion_test.py new file mode 100644 index 0000000000..99c7d94848 --- /dev/null +++ b/caffe2/python/onnx/tests/conversion_test.py @@ -0,0 +1,244 @@ +# Copyright (c) 2016-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## + +## @package onnx +# Module caffe2.python.onnx.tests.conversion_test + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import tempfile +import textwrap +import traceback + +from caffe2.proto import caffe2_pb2 +from caffe2.python import brew, core +from caffe2.python.model_helper import ModelHelper +from click.testing import CliRunner +import numpy as np +from onnx import helper, ModelProto, TensorProto +from caffe2.python.onnx.helper import make_model, c2_native_run_net + +from caffe2.python.onnx.bin.conversion import caffe2_to_onnx, onnx_to_caffe2 +from caffe2.python.onnx.helper import dummy_name +import caffe2.python.onnx.backend as c2 +from caffe2.python.onnx.tests.test_utils import TestCase + + +class TestConversion(TestCase): + def _run_command(self, cmd, *args, **kwargs): + runner = CliRunner() + result = runner.invoke(cmd, *args, **kwargs) + self.assertEqual(result.exit_code, 0, textwrap.dedent(''' + Command exited with non-zero exit code: + output: {} + exception: {} + exc_info: {} + '''.format(result.output, + result.exception, + traceback.format_exception(*result.exc_info)))) + return result + + def test_caffe2_to_onnx(self): + caffe2_net = tempfile.NamedTemporaryFile() + caffe2_init_net = tempfile.NamedTemporaryFile() + output = tempfile.NamedTemporaryFile() + + model = ModelHelper(name='caffe2-to-onnx-test') + brew.relu(model, ["X"], "Y") + caffe2_net.write(model.net.Proto().SerializeToString()) + caffe2_net.flush() + + init_model = ModelHelper(name='caffe2-to-onnx-init-test') + init_model.net.GivenTensorFill([], 'X', shape=[2, 2], + values=np.zeros((2, 2)).flatten().astype(float)) + caffe2_init_net.write(init_model.net.Proto().SerializeToString()) + caffe2_init_net.flush() + + result = self._run_command( + caffe2_to_onnx, [ + caffe2_net.name, + '--caffe2-init-net', caffe2_init_net.name, + '--output', output.name, + ], + catch_exceptions=False, + ) + + onnx_model = ModelProto() + onnx_model.ParseFromString(output.read()) + self.assertEqual(len(onnx_model.graph.node), 1) + self.assertEqual(onnx_model.graph.node[0].op_type, 'Relu') + self.assertEqual(len(onnx_model.graph.initializer), 1) + self.assertEqual(onnx_model.graph.initializer[0].name, onnx_model.graph.input[0].name) + + def test_caffe2_to_onnx_value_info(self): + caffe2_net = tempfile.NamedTemporaryFile() + output = tempfile.NamedTemporaryFile() + + model = ModelHelper(name='caffe2-to-onnx-test') + brew.relu(model, ["X"], "Y") + caffe2_net.write(model.net.Proto().SerializeToString()) + caffe2_net.flush() + + args = [caffe2_net.name, '--output', output.name] + self.assertRaisesRegexp(Exception, + 'value info', + self._run_command, caffe2_to_onnx, args) + + args.extend([ + '--value-info', + json.dumps({ + 'X': (TensorProto.FLOAT, (2, 2)), + })]) + result = self._run_command(caffe2_to_onnx, args) + + onnx_model = ModelProto() + onnx_model.ParseFromString(output.read()) + self.assertEqual(len(onnx_model.graph.node), 1) + self.assertEqual(onnx_model.graph.node[0].op_type, 'Relu') + self.assertEqual(len(onnx_model.graph.initializer), 0) + + def test_onnx_to_caffe2(self): + onnx_model = tempfile.NamedTemporaryFile() + output = tempfile.NamedTemporaryFile() + init_net_output = tempfile.NamedTemporaryFile() + + node_def = helper.make_node( + "Mul", ["X", "W"], ["Y"]) + graph_def = helper.make_graph( + [node_def], + "test", + [helper.make_tensor_value_info("X", TensorProto.FLOAT, (2, 3)), + helper.make_tensor_value_info("W", TensorProto.FLOAT, (3, 2))], + [helper.make_tensor_value_info("Y", TensorProto.FLOAT, (2, 2))], + initializer=[helper.make_tensor("W", + TensorProto.FLOAT, + [3, 2], + np.zeros((3, 2)).flatten().astype(float))]) + model_def = make_model(graph_def, producer_name='onnx-to-caffe2-test') + onnx_model.write(model_def.SerializeToString()) + onnx_model.flush() + + result = self._run_command( + onnx_to_caffe2, [ + onnx_model.name, + '--output', output.name, + '--init-net-output', init_net_output.name, + ]) + + caffe2_net = caffe2_pb2.NetDef() + caffe2_net.ParseFromString(output.read()) + self.assertEqual(len(caffe2_net.op), 1) + self.assertEqual(caffe2_net.op[0].type, 'Mul') + + caffe2_init_net = caffe2_pb2.NetDef() + caffe2_init_net.ParseFromString(init_net_output.read()) + self.assertEqual(len(caffe2_init_net.op), 1) + self.assertEqual(set(sum([list(init_op.output) + for init_op in caffe2_init_net.op], [])), + {'W'}) + + def test_convert_end2end(self): + predict_net_f = tempfile.NamedTemporaryFile() + init_net_f = tempfile.NamedTemporaryFile() + onnx_model_f = tempfile.NamedTemporaryFile() + + x = 'X' + w = 'W' + b = 'b' + y = 'Y' + + predict_net = caffe2_pb2.NetDef() + predict_net.name = 'test-convert-end2end' + predict_net.external_input[:] = [x, w, b] + predict_net.external_output[:] = [y] + predict_net.op.extend([ + core.CreateOperator( + 'FC', + inputs=[x, w, b], + outputs=[y], + axis=2, + ), + ]) + predict_net_f.write(predict_net.SerializeToString()) + predict_net_f.flush() + + init_net = caffe2_pb2.NetDef() + init_net.name = 'test-convert-end2end-init' + init_net.external_output[:] = [w, b] + x_val = np.random.randn(1, 3, 2).astype(np.float32) + w_val = np.random.randn(4, 2).astype(np.float32) + b_val = np.random.randn(4).astype(np.float32) + init_net.op.extend([ + core.CreateOperator( + 'GivenTensorFill', + [], + [w], + values=w_val, + shape=w_val.shape, + ), + core.CreateOperator( + 'GivenTensorFill', + [], + [b], + values=b_val, + shape=b_val.shape, + ), + ]) + init_net_f.write(init_net.SerializeToString()) + init_net_f.flush() + + y_val = np.matmul(x_val, w_val.transpose()) + b_val + for _ in range(5): + self._run_command( + caffe2_to_onnx, [ + predict_net_f.name, + '--caffe2-init-net', init_net_f.name, + '--output', onnx_model_f.name, + '--value-info', + json.dumps({ + x: (TensorProto.FLOAT, (1, 3, 2)), + }), + ], + catch_exceptions=False, + ) + + onnx_model_f.seek(0) + onnx_model = ModelProto() + onnx_model.ParseFromString(onnx_model_f.read()) + np.testing.assert_almost_equal( + c2.run_model( + onnx_model, {onnx_model.graph.input[0].name: x_val}), + [y_val]) + + self._run_command( + onnx_to_caffe2, [ + onnx_model_f.name, + '--output', predict_net_f.name, + '--init-net-output', init_net_f.name, + ]) + predict_net_f.seek(0) + predict_net = caffe2_pb2.NetDef() + predict_net.ParseFromString(predict_net_f.read()) + init_net_f.seek(0) + init_net = caffe2_pb2.NetDef() + init_net.ParseFromString(init_net_f.read()) + x = predict_net.external_input[0] + np.testing.assert_almost_equal(c2_native_run_net(init_net=init_net, + predict_net=predict_net, + inputs={x: x_val})[1], + [y_val]) diff --git a/caffe2/python/onnx/tests/helper_test.py b/caffe2/python/onnx/tests/helper_test.py new file mode 100644 index 0000000000..f0ce6498cb --- /dev/null +++ b/caffe2/python/onnx/tests/helper_test.py @@ -0,0 +1,45 @@ +# Copyright (c) 2016-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## + +## @package onnx +# Module caffe2.python.onnx.tests.helper_test + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import unittest + +from caffe2.python.onnx.helper import dummy_name + +from caffe2.python.onnx.tests.test_utils import TestCase + + +class TestCaffe2Basic(TestCase): + def test_dummy_name(self): + dummy_name([]) + names_1 = [dummy_name() for _ in range(3)] + dummy_name([]) + names_2 = [dummy_name() for _ in range(3)] + self.assertEqual(names_1, names_2) + + dummy_name(names_1) + names_3 = [dummy_name() for _ in range(3)] + self.assertFalse(set(names_1) & set(names_3)) + + +if __name__ == '__main__': + unittest.main() diff --git a/caffe2/python/onnx/tests/onnx_backend_test.py b/caffe2/python/onnx/tests/onnx_backend_test.py new file mode 100644 index 0000000000..941a15d5a0 --- /dev/null +++ b/caffe2/python/onnx/tests/onnx_backend_test.py @@ -0,0 +1,50 @@ +# Copyright (c) 2016-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## + +## @package onnx +# Module caffe2.python.onnx.tests.onnx_backend_test + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import os + +import unittest +import onnx.backend.test + +import caffe2.python.onnx.backend as c2 + +# This is a pytest magic variable to load extra plugins +pytest_plugins = 'onnx.backend.test.report', + +backend_test = onnx.backend.test.BackendTest(c2, __name__) + +backend_test.exclude(r'(test_ceil|test_floor' # Does not support Ceil and Floor. + '|test_hardsigmoid|test_pow' # Does not support Hardsigmoid and Pow. + '|test_mean|test_hardmax)') # Does not support Mean and Hardmax. + +# Skip vgg to speed up CI +if 'JENKINS_URL' in os.environ: + backend_test.exclude(r'(test_vgg19|test_vgg)') + +# import all test cases at global scope to make them visible to python.unittest +globals().update(backend_test + .enable_report() + .test_cases) + +if __name__ == '__main__': + unittest.main() diff --git a/caffe2/python/onnx/tests/optimize_onnx_test.py b/caffe2/python/onnx/tests/optimize_onnx_test.py new file mode 100644 index 0000000000..6efbfc52c9 --- /dev/null +++ b/caffe2/python/onnx/tests/optimize_onnx_test.py @@ -0,0 +1,118 @@ +# Copyright (c) 2016-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## + +## @package onnx +# Module caffe2.python.onnx.tests.optimize_onnx_test + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import os +import tarfile +import tempfile +import unittest + +from collections import namedtuple +from subprocess import Popen, PIPE +from six.moves.urllib.request import urlretrieve +import numpy as np + +import onnx +from onnx import helper, ModelProto, TensorProto +from onnx.backend.test.runner import Runner +import caffe2.python.onnx.backend as c2 + +from caffe2.python.onnx.tests.test_utils import TestCase + +class TestRoundtrip(TestCase): + def _roundtrip(self, model_name): + model_dir = Runner(c2)._prepare_model_data( + namedtuple('dummy', ['model_name'])(model_name)) + + pb_path = os.path.join(model_dir, 'model.pb') + + before_roundtrip = onnx.load(pb_path) + + with open(pb_path, 'rb') as pb: + after_roundtrip = onnx.load_from_string(pb.read()) + + assert onnx.helper.printable_graph(before_roundtrip.graph) \ + == onnx.helper.printable_graph(after_roundtrip.graph) + + with open(pb_path, 'rb') as pb: + assert after_roundtrip.SerializeToString() == pb.read() + + # arbitrarily pick one relatively small model to sanity test with + def test_squeezenet_v3(self): + self._roundtrip('squeezenet-ir-version-3') + + # testing just to be sure that we no-op instead of breaking on an + # older IR version. + def test_squeezenet_v1(self): + self._roundtrip('squeezenet-ir-version-1') + +class TestOptimize(TestCase): + def _optimized(self, graph): + orig_model = helper.make_model(graph, producer_name='onnx-to-caffe2-test') + orig_model_str = orig_model.SerializeToString() + optimized_model_str = c2.Caffe2Backend.optimize_onnx(orig_model_str) + optimized_model = ModelProto() + optimized_model.ParseFromString(optimized_model_str) + return optimized_model + + def test_nop_transpose(self): + trans = helper.make_node("Transpose", ["X"], ["Y"], perm=[0,1]) + graph = helper.make_graph( + [trans], + "test", + [helper.make_tensor_value_info("X", TensorProto.FLOAT, (2, 3))], + [helper.make_tensor_value_info("Y", TensorProto.FLOAT, (3, 2))]) + optimized_model = self._optimized(graph) + + for node in optimized_model.graph.node: + assert node.op_type != "Transpose" + + def test_fuse_transpose(self): + trans1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[1,0,2]) + trans2 = helper.make_node("Transpose", ["Y"], ["Z"], perm=[2,0,1]) + trans3 = helper.make_node("Transpose", ["Z"], ["A"], perm=[2,0,1]) + graph = helper.make_graph( + [trans1, trans2, trans3], + "test", + [helper.make_tensor_value_info("X", TensorProto.FLOAT, (2, 3, 4))], + [helper.make_tensor_value_info("A", TensorProto.FLOAT, (4, 3, 2))]) + optimized_model = self._optimized(graph) + + assert len(list(optimized_model.graph.node)) == 1 + + def test_fuse_transpose_into_gemm(self): + trans1 = helper.make_node("Transpose", ["X"], ["A"], perm=[1,0]) + trans2 = helper.make_node("Transpose", ["Y"], ["B"], perm=[1,0]) + gemm = helper.make_node("Gemm", ["A", "B", "C"], ["Z"]) + graph = helper.make_graph( + [trans1, trans2, gemm], + "test", + [helper.make_tensor_value_info("X", TensorProto.FLOAT, (2, 3)), + helper.make_tensor_value_info("Y", TensorProto.FLOAT, (5, 2)), + helper.make_tensor_value_info("C", TensorProto.FLOAT, (3, 5))], + [helper.make_tensor_value_info("Z", TensorProto.FLOAT, (3, 5))]) + optimized_model = self._optimized(graph) + + assert len(list(optimized_model.graph.node)) == 1 + +if __name__ == '__main__': + unittest.main() diff --git a/caffe2/python/onnx/tests/ssa_test.py b/caffe2/python/onnx/tests/ssa_test.py new file mode 100644 index 0000000000..851482e494 --- /dev/null +++ b/caffe2/python/onnx/tests/ssa_test.py @@ -0,0 +1,122 @@ +# Copyright (c) 2016-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## + +## @package onnx +# Module caffe2.python.onnx.tests.ssa_test + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import onnx +import numpy as np +from caffe2.proto import caffe2_pb2 +from caffe2.python import core +from onnx import helper, TensorProto + +import caffe2.python.onnx.frontend as c2_onnx +from caffe2.python.onnx.helper import c2_native_run_net +from caffe2.python.onnx.tests.test_utils import TestCase + + +class TestFrontendSSAConversion(TestCase): + def test_ssa(self): + X = np.random.randn(4, 2).astype(np.float32) + W = np.random.randn(3, 2).astype(np.float32) + b = np.random.randn(3).astype(np.float32) + s = np.random.randn(1).astype(np.float32) + np_result = X.dot(W.transpose()) + b + s + + net = caffe2_pb2.NetDef() + net.name = 'test-ssa' + net.external_input[:] = ['W', 'X', 'b', 's'] + net.op.extend([ + core.CreateOperator( + 'FC', + ['X', 'W', 'b'], + ['Y'] + ), + core.CreateOperator( + 'Add', + ['Y', 's'], + ['Y'], + broadcast=True, + ) + ]) + net.external_output[:] = ['Y'] + + init_net = caffe2_pb2.NetDef() + init_net.name = 'test-ssa-init' + init_net.op.extend([ + core.CreateOperator( + 'GivenTensorFill', + [], + ['W'], + values=W, + shape=W.shape, + ), + core.CreateOperator( + 'GivenTensorFill', + [], + ['b'], + values=b, + shape=b.shape, + ), + core.CreateOperator( + 'GivenTensorFill', + [], + ['s'], + values=s, + shape=s.shape, + ) + ]) + init_net.external_output[:] = ['W', 'b', 's'] + + _, orig_output = c2_native_run_net( + predict_net=net, + init_net=init_net, + inputs=[X]) + + value_info = {'X': (TensorProto.FLOAT, X.shape)} + c2_onnx.Caffe2Frontend._ssa_rewrite( + net, + init_net, + value_info) + + self.assertEqual(net.external_input, ['W_0', 'X_0', 'b_0', 's_0']) + self.assertEqual(net.op[0].input, ['X_0', 'W_0', 'b_0']) + self.assertEqual(net.op[0].output, ['Y_1']) + self.assertEqual(net.op[1].input, ['Y_1', 's_0']) + self.assertEqual(net.op[1].output, ['Y_2']) + self.assertEqual(net.external_output, ['Y_2']) + + self.assertEqual(init_net.external_input, []) + self.assertEqual(init_net.op[0].input, []) + self.assertEqual(init_net.op[0].output, ['W_0']) + self.assertEqual(init_net.op[1].input, []) + self.assertEqual(init_net.op[1].output, ['b_0']) + self.assertEqual(init_net.op[2].input, []) + self.assertEqual(init_net.op[2].output, ['s_0']) + self.assertEqual(init_net.external_output, ['W_0', 'b_0', 's_0']) + self.assertEqual(value_info, {'X_0': (TensorProto.FLOAT, X.shape)}) + + _, ssa_output = c2_native_run_net( + predict_net=net, + init_net=init_net, + inputs=[X]) + + self.assertSameOutputs(ssa_output, orig_output) + self.assertSameOutputs(ssa_output, [np_result]) diff --git a/caffe2/python/onnx/tests/test_utils.py b/caffe2/python/onnx/tests/test_utils.py new file mode 100644 index 0000000000..8d8a9ceeca --- /dev/null +++ b/caffe2/python/onnx/tests/test_utils.py @@ -0,0 +1,43 @@ +# Copyright (c) 2016-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## + +## @package onnx +# Module caffe2.python.onnx.tests.test_utils + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import unittest + +import numpy as np + + +class TestCase(unittest.TestCase): + def setUp(self): + np.random.seed(seed=0) + + def assertSameOutputs(self, outputs1, outputs2, decimal=7): + self.assertEqual(len(outputs1), len(outputs2)) + for o1, o2 in zip(outputs1, outputs2): + np.testing.assert_almost_equal(o1, o2, decimal=decimal) + + def add_test_case(name, test_func): + if not name.startswith('test_'): + raise ValueError('Test name must start with test_: {}'.format(name)) + if hasattr(self, name): + raise ValueError('Duplicated test name: {}'.format(name)) + setattr(self, name, test_func) diff --git a/caffe2/python/onnx/workspace.py b/caffe2/python/onnx/workspace.py new file mode 100644 index 0000000000..a115c91861 --- /dev/null +++ b/caffe2/python/onnx/workspace.py @@ -0,0 +1,80 @@ +# Copyright (c) 2016-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## + +## @package onnx +# Module caffe2.python.onnx.workspace + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import uuid + +from caffe2.python import workspace + + +class Workspace(object): + """ + An object representing a Caffe2 workspace. It is a context manager, + so you can say 'with workspace:' to use the represented workspace + as your global workspace. It also supports every method supported + by caffe2.python.workspace, but instead of running these operations + in the global workspace, it runs them in the workspace represented + by this object. When this object goes dead, the workspace (and all + nets and blobs within it) are freed. + + Why do we need this class? Caffe2's workspace model is very "global state" + oriented, in that there is always some ambient global workspace you are + working in which holds on to all of your networks and blobs. This class + makes it possible to work with workspaces more locally, and without + forgetting to deallocate everything in the end. + """ + def __init__(self): + # Caffe2 (apparently) doesn't provide any native method of generating + # a fresh, unused workspace, so we have to fake it by generating + # a unique ID and hoping it's not used already / will not be used + # directly in the future. + self.workspace_id = str(uuid.uuid4()) + # A stack, so that the context manager is reentrant. + self.workspace_stack = [] + + def __getattr__(self, attr): + def f(*args, **kwargs): + with self: + return getattr(workspace, attr)(*args, **kwargs) + return f + + def __enter__(self): + self.workspace_stack.append(workspace.CurrentWorkspace()) + workspace.SwitchWorkspace(self.workspace_id, create_if_missing=True) + + def __exit__(self, exc_type, exc_value, traceback): + w = self.workspace_stack.pop() + # Strictly speaking, create_if_missing here is unnecessary, since a user + # is not supposed to be allowed to destruct a workspace while we're in + # it. However, empirically, it has been observed that during abnormal + # shutdown, Caffe2 deletes its default workspace fairly early in the + # final calls to destructors. In this case, we may attempt to exit + # to a default workspace which no longer exists. create_if_missing=True + # will (harmlessly) recreate the workspace before we finally quit.) + workspace.SwitchWorkspace(w, create_if_missing=True) + + def __del__(self): + # NB: This is a 'self' call because we need to switch into the workspace + # we want to reset before we actually reset it. A direct call to + # workspace.ResetWorkspace() will reset the ambient workspace, which + # is not want we want. + self.ResetWorkspace() diff --git a/docker/jenkins/common/install_python.sh b/docker/jenkins/common/install_python.sh index 5d01999e82..d262d7ed88 100755 --- a/docker/jenkins/common/install_python.sh +++ b/docker/jenkins/common/install_python.sh @@ -146,3 +146,4 @@ pip install --no-cache-dir \ scikit-image \ tabulate \ virtualenv + |