summaryrefslogtreecommitdiff
path: root/tests/tools/tflite_run/src/tflite_run.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tests/tools/tflite_run/src/tflite_run.cc')
-rw-r--r--tests/tools/tflite_run/src/tflite_run.cc67
1 files changed, 53 insertions, 14 deletions
diff --git a/tests/tools/tflite_run/src/tflite_run.cc b/tests/tools/tflite_run/src/tflite_run.cc
index 5be6909e5..deed12856 100644
--- a/tests/tools/tflite_run/src/tflite_run.cc
+++ b/tests/tools/tflite_run/src/tflite_run.cc
@@ -15,14 +15,13 @@
*/
#include "tflite/ext/kernels/register.h"
-#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/lite/model.h"
-#include "bin_image.h"
#include "args.h"
#include "tensor_dumper.h"
#include "tensor_loader.h"
#include "misc/benchmark.h"
-#include "misc/environment.h"
+#include "misc/EnvVar.h"
#include "misc/fp32.h"
#include "tflite/Diff.h"
#include "tflite/Assert.h"
@@ -65,15 +64,23 @@ int main(const int argc, char **argv)
std::chrono::milliseconds t_prepare(0);
std::chrono::milliseconds t_invoke(0);
- nnfw::misc::benchmark::measure(t_prepare) << [&](void) {
- BuiltinOpResolver resolver;
+ try
+ {
+ nnfw::misc::benchmark::measure(t_prepare) << [&](void) {
+ BuiltinOpResolver resolver;
- InterpreterBuilder builder(*model, resolver);
+ InterpreterBuilder builder(*model, resolver);
- TFLITE_ENSURE(builder(&interpreter))
+ TFLITE_ENSURE(builder(&interpreter))
- interpreter->SetNumThreads(1);
- };
+ interpreter->SetNumThreads(nnfw::misc::EnvVar("THREAD").asInt(-1));
+ };
+ }
+ catch (const std::exception &e)
+ {
+ std::cerr << e.what() << '\n';
+ return 1;
+ }
std::shared_ptr<nnfw::tflite::Session> sess;
@@ -88,12 +95,45 @@ int main(const int argc, char **argv)
sess->prepare();
+ if (args.getInputShapes().size() != 0)
+ {
+ const int dim_values = args.getInputShapes().size();
+ int offset = 0;
+
+ for (const auto &id : interpreter->inputs())
+ {
+ TfLiteTensor *tensor = interpreter->tensor(id);
+ std::vector<int32_t> new_dim;
+ new_dim.resize(tensor->dims->size);
+
+ for (uint32_t axis = 0; axis < tensor->dims->size; axis++, offset++)
+ {
+ new_dim[axis] =
+ ((offset < dim_values) ? args.getInputShapes()[offset] : tensor->dims->data[axis]);
+ }
+
+ interpreter->ResizeInputTensor(id, new_dim);
+
+ if (offset >= dim_values)
+ break;
+ }
+ interpreter->AllocateTensors();
+ }
+
TFLiteRun::TensorLoader tensor_loader(*interpreter);
- // Load input from dumped tensor file.
- if (!args.getCompareFilename().empty())
+ // Load input from raw or dumped tensor file.
+ // Two options are exclusive and will be checked from Args.
+ if (!args.getInputFilename().empty() || !args.getCompareFilename().empty())
{
- tensor_loader.load(args.getCompareFilename());
+ if (!args.getInputFilename().empty())
+ {
+ tensor_loader.loadRawTensors(args.getInputFilename(), interpreter->inputs());
+ }
+ else
+ {
+ tensor_loader.loadDumpedTensors(args.getCompareFilename());
+ }
for (const auto &o : interpreter->inputs())
{
@@ -226,8 +266,7 @@ int main(const int argc, char **argv)
// TODO Code duplication (copied from RandomTestRunner)
- int tolerance = 1;
- nnfw::misc::env::IntAccessor("TOLERANCE").access(tolerance);
+ int tolerance = nnfw::misc::EnvVar("TOLERANCE").asInt(1);
auto equals = [tolerance](float lhs, float rhs) {
// NOTE Hybrid approach