summaryrefslogtreecommitdiff
path: root/tests/tools/tflite_loader/src/args.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tests/tools/tflite_loader/src/args.cc')
-rw-r--r--tests/tools/tflite_loader/src/args.cc83
1 files changed, 83 insertions, 0 deletions
diff --git a/tests/tools/tflite_loader/src/args.cc b/tests/tools/tflite_loader/src/args.cc
new file mode 100644
index 000000000..3fe1d0bf4
--- /dev/null
+++ b/tests/tools/tflite_loader/src/args.cc
@@ -0,0 +1,83 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "args.h"
+
+#include <iostream>
+
+#include <boost/filesystem.hpp>
+
+namespace TFLiteRun
+{
+
+Args::Args(const int argc, char **argv) noexcept
+{
+ Initialize();
+ Parse(argc, argv);
+}
+
+void Args::Initialize(void)
+{
+ // General options
+ po::options_description general("General options");
+
+ // clang-format off
+ general.add_options()
+ ("help,h", "Display available options")
+ ("tflite", po::value<std::string>()->default_value("")->required(), "Input tflite model file for serialization")
+ ("data,d", po::value<std::vector<std::string>>()->multitoken()->default_value(std::vector<std::string>{}, ""), "Input data file for model");
+ // clang-format on
+
+ _options.add(general);
+ _positional.add("tflite", 1);
+}
+
+void Args::print(char **argv)
+{
+ std::cout << "tflite_loader" << std::endl << std::endl;
+ std::cout << "Load tflite model by Loader and TFLite and compare their output" << std::endl;
+ std::cout << "Usage:" << std::endl;
+ std::cout << argv[0] << " --tflite model_file.tflite --data input_data.dat" << std::endl;
+ std::cout << _options;
+ std::cout << std::endl;
+}
+
+void Args::Parse(const int argc, char **argv)
+{
+ po::variables_map vm;
+ po::store(po::command_line_parser(argc, argv).options(_options).positional(_positional).run(),
+ vm);
+ po::notify(vm);
+
+ if (vm.count("help"))
+ {
+ print(argv);
+
+ exit(0);
+ }
+
+ if (vm.count("tflite"))
+ {
+ _tflite_filename = vm["tflite"].as<std::string>();
+ }
+
+ if (vm.count("data"))
+ {
+ _data_filenames = vm["data"].as<std::vector<std::string>>();
+ }
+}
+
+} // end of namespace TFLiteRun