diff options
Diffstat (limited to 'model-optimizer/mo/front/kaldi/extractors/copy_ext.py')
-rw-r--r-- | model-optimizer/mo/front/kaldi/extractors/copy_ext.py | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/model-optimizer/mo/front/kaldi/extractors/copy_ext.py b/model-optimizer/mo/front/kaldi/extractors/copy_ext.py new file mode 100644 index 000000000..3348ef14c --- /dev/null +++ b/model-optimizer/mo/front/kaldi/extractors/copy_ext.py @@ -0,0 +1,40 @@ +""" + Copyright (c) 2018 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import numpy as np + +from mo.front.caffe.extractors.utils import embed_input +from mo.front.common.partial_infer.elemental import copy_shape_infer +from mo.front.extractor import FrontExtractorOp +from mo.front.kaldi.loader.utils import read_binary_integer32_token, read_blob +from mo.ops.permute import Permute + + +class CopyFrontExtractor(FrontExtractorOp): + op = 'copy' + enabled = True + + @staticmethod + def extract(node): + pb = node.parameters + weights_size = read_binary_integer32_token(pb) + weights = read_blob(pb, weights_size, dtype=np.int32) - 1 + attrs = { + 'infer': copy_shape_infer + } + embed_input(attrs, 1, 'indexes', weights) + Permute.update_node_stat(node, attrs) + return __class__.enabled |