summaryrefslogtreecommitdiff
path: root/tests/tools/tflite_run/src/tflite_run.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tests/tools/tflite_run/src/tflite_run.cc')
-rw-r--r--tests/tools/tflite_run/src/tflite_run.cc262
1 files changed, 262 insertions, 0 deletions
diff --git a/tests/tools/tflite_run/src/tflite_run.cc b/tests/tools/tflite_run/src/tflite_run.cc
new file mode 100644
index 000000000..5be6909e5
--- /dev/null
+++ b/tests/tools/tflite_run/src/tflite_run.cc
@@ -0,0 +1,262 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "tflite/ext/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+
+#include "bin_image.h"
+#include "args.h"
+#include "tensor_dumper.h"
+#include "tensor_loader.h"
+#include "misc/benchmark.h"
+#include "misc/environment.h"
+#include "misc/fp32.h"
+#include "tflite/Diff.h"
+#include "tflite/Assert.h"
+#include "tflite/Session.h"
+#include "tflite/InterpreterSession.h"
+#include "tflite/NNAPISession.h"
+#include "misc/tensor/IndexIterator.h"
+#include "misc/tensor/Object.h"
+
+#include <iostream>
+#include <chrono>
+#include <algorithm>
+
+using namespace tflite;
+using namespace nnfw::tflite;
+using namespace std::placeholders; // for _1, _2 ...
+
+void print_max_idx(float *f, int size)
+{
+ float *p = std::max_element(f, f + size);
+ std::cout << "max:" << p - f;
+}
+
+int main(const int argc, char **argv)
+{
+ bool use_nnapi = false;
+
+ if (std::getenv("USE_NNAPI") != nullptr)
+ {
+ use_nnapi = true;
+ }
+
+ StderrReporter error_reporter;
+
+ TFLiteRun::Args args(argc, argv);
+
+ auto model = FlatBufferModel::BuildFromFile(args.getTFLiteFilename().c_str(), &error_reporter);
+ std::unique_ptr<Interpreter> interpreter;
+
+ std::chrono::milliseconds t_prepare(0);
+ std::chrono::milliseconds t_invoke(0);
+
+ nnfw::misc::benchmark::measure(t_prepare) << [&](void) {
+ BuiltinOpResolver resolver;
+
+ InterpreterBuilder builder(*model, resolver);
+
+ TFLITE_ENSURE(builder(&interpreter))
+
+ interpreter->SetNumThreads(1);
+ };
+
+ std::shared_ptr<nnfw::tflite::Session> sess;
+
+ if (use_nnapi)
+ {
+ sess = std::make_shared<nnfw::tflite::NNAPISession>(interpreter.get());
+ }
+ else
+ {
+ sess = std::make_shared<nnfw::tflite::InterpreterSession>(interpreter.get());
+ }
+
+ sess->prepare();
+
+ TFLiteRun::TensorLoader tensor_loader(*interpreter);
+
+ // Load input from dumped tensor file.
+ if (!args.getCompareFilename().empty())
+ {
+ tensor_loader.load(args.getCompareFilename());
+
+ for (const auto &o : interpreter->inputs())
+ {
+ const auto &tensor_view = tensor_loader.get(o);
+ TfLiteTensor *tensor = interpreter->tensor(o);
+
+ memcpy(reinterpret_cast<void *>(tensor->data.f),
+ reinterpret_cast<const void *>(tensor_view._base), tensor->bytes);
+ }
+ }
+ else
+ {
+ const int seed = 1; /* TODO Add an option for seed value */
+ RandomGenerator randgen{seed, 0.0f, 2.0f};
+
+ // No input specified. So we fill the input tensors with random values.
+ for (const auto &o : interpreter->inputs())
+ {
+ TfLiteTensor *tensor = interpreter->tensor(o);
+ if (tensor->type == kTfLiteInt32)
+ {
+ // Generate singed 32-bit integer (s32) input
+ auto tensor_view = nnfw::tflite::TensorView<int32_t>::make(*interpreter, o);
+
+ int32_t value = 0;
+
+ nnfw::misc::tensor::iterate(tensor_view.shape())
+ << [&](const nnfw::misc::tensor::Index &ind) {
+ // TODO Generate random values
+ // Gather operation: index should be within input coverage.
+ tensor_view.at(ind) = value;
+ value++;
+ };
+ }
+ else if (tensor->type == kTfLiteUInt8)
+ {
+ // Generate unsigned 8-bit integer input
+ auto tensor_view = nnfw::tflite::TensorView<uint8_t>::make(*interpreter, o);
+
+ uint8_t value = 0;
+
+ nnfw::misc::tensor::iterate(tensor_view.shape())
+ << [&](const nnfw::misc::tensor::Index &ind) {
+ // TODO Generate random values
+ tensor_view.at(ind) = value;
+ value = (value + 1) & 0xFF;
+ };
+ }
+ else if (tensor->type == kTfLiteBool)
+ {
+ // Generate bool input
+ auto tensor_view = nnfw::tflite::TensorView<bool>::make(*interpreter, o);
+
+ auto fp = static_cast<bool (RandomGenerator::*)(const ::nnfw::misc::tensor::Shape &,
+ const ::nnfw::misc::tensor::Index &)>(
+ &RandomGenerator::generate<bool>);
+ const nnfw::misc::tensor::Object<bool> data(tensor_view.shape(),
+ std::bind(fp, randgen, _1, _2));
+
+ nnfw::misc::tensor::iterate(tensor_view.shape())
+ << [&](const nnfw::misc::tensor::Index &ind) {
+ const auto value = data.at(ind);
+ tensor_view.at(ind) = value;
+ };
+ }
+ else
+ {
+ assert(tensor->type == kTfLiteFloat32);
+
+ const float *end = reinterpret_cast<const float *>(tensor->data.raw_const + tensor->bytes);
+ for (float *ptr = tensor->data.f; ptr < end; ptr++)
+ {
+ *ptr = randgen.generate<float>();
+ }
+ }
+ }
+ }
+
+ TFLiteRun::TensorDumper tensor_dumper;
+ // Must be called before `interpreter->Invoke()`
+ tensor_dumper.addTensors(*interpreter, interpreter->inputs());
+
+ std::cout << "input tensor indices = [";
+ for (const auto &o : interpreter->inputs())
+ {
+ std::cout << o << ",";
+ }
+ std::cout << "]" << std::endl;
+
+ nnfw::misc::benchmark::measure(t_invoke) << [&sess](void) {
+ if (!sess->run())
+ {
+ assert(0 && "run failed!");
+ }
+ };
+
+ sess->teardown();
+
+ // Must be called after `interpreter->Invoke()`
+ tensor_dumper.addTensors(*interpreter, interpreter->outputs());
+
+ std::cout << "output tensor indices = [";
+ for (const auto &o : interpreter->outputs())
+ {
+ std::cout << o << "(";
+
+ print_max_idx(interpreter->tensor(o)->data.f, interpreter->tensor(o)->bytes / sizeof(float));
+
+ std::cout << "),";
+ }
+ std::cout << "]" << std::endl;
+
+ std::cout << "Prepare takes " << t_prepare.count() / 1000.0 << " seconds" << std::endl;
+ std::cout << "Invoke takes " << t_invoke.count() / 1000.0 << " seconds" << std::endl;
+
+ if (!args.getDumpFilename().empty())
+ {
+ const std::string &dump_filename = args.getDumpFilename();
+ tensor_dumper.dump(dump_filename);
+ std::cout << "Input/output tensors have been dumped to file \"" << dump_filename << "\"."
+ << std::endl;
+ }
+
+ if (!args.getCompareFilename().empty())
+ {
+ const std::string &compare_filename = args.getCompareFilename();
+ std::cout << "========================================" << std::endl;
+ std::cout << "Comparing the results with \"" << compare_filename << "\"." << std::endl;
+ std::cout << "========================================" << std::endl;
+
+ // TODO Code duplication (copied from RandomTestRunner)
+
+ int tolerance = 1;
+ nnfw::misc::env::IntAccessor("TOLERANCE").access(tolerance);
+
+ auto equals = [tolerance](float lhs, float rhs) {
+ // NOTE Hybrid approach
+ // TODO Allow users to set tolerance for absolute_epsilon_equal
+ if (nnfw::misc::fp32::absolute_epsilon_equal(lhs, rhs))
+ {
+ return true;
+ }
+
+ return nnfw::misc::fp32::epsilon_equal(lhs, rhs, tolerance);
+ };
+
+ nnfw::misc::tensor::Comparator comparator(equals);
+ TfLiteInterpMatchApp app(comparator);
+ bool res = true;
+
+ for (const auto &o : interpreter->outputs())
+ {
+ auto expected = tensor_loader.get(o);
+ auto obtained = nnfw::tflite::TensorView<float>::make(*interpreter, o);
+
+ res = res && app.compareSingleTensorView(expected, obtained, o);
+ }
+
+ if (!res)
+ {
+ return 255;
+ }
+ }
+
+ return 0;
+}