summaryrefslogtreecommitdiff
path: root/model-optimizer/mo/front/kaldi/extractors/pooling_ext.py
diff options
context:
space:
mode:
Diffstat (limited to 'model-optimizer/mo/front/kaldi/extractors/pooling_ext.py')
-rw-r--r--model-optimizer/mo/front/kaldi/extractors/pooling_ext.py14
1 files changed, 2 insertions, 12 deletions
diff --git a/model-optimizer/mo/front/kaldi/extractors/pooling_ext.py b/model-optimizer/mo/front/kaldi/extractors/pooling_ext.py
index df01ff36f..44c64a32d 100644
--- a/model-optimizer/mo/front/kaldi/extractors/pooling_ext.py
+++ b/model-optimizer/mo/front/kaldi/extractors/pooling_ext.py
@@ -13,12 +13,10 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
-import numpy as np
from mo.front.common.extractors.utils import layout_attrs
from mo.front.common.partial_infer.utils import int64_array
from mo.front.extractor import FrontExtractorOp
-from mo.graph.graph import Node
from mo.ops.op import Op
@@ -29,20 +27,12 @@ class PoolingFrontExtractor(FrontExtractorOp):
@staticmethod
def extract(node):
mapping_rule = {
- 'window': int64_array([1, 1, node.pb.kernel, 1]),
+ 'window': int64_array([1, 1, 1, node.pb.kernel]),
'stride': int64_array([1, 1, node.pb.stride, node.pb.stride]),
'pool_stride': node.pb.pool_stride,
'pad': int64_array([[0, 0], [0, 0], [0, 0], [0, 0]]),
- 'infer': PoolingFrontExtractor.infer
+ 'pad_spatial_shape': int64_array([[0, 0], [0, 0]]),
}
mapping_rule.update(layout_attrs())
Op.get_op_class_by_name('Pooling').update_node_stat(node, mapping_rule)
return __class__.enabled
-
- @staticmethod
- def infer(node: Node):
- batch = node.in_node().in_node().in_node().shape[node.batch_dims]
- input_dim_ = node.in_node().in_node().in_node().shape[1]
- num_patches = int(np.ceil(input_dim_ / node.pool_stride))
- num_pools = 1 + int(np.ceil((num_patches - node.window[node.spatial_dims][0]) / node.stride[node.spatial_dims][0]))
- node.out_node(0).shape = int64_array([batch, node.pool_stride, 1, num_pools])