summaryrefslogtreecommitdiff
path: root/runtime/onert/backend/cpu/ops/BatchMatMulLayer.cc
diff options
context:
space:
mode:
Diffstat (limited to 'runtime/onert/backend/cpu/ops/BatchMatMulLayer.cc')
-rw-r--r--runtime/onert/backend/cpu/ops/BatchMatMulLayer.cc17
1 files changed, 8 insertions, 9 deletions
diff --git a/runtime/onert/backend/cpu/ops/BatchMatMulLayer.cc b/runtime/onert/backend/cpu/ops/BatchMatMulLayer.cc
index 7ef023788..3b08fd5b1 100644
--- a/runtime/onert/backend/cpu/ops/BatchMatMulLayer.cc
+++ b/runtime/onert/backend/cpu/ops/BatchMatMulLayer.cc
@@ -28,8 +28,8 @@ namespace ops
{
BatchMatMulLayer::BatchMatMulLayer()
- : _lhs(nullptr), _rhs(nullptr), _output(nullptr), _adj_x(false), _adj_y(false),
- _kernel(new nnfw::cker::BatchMatMul())
+ : _lhs(nullptr), _rhs(nullptr), _output(nullptr), _adj_x(false), _adj_y(false),
+ _kernel(new nnfw::cker::BatchMatMul())
{
// DO NOTHING
}
@@ -39,16 +39,15 @@ BatchMatMulLayer::~BatchMatMulLayer() = default;
void BatchMatMulLayer::batchMatMulFloat32()
{
nnfw::cker::BatchMatMul &batchmatmul_kernel = *_kernel;
- nnfw::cker::Shape lhs_shape = getTensorShape(_lhs);
- nnfw::cker::Shape rhs_shape = getTensorShape(_rhs);
- nnfw::cker::Shape output_shape = getTensorShape(_output);
+ nnfw::cker::Shape lhs_shape = getShape(_lhs);
+ nnfw::cker::Shape rhs_shape = getShape(_rhs);
+ nnfw::cker::Shape output_shape = getShape(_output);
// TODO implement for constant input
batchmatmul_kernel.prepare(lhs_shape, rhs_shape, _adj_x, _adj_y);
- batchmatmul_kernel(lhs_shape, reinterpret_cast<const float *>(_lhs->buffer()), rhs_shape,
- reinterpret_cast<const float *>(_rhs->buffer()), _adj_x, _adj_y, output_shape,
- reinterpret_cast<float *>(_output->buffer()));
+ batchmatmul_kernel(lhs_shape, getBuffer<float>(_lhs), rhs_shape, getBuffer<float>(_rhs), _adj_x,
+ _adj_y, output_shape, getBuffer<float>(_output));
}
void BatchMatMulLayer::configure(const IPortableTensor *lhs, const IPortableTensor *rhs, bool adj_x,
@@ -67,7 +66,7 @@ void BatchMatMulLayer::configure(const IPortableTensor *lhs, const IPortableTens
void BatchMatMulLayer::run()
{
- if (_lhs->data_type() == OperandType::FLOAT32)
+ if ((_lhs->data_type() == OperandType::FLOAT32) && (_rhs->data_type() == OperandType::FLOAT32))
{
batchMatMulFloat32();
}