diff options
Diffstat (limited to 'tests/tools/tflite_run/src/tflite_run.cc')
-rw-r--r-- | tests/tools/tflite_run/src/tflite_run.cc | 67 |
1 files changed, 53 insertions, 14 deletions
diff --git a/tests/tools/tflite_run/src/tflite_run.cc b/tests/tools/tflite_run/src/tflite_run.cc index 5be6909e5..deed12856 100644 --- a/tests/tools/tflite_run/src/tflite_run.cc +++ b/tests/tools/tflite_run/src/tflite_run.cc @@ -15,14 +15,13 @@ */ #include "tflite/ext/kernels/register.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/model.h" -#include "bin_image.h" #include "args.h" #include "tensor_dumper.h" #include "tensor_loader.h" #include "misc/benchmark.h" -#include "misc/environment.h" +#include "misc/EnvVar.h" #include "misc/fp32.h" #include "tflite/Diff.h" #include "tflite/Assert.h" @@ -65,15 +64,23 @@ int main(const int argc, char **argv) std::chrono::milliseconds t_prepare(0); std::chrono::milliseconds t_invoke(0); - nnfw::misc::benchmark::measure(t_prepare) << [&](void) { - BuiltinOpResolver resolver; + try + { + nnfw::misc::benchmark::measure(t_prepare) << [&](void) { + BuiltinOpResolver resolver; - InterpreterBuilder builder(*model, resolver); + InterpreterBuilder builder(*model, resolver); - TFLITE_ENSURE(builder(&interpreter)) + TFLITE_ENSURE(builder(&interpreter)) - interpreter->SetNumThreads(1); - }; + interpreter->SetNumThreads(nnfw::misc::EnvVar("THREAD").asInt(-1)); + }; + } + catch (const std::exception &e) + { + std::cerr << e.what() << '\n'; + return 1; + } std::shared_ptr<nnfw::tflite::Session> sess; @@ -88,12 +95,45 @@ int main(const int argc, char **argv) sess->prepare(); + if (args.getInputShapes().size() != 0) + { + const int dim_values = args.getInputShapes().size(); + int offset = 0; + + for (const auto &id : interpreter->inputs()) + { + TfLiteTensor *tensor = interpreter->tensor(id); + std::vector<int32_t> new_dim; + new_dim.resize(tensor->dims->size); + + for (uint32_t axis = 0; axis < tensor->dims->size; axis++, offset++) + { + new_dim[axis] = + ((offset < dim_values) ? args.getInputShapes()[offset] : tensor->dims->data[axis]); + } + + interpreter->ResizeInputTensor(id, new_dim); + + if (offset >= dim_values) + break; + } + interpreter->AllocateTensors(); + } + TFLiteRun::TensorLoader tensor_loader(*interpreter); - // Load input from dumped tensor file. - if (!args.getCompareFilename().empty()) + // Load input from raw or dumped tensor file. + // Two options are exclusive and will be checked from Args. + if (!args.getInputFilename().empty() || !args.getCompareFilename().empty()) { - tensor_loader.load(args.getCompareFilename()); + if (!args.getInputFilename().empty()) + { + tensor_loader.loadRawTensors(args.getInputFilename(), interpreter->inputs()); + } + else + { + tensor_loader.loadDumpedTensors(args.getCompareFilename()); + } for (const auto &o : interpreter->inputs()) { @@ -226,8 +266,7 @@ int main(const int argc, char **argv) // TODO Code duplication (copied from RandomTestRunner) - int tolerance = 1; - nnfw::misc::env::IntAccessor("TOLERANCE").access(tolerance); + int tolerance = nnfw::misc::EnvVar("TOLERANCE").asInt(1); auto equals = [tolerance](float lhs, float rhs) { // NOTE Hybrid approach |