diff options
Diffstat (limited to 'model-optimizer/mo/front/kaldi/extractor.py')
-rw-r--r-- | model-optimizer/mo/front/kaldi/extractor.py | 24 |
1 files changed, 6 insertions, 18 deletions
diff --git a/model-optimizer/mo/front/kaldi/extractor.py b/model-optimizer/mo/front/kaldi/extractor.py index 2d4e9e1cd..f0e3b3b11 100644 --- a/model-optimizer/mo/front/kaldi/extractor.py +++ b/model-optimizer/mo/front/kaldi/extractor.py @@ -21,39 +21,27 @@ from mo.utils.utils import refer_to_faq_msg def node_pb_arg(pb_extractor): - return lambda node: pb_extractor(node.pb) + return lambda node: pb_extractor(node.parameters) -kaldi_type_extractors = { - # Data Layers - 'globalinput': node_pb_arg(lambda x: dict(op='Placeholder', type='Input', - infer=lambda node: single_output_infer(node, lambda n: n.shape))), - - # Utility Layers - 'softmax': node_pb_arg(lambda _: dict(op='SoftMax', type='SoftMax', infer=copy_shape_infer)), -} +kaldi_type_extractors = {} def common_kaldi_fields(node: Node) -> dict: - pb = node.pb if node.pb else node - layer_type = pb.type + layer_type = node.op return { 'kind': 'op', - 'name': pb.name, - 'type': layer_type, + 'name': node.id, 'op': layer_type, # generic code relies on op; it should be overridden by specific op extractor 'infer': None, - 'precision': 'FP32' # TODO use real precision derived from the model + 'precision': 'FP32' } def kaldi_extractor(node: Node) -> (bool, dict): - if node.has_valid('op') and node.op == 'Identity': - return True, {} result = common_kaldi_fields(node) - - layer_type = result['type'].lower() + layer_type = result['op'] if layer_type not in kaldi_type_extractors: raise Error('Found unsupported layer {}. '.format(node.id) + 'Model Optimizer does not support this layer type: {}. '.format(layer_type) + |