summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--runtime/onert/frontend/base_loader/include/base_loader.h42
1 files changed, 30 insertions, 12 deletions
diff --git a/runtime/onert/frontend/base_loader/include/base_loader.h b/runtime/onert/frontend/base_loader/include/base_loader.h
index 012e9e9b5..7cc72ad86 100644
--- a/runtime/onert/frontend/base_loader/include/base_loader.h
+++ b/runtime/onert/frontend/base_loader/include/base_loader.h
@@ -1048,19 +1048,34 @@ void BaseLoader<LoaderDomain, SpecificLoader>::loadBatchMatMul(const Operator *o
loadOperationIO(op, inputs, outputs);
ir::operation::BatchMatMul::Param param;
- if (op->custom_options() == nullptr)
- {
- param.adj_x = false;
- param.adj_y = false;
- }
- else
+ const auto builtin_op = _model->operator_codes()->Get(op->opcode_index())->builtin_code();
+
+ switch (builtin_op)
{
- size_t custom_op_data_size = op->custom_options()->size();
- auto custom_op_data = op->custom_options()->Data();
- auto data_root = flexbuffers::GetRoot(custom_op_data, custom_op_data_size);
- auto attr_map = data_root.AsMap();
- param.adj_x = attr_map["adj_x"].AsBool();
- param.adj_y = attr_map["adj_y"].AsBool();
+ case BuiltinOperator::BuiltinOperator_BATCH_MATMUL:
+ param.adj_x = op->builtin_options_as_BatchMatMulOptions()->adjoint_lhs();
+ param.adj_y = op->builtin_options_as_BatchMatMulOptions()->adjoint_rhs();
+ break;
+ case BuiltinOperator::BuiltinOperator_CUSTOM:
+ if (op->custom_options() == nullptr)
+ {
+ param.adj_x = false;
+ param.adj_y = false;
+ }
+ else
+ {
+ size_t custom_op_data_size = op->custom_options()->size();
+ auto custom_op_data = op->custom_options()->Data();
+ auto data_root = flexbuffers::GetRoot(custom_op_data, custom_op_data_size);
+ auto attr_map = data_root.AsMap();
+ param.adj_x = attr_map["adj_x"].AsBool();
+ param.adj_y = attr_map["adj_y"].AsBool();
+ }
+ break;
+ default:
+ throw std::runtime_error(
+ std::string("Wrong loaded operation: ").append(EnumNameBuiltinOperator(builtin_op)) +
+ " as " + EnumNameBuiltinOperator(BuiltinOperator::BuiltinOperator_BATCH_MATMUL));
}
std::unique_ptr<ir::Operation> new_op{new ir::operation::BatchMatMul{inputs, outputs, param}};
@@ -2025,6 +2040,9 @@ void BaseLoader<LoaderDomain, SpecificLoader>::loadOperation(const Operator *op,
case BuiltinOperator::BuiltinOperator_RANGE:
loadRange(op, subg);
return;
+ case BuiltinOperator::BuiltinOperator_BATCH_MATMUL:
+ loadBatchMatMul(op, subg);
+ return;
default:
throw std::runtime_error(
std::string("Unsupported operation: ").append(EnumNameBuiltinOperator(builtin_op)));