summaryrefslogtreecommitdiff
path: root/runtimes/contrib/tflite_test/tflite_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'runtimes/contrib/tflite_test/tflite_test.cpp')
-rw-r--r--runtimes/contrib/tflite_test/tflite_test.cpp239
1 files changed, 239 insertions, 0 deletions
diff --git a/runtimes/contrib/tflite_test/tflite_test.cpp b/runtimes/contrib/tflite_test/tflite_test.cpp
new file mode 100644
index 000000000..80ba448c6
--- /dev/null
+++ b/runtimes/contrib/tflite_test/tflite_test.cpp
@@ -0,0 +1,239 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <tensorflow/lite/model.h>
+#include <tensorflow/core/public/session.h>
+
+#include "tflite/Assert.h"
+#include "tflite/Session.h"
+#include "tflite/InterpreterSession.h"
+#include "tflite/NNAPISession.h"
+#include "tflite/ext/kernels/register.h"
+
+#include "misc/fp32.h"
+
+#include <iostream>
+
+#include <string>
+#include <vector>
+
+#define TF_ENSURE(e) \
+ { \
+ if (!(e).ok()) \
+ { \
+ throw std::runtime_error{"'" #e "' FAILED"}; \
+ } \
+ }
+
+using namespace tflite;
+using namespace tflite::ops::builtin;
+
+std::unique_ptr<FlatBufferModel> BuildModelFromFile(const std::string &path)
+{
+ static StderrReporter reporter;
+ return FlatBufferModel::BuildFromFile(path.c_str(), &reporter);
+}
+
+std::unique_ptr<Interpreter> BuildInterpFromModel(const std::unique_ptr<FlatBufferModel> &model)
+{
+ std::unique_ptr<Interpreter> interp;
+
+ BuiltinOpResolver resolver;
+ InterpreterBuilder builder(*model, resolver);
+
+ TFLITE_ENSURE(builder(&interp));
+
+ return std::move(interp);
+}
+
+tensorflow::TensorShape asTensorflowShape(const TfLiteTensor *tensor)
+{
+ tensorflow::TensorShape shape;
+
+ const int rank = tensor->dims->size;
+
+ for (int axis = 0; axis < rank; ++axis)
+ {
+ shape.AddDim(tensor->dims->data[axis]);
+ }
+
+ return shape;
+}
+
+uint32_t count_elements(const TfLiteTensor *tensor)
+{
+ const int rank = tensor->dims->size;
+
+ if (rank == 0)
+ {
+ return 0;
+ }
+
+ uint32_t res = 1;
+
+ for (int axis = 0; axis < rank; ++axis)
+ {
+ res *= tensor->dims->data[axis];
+ }
+
+ return res;
+}
+
+int main(int argc, char **argv)
+{
+ bool use_nnapi = false;
+
+ if (std::getenv("USE_NNAPI") != nullptr)
+ {
+ use_nnapi = true;
+ }
+
+ if (argc < 3)
+ {
+ std::cerr << "USAGE: " << argv[0] << " [T/F lite model] [T/F model]" << std::endl;
+ return 255;
+ }
+
+ //
+ // Prepare Tensorflow Lite session
+ //
+ const std::string lite_model_path{argv[1]};
+
+ auto lite_model = BuildModelFromFile(lite_model_path);
+ auto lite_interp = BuildInterpFromModel(lite_model);
+
+ std::shared_ptr<nnfw::tflite::Session> lite_sess;
+
+ if (use_nnapi)
+ {
+ lite_sess = std::make_shared<nnfw::tflite::NNAPISession>(lite_interp.get());
+ }
+ else
+ {
+ lite_sess = std::make_shared<nnfw::tflite::InterpreterSession>(lite_interp.get());
+ }
+
+ //
+ // Prepare Tensorflow session
+ //
+ const std::string full_model_path{argv[2]};
+
+ tensorflow::Session *full_sess;
+ tensorflow::GraphDef full_model;
+
+ TF_ENSURE(tensorflow::NewSession(tensorflow::SessionOptions(), &full_sess));
+ TF_ENSURE(ReadBinaryProto(tensorflow::Env::Default(), full_model_path, &full_model));
+ TF_ENSURE(full_sess->Create(full_model));
+
+ //
+ //
+ //
+ std::vector<tensorflow::Tensor> input_nodes;
+ std::vector<std::string> input_names;
+
+ for (uint32_t n = 0; n < lite_interp->inputs().size(); ++n)
+ {
+ const TfLiteTensor *tensor = lite_interp->tensor(lite_interp->inputs().at(n));
+
+ input_nodes.emplace_back(tensorflow::DT_FLOAT, asTensorflowShape(tensor));
+ input_names.emplace_back(tensor->name);
+ }
+
+ assert(input_nodes.size() == input_names.size());
+ assert(input_nodes.size() == lite_interp->inputs().size());
+
+ std::vector<std::string> output_names;
+ std::vector<tensorflow::Tensor> output_nodes;
+
+ for (uint32_t n = 0; n < lite_interp->outputs().size(); ++n)
+ {
+ const TfLiteTensor *tensor = lite_interp->tensor(lite_interp->outputs().at(n));
+
+ output_names.emplace_back(tensor->name);
+ }
+
+ assert(output_names.size() == lite_interp->outputs().size());
+ // output_nodes will be initialized after Tensorflow Session run
+ assert(output_nodes.size() == 0);
+
+ //
+ // Prepare inference
+ //
+ lite_sess->prepare();
+
+ // TODO Feed Inputs (for both Tensorflow and Tensorflow Lite)
+ std::vector<std::pair<std::string, tensorflow::Tensor>> inputs;
+
+ for (uint32_t n = 0; n < input_nodes.size(); ++n)
+ {
+ inputs.emplace_back(input_names.at(0), input_nodes.at(0));
+ }
+
+ //
+ // Run inference
+ //
+ TF_ENSURE(full_sess->Run(inputs, output_names, {}, &output_nodes));
+
+ lite_sess->run();
+
+ //
+ // Compare Output
+ //
+ auto equals = [](float lhs, float rhs) {
+ // TODO Allow users to set tolerance
+ if (nnfw::misc::fp32::absolute_epsilon_equal(lhs, rhs))
+ {
+ return true;
+ }
+
+ return nnfw::misc::fp32::epsilon_equal(lhs, rhs);
+ };
+
+ const uint32_t output_count = output_names.size();
+
+ bool matched = true;
+
+ for (uint32_t n = 0; n < output_count; ++n)
+ {
+ const TfLiteTensor *tensor = lite_interp->tensor(lite_interp->outputs().at(n));
+
+ // TODO Compare shape
+
+ const auto element_count = count_elements(tensor);
+
+ std::cout << "Compare output #" << n << "(" << tensor->name << ", " << element_count
+ << " elements)" << std::endl;
+ for (uint32_t index = 0; index < element_count; ++index)
+ {
+ const auto full_value = output_nodes.at(n).flat<float>().data()[index];
+ const auto lite_value = lite_sess->interp()->typed_output_tensor<float>(n)[index];
+
+ if (!equals(full_value, lite_value))
+ {
+ std::cerr << full_value << " is expected, but " << lite_value << " is obtaeind (at " << n
+ << ":" << index << ")" << std::endl;
+ matched = false;
+ }
+ }
+ }
+
+ //
+ // Cleanup
+ //
+ lite_sess->teardown();
+
+ return matched ? 0 : 255;
+}