summaryrefslogtreecommitdiff
path: root/runtimes/pure_arm_compute/src/internal/arm_compute/Cast.cc
diff options
context:
space:
mode:
Diffstat (limited to 'runtimes/pure_arm_compute/src/internal/arm_compute/Cast.cc')
-rw-r--r--runtimes/pure_arm_compute/src/internal/arm_compute/Cast.cc41
1 files changed, 14 insertions, 27 deletions
diff --git a/runtimes/pure_arm_compute/src/internal/arm_compute/Cast.cc b/runtimes/pure_arm_compute/src/internal/arm_compute/Cast.cc
index ff2f79309..1a5c735ee 100644
--- a/runtimes/pure_arm_compute/src/internal/arm_compute/Cast.cc
+++ b/runtimes/pure_arm_compute/src/internal/arm_compute/Cast.cc
@@ -55,36 +55,16 @@
int new_pv[4] = {0};
::arm_compute::Coordinates axises = getARMComputeAxises(rank);
- if (rank == 4)
+ for (uint32_t i = 0; i < rank; ++i)
{
- /**
- axises = {3,1,0,2}
- NNAPI PermutationVector
- N 0 3
- H 1 1
- W 2 0
- C 3 2
- **/
- new_pv[0] = axises[runtime_pv[2]];
- new_pv[1] = axises[runtime_pv[1]];
- new_pv[2] = axises[runtime_pv[3]];
- new_pv[3] = axises[runtime_pv[0]];
- }
- else
- {
- /**
- mapping/axises = {rank-1 to 0}
- CHW --------> WHC
- or
- WH ----------> HW
- **/
- for (int id = 0; id < rank; ++id)
- {
- new_pv[id] = axises[runtime_pv[rank - id - 1]];
- }
+ new_pv[axises[i]] = ToARMComputeAxis(rank, runtime_pv[i]).value();
}
- return ::arm_compute::PermutationVector{new_pv[0], new_pv[1], new_pv[2], new_pv[3]};
+ ::arm_compute::PermutationVector ACL_PV =
+ ::arm_compute::PermutationVector{new_pv[0], new_pv[1], new_pv[2], new_pv[3]};
+ ACL_PV.set_num_dimensions(rank);
+
+ return ACL_PV;
}
::arm_compute::TensorShape asTensorShape(const internal::tflite::operand::Shape &shape,
@@ -163,3 +143,10 @@
return ::arm_compute::TensorInfo(shape, 1, asDataType(type),
asQuantizationInfo(scale, zeroPoint));
}
+
+::arm_compute::TensorInfo asTensorInfo(const ::arm_compute::TensorShape &shape,
+ const ::arm_compute::DataType &type, const float scale,
+ const int32_t zeroPoint)
+{
+ return ::arm_compute::TensorInfo(shape, 1, type, asQuantizationInfo(scale, zeroPoint));
+}