summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYinghai Lu <yinghai@fb.com>2018-02-20 13:56:52 -0800
committerGitHub <noreply@github.com>2018-02-20 13:56:52 -0800
commitcc7e61c88d8250c31bd5cd4897323335bc550d3a (patch)
tree8afe634a8aba693a92344b0dc98c9cce5fb03651
parent7283d5194a71e5fb47baa7025f7d3874b8fab3c6 (diff)
downloadpytorch-cc7e61c88d8250c31bd5cd4897323335bc550d3a.tar.gz
pytorch-cc7e61c88d8250c31bd5cd4897323335bc550d3a.tar.bz2
pytorch-cc7e61c88d8250c31bd5cd4897323335bc550d3a.zip
Move onnx-caffe2 inside caffe2 (#1921)
* Move onnx-caffe2 inside caffe2 * Update to the lastest onnx-caffe2 and update jenkins env * Rename onnx_caffe2 to onnx * Add __init__.py to caffe2/python/onnx * Change CI check variable to JENKINS_URL * Cherrypick recent onnx-caffe2 update
-rw-r--r--caffe2/python/onnx/__init__.py0
-rw-r--r--caffe2/python/onnx/backend.py1159
-rw-r--r--caffe2/python/onnx/backend_rep.py77
-rw-r--r--caffe2/python/onnx/bin/conversion.py104
-rw-r--r--caffe2/python/onnx/error.py23
-rw-r--r--caffe2/python/onnx/frontend.py551
-rw-r--r--caffe2/python/onnx/helper.py157
-rw-r--r--caffe2/python/onnx/tests/caffe2_ref_test.py357
-rw-r--r--caffe2/python/onnx/tests/conversion_test.py244
-rw-r--r--caffe2/python/onnx/tests/helper_test.py45
-rw-r--r--caffe2/python/onnx/tests/onnx_backend_test.py50
-rw-r--r--caffe2/python/onnx/tests/optimize_onnx_test.py118
-rw-r--r--caffe2/python/onnx/tests/ssa_test.py122
-rw-r--r--caffe2/python/onnx/tests/test_utils.py43
-rw-r--r--caffe2/python/onnx/workspace.py80
-rwxr-xr-xdocker/jenkins/common/install_python.sh1
16 files changed, 3131 insertions, 0 deletions
diff --git a/caffe2/python/onnx/__init__.py b/caffe2/python/onnx/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/caffe2/python/onnx/__init__.py
diff --git a/caffe2/python/onnx/backend.py b/caffe2/python/onnx/backend.py
new file mode 100644
index 0000000000..95f8f59f0f
--- /dev/null
+++ b/caffe2/python/onnx/backend.py
@@ -0,0 +1,1159 @@
+# Copyright (c) 2016-present, Facebook, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+##############################################################################
+
+## @package onnx
+# Module caffe2.python.onnx.backend
+
+"""Backend for running ONNX on Caffe2
+
+To run this, you will need to have Caffe2 installed as well.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import os
+import collections
+from subprocess import Popen, PIPE
+
+import caffe2
+from caffe2.python import core, workspace, rnn_cell, gru_cell
+from caffe2.python.model_helper import ModelHelper
+from caffe2.proto import caffe2_pb2
+import caffe2.python.utils
+import numpy as np
+import onnx
+from onnx import checker, GraphProto, TensorProto, AttributeProto, ModelProto
+import onnx.numpy_helper
+import onnx.defs
+import onnx.optimizer
+from onnx.backend.base import Backend, Device, DeviceType, namedtupledict
+
+from caffe2.python.onnx.workspace import Workspace
+from caffe2.python.onnx.backend_rep import Caffe2Rep
+from caffe2.python.onnx.helper import dummy_name
+
+import warnings
+
+def force_unicode(s):
+ try:
+ return s.decode('utf-8')
+ except AttributeError:
+ return s
+
+def get_device_option(device):
+ m = {DeviceType.CPU: caffe2_pb2.CPU,
+ DeviceType.CUDA: caffe2_pb2.CUDA}
+ return core.DeviceOption(m[device.type], device.device_id)
+
+
+class OnnxAttributes(dict):
+ """
+ This is a more convenient way to work with ONNX/Caffe2 attributes
+ that is not the protobuf representation.
+ """
+ @staticmethod
+ def from_onnx(args):
+ d = OnnxAttributes()
+ for arg in args:
+ d[arg.name] = convertAttributeProto(arg)
+ return d
+
+ def caffe2(self, kmap=lambda k: k):
+ for k, v in self.items():
+ if kmap(k) != '':
+ yield caffe2.python.utils.MakeArgument(kmap(k), v)
+
+
+# TODO: Move this into ONNX main library
+def convertAttributeProto(onnx_arg):
+ """
+ Convert an ONNX AttributeProto into an appropriate Python object
+ for the type.
+
+ NB: Tensor attribute gets returned as the straight proto.
+ """
+ if onnx_arg.HasField('f'):
+ return onnx_arg.f
+ elif onnx_arg.HasField('i'):
+ return onnx_arg.i
+ elif onnx_arg.HasField('s'):
+ return onnx_arg.s
+ elif onnx_arg.HasField('t'):
+ return onnx_arg.t # this is a proto!
+ elif len(onnx_arg.floats):
+ return list(onnx_arg.floats)
+ elif len(onnx_arg.ints):
+ return list(onnx_arg.ints)
+ elif len(onnx_arg.strings):
+ return list(onnx_arg.strings)
+ else:
+ raise ValueError("Unsupported ONNX attribute: {}".format(onnx_arg))
+
+
+# TODO: Move this into ONNX main library
+class OnnxNode(object):
+ """
+ Reimplementation of NodeProto from ONNX, but in a form
+ more convenient to work with from Python.
+
+ We may temporarily edit these nodes to get them into Caffe2 form,
+ before actually translating into the Caffe2 protobuf, since this
+ is easier than decomposing everything, and putting it back together
+ when we're ready.
+ """
+ def __init__(self, node):
+ self.name = str(node.name)
+ self.op_type = str(node.op_type)
+ self.attrs = OnnxAttributes.from_onnx(node.attribute)
+ self.consumed_inputs = self.attrs.pop("consumed_inputs", None)
+ self.inputs = list(node.input)
+ self.outputs = list(node.output)
+
+
+Caffe2Ops = collections.namedtuple('Caffe2Ops', ['ops', 'init_ops', 'interface_blobs'])
+
+
+class Caffe2Backend(Backend):
+
+ # The greatest version of the ONNX operator set which we are aware of.
+ # Models whose version is larger than this will cause us to emit a warning
+ # that we are attempting to translate on a "best effort" basis.
+ #
+ # If you increase this, make SURE you cross-reference all BC-breaking
+ # changes from one version to the next, and any that you did not
+ # implement, mark as broken in _broken_operators
+ _known_opset_version = 3
+
+ # This dictionary will record operators which are KNOWN to be
+ # broken, so we give a good error message rather than do something
+ # bogus and then fail.
+ _broken_operators = {
+ # 'BrokenOp': version_it_was_broken_in
+ }
+
+ # Operators that are different between Caffe2 and
+ # ONNX but only in their name.
+ # In most cases, this should be empty - as the effort of ONNX is
+ # to unify the operator definitions.
+ _renamed_operators = {
+ 'Caffe2ConvTranspose': 'ConvTranspose',
+ 'GlobalMaxPool': 'MaxPool',
+ 'GlobalAveragePool': 'AveragePool',
+ 'Pad': 'PadImage',
+ 'Neg': 'Negative',
+ 'BatchNormalization': 'SpatialBN',
+ 'InstanceNormalization': 'InstanceNorm',
+ 'MatMul': 'BatchMatMul',
+ 'Upsample': 'ResizeNearest',
+ 'Identity': 'Copy',
+ 'InstanceNormalization': 'InstanceNorm',
+ 'Equal': 'EQ',
+ 'Less': 'LT',
+ 'Greater': 'GT',
+ 'Unsqueeze': 'ExpandDims',
+ }
+
+ _global_renamed_attrs = {'kernel_shape': 'kernels'}
+ _per_op_renamed_attrs = {
+ 'Squeeze': {'axes': 'dims'},
+ 'Unsqueeze': {'axes': 'dims'},
+ 'Transpose': {'perm': 'axes'},
+ 'Upsample': {'mode': ''},
+ 'ConvTranspose': {'output_padding': 'adjs'},
+ 'Selu': {'gamma': 'scale'},
+ }
+
+ # operators whose behavior is different beyond renaming
+ # the value is an attribute of this class that is a
+ # function from ToffeIR node_def to caffe2 op_def
+ _special_operators = {
+ 'Constant': '_create_constant',
+ 'Conv': '_create_conv_pool_op_base',
+ 'AveragePool': '_create_conv_pool_op_base',
+ 'GlobalAveragePool': '_create_conv_pool_op_base',
+ 'GlobalMaxPool': '_create_conv_pool_op_base',
+ 'MaxPool': '_create_conv_pool_op_base',
+ 'Reshape': '_create_reshape',
+ 'Gather': '_create_gather',
+ 'Gemm': '_create_gemm',
+ 'Pad': '_create_pad',
+ 'Concat': '_create_concat',
+ 'LogSoftmax': '_create_logsoftmax',
+ 'Slice': '_create_slice',
+ 'LSTM': '_create_lstm',
+ 'GRU': '_create_gru',
+ 'RNN': '_create_rnn',
+ 'Sqrt': '_create_sqrt',
+ 'Reciprocal': '_create_reciprocal',
+ }
+
+ # NB: By default, you will use the LATEST definition of the operator,
+ # so this interface MAY make BC-breaking changes. Specify an
+ # opset_version if you don't want this to version.
+ @classmethod
+ def run_node(cls, node, inputs, device='CPU', opset_version=_known_opset_version):
+ super(Caffe2Backend, cls).run_node(node, inputs, device)
+
+ device_option = get_device_option(Device(device))
+ with Workspace(), core.DeviceScope(device_option): # temporary!
+ if isinstance(inputs, dict):
+ for key, value in inputs.items():
+ workspace.FeedBlob(key, value)
+ else:
+ assert len(node.input) == len(inputs), "{}: expected {} but got {}".format(
+ node.op_type, len(node.input), len(inputs))
+ for key, value in zip(node.input, inputs):
+ workspace.FeedBlob(key, value)
+
+ cls._inplace_rewrite([node])
+ init_ops, ops, _ = cls._onnx_node_to_caffe2_op(
+ None, None, node, opset_version or cls._known_opset_version)
+ ops = init_ops + ops
+ for op in ops:
+ op.device_option.CopyFrom(device_option)
+ workspace.RunOperatorsOnce(ops)
+ output_values = [workspace.FetchBlob(name) for name in node.output]
+ return namedtupledict('Outputs', node.output)(*output_values)
+
+ @classmethod
+ def _create_tensor_filling_op(cls, onnx_tensor, name=None):
+ """
+ Given an Onnx TensorProto, translate it into a Caffe2 operator
+ which produces the given tensor filling op.
+ """
+ assert name or onnx_tensor.name
+ name = name or onnx_tensor.name
+
+ c2_op = caffe2_pb2.OperatorDef()
+
+ c2_values = c2_op.arg.add()
+ c2_values.name = "values"
+
+ def tensor2list(onnx_tensor):
+ # Use the onnx.numpy_helper because the data may be raw
+ return onnx.numpy_helper.to_array(onnx_tensor).flatten().tolist()
+
+ if onnx_tensor.data_type in [TensorProto.FLOAT]:
+ c2_op.type = 'GivenTensorFill'
+ c2_values.floats.extend(tensor2list(onnx_tensor))
+ elif onnx_tensor.data_type in [TensorProto.DOUBLE]:
+ c2_op.type = 'GivenTensorDoubleFill'
+ c2_values.floats.extend(tensor2list(onnx_tensor))
+ elif onnx_tensor.data_type in [TensorProto.INT64,
+ TensorProto.UINT32]:
+ c2_op.type = 'GivenTensorInt64Fill'
+ c2_values.ints.extend(tensor2list(onnx_tensor))
+ elif onnx_tensor.data_type in [TensorProto.UINT8,
+ TensorProto.INT8,
+ TensorProto.UINT16,
+ TensorProto.INT16,
+ TensorProto.INT32]:
+ c2_op.type = 'GivenTensorIntFill'
+ c2_values.ints.extend(tensor2list(onnx_tensor))
+ elif onnx_tensor.data_type == TensorProto.BOOL:
+ c2_op.type = 'GivenTensorBoolFill'
+ c2_values.ints.extend(tensor2list(onnx_tensor))
+ elif onnx_tensor.data_type == TensorProto.STRING:
+ c2_op.type = 'GivenTensorStringFill'
+ c2_values.strings.extend(onnx_tensor.string_data)
+ else:
+ raise RuntimeError(
+ "unrecognized tensor type {}".format(onnx_tensor.data_type))
+
+ c2_shape = c2_op.arg.add()
+ c2_shape.name = "shape"
+ c2_shape.ints.extend(onnx_tensor.dims)
+
+ c2_op.output.append(name)
+
+ return c2_op
+
+ @classmethod
+ def _create_constant(cls, init_model, pred_model, n, opset_version):
+ assert len(n.outputs) == 1
+ return cls._create_tensor_filling_op(n.attrs["value"], n.outputs[0])
+
+ @classmethod
+ def _create_gather(cls, init_model, pred_model, n, opset_version):
+ (A, B) = n.inputs
+ (Y, ) = n.outputs
+ axis = n.attrs.get('axis', 0)
+
+ if axis == 0:
+ return core.CreateOperator("Gather", [A, B], [Y])
+ elif axis == 1:
+ return core.CreateOperator("BatchGather", [A, B], [Y])
+ raise ValueError(
+ 'Caffe2 only supports Gather with axis being 0 or 1,' +
+ 'whereas axis is ' + str(axis))
+
+ @classmethod
+ def _create_logsoftmax(cls, init_model, pred_model, n, opset_version):
+ # NB: this implementation is not backward stable.
+ (A,) = n.inputs
+ (Y,) = n.outputs
+ axis = n.attrs.get('axis', 1)
+ ops = []
+ softmax_A = dummy_name()
+ ops.append(core.CreateOperator('Softmax', [A], [softmax_A], axis=axis))
+ ops.append(core.CreateOperator('Log', [softmax_A], [Y]))
+ return ops
+
+ @classmethod
+ def _create_gemm(cls, init_model, pred_model, n, opset_version):
+ (A, B, C) = n.inputs
+ (Y,) = n.outputs
+ alpha = n.attrs.get('alpha', 1.)
+ beta = n.attrs.get('beta', 1.)
+
+ ops = []
+ if alpha != 1:
+ scaled_A = dummy_name()
+ ops.append(core.CreateOperator('Scale', [A], [scaled_A], scale=alpha))
+ A = scaled_A
+ if beta != 1:
+ scaled_C = dummy_name()
+ ops.append(core.CreateOperator('Scale', [C], [scaled_C], scale=beta))
+ C = scaled_C
+
+ trans_a = n.attrs.get('transA', 0)
+ trans_b = n.attrs.get('transB', 0)
+ broadcast = n.attrs.get('broadcast', 0)
+ if not trans_a and trans_b and broadcast:
+ ops.append(core.CreateOperator('FC',
+ [A, B, C],
+ [Y]))
+ else:
+ AB = dummy_name()
+ ops.append(core.CreateOperator('MatMul',
+ [A, B],
+ [AB],
+ trans_a=trans_a,
+ trans_b=trans_b))
+ ops.append(core.CreateOperator('Add',
+ [AB, C],
+ [Y],
+ broadcast=broadcast))
+
+ return ops
+
+ @classmethod
+ def _rnn_shape_inference(cls, init_model, pred_model, n, input_blob, W):
+ # ad-hoc, informally-specified, bug-ridden, slow
+ # implementation of shape inference
+
+ # if the weight matrices are directly provided as
+ # initializers, their dimensions should be available in the
+ # init net model.
+ for x in init_model.graph.input:
+ if x.name == W:
+ return x.type.tensor_type.shape.dim[1].dim_value
+
+ # otherwise, assume that the input_blob is either a direct
+ # graph input, or another rnn op of the same type. This
+ # matches the pattern produced by exporting from pytorch
+ # (where the weight matrices are unusable for this purpose due
+ # to reshaping operations that lose shape information).
+ for x in pred_model.graph.input:
+ if x.name == input_blob:
+ return x.type.tensor_type.shape.dim[2].dim_value
+
+ curr = n
+ while True:
+ for x in pred_model.graph.input:
+ if x.name == curr.inputs[0] and curr.op_type == 'Gather':
+ return x.type.tensor_type.shape.dim[1].dim_value
+ prev = [x for x in map(OnnxNode, pred_model.graph.node) if x.outputs[0] == curr.inputs[0]]
+ if len(prev) != 1:
+ return
+ prev = prev[0]
+ if prev.op_type == n.op_type:
+ return prev.attrs['hidden_size']
+ curr = prev
+
+ @classmethod
+ def _create_rnn(cls, init_model, pred_model, n, opset_version):
+ assert init_model is not None, "cannot convert RNNs without access to the full model"
+ assert pred_model is not None, "cannot convert RNNs without access to the full model"
+
+ attrs = dict(n.attrs) # make a copy, which is safe to mutate
+ hidden_size = attrs.pop('hidden_size')
+ activation = force_unicode(attrs.pop('activations', ('tanh',))[0])
+ direction = force_unicode(attrs.pop('direction', 'forward'))
+ assert not attrs, "unsupported RNN attributes: " + str(attrs.keys())
+ assert direction in ['forward', 'bidirectional'], "unsupported backwards RNN"
+
+ input_blob, W, R, B, sequence_lens, initial_h = n.inputs
+
+ if sequence_lens == "":
+ sequence_lens = None
+
+ input_size = cls._rnn_shape_inference(init_model, pred_model, n, input_blob, W)
+ if input_size is None:
+ raise RuntimeError("best-effort shape inference for RNN input failed")
+
+ init_net = core.Net("init-net")
+ pred_mh = ModelHelper()
+
+ def make_rnn(direction_offset):
+ name = dummy_name()
+
+ # input and recurrence biases are squashed together in
+ # onnx but not in caffe2
+
+ bias_offset = 2 * direction_offset * hidden_size
+ init_net.Slice(B, name + "/i2h_b",
+ starts=[bias_offset + 0 * hidden_size],
+ ends =[bias_offset + 1 * hidden_size])
+ init_net.Slice(B, name + "/gates_t_b",
+ starts=[bias_offset + 1 * hidden_size],
+ ends =[bias_offset + 2 * hidden_size])
+
+ weight_offset = direction_offset * hidden_size
+ init_net.Slice(W, name + '/i2h_w',
+ starts=[weight_offset + 0 * hidden_size, 0],
+ ends =[weight_offset + 1 * hidden_size,-1])
+ init_net.Slice(R, name + '/gates_t_w',
+ starts=[weight_offset + 0 * hidden_size, 0],
+ ends =[weight_offset + 1 * hidden_size,-1])
+
+ initial_h_sliced = name + '/initial_h'
+ init_net.Slice(initial_h, initial_h_sliced,
+ starts=[direction_offset + 0, 0, 0],
+ ends =[direction_offset + 1,-1,-1])
+
+ if direction_offset == 1:
+ input = pred_mh.net.ReversePackedSegs(
+ [input_blob, sequence_lens], name + "/input-reversed")
+ else:
+ input = input_blob
+
+ hidden_t_all, hidden_t_last = rnn_cell.BasicRNN(
+ pred_mh,
+ input,
+ sequence_lens,
+ [initial_h_sliced],
+ input_size,
+ hidden_size,
+ name,
+ drop_states=True,
+ forward_only=True,
+ activation=activation
+ )
+
+ if direction_offset == 1:
+ hidden_t_all = pred_mh.net.ReversePackedSegs(
+ [hidden_t_all, sequence_lens], name + "/output-reversed")
+
+ return hidden_t_all, hidden_t_last
+
+ if direction == 'forward':
+ hidden_t_all, hidden_t_last = make_rnn(0)
+ pred_mh.net = pred_mh.net.Clone(
+ "dummy-clone-net",
+ blob_remap={ hidden_t_all: n.outputs[0], hidden_t_last: n.outputs[1] }
+ )
+ elif direction == 'bidirectional':
+ hidden_t_all_f, hidden_t_last_f = make_rnn(0)
+ hidden_t_all_b, hidden_t_last_b = make_rnn(1)
+ pred_mh.net.Concat([hidden_t_all_f, hidden_t_all_b],
+ [n.outputs[0], dummy_name()], axis=2)
+ pred_mh.net.Concat([hidden_t_last_f, hidden_t_last_b],
+ [n.outputs[1], dummy_name()], axis=2)
+
+ return Caffe2Ops(list(pred_mh.Proto().op),
+ list(init_net.Proto().op),
+ list(pred_mh.Proto().external_input))
+
+ @classmethod
+ def _create_lstm(cls, init_model, pred_model, n, opset_version):
+ assert init_model is not None, "cannot convert LSTMs without access to the full model"
+ assert pred_model is not None, "cannot convert LSTMs without access to the full model"
+
+ attrs = dict(n.attrs) # make a copy, which is safe to mutate
+ hidden_size = attrs.pop('hidden_size')
+ direction = force_unicode(attrs.pop('direction', 'forward'))
+ assert not attrs, "unsupported LSTM attributes: " + str(attrs.keys())
+ assert direction in ['forward', 'bidirectional'], "unsupported backwards LSTM"
+
+ input_blob, W, R, B, sequence_lens, initial_h, initial_c = n.inputs
+
+ if sequence_lens == "":
+ sequence_lens = None
+
+ input_size = cls._rnn_shape_inference(init_model, pred_model, n, input_blob, W)
+ if input_size is None:
+ raise RuntimeError("best-effort shape inference for LSTM input failed")
+
+ init_net = core.Net("init-net")
+ pred_mh = ModelHelper()
+
+ def make_lstm(direction_offset):
+ name = dummy_name()
+
+ # input and recurrence biases are squashed together in
+ # onnx but not in caffe2
+
+ bias_offset = 8 * direction_offset * hidden_size
+ Bi = init_net.Slice(B, name + "_bias_i2h",
+ starts=[bias_offset + 0 * hidden_size],
+ ends =[bias_offset + 4 * hidden_size])
+ Br = init_net.Slice(B, name + "_bias_gates",
+ starts=[bias_offset + 4 * hidden_size],
+ ends =[bias_offset + 8 * hidden_size])
+
+ weight_offset = 4 * direction_offset * hidden_size
+ W_ = init_net.Slice(W, name + '/i2h_w_pre',
+ starts=[weight_offset + 0 * hidden_size, 0],
+ ends =[weight_offset + 4 * hidden_size,-1])
+ R_ = init_net.Slice(R, name + '/gates_t_w_pre',
+ starts=[weight_offset + 0 * hidden_size, 0],
+ ends =[weight_offset + 4 * hidden_size,-1])
+
+ # caffe2 has a different order from onnx. We need to rearrange
+ # i o f c -> i f o c
+ reforms = ((W_, 'i2h_w', [(0, -1)]),
+ (R_, 'gates_t_w', [(0, -1)]),
+ (Bi, 'i2h_b' , []),
+ (Br, 'gates_t_b', []))
+ for name_from, name_to, extra_dims in reforms:
+ xi, xo, xf, xc = [name_from + suffix for suffix in ("_i", "_o", "_f", "_c")]
+ for i, x in enumerate([xi, xo, xf, xc]):
+ dim0 = i * hidden_size, (i+1) * hidden_size
+ starts, ends = zip(dim0, *extra_dims)
+ init_net.Slice(name_from, x, starts=starts, ends=ends)
+ init_net.Concat([xi, xf, xo, xc], ['%s/%s' % (name, name_to), dummy_name()], axis=0)
+
+ initial_h_sliced = name + '/initial_h'
+ init_net.Slice(initial_h, initial_h_sliced,
+ starts=[direction_offset + 0, 0, 0],
+ ends =[direction_offset + 1,-1,-1])
+ initial_c_sliced = name + '/initial_c'
+ init_net.Slice(initial_c, initial_c_sliced,
+ starts=[direction_offset + 0, 0, 0],
+ ends =[direction_offset + 1,-1,-1])
+
+ if direction_offset == 1:
+ input = pred_mh.net.ReversePackedSegs(
+ [input_blob, sequence_lens], name + "/input-reversed")
+ else:
+ input = input_blob
+
+ hidden_t_all, hidden_t_last, _, _, params = rnn_cell.LSTM(
+ pred_mh,
+ input,
+ sequence_lens,
+ [initial_h_sliced, initial_c_sliced],
+ input_size,
+ hidden_size,
+ name,
+ drop_states=True,
+ forward_only=True,
+ return_params=True
+ )
+
+ if direction_offset == 1:
+ hidden_t_all = pred_mh.net.ReversePackedSegs(
+ [hidden_t_all, sequence_lens], name + "/output-reversed")
+
+ return hidden_t_all, hidden_t_last
+
+ if direction == 'forward':
+ hidden_t_all, hidden_t_last = make_lstm(0)
+ pred_mh.net = pred_mh.net.Clone(
+ "dummy-clone-net",
+ blob_remap={ hidden_t_all: n.outputs[0], hidden_t_last: n.outputs[1] }
+ )
+ elif direction == 'bidirectional':
+ hidden_t_all_f, hidden_t_last_f = make_lstm(0)
+ hidden_t_all_b, hidden_t_last_b = make_lstm(1)
+ pred_mh.net.Concat([hidden_t_all_f, hidden_t_all_b],
+ [n.outputs[0], dummy_name()], axis=2)
+ pred_mh.net.Concat([hidden_t_last_f, hidden_t_last_b],
+ [n.outputs[1], dummy_name()], axis=2)
+
+ return Caffe2Ops(list(pred_mh.Proto().op),
+ list(init_net.Proto().op),
+ list(pred_mh.Proto().external_input))
+
+ @classmethod
+ def _create_gru(cls, init_model, pred_model, n, opset_version):
+ assert init_model is not None, "cannot convert GRUs without access to the full model"
+ assert pred_model is not None, "cannot convert GRUs without access to the full model"
+
+ attrs = dict(n.attrs) # make a copy, which is safe to mutate
+ hidden_size = attrs.pop('hidden_size')
+ linear_before_reset = attrs.pop('linear_before_reset', 0)
+ direction = force_unicode(attrs.pop('direction', 'forward'))
+ assert not attrs, "unsupported GRU attributes: " + str(attrs.keys())
+ assert direction in ['forward', 'bidirectional'], "unsupported backwards GRU"
+
+ input_blob, W, R, B, sequence_lens, initial_h = n.inputs
+
+ if sequence_lens == "":
+ sequence_lens = None
+
+ input_size = cls._rnn_shape_inference(init_model, pred_model, n, input_blob, W)
+ if input_size is None:
+ raise RuntimeError("best-effort shape inference for GRU input failed")
+
+ init_net = core.Net("init-net")
+ pred_mh = ModelHelper()
+
+ def make_gru(direction_offset):
+ name = dummy_name()
+
+ # input and recurrence biases are squashed together in
+ # onnx but not in caffe2
+
+ bias_offset = 6 * direction_offset * hidden_size
+ Bi = init_net.Slice(B, name + "_bias_i2h",
+ starts=[bias_offset + 0 * hidden_size],
+ ends =[bias_offset + 3 * hidden_size])
+ Br = init_net.Slice(B, name + "_bias_gates",
+ starts=[bias_offset + 3 * hidden_size],
+ ends =[bias_offset + 6 * hidden_size])
+
+ weight_offset = 3 * direction_offset * hidden_size
+ W_ = init_net.Slice(W, name + '/i2h_w_pre',
+ starts=[weight_offset + 0 * hidden_size, 0],
+ ends =[weight_offset + 3 * hidden_size,-1])
+ R_ = init_net.Slice(R, name + '/gates_t_w_pre',
+ starts=[weight_offset + 0 * hidden_size, 0],
+ ends =[weight_offset + 3 * hidden_size,-1])
+
+ # caffe2 has a different order from onnx. We need to rearrange
+ # z r h -> r z h
+ reforms = ((W_, 'i2h_w', True, [(0,-1)]),
+ (R_, 'gate_t_w', False, [(0,-1)]),
+ (Bi, 'i2h_b', True, []),
+ (Br, 'gate_t_b', False, []))
+ for name_from, name_to, do_concat, extra_dims in reforms:
+ xz, xr, xh = ['%s/%s_%s' % (name, prefix, name_to) for prefix in ('update', 'reset', 'output')]
+ for i, x in enumerate([xz, xr, xh]):
+ dim0 = i * hidden_size, (i+1) * hidden_size
+ starts, ends = zip(dim0, *extra_dims)
+ init_net.Slice(name_from, x, starts=starts, ends=ends)
+ if do_concat:
+ init_net.Concat([xr, xz, xh], ['%s/%s' % (name, name_to), dummy_name()], axis=0)
+
+ initial_h_sliced = name + '/initial_h'
+ init_net.Slice(initial_h, initial_h_sliced,
+ starts=[direction_offset + 0, 0, 0],
+ ends =[direction_offset + 1,-1,-1])
+
+ if direction_offset == 1:
+ input = pred_mh.net.ReversePackedSegs(
+ [input_blob, sequence_lens], name + "/input-reversed")
+ else:
+ input = input_blob
+
+ hidden_t_all, hidden_t_last = gru_cell.GRU(
+ pred_mh,
+ input,
+ sequence_lens,
+ [initial_h_sliced],
+ input_size,
+ hidden_size,
+ name,
+ drop_states=True,
+ forward_only=True,
+ linear_before_reset=linear_before_reset
+ )
+
+ if direction_offset == 1:
+ hidden_t_all = pred_mh.net.ReversePackedSegs(
+ [hidden_t_all, sequence_lens], name + "/output-reversed")
+
+ return hidden_t_all, hidden_t_last
+
+ if direction == 'forward':
+ hidden_t_all, hidden_t_last = make_gru(0)
+ pred_mh.net = pred_mh.net.Clone(
+ "dummy-clone-net",
+ blob_remap={ hidden_t_all: n.outputs[0], hidden_t_last: n.outputs[1] }
+ )
+ elif direction == 'bidirectional':
+ hidden_t_all_f, hidden_t_last_f = make_gru(0)
+ hidden_t_all_b, hidden_t_last_b = make_gru(1)
+ pred_mh.net.Concat([hidden_t_all_f, hidden_t_all_b],
+ [n.outputs[0], dummy_name()], axis=2)
+ pred_mh.net.Concat([hidden_t_last_f, hidden_t_last_b],
+ [n.outputs[1], dummy_name()], axis=2)
+
+ return Caffe2Ops(list(pred_mh.Proto().op),
+ list(init_net.Proto().op),
+ list(pred_mh.Proto().external_input))
+
+ @classmethod
+ def _create_pad(cls, init_model, pred_model, n, opset_version):
+ if opset_version < 2:
+ pads = n.attrs['paddings']
+ else:
+ pads = n.attrs['pads']
+ if not (len(pads) == 8 and
+ # first two dim is for batch and channel
+ set(pads[:2] + pads[4:6]) == {0}):
+ raise ValueError('Caffe2 only supports padding 2D Tensor, whereas padding is ' + str(pads))
+ # Guard the invalid (negative) pads attribute.
+ if min(pads) < 0:
+ raise ValueError('ONNX does not support negative pads in Pad, but get {}.'.format(pads))
+ pads[:] = pads[2:4] + pads[6:8]
+ return cls._common_onnx_node_to_caffe2_op(init_model, pred_model, n, opset_version)
+
+ @classmethod
+ def _create_concat(cls, init_model, pred_model, n, opset_version):
+ # TODO: Caffe2 Concat has an extra output. It should be only
+ # used when doing training, so we should change Caffe2 to allow
+ # 1 output.
+ op = cls._common_onnx_node_to_caffe2_op(init_model, pred_model, n, opset_version)
+ assert len(op.output) == 1
+ op.output.append(dummy_name())
+ return op
+
+ @classmethod
+ def _create_slice(cls, init_model, pred_model, n, opset_version):
+ op = cls._common_onnx_node_to_caffe2_op(init_model, pred_model, n, opset_version)
+ args = {arg.name: arg for arg in op.arg}
+ starts_vals = np.array(
+ args.pop('starts').ints, dtype=np.int64).tolist()
+ ends_vals = np.array(
+ [i - 1 if i < 0 else i for i in args.pop('ends').ints],
+ dtype=np.int64).tolist()
+ if 'axes' in args:
+ axes_vals = np.array(
+ args.pop('axes').ints, dtype=np.int32).tolist()
+ else:
+ ndims = len(starts_vals)
+ axes_vals = np.array(range(ndims), dtype=np.int32).tolist()
+
+ data, = op.input
+ ops = []
+
+ shape_tensor = dummy_name()
+ ops.append(core.CreateOperator(
+ 'Shape',
+ [data],
+ [shape_tensor]
+ ))
+
+ axes_tensor = dummy_name()
+ ops.extend([
+ core.CreateOperator(
+ 'GivenTensorIntFill',
+ [],
+ [axes_tensor],
+ shape=[len(axes_vals)],
+ values=axes_vals,
+ ),
+ ])
+
+ starts_vals_tensor = dummy_name()
+ starts_tensor = dummy_name()
+ casted_starts_tensor = dummy_name()
+ ops.extend([
+ core.CreateOperator(
+ 'GivenTensorInt64Fill',
+ [],
+ [starts_vals_tensor],
+ shape=[len(starts_vals)],
+ values=starts_vals,
+ ),
+ core.CreateOperator(
+ 'ConstantFill',
+ [shape_tensor],
+ [starts_tensor],
+ dtype=caffe2_pb2.TensorProto.INT64,
+ value=0,
+ ),
+ core.CreateOperator(
+ 'ScatterAssign',
+ [starts_tensor, axes_tensor, starts_vals_tensor],
+ [starts_tensor],
+ ),
+ # Slice only accepts starts as int
+ core.CreateOperator(
+ 'Cast',
+ [starts_tensor],
+ [casted_starts_tensor],
+ to=caffe2_pb2.TensorProto.INT32,
+ ),
+ ])
+
+ ends_vals_tensor = dummy_name()
+ ends_tensor = dummy_name()
+ casted_ends_tensor = dummy_name()
+ ops.extend([
+ core.CreateOperator(
+ 'GivenTensorInt64Fill',
+ [],
+ [ends_vals_tensor],
+ shape=[len(ends_vals)],
+ values=ends_vals,
+ ),
+ core.CreateOperator(
+ 'ConstantFill',
+ [shape_tensor],
+ [ends_tensor],
+ dtype=caffe2_pb2.TensorProto.INT64,
+ value=-1,
+ ),
+ core.CreateOperator(
+ 'ScatterAssign',
+ [ends_tensor, axes_tensor, ends_vals_tensor],
+ [ends_tensor],
+ ),
+ # Slice only accepts ends as int
+ core.CreateOperator(
+ 'Cast',
+ [ends_tensor],
+ [casted_ends_tensor],
+ to=caffe2_pb2.TensorProto.INT32,
+ ),
+ ])
+
+ op.input[:] = [data, casted_starts_tensor, casted_ends_tensor]
+ del op.arg[:]
+ op.arg.extend(args.values())
+ ops.append(op)
+
+ return ops
+
+ # Note [Caffe2 ConvPoolOpBase]
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ # To understand what is going on here, we have to talk a little bit about
+ # Caffe2's internals.
+ #
+ # First, it's important to know that all of Caffe2's pooling and convolution
+ # operators inherit from "ConvPoolOpBase", which is an abstract class that
+ # defines all of the attributes (kernels, dilations, strides, etc) which one
+ # sees on these operators. Unfortunately, Caffe2's documentation generator
+ # doesn't know how to handle cases like this, so for example, if you look at
+ # the docs for MaxPool at <https://caffe2.ai/docs/operators-catalogue.html#maxpool>
+ # you won't see any of the attributes. You have to go source diving to
+ # find the information; in particular, you want to look at:
+ # https://github.com/caffe2/caffe2/blob/master/caffe2/operators/conv_pool_op_base.h
+ # This class handles *global* pooling as well.
+ #
+ # Second, it's important to know what Caffe2 expects for padding, which can
+ # be somewhat difficult to understand from the code because Caffe2 handles
+ # both singular/pluralized spellings of padding, and there is also legacy
+ # padding business. The short version of the story is that, for NON-legacy
+ # padding (which is what we want to output), padding is expected to be
+ # *twice* the size of kernels. So if you have a 2D convolution, Caffe2
+ # will accept two values in 'kernels', but FOUR values in 'pads';
+ # furthermore, this is *mandatory.*
+ #
+ # Finally, ConvPoolOpBase is not the only class of it's kind; there is
+ # also ConvTransposeUnpoolBase, which backs ConvTranspose. So don't
+ # be tricked by the fact that Conv and ConvTranspose have similar
+ # parameters; they exercise different codepaths and need to be handled
+ # differently.
+
+ @classmethod
+ def _create_conv_pool_op_base(cls, init_model, pred_model, n, opset_version):
+ if n.op_type.startswith('Global'):
+ n.attrs['global_pooling'] = 1
+
+ try:
+ kernels = n.attrs['kernel_shape']
+ pads = n.attrs['pads']
+ except KeyError:
+ pass
+ else:
+ if len(kernels) == len(pads):
+ # Caffe2 requires pads to be twice the size of kernels.
+ n.attrs['pads'] = pads * 2
+
+ return cls._common_onnx_node_to_caffe2_op(init_model, pred_model, n, opset_version)
+
+ @classmethod
+ def _create_reshape(cls, init_model, pred_model, n, opset_version):
+ c2_op = cls._common_onnx_node_to_caffe2_op(init_model, pred_model, n, opset_version)
+ # Caffe2 has an extra output
+ c2_op.output.append(dummy_name())
+ return c2_op
+
+ @classmethod
+ def _create_sqrt(cls, init_model, pred_model, n, opset_version):
+ (X,) = n.inputs
+ (Y,) = n.outputs
+ return core.CreateOperator(
+ 'Pow',
+ [X],
+ [Y],
+ exponent=0.5,
+ )
+
+ @classmethod
+ def _create_reciprocal(cls, init_model, pred_model, n, opset_version):
+ (X,) = n.inputs
+ (Y,) = n.outputs
+ return core.CreateOperator(
+ 'Pow',
+ [X],
+ [Y],
+ exponent=-1.0,
+ )
+
+ @classmethod
+ def _direct_initialize_parameters(cls, initializer, ws, device_option):
+ for tp in initializer:
+ ws.FeedBlob(tp.name, onnx.numpy_helper.to_array(tp), device_option)
+
+ @classmethod
+ def _direct_initialize_inputs(cls, inputs, initialized, ws, device_option):
+ for value_info in inputs:
+ if value_info.name in initialized:
+ continue
+ shape = list(d.dim_value for d in value_info.type.tensor_type.shape.dim)
+ ws.FeedBlob(value_info.name, np.ones(shape), device_option)
+
+ @staticmethod
+ def optimize_onnx(input, init=False, predict=False):
+ passes = ['fuse_consecutive_transposes',
+ 'eliminate_nop_transpose',
+ 'fuse_transpose_into_gemm']
+ if init:
+ passes.append('split_init')
+ if predict:
+ passes.append('split_predict')
+ out = onnx.optimizer.optimize(input, passes)
+ return out
+
+ @classmethod
+ def prepare(cls, model, device='CPU', **kwargs):
+ '''
+ For Onnx Caffe2Backend, we require that init_graph don't initialize the actual input of the predict_graph,
+
+ for example, if "img" is the input blob for the predict_net, we require that in init_graph and in
+ initializer of the predict_graph, "img" is not initalized. We don't have a check for this, since
+ there is no way we can know which blob is the input of the predict_graph.
+ '''
+ super(Caffe2Backend, cls).prepare(model, device, **kwargs)
+
+
+ opset_version = None
+ for imp in model.opset_import:
+ if not imp.HasField("domain") or imp.domain == "":
+ opset_version = imp.version
+ if imp.version > cls._known_opset_version:
+ warnings.warn("This version of onnx-caffe2 targets ONNX operator set version {}, but the model we are trying to import uses version {}. We will try to import it anyway, but if the model uses operators which had BC-breaking changes in the intervening versions, import will fail.".format(cls._known_opset_version, imp.version))
+ else:
+ warnings.warn("Unrecognized operator set {}".format(imp.domain))
+ if opset_version is None:
+ if model.ir_version >= 0x00000003:
+ raise RuntimeError("Model with IR version >= 3 did not specify ONNX operator set version (onnx-caffe2 requires it)")
+ else:
+ opset_version = 1
+
+ ws = Workspace()
+ device_option = get_device_option(Device(device))
+
+ # Directly load initializer data into blobs in workspace
+ cls._direct_initialize_parameters(
+ model.graph.initializer,
+ ws,
+ device_option,
+ )
+
+ initialized = {init.name for init in model.graph.initializer}
+
+ cls._direct_initialize_inputs(
+ model.graph.input,
+ initialized,
+ ws,
+ device_option,
+ )
+
+ uninitialized = [value_info.name for value_info in model.graph.input if value_info.name not in initialized]
+
+ init_net, predict_net = cls._onnx_model_to_caffe2_net(model, device, opset_version, False)
+
+ retval = Caffe2Rep(init_net, predict_net, ws, uninitialized)
+ return retval
+
+ @classmethod
+ # TODO: This method needs a refactor for clarity
+ def _onnx_node_to_caffe2_op(cls, init_model, pred_model, node_def, opset_version):
+ if node_def.op_type in cls._special_operators:
+ translator = getattr(cls, cls._special_operators[node_def.op_type])
+ else:
+ translator = cls._common_onnx_node_to_caffe2_op
+ ops = translator(init_model, pred_model, OnnxNode(node_def), opset_version)
+ if isinstance(ops, Caffe2Ops):
+ return ops
+ if not isinstance(ops, collections.Iterable):
+ ops = [ops]
+ return Caffe2Ops(ops, [], [])
+
+ @classmethod
+ def _common_onnx_node_to_caffe2_op(cls, init_model, pred_model, onnx_node, opset_version):
+ """
+ This translator performs the basic translation of ONNX nodes into
+ Caffe2 operators. Besides doing a straightforward marshalling from
+ one format to another, it also does these extra things:
+
+ - Renames operators based on '_renamed_operators'
+ - Renames attributes based on '_global_renamed_attrs' and
+ '_per_op_renamed_attrs'
+
+ If you're writing a custom translator, consider calling this first,
+ and then fixing things up further.
+ """
+ c2_op = caffe2_pb2.OperatorDef()
+
+ c2_op.input.extend(onnx_node.inputs)
+ c2_op.output.extend(onnx_node.outputs)
+ c2_op.name = onnx_node.name
+
+ onnx_op_type = onnx_node.op_type
+ broken_version = cls._broken_operators.get(onnx_op_type, float('Inf'))
+ if broken_version <= opset_version:
+ raise ValueError(
+ "Don't know how to translate op {} in ONNX operator set v{} (I only support prior to v{})".format(onnx_op_type, opset_version, broken_version))
+ c2_op.type = cls._renamed_operators.get(onnx_op_type, onnx_op_type)
+ if not core.IsOperator(c2_op.type):
+ raise ValueError(
+ "Don't know how to translate op {}".format(onnx_op_type))
+
+ def kmap(k):
+ if (onnx_op_type in cls._per_op_renamed_attrs and
+ k in cls._per_op_renamed_attrs[onnx_op_type]):
+ return cls._per_op_renamed_attrs[onnx_op_type][k]
+ if k in cls._global_renamed_attrs:
+ return cls._global_renamed_attrs[k]
+ return k
+ c2_op.arg.extend(onnx_node.attrs.caffe2(kmap=kmap))
+
+ return c2_op
+
+
+ @classmethod
+ def _inplace_rewrite(cls, graph_or_nodes):
+ '''
+ currently we use this to translate ONNX-style
+ consumed_input annotations to Caffe2-style in place
+ updates (use same input and output names).
+ '''
+ is_graph = isinstance(graph_or_nodes, GraphProto)
+ if is_graph:
+ nodes = graph_or_nodes.node
+ else:
+ nodes = graph_or_nodes
+
+ renamed = {}
+
+ for node in nodes:
+ node.input[:] = [renamed.get(input_name, input_name)
+ for input_name in node.input]
+ consumed_inputs = OnnxNode(node).consumed_inputs or []
+ output_idxes = set(range(len(node.output)))
+ schema = onnx.defs.get_schema(node.op_type)
+ for i, consumed in enumerate(consumed_inputs):
+ if not consumed:
+ continue
+ _, output_idx = schema.consumed(i)
+ # consumed outputs are not always present
+ # for instance batch norm in test mode
+ # does not return the consumed inputs
+ if output_idx < len(node.output):
+ output_idxes.remove(output_idx)
+ old_val = node.output[output_idx]
+ new_val = node.input[i]
+ node.output[output_idx] = new_val
+ renamed[old_val] = new_val
+ for idx in output_idxes:
+ name = node.output[idx]
+ node.output[idx] = renamed.get(name, name)
+ if is_graph:
+ for output in graph_or_nodes.output:
+ output.name = renamed.get(output.name, output.name)
+
+ @staticmethod
+ def _all_names_in_graph(graph):
+ if graph is None:
+ return set()
+
+ names = set()
+ names.update(value_info.name for value_info in graph.input)
+ names.update(value_info.name for value_info in graph.output)
+ for node in graph.node:
+ names.update(node.input)
+ names.update(node.output)
+ return names
+
+ @classmethod
+ def _onnx_model_to_caffe2_net(cls, onnx_model, device, opset_version, include_initializers):
+ device_option = get_device_option(Device(device))
+
+ init_model = ModelProto()
+ init_model.ParseFromString(cls.optimize_onnx(onnx_model.SerializeToString(), init=True))
+ cls._inplace_rewrite(init_model.graph)
+
+ pred_model = ModelProto()
+ pred_model.ParseFromString(cls.optimize_onnx(onnx_model.SerializeToString(), predict=True))
+ cls._inplace_rewrite(pred_model.graph)
+
+ init_net = caffe2_pb2.NetDef()
+ pred_net = caffe2_pb2.NetDef()
+
+ init_net.name = onnx_model.graph.name + '_init'
+ pred_net.name = onnx_model.graph.name + '_predict'
+
+ if include_initializers:
+ init_net.op.extend(cls._create_tensor_filling_op(tp) for tp in onnx_model.graph.initializer)
+
+ dummy_name(cls._all_names_in_graph(init_model.graph) | cls._all_names_in_graph(pred_model.graph))
+
+ for net, model in ( (init_net, init_model), (pred_net, pred_model) ):
+ net.device_option.CopyFrom(device_option)
+ for node in model.graph.node:
+ c2ops = cls._onnx_node_to_caffe2_op(
+ init_model, pred_model, node, opset_version)
+ (init_net if include_initializers else net).op.extend(c2ops.init_ops)
+ net.op.extend(c2ops.ops)
+ net.external_input.extend(c2ops.interface_blobs)
+ net.external_output.extend(
+ value_info.name for value_info in model.graph.output)
+ net.external_input.extend(
+ value_info.name for value_info in model.graph.input)
+
+ return init_net, pred_net
+
+ # wrapper for backwards compatability
+ @classmethod
+ def onnx_graph_to_caffe2_net(cls, model, device="CPU", opset_version=_known_opset_version):
+ return cls._onnx_model_to_caffe2_net(model, device=device, opset_version=opset_version, include_initializers=True)
+
+ @classmethod
+ def supports_device(cls, device_str):
+ device = Device(device_str)
+ if device.type == DeviceType.CPU:
+ return True
+ elif device.type == DeviceType.CUDA:
+ return workspace.has_gpu_support
+ return False
+
+
+prepare = Caffe2Backend.prepare
+
+run_node = Caffe2Backend.run_node
+
+run_model = Caffe2Backend.run_model
+
+supports_device = Caffe2Backend.supports_device # noqa
diff --git a/caffe2/python/onnx/backend_rep.py b/caffe2/python/onnx/backend_rep.py
new file mode 100644
index 0000000000..f427b6d2d1
--- /dev/null
+++ b/caffe2/python/onnx/backend_rep.py
@@ -0,0 +1,77 @@
+# Copyright (c) 2016-present, Facebook, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+##############################################################################
+
+## @package onnx
+# Module caffe2.python.onnx.backend_rep
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from caffe2.python import core, workspace
+from caffe2.proto import caffe2_pb2
+from onnx.backend.base import BackendRep, namedtupledict
+
+class Caffe2Rep(BackendRep):
+ def __init__(self, init_net, predict_net, workspace, uninitialized):
+ super(Caffe2Rep, self).__init__()
+ self.init_net = init_net
+ self.predict_net = predict_net
+ self.workspace = workspace
+ # The list of uninitialized external_inputs in workspace, we need this to
+ # pair the name with given sequence inputs.
+ self.uninitialized = uninitialized
+ self.nets_created = False
+ self.ran_init_net = False
+
+ @property
+ def _name_scope(self):
+ if self.predict_net.device_option.device_type == caffe2_pb2.CUDA:
+ return 'gpu_{}'.format(self.predict_net.device_option.cuda_gpu_id)
+ return ''
+
+ def run(self, inputs, **kwargs):
+ super(Caffe2Rep, self).run(inputs, **kwargs)
+ with self.workspace:
+ with core.DeviceScope(self.predict_net.device_option):
+ if isinstance(inputs, dict):
+ with core.NameScope(self._name_scope):
+ for key, value in inputs.items():
+ workspace.FeedBlob(key, value)
+ elif isinstance(inputs, list) or isinstance(inputs, tuple):
+ if len(self.uninitialized) != len(inputs):
+ raise RuntimeError('Expected {} values for uninitialized '
+ 'graph inputs ({}), but got {}.'.format(
+ len(self.uninitialized),
+ ', '.join(self.uninitialized),
+ len(inputs)))
+ for i, value in enumerate(inputs):
+ # namescope already baked into protobuf
+ workspace.FeedBlob(self.uninitialized[i], value)
+ else:
+ # single input
+ workspace.FeedBlob(self.uninitialized[0], inputs)
+ if not self.nets_created:
+ workspace.CreateNet(self.init_net)
+ workspace.CreateNet(self.predict_net)
+ self.nets_created = True
+ if not self.ran_init_net:
+ workspace.RunNet(self.init_net.name)
+ self.ran_init_net = True
+ workspace.RunNet(self.predict_net.name)
+ output_values = [workspace.FetchBlob(name)
+ for name in self.predict_net.external_output]
+ return namedtupledict('Outputs',
+ self.predict_net.external_output)(*output_values)
diff --git a/caffe2/python/onnx/bin/conversion.py b/caffe2/python/onnx/bin/conversion.py
new file mode 100644
index 0000000000..fd5eb70ef1
--- /dev/null
+++ b/caffe2/python/onnx/bin/conversion.py
@@ -0,0 +1,104 @@
+# Copyright (c) 2016-present, Facebook, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+##############################################################################
+
+## @package onnx
+# Module caffe2.python.onnx.bin.conversion
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+
+from caffe2.proto import caffe2_pb2
+import click
+import numpy as np
+from onnx import checker, ModelProto
+
+from caffe2.python.onnx.backend import Caffe2Backend as c2
+import caffe2.python.onnx.frontend as c2_onnx
+
+
+@click.command(
+ help='convert caffe2 net to onnx model',
+ context_settings={
+ 'help_option_names': ['-h', '--help']
+ }
+)
+@click.argument('caffe2_net', type=click.File('rb'))
+@click.option('--caffe2-net-name',
+ type=str,
+ help="Name of the caffe2 net")
+@click.option('--caffe2-init-net',
+ type=click.File('rb'),
+ help="Path of the caffe2 init net pb file")
+@click.option('--value-info',
+ type=str,
+ help='A json string providing the '
+ 'type and shape information of the inputs')
+@click.option('-o', '--output', required=True,
+ type=click.File('wb'),
+ help='Output path for the onnx model pb file')
+def caffe2_to_onnx(caffe2_net,
+ caffe2_net_name,
+ caffe2_init_net,
+ value_info,
+ output):
+ c2_net_proto = caffe2_pb2.NetDef()
+ c2_net_proto.ParseFromString(caffe2_net.read())
+ if not c2_net_proto.name and not caffe2_net_name:
+ raise click.BadParameter(
+ 'The input caffe2 net does not have name, '
+ '--caffe2-net-name must be provided')
+ c2_net_proto.name = caffe2_net_name or c2_net_proto.name
+ if caffe2_init_net:
+ c2_init_net_proto = caffe2_pb2.NetDef()
+ c2_init_net_proto.ParseFromString(caffe2_init_net.read())
+ c2_init_net_proto.name = '{}_init'.format(caffe2_net_name)
+ else:
+ c2_init_net_proto = None
+
+ if value_info:
+ value_info = json.loads(value_info)
+
+ onnx_model = c2_onnx.caffe2_net_to_onnx_model(
+ predict_net=c2_net_proto,
+ init_net=c2_init_net_proto,
+ value_info=value_info)
+
+ output.write(onnx_model.SerializeToString())
+
+
+@click.command(
+ help='convert onnx model to caffe2 net',
+ context_settings={
+ 'help_option_names': ['-h', '--help']
+ }
+)
+@click.argument('onnx_model', type=click.File('rb'))
+@click.option('-o', '--output', required=True,
+ type=click.File('wb'),
+ help='Output path for the caffe2 net file')
+@click.option('--init-net-output',
+ required=True,
+ type=click.File('wb'),
+ help='Output path for the caffe2 init net file')
+def onnx_to_caffe2(onnx_model, output, init_net_output):
+ onnx_model_proto = ModelProto()
+ onnx_model_proto.ParseFromString(onnx_model.read())
+
+ init_net, predict_net = c2.onnx_graph_to_caffe2_net(onnx_model_proto)
+ init_net_output.write(init_net.SerializeToString())
+ output.write(predict_net.SerializeToString())
diff --git a/caffe2/python/onnx/error.py b/caffe2/python/onnx/error.py
new file mode 100644
index 0000000000..760d7b048b
--- /dev/null
+++ b/caffe2/python/onnx/error.py
@@ -0,0 +1,23 @@
+# Copyright (c) 2016-present, Facebook, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+##############################################################################
+
+## @package onnx
+# Module caffe2.python.onnx.error
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+class BaseException(Exception): pass
+class Unsupported(BaseException): pass
diff --git a/caffe2/python/onnx/frontend.py b/caffe2/python/onnx/frontend.py
new file mode 100644
index 0000000000..bac6024e09
--- /dev/null
+++ b/caffe2/python/onnx/frontend.py
@@ -0,0 +1,551 @@
+# Copyright (c) 2016-present, Facebook, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+##############################################################################
+
+## @package onnx
+# Module caffe2.python.onnx.frontend
+
+"""Caffe2 Protobuf to ONNX converter
+
+To run this, you will need to have Caffe2 installed as well.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import itertools
+import collections
+import logging
+import re
+
+from caffe2.python import core as caffe2_core
+from enum import Enum
+from onnx import (defs, checker, helper, numpy_helper, mapping,
+ ModelProto, GraphProto, NodeProto, AttributeProto, TensorProto, OperatorSetIdProto)
+from onnx.helper import make_tensor, make_tensor_value_info
+import numpy as np
+
+from caffe2.python.onnx.helper import make_model, c2_native_run_net, dummy_name
+from caffe2.python.onnx.error import Unsupported
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+class Caffe2Frontend(object):
+ # This number controls the semantics of the operators we target. Whenever
+ # ONNX makes a BC breaking change to semantics of operators, having this set
+ # to an accurate number will prevent our models form exporting. However,
+ # we should strive to keep this up-to-date as much as possible.
+ _target_opset_version = 3
+
+ _renamed_operators = {
+ 'SpatialBN': 'BatchNormalization',
+ 'Conv1D': 'Conv',
+ 'Conv2D': 'Conv',
+ 'Conv3D': 'Conv',
+ 'ConvTranspose1D': 'ConvTranspose',
+ 'ConvTranspose2D': 'ConvTranspose',
+ 'ConvTranspose3D': 'ConvTranspose',
+ 'MaxPool1D': 'MaxPool',
+ 'MaxPool2D': 'MaxPool',
+ 'MaxPool3D': 'MaxPool',
+ 'AveragePool1D': 'AveragePool',
+ 'AveragePool2D': 'AveragePool',
+ 'AveragePool3D': 'AveragePool',
+ }
+
+ # caffe2 arguments that are completely removed in onnx
+ _blacklist_caffe2_args = {
+ 'order': {b'NCHW'},
+ 'cudnn_exhaustive_search': {0, 1},
+ 'use_cudnn': {0, 1},
+ }
+
+ _global_renamed_args = {
+ 'kernels': 'kernel_shape',
+ }
+
+ _per_op_renamed_args = {
+ 'Squeeze': {'dims': 'axes'},
+ 'Transpose': {'axes': 'perm'},
+ }
+
+ _special_operators = {
+ 'Conv': '_create_conv_pool_op',
+ 'ConvTranspose': '_create_conv_pool_op',
+ 'ChannelShuffle': '_create_channel_shuffle',
+ 'MaxPool': '_create_conv_pool_op',
+ 'AveragePool': '_create_conv_pool_op',
+ 'Concat': '_create_concat',
+ 'FC': '_create_gemm',
+ 'LRN': '_create_lrn',
+ 'Slice': '_create_slice',
+ 'Reshape': '_create_reshape',
+ }
+
+ @classmethod
+ def _common_caffe2_arg_to_onnx_attr(cls, op_def, arg):
+ # name
+ op_type = op_def.type
+ if op_type in cls._per_op_renamed_args:
+ name = cls._per_op_renamed_args[op_type].get(
+ arg.name, arg.name)
+ else:
+ name = cls._global_renamed_args.get(arg.name, arg.name)
+
+ # value
+ if arg.HasField('f'):
+ value = arg.f
+ elif arg.HasField('i'):
+ value = arg.i
+ elif arg.HasField('s'):
+ value = arg.s
+ elif arg.floats:
+ value = arg.floats
+ elif arg.ints:
+ value = arg.ints
+ elif arg.strings:
+ value = arg.strings
+ else:
+ raise ValueError('Could not find data field in arg: {}'.format(arg))
+
+ if name in cls._blacklist_caffe2_args:
+ assert value in cls._blacklist_caffe2_args[arg.name]
+ return None
+
+ return helper.make_attribute(name, value)
+
+ @classmethod
+ def caffe2_arg_to_onnx_attr(cls, op_def, arg):
+ return cls._common_caffe2_arg_to_onnx_attr(op_def, arg)
+
+ @classmethod
+ def _common_caffe2_op_to_onnx_node(cls, op_def, shapes):
+ node_def = NodeProto()
+ node_def.name = op_def.name
+
+ node_def.op_type = cls._renamed_operators.get(op_def.type, op_def.type)
+
+ node_def.input.extend(op_def.input)
+ node_def.output.extend(op_def.output)
+
+ attrs = filter(None, [cls.caffe2_arg_to_onnx_attr(op_def, arg)
+ for arg in op_def.arg])
+ node_def.attribute.extend(attrs)
+
+ return node_def
+
+ @classmethod
+ def _create_concat(cls, op_def, shapes):
+ node = cls._common_caffe2_op_to_onnx_node(op_def, shapes)
+ if len(node.output) == 2:
+ del node.output[1]
+ explicit_axis = any(arg.name == 'axis' for arg in op_def.arg)
+ if not explicit_axis:
+ node.attribute.extend([helper.make_attribute('axis', 1)])
+ return node
+
+ @classmethod
+ def _create_reshape(cls, op_def, shapes):
+ node = cls._common_caffe2_op_to_onnx_node(op_def, shapes)
+ if len(node.output) == 2:
+ del node.output[1]
+ return node
+
+ @classmethod
+ def _create_conv_pool_op(cls, op_def, shapes):
+ node = cls._common_caffe2_op_to_onnx_node(op_def, shapes)
+
+ if node.op_type in ['MaxPool', 'AveragePool']:
+ for i, attr in enumerate(node.attribute):
+ if attr.name == 'global_pooling' and attr.i:
+ node.op_type = 'Global{}'.format(node.op_type)
+ del node.attribute[i]
+ break
+
+ attrs = {attr.name: attr for attr in node.attribute}
+ def apply_trans(k, dim=2, ks=None):
+ ks = ks or (k + 's')
+ if dim == 2:
+ k_h, k_w = k + '_h', k + '_w'
+ else:
+ k_t, k_l, k_b, k_r = k + '_t', k + '_l', k + '_b', k + '_r'
+
+ vals = None
+ if (dim == 2 and k_h in attrs and k_w in attrs):
+ vals = [attrs[k_h].i, attrs[k_w].i]
+ del attrs[k_h]
+ del attrs[k_w]
+ elif (dim == 4 and
+ k_t in attrs and k_l in attrs and k_b in attrs and k_r in attrs):
+ vals = [attrs[k_t].i,
+ attrs[k_l].i,
+ attrs[k_b].i,
+ attrs[k_r].i]
+ del attrs[k_t]
+ del attrs[k_l]
+ del attrs[k_b]
+ del attrs[k_r]
+ elif k in attrs:
+ vals = [attrs[k].i] * dim
+ del attrs[k]
+
+ if vals and not node.op_type.startswith('Global'):
+ attrs[ks] = helper.make_attribute(ks, vals)
+
+ apply_trans('kernel', ks='kernel_shape')
+ apply_trans('stride')
+ apply_trans('dilation')
+ apply_trans('adj')
+ apply_trans('pad', 4)
+
+ del node.attribute[:]
+ node.attribute.extend(attrs.values())
+ return node
+
+ @classmethod
+ def _create_gemm(cls, op_def, shapes):
+ x, w, b = op_def.input
+ args = {arg.name: arg for arg in op_def.arg}
+ y, = op_def.output
+ x_shape = list(shapes[x])
+
+ nodes = []
+ if 'axis' in args:
+ axis = args['axis'].i
+ outer = np.prod(x_shape[:axis]).astype(int)
+ inner = np.prod(x_shape[axis:]).astype(int)
+ reshaped_x = dummy_name()
+ nodes.append(helper.make_node(
+ 'Reshape',
+ inputs=[x],
+ outputs=[reshaped_x],
+ shape=[outer, inner],
+ ))
+ x = reshaped_x
+
+ if 'axis_w' in args:
+ axis_w = args['axis_w'].i
+ w_shape = shapes[w]
+ outer = np.prod(w_shape[:axis_w]).astype(int).item()
+ inner = np.prod(w_shape[axis_w:]).astype(int).item()
+ reshaped_w = dummy_name()
+ nodes.append(helper.make_node(
+ 'Reshape',
+ inputs=[w],
+ outputs=[reshaped_w],
+ shape=[outer, inner],
+ ))
+ w = reshaped_w
+
+ gemm_y_output = dummy_name() if 'axis' in args else y
+ nodes.append(helper.make_node(
+ 'Gemm',
+ inputs=[x, w, b],
+ outputs=[gemm_y_output],
+ name=op_def.name,
+ transB=1,
+ broadcast=1,
+ ))
+
+ if 'axis' in args:
+ axis = args['axis'].i
+ nodes.append(helper.make_node(
+ 'Reshape',
+ inputs=[gemm_y_output],
+ outputs=[y],
+ shape=x_shape[:axis] + [-1],
+ ))
+
+ return nodes
+
+ @classmethod
+ def _create_lrn(cls, op_def, shapes):
+ node = cls._common_caffe2_op_to_onnx_node(op_def, shapes)
+ if len(node.output) == 2:
+ del node.output[1]
+ return node
+
+ @classmethod
+ def _create_slice(cls, op_def, shapes):
+ if len(op_def.input) > 1:
+ raise Unsupported(
+ 'ONNX Slice operator does not support dynamic slice.')
+ node = cls._common_caffe2_op_to_onnx_node(op_def, shapes)
+ attrs = {attr.name: attr for attr in node.attribute}
+ ndims = len(attrs['starts'].ints)
+
+ node.attribute.extend([helper.make_attribute('axes', range(ndims))])
+
+ data, = node.input
+ shape = shapes[data]
+
+ ends = attrs['ends'].ints
+ for i, end in enumerate(ends):
+ if end >= 0:
+ continue
+ if end == -1:
+ end = shape[i]
+ else:
+ end = end + 1
+ ends[i] = end
+
+ return node
+
+ @classmethod
+ def _create_channel_shuffle(cls, op_def, shapes):
+ x, = op_def.input
+ y, = op_def.output
+ n, c, h, w = shapes[x]
+ args = {arg.name: arg for arg in op_def.arg}
+ g = args['group'].i
+ assert c % g == 0
+
+ nodes = []
+
+ tmp1 = dummy_name()
+ nodes.append(helper.make_node(
+ 'Reshape',
+ inputs=[x],
+ outputs=[tmp1],
+ shape=[n, g, c // g, h, w],
+ ))
+
+ tmp2 = dummy_name()
+ nodes.append(helper.make_node(
+ 'Transpose',
+ inputs=[tmp1],
+ outputs=[tmp2],
+ perm=[0, 2, 1, 3, 4],
+ ))
+
+ nodes.append(helper.make_node(
+ 'Reshape',
+ inputs=[tmp2],
+ outputs=[y],
+ shape=[n, c, h, w],
+ ))
+ return nodes
+
+ @classmethod
+ def caffe2_op_to_onnx_node(cls, op_def, shapes):
+ if op_def.type in cls._special_operators:
+ translator = getattr(cls, cls._special_operators[op_def.type])
+ else:
+ translator = cls._common_caffe2_op_to_onnx_node
+ nodes = translator(op_def, shapes)
+ if not isinstance(nodes, collections.Iterable):
+ nodes = [nodes]
+ return nodes
+
+ @staticmethod
+ def _all_names_in_net(net):
+ if net is None:
+ return set()
+
+ names = set()
+ names.update(net.external_input)
+ names.update(net.external_output)
+ for op in net.op:
+ names.update(op.input)
+ names.update(op.output)
+ return names
+
+ @classmethod
+ def caffe2_net_to_onnx_graph(cls,
+ predict_net,
+ init_net=None,
+ value_info=None):
+ if value_info is None:
+ value_info = {}
+ if not isinstance(value_info, dict):
+ raise ValueError('Please pass value_info as a '
+ 'name -> (type, shape) dictionary')
+
+ cls._ssa_rewrite(predict_net, init_net, value_info)
+
+ if init_net:
+ initializer = cls.caffe2_init_net_to_initializer(init_net)
+ value_info.update({init.name: (init.data_type, init.dims)
+ for init in initializer})
+ else:
+ initializer = []
+
+ # Check whether we have got type shape info of all input
+ missing = (set(list(predict_net.external_input)) -
+ set(value_info.keys()))
+ if missing:
+ raise RuntimeError('Could not find value info of inputs: {}'.format(
+ ', '.join(missing)))
+
+ inputs = {}
+ for name in predict_net.external_input:
+ elem_type, shape = value_info[name]
+ inputs[name] = np.random.randn(*shape).astype(
+ mapping.TENSOR_TYPE_TO_NP_TYPE[elem_type])
+
+ ws, outputs = c2_native_run_net(
+ init_net,
+ predict_net,
+ inputs)
+
+ for name in predict_net.external_output:
+ output = outputs[name]
+ elem_type = mapping.NP_TYPE_TO_TENSOR_TYPE[output.dtype]
+ shape = output.shape
+ value_info[name] = (elem_type, shape)
+
+ graph_def = GraphProto()
+ graph_def.name = predict_net.name
+ graph_def.initializer.extend(initializer)
+ # This is a mapping from Caffe2 names to ONNX names
+ graph_def.input.extend(
+ make_tensor_value_info(
+ name=name,
+ elem_type=value_info[name][0],
+ shape=value_info[name][1])
+ for name in predict_net.external_input)
+
+ dummy_name(cls._all_names_in_net(predict_net) |
+ cls._all_names_in_net(init_net))
+
+ for op in predict_net.op:
+ shapes = {}
+ for name in itertools.chain(op.input, op.output):
+ blob = ws.FetchBlob(name)
+ if hasattr(blob, 'shape'):
+ shapes[name] = blob.shape
+ graph_def.node.extend(
+ cls.caffe2_op_to_onnx_node(
+ op, shapes=shapes))
+
+ all_output = set(sum((list(node.output) for node in graph_def.node),
+ [init.name for init in graph_def.initializer]))
+ redundant_output = set(vi.name for vi in graph_def.output) - all_output
+ if redundant_output:
+ logger.warning(
+ 'There are graph output not produced by any node or initializer: {}'
+ '! Will drop them.'.format(', '.join(redundant_output)))
+ graph_def.output.extend(
+ make_tensor_value_info(
+ name=name,
+ elem_type=value_info[name][0],
+ shape=value_info[name][1])
+ for name in predict_net.external_output
+ if name in all_output)
+
+ cls._annotate_consumed(graph_def)
+ checker.check_graph(graph_def)
+ return graph_def
+
+ @classmethod
+ def caffe2_init_net_to_initializer(cls, init_net):
+ initializer = []
+ for op in init_net.op:
+ assert not op.input
+ try:
+ data_type, field_name = {
+ 'GivenTensorFill': (TensorProto.FLOAT, 'floats'),
+ 'GivenTensorInt64Fill': (TensorProto.INT64, 'ints'),
+ 'GivenTensorIntFill': (TensorProto.INT32, 'ints'),
+ 'GivenTensorBoolFill': (TensorProto.BOOL, 'ints'),
+ 'GivenTensorStringFill': (TensorProto.STRING, 'strings'),
+ }[op.type]
+ except KeyError:
+ raise RuntimeError(
+ "Can not translate init_net with operator '{}' "
+ "to initializer".format(op.type)
+ )
+ raw = (data_type != TensorProto.STRING)
+ args = {a.name: a for a in op.arg}
+ vals = getattr(args['values'], field_name)
+ if raw:
+ vals = np.asarray(
+ vals,
+ dtype=mapping.TENSOR_TYPE_TO_NP_TYPE[data_type]).tobytes()
+ initializer.append(make_tensor(
+ name=op.output[0],
+ data_type=data_type,
+ dims=args['shape'].ints,
+ vals=vals,
+ raw=raw,
+ ))
+ return initializer
+
+ @classmethod
+ def _annotate_consumed(cls, graph_def):
+ for node in graph_def.node:
+ schema = defs.get_schema(node.op_type)
+ consumes = []
+ for i, _input_name in enumerate(node.input):
+ consume_type, output_idx = schema.consumed(i)
+ if consume_type == defs.OpSchema.UseType.CONSUME_ENFORCED:
+ consumes.append(1)
+ else:
+ consumes.append(0)
+
+ if any(consumes):
+ node.attribute.extend([helper.make_attribute(
+ 'consumed_inputs',
+ consumes,
+ )])
+
+ @classmethod
+ def _ssa_rewrite(cls, net, init_net, value_info):
+ def ssa_name(name, version):
+ return '{}_{}'.format(name, version)
+
+ if init_net:
+ for op in init_net.op:
+ assert re.match('GivenTensor.*Fill', op.type)
+ assert len(op.output) == 1
+ op.output[0] = ssa_name(op.output[0], 0)
+ init_net.external_input[:] = [ssa_name(name, 0)
+ for name in init_net.external_input]
+ init_net.external_output[:] = [ssa_name(name, 0)
+ for name in init_net.external_output]
+ if value_info:
+ ssa_value_info = {ssa_name(name, 0): value
+ for name, value in value_info.items()}
+ value_info.clear()
+ value_info.update(ssa_value_info)
+ net.external_input[:] = [ssa_name(name, 0)
+ for name in net.external_input]
+ ssa, blob_versions = caffe2_core.get_ssa(net)
+ assert len(net.op) == len(ssa)
+ for op, (versioned_inputs, versioned_outputs) in zip(net.op, ssa):
+ op.input[:] = [ssa_name(name, version)
+ for name, version in versioned_inputs]
+ op.output[:] = [ssa_name(name, version)
+ for name, version in versioned_outputs]
+ net.external_output[:] = [ssa_name(name, blob_versions[name])
+ for name in net.external_output]
+
+ @classmethod
+ def caffe2_net_to_onnx_model(cls, *args, **kwargs):
+ model = make_model(cls.caffe2_net_to_onnx_graph(*args, **kwargs))
+ opset_id = OperatorSetIdProto()
+ opset_id.domain = '' # ONNX
+ opset_id.version = cls._target_opset_version
+ model.opset_import.extend([opset_id])
+ checker.check_model(model)
+ return model
+
+
+caffe2_net_to_onnx_graph = Caffe2Frontend.caffe2_net_to_onnx_graph
+caffe2_net_to_onnx_model = Caffe2Frontend.caffe2_net_to_onnx_model
+caffe2_init_net_to_initializer = Caffe2Frontend.caffe2_init_net_to_initializer
diff --git a/caffe2/python/onnx/helper.py b/caffe2/python/onnx/helper.py
new file mode 100644
index 0000000000..7d95619e9a
--- /dev/null
+++ b/caffe2/python/onnx/helper.py
@@ -0,0 +1,157 @@
+# Copyright (c) 2016-present, Facebook, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+##############################################################################
+
+## @package onnx
+# Module caffe2.python.onnx.helper
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from caffe2.proto import caffe2_pb2
+from onnx import helper
+from onnx.backend.base import namedtupledict
+
+from caffe2.python.onnx.workspace import Workspace
+
+import io
+import logging
+import time
+
+
+log = logging.getLogger(__name__)
+
+
+class _DummyNameFactory(object):
+ used_names = set()
+ counter = 0
+
+ @classmethod
+ def dummy_name(cls, used_names=None):
+ if used_names is not None:
+ cls.used_names.clear()
+ cls.used_names.update(used_names)
+ cls.counter = 0
+ return None
+ else:
+ while True:
+ name = 'OC2_DUMMY_{}'.format(cls.counter)
+ cls.counter += 1
+ if name not in cls.used_names:
+ cls.used_names.add(name)
+ return name
+
+dummy_name = _DummyNameFactory.dummy_name
+
+
+def make_model(graph, **kwargs):
+ kwargs.setdefault('producer_name', 'onnx-caffe2')
+ return helper.make_model(graph=graph, **kwargs)
+
+
+def c2_native_run_op(op_def, inputs):
+ ws = Workspace()
+ if isinstance(inputs, dict):
+ for key, value in inputs.items():
+ ws.FeedBlob(key, value, op_def.device_option)
+ else:
+ assert(len(op_def.input) == len(inputs))
+ for key, value in zip(op_def.input, inputs):
+ ws.FeedBlob(key, value, op_def.device_option)
+
+ ws.RunOperatorOnce(op_def)
+
+ output_names = op_def.output
+ output_values = [ws.FetchBlob(name) for name in output_names]
+ return ws, namedtupledict('Outputs', output_names)(*output_values)
+
+
+def c2_native_run_net(init_net, predict_net, inputs):
+ ws = Workspace()
+ if init_net:
+ ws.RunNetOnce(init_net)
+
+ if isinstance(inputs, dict):
+ for key, value in inputs.items():
+ ws.FeedBlob(key, value, predict_net.device_option)
+ else:
+ uninitialized = [input_name
+ for input_name in predict_net.external_input
+ if not ws.HasBlob(input_name)]
+ if len(uninitialized) == len(inputs):
+ for key, value in zip(uninitialized, inputs):
+ ws.FeedBlob(key, value, predict_net.device_option)
+ else:
+ # If everything is initialized,
+ # we just initialized the first len(inputs) external_input.
+ assert(len(inputs) <= len(predict_net.external_input))
+ for i in range(len(inputs)):
+ ws.FeedBlob(predict_net.external_input[i], inputs[i],
+ predict_net.device_option)
+
+ ws.RunNetOnce(predict_net)
+
+ output_names = predict_net.external_output
+ output_values = [ws.FetchBlob(name) for name in output_names]
+ return ws, namedtupledict('Outputs', output_names)(*output_values)
+
+
+def load_caffe2_net(file):
+ net = caffe2_pb2.NetDef()
+ with open(file, "rb") as f:
+ net.ParseFromString(f.read())
+ return net
+
+
+def save_caffe2_net(net, file, output_txt=False):
+ with open(file, "wb") as f:
+ f.write(net.SerializeToString())
+ if output_txt:
+ with open(file + "txt", "w") as f:
+ f.write(str(net))
+
+
+def benchmark_caffe2_model(init_net, predict_net, warmup_iters=3, main_iters=10, layer_details=True):
+ '''
+ Run the benchmark net on the target model.
+ Return the execution time per iteration (millisecond).
+ '''
+ ws = Workspace()
+ if init_net:
+ ws.RunNetOnce(init_net)
+ ws.CreateNet(predict_net)
+ results = ws.BenchmarkNet(predict_net.name, warmup_iters, main_iters, layer_details)
+ del ws
+ return results[0]
+
+
+def benchmark_pytorch_model(model, inputs, training=False, warmup_iters=3,
+ main_iters=10, verbose=False):
+ '''
+ Run the model several times, and measure the execution time.
+ Return the execution time per iteration (millisecond).
+ '''
+ for _i in range(warmup_iters):
+ model(*inputs)
+ total_pytorch_time = 0.0
+ for _i in range(main_iters):
+ ts = time.time()
+ model(*inputs)
+ te = time.time()
+ total_pytorch_time += te - ts
+ log.info("The PyTorch model execution time per iter is {} milliseconds, "
+ "{} iters per second.".format(total_pytorch_time / main_iters * 1000,
+ main_iters / total_pytorch_time))
+ return total_pytorch_time * 1000 / main_iters
diff --git a/caffe2/python/onnx/tests/caffe2_ref_test.py b/caffe2/python/onnx/tests/caffe2_ref_test.py
new file mode 100644
index 0000000000..d6bc396451
--- /dev/null
+++ b/caffe2/python/onnx/tests/caffe2_ref_test.py
@@ -0,0 +1,357 @@
+# Copyright (c) 2016-present, Facebook, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+##############################################################################
+
+## @package onnx
+# Module caffe2.python.onnx.tests.caffe2_ref_test
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import json
+import os
+import unittest
+
+from caffe2.python import core
+from caffe2.proto import caffe2_pb2
+
+import onnx
+from onnx.helper import make_node, make_graph, make_tensor, make_tensor_value_info
+from caffe2.python.onnx.helper import make_model, c2_native_run_net, c2_native_run_op
+
+from onnx import defs, mapping
+import caffe2.python.onnx.frontend as c2_onnx
+import caffe2.python.onnx.backend as c2
+
+import numpy as np
+from caffe2.python.models.download import downloadFromURLToFile, getURLFromName, deleteDirectory
+
+from caffe2.python.onnx.helper import dummy_name
+from caffe2.python.onnx.tests.test_utils import TestCase
+
+
+class TestCaffe2Basic(TestCase):
+ def test_dummy_name(self):
+ n1 = dummy_name()
+ n2 = dummy_name()
+ assert n1 != n2, "Got same names in different calls: {}".format(n1)
+
+ def test_relu_node_inplace(self):
+ X = np.random.randn(3, 2).astype(np.float32)
+ Y_ref = np.clip(X, 0, np.inf)
+
+ node_def = make_node(
+ "Relu", ["X"], ["Y"], consumed_inputs=[1])
+ output = c2.run_node(
+ node_def, {"X": X})
+ np.testing.assert_almost_equal(output.X, Y_ref)
+
+ node_def = make_node(
+ "Relu", ["X"], ["Y"], consumed_inputs=[1])
+ graph_def = make_graph(
+ [node_def],
+ name="test",
+ inputs=[make_tensor_value_info("X", onnx.TensorProto.FLOAT, [3, 2])],
+ outputs=[make_tensor_value_info("X", onnx.TensorProto.FLOAT, [3, 2])])
+ c2_rep = c2.prepare(make_model(graph_def))
+ output = c2_rep.run({"X": X})
+ np.testing.assert_almost_equal(output.X, Y_ref)
+
+ def test_relu_graph(self):
+ X = np.random.randn(3, 2).astype(np.float32)
+ Y_ref = np.clip(X, 0, np.inf)
+
+ node_def = make_node(
+ "Relu", ["X"], ["Y"])
+ output = c2.run_node(
+ node_def, {"X": X})
+ np.testing.assert_almost_equal(output.Y, Y_ref)
+
+ graph_def = make_graph(
+ [node_def],
+ name="test",
+ inputs=[make_tensor_value_info("X", onnx.TensorProto.FLOAT, [3, 2])],
+ outputs=[make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [3, 2])])
+ c2_rep = c2.prepare(make_model(graph_def))
+ output = c2_rep.run(X)
+ np.testing.assert_almost_equal(output.Y, Y_ref)
+
+ def test_initializer(self):
+ X = np.array([[1, 2], [3, 4]]).astype(np.float32)
+ Y = np.array([[1, 2], [3, 4]]).astype(np.float32)
+ weight = np.array([[1, 0], [0, 1]])
+ graph_def = make_graph(
+ [make_node("Add", ["X", "Y"], ["Z0"]),
+ make_node("Cast", ["Z0"], ["Z"], to="float"),
+ make_node("Mul", ["Z", "weight"], ["W0"]),
+ make_node("Tanh", ["W0"], ["W1"]),
+ make_node("Sigmoid", ["W1"], ["W2"]),
+ make_node("Scale", ["W2"], ["W3"], scale=-1.0)],
+ name="test_initializer",
+ inputs=[
+ make_tensor_value_info("X", onnx.TensorProto.FLOAT, (2, 2)),
+ make_tensor_value_info("Y", onnx.TensorProto.FLOAT, (2, 2)),
+ make_tensor_value_info("weight", onnx.TensorProto.FLOAT, (2, 2)),
+ ],
+ outputs=[
+ make_tensor_value_info("W3", onnx.TensorProto.FLOAT, (2, 2))
+ ],
+ initializer=[make_tensor("weight",
+ onnx.TensorProto.FLOAT,
+ [2, 2],
+ weight.flatten().astype(float))]
+ )
+
+ def sigmoid(x):
+ return 1 / (1 + np.exp(-x))
+
+ W_ref = -sigmoid(np.tanh((X + Y) * weight))
+ c2_rep = c2.prepare(make_model(graph_def))
+ output = c2_rep.run({"X": X, "Y": Y})
+ np.testing.assert_almost_equal(output["W3"], W_ref)
+
+ def test_gemm(self):
+ # simple
+ A = np.random.randn(3, 2).astype(np.float32)
+ B = np.random.randn(2, 4).astype(np.float32)
+ C = np.random.randn(3, 4).astype(np.float32)
+ node_def = make_node(
+ 'Gemm',
+ ['A', 'B', 'C'],
+ ["Y"])
+ output = c2.run_node(node_def, [A, B, C])
+ np.testing.assert_almost_equal(output["Y"], np.dot(A, B) + C)
+
+ # transA
+ A = np.transpose(A)
+ node_def = make_node(
+ 'Gemm',
+ ['A', 'B', 'C'],
+ ["Y"],
+ transA=True)
+ output = c2.run_node(node_def, [A, B, C])
+ np.testing.assert_almost_equal(
+ output["Y"],
+ np.dot(np.transpose(A), B) + C)
+ # revert A
+ A = np.transpose(A)
+
+ # transB
+ B = np.transpose(B)
+ node_def = make_node(
+ 'Gemm',
+ ['A', 'B', 'C'],
+ ["Y"],
+ transB=True)
+ output = c2.run_node(node_def, [A, B, C])
+ np.testing.assert_almost_equal(
+ output["Y"],
+ np.dot(A, np.transpose(B)) + C)
+ # revert A
+ B = np.transpose(B)
+
+ # scale
+ alpha = np.random.random()
+ beta = np.random.random()
+ node_def = make_node(
+ 'Gemm',
+ ['A', 'B', 'C'],
+ ["Y"],
+ alpha=alpha,
+ beta=beta)
+ output = c2.run_node(node_def, [A, B, C])
+ np.testing.assert_almost_equal(
+ output["Y"],
+ alpha * np.dot(A, B) + beta * C)
+
+ # broadcast
+ C = np.random.randn(4).astype(np.float32)
+ node_def = make_node(
+ 'Gemm',
+ ['A', 'B', 'C'],
+ ["Y"],
+ alpha=alpha,
+ beta=beta,
+ broadcast=1)
+ output = c2.run_node(node_def, [A, B, C])
+ np.testing.assert_almost_equal(
+ output["Y"],
+ alpha * np.dot(A, B) + beta * C)
+
+ def test_tensor_filling_ops(self):
+ for dtype in [
+ onnx.TensorProto.FLOAT,
+ onnx.TensorProto.DOUBLE,
+ onnx.TensorProto.BOOL,
+ onnx.TensorProto.INT8,
+ onnx.TensorProto.INT16,
+ onnx.TensorProto.INT32,
+ onnx.TensorProto.INT64,
+ onnx.TensorProto.UINT8,
+ onnx.TensorProto.UINT16,
+ onnx.TensorProto.UINT32,
+ ]:
+ shape = (1, 2, 3)
+ vals = np.random.randn(*shape)
+ if dtype != onnx.TensorProto.BOOL:
+ vals *= 5
+ vals = vals.astype(
+ mapping.TENSOR_TYPE_TO_NP_TYPE[dtype])
+ tensor = make_tensor(
+ name='test-tensor-{}'.format(dtype),
+ data_type=dtype,
+ dims=[1, 2, 3],
+ vals=vals.flatten().tolist(),
+ )
+ op = c2.Caffe2Backend._create_tensor_filling_op(tensor)
+ self.assertEqual(len(op.input), 0)
+ self.assertEqual(op.output, [tensor.name])
+ ws, output = c2_native_run_op(op, inputs=[])
+ self.assertEqual(len(output), 1)
+ np.testing.assert_almost_equal(output[0], vals)
+ np.testing.assert_almost_equal(ws.FetchBlob(op.output[0]), vals)
+
+ def test_slice(self):
+ X = np.random.randn(1, 2, 3).astype(np.float32)
+ starts = np.array([0, 1, 0], dtype=np.int32)
+ ends = np.array([-1, 2, 3], dtype=np.int32)
+
+ predict_net = caffe2_pb2.NetDef()
+ predict_net.name = 'test-slice-net'
+ predict_net.external_input[:] = ['X']
+ predict_net.external_output[:] = ['Y']
+ predict_net.op.extend([
+ core.CreateOperator(
+ 'Slice',
+ inputs=['X'],
+ outputs=['Y'],
+ starts=starts,
+ ends=ends,
+ ),
+ ])
+ ws, (Y,) = c2_native_run_net(
+ init_net=None,
+ predict_net=predict_net,
+ inputs=[X])
+
+ onnx_model = c2_onnx.caffe2_net_to_onnx_model(
+ predict_net=predict_net,
+ value_info={
+ 'X': (onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[X.dtype], X.shape)
+ })
+ Y, = c2.run_model(onnx_model, inputs=[X])
+ np.testing.assert_almost_equal(Y, X[:, 1:2, :])
+
+
+class TestCaffe2End2End(TestCase):
+ def _model_dir(self, model):
+ caffe2_home = os.path.expanduser(os.getenv('ONNX_HOME', '~/.caffe2'))
+ models_dir = os.getenv('ONNX_MODELS', os.path.join(caffe2_home, 'models'))
+ return os.path.join(models_dir, model)
+
+ def _test_net(self,
+ net_name,
+ input_blob_dims=(1, 3, 224, 224),
+ decimal=7):
+ np.random.seed(seed=0)
+ model_dir = self._model_dir(net_name)
+ if not os.path.exists(model_dir):
+ self._download(net_name)
+ c2_predict_pb = os.path.join(model_dir, 'predict_net.pb')
+ c2_predict_net = caffe2_pb2.NetDef()
+ with open(c2_predict_pb, 'rb') as f:
+ c2_predict_net.ParseFromString(f.read())
+ c2_predict_net.name = net_name
+
+ c2_init_pb = os.path.join(model_dir, 'init_net.pb')
+ c2_init_net = caffe2_pb2.NetDef()
+ with open(c2_init_pb, 'rb') as f:
+ c2_init_net.ParseFromString(f.read())
+ c2_init_net.name = net_name + '_init'
+
+ n, c, h, w = input_blob_dims
+ data = np.random.randn(n, c, h, w).astype(np.float32)
+ inputs = [data]
+ _, c2_outputs = c2_native_run_net(c2_init_net, c2_predict_net, inputs)
+ del _
+
+ model = c2_onnx.caffe2_net_to_onnx_model(
+ predict_net=c2_predict_net,
+ init_net=c2_init_net,
+ value_info=json.load(open(os.path.join(model_dir, 'value_info.json'))))
+ c2_ir = c2.prepare(model)
+ onnx_outputs = c2_ir.run(inputs)
+ self.assertSameOutputs(c2_outputs, onnx_outputs, decimal=decimal)
+
+ def _download(self, model):
+ model_dir = self._model_dir(model)
+ assert not os.path.exists(model_dir)
+ os.makedirs(model_dir)
+ for f in ['predict_net.pb', 'init_net.pb', 'value_info.json']:
+ url = getURLFromName(model, f)
+ dest = os.path.join(model_dir, f)
+ try:
+ try:
+ downloadFromURLToFile(url, dest,
+ show_progress=False)
+ except TypeError:
+ # show_progress not supported prior to
+ # Caffe2 78c014e752a374d905ecfb465d44fa16e02a28f1
+ # (Sep 17, 2017)
+ downloadFromURLToFile(url, dest)
+ except Exception as e:
+ print("Abort: {reason}".format(reason=e))
+ print("Cleaning up...")
+ deleteDirectory(model_dir)
+ exit(1)
+
+ def test_alexnet(self):
+ self._test_net('bvlc_alexnet', decimal=4)
+
+ def test_resnet50(self):
+ self._test_net('resnet50')
+
+ @unittest.skipIf(
+ os.environ.get('JENKINS_URL'),
+ 'Taking too long to download!')
+ def test_vgg16(self):
+ self._test_net('vgg16')
+
+ @unittest.skipIf(
+ os.environ.get('JENKINS_URL'),
+ 'Running vgg19 on Travis with Python 2 keeps getting OOM!')
+ def test_vgg19(self):
+ self._test_net('vgg19')
+
+ def test_inception_v1(self):
+ self._test_net('inception_v1', decimal=2)
+
+ def test_inception_v2(self):
+ self._test_net('inception_v2')
+
+ @unittest.skip('Need to add support for ConstantFill operator')
+ def test_squeezenet(self):
+ self._test_net('squeezenet')
+
+ def test_shufflenet(self):
+ self._test_net('shufflenet')
+
+ def test_densenet121(self):
+ self._test_net('densenet121')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/caffe2/python/onnx/tests/conversion_test.py b/caffe2/python/onnx/tests/conversion_test.py
new file mode 100644
index 0000000000..99c7d94848
--- /dev/null
+++ b/caffe2/python/onnx/tests/conversion_test.py
@@ -0,0 +1,244 @@
+# Copyright (c) 2016-present, Facebook, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+##############################################################################
+
+## @package onnx
+# Module caffe2.python.onnx.tests.conversion_test
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+import tempfile
+import textwrap
+import traceback
+
+from caffe2.proto import caffe2_pb2
+from caffe2.python import brew, core
+from caffe2.python.model_helper import ModelHelper
+from click.testing import CliRunner
+import numpy as np
+from onnx import helper, ModelProto, TensorProto
+from caffe2.python.onnx.helper import make_model, c2_native_run_net
+
+from caffe2.python.onnx.bin.conversion import caffe2_to_onnx, onnx_to_caffe2
+from caffe2.python.onnx.helper import dummy_name
+import caffe2.python.onnx.backend as c2
+from caffe2.python.onnx.tests.test_utils import TestCase
+
+
+class TestConversion(TestCase):
+ def _run_command(self, cmd, *args, **kwargs):
+ runner = CliRunner()
+ result = runner.invoke(cmd, *args, **kwargs)
+ self.assertEqual(result.exit_code, 0, textwrap.dedent('''
+ Command exited with non-zero exit code:
+ output: {}
+ exception: {}
+ exc_info: {}
+ '''.format(result.output,
+ result.exception,
+ traceback.format_exception(*result.exc_info))))
+ return result
+
+ def test_caffe2_to_onnx(self):
+ caffe2_net = tempfile.NamedTemporaryFile()
+ caffe2_init_net = tempfile.NamedTemporaryFile()
+ output = tempfile.NamedTemporaryFile()
+
+ model = ModelHelper(name='caffe2-to-onnx-test')
+ brew.relu(model, ["X"], "Y")
+ caffe2_net.write(model.net.Proto().SerializeToString())
+ caffe2_net.flush()
+
+ init_model = ModelHelper(name='caffe2-to-onnx-init-test')
+ init_model.net.GivenTensorFill([], 'X', shape=[2, 2],
+ values=np.zeros((2, 2)).flatten().astype(float))
+ caffe2_init_net.write(init_model.net.Proto().SerializeToString())
+ caffe2_init_net.flush()
+
+ result = self._run_command(
+ caffe2_to_onnx, [
+ caffe2_net.name,
+ '--caffe2-init-net', caffe2_init_net.name,
+ '--output', output.name,
+ ],
+ catch_exceptions=False,
+ )
+
+ onnx_model = ModelProto()
+ onnx_model.ParseFromString(output.read())
+ self.assertEqual(len(onnx_model.graph.node), 1)
+ self.assertEqual(onnx_model.graph.node[0].op_type, 'Relu')
+ self.assertEqual(len(onnx_model.graph.initializer), 1)
+ self.assertEqual(onnx_model.graph.initializer[0].name, onnx_model.graph.input[0].name)
+
+ def test_caffe2_to_onnx_value_info(self):
+ caffe2_net = tempfile.NamedTemporaryFile()
+ output = tempfile.NamedTemporaryFile()
+
+ model = ModelHelper(name='caffe2-to-onnx-test')
+ brew.relu(model, ["X"], "Y")
+ caffe2_net.write(model.net.Proto().SerializeToString())
+ caffe2_net.flush()
+
+ args = [caffe2_net.name, '--output', output.name]
+ self.assertRaisesRegexp(Exception,
+ 'value info',
+ self._run_command, caffe2_to_onnx, args)
+
+ args.extend([
+ '--value-info',
+ json.dumps({
+ 'X': (TensorProto.FLOAT, (2, 2)),
+ })])
+ result = self._run_command(caffe2_to_onnx, args)
+
+ onnx_model = ModelProto()
+ onnx_model.ParseFromString(output.read())
+ self.assertEqual(len(onnx_model.graph.node), 1)
+ self.assertEqual(onnx_model.graph.node[0].op_type, 'Relu')
+ self.assertEqual(len(onnx_model.graph.initializer), 0)
+
+ def test_onnx_to_caffe2(self):
+ onnx_model = tempfile.NamedTemporaryFile()
+ output = tempfile.NamedTemporaryFile()
+ init_net_output = tempfile.NamedTemporaryFile()
+
+ node_def = helper.make_node(
+ "Mul", ["X", "W"], ["Y"])
+ graph_def = helper.make_graph(
+ [node_def],
+ "test",
+ [helper.make_tensor_value_info("X", TensorProto.FLOAT, (2, 3)),
+ helper.make_tensor_value_info("W", TensorProto.FLOAT, (3, 2))],
+ [helper.make_tensor_value_info("Y", TensorProto.FLOAT, (2, 2))],
+ initializer=[helper.make_tensor("W",
+ TensorProto.FLOAT,
+ [3, 2],
+ np.zeros((3, 2)).flatten().astype(float))])
+ model_def = make_model(graph_def, producer_name='onnx-to-caffe2-test')
+ onnx_model.write(model_def.SerializeToString())
+ onnx_model.flush()
+
+ result = self._run_command(
+ onnx_to_caffe2, [
+ onnx_model.name,
+ '--output', output.name,
+ '--init-net-output', init_net_output.name,
+ ])
+
+ caffe2_net = caffe2_pb2.NetDef()
+ caffe2_net.ParseFromString(output.read())
+ self.assertEqual(len(caffe2_net.op), 1)
+ self.assertEqual(caffe2_net.op[0].type, 'Mul')
+
+ caffe2_init_net = caffe2_pb2.NetDef()
+ caffe2_init_net.ParseFromString(init_net_output.read())
+ self.assertEqual(len(caffe2_init_net.op), 1)
+ self.assertEqual(set(sum([list(init_op.output)
+ for init_op in caffe2_init_net.op], [])),
+ {'W'})
+
+ def test_convert_end2end(self):
+ predict_net_f = tempfile.NamedTemporaryFile()
+ init_net_f = tempfile.NamedTemporaryFile()
+ onnx_model_f = tempfile.NamedTemporaryFile()
+
+ x = 'X'
+ w = 'W'
+ b = 'b'
+ y = 'Y'
+
+ predict_net = caffe2_pb2.NetDef()
+ predict_net.name = 'test-convert-end2end'
+ predict_net.external_input[:] = [x, w, b]
+ predict_net.external_output[:] = [y]
+ predict_net.op.extend([
+ core.CreateOperator(
+ 'FC',
+ inputs=[x, w, b],
+ outputs=[y],
+ axis=2,
+ ),
+ ])
+ predict_net_f.write(predict_net.SerializeToString())
+ predict_net_f.flush()
+
+ init_net = caffe2_pb2.NetDef()
+ init_net.name = 'test-convert-end2end-init'
+ init_net.external_output[:] = [w, b]
+ x_val = np.random.randn(1, 3, 2).astype(np.float32)
+ w_val = np.random.randn(4, 2).astype(np.float32)
+ b_val = np.random.randn(4).astype(np.float32)
+ init_net.op.extend([
+ core.CreateOperator(
+ 'GivenTensorFill',
+ [],
+ [w],
+ values=w_val,
+ shape=w_val.shape,
+ ),
+ core.CreateOperator(
+ 'GivenTensorFill',
+ [],
+ [b],
+ values=b_val,
+ shape=b_val.shape,
+ ),
+ ])
+ init_net_f.write(init_net.SerializeToString())
+ init_net_f.flush()
+
+ y_val = np.matmul(x_val, w_val.transpose()) + b_val
+ for _ in range(5):
+ self._run_command(
+ caffe2_to_onnx, [
+ predict_net_f.name,
+ '--caffe2-init-net', init_net_f.name,
+ '--output', onnx_model_f.name,
+ '--value-info',
+ json.dumps({
+ x: (TensorProto.FLOAT, (1, 3, 2)),
+ }),
+ ],
+ catch_exceptions=False,
+ )
+
+ onnx_model_f.seek(0)
+ onnx_model = ModelProto()
+ onnx_model.ParseFromString(onnx_model_f.read())
+ np.testing.assert_almost_equal(
+ c2.run_model(
+ onnx_model, {onnx_model.graph.input[0].name: x_val}),
+ [y_val])
+
+ self._run_command(
+ onnx_to_caffe2, [
+ onnx_model_f.name,
+ '--output', predict_net_f.name,
+ '--init-net-output', init_net_f.name,
+ ])
+ predict_net_f.seek(0)
+ predict_net = caffe2_pb2.NetDef()
+ predict_net.ParseFromString(predict_net_f.read())
+ init_net_f.seek(0)
+ init_net = caffe2_pb2.NetDef()
+ init_net.ParseFromString(init_net_f.read())
+ x = predict_net.external_input[0]
+ np.testing.assert_almost_equal(c2_native_run_net(init_net=init_net,
+ predict_net=predict_net,
+ inputs={x: x_val})[1],
+ [y_val])
diff --git a/caffe2/python/onnx/tests/helper_test.py b/caffe2/python/onnx/tests/helper_test.py
new file mode 100644
index 0000000000..f0ce6498cb
--- /dev/null
+++ b/caffe2/python/onnx/tests/helper_test.py
@@ -0,0 +1,45 @@
+# Copyright (c) 2016-present, Facebook, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+##############################################################################
+
+## @package onnx
+# Module caffe2.python.onnx.tests.helper_test
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import unittest
+
+from caffe2.python.onnx.helper import dummy_name
+
+from caffe2.python.onnx.tests.test_utils import TestCase
+
+
+class TestCaffe2Basic(TestCase):
+ def test_dummy_name(self):
+ dummy_name([])
+ names_1 = [dummy_name() for _ in range(3)]
+ dummy_name([])
+ names_2 = [dummy_name() for _ in range(3)]
+ self.assertEqual(names_1, names_2)
+
+ dummy_name(names_1)
+ names_3 = [dummy_name() for _ in range(3)]
+ self.assertFalse(set(names_1) & set(names_3))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/caffe2/python/onnx/tests/onnx_backend_test.py b/caffe2/python/onnx/tests/onnx_backend_test.py
new file mode 100644
index 0000000000..941a15d5a0
--- /dev/null
+++ b/caffe2/python/onnx/tests/onnx_backend_test.py
@@ -0,0 +1,50 @@
+# Copyright (c) 2016-present, Facebook, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+##############################################################################
+
+## @package onnx
+# Module caffe2.python.onnx.tests.onnx_backend_test
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import os
+
+import unittest
+import onnx.backend.test
+
+import caffe2.python.onnx.backend as c2
+
+# This is a pytest magic variable to load extra plugins
+pytest_plugins = 'onnx.backend.test.report',
+
+backend_test = onnx.backend.test.BackendTest(c2, __name__)
+
+backend_test.exclude(r'(test_ceil|test_floor' # Does not support Ceil and Floor.
+ '|test_hardsigmoid|test_pow' # Does not support Hardsigmoid and Pow.
+ '|test_mean|test_hardmax)') # Does not support Mean and Hardmax.
+
+# Skip vgg to speed up CI
+if 'JENKINS_URL' in os.environ:
+ backend_test.exclude(r'(test_vgg19|test_vgg)')
+
+# import all test cases at global scope to make them visible to python.unittest
+globals().update(backend_test
+ .enable_report()
+ .test_cases)
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/caffe2/python/onnx/tests/optimize_onnx_test.py b/caffe2/python/onnx/tests/optimize_onnx_test.py
new file mode 100644
index 0000000000..6efbfc52c9
--- /dev/null
+++ b/caffe2/python/onnx/tests/optimize_onnx_test.py
@@ -0,0 +1,118 @@
+# Copyright (c) 2016-present, Facebook, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+##############################################################################
+
+## @package onnx
+# Module caffe2.python.onnx.tests.optimize_onnx_test
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import os
+import tarfile
+import tempfile
+import unittest
+
+from collections import namedtuple
+from subprocess import Popen, PIPE
+from six.moves.urllib.request import urlretrieve
+import numpy as np
+
+import onnx
+from onnx import helper, ModelProto, TensorProto
+from onnx.backend.test.runner import Runner
+import caffe2.python.onnx.backend as c2
+
+from caffe2.python.onnx.tests.test_utils import TestCase
+
+class TestRoundtrip(TestCase):
+ def _roundtrip(self, model_name):
+ model_dir = Runner(c2)._prepare_model_data(
+ namedtuple('dummy', ['model_name'])(model_name))
+
+ pb_path = os.path.join(model_dir, 'model.pb')
+
+ before_roundtrip = onnx.load(pb_path)
+
+ with open(pb_path, 'rb') as pb:
+ after_roundtrip = onnx.load_from_string(pb.read())
+
+ assert onnx.helper.printable_graph(before_roundtrip.graph) \
+ == onnx.helper.printable_graph(after_roundtrip.graph)
+
+ with open(pb_path, 'rb') as pb:
+ assert after_roundtrip.SerializeToString() == pb.read()
+
+ # arbitrarily pick one relatively small model to sanity test with
+ def test_squeezenet_v3(self):
+ self._roundtrip('squeezenet-ir-version-3')
+
+ # testing just to be sure that we no-op instead of breaking on an
+ # older IR version.
+ def test_squeezenet_v1(self):
+ self._roundtrip('squeezenet-ir-version-1')
+
+class TestOptimize(TestCase):
+ def _optimized(self, graph):
+ orig_model = helper.make_model(graph, producer_name='onnx-to-caffe2-test')
+ orig_model_str = orig_model.SerializeToString()
+ optimized_model_str = c2.Caffe2Backend.optimize_onnx(orig_model_str)
+ optimized_model = ModelProto()
+ optimized_model.ParseFromString(optimized_model_str)
+ return optimized_model
+
+ def test_nop_transpose(self):
+ trans = helper.make_node("Transpose", ["X"], ["Y"], perm=[0,1])
+ graph = helper.make_graph(
+ [trans],
+ "test",
+ [helper.make_tensor_value_info("X", TensorProto.FLOAT, (2, 3))],
+ [helper.make_tensor_value_info("Y", TensorProto.FLOAT, (3, 2))])
+ optimized_model = self._optimized(graph)
+
+ for node in optimized_model.graph.node:
+ assert node.op_type != "Transpose"
+
+ def test_fuse_transpose(self):
+ trans1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[1,0,2])
+ trans2 = helper.make_node("Transpose", ["Y"], ["Z"], perm=[2,0,1])
+ trans3 = helper.make_node("Transpose", ["Z"], ["A"], perm=[2,0,1])
+ graph = helper.make_graph(
+ [trans1, trans2, trans3],
+ "test",
+ [helper.make_tensor_value_info("X", TensorProto.FLOAT, (2, 3, 4))],
+ [helper.make_tensor_value_info("A", TensorProto.FLOAT, (4, 3, 2))])
+ optimized_model = self._optimized(graph)
+
+ assert len(list(optimized_model.graph.node)) == 1
+
+ def test_fuse_transpose_into_gemm(self):
+ trans1 = helper.make_node("Transpose", ["X"], ["A"], perm=[1,0])
+ trans2 = helper.make_node("Transpose", ["Y"], ["B"], perm=[1,0])
+ gemm = helper.make_node("Gemm", ["A", "B", "C"], ["Z"])
+ graph = helper.make_graph(
+ [trans1, trans2, gemm],
+ "test",
+ [helper.make_tensor_value_info("X", TensorProto.FLOAT, (2, 3)),
+ helper.make_tensor_value_info("Y", TensorProto.FLOAT, (5, 2)),
+ helper.make_tensor_value_info("C", TensorProto.FLOAT, (3, 5))],
+ [helper.make_tensor_value_info("Z", TensorProto.FLOAT, (3, 5))])
+ optimized_model = self._optimized(graph)
+
+ assert len(list(optimized_model.graph.node)) == 1
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/caffe2/python/onnx/tests/ssa_test.py b/caffe2/python/onnx/tests/ssa_test.py
new file mode 100644
index 0000000000..851482e494
--- /dev/null
+++ b/caffe2/python/onnx/tests/ssa_test.py
@@ -0,0 +1,122 @@
+# Copyright (c) 2016-present, Facebook, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+##############################################################################
+
+## @package onnx
+# Module caffe2.python.onnx.tests.ssa_test
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import onnx
+import numpy as np
+from caffe2.proto import caffe2_pb2
+from caffe2.python import core
+from onnx import helper, TensorProto
+
+import caffe2.python.onnx.frontend as c2_onnx
+from caffe2.python.onnx.helper import c2_native_run_net
+from caffe2.python.onnx.tests.test_utils import TestCase
+
+
+class TestFrontendSSAConversion(TestCase):
+ def test_ssa(self):
+ X = np.random.randn(4, 2).astype(np.float32)
+ W = np.random.randn(3, 2).astype(np.float32)
+ b = np.random.randn(3).astype(np.float32)
+ s = np.random.randn(1).astype(np.float32)
+ np_result = X.dot(W.transpose()) + b + s
+
+ net = caffe2_pb2.NetDef()
+ net.name = 'test-ssa'
+ net.external_input[:] = ['W', 'X', 'b', 's']
+ net.op.extend([
+ core.CreateOperator(
+ 'FC',
+ ['X', 'W', 'b'],
+ ['Y']
+ ),
+ core.CreateOperator(
+ 'Add',
+ ['Y', 's'],
+ ['Y'],
+ broadcast=True,
+ )
+ ])
+ net.external_output[:] = ['Y']
+
+ init_net = caffe2_pb2.NetDef()
+ init_net.name = 'test-ssa-init'
+ init_net.op.extend([
+ core.CreateOperator(
+ 'GivenTensorFill',
+ [],
+ ['W'],
+ values=W,
+ shape=W.shape,
+ ),
+ core.CreateOperator(
+ 'GivenTensorFill',
+ [],
+ ['b'],
+ values=b,
+ shape=b.shape,
+ ),
+ core.CreateOperator(
+ 'GivenTensorFill',
+ [],
+ ['s'],
+ values=s,
+ shape=s.shape,
+ )
+ ])
+ init_net.external_output[:] = ['W', 'b', 's']
+
+ _, orig_output = c2_native_run_net(
+ predict_net=net,
+ init_net=init_net,
+ inputs=[X])
+
+ value_info = {'X': (TensorProto.FLOAT, X.shape)}
+ c2_onnx.Caffe2Frontend._ssa_rewrite(
+ net,
+ init_net,
+ value_info)
+
+ self.assertEqual(net.external_input, ['W_0', 'X_0', 'b_0', 's_0'])
+ self.assertEqual(net.op[0].input, ['X_0', 'W_0', 'b_0'])
+ self.assertEqual(net.op[0].output, ['Y_1'])
+ self.assertEqual(net.op[1].input, ['Y_1', 's_0'])
+ self.assertEqual(net.op[1].output, ['Y_2'])
+ self.assertEqual(net.external_output, ['Y_2'])
+
+ self.assertEqual(init_net.external_input, [])
+ self.assertEqual(init_net.op[0].input, [])
+ self.assertEqual(init_net.op[0].output, ['W_0'])
+ self.assertEqual(init_net.op[1].input, [])
+ self.assertEqual(init_net.op[1].output, ['b_0'])
+ self.assertEqual(init_net.op[2].input, [])
+ self.assertEqual(init_net.op[2].output, ['s_0'])
+ self.assertEqual(init_net.external_output, ['W_0', 'b_0', 's_0'])
+ self.assertEqual(value_info, {'X_0': (TensorProto.FLOAT, X.shape)})
+
+ _, ssa_output = c2_native_run_net(
+ predict_net=net,
+ init_net=init_net,
+ inputs=[X])
+
+ self.assertSameOutputs(ssa_output, orig_output)
+ self.assertSameOutputs(ssa_output, [np_result])
diff --git a/caffe2/python/onnx/tests/test_utils.py b/caffe2/python/onnx/tests/test_utils.py
new file mode 100644
index 0000000000..8d8a9ceeca
--- /dev/null
+++ b/caffe2/python/onnx/tests/test_utils.py
@@ -0,0 +1,43 @@
+# Copyright (c) 2016-present, Facebook, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+##############################################################################
+
+## @package onnx
+# Module caffe2.python.onnx.tests.test_utils
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import unittest
+
+import numpy as np
+
+
+class TestCase(unittest.TestCase):
+ def setUp(self):
+ np.random.seed(seed=0)
+
+ def assertSameOutputs(self, outputs1, outputs2, decimal=7):
+ self.assertEqual(len(outputs1), len(outputs2))
+ for o1, o2 in zip(outputs1, outputs2):
+ np.testing.assert_almost_equal(o1, o2, decimal=decimal)
+
+ def add_test_case(name, test_func):
+ if not name.startswith('test_'):
+ raise ValueError('Test name must start with test_: {}'.format(name))
+ if hasattr(self, name):
+ raise ValueError('Duplicated test name: {}'.format(name))
+ setattr(self, name, test_func)
diff --git a/caffe2/python/onnx/workspace.py b/caffe2/python/onnx/workspace.py
new file mode 100644
index 0000000000..a115c91861
--- /dev/null
+++ b/caffe2/python/onnx/workspace.py
@@ -0,0 +1,80 @@
+# Copyright (c) 2016-present, Facebook, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+##############################################################################
+
+## @package onnx
+# Module caffe2.python.onnx.workspace
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import uuid
+
+from caffe2.python import workspace
+
+
+class Workspace(object):
+ """
+ An object representing a Caffe2 workspace. It is a context manager,
+ so you can say 'with workspace:' to use the represented workspace
+ as your global workspace. It also supports every method supported
+ by caffe2.python.workspace, but instead of running these operations
+ in the global workspace, it runs them in the workspace represented
+ by this object. When this object goes dead, the workspace (and all
+ nets and blobs within it) are freed.
+
+ Why do we need this class? Caffe2's workspace model is very "global state"
+ oriented, in that there is always some ambient global workspace you are
+ working in which holds on to all of your networks and blobs. This class
+ makes it possible to work with workspaces more locally, and without
+ forgetting to deallocate everything in the end.
+ """
+ def __init__(self):
+ # Caffe2 (apparently) doesn't provide any native method of generating
+ # a fresh, unused workspace, so we have to fake it by generating
+ # a unique ID and hoping it's not used already / will not be used
+ # directly in the future.
+ self.workspace_id = str(uuid.uuid4())
+ # A stack, so that the context manager is reentrant.
+ self.workspace_stack = []
+
+ def __getattr__(self, attr):
+ def f(*args, **kwargs):
+ with self:
+ return getattr(workspace, attr)(*args, **kwargs)
+ return f
+
+ def __enter__(self):
+ self.workspace_stack.append(workspace.CurrentWorkspace())
+ workspace.SwitchWorkspace(self.workspace_id, create_if_missing=True)
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ w = self.workspace_stack.pop()
+ # Strictly speaking, create_if_missing here is unnecessary, since a user
+ # is not supposed to be allowed to destruct a workspace while we're in
+ # it. However, empirically, it has been observed that during abnormal
+ # shutdown, Caffe2 deletes its default workspace fairly early in the
+ # final calls to destructors. In this case, we may attempt to exit
+ # to a default workspace which no longer exists. create_if_missing=True
+ # will (harmlessly) recreate the workspace before we finally quit.)
+ workspace.SwitchWorkspace(w, create_if_missing=True)
+
+ def __del__(self):
+ # NB: This is a 'self' call because we need to switch into the workspace
+ # we want to reset before we actually reset it. A direct call to
+ # workspace.ResetWorkspace() will reset the ambient workspace, which
+ # is not want we want.
+ self.ResetWorkspace()
diff --git a/docker/jenkins/common/install_python.sh b/docker/jenkins/common/install_python.sh
index 5d01999e82..d262d7ed88 100755
--- a/docker/jenkins/common/install_python.sh
+++ b/docker/jenkins/common/install_python.sh
@@ -146,3 +146,4 @@ pip install --no-cache-dir \
scikit-image \
tabulate \
virtualenv
+