summaryrefslogtreecommitdiff
path: root/model-optimizer/extensions/front/mxnet/rnn_param_concat.py
diff options
context:
space:
mode:
Diffstat (limited to 'model-optimizer/extensions/front/mxnet/rnn_param_concat.py')
-rw-r--r--model-optimizer/extensions/front/mxnet/rnn_param_concat.py35
1 files changed, 35 insertions, 0 deletions
diff --git a/model-optimizer/extensions/front/mxnet/rnn_param_concat.py b/model-optimizer/extensions/front/mxnet/rnn_param_concat.py
new file mode 100644
index 000000000..8b21e7e43
--- /dev/null
+++ b/model-optimizer/extensions/front/mxnet/rnn_param_concat.py
@@ -0,0 +1,35 @@
+"""
+ Copyright (c) 2018 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+from mo.front.extractor import FrontExtractorOp
+from mo.front.mxnet.extractors.utils import get_mxnet_layer_attrs
+from mo.ops.concat import Concat
+
+
+class RNNParamConcatFrontExtractor(FrontExtractorOp):
+ op = '_rnn_param_concat'
+ enabled = True
+
+ @staticmethod
+ def extract(node):
+ attrs = get_mxnet_layer_attrs(node.symbol_dict)
+ data = {
+ 'axis': attrs.int("dim", 1),
+ }
+
+ # update the attributes of the node
+ Concat.update_node_stat(node, data)
+ return __class__.enabled