summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHyun Sik Yoon <63768207+hyunsik-yoon@users.noreply.github.com>2020-06-25 14:55:24 +0900
committerGitHub <noreply@github.com>2020-06-25 14:55:24 +0900
commit7e9377dd8d9094929b81e18ba95a2f10e0624d48 (patch)
tree07e9dc9acaa05a3388d095f2024497336f2bee8a
parent0221b2090d0a257df3b749053907cd6e75d8e8b6 (diff)
downloadnnfw-7e9377dd8d9094929b81e18ba95a2f10e0624d48.tar.gz
nnfw-7e9377dd8d9094929b81e18ba95a2f10e0624d48.tar.bz2
nnfw-7e9377dd8d9094929b81e18ba95a2f10e0624d48.zip
[onert] Make Transpose op get int32 input (#2604)
This makes Transpose op to handle int32 input. Signed-off-by: Hyun Sik Yoon <hyunsik.yoon.1024@gmail.com>
-rw-r--r--runtime/onert/backend/cpu/ops/TransposeLayer.cc12
-rw-r--r--runtime/onert/backend/cpu/ops/TransposeLayer.h2
2 files changed, 9 insertions, 5 deletions
diff --git a/runtime/onert/backend/cpu/ops/TransposeLayer.cc b/runtime/onert/backend/cpu/ops/TransposeLayer.cc
index ba94ea53a..6712edf30 100644
--- a/runtime/onert/backend/cpu/ops/TransposeLayer.cc
+++ b/runtime/onert/backend/cpu/ops/TransposeLayer.cc
@@ -34,7 +34,7 @@ TransposeLayer::TransposeLayer() : _input(nullptr), _output(nullptr), _perm()
// DO NOTHING
}
-void TransposeLayer::transposeFloat32()
+template <typename T> void TransposeLayer::transpose()
{
nnfw::cker::TransposeParams param;
param.perm_count = _perm.size();
@@ -44,8 +44,8 @@ void TransposeLayer::transposeFloat32()
}
nnfw::cker::Transpose(param, getTensorShape(_input),
- reinterpret_cast<const float *>(_input->buffer()), getTensorShape(_output),
- reinterpret_cast<float *>(_output->buffer()));
+ reinterpret_cast<const T *>(_input->buffer()), getTensorShape(_output),
+ reinterpret_cast<T *>(_output->buffer()));
}
void TransposeLayer::transposeQuant8()
@@ -66,7 +66,11 @@ void TransposeLayer::run()
{
if (_input->data_type() == OperandType::FLOAT32)
{
- transposeFloat32();
+ transpose<float>();
+ }
+ else if (_input->data_type() == OperandType::INT32)
+ {
+ transpose<int32_t>();
}
else if (_input->data_type() == OperandType::QUANT_UINT8_ASYMM)
{
diff --git a/runtime/onert/backend/cpu/ops/TransposeLayer.h b/runtime/onert/backend/cpu/ops/TransposeLayer.h
index 319412a55..ae48af553 100644
--- a/runtime/onert/backend/cpu/ops/TransposeLayer.h
+++ b/runtime/onert/backend/cpu/ops/TransposeLayer.h
@@ -36,7 +36,7 @@ public:
TransposeLayer();
public:
- void transposeFloat32();
+ template <typename T> void transpose();
void transposeQuant8();