diff options
Diffstat (limited to 'runtimes/pure_arm_compute/src/internal/arm_compute/Cast.cc')
-rw-r--r-- | runtimes/pure_arm_compute/src/internal/arm_compute/Cast.cc | 41 |
1 files changed, 14 insertions, 27 deletions
diff --git a/runtimes/pure_arm_compute/src/internal/arm_compute/Cast.cc b/runtimes/pure_arm_compute/src/internal/arm_compute/Cast.cc index ff2f79309..1a5c735ee 100644 --- a/runtimes/pure_arm_compute/src/internal/arm_compute/Cast.cc +++ b/runtimes/pure_arm_compute/src/internal/arm_compute/Cast.cc @@ -55,36 +55,16 @@ int new_pv[4] = {0}; ::arm_compute::Coordinates axises = getARMComputeAxises(rank); - if (rank == 4) + for (uint32_t i = 0; i < rank; ++i) { - /** - axises = {3,1,0,2} - NNAPI PermutationVector - N 0 3 - H 1 1 - W 2 0 - C 3 2 - **/ - new_pv[0] = axises[runtime_pv[2]]; - new_pv[1] = axises[runtime_pv[1]]; - new_pv[2] = axises[runtime_pv[3]]; - new_pv[3] = axises[runtime_pv[0]]; - } - else - { - /** - mapping/axises = {rank-1 to 0} - CHW --------> WHC - or - WH ----------> HW - **/ - for (int id = 0; id < rank; ++id) - { - new_pv[id] = axises[runtime_pv[rank - id - 1]]; - } + new_pv[axises[i]] = ToARMComputeAxis(rank, runtime_pv[i]).value(); } - return ::arm_compute::PermutationVector{new_pv[0], new_pv[1], new_pv[2], new_pv[3]}; + ::arm_compute::PermutationVector ACL_PV = + ::arm_compute::PermutationVector{new_pv[0], new_pv[1], new_pv[2], new_pv[3]}; + ACL_PV.set_num_dimensions(rank); + + return ACL_PV; } ::arm_compute::TensorShape asTensorShape(const internal::tflite::operand::Shape &shape, @@ -163,3 +143,10 @@ return ::arm_compute::TensorInfo(shape, 1, asDataType(type), asQuantizationInfo(scale, zeroPoint)); } + +::arm_compute::TensorInfo asTensorInfo(const ::arm_compute::TensorShape &shape, + const ::arm_compute::DataType &type, const float scale, + const int32_t zeroPoint) +{ + return ::arm_compute::TensorInfo(shape, 1, type, asQuantizationInfo(scale, zeroPoint)); +} |