diff options
Diffstat (limited to 'model-optimizer/extensions/middle/UpsampleToResample.py')
-rw-r--r-- | model-optimizer/extensions/middle/UpsampleToResample.py | 18 |
1 files changed, 11 insertions, 7 deletions
diff --git a/model-optimizer/extensions/middle/UpsampleToResample.py b/model-optimizer/extensions/middle/UpsampleToResample.py index 77227535e..98bbe2d11 100644 --- a/model-optimizer/extensions/middle/UpsampleToResample.py +++ b/model-optimizer/extensions/middle/UpsampleToResample.py @@ -22,6 +22,7 @@ import numpy as np from extensions.ops.elementwise import Mul from extensions.ops.interpolate import Interpolate +from mo.front.common.layout import get_height_dim, get_width_dim from mo.front.common.partial_infer.utils import int64_array from mo.graph.graph import Graph, Node from mo.middle.replacement import MiddleReplacementPattern @@ -33,7 +34,6 @@ from mo.ops.strided_slice import StridedSlice class UpsampleToResample(MiddleReplacementPattern): enabled = True force_clean_up = True - graph_condition = [lambda graph: graph.graph['fw'] == 'onnx'] def run_after(self): from extensions.middle.pass_separator import MiddleStart @@ -54,6 +54,7 @@ class UpsampleToResample(MiddleReplacementPattern): def replace_pattern(self, graph: Graph, match: Dict[str, Node]): log.debug('UpsampleToResample is triggered') upsample = match['upsample'] + input_shape = upsample.in_port(0).data.get_shape() if len(upsample.in_nodes()) == 2: if upsample.in_node(1).value is None: @@ -79,13 +80,15 @@ class UpsampleToResample(MiddleReplacementPattern): shape = Shape(graph, {'name': upsample.name + '/0_port'}).create_node() - begin = Const(graph, {'value': np.array([2])}).create_node() - end = Const(graph, {'value': np.array([4])}).create_node() - stride = Const(graph, {'value': np.array([1])}).create_node() + begin = Const(graph, {'value': int64_array([get_height_dim(graph.graph['layout'], + len(input_shape))])}).create_node() + end = Const(graph, {'value': int64_array([get_width_dim(graph.graph['layout'], + len(input_shape)) + 1])}).create_node() + stride = Const(graph, {'value': int64_array([1])}).create_node() ss = StridedSlice(graph, {'name': upsample.name + '/ss_0_port', 'begin_mask': np.array([1]), 'end_mask': np.array([0]), 'new_axis_mask': np.array([0]), - 'shrink_axis_mask': np.array([0]), - 'ellipsis_mask': np.array([0])}).create_node() + 'shrink_axis_mask': int64_array([0]), + 'ellipsis_mask': int64_array([0])}).create_node() mul = Mul(graph, {'name': upsample.name + '/factor_mul_'}).create_node() @@ -99,7 +102,8 @@ class UpsampleToResample(MiddleReplacementPattern): factor.out_port(0).connect(mul.in_port(1)) # Create Interpolate operation - axes = int64_array([2, 3]) if graph.graph['layout'] == 'NCHW' else int64_array([1, 2]) + axes = int64_array([get_height_dim(graph.graph['layout'], len(input_shape)), + get_width_dim(graph.graph['layout'], len(input_shape))]) resample_op = Interpolate(graph, dict(name='Interpolate/{}'.format(upsample.name), factor=factor_value, axes=axes, mode=upsample.attrs()['mode'], |