summaryrefslogtreecommitdiff
path: root/include/support/tflite/Diff.h
diff options
context:
space:
mode:
Diffstat (limited to 'include/support/tflite/Diff.h')
-rw-r--r--include/support/tflite/Diff.h92
1 files changed, 51 insertions, 41 deletions
diff --git a/include/support/tflite/Diff.h b/include/support/tflite/Diff.h
index b17c9313c..f4f3f6fe8 100644
--- a/include/support/tflite/Diff.h
+++ b/include/support/tflite/Diff.h
@@ -20,88 +20,92 @@
#include "tensorflow/contrib/lite/interpreter.h"
#include "util/tensor/Index.h"
+#include "util/tensor/Diff.h"
+#include "util/tensor/Shape.h"
+#include "util/tensor/Comparator.h"
#include "support/tflite/TensorView.h"
#include <functional>
#include <vector>
-// NOTE The code below is subject to change.
-// TODO Introduce namespaces
-struct TfLiteTensorDiff
-{
- nnfw::util::tensor::Index index;
- float expected;
- float obtained;
-
- TfLiteTensorDiff(const nnfw::util::tensor::Index &i) : index(i)
- {
- // DO NOTHING
- }
-};
-
-class TfLiteTensorComparator
+class TfLiteInterpMatchApp
{
public:
- TfLiteTensorComparator(const std::function<bool (float lhs, float rhs)> &fn) : _compare_fn{fn}
+ TfLiteInterpMatchApp(const nnfw::util::tensor::Comparator &comparator)
+ : _verbose{false}, _comparator(comparator)
{
// DO NOTHING
}
public:
- struct Observer
- {
- virtual void notify(const nnfw::util::tensor::Index &index, float expected, float obtained) = 0;
- };
+ int &verbose(void) { return _verbose; }
+
+private:
+ int _verbose;
public:
- // NOTE Observer should live longer than comparator
- std::vector<TfLiteTensorDiff> compare(const nnfw::support::tflite::TensorView<float> &expected,
- const nnfw::support::tflite::TensorView<float> &obtained,
- Observer *observer = nullptr) const;
+ bool run(::tflite::Interpreter &pure, ::tflite::Interpreter &nnapi) const;
+ template <typename T>
+ bool compareSingleTensorView(const nnfw::support::tflite::TensorView<T> &expected,
+ const nnfw::support::tflite::TensorView<T> &obtained,
+ int id) const;
private:
- std::function<bool (float lhs, float rhs)> _compare_fn;
+ const nnfw::util::tensor::Comparator &_comparator;
};
-class TfLiteInterpMatchApp
+#include "support/tflite/interp/Builder.h"
+#include "support/tflite/Quantization.h"
+
+#include <random>
+
+class RandomGenerator
{
public:
- TfLiteInterpMatchApp(const TfLiteTensorComparator &comparator)
- : _verbose{false}, _comparator(comparator)
+ RandomGenerator(int seed, float mean, float stddev,
+ const TfLiteQuantizationParams quantization = make_default_quantization())
+ : _rand{seed}, _dist{mean, stddev}, _quantization{quantization}
{
// DO NOTHING
}
public:
- int &verbose(void) { return _verbose; }
-
-private:
- int _verbose;
+ template <typename T>
+ T generate(const ::nnfw::util::tensor::Shape &, const ::nnfw::util::tensor::Index &)
+ {
+ return generate<T>();
+ }
-public:
- bool run(::tflite::Interpreter &pure, ::tflite::Interpreter &nnapi) const;
+ template <typename T> T generate(void)
+ {
+ return _dist(_rand);
+ }
private:
- const TfLiteTensorComparator &_comparator;
+ std::minstd_rand _rand;
+ std::normal_distribution<float> _dist;
+ const TfLiteQuantizationParams _quantization;
};
-#include "support/tflite/interp/Builder.h"
-
-#include <random>
+template <>
+uint8_t RandomGenerator::generate<uint8_t>(void);
// For NNAPI testing
struct RandomTestParam
{
int verbose;
int tolerance;
+ int tensor_logging = 0;
+ std::string log_path = ""; // meaningful only when tensor_logging is 1
};
class RandomTestRunner
{
public:
- RandomTestRunner(int seed, const RandomTestParam &param)
- : _rand{seed}, _param{param}
+ RandomTestRunner(int seed, const RandomTestParam &param,
+ const TfLiteQuantizationParams quantization = make_default_quantization())
+ : _randgen{seed, 0.0f, 2.0f, quantization}, _param{param}
{
// DO NOTHING
}
@@ -111,9 +115,15 @@ public:
// Return 0 if test succeeds
int run(const nnfw::support::tflite::interp::Builder &builder);
+public:
+ RandomGenerator &generator() { return _randgen; };
+
private:
- std::minstd_rand _rand;
+ RandomGenerator _randgen;
const RandomTestParam _param;
+
+public:
+ static RandomTestRunner make(int seed);
};
#endif // __NNFW_SUPPORT_TFLITE_COMPARE_H__