diff options
Diffstat (limited to 'runtime/onert/backend/cpu/ops/BatchMatMulLayer.cc')
-rw-r--r-- | runtime/onert/backend/cpu/ops/BatchMatMulLayer.cc | 17 |
1 files changed, 8 insertions, 9 deletions
diff --git a/runtime/onert/backend/cpu/ops/BatchMatMulLayer.cc b/runtime/onert/backend/cpu/ops/BatchMatMulLayer.cc index 7ef023788..3b08fd5b1 100644 --- a/runtime/onert/backend/cpu/ops/BatchMatMulLayer.cc +++ b/runtime/onert/backend/cpu/ops/BatchMatMulLayer.cc @@ -28,8 +28,8 @@ namespace ops { BatchMatMulLayer::BatchMatMulLayer() - : _lhs(nullptr), _rhs(nullptr), _output(nullptr), _adj_x(false), _adj_y(false), - _kernel(new nnfw::cker::BatchMatMul()) + : _lhs(nullptr), _rhs(nullptr), _output(nullptr), _adj_x(false), _adj_y(false), + _kernel(new nnfw::cker::BatchMatMul()) { // DO NOTHING } @@ -39,16 +39,15 @@ BatchMatMulLayer::~BatchMatMulLayer() = default; void BatchMatMulLayer::batchMatMulFloat32() { nnfw::cker::BatchMatMul &batchmatmul_kernel = *_kernel; - nnfw::cker::Shape lhs_shape = getTensorShape(_lhs); - nnfw::cker::Shape rhs_shape = getTensorShape(_rhs); - nnfw::cker::Shape output_shape = getTensorShape(_output); + nnfw::cker::Shape lhs_shape = getShape(_lhs); + nnfw::cker::Shape rhs_shape = getShape(_rhs); + nnfw::cker::Shape output_shape = getShape(_output); // TODO implement for constant input batchmatmul_kernel.prepare(lhs_shape, rhs_shape, _adj_x, _adj_y); - batchmatmul_kernel(lhs_shape, reinterpret_cast<const float *>(_lhs->buffer()), rhs_shape, - reinterpret_cast<const float *>(_rhs->buffer()), _adj_x, _adj_y, output_shape, - reinterpret_cast<float *>(_output->buffer())); + batchmatmul_kernel(lhs_shape, getBuffer<float>(_lhs), rhs_shape, getBuffer<float>(_rhs), _adj_x, + _adj_y, output_shape, getBuffer<float>(_output)); } void BatchMatMulLayer::configure(const IPortableTensor *lhs, const IPortableTensor *rhs, bool adj_x, @@ -67,7 +66,7 @@ void BatchMatMulLayer::configure(const IPortableTensor *lhs, const IPortableTens void BatchMatMulLayer::run() { - if (_lhs->data_type() == OperandType::FLOAT32) + if ((_lhs->data_type() == OperandType::FLOAT32) && (_rhs->data_type() == OperandType::FLOAT32)) { batchMatMulFloat32(); } |