summaryrefslogtreecommitdiff
path: root/tools
diff options
context:
space:
mode:
author이상규/동작제어Lab(SR)/Principal Engineer/삼성전자 <sg5.lee@samsung.com>2018-11-26 14:56:12 +0900
committer오형석/동작제어Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>2018-11-26 14:56:12 +0900
commit9fe5ce4d9d386e50e529e52b09bb4951d887a7b1 (patch)
tree394716a4b8aa3eeb16a8863f1b6ba408377c1760 /tools
parentaf76fbca19f76449164406f7ff2e69dc8e6a5c88 (diff)
downloadnnfw-9fe5ce4d9d386e50e529e52b09bb4951d887a7b1.tar.gz
nnfw-9fe5ce4d9d386e50e529e52b09bb4951d887a7b1.tar.bz2
nnfw-9fe5ce4d9d386e50e529e52b09bb4951d887a7b1.zip
tflite_benchmark_model is updated to v1.12.0. (#3660)
Most files are not changed from v1.12.0. My modification is to support multiple kernel expansion operators. You can find the changes from stats_calculator.cc and profile_summarizer.cc. Signed-off-by: Sanggyu Lee <sg5.lee@samsung.com>
Diffstat (limited to 'tools')
-rw-r--r--tools/tflite_benchmark_model/CMakeLists.txt7
-rw-r--r--tools/tflite_benchmark_model/README.md24
-rw-r--r--tools/tflite_benchmark_model/benchmark_main.cc53
-rw-r--r--tools/tflite_benchmark_model/benchmark_model.cc175
-rw-r--r--tools/tflite_benchmark_model/benchmark_model.h177
-rw-r--r--tools/tflite_benchmark_model/benchmark_params.cc73
-rw-r--r--tools/tflite_benchmark_model/benchmark_params.h118
-rw-r--r--tools/tflite_benchmark_model/benchmark_tflite_model.cc105
-rw-r--r--tools/tflite_benchmark_model/benchmark_tflite_model.h95
-rw-r--r--tools/tflite_benchmark_model/command_line_flags.cc214
-rw-r--r--tools/tflite_benchmark_model/command_line_flags.h141
-rw-r--r--tools/tflite_benchmark_model/logging.h92
-rw-r--r--tools/tflite_benchmark_model/profile_summarizer.cc45
-rw-r--r--tools/tflite_benchmark_model/profile_summarizer.h55
-rw-r--r--tools/tflite_benchmark_model/stats_calculator.cc317
15 files changed, 416 insertions, 1275 deletions
diff --git a/tools/tflite_benchmark_model/CMakeLists.txt b/tools/tflite_benchmark_model/CMakeLists.txt
index d52690460..dd54dc5b5 100644
--- a/tools/tflite_benchmark_model/CMakeLists.txt
+++ b/tools/tflite_benchmark_model/CMakeLists.txt
@@ -1,5 +1,12 @@
file(GLOB_RECURSE SOURCES "*.cc")
+nnfw_find_package(TensorFlowSource REQUIRED)
+set(TENSORFLOW_LITE_BASE "${TensorFlowSource_DIR}/tensorflow/contrib/lite")
+list(APPEND SOURCES "${TENSORFLOW_LITE_BASE}/tools/benchmark/benchmark_main.cc"
+ "${TENSORFLOW_LITE_BASE}/tools/benchmark/benchmark_model.cc"
+ "${TENSORFLOW_LITE_BASE}/tools/benchmark/benchmark_params.cc"
+ "${TENSORFLOW_LITE_BASE}/tools/benchmark/command_line_flags.cc")
+
add_executable(tflite_benchmark_model ${SOURCES})
target_compile_definitions(tflite_benchmark_model PUBLIC "TFLITE_PROFILING_ENABLED")
target_link_libraries(tflite_benchmark_model tensorflow-lite ${LIB_PTHREAD} dl nnfw_util nnfw_support_tflite)
diff --git a/tools/tflite_benchmark_model/README.md b/tools/tflite_benchmark_model/README.md
index 93769305b..8d997639f 100644
--- a/tools/tflite_benchmark_model/README.md
+++ b/tools/tflite_benchmark_model/README.md
@@ -9,7 +9,7 @@ of runs. Aggregrate latency statistics are reported after running the benchmark.
The instructions below are for running the binary on Desktop and Android,
for iOS please use the
-[iOS benchmark app] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios).
+[iOS benchmark app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios).
## Parameters
@@ -17,11 +17,6 @@ The binary takes the following required parameters:
* `graph`: `string` \
The path to the TFLite model file.
-* `input_layer`: `string` \
- The name of the input layer, this is typically the first layer of the model.
-* `input_layer_shape`: `string` \
- The shape of the input layer. This is a comma separated string of the shape
- of tensor of input layer.
and the following optional parameters:
@@ -29,11 +24,13 @@ and the following optional parameters:
The number of threads to use for running TFLite interpreter.
* `warmup_runs`: `int` (default=1) \
The number of warmup runs to do before starting the benchmark.
+* `num_runs`: `int` (default=50) \
+ The number of runs. Increase this to reduce variance.
* `run_delay`: `float` (default=-1.0) \
The delay in seconds between subsequent benchmark runs. Non-positive values
mean use no delay.
* `use_nnapi`: `bool` (default=false) \
- Whether to use [Android NNAPI] (https://developer.android.com/ndk/guides/neuralnetworks/).
+ Whether to use [Android NNAPI](https://developer.android.com/ndk/guides/neuralnetworks/).
This API is available on recent Android devices.
## To build/install/run
@@ -75,8 +72,6 @@ adb push mobilenet_quant_v1_224.tflite /data/local/tmp
```
adb shell /data/local/tmp/benchmark_model \
--graph=/data/local/tmp/mobilenet_quant_v1_224.tflite \
- --input_layer="input" \
- --input_layer_shape="1,224,224,3" \
--num_threads=4
```
@@ -93,13 +88,10 @@ For example:
```
bazel-bin/tensorflow/contrib/lite/tools/benchmark/benchmark_model \
--graph=mobilenet_quant_v1_224.tflite \
- --input_layer="Placeholder" \
- --input_layer_shape="1,224,224,3" \
--num_threads=4
```
-The MobileNet graph used as an example here may be downloaded from
-https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip
+The MobileNet graph used as an example here may be downloaded from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip).
## Reducing variance between runs on Android.
@@ -115,10 +107,8 @@ E.g. for running the benchmark on big cores on Pixel 2 with a single thread one
can use the following command:
```
-adb shell tasket f0 /data/local/tmp/benchmark_model \
+adb shell taskset f0 /data/local/tmp/benchmark_model \
--graph=/data/local/tmp/mobilenet_quant_v1_224.tflite \
- --input_layer="input" \
- --input_layer_shape="1,224,224,3" \
--num_threads=1
```
@@ -205,5 +195,3 @@ Memory (bytes): count=0
Average inference timings in us: Warmup: 83235, Init: 38467, no stats: 79760.9
```
-
-
diff --git a/tools/tflite_benchmark_model/benchmark_main.cc b/tools/tflite_benchmark_model/benchmark_main.cc
deleted file mode 100644
index 7e4231c48..000000000
--- a/tools/tflite_benchmark_model/benchmark_main.cc
+++ /dev/null
@@ -1,53 +0,0 @@
-/*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "benchmark_tflite_model.h"
-#include "logging.h"
-
-namespace nnfw {
-namespace benchmark {
-
-int Main(int argc, char** argv) {
-#ifdef TFLITE_CUSTOM_OPS_HEADER
- TFLITE_LOG(INFO) << "STARTING with custom ops!";
-#else
- TFLITE_LOG(INFO) << "STARTING!";
-#endif
- BenchmarkTfLiteModel benchmark;
- BenchmarkLoggingListener listener;
- benchmark.AddListener(&listener);
- benchmark.Run(argc, argv);
- return 0;
-}
-} // namespace benchmark
-} // namespace nnfw
-
-int main(int argc, char** argv) { return nnfw::benchmark::Main(argc, argv); }
diff --git a/tools/tflite_benchmark_model/benchmark_model.cc b/tools/tflite_benchmark_model/benchmark_model.cc
deleted file mode 100644
index 7869180bf..000000000
--- a/tools/tflite_benchmark_model/benchmark_model.cc
+++ /dev/null
@@ -1,175 +0,0 @@
-/*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "benchmark_model.h"
-
-#include <time.h>
-
-#include <iostream>
-#include <sstream>
-
-#include "tensorflow/contrib/lite/profiling/time.h"
-#include "logging.h"
-
-namespace {
-void SleepForSeconds(double sleep_seconds) {
- if (sleep_seconds <= 0.0) {
- return;
- }
- // Convert the run_delay string into a timespec.
- timespec req;
- req.tv_sec = static_cast<time_t>(sleep_seconds);
- req.tv_nsec = (sleep_seconds - req.tv_sec) * 1000000000;
- // If requested, sleep between runs for an arbitrary amount of time.
- // This can be helpful to determine the effect of mobile processor
- // scaling and thermal throttling.
-#ifdef PLATFORM_WINDOWS
- Sleep(sleep_seconds * 1000);
-#else
- nanosleep(&req, nullptr);
-#endif
-}
-
-} // namespace
-
-namespace nnfw {
-namespace benchmark {
-using tensorflow::Stat;
-
-BenchmarkParams BenchmarkModel::DefaultParams() {
- BenchmarkParams params;
- params.AddParam("num_runs", BenchmarkParam::Create<int32_t>(50));
- params.AddParam("run_delay", BenchmarkParam::Create<float>(-1.0f));
- params.AddParam("num_threads", BenchmarkParam::Create<int32_t>(1));
- params.AddParam("benchmark_name", BenchmarkParam::Create<std::string>(""));
- params.AddParam("output_prefix", BenchmarkParam::Create<std::string>(""));
- params.AddParam("warmup_runs", BenchmarkParam::Create<int32_t>(1));
- return params;
-}
-
-BenchmarkModel::BenchmarkModel() : params_(DefaultParams()) {}
-
-void BenchmarkLoggingListener::OnBenchmarkEnd(const BenchmarkResults &results) {
- auto inference_us = results.inference_time_us();
- auto init_us = results.startup_latency_us();
- auto warmup_us = results.warmup_time_us();
- TFLITE_LOG(INFO) << "Average inference timings in us: "
- << "Warmup: " << warmup_us.avg() << ", "
- << "Init: " << init_us << ", "
- << "no stats: " << inference_us.avg();
-}
-
-std::vector<Flag> BenchmarkModel::GetFlags() {
- return {
- CreateFlag<int32_t>("num_runs", &params_, "number of runs"),
- CreateFlag<float>("run_delay", &params_, "delay between runs in seconds"),
- CreateFlag<int32_t>("num_threads", &params_, "number of threads"),
- CreateFlag<std::string>("benchmark_name", &params_, "benchmark name"),
- CreateFlag<std::string>("output_prefix", &params_,
- "benchmark output prefix"),
- CreateFlag<int32_t>("warmup_runs", &params_,
- "how many runs to initialize model"),
- };
-}
-
-void BenchmarkModel::LogFlags() {
- TFLITE_LOG(INFO) << "Num runs: [" << params_.Get<int32_t>("num_runs") << "]";
- TFLITE_LOG(INFO) << "Inter-run delay (seconds): ["
- << params_.Get<float>("run_delay") << "]";
- TFLITE_LOG(INFO) << "Num threads: [" << params_.Get<int32_t>("num_threads")
- << "]";
- TFLITE_LOG(INFO) << "Benchmark name: ["
- << params_.Get<std::string>("benchmark_name") << "]";
- TFLITE_LOG(INFO) << "Output prefix: ["
- << params_.Get<std::string>("output_prefix") << "]";
- TFLITE_LOG(INFO) << "Warmup runs: [" << params_.Get<int32_t>("warmup_runs")
- << "]";
-}
-
-Stat<int64_t> BenchmarkModel::Run(int num_times, RunType run_type) {
- Stat<int64_t> run_stats;
- TFLITE_LOG(INFO) << "Running benchmark for " << num_times << " iterations ";
- for (int run = 0; run < num_times; run++) {
- listeners_.OnSingleRunStart(run_type);
- int64_t start_us = tflite::profiling::time::NowMicros();
- RunImpl();
- int64_t end_us = tflite::profiling::time::NowMicros();
- listeners_.OnSingleRunEnd();
-
- run_stats.UpdateStat(end_us - start_us);
- SleepForSeconds(params_.Get<float>("run_delay"));
- }
-
- std::stringstream stream;
- run_stats.OutputToStream(&stream);
- TFLITE_LOG(INFO) << stream.str() << std::endl;
-
- return run_stats;
-}
-
-void BenchmarkModel::Run(int argc, char **argv) {
- if (!ParseFlags(argc, argv)) {
- return;
- }
-
- LogFlags();
-
- listeners_.OnBenchmarkStart(params_);
- int64_t initialization_start_us = tflite::profiling::time::NowMicros();
- Init();
- int64_t initialization_end_us = tflite::profiling::time::NowMicros();
- int64_t startup_latency_us = initialization_end_us - initialization_start_us;
- TFLITE_LOG(INFO) << "Initialized session in " << startup_latency_us / 1e3
- << "ms";
-
- uint64_t input_bytes = ComputeInputBytes();
- Stat<int64_t> warmup_time_us =
- Run(params_.Get<int32_t>("warmup_runs"), WARMUP);
- Stat<int64_t> inference_time_us =
- Run(params_.Get<int32_t>("num_runs"), REGULAR);
- listeners_.OnBenchmarkEnd(
- {startup_latency_us, input_bytes, warmup_time_us, inference_time_us});
-}
-
-bool BenchmarkModel::ParseFlags(int argc, char **argv) {
- auto flag_list = GetFlags();
- const bool parse_result =
- Flags::Parse(&argc, const_cast<const char **>(argv), flag_list);
- if (!parse_result) {
- std::string usage = Flags::Usage(argv[0], flag_list);
- TFLITE_LOG(ERROR) << usage;
- return false;
- }
- return ValidateFlags();
-}
-
-} // namespace benchmark
-} // namespace nnfw
diff --git a/tools/tflite_benchmark_model/benchmark_model.h b/tools/tflite_benchmark_model/benchmark_model.h
deleted file mode 100644
index 5645e2910..000000000
--- a/tools/tflite_benchmark_model/benchmark_model.h
+++ /dev/null
@@ -1,177 +0,0 @@
-/*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef __TFLITE_BENCHMARK_MODEL_BENCHMARK_MODEL_H__
-#define __TFLITE_BENCHMARK_MODEL_BENCHMARK_MODEL_H__
-
-#include <cmath>
-#include <limits>
-#include <ostream>
-#include <string>
-#include <unordered_set>
-#include <vector>
-
-#include "benchmark_params.h"
-#include "command_line_flags.h"
-#include "tensorflow/core/util/stats_calculator.h"
-
-namespace nnfw {
-namespace benchmark {
-
-enum RunType {
- WARMUP,
- REGULAR,
-};
-
-class BenchmarkResults {
- public:
- BenchmarkResults(int64_t startup_latency_us, uint64_t input_bytes,
- tensorflow::Stat<int64_t> warmup_time_us,
- tensorflow::Stat<int64_t> inference_time_us)
- : startup_latency_us_(startup_latency_us),
- input_bytes_(input_bytes),
- warmup_time_us_(warmup_time_us),
- inference_time_us_(inference_time_us) {}
-
- tensorflow::Stat<int64_t> inference_time_us() const {
- return inference_time_us_;
- }
- tensorflow::Stat<int64_t> warmup_time_us() const { return warmup_time_us_; }
- int64_t startup_latency_us() const { return startup_latency_us_; }
- uint64_t input_bytes() const { return input_bytes_; }
- double throughput_MB_per_second() const {
- double bytes_per_sec = (input_bytes_ * inference_time_us_.count() * 1e6) /
- inference_time_us_.sum();
- return bytes_per_sec / (1024.0 * 1024.0);
- }
-
- private:
- int64_t startup_latency_us_;
- uint64_t input_bytes_;
- tensorflow::Stat<int64_t> warmup_time_us_;
- tensorflow::Stat<int64_t> inference_time_us_;
-};
-
-class BenchmarkListener {
- public:
- virtual void OnBenchmarkStart(const BenchmarkParams& params) {}
- virtual void OnSingleRunStart(RunType runType) {}
- virtual void OnSingleRunEnd() {}
- virtual void OnBenchmarkEnd(const BenchmarkResults& results) {}
- virtual ~BenchmarkListener() {}
-};
-
-// A listener that forwards its method calls to a collection of listeners.
-class BenchmarkListeners : public BenchmarkListener {
- public:
- // Added a listener to the listener collection.
- // |listener| is not owned by the instance of |BenchmarkListeners|.
- // |listener| should not be null and should outlast the instance of
- // |BenchmarkListeners|.
- void AddListener(BenchmarkListener* listener) {
- listeners_.push_back(listener);
- }
-
- void OnBenchmarkStart(const BenchmarkParams& params) override {
- for (auto listener : listeners_) {
- listener->OnBenchmarkStart(params);
- }
- }
-
- void OnSingleRunStart(RunType runType) override {
- for (auto listener : listeners_) {
- listener->OnSingleRunStart(runType);
- }
- }
-
- void OnSingleRunEnd() override {
- for (auto listener : listeners_) {
- listener->OnSingleRunEnd();
- }
- }
-
- void OnBenchmarkEnd(const BenchmarkResults& results) override {
- for (auto listener : listeners_) {
- listener->OnBenchmarkEnd(results);
- }
- }
-
- ~BenchmarkListeners() {}
-
- private:
- // Use vector so listeners are invoked in the order they are added.
- std::vector<BenchmarkListener*> listeners_;
-};
-
-// Benchmark listener that just logs the results of benchmark run.
-class BenchmarkLoggingListener : public BenchmarkListener {
- void OnBenchmarkEnd(const BenchmarkResults& results) override;
-};
-
-template <typename T>
-Flag CreateFlag(const char* name, BenchmarkParams* params,
- const std::string& usage) {
- return Flag(name, [params, name](const T& val) { params->Set<T>(name, val); },
- params->Get<T>(name), usage);
-}
-
-// Benchmarks a model.
-//
-// Subclasses need to implement initialization and running of the model.
-// The results can be collected by adding BenchmarkListener(s).
-class BenchmarkModel {
- public:
- static BenchmarkParams DefaultParams();
- BenchmarkModel();
- BenchmarkModel(BenchmarkParams params) : params_(std::move(params)) {}
- virtual ~BenchmarkModel() {}
- bool ParseFlags(int argc, char** argv);
- virtual void Init() = 0;
- void Run(int argc, char** argv);
- void AddListener(BenchmarkListener* listener) {
- listeners_.AddListener(listener);
- }
-
- protected:
- virtual void LogFlags();
- virtual bool ValidateFlags() { return true; }
- virtual std::vector<Flag> GetFlags();
- virtual uint64_t ComputeInputBytes() = 0;
- virtual tensorflow::Stat<int64_t> Run(int num_times, RunType run_type);
- virtual void RunImpl() = 0;
- BenchmarkParams params_;
- BenchmarkListeners listeners_;
-};
-
-} // namespace benchmark
-} // namespace nnfw
-
-#endif //__TFLITE_BENCHMARK_MODEL_BENCHMARK_MODEL_H__
diff --git a/tools/tflite_benchmark_model/benchmark_params.cc b/tools/tflite_benchmark_model/benchmark_params.cc
deleted file mode 100644
index 7b667a442..000000000
--- a/tools/tflite_benchmark_model/benchmark_params.cc
+++ /dev/null
@@ -1,73 +0,0 @@
-/*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "benchmark_params.h"
-
-#include <string>
-#include <unordered_map>
-#include <vector>
-
-#include "logging.h"
-
-namespace nnfw {
-namespace benchmark {
-
-void BenchmarkParam::AssertHasSameType(BenchmarkParam::ParamType a,
- BenchmarkParam::ParamType b) {
- TFLITE_BENCHMARK_CHECK(a == b) << "Type mismatch while accessing parameter.";
-}
-
-template <>
-BenchmarkParam::ParamType BenchmarkParam::GetValueType<int32_t>() {
- return BenchmarkParam::ParamType::TYPE_INT32;
-}
-
-template <>
-BenchmarkParam::ParamType BenchmarkParam::GetValueType<bool>() {
- return BenchmarkParam::ParamType::TYPE_BOOL;
-}
-
-template <>
-BenchmarkParam::ParamType BenchmarkParam::GetValueType<float>() {
- return BenchmarkParam::ParamType::TYPE_FLOAT;
-}
-
-template <>
-BenchmarkParam::ParamType BenchmarkParam::GetValueType<std::string>() {
- return BenchmarkParam::ParamType::TYPE_STRING;
-}
-
-void BenchmarkParams::AssertParamExists(const std::string& name) const {
- TFLITE_BENCHMARK_CHECK(HasParam(name)) << name << " was not found.";
-}
-
-} // namespace benchmark
-} // namespace nnfw
diff --git a/tools/tflite_benchmark_model/benchmark_params.h b/tools/tflite_benchmark_model/benchmark_params.h
deleted file mode 100644
index 1ac3f4af6..000000000
--- a/tools/tflite_benchmark_model/benchmark_params.h
+++ /dev/null
@@ -1,118 +0,0 @@
-/*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef __TFLITE_BENCHMARK_MODEL_BENCHMARK_PARAMS_H__
-#define __TFLITE_BENCHMARK_MODEL_BENCHMARK_PARAMS_H__
-#include <memory>
-#include <string>
-#include <unordered_map>
-#include <vector>
-
-#include "logging.h"
-
-namespace nnfw {
-namespace benchmark {
-
-template <typename T>
-class TypedBenchmarkParam;
-
-class BenchmarkParam {
- protected:
- enum class ParamType { TYPE_INT32, TYPE_FLOAT, TYPE_BOOL, TYPE_STRING };
-
- public:
- template <typename T>
- static std::unique_ptr<BenchmarkParam> Create(const T& default_value) {
- return std::unique_ptr<BenchmarkParam>(
- new TypedBenchmarkParam<T>(default_value));
- }
-
- template <typename T>
- TypedBenchmarkParam<T>* AsTyped() {
- AssertHasSameType(GetValueType<T>(), type_);
- return static_cast<TypedBenchmarkParam<T>*>(this);
- }
- virtual ~BenchmarkParam() {}
- BenchmarkParam(ParamType type) : type_(type) {}
-
- private:
- static void AssertHasSameType(ParamType a, ParamType b);
- protected:
- template <typename T>
- static ParamType GetValueType();
-
- const ParamType type_;
-};
-
-template <typename T>
-class TypedBenchmarkParam : public BenchmarkParam {
- public:
- TypedBenchmarkParam(const T& value)
- : BenchmarkParam(GetValueType<T>()), value_(value) {}
- void Set(const T& value) { value_ = value; }
-
- T Get() { return value_; }
-
- private:
- T value_;
-};
-
-class BenchmarkParams {
- public:
- void AddParam(const std::string& name,
- std::unique_ptr<BenchmarkParam> value) {
- params_[name] = std::move(value);
- }
-
- bool HasParam(const std::string& name) const {
- return params_.find(name) != params_.end();
- }
-
- template <typename T>
- void Set(const std::string& name, const T& value) {
- AssertParamExists(name);
- params_.at(name)->AsTyped<T>()->Set(value);
- }
-
- template <typename T>
- T Get(const std::string& name) const {
- AssertParamExists(name);
- return params_.at(name)->AsTyped<T>()->Get();
- }
-
- private:
- void AssertParamExists(const std::string& name) const;
- std::unordered_map<std::string, std::unique_ptr<BenchmarkParam>> params_;
-};
-
-} // namespace benchmark
-} // namespace nnfw
-#endif // __TFLITE_BENCHMARK_MODEL_BENCHMARK_PARAMS_H__
diff --git a/tools/tflite_benchmark_model/benchmark_tflite_model.cc b/tools/tflite_benchmark_model/benchmark_tflite_model.cc
index d277795a3..611bd6a78 100644
--- a/tools/tflite_benchmark_model/benchmark_tflite_model.cc
+++ b/tools/tflite_benchmark_model/benchmark_tflite_model.cc
@@ -29,7 +29,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "benchmark_tflite_model.h"
+#include "tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h"
#include <cstdarg>
#include <cstdlib>
@@ -39,11 +39,16 @@ limitations under the License.
#include <unordered_set>
#include <vector>
+#ifdef TFLITE_FLEX
+#include "tensorflow/contrib/lite/delegates/flex/delegate.h"
+#endif // TFLITE_FLEX
#include "support/tflite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/op_resolver.h"
#include "tensorflow/contrib/lite/string_util.h"
-#include "logging.h"
+#include "tensorflow/contrib/lite/tools/benchmark/logging.h"
+
+// For profiling nnapi_delegate
#include "util/profiling/profiling.h"
#include "support/tflite/nnapi_delegate.h"
@@ -51,7 +56,7 @@ limitations under the License.
void RegisterSelectedOps(::tflite::MutableOpResolver* resolver);
#endif
-namespace nnfw {
+namespace tflite {
namespace benchmark {
void ProfilingListener::SetInterpreter(tflite::Interpreter* interpreter) {
@@ -130,7 +135,7 @@ void FillRandomValue(T* ptr, const std::vector<int>& sizes,
void FillRandomString(tflite::DynamicBuffer* buffer,
const std::vector<int>& sizes,
- const std::function<std::string()>& random_func) {
+ const std::function<string()>& random_func) {
int num_elements = 1;
for (int dim : sizes) {
num_elements *= dim;
@@ -142,7 +147,7 @@ void FillRandomString(tflite::DynamicBuffer* buffer,
}
bool PopulateInputLayerInfo(
- const std::string& names_string, const std::string& shapes_string,
+ const string& names_string, const string& shapes_string,
std::vector<BenchmarkTfLiteModel::InputLayerInfo>* info) {
std::vector<std::string> names = Split(names_string, ',');
std::vector<std::string> shapes = Split(shapes_string, ':');
@@ -216,8 +221,8 @@ std::vector<Flag> BenchmarkTfLiteModel::GetFlags() {
return flags;
}
-void BenchmarkTfLiteModel::LogFlags() {
- BenchmarkModel::LogFlags();
+void BenchmarkTfLiteModel::LogParams() {
+ BenchmarkModel::LogParams();
TFLITE_LOG(INFO) << "Graph: [" << params_.Get<std::string>("graph") << "]";
TFLITE_LOG(INFO) << "Input layers: ["
<< params_.Get<std::string>("input_layer") << "]";
@@ -226,7 +231,7 @@ void BenchmarkTfLiteModel::LogFlags() {
TFLITE_LOG(INFO) << "Use nnapi : [" << params_.Get<bool>("use_nnapi") << "]";
}
-bool BenchmarkTfLiteModel::ValidateFlags() {
+bool BenchmarkTfLiteModel::ValidateParams() {
if (params_.Get<std::string>("graph").empty()) {
TFLITE_LOG(ERROR)
<< "Please specify the name of your TF Lite input file with --graph";
@@ -247,6 +252,46 @@ uint64_t BenchmarkTfLiteModel::ComputeInputBytes() {
return total_input_bytes;
}
+void BenchmarkTfLiteModel::PrepareInputsAndOutputs() {
+ auto interpreter_inputs = interpreter->inputs();
+ // Set the values of the input tensors.
+ for (int j = 0; j < inputs.size(); ++j) {
+ const InputLayerInfo& input = inputs[j];
+ int i = interpreter_inputs[j];
+ TfLiteTensor* t = interpreter->tensor(i);
+ std::vector<int> sizes = input.shape;
+
+ // TODO(ahentz): below we ignore the O-th dimension (number of batches).
+ if (t->type == kTfLiteFloat32) {
+ FillRandomValue<float>(
+ interpreter->typed_tensor<float>(i),
+ std::vector<int>(sizes.begin() + 1, sizes.end()),
+ []() { return static_cast<float>(rand()) / RAND_MAX - 0.5f; });
+ } else if (t->type == kTfLiteInt32) {
+ // TODO(yunluli): This is currently only used for handling embedding input
+ // for speech models. Generalize if necessary.
+ FillRandomValue<int32_t>(
+ interpreter->typed_tensor<int32_t>(i),
+ std::vector<int32_t>(sizes.begin() + 1, sizes.end()),
+ []() { return static_cast<int32_t>(rand()) % 100; });
+ } else if (t->type == kTfLiteUInt8) {
+ FillRandomValue<uint8_t>(
+ interpreter->typed_tensor<uint8_t>(i),
+ std::vector<int>(sizes.begin() + 1, sizes.end()),
+ []() { return static_cast<uint8_t>(rand()) % 255; });
+ } else if (t->type == kTfLiteString) {
+ tflite::DynamicBuffer buffer;
+ FillRandomString(&buffer, sizes, []() {
+ return "we're have some friends over saturday to hang out in the yard";
+ });
+ buffer.WriteToTensor(interpreter->tensor(i));
+ } else {
+ TFLITE_LOG(FATAL) << "Don't know how to populate tensor " << t->name
+ << " of type " << t->type;
+ }
+ }
+}
+
void BenchmarkTfLiteModel::Init() {
std::string graph = params_.Get<std::string>("graph");
model = tflite::FlatBufferModel::BuildFromFile(graph.c_str());
@@ -269,7 +314,7 @@ void BenchmarkTfLiteModel::Init() {
TFLITE_LOG(FATAL) << "Failed to construct interpreter";
}
profiling_listener_.SetInterpreter(interpreter.get());
- profiling::Context::get().setProfiler(interpreter->GetProfiler());
+ ::profiling::Context::get().setProfiler(interpreter->GetProfiler());
const int32_t num_threads = params_.Get<int32_t>("num_threads");
@@ -280,6 +325,16 @@ void BenchmarkTfLiteModel::Init() {
bool use_nnapi = params_.Get<bool>("use_nnapi");
interpreter->UseNNAPI(use_nnapi);
+
+#ifdef TFLITE_FLEX
+ TFLITE_LOG(INFO) << "Instantiating Flex Delegate";
+ delegate_ = FlexDelegate::Create();
+ if (delegate_) {
+ interpreter->ModifyGraphWithDelegate(delegate_.get(),
+ /*allow_dynamic_tensors=*/true);
+ }
+#endif // TFLITE_FLEX
+
auto interpreter_inputs = interpreter->inputs();
if (!inputs.empty()) {
@@ -311,36 +366,6 @@ void BenchmarkTfLiteModel::Init() {
if (interpreter->AllocateTensors() != kTfLiteOk) {
TFLITE_LOG(FATAL) << "Failed to allocate tensors!";
}
-
- // Set the values of the input tensors.
- for (int j = 0; j < inputs.size(); ++j) {
- const InputLayerInfo& input = inputs[j];
- int i = interpreter_inputs[j];
- TfLiteTensor* t = interpreter->tensor(i);
- std::vector<int> sizes = input.shape;
-
- // TODO(ahentz): below we ignore the O-th dimension (number of batches).
- if (t->type == kTfLiteFloat32) {
- FillRandomValue<float>(
- interpreter->typed_tensor<float>(i),
- std::vector<int>(sizes.begin() + 1, sizes.end()),
- []() { return static_cast<float>(rand()) / RAND_MAX - 0.5f; });
- } else if (t->type == kTfLiteUInt8) {
- FillRandomValue<uint8_t>(
- interpreter->typed_tensor<uint8_t>(i),
- std::vector<int>(sizes.begin() + 1, sizes.end()),
- []() { return static_cast<uint8_t>(rand()) % 255; });
- } else if (t->type == kTfLiteString) {
- tflite::DynamicBuffer buffer;
- FillRandomString(&buffer, sizes, []() {
- return "we're have some friends over saturday to hang out in the yard";
- });
- buffer.WriteToTensor(interpreter->tensor(i));
- } else {
- TFLITE_LOG(FATAL) << "Don't know how to populate tensor " << t->name
- << " of type " << t->type;
- }
- }
}
void BenchmarkTfLiteModel::RunImpl() {
@@ -357,4 +382,4 @@ void BenchmarkTfLiteModel::RunImpl() {
}
} // namespace benchmark
-} // namespace nnfw
+} // namespace tflite
diff --git a/tools/tflite_benchmark_model/benchmark_tflite_model.h b/tools/tflite_benchmark_model/benchmark_tflite_model.h
deleted file mode 100644
index 7892de1f7..000000000
--- a/tools/tflite_benchmark_model/benchmark_tflite_model.h
+++ /dev/null
@@ -1,95 +0,0 @@
-/*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef __TFLITE_BENCHMARK_MODEL_BENCHMARK_TFLITE_MODEL_H__
-#define __TFLITE_BENCHMARK_MODEL_BENCHMARK_TFLITE_MODEL_H__
-
-#include <memory>
-#include <string>
-#include <vector>
-
-#include "tensorflow/contrib/lite/model.h"
-#include "tensorflow/contrib/lite/profiling/profile_summarizer.h"
-#include "benchmark_model.h"
-
-namespace nnfw {
-namespace benchmark {
-
-// Dumps profiling events if profiling is enabled
-class ProfilingListener : public BenchmarkListener {
- public:
- explicit ProfilingListener() : interpreter_(nullptr), has_profiles_(false) {}
-
- void SetInterpreter(tflite::Interpreter* interpreter);
-
- void OnSingleRunStart(RunType run_type) override;
-
- void OnSingleRunEnd() override;
-
- void OnBenchmarkEnd(const BenchmarkResults& results) override;
-
- private:
- tflite::Interpreter* interpreter_;
- tflite::profiling::Profiler profiler_;
- tflite::profiling::ProfileSummarizer summarizer_;
- bool has_profiles_;
-};
-
-// Benchmarks a TFLite model by running tflite interpreter.
-class BenchmarkTfLiteModel : public BenchmarkModel {
- public:
- BenchmarkTfLiteModel();
- BenchmarkTfLiteModel(BenchmarkParams params);
-
- std::vector<Flag> GetFlags() override;
- void LogFlags() override;
- bool ValidateFlags() override;
- uint64_t ComputeInputBytes() override;
- void Init() override;
- void RunImpl() override;
- virtual ~BenchmarkTfLiteModel() {}
-
- struct InputLayerInfo {
- std::string name;
- std::vector<int> shape;
- };
-
- private:
- std::unique_ptr<tflite::FlatBufferModel> model;
- std::unique_ptr<tflite::Interpreter> interpreter;
- std::vector<InputLayerInfo> inputs;
- ProfilingListener profiling_listener_;
-};
-
-} // namespace benchmark
-} // namespace nnfw
-
-#endif //__TFLITE_BENCHMARK_MODEL_BENCHMARK_TFLITE_MODEL_H__
diff --git a/tools/tflite_benchmark_model/command_line_flags.cc b/tools/tflite_benchmark_model/command_line_flags.cc
deleted file mode 100644
index eacca9f73..000000000
--- a/tools/tflite_benchmark_model/command_line_flags.cc
+++ /dev/null
@@ -1,214 +0,0 @@
-/*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "command_line_flags.h"
-
-#include <cstring>
-#include <sstream>
-#include <string>
-#include <utility>
-#include <vector>
-
-namespace nnfw {
-namespace {
-
-template <typename T>
-std::string ToString(T val) {
- std::ostringstream stream;
- stream << val;
- return stream.str();
-}
-
-bool ParseFlag(const std::string& arg, const std::string& flag,
- const std::function<bool(const std::string&)>& parse_func,
- bool* value_parsing_ok) {
- *value_parsing_ok = true;
- std::string flag_prefix = "--" + flag + "=";
- if (arg.find(flag_prefix) != 0) {
- return false;
- }
- bool has_value = arg.size() >= flag_prefix.size();
- *value_parsing_ok = has_value;
- if (has_value) {
- *value_parsing_ok = parse_func(arg.substr(flag_prefix.size()));
- }
- return true;
-}
-
-template <typename T>
-bool ParseFlag(const std::string& flag_value,
- const std::function<void(const T&)>& hook) {
- std::istringstream stream(flag_value);
- T read_value;
- stream >> read_value;
- if (!stream.eof() && !stream.good()) {
- return false;
- }
- hook(read_value);
- return true;
-}
-
-bool ParseBoolFlag(const std::string& flag_value,
- const std::function<void(const bool&)>& hook) {
- if (flag_value != "true" && flag_value != "false") {
- return false;
- }
-
- hook(flag_value == "true");
- return true;
-}
-} // namespace
-
-Flag::Flag(const char* name, const std::function<void(const int32_t&)>& hook,
- int32_t default_value, const std::string& usage_text)
- : name_(name),
- type_(TYPE_INT32),
- value_hook_([hook](const std::string& flag_value) {
- return ParseFlag<int32_t>(flag_value, hook);
- }),
- default_for_display_(ToString(default_value)),
- usage_text_(usage_text) {}
-
-Flag::Flag(const char* name, const std::function<void(const int64_t&)>& hook,
- int64_t default_value, const std::string& usage_text)
- : name_(name),
- type_(TYPE_INT64),
- value_hook_([hook](const std::string& flag_value) {
- return ParseFlag<int64_t>(flag_value, hook);
- }),
- default_for_display_(ToString(default_value)),
- usage_text_(usage_text) {}
-
-Flag::Flag(const char* name, const std::function<void(const float&)>& hook,
- float default_value, const std::string& usage_text)
- : name_(name),
- type_(TYPE_FLOAT),
- value_hook_([hook](const std::string& flag_value) {
- return ParseFlag<float>(flag_value, hook);
- }),
- default_for_display_(ToString(default_value)),
- usage_text_(usage_text) {}
-
-Flag::Flag(const char* name, const std::function<void(const bool&)>& hook,
- bool default_value, const std::string& usage_text)
- : name_(name),
- type_(TYPE_BOOL),
- value_hook_([hook](const std::string& flag_value) {
- return ParseBoolFlag(flag_value, hook);
- }),
- default_for_display_(default_value ? "true" : "false"),
- usage_text_(usage_text) {}
-
-Flag::Flag(const char* name,
- const std::function<void(const std::string&)>& hook,
- const std::string& default_value, const std::string& usage_text)
- : name_(name),
- type_(TYPE_STRING),
- value_hook_([hook](const std::string& flag_value) {
- hook(flag_value);
- return true;
- }),
- default_for_display_(default_value),
- usage_text_(usage_text) {}
-
-bool Flag::Parse(const std::string& arg, bool* value_parsing_ok) const {
- return ParseFlag(arg, name_, value_hook_, value_parsing_ok);
-}
-
-std::string Flag::GetTypeName() const {
- switch (type_) {
- case TYPE_INT32:
- return "int32";
- case TYPE_INT64:
- return "int64";
- case TYPE_FLOAT:
- return "float";
- case TYPE_BOOL:
- return "bool";
- case TYPE_STRING:
- return "string";
- }
-
- return "unknown";
-}
-
-/*static*/ bool Flags::Parse(int* argc, const char** argv,
- const std::vector<Flag>& flag_list) {
- bool result = true;
- std::vector<const char*> unknown_flags;
- for (int i = 1; i < *argc; ++i) {
- if (std::string(argv[i]) == "--") {
- while (i < *argc) {
- unknown_flags.push_back(argv[i]);
- ++i;
- }
- break;
- }
-
- bool was_found = false;
- for (const Flag& flag : flag_list) {
- bool value_parsing_ok;
- was_found = flag.Parse(argv[i], &value_parsing_ok);
- if (!value_parsing_ok) {
- result = false;
- }
- if (was_found) {
- break;
- }
- }
- if (!was_found) {
- unknown_flags.push_back(argv[i]);
- }
- }
- int dst = 1; // Skip argv[0]
- for (auto f : unknown_flags) {
- argv[dst++] = f;
- }
- argv[dst++] = nullptr;
- *argc = unknown_flags.size() + 1;
- return result && (*argc < 2 || std::strcmp(argv[1], "--help") != 0);
-}
-
-/*static*/ std::string Flags::Usage(const std::string& cmdline,
- const std::vector<Flag>& flag_list) {
- std::ostringstream usage_text;
- usage_text << "usage: " << cmdline << "\n";
- if (!flag_list.empty()) {
- usage_text << "Flags:\n";
- }
-
- for (const Flag& flag : flag_list) {
- auto type_name = flag.GetTypeName();
- usage_text << "\t";
- usage_text << "--" << flag.name_ << "=" << flag.default_for_display_;
- usage_text << "\t" << type_name << "\t" << flag.usage_text_ << "\n";
- }
- return usage_text.str();
-}
-
-} // namespace nnfw
diff --git a/tools/tflite_benchmark_model/command_line_flags.h b/tools/tflite_benchmark_model/command_line_flags.h
deleted file mode 100644
index 766417d87..000000000
--- a/tools/tflite_benchmark_model/command_line_flags.h
+++ /dev/null
@@ -1,141 +0,0 @@
-/*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef __TFLITE_BENCHMARK_MODEL_COMMAND_LINE_FLAGS_H__
-#define __TFLITE_BENCHMARK_MODEL_COMMAND_LINE_FLAGS_H__
-
-#include <functional>
-#include <string>
-#include <vector>
-
-namespace nnfw {
-// A simple command-line argument parsing module.
-// Dependency free simplified port of core/util/command_line_flags.
-// This class is written for benchmarks and uses inefficient string
-// concatenation. This was written to avoid dependency on tensorflow/core/util
-// which transitively brings in a lot of other dependencies that are not
-// necessary for tflite benchmarking code.
-// The recommended way of using it is with local variables and an initializer
-// list of Flag objects, for example:
-//
-// int some_int = 10;
-// bool some_switch = false;
-// std::string some_name = "something";
-//
-// std::vector<tensorFlow::Flag> flag_list = {
-// Flag::CreateFlag("some_int", &some_int, "an integer that affects X"),
-// Flag::CreateFlag("some_switch", &some_switch, "a bool that affects Y"),
-// Flag::CreateFlag("some_name", &some_name, "a string that affects Z")
-// };
-// // Get usage message before ParseFlags() to capture default values.
-// std::string usage = Flag::Usage(argv[0], flag_list);
-// bool parsed_values_ok = Flags::Parse(&argc, argv, flag_list);
-//
-// tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
-// if (argc != 1 || !parsed_values_ok) {
-// ...output usage and error message...
-// }
-//
-// The argc and argv values are adjusted by the Parse function so all that
-// remains is the program name (at argv[0]) and any unknown arguments fill the
-// rest of the array. This means you can check for flags that weren't understood
-// by seeing if argv is greater than 1.
-// The result indicates if there were any errors parsing the values that were
-// passed to the command-line switches. For example, --some_int=foo would return
-// false because the argument is expected to be an integer.
-//
-// NOTE: Unlike gflags-style libraries, this library is intended to be
-// used in the `main()` function of your binary. It does not handle
-// flag definitions that are scattered around the source code.
-
-// A description of a single command line flag, holding its name, type, usage
-// text, and a pointer to the corresponding variable.
-class Flag {
- public:
- template <typename T>
- static Flag CreateFlag(const char* name, T* val, const char* usage) {
- return Flag(name, [val](const T& v) { *val = v; }, *val, usage);
- }
-
- Flag(const char* name, const std::function<void(const int32_t&)>& hook,
- int32_t default_value, const std::string& usage_text);
- Flag(const char* name, const std::function<void(const int64_t&)>& hook,
- int64_t default_value, const std::string& usage_text);
- Flag(const char* name, const std::function<void(const float&)>& hook,
- float default_value, const std::string& usage_text);
- Flag(const char* name, const std::function<void(const bool&)>& hook,
- bool default_value, const std::string& usage_text);
- Flag(const char* name, const std::function<void(const std::string&)>& hook,
- const std::string& default_value, const std::string& usage_text);
-
- private:
- friend class Flags;
-
- bool Parse(const std::string& arg, bool* value_parsing_ok) const;
-
- std::string name_;
- enum {
- TYPE_INT32,
- TYPE_INT64,
- TYPE_BOOL,
- TYPE_STRING,
- TYPE_FLOAT,
- } type_;
-
- std::string GetTypeName() const;
-
- std::function<bool(const std::string&)> value_hook_;
- std::string default_for_display_;
-
- std::string usage_text_;
-};
-
-class Flags {
- public:
- // Parse the command line represented by argv[0, ..., (*argc)-1] to find flag
- // instances matching flags in flaglist[]. Update the variables associated
- // with matching flags, and remove the matching arguments from (*argc, argv).
- // Return true iff all recognized flag values were parsed correctly, and the
- // first remaining argument is not "--help".
- static bool Parse(int* argc, const char** argv,
- const std::vector<Flag>& flag_list);
-
- // Return a usage message with command line cmdline, and the
- // usage_text strings in flag_list[].
- static std::string Usage(const std::string& cmdline,
- const std::vector<Flag>& flag_list);
-};
-
-} // namespace nnfw
-
-#endif // __TFLITE_BENCHMARK_MODEL_COMMAND_LINE_FLAGS_H__
-
-
diff --git a/tools/tflite_benchmark_model/logging.h b/tools/tflite_benchmark_model/logging.h
deleted file mode 100644
index e694a0926..000000000
--- a/tools/tflite_benchmark_model/logging.h
+++ /dev/null
@@ -1,92 +0,0 @@
-/*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef __TFLITE_BENCHMARK_MODEL_LOGGING_H_
-#define __TFLITE_BENCHMARK_MODEL_LOGGING_H_
-
-// LOG and CHECK macros for benchmarks.
-
-#include <cstdlib>
-#include <iostream>
-#include <sstream>
-
-namespace nnfw {
-namespace logging {
-// A wrapper that logs to stderr.
-//
-// Used for TFLITE_LOG and TFLITE_BENCHMARK_CHECK macros.
-class LoggingWrapper {
- public:
- enum class LogSeverity : int {
- INFO = 0,
- WARN = 1,
- ERROR = 2,
- FATAL = 3,
- };
- LoggingWrapper(LogSeverity severity)
- : severity_(severity), should_log_(true) {}
- LoggingWrapper(LogSeverity severity, bool log)
- : severity_(severity), should_log_(log) {}
- std::stringstream& Stream() { return stream_; }
- ~LoggingWrapper() {
- if (should_log_) {
- std::cerr << stream_.str() << std::endl;
- if (severity_ == LogSeverity::FATAL) {
- std::flush(std::cerr);
- std::abort();
- }
- }
- }
-
- private:
- std::stringstream stream_;
- LogSeverity severity_;
- bool should_log_;
-};
-
-} // namespace logging
-
-} // namespace nnfw
-
-#define TFLITE_LOG(severity) \
- nnfw::logging::LoggingWrapper( \
- nnfw::logging::LoggingWrapper::LogSeverity::severity) \
- .Stream()
-
-#define TFLITE_BENCHMARK_CHECK(condition) \
- nnfw::logging::LoggingWrapper( \
- nnfw::logging::LoggingWrapper::LogSeverity::FATAL, \
- (condition) ? false : true) \
- .Stream()
-
-#define TFLITE_BENCHMARK_CHECK_EQ(a, b) TFLITE_BENCHMARK_CHECK(a == b)
-
-#endif // __TFLITE_BENCHMARK_MODEL_BENCHMARK_LOGGING_H_
diff --git a/tools/tflite_benchmark_model/profile_summarizer.cc b/tools/tflite_benchmark_model/profile_summarizer.cc
index 4d12b50af..ce19b0c98 100644
--- a/tools/tflite_benchmark_model/profile_summarizer.cc
+++ b/tools/tflite_benchmark_model/profile_summarizer.cc
@@ -39,8 +39,6 @@ namespace tflite {
namespace profiling {
namespace {
-using Detail = tensorflow::StatsCalculator::Detail;
-
struct OperatorDetails {
std::string name;
std::vector<std::string> inputs;
@@ -94,18 +92,30 @@ OperatorDetails GetOperatorDetails(const tflite::Interpreter& interpreter,
} else {
op_name = tflite::EnumNamesBuiltinOperator()[code];
}
+ const char* profiling_string =
+ interpreter.OpProfilingString(node_reg->second, &node_reg->first);
OperatorDetails details;
details.name = op_name;
+ if (profiling_string) {
+ details.name += ":" + std::string(profiling_string);
+ }
details.inputs = GetTensorNames(interpreter, inputs);
details.outputs = GetTensorNames(interpreter, outputs);
return details;
}
+tensorflow::StatSummarizerOptions GetProfileSummarizerOptions() {
+ auto options = tensorflow::StatSummarizerOptions();
+ options.show_summary = true;
+ options.show_memory = false;
+ return options;
+}
+
} // namespace
ProfileSummarizer::ProfileSummarizer()
- : stats_calculator_(new ::tensorflow::StatsCalculator(
- tensorflow::StatSummarizerOptions())) {}
+ : stats_calculator_(
+ new ::tensorflow::StatsCalculator(GetProfileSummarizerOptions())) {}
void ProfileSummarizer::ProcessProfiles(
const std::vector<const ProfileEvent*>& profile_stats,
@@ -129,35 +139,22 @@ void ProfileSummarizer::ProcessProfiles(
int64_t base_start_us = events[0]->begin_timestamp_us;
int node_num = 0;
int64_t curr_total_us = 0;
- std::map<std::string, Detail> details;
int prev_op_idx = -1;
- int seq_no = 1;
+ int child_op_no = 1;
for (auto event : events) {
auto op_details = GetOperatorDetails(interpreter, event->event_metadata);
- bool is_continued = (prev_op_idx == event->event_metadata);
- seq_no = is_continued ? seq_no + 1 : 1;
- auto node_name = ToString(op_details.outputs) + "#" + std::to_string(seq_no);
- auto result = details.emplace(node_name, Detail());
- Detail* detail = &(result.first->second);
- detail->start_us.UpdateStat(event->begin_timestamp_us - base_start_us);
+ bool from_same_op = (prev_op_idx == event->event_metadata);
+ child_op_no = from_same_op ? child_op_no + 1 : 1;
+ auto node_name = ToString(op_details.outputs) + "#" + std::to_string(child_op_no);
+ int64_t start_us = event->begin_timestamp_us - base_start_us;
int64_t node_exec_time =
event->end_timestamp_us - event->begin_timestamp_us;
- detail->rel_end_us.UpdateStat(node_exec_time);
+ stats_calculator_->AddNodeStats(node_name, op_details.name, node_num,
+ start_us, node_exec_time, 0 /*memory */);
curr_total_us += node_exec_time;
++node_num;
-
- if (result.second) {
- detail->name = node_name;
- detail->type = op_details.name;
- detail->run_order = node_num;
- detail->times_called = 0;
- }
- if (!is_continued) {
- ++detail->times_called;
- }
prev_op_idx = event->event_metadata;
}
- stats_calculator_->UpdateDetails(details);
stats_calculator_->UpdateRunTotalUs(curr_total_us);
}
} // namespace profiling
diff --git a/tools/tflite_benchmark_model/profile_summarizer.h b/tools/tflite_benchmark_model/profile_summarizer.h
deleted file mode 100644
index a529ff874..000000000
--- a/tools/tflite_benchmark_model/profile_summarizer.h
+++ /dev/null
@@ -1,55 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILE_SUMMARIZER_H_
-#define TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILE_SUMMARIZER_H_
-
-#include <vector>
-
-#include "tensorflow/contrib/lite/interpreter.h"
-#include "tensorflow/contrib/lite/profiling/profiler.h"
-#include "tensorflow/core/util/stats_calculator.h"
-
-namespace tflite {
-namespace profiling {
-
-// Creates a summary of operator invocations in the interpreter.
-class ProfileSummarizer {
- public:
- ProfileSummarizer();
- virtual ~ProfileSummarizer() {}
-
- // Process profile events to update statistics for operator invocations.
- void ProcessProfiles(const std::vector<const ProfileEvent*>& profile_stats,
- const tflite::Interpreter& interpreter);
-
- // Returns a string detailing the accumulated runtime stats in a tab-separated
- // format which can be pasted into a spreadsheet for further analysis.
- std::string GetOutputString() const {
- return stats_calculator_->GetOutputString();
- }
-
- std::string GetShortSummary() const {
- return stats_calculator_->GetShortSummary();
- }
-
- private:
- std::unique_ptr<tensorflow::StatsCalculator> stats_calculator_;
-};
-
-} // namespace profiling
-} // namespace tflite
-
-#endif // TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILE_SUMMARIZER_H_
diff --git a/tools/tflite_benchmark_model/stats_calculator.cc b/tools/tflite_benchmark_model/stats_calculator.cc
new file mode 100644
index 000000000..578650701
--- /dev/null
+++ b/tools/tflite_benchmark_model/stats_calculator.cc
@@ -0,0 +1,317 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/util/stats_calculator.h"
+
+#include <iomanip>
+#include <map>
+#include <queue>
+#include <sstream>
+#include <string>
+#include <algorithm>
+
+namespace tensorflow {
+
+StatsCalculator::StatsCalculator(const StatSummarizerOptions& options)
+ : options_(options) {}
+
+std::string StatsCalculator::GetShortSummary() const {
+ std::stringstream stream;
+ stream << "Timings (microseconds): ";
+ run_total_us_.OutputToStream(&stream);
+ stream << std::endl;
+
+ stream << "Memory (bytes): ";
+ memory_.OutputToStream(&stream);
+ stream << std::endl;
+
+ stream << details_.size() << " nodes observed" << std::endl;
+ return stream.str();
+}
+
+std::ostream& InitField(std::ostream& stream, int width) {
+ stream << "\t" << std::right << std::setw(width) << std::fixed
+ << std::setprecision(3);
+ return stream;
+}
+
+std::string StatsCalculator::HeaderString(const std::string& title) const {
+ std::stringstream stream;
+
+ stream << "============================== " << title
+ << " ==============================" << std::endl;
+
+ InitField(stream, 24) << "[node type]";
+ InitField(stream, 9) << "[start]";
+ InitField(stream, 9) << "[first]";
+ InitField(stream, 9) << "[avg ms]";
+ InitField(stream, 8) << "[%]";
+ InitField(stream, 8) << "[cdf%]";
+ InitField(stream, 10) << "[mem KB]";
+ InitField(stream, 9) << "[times called]";
+ stream << "\t"
+ << "[Name]";
+ return stream.str();
+}
+
+std::string StatsCalculator::ColumnString(const Detail& detail,
+ const int64_t cumulative_stat_on_node,
+ const Stat<int64_t>& stat) const {
+ const double start_ms = detail.start_us.avg() / 1000.0;
+ const double first_time_ms = detail.rel_end_us.first() / 1000.0;
+ const double avg_time_ms = detail.rel_end_us.avg() / 1000.0;
+ const double percentage = detail.rel_end_us.sum() * 100.0 / stat.sum();
+ const double cdf_percentage = (cumulative_stat_on_node * 100.0f) / stat.sum();
+ const int64_t times_called = detail.times_called / num_runs();
+
+ std::stringstream stream;
+ InitField(stream, 24) << detail.type;
+ InitField(stream, 9) << start_ms;
+ InitField(stream, 9) << first_time_ms;
+ InitField(stream, 9) << avg_time_ms;
+ InitField(stream, 7) << percentage << "%";
+ InitField(stream, 7) << cdf_percentage << "%";
+ InitField(stream, 10) << detail.mem_used.newest() / 1000.0;
+ InitField(stream, 9) << times_called;
+ stream << "\t" << detail.name;
+
+ return stream.str();
+}
+
+void StatsCalculator::OrderNodesByMetric(
+ SortingMetric metric, std::vector<const Detail*>* details) const {
+ std::priority_queue<std::pair<std::string, const Detail*>> sorted_list;
+ const int num_nodes = details_.size();
+
+ for (const auto& det : details_) {
+ const Detail* detail = &(det.second);
+ std::stringstream stream;
+ stream << std::setw(20) << std::right << std::setprecision(10)
+ << std::fixed;
+
+ switch (metric) {
+ case BY_NAME:
+ stream << detail->name;
+ break;
+ case BY_RUN_ORDER:
+ stream << num_nodes - detail->run_order;
+ break;
+ case BY_TIME:
+ stream << detail->rel_end_us.avg();
+ break;
+ case BY_MEMORY:
+ stream << detail->mem_used.avg();
+ break;
+ case BY_TYPE:
+ stream << detail->type;
+ break;
+ default:
+ stream << "";
+ break;
+ }
+
+ sorted_list.emplace(stream.str(), detail);
+ }
+
+ while (!sorted_list.empty()) {
+ auto entry = sorted_list.top();
+ sorted_list.pop();
+ details->push_back(entry.second);
+ }
+}
+
+void StatsCalculator::ComputeStatsByType(
+ std::map<std::string, int64_t>* node_type_map_count,
+ std::map<std::string, int64_t>* node_type_map_time,
+ std::map<std::string, int64_t>* node_type_map_memory,
+ std::map<std::string, int64_t>* node_type_map_times_called,
+ int64_t* accumulated_us) const {
+ int64_t run_count = run_total_us_.count();
+
+ for (const auto& det : details_) {
+ const std::string node_name = det.first;
+ const Detail& detail = det.second;
+
+ int64_t curr_time_val =
+ static_cast<int64_t>(detail.rel_end_us.sum() / run_count);
+ *accumulated_us += curr_time_val;
+
+ int64_t curr_memory_val = detail.mem_used.newest();
+
+ const std::string& node_type = detail.type;
+
+ const std::string sharp1("#1");
+ bool first = std::mismatch(sharp1.rbegin(), sharp1.rend(), node_name.rbegin()).first == sharp1.rend();
+
+ if (first) {
+ (*node_type_map_count)[node_type] += 1;
+ (*node_type_map_times_called)[node_type] += detail.times_called / run_count;
+ }
+ (*node_type_map_time)[node_type] += curr_time_val;
+ (*node_type_map_memory)[node_type] += curr_memory_val;
+ }
+}
+
+std::string StatsCalculator::GetStatsByNodeType() const {
+ std::stringstream stream;
+
+ stream << "Number of nodes executed: " << details_.size() << std::endl;
+
+ stream << "============================== Summary by node type "
+ "=============================="
+ << std::endl;
+
+ std::map<std::string, int64_t> node_type_map_count;
+ std::map<std::string, int64_t> node_type_map_time;
+ std::map<std::string, int64_t> node_type_map_memory;
+ std::map<std::string, int64_t> node_type_map_times_called;
+ int64_t accumulated_us = 0;
+
+ ComputeStatsByType(&node_type_map_count, &node_type_map_time,
+ &node_type_map_memory, &node_type_map_times_called,
+ &accumulated_us);
+
+ // Sort them.
+ std::priority_queue<std::pair<int64_t, std::pair<std::string, int64_t>>>
+ timings;
+ for (const auto& node_type : node_type_map_time) {
+ const int64_t mem_used = node_type_map_memory[node_type.first];
+ timings.emplace(node_type.second,
+ std::pair<std::string, int64_t>(node_type.first, mem_used));
+ }
+
+ InitField(stream, 24) << "[Node type]";
+ InitField(stream, 9) << "[count]";
+ InitField(stream, 10) << "[avg ms]";
+ InitField(stream, 11) << "[avg %]";
+ InitField(stream, 11) << "[cdf %]";
+ InitField(stream, 10) << "[mem KB]";
+ InitField(stream, 10) << "[times called]";
+ stream << std::endl;
+
+ float cdf = 0.0f;
+ while (!timings.empty()) {
+ auto entry = timings.top();
+ timings.pop();
+
+ const std::string node_type = entry.second.first;
+ const float memory = entry.second.second / 1000.0f;
+
+ const int64_t node_type_total_us = entry.first;
+ const float time_per_run_ms = node_type_total_us / 1000.0f;
+
+ const float percentage =
+ ((entry.first / static_cast<float>(accumulated_us)) * 100.0f);
+ cdf += percentage;
+
+ InitField(stream, 24) << node_type;
+ InitField(stream, 9) << node_type_map_count[node_type];
+ InitField(stream, 10) << time_per_run_ms;
+ InitField(stream, 10) << percentage << "%";
+ InitField(stream, 10) << cdf << "%";
+ InitField(stream, 10) << memory;
+ InitField(stream, 9) << node_type_map_times_called[node_type];
+ stream << std::endl;
+ }
+ stream << std::endl;
+ return stream.str();
+}
+
+std::string StatsCalculator::GetStatsByMetric(const std::string& title,
+ SortingMetric sorting_metric,
+ int num_stats) const {
+ std::vector<const Detail*> details;
+ OrderNodesByMetric(sorting_metric, &details);
+
+ double cumulative_stat_on_node = 0;
+
+ std::stringstream stream;
+ stream << HeaderString(title) << std::endl;
+ int stat_num = 0;
+ for (auto detail : details) {
+ ++stat_num;
+ if (num_stats > 0 && stat_num > num_stats) {
+ break;
+ }
+
+ // TODO(andrewharp): Make this keep track of the particular metric for cdf.
+ cumulative_stat_on_node += detail->rel_end_us.sum();
+ stream << ColumnString(*detail, cumulative_stat_on_node, run_total_us_)
+ << std::endl;
+ }
+ stream << std::endl;
+ return stream.str();
+}
+
+std::string StatsCalculator::GetOutputString() const {
+ std::stringstream stream;
+ if (options_.show_run_order) {
+ stream << GetStatsByMetric("Run Order", BY_RUN_ORDER,
+ options_.run_order_limit);
+ }
+ if (options_.show_time) {
+ stream << GetStatsByMetric("Top by Computation Time", BY_TIME,
+ options_.time_limit);
+ }
+ if (options_.show_memory) {
+ stream << GetStatsByMetric("Top by Memory Use", BY_MEMORY,
+ options_.memory_limit);
+ }
+ if (options_.show_type) {
+ stream << GetStatsByNodeType();
+ }
+ if (options_.show_summary) {
+ stream << GetShortSummary() << std::endl;
+ }
+ return stream.str();
+}
+
+void StatsCalculator::AddNodeStats(const std::string& name,
+ const std::string& type, int64_t run_order,
+ int64_t start_us, int64_t rel_end_us,
+ int64_t mem_used) {
+ Detail* detail = nullptr;
+ if (details_.find(name) == details_.end()) {
+ details_.insert({name, {}});
+ detail = &details_.at(name);
+ detail->type = type;
+ detail->name = name;
+ detail->run_order = run_order;
+ } else {
+ detail = &details_.at(name);
+ }
+ detail->start_us.UpdateStat(start_us);
+ detail->rel_end_us.UpdateStat(rel_end_us);
+ detail->mem_used.UpdateStat(mem_used);
+ detail->times_called++;
+}
+
+} // namespace tensorflow