diff options
Diffstat (limited to 'runtime/libs/tflite/include/tflite')
16 files changed, 1378 insertions, 0 deletions
diff --git a/runtime/libs/tflite/include/tflite/Assert.h b/runtime/libs/tflite/include/tflite/Assert.h new file mode 100644 index 000000000..148ac7e01 --- /dev/null +++ b/runtime/libs/tflite/include/tflite/Assert.h @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file Assert.h + * @brief This file contains helper function of assertion + * @ingroup COM_AI_RUNTIME + */ + +#ifndef __NNFW_TFLITE_ASSERT_H__ +#define __NNFW_TFLITE_ASSERT_H__ + +#include "tensorflow/lite/context.h" + +#include <sstream> + +#define STR_DETAIL(value) #value +#define STR(value) STR_DETAIL(value) + +#define TFLITE_ENSURE(exp) \ + { \ + const TfLiteStatus status = (exp); \ + \ + if (status != kTfLiteOk) \ + { \ + std::ostringstream ss; \ + ss << #exp << " failed (" << __FILE__ << ":" << __LINE__ << ")"; \ + throw std::runtime_error{ss.str()}; \ + } \ + } + +#endif // __NNFW_TFLITE_ASSERT_H__ diff --git a/runtime/libs/tflite/include/tflite/Diff.h b/runtime/libs/tflite/include/tflite/Diff.h new file mode 100644 index 000000000..eca2fd502 --- /dev/null +++ b/runtime/libs/tflite/include/tflite/Diff.h @@ -0,0 +1,201 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file Diff.h + * @brief This file contains classes for testing correctess of implementation + * @ingroup COM_AI_RUNTIME + */ + +#ifndef __NNFW_TFLITE_DIFF_H__ +#define __NNFW_TFLITE_DIFF_H__ + +#include "tensorflow/lite/interpreter.h" + +#include "misc/tensor/Index.h" +#include "misc/tensor/Diff.h" +#include "misc/tensor/Shape.h" +#include "misc/tensor/Comparator.h" + +#include "tflite/TensorView.h" + +#include <functional> +#include <vector> + +/** + * @brief Class to define TfLite interpreter match application + */ +class TfLiteInterpMatchApp +{ +public: + /** + * @brief Construct a new TfLiteInterpMatchApp object with Comparator + * @param[in] comparator Comparator object for tensor comparation + */ + TfLiteInterpMatchApp(const nnfw::misc::tensor::Comparator &comparator) + : _verbose{false}, _comparator(comparator) + { + // DO NOTHING + } + +public: + /** + * @brief Get reference verbose for debugging information + * @return Reference of verbose value + */ + int &verbose(void) { return _verbose; } + +private: + int _verbose; + +public: + /** + * @brief Run two interpreter and return the output matching + * @param[in] pure Interpreter object of expected(with TfLite) + * @param[in] nnapi Interpreter object of obtained(through NNAPI) + * @return @c true if two Interpreter results are same, otherwise @c false + */ + bool run(::tflite::Interpreter &pure, ::tflite::Interpreter &nnapi) const; + /** + * @brief Compare two TensorView values and return the match result + * @param[in] expected TensorView object to read expected values + * @param[in] obtained TensorView object to read obtained values + * @param[in] id Tensor ID value used for debug message + * @return @c true if two TensorView values are same, otherwise @c false + */ + template <typename T> + bool compareSingleTensorView(const nnfw::tflite::TensorView<T> &expected, + const nnfw::tflite::TensorView<T> &obtained, int id) const; + +private: + const nnfw::misc::tensor::Comparator &_comparator; +}; + +#include "tflite/interp/Builder.h" +#include "tflite/Quantization.h" + +#include <random> + +/** + * @brief Class to generate random values + */ +class RandomGenerator +{ +public: + /** + * @brief Construct a new RandomGenerator object + * @param[in] seed Random seed value + * @param[in] mean Mean value of normal random number generation + * @param[in] stddev Standard deviation of random number generation + * @param[in] quantization TfLiteQuantizationParams type to represent quantization value + * (not used yet) + */ + RandomGenerator(uint32_t seed, float mean, float stddev, + const TfLiteQuantizationParams quantization = make_default_quantization()) + : _rand{seed}, _dist{mean, stddev}, _quantization{quantization} + { + (void)_quantization; + } + +public: + /** + * @brief Generate random numbers for type T + * @param[in] s Shape value + * @param[in] i Index value + * @return Random generated value + * @note This is same as T generate(void) as two input parameters are not used + */ + template <typename T> + T generate(const ::nnfw::misc::tensor::Shape &, const ::nnfw::misc::tensor::Index &) + { + return generate<T>(); + } + + /** + * @brief Generate random numbers for type T + * @return Random generated value + */ + template <typename T> T generate(void) { return _dist(_rand); } + +private: + std::minstd_rand _rand; + std::normal_distribution<float> _dist; + // unused + const TfLiteQuantizationParams _quantization; +}; + +template <> uint8_t RandomGenerator::generate<uint8_t>(void); +template <> bool RandomGenerator::generate<bool>(void); + +/** + * @brief Structure for NNAPI correctness test + */ +struct RandomTestParam +{ + int verbose; //!< Verbosity of debug information + int tolerance; //!< Torlerance of value difference + int tensor_logging = 0; //!< Save logging to a file if not 0 + std::string log_path = ""; //!< Path of log file, meaningful only when tensor_logging is 1 +}; + +/** + * @brief Class to define Random test runner + */ +class RandomTestRunner +{ +public: + /** + * @brief Construct a new RandomTestRunner object + * @param[in] seed Random seed value + * @param[in] param RandomTestParam object for test runner + * @param[in] quantization TfLiteQuantizationParams type to represent quantization value + */ + RandomTestRunner(uint32_t seed, const RandomTestParam ¶m, + const TfLiteQuantizationParams quantization = make_default_quantization()) + : _randgen{seed, 0.0f, 2.0f, quantization}, _param{param} + { + // DO NOTHING + } + +public: + /** + * @brief Run the random test runner + * @param[in] builder Interpreter Builder used to run + * @return 0 if test succeeds, otherwise failure + */ + int run(const nnfw::tflite::Builder &builder); + +public: + /** + * @brief Get RandomGenerator reference + * @return RandomGenerator reference + */ + RandomGenerator &generator() { return _randgen; }; + +private: + RandomGenerator _randgen; + const RandomTestParam _param; + +public: + /** + * @brief Create a RandomTestRunner object + * @param[in] seed Random seed value + * @return RandomGenerator object + */ + static RandomTestRunner make(uint32_t seed); +}; + +#endif // __NNFW_TFLITE_DIFF_H__ diff --git a/runtime/libs/tflite/include/tflite/FeatureView.h b/runtime/libs/tflite/include/tflite/FeatureView.h new file mode 100644 index 000000000..a8f069c40 --- /dev/null +++ b/runtime/libs/tflite/include/tflite/FeatureView.h @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file FeatureView.h + * @brief This file contains FeatureView class + * @ingroup COM_AI_RUNTIME + */ + +#ifndef __NNFW_TFLITE_FEATURE_VIEW_H__ +#define __NNFW_TFLITE_FEATURE_VIEW_H__ + +#include "tensorflow/lite/interpreter.h" + +#include "tflite/InputIndex.h" +#include "tflite/OutputIndex.h" + +#include "misc/feature/Shape.h" +#include "misc/feature/Reader.h" + +namespace nnfw +{ +namespace tflite +{ + +template <typename T> class FeatureView; + +/** + * @brief Class to support reading element of float type feature + */ +template <> class FeatureView<float> : public nnfw::misc::feature::Reader<float> +{ +public: + /** + * @brief Construct a new FeatureView object + * @param[in] interp Interpreter to read from + * @param[in] index InputIndex index of input + */ + FeatureView(::tflite::Interpreter &interp, const InputIndex &index); + /** + * @brief Construct a new FeatureView object + * @param[in] interp Interpreter to read from + * @param[in] index OutputIndex index of output + */ + FeatureView(::tflite::Interpreter &interp, const OutputIndex &index); + +public: + /** + * @brief Get value of element using channel, row and column index + * @param[in] ch Channel index + * @param[in] row Row index + * @param[in] col Column index + * @return Value of element + */ + float at(uint32_t ch, uint32_t row, uint32_t col) const; + /** + * @brief Get reference of element using channel, row and column index + * @param[in] ch Channel index + * @param[in] row Row index + * @param[in] col Column index + * @return Reference of element + */ + float &at(uint32_t ch, uint32_t row, uint32_t col); + + float at(uint32_t batch, uint32_t ch, uint32_t row, uint32_t col) const = 0; + +private: + /** + * @brief Get offset of element from channel, row and column index + * @param[in] ch Channel index + * @param[in] row Row index + * @param[in] col Column index + * @return Offset of element + */ + uint32_t getElementOffset(uint32_t ch, uint32_t row, uint32_t col) const + { + uint32_t res = 0; + + // TensorFlow Lite assumes that NHWC ordering for tessor + res += row * _shape.W * _shape.C; + res += col * _shape.C; + res += ch; + + return res; + } + +private: + nnfw::misc::feature::Shape _shape; + float *_base; +}; + +} // namespace tflite +} // namespace nnfw + +#endif // __NNFW_TFLITE_FEATURE_VIEW_H__ diff --git a/runtime/libs/tflite/include/tflite/InputIndex.h b/runtime/libs/tflite/include/tflite/InputIndex.h new file mode 100644 index 000000000..f535b2626 --- /dev/null +++ b/runtime/libs/tflite/include/tflite/InputIndex.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file InputIndex.h + * @brief This file contains InputIndex class + * @ingroup COM_AI_RUNTIME + */ + +#ifndef __NNFW_TFLITE_INPUT_INDEX_H__ +#define __NNFW_TFLITE_INPUT_INDEX_H__ + +namespace nnfw +{ +namespace tflite +{ + +/** + * @brief Class to express index of input + */ +class InputIndex +{ +public: + /** + * @brief Construct a new InputIndex object with index value + * @param [in] index The value of index + */ + InputIndex(int index) : _index(index) + { + // DO NOTHING + } + +public: + /** + * @brief Get index value as int + * @return Index value as int + */ + int asInt(void) const { return _index; } + +private: + int _index; +}; + +} // namespace tflite +} // namespace nnfw + +#endif // __NNFW_TFLITE_INPUT_INDEX_H__ diff --git a/runtime/libs/tflite/include/tflite/InterpreterSession.h b/runtime/libs/tflite/include/tflite/InterpreterSession.h new file mode 100644 index 000000000..deaf05a7f --- /dev/null +++ b/runtime/libs/tflite/include/tflite/InterpreterSession.h @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file InterpreterSession.h + * @brief This file contains InterpreterSession class + * @ingroup COM_AI_RUNTIME + */ + +#ifndef __NNFW_TFLITE_INTERPRETER_SESSION_H__ +#define __NNFW_TFLITE_INTERPRETER_SESSION_H__ + +#include "Session.h" + +namespace nnfw +{ +namespace tflite +{ + +/** + * @brief Class to define TfLite interpreter session which is inherited from Session class + */ +class InterpreterSession final : public Session +{ +public: + /** + * @brief Construct a InterpreterSession object with interpreter of TfLite + * @param[in] interp The TfLite interpreter pointer + */ + InterpreterSession(::tflite::Interpreter *interp) : _interp{interp} + { + // DO NOTHING + } + +public: + /** + * @brief Get TfLite interpreter pointer + * @return The TfLite interpreter + */ + ::tflite::Interpreter *interp(void) override { return _interp; } + +public: + /** + * @brief Prepare the TfLite interpreter session + * @return @c true if tensor preparation is successful, otherwise @c false + */ + bool prepare(void) override + { + _interp->UseNNAPI(false); + + if (kTfLiteOk != _interp->AllocateTensors()) + { + return false; + } + + return true; + } + + /** + * @brief Run the Invoke function of TfLite interpreter + * @return @c true if Invoke() is successful, otherwise @c false + */ + bool run(void) override + { + // Return true if Invoke returns kTfLiteOk + return kTfLiteOk == _interp->Invoke(); + } + + /** + * @brief Tear down TfLite interpreter session + * @return @c true always + */ + bool teardown(void) override + { + // Do NOTHING currently + return true; + } + +private: + ::tflite::Interpreter *const _interp; +}; + +} // namespace tflite +} // namespace nnfw + +#endif // __NNFW_TFLITE_INTERPRETER_SESSION_H__ diff --git a/runtime/libs/tflite/include/tflite/NNAPISession.h b/runtime/libs/tflite/include/tflite/NNAPISession.h new file mode 100644 index 000000000..f430e86d3 --- /dev/null +++ b/runtime/libs/tflite/include/tflite/NNAPISession.h @@ -0,0 +1,102 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file NNAPISession.h + * @brief This file contains NNAPISession class + * @ingroup COM_AI_RUNTIME + */ + +#ifndef __NNFW_TFLITE_NNAPI_SESSION_H__ +#define __NNFW_TFLITE_NNAPI_SESSION_H__ + +#include "Session.h" +#include "tflite/ext/nnapi_delegate.h" + +namespace nnfw +{ +namespace tflite +{ + +/** + * @brief Class to define NNAPI interpreter session which is inherited from Session class + */ +class NNAPISession final : public Session +{ +public: + /** + * @brief Construct a NNAPISession object with interpreter of TfLite + * @param[in] interp The TfLite interpreter pointer + * @note Invoke BuildGraph() of NNAPI delegate from Interpreter + */ + NNAPISession(::tflite::Interpreter *interp) : _interp{interp} + { + // Construct Graph from Interpreter + // primary_subgraph: Experimental interface. Return 1st sugbraph + _delegate.BuildGraph(&interp->primary_subgraph()); + } + +public: + /** + * @brief Get TfLite interpreter pointer + * @return The TfLite interpreter + */ + ::tflite::Interpreter *interp(void) override { return _interp; } + +public: + /** + * @brief Prepare the TfLite interpreter session + * @return @c true if tensor preparation is successful, otherwise @c false + */ + bool prepare(void) override + { + // Explicitly turn off T/F lite internal NNAPI delegation in order to use locally defined + // NNAPI delegation. + _interp->UseNNAPI(false); + + if (kTfLiteOk != _interp->AllocateTensors()) + { + return false; + } + + return true; + } + + /** + * @brief Run the Invoke function of NNAPI delegate + * @return @c true if Invoke() is successful, otherwise @c false + */ + bool run(void) override { return kTfLiteOk == _delegate.Invoke(&_interp->primary_subgraph()); } + + /** + * @brief Tear down TfLite interpreter session + * @return @c true always + */ + bool teardown(void) override + { + // DO NOTHING + return true; + } + +private: + ::tflite::Interpreter *const _interp; + nnfw::tflite::NNAPIDelegate _delegate; +}; + +} // namespace tflite +} // namespace nnfw + +#endif // __NNFW_TFLITE_NNAPI_SESSION_H__ diff --git a/runtime/libs/tflite/include/tflite/OutputIndex.h b/runtime/libs/tflite/include/tflite/OutputIndex.h new file mode 100644 index 000000000..dd1ca8d44 --- /dev/null +++ b/runtime/libs/tflite/include/tflite/OutputIndex.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file OutputIndex.h + * @brief This file contains OutputIndex class + * @ingroup COM_AI_RUNTIME + */ + +#ifndef __NNFW_TFLITE_OUTPUT_INDEX_H__ +#define __NNFW_TFLITE_OUTPUT_INDEX_H__ + +namespace nnfw +{ +namespace tflite +{ + +/** + * @brief Class to define OutputIndex + */ +class OutputIndex +{ +public: + /** + * @brief Construct a OutputIndex object with index value + * @param[in] index The value of index + */ + OutputIndex(int index) : _index(index) + { + // DO NOTHING + } + +public: + /** + * @brief Get index value as int + * @return Index valuel as int + */ + int asInt(void) const { return _index; } + +private: + int _index; +}; + +} // namespace tflite +} // namespace nnfw + +#endif // __NNFW_TFLITE_OUTPUT_INDEX_H__ diff --git a/runtime/libs/tflite/include/tflite/Quantization.h b/runtime/libs/tflite/include/tflite/Quantization.h new file mode 100644 index 000000000..8272bcdc0 --- /dev/null +++ b/runtime/libs/tflite/include/tflite/Quantization.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file Quantization.h + * @brief This file contains BitwiseIntToFloat union and quantization related + * @ingroup COM_AI_RUNTIME + */ + +#ifndef __NNFW_TFLITE_QUANTIZATION_H__ +#define __NNFW_TFLITE_QUANTIZATION_H__ + +/** + * @brief Union to provide bitwise conversion of integer and float + */ +union BitwiseIntToFloat { + int i; + float f; +}; + +static const float FLOAT_NEAREST_TO_1 = BitwiseIntToFloat{0x3f7fffff}.f; + +#include "tensorflow/lite/context.h" + +/** + * @brief Get TfLiteQuantizationParams object with default values + * @return TfLiteQuantizationParams object + */ +TfLiteQuantizationParams make_default_quantization(void); + +#endif // __NNFW_TFLITE_QUANTIZATION_H__ diff --git a/runtime/libs/tflite/include/tflite/Session.h b/runtime/libs/tflite/include/tflite/Session.h new file mode 100644 index 000000000..b653acf61 --- /dev/null +++ b/runtime/libs/tflite/include/tflite/Session.h @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file Session.h + * @brief This file contains Session class + * @ingroup COM_AI_RUNTIME + */ + +#ifndef __NNFW_TFLITE_SESSION_H__ +#define __NNFW_TFLITE_SESSION_H__ + +#include <tensorflow/lite/interpreter.h> + +namespace nnfw +{ +namespace tflite +{ + +/** + * @brief Structure to provide interface methods of interpreter session + */ +struct Session +{ + /** + * @brief Destruct Session object using default destructor + */ + virtual ~Session() = default; + + /** + * @brief Get the Interpreter object pointer + * @return The Interpreter object pointer + */ + virtual ::tflite::Interpreter *interp(void) = 0; + + /** + * @brief Prepare the session + * @return @c true if prepare method succeeded, otherwise @c false + */ + virtual bool prepare(void) = 0; + /** + * @brief Run the session + * @return @c true if run method succeeded, otherwise @c false + */ + virtual bool run(void) = 0; + /** + * @brief Teardown(release) the session + * @return @c true if teardown method succeeded, otherwise @c false + */ + virtual bool teardown(void) = 0; +}; + +} // namespace tflite +} // namespace nnfw + +#endif // __NNFW_TFLITE_INTERP_SESSION_H__ diff --git a/runtime/libs/tflite/include/tflite/TensorLogger.h b/runtime/libs/tflite/include/tflite/TensorLogger.h new file mode 100644 index 000000000..a824c3411 --- /dev/null +++ b/runtime/libs/tflite/include/tflite/TensorLogger.h @@ -0,0 +1,168 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file TensorLogger.h + * @brief This file contains TensorLogger class + * @ingroup COM_AI_RUNTIME + */ + +#ifndef __NNFW_TFLITE_TENSOR_LOGGER_H__ +#define __NNFW_TFLITE_TENSOR_LOGGER_H__ + +#include "misc/tensor/IndexIterator.h" +#include "tflite/TensorView.h" + +#include <tensorflow/lite/interpreter.h> +#include <tensorflow/lite/context.h> +#include <fstream> +#include <iomanip> + +namespace nnfw +{ +namespace tflite +{ + +/** + * @brief Class to write input and output value / shape into a file in python form + * @note This is a utility to write input and output value / shape into a file in python form.\n + * any python app can load this value by running the python code below:\n + * exec(open(filename).read())\n + * generated python code looks like the following: \n + * tensor_shape_gen = []\n + * tensor_value_gen = []\n\n + * tensor_shape_gen.append("{2, 1, 2}")\n + * tensor_value_gen.append([1, 2, 3, 4])\n\n + * tensor_shape_gen.append("{2}")\n + * tensor_value_gen.append([1, 2])\n\n + * tensor_shape_gen.append("{2, 1, 2}")\n + * tensor_value_gen.append([1, 4, 3, 8])\n + */ +class TensorLogger +{ +private: + std::ofstream _outfile; + +public: + /** + * @brief Get TensorLogger instance + * @return The TensorLogger instance + */ + static TensorLogger &get() + { + static TensorLogger instance; + return instance; + } + + /** + * @brief Save the tensor details to file from interpreter + * @param[in] path The file path to save + * @param[in] interp The TfLite interpreter + */ + void save(const std::string &path, ::tflite::Interpreter &interp) + { + open(path); + + int log_index = 0; + for (const auto id : interp.inputs()) + { + _outfile << "# input tensors" << std::endl; + printTensor(interp, id, log_index++); + } + for (const auto id : interp.outputs()) + { + _outfile << "# output tensors" << std::endl; + printTensor(interp, id, log_index++); + } + close(); + } + +private: + void open(const std::string &path) + { + if (!_outfile.is_open()) + _outfile.open(path, std::ios_base::out); + + _outfile << "# ------ file: " << path << " ------" << std::endl + << "tensor_shape_gen = []" << std::endl + << "tensor_value_gen = []" << std::endl + << std::endl; + } + + void printTensor(::tflite::Interpreter &interp, const int id, const int log_index) + { + const TfLiteTensor *tensor = interp.tensor(id); + + _outfile << "# tensor name: " << tensor->name << std::endl; + _outfile << "# tflite::interpreter.tensor(" << id << ") -> " + "tensor_value_gen[" + << log_index << "]" << std::endl; + + if (tensor->type == kTfLiteInt32) + { + printTensorShape(tensor); + printTensorValue<int32_t>(tensor, tensor->data.i32); + } + else if (interp.tensor(id)->type == kTfLiteUInt8) + { + printTensorShape(tensor); + printTensorValue<uint8_t>(tensor, tensor->data.uint8); + } + else if (tensor->type == kTfLiteFloat32) + { + printTensorShape(tensor); + printTensorValue<float>(tensor, tensor->data.f); + } + } + + void printTensorShape(const TfLiteTensor *tensor) + { + _outfile << "tensor_shape_gen.append('{"; + + int r = 0; + for (; r < tensor->dims->size - 1; r++) + { + _outfile << tensor->dims->data[r] << ", "; + } + _outfile << tensor->dims->data[r]; + + _outfile << "}')" << std::endl; + } + + template <typename T> void printTensorValue(const TfLiteTensor *tensor, T *tensor_data_ptr) + { + _outfile << "tensor_value_gen.append(["; + + _outfile << std::fixed << std::setprecision(10); + + const T *end = reinterpret_cast<const T *>(tensor->data.raw_const + tensor->bytes); + for (T *ptr = tensor_data_ptr; ptr < end; ptr++) + _outfile << *ptr << ", "; + + _outfile << "])" << std::endl << std::endl; + } + + void close() + { + _outfile << "# --------- tensor shape and value defined above ---------" << std::endl; + _outfile.close(); + } +}; + +} // namespace tflite +} // namespace nnfw + +#endif // __NNFW_TFLITE_TENSOR_LOGGER_H__ diff --git a/runtime/libs/tflite/include/tflite/TensorShapeUtils.h b/runtime/libs/tflite/include/tflite/TensorShapeUtils.h new file mode 100644 index 000000000..ba8687413 --- /dev/null +++ b/runtime/libs/tflite/include/tflite/TensorShapeUtils.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file TensorShapeUtils.h + * @brief This file contains utilities function of tensor shape + * @ingroup COM_AI_RUNTIME + */ + +#ifndef __NNFW_TFLITE_TENSOR_SHAPE_UTILS_H__ +#define __NNFW_TFLITE_TENSOR_SHAPE_UTILS_H__ + +#include "misc/tensor/Shape.h" + +#include <vector> + +namespace nnfw +{ +namespace tflite +{ + +/** + * @brief Converts tensor::Shape into a vector + * @param[in] shape The tensor shape to be converted + * @return vector value of given shape object + */ +static inline std::vector<int32_t> as_dims(const nnfw::misc::tensor::Shape &shape) +{ + std::vector<int32_t> dims; + + for (uint32_t axis = 0; axis < shape.rank(); ++axis) + { + dims.emplace_back(shape.dim(axis)); + } + + return dims; +} + +/** + * @brief Broadcasts between two given shapes + * @param[in] lhs_shape The left hand side shape + * @param[in] rhs_shape The right hand side shape + * @return The broadcasted shape + */ +nnfw::misc::tensor::Shape broadcast(const nnfw::misc::tensor::Shape &lhs_shape, + const nnfw::misc::tensor::Shape &rhs_shape); + +} // namespace tflite +} // namespace nnfw + +#endif // __NNFW_TFLITE_TENSOR_SHAPE_UTILS_H__ diff --git a/runtime/libs/tflite/include/tflite/TensorUtils.h b/runtime/libs/tflite/include/tflite/TensorUtils.h new file mode 100644 index 000000000..08af1468b --- /dev/null +++ b/runtime/libs/tflite/include/tflite/TensorUtils.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file TensorUtils.h + * @brief This file contains utilities function + * @ingroup COM_AI_RUNTIME + */ + +#ifndef __NNFW_TFLITE_TENSOR_UTILS_H__ +#define __NNFW_TFLITE_TENSOR_UTILS_H__ + +#include <tensorflow/lite/context.h> + +namespace nnfw +{ +namespace tflite +{ + +/** + * @brief Get @c true if tensor type is kTfLiteFloat32, otherwise @c false + * @param[in] tensor The tensor object to be compared + * @return @c true if tensor type is kTfLiteFloat32, otherwise @c false + */ +inline bool isFloatTensor(const TfLiteTensor *tensor) { return tensor->type == kTfLiteFloat32; } + +/** + * @brief Get @c true if tensor is 4-D tensor and the first dimension length is 1, + * otherwise @c false + * @param[in] tensor The tensor object to be compared + * @return @c true if tensor is 4-D tensor and the first dimension length is 1, otherwise @c false + */ +inline bool isFeatureTensor(const TfLiteTensor *tensor) +{ + return (tensor->dims->size == 4) && (tensor->dims->data[0] == 1); +} + +} // namespace tflite +} // namespace nnfw + +#endif // __NNFW_TFLITE_TENSOR_UTILS_H__ diff --git a/runtime/libs/tflite/include/tflite/TensorView.h b/runtime/libs/tflite/include/tflite/TensorView.h new file mode 100644 index 000000000..ce791a73f --- /dev/null +++ b/runtime/libs/tflite/include/tflite/TensorView.h @@ -0,0 +1,120 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file TensorView.h + * @brief This file contains TensorView class + * @ingroup COM_AI_RUNTIME + */ + +#ifndef __NNFW_TFLITE_TENSOR_VIEW_H__ +#define __NNFW_TFLITE_TENSOR_VIEW_H__ + +#include "tensorflow/lite/interpreter.h" + +#include "misc/tensor/Shape.h" +#include "misc/tensor/Index.h" +#include "misc/tensor/Reader.h" +#include "misc/tensor/NonIncreasingStride.h" + +namespace nnfw +{ +namespace tflite +{ + +/** + * @brief Class to define TensorView which is inherited from nnfw::misc::tensor::Reader<T> class + */ +template <typename T> class TensorView final : public nnfw::misc::tensor::Reader<T> +{ +public: + /** + * @brief Construct a TensorView object with base and shape informations + * @param[in] shape The shape of a tensor + * @param[in] base The base address of a tensor + */ + TensorView(const nnfw::misc::tensor::Shape &shape, T *base) : _shape{shape}, _base{base} + { + // Set 'stride' + _stride.init(_shape); + } + +public: + /** + * @brief Get shape of tensor + * @return Reference of shape + */ + const nnfw::misc::tensor::Shape &shape(void) const { return _shape; } + +public: + /** + * @brief Get value of tensor index + * @param[in] index The tensor index + * @return The value at the index + */ + T at(const nnfw::misc::tensor::Index &index) const override + { + const auto offset = _stride.offset(index); + return *(_base + offset); + } + +public: + /** + * @brief Get reference value of tensor index + * @param[in] index The tensor index + * @return The reference value at the index + */ + T &at(const nnfw::misc::tensor::Index &index) + { + const auto offset = _stride.offset(index); + return *(_base + offset); + } + +private: + nnfw::misc::tensor::Shape _shape; /**< The tensor shape */ + +public: + T *_base; /**< The base address of tensor */ + nnfw::misc::tensor::NonIncreasingStride _stride; /**< The NonIncreasingStride object */ + +public: + // TODO Introduce Operand ID class + /** + * @brief Create TensorView object using given parameters + * @param[in] interp The TfLite interpreter + * @param[in] tensor_index The tensor index + * @return The new TensorView<T> object + */ + static TensorView<T> make(::tflite::Interpreter &interp, int tensor_index) + { + auto tensor_ptr = interp.tensor(tensor_index); + + // Set 'shape' + nnfw::misc::tensor::Shape shape(tensor_ptr->dims->size); + + for (uint32_t axis = 0; axis < shape.rank(); ++axis) + { + shape.dim(axis) = tensor_ptr->dims->data[axis]; + } + + return TensorView<T>(shape, interp.typed_tensor<T>(tensor_index)); + } +}; + +} // namespace tflite +} // namespace nnfw + +#endif // __NNFW_TFLITE_TENSOR_VIEW_H__ diff --git a/runtime/libs/tflite/include/tflite/interp/Builder.h b/runtime/libs/tflite/include/tflite/interp/Builder.h new file mode 100644 index 000000000..0f54e1779 --- /dev/null +++ b/runtime/libs/tflite/include/tflite/interp/Builder.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file Builder.h + * @brief This file contains Builder structure + * @ingroup COM_AI_RUNTIME + */ + +#ifndef __NNFW_TFLITE_INTERP_BUILDER_H__ +#define __NNFW_TFLITE_INTERP_BUILDER_H__ + +#include <tensorflow/lite/interpreter.h> + +namespace nnfw +{ +namespace tflite +{ + +/** + * @brief Structure to Builder + */ +struct Builder +{ + /** + * @brief Destroy the Builder object + */ + virtual ~Builder() = default; + + /** + * @brief Build a FlatBuffer model + * @return The TfLite interpreter object + */ + virtual std::unique_ptr<::tflite::Interpreter> build(void) const = 0; +}; + +} // namespace tflite +} // namespace nnfw + +#endif // __NNFW_TFLITE_INTERP_BUILDER_H__ diff --git a/runtime/libs/tflite/include/tflite/interp/FlatBufferBuilder.h b/runtime/libs/tflite/include/tflite/interp/FlatBufferBuilder.h new file mode 100644 index 000000000..2d96af50b --- /dev/null +++ b/runtime/libs/tflite/include/tflite/interp/FlatBufferBuilder.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file FlatBufferBuilder.h + * @brief This file contains FlatBufferBuilder class + * @ingroup COM_AI_RUNTIME + */ + +#ifndef __NNFW_TFLITE_INTERP_FLAT_BUFFER_BUILDER_H__ +#define __NNFW_TFLITE_INTERP_FLAT_BUFFER_BUILDER_H__ + +#include <tensorflow/lite/model.h> + +#include "tflite/interp/Builder.h" + +namespace nnfw +{ +namespace tflite +{ + +/** + * @brief Class to define FlatBufferBuilder which is inherited from Builder + */ +class FlatBufferBuilder final : public Builder +{ +public: + /** + * @brief Construct a FlatBufferBuilder object with FlatBufferModel of TfLite + * @param[in] model The TfLite Flatbuffer model + */ + FlatBufferBuilder(const ::tflite::FlatBufferModel &model) : _model{model} + { + // DO NOTHING + } + +public: + /** + * @brief Build a FlatBuffer model + * @return The TfLite interpreter pointer address + */ + std::unique_ptr<::tflite::Interpreter> build(void) const override; + +private: + const ::tflite::FlatBufferModel &_model; +}; + +} // namespace tflite +} // namespace nnfw + +#endif // __NNFW_TFLITE_INTERP_FLAT_BUFFER_BUILDER_H__ diff --git a/runtime/libs/tflite/include/tflite/interp/FunctionBuilder.h b/runtime/libs/tflite/include/tflite/interp/FunctionBuilder.h new file mode 100644 index 000000000..7bfb8db2d --- /dev/null +++ b/runtime/libs/tflite/include/tflite/interp/FunctionBuilder.h @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file FunctionBuilder.h + * @brief This file contains FunctionBuilder class + * @ingroup COM_AI_RUNTIME + */ + +#ifndef __NNFW_TFLITE_INTERP_FUNCTION_BUILDER_H__ +#define __NNFW_TFLITE_INTERP_FUNCTION_BUILDER_H__ + +#include <tensorflow/lite/model.h> + +#include "tflite/interp/Builder.h" + +namespace nnfw +{ +namespace tflite +{ + +/** + * @brief Class to define FunctionBuilder which is inherited from Builder + */ +class FunctionBuilder final : public Builder +{ +public: + using SetupFunc = std::function<void(::tflite::Interpreter &)>; + +public: + /** + * @brief Construct a FunctionBuilder object with SetupFunction + * @param[in] fn The SetupFunc object + */ + FunctionBuilder(const SetupFunc &fn) : _fn{fn} + { + // DO NOTHING + } + +public: + /** + * @brief Build a SetupFunc + * @return The TfLite interpreter pointer address + */ + std::unique_ptr<::tflite::Interpreter> build(void) const override; + +private: + SetupFunc _fn; +}; + +} // namespace tflite +} // namespace nnfw + +#endif // __NNFW_TFLITE_INTERP_FUNCTION_BUILDER_H__ |