diff options
Diffstat (limited to 'model-optimizer/extensions/front/tf/deconv_ext.py')
-rw-r--r-- | model-optimizer/extensions/front/tf/deconv_ext.py | 81 |
1 files changed, 81 insertions, 0 deletions
diff --git a/model-optimizer/extensions/front/tf/deconv_ext.py b/model-optimizer/extensions/front/tf/deconv_ext.py new file mode 100644 index 000000000..67de85650 --- /dev/null +++ b/model-optimizer/extensions/front/tf/deconv_ext.py @@ -0,0 +1,81 @@ +""" + Copyright (c) 2018 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import numpy as np + +from mo.front.common.partial_infer.utils import convert_tf_padding_to_str, int64_array +from mo.front.extractor import FrontExtractorOp +from mo.front.tf.extractors.utils import tf_data_format_spatial, tf_data_format_channel, tf_data_format_batch, \ + tf_int_list +from mo.ops.deconvolution import Deconvolution +from mo.ops.op import PermuteAttrs + + +class Conv2DBackpropInputFrontExtractor(FrontExtractorOp): + op = 'Conv2DBackpropInput' + enabled = True + + @staticmethod + def extract(node): + attrs = tf_create_attrs(node, 3, 2) + attrs.update({'op': __class__.op, + 'get_weights_permute': PermuteAttrs.Permutation(perm=int64_array([3, 2, 0, 1]), + inv=int64_array([2, 3, 1, 0])) + }) + + # update the attributes of the node + Deconvolution.update_node_stat(node, attrs) + return __class__.enabled + + +class Conv3DBackpropInputV2InputFrontExtractor(FrontExtractorOp): + op = 'Conv3DBackpropInputV2' + enabled = True + + @staticmethod + def extract(node): + attrs = tf_create_attrs(node, 4, 3) + attrs.update({'op': __class__.op, + 'get_weights_permute': PermuteAttrs.Permutation(perm=int64_array([3, 4, 0, 1, 2]), + inv=int64_array([2, 3, 4, 0, 1])) + }) + + # update the attributes of the node + Deconvolution.update_node_stat(node, attrs) + return __class__.enabled + + +def tf_create_attrs(node, input_feature_channel, output_feature_channel): + data_format = node.pb.attr["data_format"] + + return { + 'auto_pad': convert_tf_padding_to_str(node.pb.attr['padding']), + 'bias_addable': True, + 'bias_term': False, + 'spatial_dims': tf_data_format_spatial(data_format), + 'channel_dims': tf_data_format_channel(data_format), + 'batch_dims': tf_data_format_batch(data_format), + 'pad': None, # will be inferred when input shape is known + 'pad_spatial_shape': None, + 'output_spatial_shape': None, + 'output_shape': None, + 'output': None, + 'stride': tf_int_list(node.pb.attr["strides"].list), + 'type': None, # don't set type until we are sure it is really translated to correct IR; see infer function + 'group': None, + 'layout': data_format.s.decode(), + 'input_feature_channel': input_feature_channel, + 'output_feature_channel': output_feature_channel, + } |