summaryrefslogtreecommitdiff
path: root/model-optimizer/extensions/ops/select.py
diff options
context:
space:
mode:
Diffstat (limited to 'model-optimizer/extensions/ops/select.py')
-rw-r--r--model-optimizer/extensions/ops/select.py63
1 files changed, 63 insertions, 0 deletions
diff --git a/model-optimizer/extensions/ops/select.py b/model-optimizer/extensions/ops/select.py
new file mode 100644
index 000000000..b377eb2ba
--- /dev/null
+++ b/model-optimizer/extensions/ops/select.py
@@ -0,0 +1,63 @@
+"""
+ Copyright (c) 2018 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+import networkx as nx
+import numpy as np
+
+from mo.graph.graph import Node
+from mo.ops.op import Op
+from mo.utils.error import Error
+
+
+class Select(Op):
+ op = 'Select'
+
+ def __init__(self, graph: nx.MultiDiGraph, attrs: dict):
+ mandatory_props = {
+ 'op': __class__.op,
+ 'infer': __class__.infer,
+ }
+ super().__init__(graph, mandatory_props, attrs)
+
+ @staticmethod
+ def infer(node: Node):
+ assert len(node.in_nodes()) == 3, "Select operation must have 3 inputs by TensorFlow reference:" \
+ " \'condition\', \'then\' and \'else\' tensors"
+ condition_node = node.in_node(0)
+ resulting_tensors = [node.in_node(1), node.in_node(2)]
+
+ assert np.array_equal(resulting_tensors[0].shape, resulting_tensors[1].shape), \
+ "TensorFlow \'Select\' operation has 3 inputs: \'condition\', \'then\' and \'else\' tensors." \
+ "\'then\' and \'else\' tensors must have the same shape by TensorFlow reference"
+ output_shape = resulting_tensors[0].shape
+
+ # Case with unknown condition
+ if not condition_node.has_valid('value'):
+ # infer only shapes
+ for out in node.out_nodes():
+ node.out_node(out).shape = np.array(output_shape)
+ return
+
+ condition_value = condition_node.value[0]
+
+ assert isinstance(condition_value, np.bool_), \
+ "TensorFlow \'Select\' operation has 3 inputs: \'condition\', \'then\' and \'else\' tensors. " \
+ "Value of \'condition\' tensor must be boolen by TensorFlow reference"
+
+ output_value = resulting_tensors[not condition_value].value
+ for _, out_node in node.graph.out_edges(node.id):
+ node.graph.node[out_node]['shape'] = np.array(output_shape)
+ node.graph.node[out_node]['value'] = None if output_value is None else np.array(output_value) \ No newline at end of file