diff options
Diffstat (limited to 'include/support/tflite/Diff.h')
-rw-r--r-- | include/support/tflite/Diff.h | 92 |
1 files changed, 51 insertions, 41 deletions
diff --git a/include/support/tflite/Diff.h b/include/support/tflite/Diff.h index b17c9313c..f4f3f6fe8 100644 --- a/include/support/tflite/Diff.h +++ b/include/support/tflite/Diff.h @@ -20,88 +20,92 @@ #include "tensorflow/contrib/lite/interpreter.h" #include "util/tensor/Index.h" +#include "util/tensor/Diff.h" +#include "util/tensor/Shape.h" +#include "util/tensor/Comparator.h" #include "support/tflite/TensorView.h" #include <functional> #include <vector> -// NOTE The code below is subject to change. -// TODO Introduce namespaces -struct TfLiteTensorDiff -{ - nnfw::util::tensor::Index index; - float expected; - float obtained; - - TfLiteTensorDiff(const nnfw::util::tensor::Index &i) : index(i) - { - // DO NOTHING - } -}; - -class TfLiteTensorComparator +class TfLiteInterpMatchApp { public: - TfLiteTensorComparator(const std::function<bool (float lhs, float rhs)> &fn) : _compare_fn{fn} + TfLiteInterpMatchApp(const nnfw::util::tensor::Comparator &comparator) + : _verbose{false}, _comparator(comparator) { // DO NOTHING } public: - struct Observer - { - virtual void notify(const nnfw::util::tensor::Index &index, float expected, float obtained) = 0; - }; + int &verbose(void) { return _verbose; } + +private: + int _verbose; public: - // NOTE Observer should live longer than comparator - std::vector<TfLiteTensorDiff> compare(const nnfw::support::tflite::TensorView<float> &expected, - const nnfw::support::tflite::TensorView<float> &obtained, - Observer *observer = nullptr) const; + bool run(::tflite::Interpreter &pure, ::tflite::Interpreter &nnapi) const; + template <typename T> + bool compareSingleTensorView(const nnfw::support::tflite::TensorView<T> &expected, + const nnfw::support::tflite::TensorView<T> &obtained, + int id) const; private: - std::function<bool (float lhs, float rhs)> _compare_fn; + const nnfw::util::tensor::Comparator &_comparator; }; -class TfLiteInterpMatchApp +#include "support/tflite/interp/Builder.h" +#include "support/tflite/Quantization.h" + +#include <random> + +class RandomGenerator { public: - TfLiteInterpMatchApp(const TfLiteTensorComparator &comparator) - : _verbose{false}, _comparator(comparator) + RandomGenerator(int seed, float mean, float stddev, + const TfLiteQuantizationParams quantization = make_default_quantization()) + : _rand{seed}, _dist{mean, stddev}, _quantization{quantization} { // DO NOTHING } public: - int &verbose(void) { return _verbose; } - -private: - int _verbose; + template <typename T> + T generate(const ::nnfw::util::tensor::Shape &, const ::nnfw::util::tensor::Index &) + { + return generate<T>(); + } -public: - bool run(::tflite::Interpreter &pure, ::tflite::Interpreter &nnapi) const; + template <typename T> T generate(void) + { + return _dist(_rand); + } private: - const TfLiteTensorComparator &_comparator; + std::minstd_rand _rand; + std::normal_distribution<float> _dist; + const TfLiteQuantizationParams _quantization; }; -#include "support/tflite/interp/Builder.h" - -#include <random> +template <> +uint8_t RandomGenerator::generate<uint8_t>(void); // For NNAPI testing struct RandomTestParam { int verbose; int tolerance; + int tensor_logging = 0; + std::string log_path = ""; // meaningful only when tensor_logging is 1 }; class RandomTestRunner { public: - RandomTestRunner(int seed, const RandomTestParam ¶m) - : _rand{seed}, _param{param} + RandomTestRunner(int seed, const RandomTestParam ¶m, + const TfLiteQuantizationParams quantization = make_default_quantization()) + : _randgen{seed, 0.0f, 2.0f, quantization}, _param{param} { // DO NOTHING } @@ -111,9 +115,15 @@ public: // Return 0 if test succeeds int run(const nnfw::support::tflite::interp::Builder &builder); +public: + RandomGenerator &generator() { return _randgen; }; + private: - std::minstd_rand _rand; + RandomGenerator _randgen; const RandomTestParam _param; + +public: + static RandomTestRunner make(int seed); }; #endif // __NNFW_SUPPORT_TFLITE_COMPARE_H__ |