summaryrefslogtreecommitdiff
path: root/model-optimizer/extensions/front/tf/basic_lstm_cell.py
diff options
context:
space:
mode:
Diffstat (limited to 'model-optimizer/extensions/front/tf/basic_lstm_cell.py')
-rw-r--r--model-optimizer/extensions/front/tf/basic_lstm_cell.py204
1 files changed, 204 insertions, 0 deletions
diff --git a/model-optimizer/extensions/front/tf/basic_lstm_cell.py b/model-optimizer/extensions/front/tf/basic_lstm_cell.py
new file mode 100644
index 000000000..c74414852
--- /dev/null
+++ b/model-optimizer/extensions/front/tf/basic_lstm_cell.py
@@ -0,0 +1,204 @@
+"""
+ Copyright (c) 2017-2018 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+import networkx as nx
+
+from extensions.ops.lstm_cell import LSTMCell
+from mo.front.common.replacement import FrontReplacementSubgraph
+from mo.graph.graph import Node, replace_node, get_inputs_with_ports
+
+
+class BasicLSTMCell(FrontReplacementSubgraph):
+ enabled = True
+
+ # list of names of all original nodes that are supported by IE
+ # this list is collected gradually by a separate transformation
+ # original name in this case is a selected node in the pattern
+ # that is returned from anchor() function
+ instances_supported_by_IE = []
+
+ # True if transformation should be activated only for instances collected in supported_by_IE list
+ # It will be set to True by a separate transformation
+ second_round = False
+
+
+ def __init__(self):
+
+ super().__init__()
+
+ # Inputs that are required by LSTMCell operation definition
+ __class__.inputs = ['input_op', 'input_hidden_state', 'input_cell_state', 'weights', 'biases']
+
+ # Extra inputs that are not expected by LSTMCell but required for extra checks
+ # at middle-end partial inference stage. They are consumed by the extended infer function
+ # and then removed.
+ __class__.extra_inputs = ['concat_axis', 'split_axis', 'shift_const']
+
+ __class__.outputs = ['mul_2', 'add_1']
+
+
+ def pattern(self):
+ return dict(
+ nodes=[
+ ('concat_axis', dict()),
+ ('concat', dict(op='ConcatV2')),
+ ('weights', dict()),
+ ('matmul', dict(op='MatMul')),
+ ('biases', dict()),
+ ('biasadd', dict(op='Add')),
+ ('split_axis', dict()),
+ ('split', dict(op='Split')),
+ ('shift_const', dict()),
+ ('shift', dict(op='Add')),
+ ('sigmoid_0', dict(op='Activation', operation='sigmoid')),
+ ('mul_0', dict(op='Mul')),
+ ('sigmoid_1', dict(op='Activation', operation='sigmoid')),
+ ('tanh_0', dict(op='Activation', operation='tanh')),
+ ('mul_1', dict(op='Mul')),
+ ('add_1', dict(op='Add')),
+ ('tanh_1', dict(op='Activation', operation='tanh')),
+ ('sigmoid_2', dict(op='Activation', operation='sigmoid')),
+ ('mul_2', dict(op='Mul'))
+ ],
+ edges=[
+ # This important block specifies how input/hidden are concatenated
+ ('concat_axis', 'concat', {'in': 2}),
+
+ ('concat', 'matmul', {'in': 0}),
+ ('weights', 'matmul', {'in': 1}),
+ ('matmul', 'biasadd', {'in': 0}),
+ ('biases', 'biasadd', {'in': 1}),
+
+ ('split_axis', 'split', {'in': 0}),
+ ('biasadd', 'split', {'in': 1}),
+
+ # This important block specifies how gates are ordered in TF graph
+ ('split', 'sigmoid_1', {'out': 0}), # i
+ ('split', 'tanh_0', {'out': 1}), # c
+ ('split', 'shift', {'out': 2}), # f (this is unbiased f, there is an extra addition here)
+ ('split', 'sigmoid_2', {'out': 3}), # o
+
+ ('shift_const', 'shift', {}),
+ ('shift', 'sigmoid_0', {}),
+ ('sigmoid_0', 'mul_0', {}),
+
+ ('sigmoid_1', 'mul_1', {}),
+ ('tanh_0', 'mul_1', {}),
+
+ ('mul_0', 'add_1', {}),
+ ('mul_1', 'add_1', {}),
+
+ ('add_1', 'tanh_1', {}),
+ ('tanh_1', 'mul_2', {}),
+ ('sigmoid_2', 'mul_2', {}),
+ ])
+
+
+ @staticmethod
+ def mark_supported_by_IE(node: Node):
+ """ Mark a given node as a supported LSTMCell by setting attribute `supported_by_IE`.
+ The node original name is also included in the list of all supported by IE LSTMCell
+ instances for possible second round of the network conversion.
+ """
+ assert node.has_valid('original_name'), \
+ 'Node {} doesn\'t have a reference to original FW operation name; bad LSTMCell'.format(node.soft_get('name'))
+ __class__.instances_supported_by_IE.append(node.original_name)
+ node['supported_by_IE'] = True
+
+
+ @staticmethod
+ def finalize_first_round():
+ """ Switch the mode of this pattern into `second stage` where only supported patterns are converted. """
+ __class__.second_round = True
+
+
+ @staticmethod
+ def anchor():
+ """ Mnemonic name in the pattern that is used as an anchor name for this pattern in the original graph.
+ Used for the second round of the pattern application when only a part of instances is allowed for conversion.
+ """
+ return 'concat'
+
+
+ def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict):
+
+ # node that is used to identify this pattern application instance for switching between supported
+ # and not supported LSTMCell sub-graphs; this value will be searched in __class__.instances_supported_by_IE.
+ anchor_node = match[__class__.anchor()]
+ assert anchor_node.has_valid('name'), \
+ 'LSTMCell anchor node {} does\'t have attribute name; such nodes are not supported.'
+
+ if __class__.second_round and anchor_node.name not in __class__.instances_supported_by_IE:
+ # at the second round of conversion we apply pattern selectively: only instances from
+ # __class__.instances_supported_by_IE are allowed for conversion; all others should be skipped
+ return
+
+ match['input_op'] = match['concat'].in_node(0)
+ match['input_hidden_state'] = match['concat'].in_node(1)
+ match['input_cell_state'] = match['mul_0'].in_node(0) if match['mul_0'].in_node(0).id != match['sigmoid_0'].id \
+ else match['mul_0'].in_node(1)
+
+ pattern_edges = self.pattern()['edges']
+ pattern_edges.extend([('input_op', 'concat'), ('input_cell_state', 'mul_0'), ('input_hidden_state', 'concat')])
+ inputs = get_inputs_with_ports(graph, match, pattern_edges, __class__.inputs + __class__.extra_inputs)
+
+ lstm_op = LSTMCell(graph, dict(
+ name=match['concat'].name + '/LSTMCell',
+ mark_supported_by_IE=__class__.mark_supported_by_IE,
+ original_name=anchor_node.name,
+ finalize_first_round=__class__.finalize_first_round,
+ ))
+ lstm_node = lstm_op.create_node(inputs)
+ lstm_node['old_infer'] = lstm_node.infer
+ lstm_node.infer = __class__.infer
+
+ # this node consumes one of the resulting LSTMCell outputs,
+ # it should be removed before reconnecting the nodes,
+ # otherwise it will be reconnected to the new cell output
+ graph.remove_node(match['tanh_1'].id)
+
+ for i, output in enumerate(__class__.outputs):
+ replace_node(match[output], lstm_node, i)
+
+ lstm_node['tf'] = True
+ lstm_node['extra_inputs'] = {name: match[name].id for name in __class__.extra_inputs}
+ lstm_node['inputs'] = {name: match[name].id for name in __class__.inputs}
+
+
+ @staticmethod
+ def infer(node: Node):
+ assert len(node.in_nodes()) == len(__class__.inputs) + len(__class__.extra_inputs)
+
+ for axis in ['concat_axis', 'split_axis']:
+ axis_node = __class__.extra_inputs.index(axis) + len(__class__.inputs)
+ assert node.in_node(axis_node).has_valid('value')
+ assert node.in_node(axis_node).value == 1
+
+ shift_const = node.in_node(__class__.extra_inputs.index('shift_const') + len(__class__.inputs))
+ assert shift_const.has_valid('value')
+ shift_const = shift_const.value
+ assert shift_const.ndim == 0 # expect scalar value
+ node['shift_const'] = shift_const.copy()
+
+ weights_node = node.in_node(__class__.inputs.index('weights'))
+ biases_node = node.in_node(__class__.inputs.index('biases'))
+
+ assert weights_node.has_valid('value')
+ assert biases_node.has_valid('value')
+
+ # Restore original infer function (to avoid calling previous code twice) and call it
+ node.infer = node.old_infer
+ node.infer(node)