summaryrefslogtreecommitdiff
path: root/tests/tools/tflite_comparator/src/tflite_comparator.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tests/tools/tflite_comparator/src/tflite_comparator.cc')
-rw-r--r--tests/tools/tflite_comparator/src/tflite_comparator.cc398
1 files changed, 398 insertions, 0 deletions
diff --git a/tests/tools/tflite_comparator/src/tflite_comparator.cc b/tests/tools/tflite_comparator/src/tflite_comparator.cc
new file mode 100644
index 000000000..383a4e4de
--- /dev/null
+++ b/tests/tools/tflite_comparator/src/tflite_comparator.cc
@@ -0,0 +1,398 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "args.h"
+
+#include <nnfw_experimental.h>
+#include <nnfw_internal.h>
+
+#include <misc/EnvVar.h>
+#include <misc/fp32.h>
+#include <misc/RandomGenerator.h>
+
+#include <tflite/Assert.h>
+#include <tflite/InterpreterSession.h>
+
+#include <fstream>
+#include <iostream>
+#include <memory>
+#include <string>
+
+const int RUN_FAILED = 1;
+
+using namespace nnfw::tflite;
+
+const int FILE_ERROR = 2;
+
+#define NNFW_ASSERT_FAIL(expr, msg) \
+ if ((expr) != NNFW_STATUS_NO_ERROR) \
+ { \
+ std::cerr << msg << std::endl; \
+ exit(-1); \
+ }
+
+// Read vector of floats from selected file
+void readData(const std::string &path, std::vector<uint8_t> &dest)
+{
+ std::ifstream in(path);
+ if (!in.good())
+ {
+ std::cerr << "can not open data file " << path << "\n";
+ exit(FILE_ERROR);
+ }
+ in.seekg(0, std::ifstream::end);
+ size_t len = in.tellg();
+ in.seekg(0, std::ifstream::beg);
+
+ assert(dest.size() == len);
+ in.read(reinterpret_cast<char *>(dest.data()), len);
+}
+
+template <typename T>
+void randomData(nnfw::misc::RandomGenerator &randgen, std::vector<uint8_t> &dest)
+{
+ size_t elements = dest.size() / sizeof(T);
+ assert(dest.size() % sizeof(T) == 0);
+
+ std::vector<T> vec(elements);
+ for (uint64_t i = 0; i < elements; i++)
+ {
+ vec[i] = randgen.generate<T>();
+ }
+ memcpy(dest.data(), vec.data(), elements * sizeof(T));
+}
+
+void randomBoolData(nnfw::misc::RandomGenerator &randgen, std::vector<uint8_t> &dest)
+{
+ size_t elements = dest.size();
+ std::vector<uint8_t> vec(elements);
+ for (uint64_t i = 0; i < elements; i++)
+ {
+ bool value = randgen.generate<bool>();
+ dest[i] = value ? 1 : 0;
+ }
+}
+
+inline uint64_t num_elems(const nnfw_tensorinfo *ti)
+{
+ uint64_t n = 1;
+ for (uint32_t i = 0; i < ti->rank; ++i)
+ {
+ n *= ti->dims[i];
+ }
+ return n;
+}
+
+inline size_t sizeOfNnfwType(NNFW_TYPE type)
+{
+ switch (type)
+ {
+ case NNFW_TYPE_TENSOR_BOOL:
+ case NNFW_TYPE_TENSOR_UINT8:
+ case NNFW_TYPE_TENSOR_QUANT8_ASYMM:
+ case NNFW_TYPE_TENSOR_QUANT8_ASYMM_SIGNED:
+ return 1;
+ case NNFW_TYPE_TENSOR_QUANT16_SYMM_SIGNED:
+ return 2;
+ case NNFW_TYPE_TENSOR_FLOAT32:
+ case NNFW_TYPE_TENSOR_INT32:
+ return 4;
+ case NNFW_TYPE_TENSOR_INT64:
+ return 8;
+ default:
+ throw std::runtime_error{"Invalid tensor type"};
+ }
+}
+
+template <typename T>
+bool isClose(const T *ref_buf, const std::vector<uint8_t> &act_buf, uint32_t index)
+{
+ // TODO better way for handling quant error?
+ auto tolerance = static_cast<uint64_t>(nnfw::misc::EnvVar("TOLERANCE").asInt(0));
+ bool match = true;
+
+ for (uint32_t e = 0; e < act_buf.size() / sizeof(T); e++)
+ {
+ T ref = ref_buf[e];
+ T act = reinterpret_cast<const T *>(act_buf.data())[e];
+ uint64_t diff = static_cast<uint64_t>(((ref > act) ? (ref - act) : (act - ref)));
+
+ if (ref != act && diff > tolerance)
+ {
+ std::cerr << "Output #" << index << ", Element Index : " << e << ", ref: " << ref
+ << ", act: " << act << " (diff: " << diff << ")" << std::endl;
+ match = false;
+ }
+ }
+
+ return match;
+}
+
+template <>
+bool isClose<float>(const float *ref_buf, const std::vector<uint8_t> &act_buf, uint32_t index)
+{
+ uint32_t tolerance = nnfw::misc::EnvVar("TOLERANCE").asInt(1);
+ bool match = true;
+
+ for (uint32_t e = 0; e < act_buf.size() / sizeof(float); e++)
+ {
+ float ref = ref_buf[e];
+ float act = reinterpret_cast<const float *>(act_buf.data())[e];
+ float diff = std::fabs(ref - act);
+
+ bool match_elem = nnfw::misc::fp32::absolute_epsilon_equal(ref, act)
+ ? true
+ : nnfw::misc::fp32::epsilon_equal(ref, act, tolerance);
+
+ if (!match_elem)
+ {
+ std::cerr << "Output #" << index << ", Element Index : " << e << ", ref: " << ref
+ << ", act: " << act << " (diff: " << diff << ")" << std::endl;
+ match = false;
+ }
+ }
+
+ return match;
+}
+
+bool exact(const uint8_t *ref_buf, const std::vector<uint8_t> &act_buf, uint32_t index)
+{
+ bool match = true;
+ for (uint32_t e = 0; e < act_buf.size() / sizeof(uint8_t); e++)
+ {
+ uint8_t ref_raw = ref_buf[e];
+ bool ref = (ref_raw != 0 ? true : false);
+ uint8_t act_raw = reinterpret_cast<const uint8_t *>(act_buf.data())[e];
+ bool act = (act_raw != 0 ? true : false);
+ if (ref != act)
+ {
+ std::cerr << "Output #" << index << ", Element Index : " << e << ", ref: " << ref
+ << ", act: " << act << std::endl;
+ match = false;
+ }
+ }
+
+ return match;
+}
+
+int main(const int argc, char **argv)
+{
+ TFLiteRun::Args args(argc, argv);
+
+ auto tflite_file = args.getTFLiteFilename();
+ auto data_files = args.getDataFilenames();
+
+ if (tflite_file.empty())
+ {
+ args.print(argv);
+ return RUN_FAILED;
+ }
+
+ std::cout << "[Execution] Stage start!" << std::endl;
+ // Loading
+ nnfw_session *onert_session = nullptr;
+ NNFW_ASSERT_FAIL(nnfw_create_session(&onert_session), "[ ERROR ] Failure during model load");
+ if (onert_session == nullptr)
+ {
+ std::cerr << "[ ERROR ] Failure to open session" << std::endl;
+ exit(-1);
+ }
+
+ NNFW_ASSERT_FAIL(nnfw_load_model_from_modelfile(onert_session, tflite_file.c_str()),
+ "[ ERROR ] Failure during model load");
+
+ uint32_t num_inputs;
+ uint32_t num_outputs;
+ NNFW_ASSERT_FAIL(nnfw_input_size(onert_session, &num_inputs),
+ "[ ERROR ] Failure during get model inputs");
+ NNFW_ASSERT_FAIL(nnfw_output_size(onert_session, &num_outputs),
+ "[ ERROR ] Failure during get model outputs");
+
+ std::cout << "[Execution] Model is deserialized!" << std::endl;
+
+ // Compile
+ nnfw_prepare(onert_session);
+
+ std::cout << "[Execution] Model compiled!" << std::endl;
+
+ // Prepare input/output data
+ std::vector<std::vector<uint8_t>> inputs(num_inputs);
+ std::vector<std::vector<uint8_t>> outputs(num_outputs);
+
+ bool generate_data = data_files.empty();
+ bool read_data = data_files.size() == num_inputs;
+ if (!generate_data && !read_data)
+ {
+ std::cerr << "[ ERROR ] "
+ << "Wrong number of input files." << std::endl;
+ exit(1);
+ }
+
+ const int seed = 1; /* TODO Add an option for seed value */
+ nnfw::misc::RandomGenerator randgen{seed, 0.0f, 2.0f};
+
+ for (uint32_t i = 0; i < num_inputs; i++)
+ {
+ nnfw_tensorinfo ti_input;
+ NNFW_ASSERT_FAIL(nnfw_input_tensorinfo(onert_session, i, &ti_input),
+ "[ ERROR ] Failure during get input data info");
+ size_t input_size = num_elems(&ti_input) * sizeOfNnfwType(ti_input.dtype);
+
+ inputs[i].resize(input_size);
+
+ if (generate_data)
+ {
+ switch (ti_input.dtype)
+ {
+ case NNFW_TYPE_TENSOR_BOOL:
+ randomBoolData(randgen, inputs[i]);
+ break;
+ case NNFW_TYPE_TENSOR_UINT8:
+ case NNFW_TYPE_TENSOR_QUANT8_ASYMM:
+ randomData<uint8_t>(randgen, inputs[i]);
+ break;
+ case NNFW_TYPE_TENSOR_QUANT8_ASYMM_SIGNED:
+ randomData<int8_t>(randgen, inputs[i]);
+ break;
+ case NNFW_TYPE_TENSOR_QUANT16_SYMM_SIGNED:
+ randomData<int16_t>(randgen, inputs[i]);
+ case NNFW_TYPE_TENSOR_FLOAT32:
+ randomData<float>(randgen, inputs[i]);
+ break;
+ case NNFW_TYPE_TENSOR_INT32:
+ randomData<int32_t>(randgen, inputs[i]);
+ break;
+ case NNFW_TYPE_TENSOR_INT64:
+ randomData<uint64_t>(randgen, inputs[i]);
+ break;
+ default:
+ std::cerr << "[ ERROR ] "
+ << "Unspported input data type" << std::endl;
+ exit(-1);
+ break;
+ }
+ }
+ else /* read_data */
+ readData(data_files[i], inputs[i]);
+
+ NNFW_ASSERT_FAIL(nnfw_set_input(onert_session, i, ti_input.dtype, inputs[i].data(), input_size),
+ "[ ERROR ] Failure to set input tensor buffer");
+ }
+
+ std::cout << "[Execution] Input data is defined!" << std::endl;
+
+ for (uint32_t i = 0; i < num_outputs; i++)
+ {
+ nnfw_tensorinfo ti_output;
+ NNFW_ASSERT_FAIL(nnfw_output_tensorinfo(onert_session, i, &ti_output),
+ "[ ERROR ] Failure during get output tensor info");
+
+ uint64_t output_elements = num_elems(&ti_output);
+ size_t output_size = output_elements * sizeOfNnfwType(ti_output.dtype);
+ outputs[i].resize(output_size);
+
+ NNFW_ASSERT_FAIL(
+ nnfw_set_output(onert_session, i, ti_output.dtype, outputs[i].data(), output_size),
+ "[ ERROR ] Failure to set output tensor buffer");
+ }
+
+ // Execute
+ NNFW_ASSERT_FAIL(nnfw_run(onert_session), "[Execution] Can't execute");
+
+ std::cout << "[Execution] Done!" << std::endl;
+
+ // Compare with tflite
+ std::cout << "[Comparison] Stage start!" << std::endl;
+ // Read tflite model
+ auto model = TfLiteModelCreateFromFile(tflite_file.c_str());
+ auto options = TfLiteInterpreterOptionsCreate();
+ TfLiteInterpreterOptionsSetNumThreads(options, nnfw::misc::EnvVar("THREAD").asInt(1));
+ auto interpreter = TfLiteInterpreterCreate(model, options);
+
+ auto sess = std::make_shared<nnfw::tflite::InterpreterSession>(interpreter);
+ sess->prepare();
+ // Set input and run
+ for (uint32_t i = 0; i < num_inputs; i++)
+ {
+ auto input_tensor = TfLiteInterpreterGetInputTensor(interpreter, i);
+ memcpy(TfLiteTensorData(input_tensor), inputs[i].data(), inputs[i].size());
+ }
+ if (!sess->run())
+ {
+ std::cout << "[Comparison] TFLite run failed!" << std::endl;
+ assert(0 && "Run failed!");
+ }
+ std::cout << "[Comparison] TFLite run done!" << std::endl;
+
+ bool find_unmatched_output = false;
+
+ for (uint32_t out_idx = 0; out_idx < num_outputs; out_idx++)
+ {
+ nnfw_tensorinfo ti;
+ nnfw_output_tensorinfo(onert_session, out_idx, &ti);
+
+ bool matched = true;
+ // Check output tensor values
+ auto output_tensor = TfLiteInterpreterGetOutputTensor(interpreter, out_idx);
+ auto ref_output = TfLiteTensorData(output_tensor);
+ const auto &output = outputs[out_idx];
+
+ switch (ti.dtype)
+ {
+ case NNFW_TYPE_TENSOR_BOOL:
+ matched = exact(reinterpret_cast<uint8_t *>(ref_output), output, out_idx);
+ break;
+ case NNFW_TYPE_TENSOR_UINT8:
+ case NNFW_TYPE_TENSOR_QUANT8_ASYMM:
+ matched = isClose<uint8_t>(reinterpret_cast<uint8_t *>(ref_output), output, out_idx);
+ break;
+ case NNFW_TYPE_TENSOR_QUANT8_ASYMM_SIGNED:
+ matched = isClose<int8_t>(reinterpret_cast<int8_t *>(ref_output), output, out_idx);
+ break;
+ case NNFW_TYPE_TENSOR_INT32:
+ matched = isClose<int32_t>(reinterpret_cast<int32_t *>(ref_output), output, out_idx);
+ break;
+ case NNFW_TYPE_TENSOR_FLOAT32:
+ matched = isClose<float>(reinterpret_cast<float *>(ref_output), output, out_idx);
+ break;
+ case NNFW_TYPE_TENSOR_INT64:
+ matched = isClose<int64_t>(reinterpret_cast<int64_t *>(ref_output), output, out_idx);
+ break;
+ default:
+ throw std::runtime_error{"Invalid tensor type"};
+ }
+
+ if (!matched)
+ find_unmatched_output = true;
+ }
+
+ // Print results
+ int ret = 0;
+ if (find_unmatched_output)
+ {
+ std::cout << "[Comparison] outputs is not equal!" << std::endl;
+ ret = 1;
+ }
+ else
+ {
+ std::cout << "[Comparison] Outputs is equal!" << std::endl;
+ }
+ std::cout << "[Comparison] Done!" << std::endl;
+
+ nnfw_close_session(onert_session);
+
+ return ret;
+}