summaryrefslogtreecommitdiff
path: root/model-optimizer/extensions/middle/UselessMerge.py
diff options
context:
space:
mode:
Diffstat (limited to 'model-optimizer/extensions/middle/UselessMerge.py')
-rw-r--r--model-optimizer/extensions/middle/UselessMerge.py13
1 files changed, 5 insertions, 8 deletions
diff --git a/model-optimizer/extensions/middle/UselessMerge.py b/model-optimizer/extensions/middle/UselessMerge.py
index 7908bdf82..b0923bcd5 100644
--- a/model-optimizer/extensions/middle/UselessMerge.py
+++ b/model-optimizer/extensions/middle/UselessMerge.py
@@ -19,7 +19,7 @@ import logging as log
import networkx as nx
from extensions.middle.ConstSwitchResolver import ConstSwitchEraser
-from mo.graph.graph import erase_node
+from mo.middle.passes.eliminate import remove_op_node_with_data_node
from mo.middle.replacement import MiddleReplacementPattern
@@ -31,14 +31,11 @@ class UselessMergeEraser(MiddleReplacementPattern):
def pattern(self):
return dict(
- nodes=[('merge', dict(kind='op', op='Merge')),
- ('merge_data', dict(kind='data'))],
- edges=[('merge', 'merge_data')]
+ nodes=[('merge', dict(kind='op', op='Merge'))],
+ edges=[]
)
def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
if len(graph.in_edges(match['merge'].id)) <= 1:
- erase_node(match['merge'])
- erase_node(match['merge_data'])
- log.info("Useles Merge op and data nodes was deleted op='{}' data='{}'"
- "".format(match['merge'].id, match['merge_data'].id))
+ remove_op_node_with_data_node(graph, match['merge'])
+ log.info("Useles Merge op and data nodes was deleted op='{}'".format(match['merge'].id))