diff options
author | Alexey Suhov <alexey.suhov@intel.com> | 2019-10-02 17:30:49 +0300 |
---|---|---|
committer | Alexey Suhov <alexey.suhov@intel.com> | 2019-10-02 17:30:49 +0300 |
commit | 2c83de45b9c148c94f582861198d5dfe40b4e65e (patch) | |
tree | 85aa192e301c183520d6c233bc58caf02e247485 /model-optimizer/mo/front/kaldi/extractors/softmax_ext.py | |
parent | c37d4661a27afb408a45f7752acea968032afcc0 (diff) | |
download | dldt-2c83de45b9c148c94f582861198d5dfe40b4e65e.tar.gz dldt-2c83de45b9c148c94f582861198d5dfe40b4e65e.tar.bz2 dldt-2c83de45b9c148c94f582861198d5dfe40b4e65e.zip |
publish master branch
Diffstat (limited to 'model-optimizer/mo/front/kaldi/extractors/softmax_ext.py')
-rw-r--r-- | model-optimizer/mo/front/kaldi/extractors/softmax_ext.py | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/model-optimizer/mo/front/kaldi/extractors/softmax_ext.py b/model-optimizer/mo/front/kaldi/extractors/softmax_ext.py new file mode 100644 index 000000000..1dee8685c --- /dev/null +++ b/model-optimizer/mo/front/kaldi/extractors/softmax_ext.py @@ -0,0 +1,37 @@ +""" + Copyright (c) 2018-2019 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +from mo.front.common.partial_infer.elemental import copy_shape_infer +from mo.front.extractor import FrontExtractorOp +from mo.ops.softmax import Softmax + + +class SoftmaxComponentFrontExtractor(FrontExtractorOp): + op = 'softmaxcomponent' + enabled = True + + @staticmethod + def extract(node): + return SoftmaxFrontExtractor.extract(node) + + +class SoftmaxFrontExtractor(FrontExtractorOp): + op = 'softmax' + enabled = True + + @staticmethod + def extract(node): + Softmax.update_node_stat(node, {'infer': copy_shape_infer}) + return __class__.enabled |