diff options
Diffstat (limited to 'model-optimizer/extensions/middle/UselessMerge.py')
-rw-r--r-- | model-optimizer/extensions/middle/UselessMerge.py | 13 |
1 files changed, 5 insertions, 8 deletions
diff --git a/model-optimizer/extensions/middle/UselessMerge.py b/model-optimizer/extensions/middle/UselessMerge.py index 7908bdf82..b0923bcd5 100644 --- a/model-optimizer/extensions/middle/UselessMerge.py +++ b/model-optimizer/extensions/middle/UselessMerge.py @@ -19,7 +19,7 @@ import logging as log import networkx as nx from extensions.middle.ConstSwitchResolver import ConstSwitchEraser -from mo.graph.graph import erase_node +from mo.middle.passes.eliminate import remove_op_node_with_data_node from mo.middle.replacement import MiddleReplacementPattern @@ -31,14 +31,11 @@ class UselessMergeEraser(MiddleReplacementPattern): def pattern(self): return dict( - nodes=[('merge', dict(kind='op', op='Merge')), - ('merge_data', dict(kind='data'))], - edges=[('merge', 'merge_data')] + nodes=[('merge', dict(kind='op', op='Merge'))], + edges=[] ) def replace_pattern(self, graph: nx.MultiDiGraph, match: dict): if len(graph.in_edges(match['merge'].id)) <= 1: - erase_node(match['merge']) - erase_node(match['merge_data']) - log.info("Useles Merge op and data nodes was deleted op='{}' data='{}'" - "".format(match['merge'].id, match['merge_data'].id)) + remove_op_node_with_data_node(graph, match['merge']) + log.info("Useles Merge op and data nodes was deleted op='{}'".format(match['merge'].id)) |