summaryrefslogtreecommitdiff
path: root/compiler/luci-interpreter/src/kernels/Transpose.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci-interpreter/src/kernels/Transpose.cpp')
-rw-r--r--compiler/luci-interpreter/src/kernels/Transpose.cpp10
1 files changed, 5 insertions, 5 deletions
diff --git a/compiler/luci-interpreter/src/kernels/Transpose.cpp b/compiler/luci-interpreter/src/kernels/Transpose.cpp
index 8265d9937..802d87295 100644
--- a/compiler/luci-interpreter/src/kernels/Transpose.cpp
+++ b/compiler/luci-interpreter/src/kernels/Transpose.cpp
@@ -18,7 +18,7 @@
#include "kernels/Utils.h"
-#include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
+#include <tensorflow/lite/kernels/internal/reference/transpose.h>
#include <stdexcept>
@@ -29,7 +29,7 @@ namespace kernels
{
Transpose::Transpose(const Tensor *input, const Tensor *perm, Tensor *output)
- : Kernel({input, perm}, {output})
+ : Kernel({input, perm}, {output})
{
}
@@ -37,7 +37,7 @@ void Transpose::configure()
{
// Transpose op only supports 1D-4D input arrays.
int dims = input()->shape().num_dims();
- const int *perm_data = getTensorData<int32_t>(perm());
+ const int32_t *perm_data = getTensorData<int32_t>(perm());
assert(input()->shape().num_dims() <= 4);
assert(input()->element_type() == output()->element_type());
@@ -58,8 +58,8 @@ void Transpose::configure()
void Transpose::execute() const
{
tflite::TransposeParams params{};
- const int *perm_data = getTensorData<int32_t>(perm());
- const int size = perm()->shape().dim(0);
+ const int32_t *perm_data = getTensorData<int32_t>(perm());
+ const int32_t size = perm()->shape().dim(0);
params.perm_count = size;
for (int i = 0; i < size; i++)
params.perm[i] = perm_data[i];