summaryrefslogtreecommitdiff
path: root/model-optimizer/extensions/front/mxnet/ssd_pattern_remove_reshape.py
diff options
context:
space:
mode:
Diffstat (limited to 'model-optimizer/extensions/front/mxnet/ssd_pattern_remove_reshape.py')
-rw-r--r--model-optimizer/extensions/front/mxnet/ssd_pattern_remove_reshape.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/model-optimizer/extensions/front/mxnet/ssd_pattern_remove_reshape.py b/model-optimizer/extensions/front/mxnet/ssd_pattern_remove_reshape.py
index 152e39ba8..2d987d5c7 100644
--- a/model-optimizer/extensions/front/mxnet/ssd_pattern_remove_reshape.py
+++ b/model-optimizer/extensions/front/mxnet/ssd_pattern_remove_reshape.py
@@ -15,13 +15,13 @@
"""
import networkx as nx
-from mo.front.mxnet.extractors.utils import get_json_layer_attrs
+
from mo.front.common.replacement import FrontReplacementSubgraph
+from mo.front.mxnet.extractors.utils import get_json_layer_attrs
from mo.middle.passes.eliminate import remove_node_from_graph
class SsdPatternRemoveReshape(FrontReplacementSubgraph):
-
enabled = True
def pattern(self):
@@ -34,9 +34,8 @@ class SsdPatternRemoveReshape(FrontReplacementSubgraph):
edges=[
('multi_box_prior', 'concat', {'in': 0}),
('concat', 'reshape', {'in': 0})
- ],
- node_attrs=['op'],
- edge_attrs=['in'])
+ ]
+ )
def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict):
"""
@@ -56,4 +55,5 @@ class SsdPatternRemoveReshape(FrontReplacementSubgraph):
concat_node = match['concat']
attr = get_json_layer_attrs(concat_node.graph.node[concat_node.id]['symbol_dict'])
if 'dim' in attr:
- attr['dim'] = 2 \ No newline at end of file
+ attr['dim'] = 2
+ concat_node['axis'] = 2