summaryrefslogtreecommitdiff
path: root/include/support/tflite/TensorView.h
diff options
context:
space:
mode:
Diffstat (limited to 'include/support/tflite/TensorView.h')
-rw-r--r--include/support/tflite/TensorView.h42
1 files changed, 34 insertions, 8 deletions
diff --git a/include/support/tflite/TensorView.h b/include/support/tflite/TensorView.h
index 35c90a372..0475a4b45 100644
--- a/include/support/tflite/TensorView.h
+++ b/include/support/tflite/TensorView.h
@@ -31,30 +31,56 @@ namespace support
namespace tflite
{
-template<typename T> class TensorView;
-
-template<> class TensorView<float> final : public nnfw::util::tensor::Reader<float>
+template<typename T> class TensorView final : public nnfw::util::tensor::Reader<T>
{
public:
- TensorView(const nnfw::util::tensor::Shape &shape, float *base);
+ TensorView(const nnfw::util::tensor::Shape &shape, T *base)
+ : _shape{shape}, _base{base}
+ {
+ // Set 'stride'
+ _stride.init(_shape);
+ }
public:
const nnfw::util::tensor::Shape &shape(void) const { return _shape; }
public:
- float at(const nnfw::util::tensor::Index &index) const override;
- float &at(const nnfw::util::tensor::Index &index);
+ T at(const nnfw::util::tensor::Index &index) const override
+ {
+ const auto offset = _stride.offset(index);
+ return *(_base + offset);
+ }
+
+public:
+ T &at(const nnfw::util::tensor::Index &index)
+ {
+ const auto offset = _stride.offset(index);
+ return *(_base + offset);
+ }
private:
nnfw::util::tensor::Shape _shape;
public:
- float *_base;
+ T *_base;
nnfw::util::tensor::NonIncreasingStride _stride;
public:
// TODO Introduce Operand ID class
- static TensorView<float> make(::tflite::Interpreter &interp, int operand_id);
+ static TensorView<T> make(::tflite::Interpreter &interp, int tensor_index)
+ {
+ auto tensor_ptr = interp.tensor(tensor_index);
+
+ // Set 'shape'
+ nnfw::util::tensor::Shape shape(tensor_ptr->dims->size);
+
+ for (uint32_t axis = 0; axis < shape.rank(); ++axis)
+ {
+ shape.dim(axis) = tensor_ptr->dims->data[axis];
+ }
+
+ return TensorView<T>(shape, interp.typed_tensor<T>(tensor_index));
+ }
};
} // namespace tflite