summaryrefslogtreecommitdiff
path: root/runtimes/neurun/src/backend/acl_cl/TensorBuilder.cc
diff options
context:
space:
mode:
Diffstat (limited to 'runtimes/neurun/src/backend/acl_cl/TensorBuilder.cc')
-rw-r--r--runtimes/neurun/src/backend/acl_cl/TensorBuilder.cc13
1 files changed, 11 insertions, 2 deletions
diff --git a/runtimes/neurun/src/backend/acl_cl/TensorBuilder.cc b/runtimes/neurun/src/backend/acl_cl/TensorBuilder.cc
index b5c038200..92b5c4b4c 100644
--- a/runtimes/neurun/src/backend/acl_cl/TensorBuilder.cc
+++ b/runtimes/neurun/src/backend/acl_cl/TensorBuilder.cc
@@ -42,6 +42,7 @@ void TensorBuilder::registerTensorInfo(const model::operand::Index &ind,
assert(_tensors.size() == 0);
_tensor_info_map.insert({ind, info});
+ _apply_dim_correction_map.insert({ind, true});
}
void TensorBuilder::registerSubTensorInfo(const model::operand::Index &ind,
@@ -50,6 +51,7 @@ void TensorBuilder::registerSubTensorInfo(const model::operand::Index &ind,
assert(_tensors.size() == 0);
_subtensor_info_map.insert({ind, info});
+ _apply_dim_correction_map.insert({ind, true});
}
void TensorBuilder::notifyFirstUse(const model::operand::Index &)
@@ -75,7 +77,9 @@ void TensorBuilder::prepare(void)
{
auto ind = entry.first;
const auto &info = entry.second;
- auto tensor = std::make_shared<::neurun::backend::acl_cl::operand::CLTensor>(info);
+ const auto &tensor_info =
+ asTensorInfo(info.shape(), info.typeInfo(), _apply_dim_correction_map[ind]);
+ auto tensor = std::make_shared<::neurun::backend::acl_cl::operand::CLTensor>(tensor_info);
_tensors[ind] = tensor;
}
@@ -134,7 +138,7 @@ void TensorBuilder::prepare(void)
assert(info.type().offset() == parent_tensor->info()->quantization_info().offset);
assert(info.type().scale() == parent_tensor->info()->quantization_info().scale);
assert(asDataType(info.type().type()) == parent_tensor->info()->data_type());
- auto shape = asTensorShape(info.shape());
+ auto shape = asTensorShape(info.shape(), _apply_dim_correction_map[current]);
// Only support axis: 3 (channel)
::arm_compute::Coordinates coordinates;
@@ -241,6 +245,11 @@ bool TensorBuilder::isSubTensorOf(const model::operand::Index &parent,
return true;
}
+void TensorBuilder::dimCorrection(const model::operand::Index &index, bool apply_dim_correction)
+{
+ _apply_dim_correction_map[index] = apply_dim_correction;
+}
+
} // namespace acl_cl
} // namespace backend
} // namespace neurun