diff options
Diffstat (limited to 'model-optimizer/extensions/ops/rank.py')
-rw-r--r-- | model-optimizer/extensions/ops/rank.py | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/model-optimizer/extensions/ops/rank.py b/model-optimizer/extensions/ops/rank.py new file mode 100644 index 000000000..ed17048cc --- /dev/null +++ b/model-optimizer/extensions/ops/rank.py @@ -0,0 +1,40 @@ +""" + Copyright (c) 2018 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import networkx as nx +import numpy as np + +from mo.graph.graph import Node +from mo.ops.op import Op +from mo.utils.error import Error + + +class Rank(Op): + op = 'Rank' + + def __init__(self, graph: nx.MultiDiGraph, attrs: dict): + mandatory_props = { + 'op': __class__.op, + 'infer': __class__.infer, + } + super().__init__(graph, mandatory_props, attrs) + + @staticmethod + def infer(node: Node): + rank = len(node.in_node(0).shape) + out_value = np.array(rank) + node.out_node().value = out_value + node.out_node().shape = out_value.shape |