summaryrefslogtreecommitdiff
path: root/model-optimizer/extensions/back/EltwiseBroadcast.py
diff options
context:
space:
mode:
Diffstat (limited to 'model-optimizer/extensions/back/EltwiseBroadcast.py')
-rw-r--r--model-optimizer/extensions/back/EltwiseBroadcast.py42
1 files changed, 26 insertions, 16 deletions
diff --git a/model-optimizer/extensions/back/EltwiseBroadcast.py b/model-optimizer/extensions/back/EltwiseBroadcast.py
index 78713ba6f..a75974a0e 100644
--- a/model-optimizer/extensions/back/EltwiseBroadcast.py
+++ b/model-optimizer/extensions/back/EltwiseBroadcast.py
@@ -14,40 +14,51 @@
limitations under the License.
"""
-from mo.ops.tile import Tile
+import logging as log
+
import networkx as nx
+import numpy as np
+
from mo.back.replacement import BackReplacementPattern
from mo.graph.graph import unique_id, Node
-
-import logging as log
+from mo.ops.tile import Tile
class EltwiseBroadcast(BackReplacementPattern):
enabled = True
- def pattern(self):
+ @staticmethod
+ def pattern():
return dict(
nodes=[
('op', dict(kind='op', type='Eltwise'))],
- edges=[],
- node_attrs=['kind', 'type'],
- edge_attrs=[])
+ edges=[]
+ )
- def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
+ @staticmethod
+ def replace_pattern(graph: nx.MultiDiGraph, match: dict):
node = match['op']
shapes = [in_node.shape for _, in_node in node.in_nodes().items()]
out_shape = node.out_node().shape
tname = node.name + '/Broadcast/'
tile = Tile(graph, dict(name=tname))
+
+ # Working with scalar values
+ for i, shape in enumerate(shapes):
+ if len(shape) == 0:
+ shapes[i] = np.ones(len(out_shape), dtype=np.int64)
+ node.in_node(i).shape = shapes[i].copy()
+ if node.in_node(i).value is not None:
+ node.in_node(i).value = np.reshape(node.in_node(i).value, newshape=shapes[i])
+
if not all([len(shape) == len(out_shape) for shape in shapes]):
log.warning("Cannot apply broadcast for Eltwise layer {} "
- "because not all input shapes {} have the same number of elements "
- "as output shape {}.".format(
- node.soft_get('name'),
- shapes,
- out_shape
- )
- )
+ "because not all input shapes {} have the same number of elements "
+ "as output shape {}.".format(node.soft_get('name'),
+ shapes,
+ out_shape
+ )
+ )
return
input_idx = 0
@@ -71,4 +82,3 @@ class EltwiseBroadcast(BackReplacementPattern):
graph.add_edge(input.id, node.id, **graph[old_input.id][node.id][0])
graph.remove_edge(old_input.id, node.id)
input_idx += 1
-