diff options
Diffstat (limited to 'include/support/tflite/TensorView.h')
-rw-r--r-- | include/support/tflite/TensorView.h | 42 |
1 files changed, 34 insertions, 8 deletions
diff --git a/include/support/tflite/TensorView.h b/include/support/tflite/TensorView.h index 35c90a372..0475a4b45 100644 --- a/include/support/tflite/TensorView.h +++ b/include/support/tflite/TensorView.h @@ -31,30 +31,56 @@ namespace support namespace tflite { -template<typename T> class TensorView; - -template<> class TensorView<float> final : public nnfw::util::tensor::Reader<float> +template<typename T> class TensorView final : public nnfw::util::tensor::Reader<T> { public: - TensorView(const nnfw::util::tensor::Shape &shape, float *base); + TensorView(const nnfw::util::tensor::Shape &shape, T *base) + : _shape{shape}, _base{base} + { + // Set 'stride' + _stride.init(_shape); + } public: const nnfw::util::tensor::Shape &shape(void) const { return _shape; } public: - float at(const nnfw::util::tensor::Index &index) const override; - float &at(const nnfw::util::tensor::Index &index); + T at(const nnfw::util::tensor::Index &index) const override + { + const auto offset = _stride.offset(index); + return *(_base + offset); + } + +public: + T &at(const nnfw::util::tensor::Index &index) + { + const auto offset = _stride.offset(index); + return *(_base + offset); + } private: nnfw::util::tensor::Shape _shape; public: - float *_base; + T *_base; nnfw::util::tensor::NonIncreasingStride _stride; public: // TODO Introduce Operand ID class - static TensorView<float> make(::tflite::Interpreter &interp, int operand_id); + static TensorView<T> make(::tflite::Interpreter &interp, int tensor_index) + { + auto tensor_ptr = interp.tensor(tensor_index); + + // Set 'shape' + nnfw::util::tensor::Shape shape(tensor_ptr->dims->size); + + for (uint32_t axis = 0; axis < shape.rank(); ++axis) + { + shape.dim(axis) = tensor_ptr->dims->data[axis]; + } + + return TensorView<T>(shape, interp.typed_tensor<T>(tensor_index)); + } }; } // namespace tflite |