summaryrefslogtreecommitdiff
path: root/tests/tools/tflite_run/src/args.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tests/tools/tflite_run/src/args.cc')
-rw-r--r--tests/tools/tflite_run/src/args.cc31
1 files changed, 27 insertions, 4 deletions
diff --git a/tests/tools/tflite_run/src/args.cc b/tests/tools/tflite_run/src/args.cc
index 713a0a9d2..6c85d884e 100644
--- a/tests/tools/tflite_run/src/args.cc
+++ b/tests/tools/tflite_run/src/args.cc
@@ -23,7 +23,7 @@
namespace TFLiteRun
{
-Args::Args(const int argc, char **argv)
+Args::Args(const int argc, char **argv) noexcept
{
Initialize();
Parse(argc, argv);
@@ -38,7 +38,9 @@ void Args::Initialize(void)
// clang-format off
general.add_options()
("help,h", "Display available options")
+ ("input,i", po::value<std::string>()->default_value(""), "Input filename")
("dump,d", po::value<std::string>()->default_value(""), "Output filename")
+ ("ishapes", po::value<std::vector<int>>()->multitoken(), "Input shapes")
("compare,c", po::value<std::string>()->default_value(""), "filename to be compared with")
("tflite", po::value<std::string>()->required());
// clang-format on
@@ -52,9 +54,7 @@ void Args::Parse(const int argc, char **argv)
po::variables_map vm;
po::store(po::command_line_parser(argc, argv).options(_options).positional(_positional).run(),
vm);
- po::notify(vm);
-#if 0 // Enable this when we have mutually conflicting options
{
auto conflicting_options = [&](const std::string &o1, const std::string &o2) {
if ((vm.count(o1) && !vm[o1].defaulted()) && (vm.count(o2) && !vm[o2].defaulted()))
@@ -66,7 +66,6 @@ void Args::Parse(const int argc, char **argv)
conflicting_options("input", "compare");
}
-#endif
if (vm.count("help"))
{
@@ -78,6 +77,8 @@ void Args::Parse(const int argc, char **argv)
exit(0);
}
+ po::notify(vm);
+
if (vm.count("dump"))
{
_dump_filename = vm["dump"].as<std::string>();
@@ -88,6 +89,28 @@ void Args::Parse(const int argc, char **argv)
_compare_filename = vm["compare"].as<std::string>();
}
+ if (vm.count("input"))
+ {
+ _input_filename = vm["input"].as<std::string>();
+
+ if (!_input_filename.empty())
+ {
+ if (!boost::filesystem::exists(_input_filename))
+ {
+ std::cerr << "input image file not found: " << _input_filename << "\n";
+ }
+ }
+ }
+
+ if (vm.count("ishapes"))
+ {
+ _input_shapes.resize(vm["ishapes"].as<std::vector<int>>().size());
+ for (auto i = 0; i < _input_shapes.size(); i++)
+ {
+ _input_shapes[i] = vm["ishapes"].as<std::vector<int>>()[i];
+ }
+ }
+
if (vm.count("tflite"))
{
_tflite_filename = vm["tflite"].as<std::string>();