diff options
Diffstat (limited to 'model-optimizer/extensions/ops/splitv.py')
-rw-r--r-- | model-optimizer/extensions/ops/splitv.py | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/model-optimizer/extensions/ops/splitv.py b/model-optimizer/extensions/ops/splitv.py new file mode 100644 index 000000000..db8f3c47b --- /dev/null +++ b/model-optimizer/extensions/ops/splitv.py @@ -0,0 +1,41 @@ +""" + Copyright (c) 2018 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import logging as log + +import networkx as nx +import numpy as np + +from mo.front.common.partial_infer.split import tf_split_v_infer +from mo.graph.graph import Node +from mo.ops.op import Op + + +class SplitV(Op): + op = 'SplitV' + enabled = True + + def __init__(self, graph: nx.MultiDiGraph, attrs: dict): + super().__init__(graph, { + 'type': 'Split', + 'op': 'SplitV', + 'axis' : 1, + 'input_port': 0, + 'infer': tf_split_v_infer + }, attrs) + + def supported_attrs(self): + return ['axis', 'split_sizes'] |