diff options
Diffstat (limited to 'runtimes/neurun/src/backend/acl_cl/TensorBuilder.cc')
-rw-r--r-- | runtimes/neurun/src/backend/acl_cl/TensorBuilder.cc | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/runtimes/neurun/src/backend/acl_cl/TensorBuilder.cc b/runtimes/neurun/src/backend/acl_cl/TensorBuilder.cc index b5c038200..92b5c4b4c 100644 --- a/runtimes/neurun/src/backend/acl_cl/TensorBuilder.cc +++ b/runtimes/neurun/src/backend/acl_cl/TensorBuilder.cc @@ -42,6 +42,7 @@ void TensorBuilder::registerTensorInfo(const model::operand::Index &ind, assert(_tensors.size() == 0); _tensor_info_map.insert({ind, info}); + _apply_dim_correction_map.insert({ind, true}); } void TensorBuilder::registerSubTensorInfo(const model::operand::Index &ind, @@ -50,6 +51,7 @@ void TensorBuilder::registerSubTensorInfo(const model::operand::Index &ind, assert(_tensors.size() == 0); _subtensor_info_map.insert({ind, info}); + _apply_dim_correction_map.insert({ind, true}); } void TensorBuilder::notifyFirstUse(const model::operand::Index &) @@ -75,7 +77,9 @@ void TensorBuilder::prepare(void) { auto ind = entry.first; const auto &info = entry.second; - auto tensor = std::make_shared<::neurun::backend::acl_cl::operand::CLTensor>(info); + const auto &tensor_info = + asTensorInfo(info.shape(), info.typeInfo(), _apply_dim_correction_map[ind]); + auto tensor = std::make_shared<::neurun::backend::acl_cl::operand::CLTensor>(tensor_info); _tensors[ind] = tensor; } @@ -134,7 +138,7 @@ void TensorBuilder::prepare(void) assert(info.type().offset() == parent_tensor->info()->quantization_info().offset); assert(info.type().scale() == parent_tensor->info()->quantization_info().scale); assert(asDataType(info.type().type()) == parent_tensor->info()->data_type()); - auto shape = asTensorShape(info.shape()); + auto shape = asTensorShape(info.shape(), _apply_dim_correction_map[current]); // Only support axis: 3 (channel) ::arm_compute::Coordinates coordinates; @@ -241,6 +245,11 @@ bool TensorBuilder::isSubTensorOf(const model::operand::Index &parent, return true; } +void TensorBuilder::dimCorrection(const model::operand::Index &index, bool apply_dim_correction) +{ + _apply_dim_correction_map[index] = apply_dim_correction; +} + } // namespace acl_cl } // namespace backend } // namespace neurun |