diff options
Diffstat (limited to 'model-optimizer/mo/front/kaldi/extractors/pooling_ext.py')
-rw-r--r-- | model-optimizer/mo/front/kaldi/extractors/pooling_ext.py | 14 |
1 files changed, 2 insertions, 12 deletions
diff --git a/model-optimizer/mo/front/kaldi/extractors/pooling_ext.py b/model-optimizer/mo/front/kaldi/extractors/pooling_ext.py index df01ff36f..44c64a32d 100644 --- a/model-optimizer/mo/front/kaldi/extractors/pooling_ext.py +++ b/model-optimizer/mo/front/kaldi/extractors/pooling_ext.py @@ -13,12 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. """ -import numpy as np from mo.front.common.extractors.utils import layout_attrs from mo.front.common.partial_infer.utils import int64_array from mo.front.extractor import FrontExtractorOp -from mo.graph.graph import Node from mo.ops.op import Op @@ -29,20 +27,12 @@ class PoolingFrontExtractor(FrontExtractorOp): @staticmethod def extract(node): mapping_rule = { - 'window': int64_array([1, 1, node.pb.kernel, 1]), + 'window': int64_array([1, 1, 1, node.pb.kernel]), 'stride': int64_array([1, 1, node.pb.stride, node.pb.stride]), 'pool_stride': node.pb.pool_stride, 'pad': int64_array([[0, 0], [0, 0], [0, 0], [0, 0]]), - 'infer': PoolingFrontExtractor.infer + 'pad_spatial_shape': int64_array([[0, 0], [0, 0]]), } mapping_rule.update(layout_attrs()) Op.get_op_class_by_name('Pooling').update_node_stat(node, mapping_rule) return __class__.enabled - - @staticmethod - def infer(node: Node): - batch = node.in_node().in_node().in_node().shape[node.batch_dims] - input_dim_ = node.in_node().in_node().in_node().shape[1] - num_patches = int(np.ceil(input_dim_ / node.pool_stride)) - num_pools = 1 + int(np.ceil((num_patches - node.window[node.spatial_dims][0]) / node.stride[node.spatial_dims][0])) - node.out_node(0).shape = int64_array([batch, node.pool_stride, 1, num_pools]) |