summaryrefslogtreecommitdiff
path: root/model-optimizer/extensions/middle/UpsampleToResample.py
diff options
context:
space:
mode:
Diffstat (limited to 'model-optimizer/extensions/middle/UpsampleToResample.py')
-rw-r--r--model-optimizer/extensions/middle/UpsampleToResample.py18
1 files changed, 11 insertions, 7 deletions
diff --git a/model-optimizer/extensions/middle/UpsampleToResample.py b/model-optimizer/extensions/middle/UpsampleToResample.py
index 77227535e..98bbe2d11 100644
--- a/model-optimizer/extensions/middle/UpsampleToResample.py
+++ b/model-optimizer/extensions/middle/UpsampleToResample.py
@@ -22,6 +22,7 @@ import numpy as np
from extensions.ops.elementwise import Mul
from extensions.ops.interpolate import Interpolate
+from mo.front.common.layout import get_height_dim, get_width_dim
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Graph, Node
from mo.middle.replacement import MiddleReplacementPattern
@@ -33,7 +34,6 @@ from mo.ops.strided_slice import StridedSlice
class UpsampleToResample(MiddleReplacementPattern):
enabled = True
force_clean_up = True
- graph_condition = [lambda graph: graph.graph['fw'] == 'onnx']
def run_after(self):
from extensions.middle.pass_separator import MiddleStart
@@ -54,6 +54,7 @@ class UpsampleToResample(MiddleReplacementPattern):
def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
log.debug('UpsampleToResample is triggered')
upsample = match['upsample']
+ input_shape = upsample.in_port(0).data.get_shape()
if len(upsample.in_nodes()) == 2:
if upsample.in_node(1).value is None:
@@ -79,13 +80,15 @@ class UpsampleToResample(MiddleReplacementPattern):
shape = Shape(graph, {'name': upsample.name + '/0_port'}).create_node()
- begin = Const(graph, {'value': np.array([2])}).create_node()
- end = Const(graph, {'value': np.array([4])}).create_node()
- stride = Const(graph, {'value': np.array([1])}).create_node()
+ begin = Const(graph, {'value': int64_array([get_height_dim(graph.graph['layout'],
+ len(input_shape))])}).create_node()
+ end = Const(graph, {'value': int64_array([get_width_dim(graph.graph['layout'],
+ len(input_shape)) + 1])}).create_node()
+ stride = Const(graph, {'value': int64_array([1])}).create_node()
ss = StridedSlice(graph, {'name': upsample.name + '/ss_0_port', 'begin_mask': np.array([1]),
'end_mask': np.array([0]), 'new_axis_mask': np.array([0]),
- 'shrink_axis_mask': np.array([0]),
- 'ellipsis_mask': np.array([0])}).create_node()
+ 'shrink_axis_mask': int64_array([0]),
+ 'ellipsis_mask': int64_array([0])}).create_node()
mul = Mul(graph, {'name': upsample.name + '/factor_mul_'}).create_node()
@@ -99,7 +102,8 @@ class UpsampleToResample(MiddleReplacementPattern):
factor.out_port(0).connect(mul.in_port(1))
# Create Interpolate operation
- axes = int64_array([2, 3]) if graph.graph['layout'] == 'NCHW' else int64_array([1, 2])
+ axes = int64_array([get_height_dim(graph.graph['layout'], len(input_shape)),
+ get_width_dim(graph.graph['layout'], len(input_shape))])
resample_op = Interpolate(graph, dict(name='Interpolate/{}'.format(upsample.name),
factor=factor_value, axes=axes,
mode=upsample.attrs()['mode'],