summaryrefslogtreecommitdiff
path: root/runtime/contrib/mlapse/tfl/driver.cc
diff options
context:
space:
mode:
Diffstat (limited to 'runtime/contrib/mlapse/tfl/driver.cc')
-rw-r--r--runtime/contrib/mlapse/tfl/driver.cc280
1 files changed, 280 insertions, 0 deletions
diff --git a/runtime/contrib/mlapse/tfl/driver.cc b/runtime/contrib/mlapse/tfl/driver.cc
new file mode 100644
index 000000000..867a6051a
--- /dev/null
+++ b/runtime/contrib/mlapse/tfl/driver.cc
@@ -0,0 +1,280 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "mlapse/benchmark_runner.h"
+#include "mlapse/multicast_observer.h"
+#include "mlapse/CSV_report_generator.h"
+
+#include "mlapse/tfl/load.h"
+
+// From 'nnfw_lib_tflite'
+#include <tflite/InterpreterSession.h>
+#include <tflite/NNAPISession.h>
+
+// From 'nnfw_lib_cpp14'
+#include <cpp14/memory.h>
+
+// From C++ Standard Library
+#include <cassert>
+#include <fstream>
+#include <iostream>
+#include <vector>
+
+namespace
+{
+
+using namespace mlapse;
+
+class ConsoleReporter final : public mlapse::BenchmarkObserver
+{
+public:
+ ConsoleReporter() = default;
+
+public:
+ void notify(const NotificationArg<PhaseBegin> &arg) final
+ {
+ _phase = arg.phase;
+ _count = arg.count;
+
+ std::cout << tag() << " BEGIN" << std::endl;
+ }
+
+ void notify(const NotificationArg<PhaseEnd> &arg) final
+ {
+ std::cout << tag() << " END" << std::endl;
+
+ _phase = mlapse::uninitialized_phase();
+ _count = 0;
+ }
+
+ void notify(const NotificationArg<IterationBegin> &arg) final { _index = arg.index; }
+
+ void notify(const NotificationArg<IterationEnd> &arg) final
+ {
+ std::cout << tag() << " " << progress() << " - " << arg.latency.count() << "ms" << std::endl;
+ }
+
+private:
+ std::string progress(void) const
+ {
+ return "[" + std::to_string(_index + 1) + "/" + std::to_string(_count) + "]";
+ }
+
+ std::string tag(void) const
+ {
+ switch (_phase)
+ {
+ case Phase::Warmup:
+ return "WARMUP";
+ case Phase::Record:
+ return "RECORD";
+ default:
+ break;
+ }
+
+ return "unknown";
+ }
+
+ Phase _phase = mlapse::uninitialized_phase();
+ uint32_t _count = 0;
+ uint32_t _index = 0;
+};
+
+} // namespace
+
+// Q. Is is worth to make a library for these routines?
+namespace
+{
+
+enum class SessionType
+{
+ Interp,
+ NNAPI,
+};
+
+class SessionBuilder
+{
+public:
+ SessionBuilder(const SessionType &type) : _type{type}
+ {
+ // DO NOTHING
+ }
+
+public:
+ std::unique_ptr<nnfw::tflite::Session> with(tflite::Interpreter *interp) const
+ {
+ switch (_type)
+ {
+ case SessionType::Interp:
+ return nnfw::cpp14::make_unique<nnfw::tflite::InterpreterSession>(interp);
+ case SessionType::NNAPI:
+ return nnfw::cpp14::make_unique<nnfw::tflite::NNAPISession>(interp);
+ default:
+ break;
+ }
+
+ return nullptr;
+ }
+
+ std::unique_ptr<nnfw::tflite::Session>
+ with(const std::unique_ptr<tflite::Interpreter> &interp) const
+ {
+ return with(interp.get());
+ }
+
+private:
+ SessionType _type;
+};
+
+SessionBuilder make_session(const SessionType &type) { return SessionBuilder{type}; }
+
+} // namespace
+
+namespace
+{
+
+// mlapse-tfl
+// [REQUIRED] --model <path/to/tflite>
+// [OPTIONAL] --warmup-count N (default = 3)
+// [OPTIONAL] --record-count N (default = 10)
+// [OPTIONAL] --thread N or auto (default = auto)
+// [OPTIOANL] --nnapi (default = off)
+// [OPTIONAL] --pause N (default = 0)
+// [OPTIONAL] --csv-report <path/to/csv>
+int entry(const int argc, char **argv)
+{
+ // Create an observer
+ mlapse::MulticastObserver observer;
+
+ observer.append(nnfw::cpp14::make_unique<ConsoleReporter>());
+
+ // Set default parameters
+ std::string model_path;
+ bool model_path_initialized = false;
+
+ SessionType session_type = SessionType::Interp;
+ uint32_t warmup_count = 3;
+ uint32_t record_count = 10;
+ int num_thread = -1; // -1 means "auto"
+
+ // Read command-line arguments
+ std::map<std::string, std::function<uint32_t(const char *const *)>> opts;
+
+ opts["--model"] = [&model_path, &model_path_initialized](const char *const *tok) {
+ model_path = std::string{tok[0]};
+ model_path_initialized = true;
+ return 1; // # of arguments
+ };
+
+ opts["--record-count"] = [&record_count](const char *const *tok) {
+ record_count = std::stoi(tok[0]);
+ return 1; // # of arguments
+ };
+
+ opts["--thread"] = [](const char *const *tok) {
+ assert(std::string{tok[0]} == "auto");
+ return 1;
+ };
+
+ opts["--nnapi"] = [&session_type](const char *const *) {
+ session_type = SessionType::NNAPI;
+ return 0;
+ };
+
+ opts["--csv-report"] = [&observer](const char *const *tok) {
+ observer.append(nnfw::cpp14::make_unique<mlapse::CSVReportGenerator>(tok[0]));
+ return 1;
+ };
+
+ {
+ uint32_t offset = 1;
+
+ while (offset < argc)
+ {
+ auto opt = argv[offset];
+
+ auto it = opts.find(opt);
+
+ if (it == opts.end())
+ {
+ std::cout << "INVALID OPTION: " << opt << std::endl;
+ return 255;
+ }
+
+ auto func = it->second;
+
+ auto num_skip = func(argv + offset + 1);
+
+ offset += 1;
+ offset += num_skip;
+ }
+ }
+
+ // Check arguments
+ if (!model_path_initialized)
+ {
+ std::cerr << "ERROR: --model is missing" << std::endl;
+ return 255;
+ }
+
+ // Load T/F Lite model
+ auto model = mlapse::tfl::load_model(model_path);
+
+ if (model == nullptr)
+ {
+ std::cerr << "ERROR: Failed to load '" << model_path << "'" << std::endl;
+ return 255;
+ }
+
+ auto interp = mlapse::tfl::make_interpreter(model.get());
+
+ if (interp == nullptr)
+ {
+ std::cerr << "ERROR: Failed to create a T/F Lite interpreter" << std::endl;
+ return 255;
+ }
+
+ auto sess = make_session(session_type).with(interp);
+
+ if (sess == nullptr)
+ {
+ std::cerr << "ERROR: Failed to create a session" << std::endl;
+ }
+
+ // Run benchmark
+ mlapse::BenchmarkRunner benchmark_runner{warmup_count, record_count};
+
+ benchmark_runner.attach(&observer);
+ benchmark_runner.run(sess);
+
+ return 0;
+}
+
+} // namespace
+
+int main(int argc, char **argv)
+{
+ try
+ {
+ return entry(argc, argv);
+ }
+ catch (const std::exception &e)
+ {
+ std::cerr << e.what() << std::endl;
+ }
+
+ return 255;
+}