summaryrefslogtreecommitdiff
path: root/model-optimizer/mo/front/kaldi/extractors/softmax_ext.py
diff options
context:
space:
mode:
authorAlexey Suhov <alexey.suhov@intel.com>2019-10-02 17:30:49 +0300
committerAlexey Suhov <alexey.suhov@intel.com>2019-10-02 17:30:49 +0300
commit2c83de45b9c148c94f582861198d5dfe40b4e65e (patch)
tree85aa192e301c183520d6c233bc58caf02e247485 /model-optimizer/mo/front/kaldi/extractors/softmax_ext.py
parentc37d4661a27afb408a45f7752acea968032afcc0 (diff)
downloaddldt-2c83de45b9c148c94f582861198d5dfe40b4e65e.tar.gz
dldt-2c83de45b9c148c94f582861198d5dfe40b4e65e.tar.bz2
dldt-2c83de45b9c148c94f582861198d5dfe40b4e65e.zip
publish master branch
Diffstat (limited to 'model-optimizer/mo/front/kaldi/extractors/softmax_ext.py')
-rw-r--r--model-optimizer/mo/front/kaldi/extractors/softmax_ext.py37
1 files changed, 37 insertions, 0 deletions
diff --git a/model-optimizer/mo/front/kaldi/extractors/softmax_ext.py b/model-optimizer/mo/front/kaldi/extractors/softmax_ext.py
new file mode 100644
index 000000000..1dee8685c
--- /dev/null
+++ b/model-optimizer/mo/front/kaldi/extractors/softmax_ext.py
@@ -0,0 +1,37 @@
+"""
+ Copyright (c) 2018-2019 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+from mo.front.common.partial_infer.elemental import copy_shape_infer
+from mo.front.extractor import FrontExtractorOp
+from mo.ops.softmax import Softmax
+
+
+class SoftmaxComponentFrontExtractor(FrontExtractorOp):
+ op = 'softmaxcomponent'
+ enabled = True
+
+ @staticmethod
+ def extract(node):
+ return SoftmaxFrontExtractor.extract(node)
+
+
+class SoftmaxFrontExtractor(FrontExtractorOp):
+ op = 'softmax'
+ enabled = True
+
+ @staticmethod
+ def extract(node):
+ Softmax.update_node_stat(node, {'infer': copy_shape_infer})
+ return __class__.enabled