summaryrefslogtreecommitdiff
path: root/model-optimizer/extensions/middle/lstm_sequence_tensor_iterator.py
diff options
context:
space:
mode:
Diffstat (limited to 'model-optimizer/extensions/middle/lstm_sequence_tensor_iterator.py')
-rw-r--r--model-optimizer/extensions/middle/lstm_sequence_tensor_iterator.py151
1 files changed, 151 insertions, 0 deletions
diff --git a/model-optimizer/extensions/middle/lstm_sequence_tensor_iterator.py b/model-optimizer/extensions/middle/lstm_sequence_tensor_iterator.py
new file mode 100644
index 000000000..55dd79aaf
--- /dev/null
+++ b/model-optimizer/extensions/middle/lstm_sequence_tensor_iterator.py
@@ -0,0 +1,151 @@
+"""
+ Copyright (c) 2018 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+import networkx as nx
+import numpy as np
+from copy import deepcopy
+
+from extensions.middle.lstm_sequence_normalize import LSTMSequenceNormalize
+from mo.middle.replacement import MiddleReplacementPattern
+from mo.ops.op import Op
+from mo.ops.permute import Permute
+from mo.ops.reshape import Reshape
+from extensions.ops.lstm_cell import LSTMCell
+from extensions.ops.tensor_iterator import TensorIterator
+from extensions.middle.FusePermutesSequence import FusePermutesSequence
+
+
+class LSTMSequenceTensorIterator(MiddleReplacementPattern):
+ ''' Converts normalized LSTMSequence op to TensorIterator.
+
+ Normalized LSTMSequence means that it should be processed by
+ LSTMSequenceNormalize transform that ensures its stict form.
+
+ This transformation builds an altenative sub-graph for LSTMSequence
+ with TensorIterator connected in the same way as an original LSTMSequence
+ node and with internal body represented as LSTMCell op node with necessary
+ squeezes and unsqueezes around.
+ '''
+
+ enabled = True
+
+
+ def run_after(self):
+ return [LSTMSequenceNormalize]
+
+
+ def run_before(self):
+ return [FusePermutesSequence]
+
+
+ def pattern(self):
+ return dict(
+ nodes=[
+ ('lstm', dict(kind='op', op='LSTMSequence')),
+ ('input', dict(kind='data')),
+ ('weights', dict(kind='data')),
+ ('biases', dict(kind='data')),
+ # don't capture optional input initial states here
+ ('output', dict(kind='data')),
+ # don't capture optional output last states here
+ ],
+ edges=[
+ ('input', 'lstm', {'in': 0}),
+ ('weights', 'lstm', {'bin': 'weights', 'in': 1}),
+ ('biases', 'lstm', {'bin': 'biases', 'in': 2}),
+ ('lstm', 'output', {'out': 0}),
+ ]
+ )
+
+
+ def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
+ lstm = match['lstm']
+
+ # Build TensorIterator body first
+ body = nx.MultiDiGraph(name=lstm.name + '/sub_graph')
+ inputs = [Op._create_data_node(body, lstm.name + '/inport/' + str(inp), {'shape': lstm.in_node(inp).shape.copy(), 'value': lstm.in_node(inp).value.copy() if lstm.in_node(inp).value is not None else None}) for inp in [0, 3, 4, 1, 2]]
+ inputs[0].shape[lstm.sequence_dim] = 1
+ input_squeeze = Reshape(body, dict(name=lstm.name + '/input_squeeze', dim=np.delete(inputs[0].shape, lstm.sequence_dim), internal_layer_id=0))
+ inputs[0] = input_squeeze.create_node_with_data([inputs[0]], edge_attrs=[{'internal_port_id': 0}])
+ lstm_cell_op = LSTMCell(body, dict(hidden_size=match['lstm'].hidden_size, name=lstm.name + '/LSTMCell', internal_layer_id=1))
+ outputs = [Op._create_data_node(body, lstm.name + '/outport/' + str(out), {'shape': lstm.out_node(out).shape.copy() if out in lstm.out_nodes() else lstm.in_node(3).shape.copy()}) for out in [0,1] ]
+ unsqueezed_output_shape = outputs[0].shape.copy()
+ unsqueezed_output_shape[lstm.sequence_dim] = 1
+ squeezed_output_shape = np.delete(unsqueezed_output_shape, lstm.sequence_dim)
+ outputs[0].shape = squeezed_output_shape
+ output_unsqueeze = Reshape(body, dict(name=lstm.name + 'output_unsqueeze', dim=unsqueezed_output_shape, internal_layer_id=2))
+ # TODO edge attributes should be assigned by the op itself
+ lstm_cell_node = lstm_cell_op.create_node_with_data(inputs, data_nodes=outputs, edge_attrs=[{}, {'internal_port_id': 1}, {'internal_port_id': 2}, {'bin': 'weights'}, {'bin': 'biases'}])
+ lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 4
+ lstm_cell_node[0].in_node().out_edge(1)['internal_port_id'] = 5
+ lstm_cell_node[0] = output_unsqueeze.create_node_with_data([lstm_cell_node[0]])
+ lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 3
+
+ ti_op = TensorIterator(graph, {
+ 'name': lstm.name + '/TensorIterator',
+ 'body': body,
+
+ # FOR TESTING PURPOSES
+ 'input_port_map': [
+ {
+ 'external_port_id': 0,
+ 'internal_layer_id': 0,
+ 'internal_port_id': 0,
+ 'axis': lstm.sequence_dim,
+ 'stride': 1,
+ 'part_size': 1,
+ },
+ {
+ 'external_port_id': 1,
+ 'internal_layer_id': 1,
+ 'internal_port_id': 1,
+ },
+ {
+ 'external_port_id': 2,
+ 'internal_layer_id': 1,
+ 'internal_port_id': 2,
+ },
+ ],
+
+ 'output_port_map': [
+ {
+ 'external_port_id': 3,
+ 'internal_layer_id': 2,
+ 'internal_port_id': 3,
+ 'axis': lstm.sequence_dim,
+ 'stride': 1,
+ 'part_size': 1,
+ },
+ ],
+ 'back_edges': [
+ {
+ 'from_layer': 1,
+ 'from_port': 4,
+ 'to_layer': 1,
+ 'to_port': 1,
+ },
+ {
+ 'from_layer': 1,
+ 'from_port': 5,
+ 'to_layer': 1,
+ 'to_port': 2,
+ },
+ ]
+ })
+
+ outs = ti_op.create_node_with_data([lstm.in_node(i) for i in [0, 3, 4]], data_nodes=list(lstm.out_nodes().values()), edge_attrs=[{'external_port_id': 0}, {'external_port_id': 1}, {'external_port_id': 2}])
+ graph.remove_node(lstm.id)
+ outs.in_edge(0)['external_port_id'] = 3