summaryrefslogtreecommitdiff
path: root/tests/tools
diff options
context:
space:
mode:
Diffstat (limited to 'tests/tools')
-rw-r--r--tests/tools/onert_run/src/args.cc4
-rw-r--r--tests/tools/onert_run/src/args.h4
-rw-r--r--tests/tools/onert_run/src/onert_run.cc59
-rw-r--r--tests/tools/onert_train/CMakeLists.txt60
-rw-r--r--tests/tools/onert_train/README.md13
-rw-r--r--tests/tools/onert_train/src/allocation.h46
-rw-r--r--tests/tools/onert_train/src/args.cc291
-rw-r--r--tests/tools/onert_train/src/args.h92
-rw-r--r--tests/tools/onert_train/src/formatter.h47
-rw-r--r--tests/tools/onert_train/src/h5formatter.cc258
-rw-r--r--tests/tools/onert_train/src/h5formatter.h41
-rw-r--r--tests/tools/onert_train/src/measure.h90
-rw-r--r--tests/tools/onert_train/src/nnfw_util.cc49
-rw-r--r--tests/tools/onert_train/src/nnfw_util.h37
-rw-r--r--tests/tools/onert_train/src/onert_train.cc277
-rw-r--r--tests/tools/onert_train/src/randomgen.cc77
-rw-r--r--tests/tools/onert_train/src/randomgen.h40
-rw-r--r--tests/tools/onert_train/src/rawdataloader.cc77
-rw-r--r--tests/tools/onert_train/src/rawdataloader.h51
-rw-r--r--tests/tools/onert_train/src/rawformatter.cc97
-rw-r--r--tests/tools/onert_train/src/rawformatter.h40
-rw-r--r--tests/tools/onert_train/src/types.h27
-rw-r--r--tests/tools/onert_train/test/rawdataloader.test.cc224
23 files changed, 2001 insertions, 0 deletions
diff --git a/tests/tools/onert_run/src/args.cc b/tests/tools/onert_run/src/args.cc
index 1e9d1aa69..a64d81db5 100644
--- a/tests/tools/onert_run/src/args.cc
+++ b/tests/tools/onert_run/src/args.cc
@@ -299,6 +299,10 @@ void Args::Initialize(void)
"0: prints the only result. Messages btw run don't print\n"
"1: prints result and message btw run\n"
"2: prints all of messages to print\n")
+ ("quantize,q", po::value<std::string>()->default_value("")->notifier([&](const auto &v) { _quantize = v; }), "Request quantization with type (int8 or int16)")
+ ("qpath", po::value<std::string>()->default_value("")->notifier([&](const auto &v) { _quantized_model_path = v; }),
+ "Path to export quantized model.\n"
+ "If it is not set, the quantized model will be exported to the same directory of the original model/package with q8/q16 suffix.")
;
// clang-format on
diff --git a/tests/tools/onert_run/src/args.h b/tests/tools/onert_run/src/args.h
index e35a761ed..97d9b1af1 100644
--- a/tests/tools/onert_run/src/args.h
+++ b/tests/tools/onert_run/src/args.h
@@ -69,6 +69,8 @@ public:
/// @brief Return true if "--shape_run" or "--shape_prepare" is provided
bool shapeParamProvided();
const int getVerboseLevel(void) const { return _verbose_level; }
+ const std::string &getQuantize(void) const { return _quantize; }
+ const std::string &getQuantizedModelPath(void) const { return _quantized_model_path; }
private:
void Initialize();
@@ -99,6 +101,8 @@ private:
bool _print_version = false;
int _verbose_level;
bool _use_single_model = false;
+ std::string _quantize;
+ std::string _quantized_model_path;
};
} // end of namespace onert_run
diff --git a/tests/tools/onert_run/src/onert_run.cc b/tests/tools/onert_run/src/onert_run.cc
index 5acb2bb64..0bc64bb2b 100644
--- a/tests/tools/onert_run/src/onert_run.cc
+++ b/tests/tools/onert_run/src/onert_run.cc
@@ -23,6 +23,7 @@
#include "nnfw.h"
#include "nnfw_util.h"
#include "nnfw_internal.h"
+#include "nnfw_experimental.h"
#include "randomgen.h"
#include "rawformatter.h"
#ifdef RUY_PROFILER
@@ -48,6 +49,33 @@ void overwriteShapeMap(onert_run::TensorShapeMap &shape_map,
shape_map[i] = shapes[i];
}
+std::string genQuantizedModelPathFromModelPath(const std::string &model_path, bool is_q16)
+{
+ auto const extension_pos = model_path.find(".circle");
+ if (extension_pos == std::string::npos)
+ {
+ std::cerr << "Input model isn't .circle." << std::endl;
+ exit(-1);
+ }
+ auto const qstring = std::string("_quantized_") + (is_q16 ? "q16" : "q8");
+ return model_path.substr(0, extension_pos) + qstring + ".circle";
+}
+
+std::string genQuantizedModelPathFromPackagePath(const std::string &package_path, bool is_q16)
+{
+ auto package_path_without_slash = package_path;
+ if (package_path_without_slash.back() == '/')
+ package_path_without_slash.pop_back();
+ auto package_name_pos = package_path_without_slash.find_last_of('/');
+ if (package_name_pos == std::string::npos)
+ package_name_pos = 0;
+ else
+ package_name_pos++;
+ auto package_name = package_path_without_slash.substr(package_name_pos);
+ auto const qstring = std::string("_quantized_") + (is_q16 ? "q16" : "q8");
+ return package_path_without_slash + "/" + package_name + qstring + ".circle";
+}
+
int main(const int argc, char **argv)
{
using namespace onert_run;
@@ -85,6 +113,37 @@ int main(const int argc, char **argv)
NNPR_ENSURE_STATUS(nnfw_load_model_from_file(session, args.getPackageFilename().c_str()));
});
+ // Quantize model
+ auto quantize = args.getQuantize();
+ if (!quantize.empty())
+ {
+ NNFW_QUANTIZE_TYPE quantize_type = NNFW_QUANTIZE_TYPE_NOT_SET;
+ if (quantize == "int8")
+ quantize_type = NNFW_QUANTIZE_TYPE_U8_ASYM;
+ if (quantize == "int16")
+ quantize_type = NNFW_QUANTIZE_TYPE_I16_SYM;
+ NNPR_ENSURE_STATUS(nnfw_set_quantization_type(session, quantize_type));
+
+ if (args.getQuantizedModelPath() != "")
+ NNPR_ENSURE_STATUS(
+ nnfw_set_quantized_model_path(session, args.getQuantizedModelPath().c_str()));
+ else
+ {
+ if (args.useSingleModel())
+ NNPR_ENSURE_STATUS(nnfw_set_quantized_model_path(
+ session,
+ genQuantizedModelPathFromModelPath(args.getModelFilename(), quantize == "int16")
+ .c_str()));
+ else
+ NNPR_ENSURE_STATUS(nnfw_set_quantized_model_path(
+ session,
+ genQuantizedModelPathFromPackagePath(args.getPackageFilename(), quantize == "int16")
+ .c_str()));
+ }
+
+ NNPR_ENSURE_STATUS(nnfw_quantize(session));
+ }
+
char *available_backends = std::getenv("BACKENDS");
if (available_backends)
NNPR_ENSURE_STATUS(nnfw_set_available_backends(session, available_backends));
diff --git a/tests/tools/onert_train/CMakeLists.txt b/tests/tools/onert_train/CMakeLists.txt
new file mode 100644
index 000000000..f047b2ad0
--- /dev/null
+++ b/tests/tools/onert_train/CMakeLists.txt
@@ -0,0 +1,60 @@
+if(NOT BUILD_ONERT_TRAIN)
+ return()
+endif(NOT BUILD_ONERT_TRAIN)
+
+if(NOT BUILD_ONERT)
+ return()
+endif(NOT BUILD_ONERT)
+
+list(APPEND ONERT_TRAIN_SRCS "src/onert_train.cc")
+list(APPEND ONERT_TRAIN_SRCS "src/args.cc")
+list(APPEND ONERT_TRAIN_SRCS "src/nnfw_util.cc")
+list(APPEND ONERT_TRAIN_SRCS "src/randomgen.cc")
+list(APPEND ONERT_TRAIN_SRCS "src/rawformatter.cc")
+list(APPEND ONERT_TRAIN_SRCS "src/rawdataloader.cc")
+
+nnfw_find_package(Boost REQUIRED program_options)
+nnfw_find_package(HDF5 QUIET)
+
+if (HDF5_FOUND)
+ list(APPEND ONERT_TRAIN_SRCS "src/h5formatter.cc")
+endif()
+
+add_executable(onert_train ${ONERT_TRAIN_SRCS})
+
+if (HDF5_FOUND)
+ target_compile_definitions(onert_train PRIVATE ONERT_HAVE_HDF5=1)
+ target_include_directories(onert_train PRIVATE ${HDF5_INCLUDE_DIRS})
+ target_link_libraries(onert_train ${HDF5_CXX_LIBRARIES})
+else()
+ message(WARNING "HDF5 NOT found. Install libhdf5-dev or set EXT_HDF5_DIR to support load/dump in onert_train.")
+endif(HDF5_FOUND)
+
+target_include_directories(onert_train PRIVATE src)
+target_include_directories(onert_train PRIVATE ${Boost_INCLUDE_DIRS})
+
+target_link_libraries(onert_train nnfw_lib_tflite jsoncpp)
+target_link_libraries(onert_train nnfw-dev)
+target_link_libraries(onert_train ${Boost_PROGRAM_OPTIONS_LIBRARY})
+target_link_libraries(onert_train nnfw_lib_benchmark)
+
+install(TARGETS onert_train DESTINATION bin)
+
+if(NOT ENABLE_TEST)
+ return()
+endif(NOT ENABLE_TEST)
+
+# Unit Tests
+set(TEST_ONERT_TRAIN test_onert_train)
+
+file(GLOB_RECURSE ONERT_TRAIN_TEST_SRCS "test/*.cc")
+list(APPEND ONERT_TRAIN_TEST_SRCS "src/rawdataloader.cc")
+list(APPEND ONERT_TRAIN_TEST_SRCS "src/nnfw_util.cc")
+
+add_executable(${TEST_ONERT_TRAIN} ${ONERT_TRAIN_TEST_SRCS})
+
+target_link_libraries(${TEST_ONERT_TRAIN} nnfw-dev)
+target_link_libraries(${TEST_ONERT_TRAIN} gtest gtest_main dl ${LIB_PTHREAD})
+
+add_test(${TEST_ONERT_TRAIN} ${TEST_ONERT_TRAIN})
+install(TARGETS ${TEST_ONERT_TRAIN} DESTINATION unittest)
diff --git a/tests/tools/onert_train/README.md b/tests/tools/onert_train/README.md
new file mode 100644
index 000000000..a201237f6
--- /dev/null
+++ b/tests/tools/onert_train/README.md
@@ -0,0 +1,13 @@
+# onert_train
+
+`onert_train` aims to train ai model. This tool trains given the ai model entered by the user using a given input and an expected output, and stores or inference the trained model.
+
+The input models that can be supported by this tool are as follows.
+- circle
+- nnpackage
+
+## Usage
+
+### Simple train
+
+### Simple inference to trained model
diff --git a/tests/tools/onert_train/src/allocation.h b/tests/tools/onert_train/src/allocation.h
new file mode 100644
index 000000000..f5a6aa73b
--- /dev/null
+++ b/tests/tools/onert_train/src/allocation.h
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_TRAIN_ALLOCATION_H__
+#define __ONERT_TRAIN_ALLOCATION_H__
+
+#include <cstdlib>
+#include <cstdint>
+
+namespace onert_train
+{
+class Allocation
+{
+public:
+ Allocation() : data_(nullptr) {}
+ ~Allocation() { free(data_); }
+ void *data() const { return data_; }
+ void *alloc(uint64_t sz)
+ {
+ if (data_)
+ {
+ free(data_);
+ }
+
+ return data_ = malloc(sz);
+ }
+
+private:
+ void *data_;
+};
+} // namespace onert_train
+
+#endif // __ONERT_TRAIN_ALLOCATION_H__
diff --git a/tests/tools/onert_train/src/args.cc b/tests/tools/onert_train/src/args.cc
new file mode 100644
index 000000000..dbdd384b5
--- /dev/null
+++ b/tests/tools/onert_train/src/args.cc
@@ -0,0 +1,291 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "args.h"
+
+#include <functional>
+#include <iostream>
+#include <sys/stat.h>
+#include <json/json.h>
+
+namespace
+{
+
+// This function parses a json object and returns as a vector of integers
+// For example,
+// [0, [1, 2, 3, 4], 3, 40, 4, []] in JSON
+// is converted to:
+// {
+// 0 -> [1, 2, 3, 4]
+// 3 -> 40
+// 4 -> []
+// } in std::unordered_map. Note that the value type is still Json::Value.
+std::unordered_map<uint32_t, Json::Value> argArrayToMap(const Json::Value &jsonval)
+{
+ if (!jsonval.isArray() || (jsonval.size() % 2 != 0))
+ {
+ std::cerr << "JSON argument must be an even-sized array in JSON\n";
+ exit(1);
+ }
+
+ std::unordered_map<uint32_t, Json::Value> ret;
+ for (uint32_t i = 0; i < jsonval.size(); i += 2)
+ {
+ if (!jsonval[i].isUInt())
+ {
+ std::cerr << "Key values(values in even indices) must be unsigned integers\n";
+ exit(1);
+ }
+ uint32_t key = jsonval[i].asUInt();
+ Json::Value val = jsonval[i + 1];
+ ret[key] = jsonval[i + 1];
+ }
+ return ret;
+}
+
+void checkModelfile(const std::string &model_filename)
+{
+ if (model_filename.empty())
+ {
+ // TODO Print usage instead of the below message
+ std::cerr << "Please specify model file. Run with `--help` for usage."
+ << "\n";
+
+ exit(1);
+ }
+ else
+ {
+ if (access(model_filename.c_str(), F_OK) == -1)
+ {
+ std::cerr << "Model file not found: " << model_filename << "\n";
+ exit(1);
+ }
+ }
+}
+
+void checkPackage(const std::string &package_filename)
+{
+ if (package_filename.empty())
+ {
+ // TODO Print usage instead of the below message
+ std::cerr << "Please specify nnpackage file. Run with `--help` for usage."
+ << "\n";
+
+ exit(1);
+ }
+ else
+ {
+ if (access(package_filename.c_str(), F_OK) == -1)
+ {
+ std::cerr << "nnpackage not found: " << package_filename << "\n";
+ exit(1);
+ }
+ }
+}
+
+} // namespace
+
+namespace onert_train
+{
+
+Args::Args(const int argc, char **argv)
+{
+ Initialize();
+ Parse(argc, argv);
+}
+
+void Args::Initialize(void)
+{
+ auto process_nnpackage = [&](const std::string &package_filename) {
+ _package_filename = package_filename;
+
+ std::cerr << "Package Filename " << _package_filename << std::endl;
+ checkPackage(package_filename);
+ };
+
+ auto process_modelfile = [&](const std::string &model_filename) {
+ _model_filename = model_filename;
+
+ std::cerr << "Model Filename " << _model_filename << std::endl;
+ checkModelfile(model_filename);
+
+ _use_single_model = true;
+ };
+
+ auto process_path = [&](const std::string &path) {
+ struct stat sb;
+ if (stat(path.c_str(), &sb) == 0)
+ {
+ if (sb.st_mode & S_IFDIR)
+ {
+ _package_filename = path;
+ checkPackage(path);
+ std::cerr << "Package Filename " << path << std::endl;
+ }
+ else
+ {
+ _model_filename = path;
+ checkModelfile(path);
+ std::cerr << "Model Filename " << path << std::endl;
+ _use_single_model = true;
+ }
+ }
+ else
+ {
+ std::cerr << "Cannot find: " << path << "\n";
+ exit(1);
+ }
+ };
+
+ auto process_load_raw_inputfile = [&](const std::string &input_filename) {
+ _load_raw_input_filename = input_filename;
+
+ std::cerr << "Model Input Filename " << _load_raw_input_filename << std::endl;
+ checkModelfile(_load_raw_input_filename);
+ };
+
+ auto process_load_raw_expectedfile = [&](const std::string &expected_filename) {
+ _load_raw_expected_filename = expected_filename;
+
+ std::cerr << "Model Expected Filename " << _load_raw_expected_filename << std::endl;
+ checkModelfile(_load_raw_expected_filename);
+ };
+
+ auto process_output_sizes = [&](const std::string &output_sizes_json_str) {
+ Json::Value root;
+ Json::Reader reader;
+ if (!reader.parse(output_sizes_json_str, root, false))
+ {
+ std::cerr << "Invalid JSON format for output_sizes \"" << output_sizes_json_str << "\"\n";
+ exit(1);
+ }
+
+ auto arg_map = argArrayToMap(root);
+ for (auto &pair : arg_map)
+ {
+ uint32_t key = pair.first;
+ Json::Value &val_json = pair.second;
+ if (!val_json.isUInt())
+ {
+ std::cerr << "All the values in `output_sizes` must be unsigned integers\n";
+ exit(1);
+ }
+ uint32_t val = val_json.asUInt();
+ _output_sizes[key] = val;
+ }
+ };
+
+ // General options
+ po::options_description general("General options", 100);
+
+ // clang-format off
+ general.add_options()
+ ("help,h", "Print available options")
+ ("version", "Print version and exit immediately")
+ ("nnpackage", po::value<std::string>()->notifier(process_nnpackage), "NN Package file(directory) name")
+ ("modelfile", po::value<std::string>()->notifier(process_modelfile), "NN Model filename")
+ ("path", po::value<std::string>()->notifier(process_path), "NN Package or NN Modelfile path")
+ ("data_length", po::value<int>()->default_value(-1)->notifier([&](const auto &v) { _data_length = v; }), "Data length number")
+ ("load_input:raw", po::value<std::string>()->notifier(process_load_raw_inputfile),
+ "NN Model Raw Input data file\n"
+ "The datafile must have data for each input number.\n"
+ "If there are 3 inputs, the data of input0 must exist as much as data_length, "
+ "and the data for input1 and input2 must be held sequentially as data_length.\n"
+ )
+ ("load_expected:raw", po::value<std::string>()->notifier(process_load_raw_expectedfile),
+ "NN Model Raw Expected data file\n"
+ "(Same data policy with load_input:raw)\n"
+ )
+ ("mem_poll,m", po::value<bool>()->default_value(false)->notifier([&](const auto &v) { _mem_poll = v; }), "Check memory polling")
+ ("epoch", po::value<int>()->default_value(5)->notifier([&](const auto &v) { _epoch = v; }), "Epoch number (default: 5)")
+ ("batch_size", po::value<int>()->default_value(32)->notifier([&](const auto &v) { _batch_size = v; }), "Batch size (default: 32)")
+ ("learning_rate", po::value<float>()->default_value(1.0e-4)->notifier([&](const auto &v) { _learning_rate = v; }), "Learning rate (default: 1.0e-4)")
+ ("loss", po::value<int>()->default_value(0)->notifier([&] (const auto &v) { _loss_type = v; }),
+ "Loss type\n"
+ "0: MEAN_SQUARED_ERROR (default)\n"
+ "1: CATEGORICAL_CROSSENTROPY\n")
+ ("optimizer", po::value<int>()->default_value(0)->notifier([&] (const auto &v) { _optimizer_type = v; }),
+ "Optimizer type\n"
+ "0: SGD (default)\n"
+ "1: Adam\n")
+ ("verbose_level,v", po::value<int>()->default_value(0)->notifier([&](const auto &v) { _verbose_level = v; }),
+ "Verbose level\n"
+ "0: prints the only result. Messages btw run don't print\n"
+ "1: prints result and message btw run\n"
+ "2: prints all of messages to print\n")
+ ("output_sizes", po::value<std::string>()->notifier(process_output_sizes),
+ "The output buffer size in JSON 1D array\n"
+ "If not given, the model's output sizes are used\n"
+ "e.g. '[0, 40, 2, 80]' to set 0th tensor to 40 and 2nd tensor to 80.\n")
+ ;
+ // clang-format on
+
+ _options.add(general);
+ _positional.add("path", -1);
+}
+
+void Args::Parse(const int argc, char **argv)
+{
+ po::variables_map vm;
+ po::store(po::command_line_parser(argc, argv).options(_options).positional(_positional).run(),
+ vm);
+
+ if (vm.count("help"))
+ {
+ std::cout << "onert_train\n\n";
+ std::cout << "Usage: " << argv[0] << "[model path] [<options>]\n\n";
+ std::cout << _options;
+ std::cout << "\n";
+
+ exit(0);
+ }
+
+ if (vm.count("version"))
+ {
+ _print_version = true;
+ return;
+ }
+
+ {
+ auto conflicting_options = [&](const std::string &o1, const std::string &o2) {
+ if ((vm.count(o1) && !vm[o1].defaulted()) && (vm.count(o2) && !vm[o2].defaulted()))
+ {
+ throw boost::program_options::error(std::string("Two options '") + o1 + "' and '" + o2 +
+ "' cannot be given at once.");
+ }
+ };
+
+ // Cannot use both single model file and nnpackage at once
+ conflicting_options("modelfile", "nnpackage");
+
+ // Require modelfile, nnpackage, or path
+ if (!vm.count("modelfile") && !vm.count("nnpackage") && !vm.count("path"))
+ throw boost::program_options::error(
+ std::string("Require one of options modelfile, nnpackage, or path."));
+ }
+
+ try
+ {
+ po::notify(vm);
+ }
+ catch (const std::bad_cast &e)
+ {
+ std::cerr << "Bad cast error - " << e.what() << '\n';
+ exit(1);
+ }
+}
+
+} // end of namespace onert_train
diff --git a/tests/tools/onert_train/src/args.h b/tests/tools/onert_train/src/args.h
new file mode 100644
index 000000000..cbd87e111
--- /dev/null
+++ b/tests/tools/onert_train/src/args.h
@@ -0,0 +1,92 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_TRAIN_ARGS_H__
+#define __ONERT_TRAIN_ARGS_H__
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+#include <boost/program_options.hpp>
+
+#include "types.h"
+
+namespace po = boost::program_options;
+
+namespace onert_train
+{
+
+using TensorShapeMap = std::unordered_map<uint32_t, TensorShape>;
+
+#if defined(ONERT_HAVE_HDF5) && ONERT_HAVE_HDF5 == 1
+enum class WhenToUseH5Shape
+{
+ NOT_PROVIDED, // Param not provided
+ PREPARE, // read shapes in h5 file and set them as inputs' shape before calling nnfw_prepare()
+ RUN, // read shapes in h5 file and set them as inputs' shape before calling nnfw_run()
+};
+#endif
+
+class Args
+{
+public:
+ Args(const int argc, char **argv);
+ void print(void);
+
+ const std::string &getPackageFilename(void) const { return _package_filename; }
+ const std::string &getModelFilename(void) const { return _model_filename; }
+ const bool useSingleModel(void) const { return _use_single_model; }
+ const int getDataLength(void) const { return _data_length; }
+ const std::string &getLoadRawInputFilename(void) const { return _load_raw_input_filename; }
+ const std::string &getLoadRawExpectedFilename(void) const { return _load_raw_expected_filename; }
+ const bool getMemoryPoll(void) const { return _mem_poll; }
+ const int getEpoch(void) const { return _epoch; }
+ const int getBatchSize(void) const { return _batch_size; }
+ const float getLearningRate(void) const { return _learning_rate; }
+ const int getLossType(void) const { return _loss_type; }
+ const int getOptimizerType(void) const { return _optimizer_type; }
+ const bool printVersion(void) const { return _print_version; }
+ const int getVerboseLevel(void) const { return _verbose_level; }
+ std::unordered_map<uint32_t, uint32_t> getOutputSizes(void) const { return _output_sizes; }
+
+private:
+ void Initialize();
+ void Parse(const int argc, char **argv);
+
+private:
+ po::positional_options_description _positional;
+ po::options_description _options;
+
+ std::string _package_filename;
+ std::string _model_filename;
+ bool _use_single_model = false;
+ int _data_length;
+ std::string _load_raw_input_filename;
+ std::string _load_raw_expected_filename;
+ bool _mem_poll;
+ int _epoch;
+ int _batch_size;
+ float _learning_rate;
+ int _loss_type;
+ int _optimizer_type;
+ bool _print_version = false;
+ int _verbose_level;
+ std::unordered_map<uint32_t, uint32_t> _output_sizes;
+};
+
+} // end of namespace onert_train
+
+#endif // __ONERT_TRAIN_ARGS_H__
diff --git a/tests/tools/onert_train/src/formatter.h b/tests/tools/onert_train/src/formatter.h
new file mode 100644
index 000000000..6d256804e
--- /dev/null
+++ b/tests/tools/onert_train/src/formatter.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_TRAIN_FORMATTER_H__
+#define __ONERT_TRAIN_FORMATTER_H__
+
+#include <string>
+#include <vector>
+
+#include "types.h"
+#include "allocation.h"
+
+struct nnfw_session;
+
+namespace onert_train
+{
+class Formatter
+{
+public:
+ virtual ~Formatter() = default;
+ Formatter(nnfw_session *sess) : session_(sess) {}
+ virtual void loadInputs(const std::string &filename, std::vector<Allocation> &inputs) = 0;
+ virtual void dumpOutputs(const std::string &filename, std::vector<Allocation> &outputs) = 0;
+ virtual std::vector<TensorShape> readTensorShapes(const std::string &filename)
+ {
+ return std::vector<TensorShape>();
+ };
+
+protected:
+ nnfw_session *session_;
+};
+} // namespace onert_train
+
+#endif // __ONERT_TRAIN_FORMATTER_H__
diff --git a/tests/tools/onert_train/src/h5formatter.cc b/tests/tools/onert_train/src/h5formatter.cc
new file mode 100644
index 000000000..12c570b5d
--- /dev/null
+++ b/tests/tools/onert_train/src/h5formatter.cc
@@ -0,0 +1,258 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "h5formatter.h"
+#include "nnfw.h"
+#include "nnfw_util.h"
+
+#include <iostream>
+#include <stdexcept>
+#include <H5Cpp.h>
+
+namespace
+{
+onert_train::TensorShape getShape(H5::DataSet &data_set)
+{
+ std::vector<hsize_t> h5_shape; // hsize_t is unsigned long long
+ H5::DataSpace data_space = data_set.getSpace();
+ int rank = data_space.getSimpleExtentNdims();
+ h5_shape.resize(rank);
+
+ // read shape info from H5 file
+ data_space.getSimpleExtentDims(h5_shape.data(), NULL);
+
+ onert_train::TensorShape shape;
+ for (auto dim : h5_shape)
+ shape.emplace_back(static_cast<int>(dim));
+
+ return shape;
+}
+} // namespace
+
+namespace onert_train
+{
+static const char *h5_value_grpname = "value";
+
+std::vector<TensorShape> H5Formatter::readTensorShapes(const std::string &filename)
+{
+ uint32_t num_inputs;
+ NNPR_ENSURE_STATUS(nnfw_input_size(session_, &num_inputs));
+ std::vector<TensorShape> tensor_shapes;
+
+ try
+ {
+ H5::Exception::dontPrint();
+
+ H5::H5File file(filename, H5F_ACC_RDONLY);
+ H5::Group value_group = file.openGroup(h5_value_grpname);
+
+ // Constraints: if there are n data set names, they should be unique and
+ // one of [ "0", "1", .. , "n-1" ]
+ for (uint32_t i = 0; i < num_inputs; ++i)
+ {
+ H5::DataSet data_set = value_group.openDataSet(std::to_string(i));
+ H5::DataType type = data_set.getDataType();
+ auto shape = getShape(data_set);
+
+ tensor_shapes.emplace_back(shape);
+ }
+
+ return tensor_shapes;
+ }
+ catch (const H5::Exception &e)
+ {
+ H5::Exception::printErrorStack();
+ std::exit(-1);
+ }
+ catch (const std::exception &e)
+ {
+ std::cerr << e.what() << std::endl;
+ std::exit(-1);
+ }
+}
+
+void H5Formatter::loadInputs(const std::string &filename, std::vector<Allocation> &inputs)
+{
+ uint32_t num_inputs;
+ NNPR_ENSURE_STATUS(nnfw_input_size(session_, &num_inputs));
+ try
+ {
+ // Turn off the automatic error printing.
+ H5::Exception::dontPrint();
+
+ H5::H5File file(filename, H5F_ACC_RDONLY);
+ H5::Group value_group = file.openGroup(h5_value_grpname);
+ for (uint32_t i = 0; i < num_inputs; ++i)
+ {
+ nnfw_tensorinfo ti;
+ NNPR_ENSURE_STATUS(nnfw_input_tensorinfo(session_, i, &ti));
+
+ // TODO Add Assert(nnfw shape, h5 file shape size)
+
+ // allocate memory for data
+ auto bufsz = bufsize_for(&ti);
+ inputs[i].alloc(bufsz);
+
+ H5::DataSet data_set = value_group.openDataSet(std::to_string(i));
+ H5::DataType type = data_set.getDataType();
+ switch (ti.dtype)
+ {
+ case NNFW_TYPE_TENSOR_FLOAT32:
+ if (type == H5::PredType::IEEE_F32BE || type == H5::PredType::IEEE_F32LE)
+ data_set.read(inputs[i].data(), H5::PredType::NATIVE_FLOAT);
+ else
+ throw std::runtime_error("model input type is f32. But h5 data type is different.");
+ break;
+ case NNFW_TYPE_TENSOR_INT32:
+ if (type == H5::PredType::STD_I32BE || type == H5::PredType::STD_I32LE)
+ data_set.read(inputs[i].data(), H5::PredType::NATIVE_INT32);
+ else
+ throw std::runtime_error("model input type is i32. But h5 data type is different.");
+ break;
+ case NNFW_TYPE_TENSOR_INT64:
+ if (type == H5::PredType::STD_I64BE || type == H5::PredType::STD_I64LE)
+ data_set.read(inputs[i].data(), H5::PredType::NATIVE_INT64);
+ else
+ throw std::runtime_error("model input type is i64. But h5 data type is different.");
+ break;
+ case NNFW_TYPE_TENSOR_QUANT8_ASYMM:
+ case NNFW_TYPE_TENSOR_BOOL:
+ case NNFW_TYPE_TENSOR_UINT8:
+ if (type == H5::PredType::STD_U8BE || type == H5::PredType::STD_U8LE)
+ data_set.read(inputs[i].data(), H5::PredType::NATIVE_UINT8);
+ else
+ throw std::runtime_error(
+ "model input type is qasymm8, bool or uint8. But h5 data type is different.");
+ break;
+ case NNFW_TYPE_TENSOR_QUANT8_ASYMM_SIGNED:
+ if (type == H5::PredType::STD_I8BE || type == H5::PredType::STD_I8LE)
+ data_set.read(inputs[i].data(), H5::PredType::NATIVE_INT8);
+ else
+ throw std::runtime_error("model input type is int8. But h5 data type is different.");
+ break;
+ case NNFW_TYPE_TENSOR_QUANT16_SYMM_SIGNED:
+ throw std::runtime_error("NYI for NNFW_TYPE_TENSOR_QUANT16_SYMM_SIGNED type");
+ default:
+ throw std::runtime_error("onert_run can load f32, i32, qasymm8, bool and uint8.");
+ }
+ NNPR_ENSURE_STATUS(nnfw_set_input(session_, i, ti.dtype, inputs[i].data(), bufsz));
+ NNPR_ENSURE_STATUS(nnfw_set_input_layout(session_, i, NNFW_LAYOUT_CHANNELS_LAST));
+ }
+ }
+ catch (const H5::Exception &e)
+ {
+ H5::Exception::printErrorStack();
+ std::exit(-1);
+ }
+ catch (const std::exception &e)
+ {
+ std::cerr << e.what() << std::endl;
+ std::exit(-1);
+ }
+};
+
+void H5Formatter::dumpOutputs(const std::string &filename, std::vector<Allocation> &outputs)
+{
+ uint32_t num_outputs;
+ NNPR_ENSURE_STATUS(nnfw_output_size(session_, &num_outputs));
+ try
+ {
+ // Turn off the automatic error printing.
+ H5::Exception::dontPrint();
+
+ H5::H5File file(filename, H5F_ACC_TRUNC);
+ H5::Group value_group = file.createGroup(h5_value_grpname);
+ for (uint32_t i = 0; i < num_outputs; i++)
+ {
+ nnfw_tensorinfo ti;
+ NNPR_ENSURE_STATUS(nnfw_output_tensorinfo(session_, i, &ti));
+ std::vector<hsize_t> dims(ti.rank);
+ for (uint32_t j = 0; j < ti.rank; ++j)
+ {
+ if (ti.dims[j] >= 0)
+ dims[j] = static_cast<hsize_t>(ti.dims[j]);
+ else
+ {
+ std::cerr << "Negative dimension in output tensor" << std::endl;
+ exit(-1);
+ }
+ }
+ H5::DataSpace data_space(ti.rank, dims.data());
+ switch (ti.dtype)
+ {
+ case NNFW_TYPE_TENSOR_FLOAT32:
+ {
+ H5::DataSet data_set =
+ value_group.createDataSet(std::to_string(i), H5::PredType::IEEE_F32BE, data_space);
+ data_set.write(outputs[i].data(), H5::PredType::NATIVE_FLOAT);
+ break;
+ }
+ case NNFW_TYPE_TENSOR_INT32:
+ {
+ H5::DataSet data_set =
+ value_group.createDataSet(std::to_string(i), H5::PredType::STD_I32LE, data_space);
+ data_set.write(outputs[i].data(), H5::PredType::NATIVE_INT32);
+ break;
+ }
+ case NNFW_TYPE_TENSOR_INT64:
+ {
+ H5::DataSet data_set =
+ value_group.createDataSet(std::to_string(i), H5::PredType::STD_I64LE, data_space);
+ data_set.write(outputs[i].data(), H5::PredType::NATIVE_INT64);
+ break;
+ }
+ case NNFW_TYPE_TENSOR_UINT8:
+ case NNFW_TYPE_TENSOR_QUANT8_ASYMM:
+ {
+ H5::DataSet data_set =
+ value_group.createDataSet(std::to_string(i), H5::PredType::STD_U8BE, data_space);
+ data_set.write(outputs[i].data(), H5::PredType::NATIVE_UINT8);
+ break;
+ }
+ case NNFW_TYPE_TENSOR_BOOL:
+ {
+ H5::DataSet data_set =
+ value_group.createDataSet(std::to_string(i), H5::PredType::STD_U8LE, data_space);
+ data_set.write(outputs[i].data(), H5::PredType::NATIVE_INT8);
+ break;
+ }
+ case NNFW_TYPE_TENSOR_QUANT8_ASYMM_SIGNED:
+ {
+ H5::DataSet data_set =
+ value_group.createDataSet(std::to_string(i), H5::PredType::STD_I8LE, data_space);
+ data_set.write(outputs[i].data(), H5::PredType::NATIVE_INT8);
+ break;
+ }
+ case NNFW_TYPE_TENSOR_QUANT16_SYMM_SIGNED:
+ throw std::runtime_error("NYI for NNFW_TYPE_TENSOR_QUANT16_SYMM_SIGNED type");
+ default:
+ throw std::runtime_error("onert_run can dump f32, i32, qasymm8, bool and uint8.");
+ }
+ }
+ }
+ catch (const H5::Exception &e)
+ {
+ H5::Exception::printErrorStack();
+ std::exit(-1);
+ }
+ catch (const std::runtime_error &e)
+ {
+ std::cerr << "Error during dumpOutputs on onert_run : " << e.what() << std::endl;
+ std::exit(-1);
+ }
+};
+
+} // end of namespace onert_train
diff --git a/tests/tools/onert_train/src/h5formatter.h b/tests/tools/onert_train/src/h5formatter.h
new file mode 100644
index 000000000..21ef16526
--- /dev/null
+++ b/tests/tools/onert_train/src/h5formatter.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_TRAIN_H5FORMATTER_H__
+#define __ONERT_TRAIN_H5FORMATTER_H__
+
+#include "allocation.h"
+#include "formatter.h"
+#include "types.h"
+
+#include <string>
+#include <vector>
+
+struct nnfw_session;
+
+namespace onert_train
+{
+class H5Formatter : public Formatter
+{
+public:
+ H5Formatter(nnfw_session *sess) : Formatter(sess) {}
+ std::vector<TensorShape> readTensorShapes(const std::string &filename) override;
+ void loadInputs(const std::string &filename, std::vector<Allocation> &inputs) override;
+ void dumpOutputs(const std::string &filename, std::vector<Allocation> &outputs) override;
+};
+} // namespace onert_train
+
+#endif // __ONERT_TRAIN_H5FORMATTER_H__
diff --git a/tests/tools/onert_train/src/measure.h b/tests/tools/onert_train/src/measure.h
new file mode 100644
index 000000000..f7c8610d0
--- /dev/null
+++ b/tests/tools/onert_train/src/measure.h
@@ -0,0 +1,90 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_TRAIN_MEASURE_H__
+#define __ONERT_TRAIN_MEASURE_H__
+
+#include <algorithm>
+#include <ctime>
+#include <vector>
+
+namespace
+{
+uint64_t nowMicros()
+{
+ struct timespec ts;
+ clock_gettime(CLOCK_MONOTONIC, &ts);
+ return static_cast<uint64_t>(ts.tv_nsec) / 1e3 + static_cast<uint64_t>(ts.tv_sec) * 1e6;
+}
+} // namespace
+
+namespace onert_train
+{
+
+struct Step
+{
+ uint64_t time; // us
+ // TODO Support memory usage
+};
+
+class Measure
+{
+public:
+ Measure() = default;
+
+ void set(const int epoch, const int step)
+ {
+ _results.clear();
+ _results.resize(epoch);
+ std::for_each(_results.begin(), _results.end(), [step](auto &v) { v.resize(step); });
+ }
+
+ void run(const int epoch, const int step, const std::function<void()> &func)
+ {
+ if (_results.empty() || _results.size() <= epoch || _results[epoch].size() <= step)
+ {
+ throw std::runtime_error("Please set the number of epochs and steps first");
+ }
+
+ _results[epoch][step].time = nowMicros();
+
+ func();
+
+ _results[epoch][step].time = nowMicros() - _results[epoch][step].time;
+ }
+
+ double timeMicros(const int epoch)
+ {
+ if (_results.empty() || _results.size() <= epoch)
+ {
+ throw std::runtime_error("Invalid epoch");
+ }
+
+ double sum = 0u;
+ std::for_each(_results[epoch].begin(), _results[epoch].end(),
+ [&sum](auto &v) { sum += v.time; });
+ return sum / _results[epoch].size();
+ }
+
+ double timeMs(const int epoch) { return timeMicros(epoch) / 1e3; }
+
+private:
+ std::vector<std::vector<Step>> _results;
+};
+
+} // namespace onert_train
+
+#endif // __ONERT_TRAIN_MEASURE_H__
diff --git a/tests/tools/onert_train/src/nnfw_util.cc b/tests/tools/onert_train/src/nnfw_util.cc
new file mode 100644
index 000000000..8dd2aa871
--- /dev/null
+++ b/tests/tools/onert_train/src/nnfw_util.cc
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <cassert>
+#include <string>
+#include "nnfw.h"
+
+namespace onert_train
+{
+uint64_t num_elems(const nnfw_tensorinfo *ti)
+{
+ uint64_t n = 1;
+ for (uint32_t i = 0; i < ti->rank; ++i)
+ {
+ assert(ti->dims[i] >= 0);
+ n *= ti->dims[i];
+ }
+ return n;
+}
+
+uint64_t bufsize_for(const nnfw_tensorinfo *ti)
+{
+ static int elmsize[] = {
+ sizeof(float), /* NNFW_TYPE_TENSOR_FLOAT32 */
+ sizeof(int), /* NNFW_TYPE_TENSOR_INT32 */
+ sizeof(uint8_t), /* NNFW_TYPE_TENSOR_QUANT8_ASYMM */
+ sizeof(bool), /* NNFW_TYPE_TENSOR_BOOL = 3 */
+ sizeof(uint8_t), /* NNFW_TYPE_TENSOR_UINT8 = 4 */
+ sizeof(int64_t), /* NNFW_TYPE_TENSOR_INT64 = 5 */
+ sizeof(int8_t), /* NNFW_TYPE_TENSOR_QUANT8_ASYMM_SIGNED = 6 */
+ sizeof(int16_t), /* NNFW_TYPE_TENSOR_QUANT16_SYMM_SIGNED = 7 */
+ };
+ return elmsize[ti->dtype] * num_elems(ti);
+}
+
+} // namespace onert_train
diff --git a/tests/tools/onert_train/src/nnfw_util.h b/tests/tools/onert_train/src/nnfw_util.h
new file mode 100644
index 000000000..674e18fb2
--- /dev/null
+++ b/tests/tools/onert_train/src/nnfw_util.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_TRAIN_NNFW_UTIL_H__
+#define __ONERT_TRAIN_NNFW_UTIL_H__
+
+#include "nnfw.h"
+
+#define NNPR_ENSURE_STATUS(a) \
+ do \
+ { \
+ if ((a) != NNFW_STATUS_NO_ERROR) \
+ { \
+ exit(-1); \
+ } \
+ } while (0)
+
+namespace onert_train
+{
+uint64_t num_elems(const nnfw_tensorinfo *ti);
+uint64_t bufsize_for(const nnfw_tensorinfo *ti);
+} // end of namespace onert_train
+
+#endif // __ONERT_TRAIN_NNFW_UTIL_H__
diff --git a/tests/tools/onert_train/src/onert_train.cc b/tests/tools/onert_train/src/onert_train.cc
new file mode 100644
index 000000000..678d13fc9
--- /dev/null
+++ b/tests/tools/onert_train/src/onert_train.cc
@@ -0,0 +1,277 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "allocation.h"
+#include "args.h"
+#include "benchmark.h"
+#include "measure.h"
+#include "nnfw.h"
+#include "nnfw_util.h"
+#include "nnfw_internal.h"
+#include "nnfw_experimental.h"
+#include "randomgen.h"
+#include "rawformatter.h"
+#include "rawdataloader.h"
+
+#include <boost/program_options.hpp>
+#include <cassert>
+#include <chrono>
+#include <cstdlib>
+#include <iostream>
+#include <libgen.h>
+#include <stdexcept>
+#include <unordered_map>
+#include <vector>
+
+static const char *default_backend_cand = "train";
+
+int main(const int argc, char **argv)
+{
+ using namespace onert_train;
+
+ try
+ {
+ Args args(argc, argv);
+ if (args.printVersion())
+ {
+ uint32_t version;
+ NNPR_ENSURE_STATUS(nnfw_query_info_u32(NULL, NNFW_INFO_ID_VERSION, &version));
+ std::cout << "onert_train (nnfw runtime: v" << (version >> 24) << "."
+ << ((version & 0x0000FF00) >> 8) << "." << (version & 0xFF) << ")" << std::endl;
+ exit(0);
+ }
+
+ // TODO Apply verbose level to phases
+ const int verbose = args.getVerboseLevel();
+ benchmark::Phases phases(benchmark::PhaseOption{});
+
+ nnfw_session *session = nullptr;
+ NNPR_ENSURE_STATUS(nnfw_create_session(&session));
+
+ // ModelLoad
+ phases.run("MODEL_LOAD", [&](const benchmark::Phase &, uint32_t) {
+ if (args.useSingleModel())
+ NNPR_ENSURE_STATUS(
+ nnfw_load_model_from_modelfile(session, args.getModelFilename().c_str()));
+ else
+ NNPR_ENSURE_STATUS(nnfw_load_model_from_file(session, args.getPackageFilename().c_str()));
+ });
+
+ // Set training backend
+ NNPR_ENSURE_STATUS(nnfw_set_available_backends(session, default_backend_cand));
+
+ uint32_t num_inputs;
+ NNPR_ENSURE_STATUS(nnfw_input_size(session, &num_inputs));
+
+ uint32_t num_expecteds;
+ NNPR_ENSURE_STATUS(nnfw_output_size(session, &num_expecteds));
+
+ // verify input and output
+
+ auto verifyInputTypes = [session]() {
+ uint32_t sz;
+ NNPR_ENSURE_STATUS(nnfw_input_size(session, &sz));
+ for (uint32_t i = 0; i < sz; ++i)
+ {
+ nnfw_tensorinfo ti;
+ NNPR_ENSURE_STATUS(nnfw_input_tensorinfo(session, i, &ti));
+
+ if (ti.dtype < NNFW_TYPE_TENSOR_FLOAT32 || ti.dtype > NNFW_TYPE_TENSOR_QUANT16_SYMM_SIGNED)
+ {
+ std::cerr << "E: not supported input type" << std::endl;
+ exit(-1);
+ }
+ }
+ };
+
+ auto verifyOutputTypes = [session]() {
+ uint32_t sz;
+ NNPR_ENSURE_STATUS(nnfw_output_size(session, &sz));
+
+ for (uint32_t i = 0; i < sz; ++i)
+ {
+ nnfw_tensorinfo ti;
+ NNPR_ENSURE_STATUS(nnfw_output_tensorinfo(session, i, &ti));
+
+ if (ti.dtype < NNFW_TYPE_TENSOR_FLOAT32 || ti.dtype > NNFW_TYPE_TENSOR_QUANT16_SYMM_SIGNED)
+ {
+ std::cerr << "E: not supported output type" << std::endl;
+ exit(-1);
+ }
+ }
+ };
+
+ verifyInputTypes();
+ verifyOutputTypes();
+
+ auto convertLossType = [](int type) {
+ switch (type)
+ {
+ case 0:
+ return NNFW_TRAIN_LOSS_MEAN_SQUARED_ERROR;
+ case 1:
+ return NNFW_TRAIN_LOSS_CATEGORICAL_CROSSENTROPY;
+ default:
+ std::cerr << "E: not supported loss type" << std::endl;
+ exit(-1);
+ }
+ };
+
+ auto convertOptType = [](int type) {
+ switch (type)
+ {
+ case 0:
+ return NNFW_TRAIN_OPTIMIZER_SGD;
+ case 1:
+ return NNFW_TRAIN_OPTIMIZER_ADAM;
+ default:
+ std::cerr << "E: not supported optimizer type" << std::endl;
+ exit(-1);
+ }
+ };
+
+ // prepare training info
+ nnfw_train_info tri;
+ tri.batch_size = args.getBatchSize();
+ tri.learning_rate = args.getLearningRate();
+ tri.loss = convertLossType(args.getLossType());
+ tri.opt = convertOptType(args.getOptimizerType());
+
+ // prepare execution
+
+ // TODO When nnfw_{prepare|run} are failed, can't catch the time
+ phases.run("PREPARE", [&](const benchmark::Phase &, uint32_t) {
+ NNPR_ENSURE_STATUS(nnfw_train_prepare(session, &tri));
+ });
+
+ // prepare input and expected tensor info lists
+ std::vector<nnfw_tensorinfo> input_infos;
+ std::vector<nnfw_tensorinfo> expected_infos;
+
+ // prepare data buffers
+ std::vector<Allocation> input_data(num_inputs);
+ std::vector<Allocation> expected_data(num_expecteds);
+
+ for (uint32_t i = 0; i < num_inputs; ++i)
+ {
+ nnfw_tensorinfo ti;
+ NNPR_ENSURE_STATUS(nnfw_input_tensorinfo(session, i, &ti));
+ input_data[i].alloc(bufsize_for(&ti));
+ input_infos.emplace_back(std::move(ti));
+ }
+
+ for (uint32_t i = 0; i < num_expecteds; ++i)
+ {
+ nnfw_tensorinfo ti;
+ NNPR_ENSURE_STATUS(nnfw_output_tensorinfo(session, i, &ti));
+ expected_data[i].alloc(bufsize_for(&ti));
+ expected_infos.emplace_back(std::move(ti));
+ }
+
+ auto data_length = args.getDataLength();
+
+ Generator generator;
+ RawDataLoader rawDataLoader;
+
+ if (!args.getLoadRawInputFilename().empty() && !args.getLoadRawExpectedFilename().empty())
+ {
+ generator =
+ rawDataLoader.loadData(args.getLoadRawInputFilename(), args.getLoadRawExpectedFilename(),
+ input_infos, expected_infos, data_length, tri.batch_size);
+ }
+ else
+ {
+ // TODO Use random generator
+ std::cerr << "E: not supported random input and expected generator" << std::endl;
+ exit(-1);
+ }
+
+ Measure measure;
+ std::vector<float> losses(num_expecteds);
+ phases.run("EXECUTE", [&](const benchmark::Phase &, uint32_t) {
+ const int num_step = data_length / tri.batch_size;
+ const int num_epoch = args.getEpoch();
+ measure.set(num_epoch, num_step);
+ for (uint32_t epoch = 0; epoch < num_epoch; ++epoch)
+ {
+ std::fill(losses.begin(), losses.end(), 0);
+ for (uint32_t n = 0; n < num_step; ++n)
+ {
+ // get batchsize data
+ if (!generator(n, input_data, expected_data))
+ break;
+
+ // prepare input
+ for (uint32_t i = 0; i < num_inputs; ++i)
+ {
+ NNPR_ENSURE_STATUS(
+ nnfw_train_set_input(session, i, input_data[i].data(), &input_infos[i]));
+ }
+
+ // prepare output
+ for (uint32_t i = 0; i < num_expecteds; ++i)
+ {
+ NNPR_ENSURE_STATUS(
+ nnfw_train_set_expected(session, i, expected_data[i].data(), &expected_infos[i]));
+ }
+
+ // train
+ measure.run(epoch, n, [&]() { NNPR_ENSURE_STATUS(nnfw_train(session, true)); });
+
+ // store loss
+ for (int32_t i = 0; i < num_expecteds; ++i)
+ {
+ float temp = 0.f;
+ NNPR_ENSURE_STATUS(nnfw_train_get_loss(session, i, &temp));
+ losses[i] += temp;
+ }
+ }
+
+ // print loss
+ std::cout << std::fixed;
+ std::cout.precision(3);
+ std::cout << "Epoch " << epoch + 1 << "/" << num_epoch << " - " << measure.timeMs(epoch)
+ << "ms/step - loss: ";
+ std::cout.precision(4);
+ for (uint32_t i = 0; i < num_expecteds; ++i)
+ {
+ std::cout << "[" << i << "] " << losses[i] / num_step;
+ }
+ std::cout /* << "- accuracy: " << accuracy*/ << std::endl;
+ }
+ });
+
+ NNPR_ENSURE_STATUS(nnfw_close_session(session));
+
+ // prepare result
+ benchmark::Result result(phases);
+
+ // to stdout
+ benchmark::printResult(result);
+
+ return 0;
+ }
+ catch (boost::program_options::error &e)
+ {
+ std::cerr << "E: " << e.what() << std::endl;
+ exit(-1);
+ }
+ catch (std::runtime_error &e)
+ {
+ std::cerr << "E: Fail to run by runtime error:" << e.what() << std::endl;
+ exit(-1);
+ }
+}
diff --git a/tests/tools/onert_train/src/randomgen.cc b/tests/tools/onert_train/src/randomgen.cc
new file mode 100644
index 000000000..72599cbb2
--- /dev/null
+++ b/tests/tools/onert_train/src/randomgen.cc
@@ -0,0 +1,77 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "randomgen.h"
+#include "nnfw.h"
+#include "nnfw_util.h"
+#include "misc/RandomGenerator.h"
+
+#include <iostream>
+
+namespace onert_train
+{
+
+template <class T> void randomData(nnfw::misc::RandomGenerator &randgen, void *data, uint64_t size)
+{
+ for (uint64_t i = 0; i < size; i++)
+ reinterpret_cast<T *>(data)[i] = randgen.generate<T>();
+}
+
+void RandomGenerator::generate(std::vector<Allocation> &inputs)
+{
+ // generate random data
+ const int seed = 1;
+ nnfw::misc::RandomGenerator randgen{seed, 0.0f, 2.0f};
+ for (uint32_t i = 0; i < inputs.size(); ++i)
+ {
+ nnfw_tensorinfo ti;
+ NNPR_ENSURE_STATUS(nnfw_input_tensorinfo(session_, i, &ti));
+ auto input_size_in_bytes = bufsize_for(&ti);
+ inputs[i].alloc(input_size_in_bytes);
+ switch (ti.dtype)
+ {
+ case NNFW_TYPE_TENSOR_FLOAT32:
+ randomData<float>(randgen, inputs[i].data(), num_elems(&ti));
+ break;
+ case NNFW_TYPE_TENSOR_QUANT8_ASYMM:
+ randomData<uint8_t>(randgen, inputs[i].data(), num_elems(&ti));
+ break;
+ case NNFW_TYPE_TENSOR_BOOL:
+ randomData<bool>(randgen, inputs[i].data(), num_elems(&ti));
+ break;
+ case NNFW_TYPE_TENSOR_UINT8:
+ randomData<uint8_t>(randgen, inputs[i].data(), num_elems(&ti));
+ break;
+ case NNFW_TYPE_TENSOR_INT32:
+ randomData<int32_t>(randgen, inputs[i].data(), num_elems(&ti));
+ break;
+ case NNFW_TYPE_TENSOR_INT64:
+ randomData<int64_t>(randgen, inputs[i].data(), num_elems(&ti));
+ break;
+ case NNFW_TYPE_TENSOR_QUANT16_SYMM_SIGNED:
+ randomData<int16_t>(randgen, inputs[i].data(), num_elems(&ti));
+ break;
+ default:
+ std::cerr << "Not supported input type" << std::endl;
+ std::exit(-1);
+ }
+ NNPR_ENSURE_STATUS(
+ nnfw_set_input(session_, i, ti.dtype, inputs[i].data(), input_size_in_bytes));
+ NNPR_ENSURE_STATUS(nnfw_set_input_layout(session_, i, NNFW_LAYOUT_CHANNELS_LAST));
+ }
+};
+
+} // end of namespace onert_train
diff --git a/tests/tools/onert_train/src/randomgen.h b/tests/tools/onert_train/src/randomgen.h
new file mode 100644
index 000000000..410c66d6f
--- /dev/null
+++ b/tests/tools/onert_train/src/randomgen.h
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_TRAIN_RANDOMGEN_H__
+#define __ONERT_TRAIN_RANDOMGEN_H__
+
+#include <string>
+#include <vector>
+
+#include "allocation.h"
+
+struct nnfw_session;
+
+namespace onert_train
+{
+class RandomGenerator
+{
+public:
+ RandomGenerator(nnfw_session *sess) : session_(sess) {}
+ void generate(std::vector<Allocation> &inputs);
+
+private:
+ nnfw_session *session_;
+};
+} // namespace onert_train
+
+#endif // __ONERT_TRAIN_RANDOMGEN_H__
diff --git a/tests/tools/onert_train/src/rawdataloader.cc b/tests/tools/onert_train/src/rawdataloader.cc
new file mode 100644
index 000000000..a3672a9f3
--- /dev/null
+++ b/tests/tools/onert_train/src/rawdataloader.cc
@@ -0,0 +1,77 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "rawdataloader.h"
+#include "nnfw_util.h"
+
+#include <iostream>
+#include <stdexcept>
+#include <numeric>
+
+namespace onert_train
+{
+
+Generator RawDataLoader::loadData(const std::string &input_file, const std::string &expected_file,
+ const std::vector<nnfw_tensorinfo> &input_infos,
+ const std::vector<nnfw_tensorinfo> &expected_infos,
+ const uint32_t data_length, const uint32_t batch_size)
+{
+ std::vector<uint32_t> input_origins(input_infos.size());
+ uint32_t start = 0;
+ for (uint32_t i = 0; i < input_infos.size(); ++i)
+ {
+ input_origins.at(i) = start;
+ start += (bufsize_for(&input_infos[i]) / batch_size * data_length);
+ }
+
+ std::vector<uint32_t> expected_origins(expected_infos.size());
+ start = 0;
+ for (uint32_t i = 0; i < expected_infos.size(); ++i)
+ {
+ expected_origins.at(i) = start;
+ start += (bufsize_for(&expected_infos[i]) / batch_size * data_length);
+ }
+
+ try
+ {
+ _input_file = std::ifstream(input_file, std::ios::ate | std::ios::binary);
+ _expected_file = std::ifstream(expected_file, std::ios::ate | std::ios::binary);
+ }
+ catch (const std::exception &e)
+ {
+ std::cerr << e.what() << std::endl;
+ std::exit(-1);
+ }
+
+ return [input_origins, expected_origins, &input_infos, &expected_infos,
+ this](uint32_t idx, std::vector<Allocation> &inputs, std::vector<Allocation> &expecteds) {
+ for (uint32_t i = 0; i < input_infos.size(); ++i)
+ {
+ auto bufsz = bufsize_for(&input_infos[i]);
+ _input_file.seekg(input_origins[i] + idx * bufsz, std::ios::beg);
+ _input_file.read(reinterpret_cast<char *>(inputs[i].data()), bufsz);
+ }
+ for (uint32_t i = 0; i < expected_infos.size(); ++i)
+ {
+ auto bufsz = bufsize_for(&expected_infos[i]);
+ _expected_file.seekg(expected_origins[i] + idx * bufsz, std::ios::beg);
+ _expected_file.read(reinterpret_cast<char *>(expecteds[i].data()), bufsz);
+ }
+ return true;
+ };
+}
+
+} // namespace onert_train
diff --git a/tests/tools/onert_train/src/rawdataloader.h b/tests/tools/onert_train/src/rawdataloader.h
new file mode 100644
index 000000000..3fb292770
--- /dev/null
+++ b/tests/tools/onert_train/src/rawdataloader.h
@@ -0,0 +1,51 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_TRAIN_RAWDATALOADER_H__
+#define __ONERT_TRAIN_RAWDATALOADER_H__
+
+#include "allocation.h"
+#include "nnfw.h"
+
+#include <functional>
+#include <string>
+#include <vector>
+#include <fstream>
+
+namespace onert_train
+{
+
+using Generator = std::function<bool(uint32_t, /** index **/
+ std::vector<Allocation> &, /** input **/
+ std::vector<Allocation> & /** expected **/)>;
+
+class RawDataLoader
+{
+public:
+ RawDataLoader() = default;
+ Generator loadData(const std::string &input_file, const std::string &expected_file,
+ const std::vector<nnfw_tensorinfo> &input_infos,
+ const std::vector<nnfw_tensorinfo> &output_infos, const uint32_t data_length,
+ const uint32_t batch_size);
+
+private:
+ std::ifstream _input_file;
+ std::ifstream _expected_file;
+};
+
+} // namespace onert_train
+
+#endif // __ONERT_TRAIN_RAWDATALOADER_H__
diff --git a/tests/tools/onert_train/src/rawformatter.cc b/tests/tools/onert_train/src/rawformatter.cc
new file mode 100644
index 000000000..a17071684
--- /dev/null
+++ b/tests/tools/onert_train/src/rawformatter.cc
@@ -0,0 +1,97 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "rawformatter.h"
+#include "nnfw.h"
+#include "nnfw_util.h"
+
+#include <iostream>
+#include <fstream>
+#include <stdexcept>
+
+namespace onert_train
+{
+void RawFormatter::loadInputs(const std::string &filename, std::vector<Allocation> &inputs)
+{
+ uint32_t num_inputs;
+ NNPR_ENSURE_STATUS(nnfw_input_size(session_, &num_inputs));
+
+ // Support multiple inputs
+ // Option 1: Get comman-separated input file list like --load:raw a,b,c
+ // Option 2: Get prefix --load:raw in
+ // Internally access in.0, in.1, in.2, ... in.{N-1} where N is determined by nnfw info
+ // query api.
+ //
+ // Currently Option 2 is implemented.
+ try
+ {
+ for (uint32_t i = 0; i < num_inputs; ++i)
+ {
+ nnfw_tensorinfo ti;
+ NNPR_ENSURE_STATUS(nnfw_input_tensorinfo(session_, i, &ti));
+
+ // allocate memory for data
+ auto bufsz = bufsize_for(&ti);
+ inputs[i].alloc(bufsz);
+
+ std::ifstream file(filename + "." + std::to_string(i), std::ios::ate | std::ios::binary);
+ auto filesz = file.tellg();
+ if (bufsz != filesz)
+ {
+ throw std::runtime_error("Input " + std::to_string(i) +
+ " size does not match: " + std::to_string(bufsz) +
+ " expected, but " + std::to_string(filesz) + " provided.");
+ }
+ file.seekg(0, std::ios::beg);
+ file.read(reinterpret_cast<char *>(inputs[i].data()), filesz);
+ file.close();
+
+ NNPR_ENSURE_STATUS(nnfw_set_input(session_, i, ti.dtype, inputs[i].data(), bufsz));
+ NNPR_ENSURE_STATUS(nnfw_set_input_layout(session_, i, NNFW_LAYOUT_CHANNELS_LAST));
+ }
+ }
+ catch (const std::exception &e)
+ {
+ std::cerr << e.what() << std::endl;
+ std::exit(-1);
+ }
+};
+
+void RawFormatter::dumpOutputs(const std::string &filename, std::vector<Allocation> &outputs)
+{
+ uint32_t num_outputs;
+ NNPR_ENSURE_STATUS(nnfw_output_size(session_, &num_outputs));
+ try
+ {
+ for (uint32_t i = 0; i < num_outputs; i++)
+ {
+ nnfw_tensorinfo ti;
+ NNPR_ENSURE_STATUS(nnfw_output_tensorinfo(session_, i, &ti));
+ auto bufsz = bufsize_for(&ti);
+
+ std::ofstream file(filename + "." + std::to_string(i), std::ios::out | std::ios::binary);
+ file.write(reinterpret_cast<const char *>(outputs[i].data()), bufsz);
+ file.close();
+ std::cerr << filename + "." + std::to_string(i) + " is generated.\n";
+ }
+ }
+ catch (const std::runtime_error &e)
+ {
+ std::cerr << "Error during dumpOutputs on onert_run : " << e.what() << std::endl;
+ std::exit(-1);
+ }
+}
+} // end of namespace onert_train
diff --git a/tests/tools/onert_train/src/rawformatter.h b/tests/tools/onert_train/src/rawformatter.h
new file mode 100644
index 000000000..90e81b2e9
--- /dev/null
+++ b/tests/tools/onert_train/src/rawformatter.h
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_TRAIN_RAWFORMATTER_H__
+#define __ONERT_TRAIN_RAWFORMATTER_H__
+
+#include "allocation.h"
+#include "formatter.h"
+#include "types.h"
+
+#include <string>
+#include <vector>
+
+struct nnfw_session;
+
+namespace onert_train
+{
+class RawFormatter : public Formatter
+{
+public:
+ RawFormatter(nnfw_session *sess) : Formatter(sess) {}
+ void loadInputs(const std::string &filename, std::vector<Allocation> &inputs) override;
+ void dumpOutputs(const std::string &filename, std::vector<Allocation> &outputs) override;
+};
+} // namespace onert_train
+
+#endif // __ONERT_TRAIN_RAWFORMATTER_H__
diff --git a/tests/tools/onert_train/src/types.h b/tests/tools/onert_train/src/types.h
new file mode 100644
index 000000000..6e2693016
--- /dev/null
+++ b/tests/tools/onert_train/src/types.h
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_TRAIN_TYPES_H__
+#define __ONERT_TRAIN_TYPES_H__
+
+namespace onert_train
+{
+
+using TensorShape = std::vector<int>;
+
+} // end of namespace onert_train
+
+#endif // __ONERT_TRAIN_TYPES_H__
diff --git a/tests/tools/onert_train/test/rawdataloader.test.cc b/tests/tools/onert_train/test/rawdataloader.test.cc
new file mode 100644
index 000000000..b2930b37e
--- /dev/null
+++ b/tests/tools/onert_train/test/rawdataloader.test.cc
@@ -0,0 +1,224 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <nnfw.h>
+
+#include <gtest/gtest.h>
+#include <algorithm>
+#include <numeric>
+
+#include "../src/rawdataloader.h"
+#include "../src/nnfw_util.h"
+
+namespace
+{
+using namespace onert_train;
+
+class DataFileGenerator
+{
+public:
+ DataFileGenerator(uint32_t data_length)
+ : _data_length{data_length}, _input_file{"input.bin"}, _expected_file{"expected.bin"}
+ {
+ }
+ ~DataFileGenerator()
+ {
+ try
+ {
+ if (std::remove(_input_file.c_str()) != 0)
+ {
+ std::cerr << "Failed to remove " << _input_file << std::endl;
+ }
+ if (std::remove(_expected_file.c_str()) != 0)
+ {
+ std::cerr << "Failed to remove " << _expected_file << std::endl;
+ }
+ }
+ catch (const std::exception &e)
+ {
+ std::cerr << "Exception: " << e.what() << std::endl;
+ }
+ }
+
+ template <typename T>
+ const std::string &generateInputData(const std::vector<std::vector<T>> &data)
+ {
+ generateData(_input_file, data);
+ return _input_file;
+ }
+
+ template <typename T>
+ const std::string &generateExpectedData(const std::vector<std::vector<T>> &data)
+ {
+ generateData(_expected_file, data);
+ return _expected_file;
+ }
+
+private:
+ template <typename T>
+ void generateData(const std::string &name, const std::vector<std::vector<T>> &data)
+ {
+ try
+ {
+ std::ofstream file(name, std::ios::binary);
+ for (uint32_t i = 0; i < data.size(); ++i)
+ {
+ for (uint32_t j = 0; j < _data_length; ++j)
+ {
+ for (uint32_t k = 0; k < data[i].size(); ++k)
+ {
+ file.write(reinterpret_cast<const char *>(&data[i][k]), sizeof(data[i][k]));
+ }
+ }
+ }
+ }
+ catch (const std::exception &e)
+ {
+ std::cerr << "Exception: " << e.what() << std::endl;
+ }
+ }
+
+private:
+ uint32_t _data_length;
+ std::string _input_file;
+ std::string _expected_file;
+};
+
+class RawDataLoaderTest : public testing::Test
+{
+protected:
+ void SetUp() override { nnfw_create_session(&_session); }
+
+ void TearDown() override { nnfw_close_session(_session); }
+
+ nnfw_session *_session = nullptr;
+};
+
+TEST_F(RawDataLoaderTest, loadDatas_1)
+{
+ const uint32_t data_length = 100;
+ const uint32_t num_input = 1;
+ const uint32_t num_expected = 1;
+ const uint32_t batch_size = 16;
+
+ // Set data tensor info
+ nnfw_tensorinfo in_info = {
+ .dtype = NNFW_TYPE_TENSOR_INT32,
+ .rank = 4,
+ .dims = {batch_size, 2, 2, 2},
+ };
+ std::vector<nnfw_tensorinfo> in_infos{in_info};
+
+ nnfw_tensorinfo expected_info = {
+ .dtype = NNFW_TYPE_TENSOR_INT32,
+ .rank = 4,
+ .dims = {batch_size, 1, 1, 1},
+ };
+ std::vector<nnfw_tensorinfo> expected_infos{expected_info};
+
+ // Generate test data
+ std::vector<std::vector<uint32_t>> in(num_input);
+ for (uint32_t i = 0; i < num_input; ++i)
+ {
+ in[i].resize(num_elems(&in_infos[i]) / batch_size);
+ std::generate(in[i].begin(), in[i].end(), [this] {
+ static uint32_t i = 0;
+ return i++;
+ });
+ }
+
+ std::vector<std::vector<uint32_t>> expected(num_expected);
+ for (uint32_t i = 0; i < num_expected; ++i)
+ {
+ expected[i].resize(num_elems(&expected_infos[i]) / batch_size);
+ std::generate(expected[i].begin(), expected[i].end(), [in, i] {
+ auto sum = std::accumulate(in[i].begin(), in[i].end(), 0);
+ return sum;
+ });
+ }
+
+ // Generate test data file
+ DataFileGenerator file_gen(data_length);
+ auto &input_file = file_gen.generateInputData<uint32_t>(in);
+ auto &expected_file = file_gen.generateExpectedData<uint32_t>(expected);
+
+ // Set expected datas
+ std::vector<std::vector<uint32_t>> expected_in(num_input);
+ std::vector<std::vector<uint32_t>> expected_ex(num_expected);
+ for (uint32_t i = 0; i < num_input; ++i)
+ {
+ for (uint32_t j = 0; j < batch_size; ++j)
+ {
+ expected_in[i].insert(expected_in[i].end(), in[i].begin(), in[i].end());
+ }
+ }
+ for (uint32_t i = 0; i < num_expected; ++i)
+ {
+ for (uint32_t j = 0; j < batch_size; ++j)
+ {
+ expected_ex[i].insert(expected_ex[i].end(), expected[i].begin(), expected[i].end());
+ }
+ }
+
+ // Load test datas
+ RawDataLoader loader;
+ Generator generator =
+ loader.loadData(input_file, expected_file, in_infos, expected_infos, data_length, batch_size);
+
+ // Allocate inputs and expecteds data memory
+ std::vector<Allocation> inputs(num_input);
+ for (uint32_t i = 0; i < num_input; ++i)
+ {
+ inputs[i].alloc(bufsize_for(&in_infos[i]));
+ }
+ std::vector<Allocation> expecteds(num_expected);
+ for (uint32_t i = 0; i < num_expected; ++i)
+ {
+ expecteds[i].alloc(bufsize_for(&expected_infos[i]));
+ }
+
+ uint32_t num_sample = data_length / batch_size;
+ for (uint32_t i = 0; i < num_sample; ++i)
+ {
+ auto data = generator(i, inputs, expecteds);
+
+ std::vector<std::vector<uint32_t>> gen_in(num_input);
+ for (uint32_t h = 0; h < num_input; ++h)
+ {
+ auto num_elem = num_elems(&in_infos[h]);
+ for (uint32_t k = 0; k < num_elem; ++k)
+ {
+ auto inbufs = reinterpret_cast<uint32_t *>(inputs[h].data()) + k;
+ gen_in[h].emplace_back(*inbufs);
+ }
+ }
+ std::vector<std::vector<uint32_t>> gen_ex(num_expected);
+ for (uint32_t h = 0; h < num_expected; ++h)
+ {
+ auto num_elem = num_elems(&expected_infos[h]);
+ for (uint32_t k = 0; k < num_elem; ++k)
+ {
+ auto exbufs = reinterpret_cast<uint32_t *>(expecteds[h].data()) + k;
+ gen_ex[h].emplace_back(*exbufs);
+ }
+ }
+
+ EXPECT_EQ(gen_in, expected_in);
+ EXPECT_EQ(gen_ex, expected_ex);
+ }
+}
+
+} // namespace