diff options
Diffstat (limited to 'include/support/tflite/Diff.h')
-rw-r--r-- | include/support/tflite/Diff.h | 119 |
1 files changed, 119 insertions, 0 deletions
diff --git a/include/support/tflite/Diff.h b/include/support/tflite/Diff.h new file mode 100644 index 000000000..b17c9313c --- /dev/null +++ b/include/support/tflite/Diff.h @@ -0,0 +1,119 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __NNFW_SUPPORT_TFLITE_COMPARE_H__ +#define __NNFW_SUPPORT_TFLITE_COMPARE_H__ + +#include "tensorflow/contrib/lite/interpreter.h" + +#include "util/tensor/Index.h" + +#include "support/tflite/TensorView.h" + +#include <functional> +#include <vector> + +// NOTE The code below is subject to change. +// TODO Introduce namespaces +struct TfLiteTensorDiff +{ + nnfw::util::tensor::Index index; + float expected; + float obtained; + + TfLiteTensorDiff(const nnfw::util::tensor::Index &i) : index(i) + { + // DO NOTHING + } +}; + +class TfLiteTensorComparator +{ +public: + TfLiteTensorComparator(const std::function<bool (float lhs, float rhs)> &fn) : _compare_fn{fn} + { + // DO NOTHING + } + +public: + struct Observer + { + virtual void notify(const nnfw::util::tensor::Index &index, float expected, float obtained) = 0; + }; + +public: + // NOTE Observer should live longer than comparator + std::vector<TfLiteTensorDiff> compare(const nnfw::support::tflite::TensorView<float> &expected, + const nnfw::support::tflite::TensorView<float> &obtained, + Observer *observer = nullptr) const; + +private: + std::function<bool (float lhs, float rhs)> _compare_fn; +}; + +class TfLiteInterpMatchApp +{ +public: + TfLiteInterpMatchApp(const TfLiteTensorComparator &comparator) + : _verbose{false}, _comparator(comparator) + { + // DO NOTHING + } + +public: + int &verbose(void) { return _verbose; } + +private: + int _verbose; + +public: + bool run(::tflite::Interpreter &pure, ::tflite::Interpreter &nnapi) const; + +private: + const TfLiteTensorComparator &_comparator; +}; + +#include "support/tflite/interp/Builder.h" + +#include <random> + +// For NNAPI testing +struct RandomTestParam +{ + int verbose; + int tolerance; +}; + +class RandomTestRunner +{ +public: + RandomTestRunner(int seed, const RandomTestParam ¶m) + : _rand{seed}, _param{param} + { + // DO NOTHING + } + +public: + // NOTE this method updates '_rand' + // Return 0 if test succeeds + int run(const nnfw::support::tflite::interp::Builder &builder); + +private: + std::minstd_rand _rand; + const RandomTestParam _param; +}; + +#endif // __NNFW_SUPPORT_TFLITE_COMPARE_H__ |