diff options
Diffstat (limited to 'runtime/contrib/android_benchmark_app/cpp/ndk_main.cpp')
-rw-r--r-- | runtime/contrib/android_benchmark_app/cpp/ndk_main.cpp | 228 |
1 files changed, 228 insertions, 0 deletions
diff --git a/runtime/contrib/android_benchmark_app/cpp/ndk_main.cpp b/runtime/contrib/android_benchmark_app/cpp/ndk_main.cpp new file mode 100644 index 000000000..f2ca1312c --- /dev/null +++ b/runtime/contrib/android_benchmark_app/cpp/ndk_main.cpp @@ -0,0 +1,228 @@ +#include "ndk_main.h" + +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" + +#include "tflite/Assert.h" +#include "tflite/Session.h" +#include "tflite/InterpreterSession.h" +#include "tflite/NNAPISession.h" +#include "tflite/ext/kernels/register.h" + +#include "misc/benchmark.h" + +#include <boost/accumulators/accumulators.hpp> +#include <boost/accumulators/statistics/stats.hpp> +#include <boost/accumulators/statistics/mean.hpp> +#include <boost/accumulators/statistics/min.hpp> +#include <boost/accumulators/statistics/max.hpp> + +#include <cassert> +#include <chrono> +#include <sstream> + +#include <android/log.h> + +using namespace tflite; +using namespace tflite::ops::builtin; + +static StderrReporter error_reporter; + +static std::unique_ptr<FlatBufferModel> model; + +inline void setText(JNIEnv *env, jobject thisObj, const std::string &message) +{ + jclass thisClass = env->GetObjectClass(thisObj); + jmethodID setTextMethod = env->GetMethodID(thisClass, "setText", "(Ljava/lang/String;)V"); + + assert(setTextMethod != nullptr); + + env->CallVoidMethod(thisObj, setTextMethod, env->NewStringUTF(message.c_str())); +} + +inline void setTitle(JNIEnv *env, jobject thisObj, const std::string &message) +{ + jclass thisClass = env->GetObjectClass(thisObj); + jmethodID setTextMethod = env->GetMethodID(thisClass, "setTitle", "(Ljava/lang/String;)V"); + + assert(setTextMethod != nullptr); + + env->CallVoidMethod(thisObj, setTextMethod, env->NewStringUTF(message.c_str())); + + // Clear message + setText(env, thisObj, ""); +} + +inline void setText(JNIEnv *env, jobject thisObj, const std::stringstream &ss) +{ + setText(env, thisObj, ss.str()); +} + +inline std::unique_ptr<FlatBufferModel> loadModel(JNIEnv *env, jobject thisObj, + jobject model_buffer) +{ + const char *model_base = static_cast<char *>(env->GetDirectBufferAddress(model_buffer)); + jlong model_size = env->GetDirectBufferCapacity(model_buffer); + + return FlatBufferModel::BuildFromBuffer(model_base, static_cast<size_t>(model_size), + &error_reporter); +} + +struct Activity +{ + virtual ~Activity() = default; + + virtual void prepare(void) const = 0; + virtual void run(void) const = 0; + virtual void teardown(void) const = 0; +}; + +struct LiteActivity final : public Activity +{ +public: + LiteActivity(nnfw::tflite::Session &sess) : _sess(sess) + { + // DO NOTHING + } + +public: + void prepare(void) const override { _sess.prepare(); } + void run(void) const override { _sess.run(); } + void teardown(void) const override { _sess.teardown(); } + +private: + nnfw::tflite::Session &_sess; +}; + +struct SimpleActivity final : public Activity +{ +public: + SimpleActivity(const std::function<void(void)> &fn) : _fn{fn} + { + // DO NOTHING + } + +public: + void prepare(void) const override {} + void run(void) const override { _fn(); } + void teardown(void) const override {} + +private: + std::function<void(void)> _fn; +}; + +inline void runBenchmark(JNIEnv *env, jobject thisObj, Activity &act) +{ + auto runTrial = [&](void) { + std::chrono::milliseconds elapsed(0); + + act.prepare(); + nnfw::misc::benchmark::measure(elapsed) << [&](void) { act.run(); }; + act.teardown(); + + return elapsed; + }; + + // Warm-up + for (uint32_t n = 0; n < 3; ++n) + { + auto elapsed = runTrial(); + + std::stringstream ss; + ss << "Warm-up #" << n << " takes " << elapsed.count() << "ms" << std::endl; + setText(env, thisObj, ss); + } + + // Measure + using namespace boost::accumulators; + + accumulator_set<double, stats<tag::mean, tag::min, tag::max>> acc; + + for (uint32_t n = 0; n < 100; ++n) + { + auto elapsed = runTrial(); + + std::stringstream ss; + ss << "Iteration #" << n << " takes " << elapsed.count() << "ms" << std::endl; + setText(env, thisObj, ss); + + acc(elapsed.count()); + } + + std::stringstream ss; + ss << "Average is " << mean(acc) << "ms" << std::endl; + ss << "Min is " << min(acc) << "ms" << std::endl; + ss << "Max is " << max(acc) << "ms" << std::endl; + setText(env, thisObj, ss); +} + +JNIEXPORT void JNICALL Java_com_ndk_tflbench_MainActivity_runInterpreterBenchmark( + JNIEnv *env, jobject thisObj, jobject model_buffer) +{ + setTitle(env, thisObj, "Running Interpreter Benchmark"); + + auto model = loadModel(env, thisObj, model_buffer); + assert(model != nullptr); + + nnfw::tflite::BuiltinOpResolver resolver; + InterpreterBuilder builder(*model, resolver); + + std::unique_ptr<Interpreter> interpreter; + + TFLITE_ENSURE(builder(&interpreter)); + + interpreter->SetNumThreads(-1); + + nnfw::tflite::InterpreterSession sess(interpreter.get()); + LiteActivity act{sess}; + runBenchmark(env, thisObj, act); +} + +static void runNNAPIBenchmark(JNIEnv *env, jobject thisObj, jobject model_buffer) +{ + auto model = loadModel(env, thisObj, model_buffer); + assert(model != nullptr); + + nnfw::tflite::BuiltinOpResolver resolver; + InterpreterBuilder builder(*model, resolver); + + std::unique_ptr<Interpreter> interpreter; + + TFLITE_ENSURE(builder(&interpreter)); + + nnfw::tflite::NNAPISession sess(interpreter.get()); + LiteActivity act{sess}; + runBenchmark(env, thisObj, act); +} + +JNIEXPORT void JNICALL Java_com_ndk_tflbench_MainActivity_runNNAPIBenchmark(JNIEnv *env, + jobject thisObj, + jobject model_buffer) +{ + setTitle(env, thisObj, "Running NNAPI Benchmark"); + + try + { + runNNAPIBenchmark(env, thisObj, model_buffer); + } + catch (const std::exception &ex) + { + std::stringstream ss; + ss << "Caught an exception " << ex.what(); + setText(env, thisObj, ss); + } +} + +JNIEXPORT jstring JNICALL Java_com_ndk_tflbench_MainActivity_getModelName(JNIEnv *env, + jobject thisObj) +{ + return env->NewStringUTF(MODEL_NAME); +} + +#define TF_ENSURE(e) \ + { \ + if (!(e).ok()) \ + { \ + throw std::runtime_error{"'" #e "' FAILED"}; \ + } \ + } |