summaryrefslogtreecommitdiff
path: root/model-optimizer/extensions/middle/TensorIteratorInput.py
diff options
context:
space:
mode:
Diffstat (limited to 'model-optimizer/extensions/middle/TensorIteratorInput.py')
-rw-r--r--model-optimizer/extensions/middle/TensorIteratorInput.py184
1 files changed, 184 insertions, 0 deletions
diff --git a/model-optimizer/extensions/middle/TensorIteratorInput.py b/model-optimizer/extensions/middle/TensorIteratorInput.py
new file mode 100644
index 000000000..a36809c24
--- /dev/null
+++ b/model-optimizer/extensions/middle/TensorIteratorInput.py
@@ -0,0 +1,184 @@
+"""
+ Copyright (c) 2018 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+import logging as log
+
+import networkx as nx
+
+from extensions.ops.TensorIterator_ops import TensorIteratorInput
+from mo.middle.replacement import MiddleReplacementPattern
+
+
+class SmartInputMatcher(MiddleReplacementPattern):
+ """
+ This pattern match partitioned inputs for TensorIterator in dynamic_rnn loops in TF.
+ The structure of pattern without Data nodes between ops. Every node is named as op attribute of this node
+ (data nodes is marked by (data)):
+ TensorArray
+ | |
+ v v Condition (data)
+ Flow(data) Handle(data)-------------- |
+ | | | |
+ v v v v
+ Value (data) -> StridedSlice () -> Range(0;1) -> TensorArrayScatter -> Enter -> TensorArrayRead
+ | ^
+ |__________________________________________________|
+ """
+ enabled = True
+
+ @staticmethod
+ def pattern():
+ return dict(
+ nodes=[
+ ('TensorArray', dict(kind='op', op='TensorArrayV3')),
+ ('TensorArray_handle', dict(kind='data')),
+ ('TensorArray_flow', dict(kind='data')),
+ ('Enter', dict(kind='op', op='Enter')),
+ ('Enter_data', dict(kind='data')),
+
+ ('stack', dict(kind='op', op='Const')),
+ ('stack_data', dict(kind='data')),
+ ('stack_1', dict(kind='op', op='Const')),
+ ('stack_1_data', dict(kind='data')),
+ ('stack_2', dict(kind='op', op='Const')),
+ ('stack_2_data', dict(kind='data')),
+
+ ('start', dict(kind='op', op='Const')),
+ ('start_data', dict(kind='data')),
+
+ ('delta', dict(kind='op', op='Const')),
+ ('delta_data', dict(kind='data')),
+
+ ('StridedSlice', dict(kind='op', op='StridedSlice')),
+ ('StridedSlice_data', dict(kind='data')),
+ ('range', dict(kind='op', op='Range')),
+ ('range_data', dict(kind='data')),
+
+ ('TensorArrayScatter', dict(kind='op', op='TensorArrayScatterV3')),
+ ('TensorArrayScatter_data', dict(kind='data')),
+ ('Enter_1', dict(kind='op', op='Enter')),
+ ('Enter_1_data', dict(kind='data')),
+
+ ('TensorArrayRead', dict(kind='op', op='TensorArrayReadV3')),
+ ('TensorArrayRead_data', dict(kind='data')),
+
+ ('Condition_data', dict(kind='data')),
+ ],
+ edges=[
+ ('TensorArray', 'TensorArray_handle'),
+ ('TensorArray', 'TensorArray_flow'),
+ ('TensorArray_handle', 'Enter'),
+ ('Enter', 'Enter_data'),
+
+ ('stack', 'stack_data'),
+ ('stack_1', 'stack_1_data'),
+ ('stack_2', 'stack_2_data'),
+ ('stack_data', 'StridedSlice', {'in': 1}),
+ ('stack_1_data', 'StridedSlice', {'in': 2}),
+ ('stack_2_data', 'StridedSlice', {'in': 3}),
+
+ ('StridedSlice', 'StridedSlice_data'),
+ ('StridedSlice_data', 'range', {'in': 1}),
+ ('start', 'start_data'),
+ ('delta', 'delta_data'),
+
+ ('start_data', 'range', {'in': 0}),
+ ('delta_data', 'range', {'in': 2}),
+ ('range', 'range_data'),
+ ('range_data', 'TensorArrayScatter'),
+
+ ('TensorArray_handle', 'TensorArrayScatter'),
+ ('TensorArray_flow', 'TensorArrayScatter'),
+ ('TensorArrayScatter', 'TensorArrayScatter_data'),
+ ('TensorArrayScatter_data', 'Enter_1'),
+ ('Enter_1', 'Enter_1_data'),
+
+ ('Enter_data', 'TensorArrayRead'),
+ ('Enter_1_data', 'TensorArrayRead'),
+ ('Condition_data', 'TensorArrayRead'),
+ ('TensorArrayRead', 'TensorArrayRead_data'),
+ ],
+ )
+
+ @staticmethod
+ def replace_pattern(graph: nx.MultiDiGraph, match: dict):
+ log.debug('================== SmartInputFind ===============')
+
+ assert match['Enter_data'].value is not None
+ assert match['stack_data']['value'][0] == 0 and match['stack_1_data']['value'][0] == 1 and \
+ match['stack_2_data']['value'][0] == 1
+ assert match['start_data']['value'] == 0 and match['delta_data']['value'] == 1
+
+ ta_size_data = match['TensorArray'].in_node()
+ ta_size = ta_size_data.in_node()
+ value = match['TensorArrayScatter'].in_node(2)
+
+ start, end = None, None
+ if 0 in ta_size.in_nodes():
+ shape = match['StridedSlice'].in_node(0).in_node(0)
+ # Case when value for Strided slice is Const, not Shape
+ if shape['kind'] == 'op' and shape['op'] == 'Const':
+ start = 0
+ end = shape.value[0]
+ log.warning("You network cannot be reshaped since shapes of placeholders is a contants."
+ "Please, provide non-constant shapes. ")
+
+ # Create input node with params
+ # axis == 0 because in TensorArray we ALWAYS iterate over 0 axis, other params will be fill later (with
+ # condition)
+ input_node = TensorIteratorInput(graph, dict(axis=0, start=start, end=end, stride=None, part_size=None,
+ external_port_id=str(match['Enter_data'].value),
+ internal_layer_id=match['TensorArrayRead_data'].id,
+ name=match['TensorArrayRead'].name + '/TensorIteratorInput_'
+ ))
+ input_node.create_node_with_data(inputs=[ta_size_data, value, match['Condition_data']],
+ data_nodes=[match['TensorArrayRead_data']])
+ # Delete useless nodes
+ safe_nodes = ['TensorArrayRead_data', 'Condition', 'Condition_data']
+
+ nodes_for_remove = []
+ for node in match.keys():
+ if node not in safe_nodes:
+ nodes_for_remove.append(match[node].id)
+ graph.remove_nodes_from(nodes_for_remove)
+
+
+class SimpleInputMatcher(MiddleReplacementPattern):
+ """
+ This pattern match simple inputs (without partitions) in while loops in TF (this inputs are set by Enter nodes).
+ """
+ @staticmethod
+ def pattern():
+ return dict(
+ nodes=[
+ ('Enter', dict(kind='op', op='Enter')),
+ ],
+ edges=[
+ ],
+ )
+
+ @staticmethod
+ def replace_pattern(graph: nx.MultiDiGraph, match: dict):
+ log.debug('================== SimpletInputFind ===============')
+
+ input_node = TensorIteratorInput(graph, dict(external_port_id=None,
+ internal_layer_id=None,
+ name=match['Enter'].name + '/TensorIteratorInput_'
+ ))
+ input_node.create_node_with_data(inputs=[match['Enter'].in_node()], data_nodes=[match['Enter'].out_node()])
+
+ # Delete useless nodes
+ graph.remove_nodes_from([match['Enter'].id]) \ No newline at end of file