diff options
Diffstat (limited to 'model-optimizer/extensions/back/EltwiseBroadcast.py')
-rw-r--r-- | model-optimizer/extensions/back/EltwiseBroadcast.py | 42 |
1 files changed, 26 insertions, 16 deletions
diff --git a/model-optimizer/extensions/back/EltwiseBroadcast.py b/model-optimizer/extensions/back/EltwiseBroadcast.py index 78713ba6f..a75974a0e 100644 --- a/model-optimizer/extensions/back/EltwiseBroadcast.py +++ b/model-optimizer/extensions/back/EltwiseBroadcast.py @@ -14,40 +14,51 @@ limitations under the License. """ -from mo.ops.tile import Tile +import logging as log + import networkx as nx +import numpy as np + from mo.back.replacement import BackReplacementPattern from mo.graph.graph import unique_id, Node - -import logging as log +from mo.ops.tile import Tile class EltwiseBroadcast(BackReplacementPattern): enabled = True - def pattern(self): + @staticmethod + def pattern(): return dict( nodes=[ ('op', dict(kind='op', type='Eltwise'))], - edges=[], - node_attrs=['kind', 'type'], - edge_attrs=[]) + edges=[] + ) - def replace_pattern(self, graph: nx.MultiDiGraph, match: dict): + @staticmethod + def replace_pattern(graph: nx.MultiDiGraph, match: dict): node = match['op'] shapes = [in_node.shape for _, in_node in node.in_nodes().items()] out_shape = node.out_node().shape tname = node.name + '/Broadcast/' tile = Tile(graph, dict(name=tname)) + + # Working with scalar values + for i, shape in enumerate(shapes): + if len(shape) == 0: + shapes[i] = np.ones(len(out_shape), dtype=np.int64) + node.in_node(i).shape = shapes[i].copy() + if node.in_node(i).value is not None: + node.in_node(i).value = np.reshape(node.in_node(i).value, newshape=shapes[i]) + if not all([len(shape) == len(out_shape) for shape in shapes]): log.warning("Cannot apply broadcast for Eltwise layer {} " - "because not all input shapes {} have the same number of elements " - "as output shape {}.".format( - node.soft_get('name'), - shapes, - out_shape - ) - ) + "because not all input shapes {} have the same number of elements " + "as output shape {}.".format(node.soft_get('name'), + shapes, + out_shape + ) + ) return input_idx = 0 @@ -71,4 +82,3 @@ class EltwiseBroadcast(BackReplacementPattern): graph.add_edge(input.id, node.id, **graph[old_input.id][node.id][0]) graph.remove_edge(old_input.id, node.id) input_idx += 1 - |