diff options
Diffstat (limited to 'tests/tools/onert_run/src')
-rw-r--r-- | tests/tools/onert_run/src/args.cc | 4 | ||||
-rw-r--r-- | tests/tools/onert_run/src/args.h | 4 | ||||
-rw-r--r-- | tests/tools/onert_run/src/onert_run.cc | 59 |
3 files changed, 67 insertions, 0 deletions
diff --git a/tests/tools/onert_run/src/args.cc b/tests/tools/onert_run/src/args.cc index 1e9d1aa69..a64d81db5 100644 --- a/tests/tools/onert_run/src/args.cc +++ b/tests/tools/onert_run/src/args.cc @@ -299,6 +299,10 @@ void Args::Initialize(void) "0: prints the only result. Messages btw run don't print\n" "1: prints result and message btw run\n" "2: prints all of messages to print\n") + ("quantize,q", po::value<std::string>()->default_value("")->notifier([&](const auto &v) { _quantize = v; }), "Request quantization with type (int8 or int16)") + ("qpath", po::value<std::string>()->default_value("")->notifier([&](const auto &v) { _quantized_model_path = v; }), + "Path to export quantized model.\n" + "If it is not set, the quantized model will be exported to the same directory of the original model/package with q8/q16 suffix.") ; // clang-format on diff --git a/tests/tools/onert_run/src/args.h b/tests/tools/onert_run/src/args.h index e35a761ed..97d9b1af1 100644 --- a/tests/tools/onert_run/src/args.h +++ b/tests/tools/onert_run/src/args.h @@ -69,6 +69,8 @@ public: /// @brief Return true if "--shape_run" or "--shape_prepare" is provided bool shapeParamProvided(); const int getVerboseLevel(void) const { return _verbose_level; } + const std::string &getQuantize(void) const { return _quantize; } + const std::string &getQuantizedModelPath(void) const { return _quantized_model_path; } private: void Initialize(); @@ -99,6 +101,8 @@ private: bool _print_version = false; int _verbose_level; bool _use_single_model = false; + std::string _quantize; + std::string _quantized_model_path; }; } // end of namespace onert_run diff --git a/tests/tools/onert_run/src/onert_run.cc b/tests/tools/onert_run/src/onert_run.cc index 5acb2bb64..0bc64bb2b 100644 --- a/tests/tools/onert_run/src/onert_run.cc +++ b/tests/tools/onert_run/src/onert_run.cc @@ -23,6 +23,7 @@ #include "nnfw.h" #include "nnfw_util.h" #include "nnfw_internal.h" +#include "nnfw_experimental.h" #include "randomgen.h" #include "rawformatter.h" #ifdef RUY_PROFILER @@ -48,6 +49,33 @@ void overwriteShapeMap(onert_run::TensorShapeMap &shape_map, shape_map[i] = shapes[i]; } +std::string genQuantizedModelPathFromModelPath(const std::string &model_path, bool is_q16) +{ + auto const extension_pos = model_path.find(".circle"); + if (extension_pos == std::string::npos) + { + std::cerr << "Input model isn't .circle." << std::endl; + exit(-1); + } + auto const qstring = std::string("_quantized_") + (is_q16 ? "q16" : "q8"); + return model_path.substr(0, extension_pos) + qstring + ".circle"; +} + +std::string genQuantizedModelPathFromPackagePath(const std::string &package_path, bool is_q16) +{ + auto package_path_without_slash = package_path; + if (package_path_without_slash.back() == '/') + package_path_without_slash.pop_back(); + auto package_name_pos = package_path_without_slash.find_last_of('/'); + if (package_name_pos == std::string::npos) + package_name_pos = 0; + else + package_name_pos++; + auto package_name = package_path_without_slash.substr(package_name_pos); + auto const qstring = std::string("_quantized_") + (is_q16 ? "q16" : "q8"); + return package_path_without_slash + "/" + package_name + qstring + ".circle"; +} + int main(const int argc, char **argv) { using namespace onert_run; @@ -85,6 +113,37 @@ int main(const int argc, char **argv) NNPR_ENSURE_STATUS(nnfw_load_model_from_file(session, args.getPackageFilename().c_str())); }); + // Quantize model + auto quantize = args.getQuantize(); + if (!quantize.empty()) + { + NNFW_QUANTIZE_TYPE quantize_type = NNFW_QUANTIZE_TYPE_NOT_SET; + if (quantize == "int8") + quantize_type = NNFW_QUANTIZE_TYPE_U8_ASYM; + if (quantize == "int16") + quantize_type = NNFW_QUANTIZE_TYPE_I16_SYM; + NNPR_ENSURE_STATUS(nnfw_set_quantization_type(session, quantize_type)); + + if (args.getQuantizedModelPath() != "") + NNPR_ENSURE_STATUS( + nnfw_set_quantized_model_path(session, args.getQuantizedModelPath().c_str())); + else + { + if (args.useSingleModel()) + NNPR_ENSURE_STATUS(nnfw_set_quantized_model_path( + session, + genQuantizedModelPathFromModelPath(args.getModelFilename(), quantize == "int16") + .c_str())); + else + NNPR_ENSURE_STATUS(nnfw_set_quantized_model_path( + session, + genQuantizedModelPathFromPackagePath(args.getPackageFilename(), quantize == "int16") + .c_str())); + } + + NNPR_ENSURE_STATUS(nnfw_quantize(session)); + } + char *available_backends = std::getenv("BACKENDS"); if (available_backends) NNPR_ENSURE_STATUS(nnfw_set_available_backends(session, available_backends)); |