diff options
Diffstat (limited to 'runtime/onert/backend/cpu/ops/SliceLayer.cc')
-rw-r--r-- | runtime/onert/backend/cpu/ops/SliceLayer.cc | 24 |
1 files changed, 17 insertions, 7 deletions
diff --git a/runtime/onert/backend/cpu/ops/SliceLayer.cc b/runtime/onert/backend/cpu/ops/SliceLayer.cc index 449c073e6..6332fbb56 100644 --- a/runtime/onert/backend/cpu/ops/SliceLayer.cc +++ b/runtime/onert/backend/cpu/ops/SliceLayer.cc @@ -41,8 +41,8 @@ void SliceLayer::GetBeginAndSizeVectors(int dimensions, const IPortableTensor *b { for (int idx = dimensions - 1; idx >= 0; --idx) { - begins->push_back(reinterpret_cast<T *>(begin->buffer())[idx]); - sizes->push_back(reinterpret_cast<T *>(size->buffer())[idx]); + begins->push_back(getBuffer<T>(begin)[idx]); + sizes->push_back(getBuffer<T>(size)[idx]); } } @@ -55,10 +55,21 @@ template <typename T> void SliceLayer::sliceImpl() begins.reserve(kMaxDim); sizes.reserve(kMaxDim); - GetBeginAndSizeVectors<int32_t>(_input->num_dimensions(), _begin, _size, &begins, &sizes); + if (_begin->data_type() == OperandType::INT32) + { + GetBeginAndSizeVectors<int32_t>(_input->getShape().rank(), _begin, _size, &begins, &sizes); + } + else if (_begin->data_type() == OperandType::INT64) + { + GetBeginAndSizeVectors<int64_t>(_input->getShape().rank(), _begin, _size, &begins, &sizes); + } + else + { + throw std::runtime_error{"Slice: unsupported begin and/or size data type"}; + } // begins : 0-based, sizes : 1-based - for (int i = _input->num_dimensions(); i < kMaxDim; ++i) + for (int i = _input->getShape().rank(); i < kMaxDim; ++i) { begins.push_back(0); sizes.push_back(1); @@ -73,9 +84,8 @@ template <typename T> void SliceLayer::sliceImpl() op_params.size[i] = sizes[3 - i]; } - nnfw::cker::Slice(op_params, getExtendedTensorShape(_input), - reinterpret_cast<const T *>(_input->buffer()), - reinterpret_cast<T *>(_output->buffer())); + nnfw::cker::Slice(op_params, getExtendedTensorShape(_input), getBuffer<T>(_input), + getBuffer<T>(_output)); } void SliceLayer::configure(const IPortableTensor *input, const IPortableTensor *begin, |