diff options
Diffstat (limited to 'tools/tflite_benchmark_model/benchmark_model.cc')
-rw-r--r-- | tools/tflite_benchmark_model/benchmark_model.cc | 175 |
1 files changed, 175 insertions, 0 deletions
diff --git a/tools/tflite_benchmark_model/benchmark_model.cc b/tools/tflite_benchmark_model/benchmark_model.cc new file mode 100644 index 000000000..7869180bf --- /dev/null +++ b/tools/tflite_benchmark_model/benchmark_model.cc @@ -0,0 +1,175 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "benchmark_model.h" + +#include <time.h> + +#include <iostream> +#include <sstream> + +#include "tensorflow/contrib/lite/profiling/time.h" +#include "logging.h" + +namespace { +void SleepForSeconds(double sleep_seconds) { + if (sleep_seconds <= 0.0) { + return; + } + // Convert the run_delay string into a timespec. + timespec req; + req.tv_sec = static_cast<time_t>(sleep_seconds); + req.tv_nsec = (sleep_seconds - req.tv_sec) * 1000000000; + // If requested, sleep between runs for an arbitrary amount of time. + // This can be helpful to determine the effect of mobile processor + // scaling and thermal throttling. +#ifdef PLATFORM_WINDOWS + Sleep(sleep_seconds * 1000); +#else + nanosleep(&req, nullptr); +#endif +} + +} // namespace + +namespace nnfw { +namespace benchmark { +using tensorflow::Stat; + +BenchmarkParams BenchmarkModel::DefaultParams() { + BenchmarkParams params; + params.AddParam("num_runs", BenchmarkParam::Create<int32_t>(50)); + params.AddParam("run_delay", BenchmarkParam::Create<float>(-1.0f)); + params.AddParam("num_threads", BenchmarkParam::Create<int32_t>(1)); + params.AddParam("benchmark_name", BenchmarkParam::Create<std::string>("")); + params.AddParam("output_prefix", BenchmarkParam::Create<std::string>("")); + params.AddParam("warmup_runs", BenchmarkParam::Create<int32_t>(1)); + return params; +} + +BenchmarkModel::BenchmarkModel() : params_(DefaultParams()) {} + +void BenchmarkLoggingListener::OnBenchmarkEnd(const BenchmarkResults &results) { + auto inference_us = results.inference_time_us(); + auto init_us = results.startup_latency_us(); + auto warmup_us = results.warmup_time_us(); + TFLITE_LOG(INFO) << "Average inference timings in us: " + << "Warmup: " << warmup_us.avg() << ", " + << "Init: " << init_us << ", " + << "no stats: " << inference_us.avg(); +} + +std::vector<Flag> BenchmarkModel::GetFlags() { + return { + CreateFlag<int32_t>("num_runs", ¶ms_, "number of runs"), + CreateFlag<float>("run_delay", ¶ms_, "delay between runs in seconds"), + CreateFlag<int32_t>("num_threads", ¶ms_, "number of threads"), + CreateFlag<std::string>("benchmark_name", ¶ms_, "benchmark name"), + CreateFlag<std::string>("output_prefix", ¶ms_, + "benchmark output prefix"), + CreateFlag<int32_t>("warmup_runs", ¶ms_, + "how many runs to initialize model"), + }; +} + +void BenchmarkModel::LogFlags() { + TFLITE_LOG(INFO) << "Num runs: [" << params_.Get<int32_t>("num_runs") << "]"; + TFLITE_LOG(INFO) << "Inter-run delay (seconds): [" + << params_.Get<float>("run_delay") << "]"; + TFLITE_LOG(INFO) << "Num threads: [" << params_.Get<int32_t>("num_threads") + << "]"; + TFLITE_LOG(INFO) << "Benchmark name: [" + << params_.Get<std::string>("benchmark_name") << "]"; + TFLITE_LOG(INFO) << "Output prefix: [" + << params_.Get<std::string>("output_prefix") << "]"; + TFLITE_LOG(INFO) << "Warmup runs: [" << params_.Get<int32_t>("warmup_runs") + << "]"; +} + +Stat<int64_t> BenchmarkModel::Run(int num_times, RunType run_type) { + Stat<int64_t> run_stats; + TFLITE_LOG(INFO) << "Running benchmark for " << num_times << " iterations "; + for (int run = 0; run < num_times; run++) { + listeners_.OnSingleRunStart(run_type); + int64_t start_us = tflite::profiling::time::NowMicros(); + RunImpl(); + int64_t end_us = tflite::profiling::time::NowMicros(); + listeners_.OnSingleRunEnd(); + + run_stats.UpdateStat(end_us - start_us); + SleepForSeconds(params_.Get<float>("run_delay")); + } + + std::stringstream stream; + run_stats.OutputToStream(&stream); + TFLITE_LOG(INFO) << stream.str() << std::endl; + + return run_stats; +} + +void BenchmarkModel::Run(int argc, char **argv) { + if (!ParseFlags(argc, argv)) { + return; + } + + LogFlags(); + + listeners_.OnBenchmarkStart(params_); + int64_t initialization_start_us = tflite::profiling::time::NowMicros(); + Init(); + int64_t initialization_end_us = tflite::profiling::time::NowMicros(); + int64_t startup_latency_us = initialization_end_us - initialization_start_us; + TFLITE_LOG(INFO) << "Initialized session in " << startup_latency_us / 1e3 + << "ms"; + + uint64_t input_bytes = ComputeInputBytes(); + Stat<int64_t> warmup_time_us = + Run(params_.Get<int32_t>("warmup_runs"), WARMUP); + Stat<int64_t> inference_time_us = + Run(params_.Get<int32_t>("num_runs"), REGULAR); + listeners_.OnBenchmarkEnd( + {startup_latency_us, input_bytes, warmup_time_us, inference_time_us}); +} + +bool BenchmarkModel::ParseFlags(int argc, char **argv) { + auto flag_list = GetFlags(); + const bool parse_result = + Flags::Parse(&argc, const_cast<const char **>(argv), flag_list); + if (!parse_result) { + std::string usage = Flags::Usage(argv[0], flag_list); + TFLITE_LOG(ERROR) << usage; + return false; + } + return ValidateFlags(); +} + +} // namespace benchmark +} // namespace nnfw |