diff options
Diffstat (limited to 'model-optimizer/extensions/front/mxnet/ssd_pattern_remove_reshape.py')
-rw-r--r-- | model-optimizer/extensions/front/mxnet/ssd_pattern_remove_reshape.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/model-optimizer/extensions/front/mxnet/ssd_pattern_remove_reshape.py b/model-optimizer/extensions/front/mxnet/ssd_pattern_remove_reshape.py index 152e39ba8..2d987d5c7 100644 --- a/model-optimizer/extensions/front/mxnet/ssd_pattern_remove_reshape.py +++ b/model-optimizer/extensions/front/mxnet/ssd_pattern_remove_reshape.py @@ -15,13 +15,13 @@ """ import networkx as nx -from mo.front.mxnet.extractors.utils import get_json_layer_attrs + from mo.front.common.replacement import FrontReplacementSubgraph +from mo.front.mxnet.extractors.utils import get_json_layer_attrs from mo.middle.passes.eliminate import remove_node_from_graph class SsdPatternRemoveReshape(FrontReplacementSubgraph): - enabled = True def pattern(self): @@ -34,9 +34,8 @@ class SsdPatternRemoveReshape(FrontReplacementSubgraph): edges=[ ('multi_box_prior', 'concat', {'in': 0}), ('concat', 'reshape', {'in': 0}) - ], - node_attrs=['op'], - edge_attrs=['in']) + ] + ) def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict): """ @@ -56,4 +55,5 @@ class SsdPatternRemoveReshape(FrontReplacementSubgraph): concat_node = match['concat'] attr = get_json_layer_attrs(concat_node.graph.node[concat_node.id]['symbol_dict']) if 'dim' in attr: - attr['dim'] = 2
\ No newline at end of file + attr['dim'] = 2 + concat_node['axis'] = 2 |