diff options
Diffstat (limited to 'model-optimizer/extensions/middle/Reduce.py')
-rw-r--r-- | model-optimizer/extensions/middle/Reduce.py | 107 |
1 files changed, 68 insertions, 39 deletions
diff --git a/model-optimizer/extensions/middle/Reduce.py b/model-optimizer/extensions/middle/Reduce.py index 8d8537e84..6c6c91d27 100644 --- a/model-optimizer/extensions/middle/Reduce.py +++ b/model-optimizer/extensions/middle/Reduce.py @@ -19,9 +19,11 @@ import logging as log import networkx as nx import numpy as np +from mo.front.caffe.extractors.utils import get_canonical_axis_index from mo.front.common.layout import get_batch_dim, get_features_dim from mo.middle.replacement import MiddleReplacementPattern from mo.ops.pooling import Pooling +from mo.ops.power import Power from mo.ops.reshape import Reshape @@ -29,7 +31,13 @@ class ReduceReplacer(MiddleReplacementPattern): op = "Reduce" enabled = True - supported_reduce_types = ['mean'] + supported_reduce_types = ['mean', 'max', 'sum'] + + pool_method_map = { + 'max': 'max', + 'mean': 'avg', + 'sum': 'avg' + } def pattern(self): return dict( @@ -41,19 +49,24 @@ class ReduceReplacer(MiddleReplacementPattern): def replace_pattern(self, graph: nx.MultiDiGraph, match: dict): node = match['reduce'] - if not node.has_valid('reduce_type') and node.reduce_type.lower() not in self.supported_reduce_types: + if not node.has_valid('reduce_type') or node.reduce_type.lower() not in self.supported_reduce_types: log.error("Reduce type {} is not supported for node {}".format(node.soft_get('reduce_type'), node.id)) return + reduce_type = node.reduce_type.lower() + if reduce_type not in self.pool_method_map: + log.error("Reduce type {} is not included in pool_method_map. Please update pool_method_map with new key " + "{}".format(reduce_type, reduce_type)) + return + + input_data = node.in_node() + output_data = node.out_node() + input_shape = node.in_node().shape output_shape = node.out_node().shape - ndim = len(input_shape) - # Currently only NCHW layout is supported - layout = graph.graph['layout'] - if layout != 'NCHW': - log.error('{} layout currently is not supported'.format(layout)) - return + # normalize node.axis to exclude negative indices + node.axis = [get_canonical_axis_index(input_shape, a) for a in node.axis] axis = node.axis @@ -63,42 +76,58 @@ class ReduceReplacer(MiddleReplacementPattern): log.error("Reduce with not consecutive axes {} is not supported ".format(axis)) return + layout = graph.graph['layout'] + # So now we are sure that we can convert Reduce to appropriate operation - if node.reduce_type.lower() == 'mean': - # 1. Calculate shape that will be used in reduction - reduction_dim = np.prod([input_shape[idx] for idx in axis]) - begin_dims = np.array([input_shape[idx] for idx in range(axis[0])]) - end_dim = np.prod([input_shape[idx] for idx in range(axis[-1] + 1, len(input_shape))]) - # 2. Create reshape with appropriate shape + # 1. Calculate shape that will be used in reduction + reduction_dim = np.prod([input_shape[idx] for idx in axis]) + begin_dims = np.array([input_shape[idx] for idx in range(axis[0])]) + end_dim = np.prod([input_shape[idx] for idx in range(axis[-1] + 1, len(input_shape))]) + + # 2. Create reshape with appropriate shape + if layout == 'NCHW': if len(begin_dims) > 2: - begin_dims = np.array([np.prod(begin_dims[0:-1], begin_dims[-1])], dtype=np.int64) + begin_dims = np.array([np.prod(begin_dims[0:-1]), begin_dims[-1]], dtype=np.int64) else: # Expand begin_dims to 2 begin_dims = np.array(np.append(begin_dims, [1] * (2 - len(begin_dims))), dtype=np.int64) - reshape_shape = np.array([*begin_dims, reduction_dim, end_dim], dtype=np.int64) + pool_window = np.array([1, 1, reduction_dim, 1], dtype=np.int64) + elif layout == 'NHWC': + begin_dims = np.prod(begin_dims) + reshape_shape = np.array([begin_dims, reduction_dim, 1, end_dim], dtype=np.int64) + pool_window = np.array([1, reduction_dim, 1, 1], dtype=np.int64) + else: + log.error('{} layout currently is not supported'.format(layout)) + return - # 3. Reduce => Reshape->Pooling->Reshape - reshape_op = Reshape(graph, {'name': node.id + '/Reshape', 'dim': reshape_shape}) - final_reshape_op = Reshape(graph, {'name': node.id + '/FinalReshape', 'dim': output_shape}) - pooling_op = Pooling(graph, - dict(name=node.id + '/Pool', window=np.array([1, 1, reduction_dim, 1], dtype=np.int64), - output_spatial_shape=None, - batch_dims=np.array([get_batch_dim(layout, 4)], dtype=np.int64), - channel_dims=np.array([get_features_dim(layout, 4)], dtype=np.int64), - exclude_pad='false', pool_method='avg')) - - input_data = node.in_node() - output_data = node.out_node() - - graph.remove_edge(input_data.id, node.id) - graph.remove_edge(node.id, output_data.id) - - final_reshape_op.create_node_with_data( - inputs=[pooling_op.create_node_with_data( - inputs=[reshape_op.create_node_with_data( - inputs=[input_data] - )] - )], - data_nodes=output_data) + # 3. Reduce => Reshape->Pooling->Reshape + reshape_op = Reshape(graph, {'name': node.id + '/Reshape', 'dim': reshape_shape}) + final_reshape_op = Reshape(graph, {'name': node.id + '/FinalReshape', 'dim': output_shape}) + pooling_op = Pooling(graph, + dict(name=node.id + '/Pool', + window=pool_window, + output_spatial_shape=None, + batch_dims=np.array([get_batch_dim(layout, 4)], dtype=np.int64), + channel_dims=np.array([get_features_dim(layout, 4)], dtype=np.int64), + exclude_pad='false', pool_method=self.pool_method_map[reduce_type])) + + graph.remove_edge(input_data.id, node.id) + graph.remove_edge(node.id, output_data.id) + + final_reshape_op.create_node_with_data( + inputs=[pooling_op.create_node_with_data( + inputs=[reshape_op.create_node_with_data( + inputs=[input_data] + )] + )], + data_nodes=output_data) + + # 4. If it is reduction with summation, we need to multiply by size of the reduction slice with Mul op + if reduce_type == 'sum': + output_data.in_node().insert_node_with_data_after( + output_data, + Power, + {'name': node.name + '/Mul', 'scale': float(reduction_dim)} + ) |