summaryrefslogtreecommitdiff
path: root/tests/tools/onert_train/src
diff options
context:
space:
mode:
Diffstat (limited to 'tests/tools/onert_train/src')
-rw-r--r--tests/tools/onert_train/src/allocation.h46
-rw-r--r--tests/tools/onert_train/src/args.cc291
-rw-r--r--tests/tools/onert_train/src/args.h92
-rw-r--r--tests/tools/onert_train/src/formatter.h47
-rw-r--r--tests/tools/onert_train/src/h5formatter.cc258
-rw-r--r--tests/tools/onert_train/src/h5formatter.h41
-rw-r--r--tests/tools/onert_train/src/measure.h90
-rw-r--r--tests/tools/onert_train/src/nnfw_util.cc49
-rw-r--r--tests/tools/onert_train/src/nnfw_util.h37
-rw-r--r--tests/tools/onert_train/src/onert_train.cc277
-rw-r--r--tests/tools/onert_train/src/randomgen.cc77
-rw-r--r--tests/tools/onert_train/src/randomgen.h40
-rw-r--r--tests/tools/onert_train/src/rawdataloader.cc77
-rw-r--r--tests/tools/onert_train/src/rawdataloader.h51
-rw-r--r--tests/tools/onert_train/src/rawformatter.cc97
-rw-r--r--tests/tools/onert_train/src/rawformatter.h40
-rw-r--r--tests/tools/onert_train/src/types.h27
17 files changed, 1637 insertions, 0 deletions
diff --git a/tests/tools/onert_train/src/allocation.h b/tests/tools/onert_train/src/allocation.h
new file mode 100644
index 000000000..f5a6aa73b
--- /dev/null
+++ b/tests/tools/onert_train/src/allocation.h
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_TRAIN_ALLOCATION_H__
+#define __ONERT_TRAIN_ALLOCATION_H__
+
+#include <cstdlib>
+#include <cstdint>
+
+namespace onert_train
+{
+class Allocation
+{
+public:
+ Allocation() : data_(nullptr) {}
+ ~Allocation() { free(data_); }
+ void *data() const { return data_; }
+ void *alloc(uint64_t sz)
+ {
+ if (data_)
+ {
+ free(data_);
+ }
+
+ return data_ = malloc(sz);
+ }
+
+private:
+ void *data_;
+};
+} // namespace onert_train
+
+#endif // __ONERT_TRAIN_ALLOCATION_H__
diff --git a/tests/tools/onert_train/src/args.cc b/tests/tools/onert_train/src/args.cc
new file mode 100644
index 000000000..dbdd384b5
--- /dev/null
+++ b/tests/tools/onert_train/src/args.cc
@@ -0,0 +1,291 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "args.h"
+
+#include <functional>
+#include <iostream>
+#include <sys/stat.h>
+#include <json/json.h>
+
+namespace
+{
+
+// This function parses a json object and returns as a vector of integers
+// For example,
+// [0, [1, 2, 3, 4], 3, 40, 4, []] in JSON
+// is converted to:
+// {
+// 0 -> [1, 2, 3, 4]
+// 3 -> 40
+// 4 -> []
+// } in std::unordered_map. Note that the value type is still Json::Value.
+std::unordered_map<uint32_t, Json::Value> argArrayToMap(const Json::Value &jsonval)
+{
+ if (!jsonval.isArray() || (jsonval.size() % 2 != 0))
+ {
+ std::cerr << "JSON argument must be an even-sized array in JSON\n";
+ exit(1);
+ }
+
+ std::unordered_map<uint32_t, Json::Value> ret;
+ for (uint32_t i = 0; i < jsonval.size(); i += 2)
+ {
+ if (!jsonval[i].isUInt())
+ {
+ std::cerr << "Key values(values in even indices) must be unsigned integers\n";
+ exit(1);
+ }
+ uint32_t key = jsonval[i].asUInt();
+ Json::Value val = jsonval[i + 1];
+ ret[key] = jsonval[i + 1];
+ }
+ return ret;
+}
+
+void checkModelfile(const std::string &model_filename)
+{
+ if (model_filename.empty())
+ {
+ // TODO Print usage instead of the below message
+ std::cerr << "Please specify model file. Run with `--help` for usage."
+ << "\n";
+
+ exit(1);
+ }
+ else
+ {
+ if (access(model_filename.c_str(), F_OK) == -1)
+ {
+ std::cerr << "Model file not found: " << model_filename << "\n";
+ exit(1);
+ }
+ }
+}
+
+void checkPackage(const std::string &package_filename)
+{
+ if (package_filename.empty())
+ {
+ // TODO Print usage instead of the below message
+ std::cerr << "Please specify nnpackage file. Run with `--help` for usage."
+ << "\n";
+
+ exit(1);
+ }
+ else
+ {
+ if (access(package_filename.c_str(), F_OK) == -1)
+ {
+ std::cerr << "nnpackage not found: " << package_filename << "\n";
+ exit(1);
+ }
+ }
+}
+
+} // namespace
+
+namespace onert_train
+{
+
+Args::Args(const int argc, char **argv)
+{
+ Initialize();
+ Parse(argc, argv);
+}
+
+void Args::Initialize(void)
+{
+ auto process_nnpackage = [&](const std::string &package_filename) {
+ _package_filename = package_filename;
+
+ std::cerr << "Package Filename " << _package_filename << std::endl;
+ checkPackage(package_filename);
+ };
+
+ auto process_modelfile = [&](const std::string &model_filename) {
+ _model_filename = model_filename;
+
+ std::cerr << "Model Filename " << _model_filename << std::endl;
+ checkModelfile(model_filename);
+
+ _use_single_model = true;
+ };
+
+ auto process_path = [&](const std::string &path) {
+ struct stat sb;
+ if (stat(path.c_str(), &sb) == 0)
+ {
+ if (sb.st_mode & S_IFDIR)
+ {
+ _package_filename = path;
+ checkPackage(path);
+ std::cerr << "Package Filename " << path << std::endl;
+ }
+ else
+ {
+ _model_filename = path;
+ checkModelfile(path);
+ std::cerr << "Model Filename " << path << std::endl;
+ _use_single_model = true;
+ }
+ }
+ else
+ {
+ std::cerr << "Cannot find: " << path << "\n";
+ exit(1);
+ }
+ };
+
+ auto process_load_raw_inputfile = [&](const std::string &input_filename) {
+ _load_raw_input_filename = input_filename;
+
+ std::cerr << "Model Input Filename " << _load_raw_input_filename << std::endl;
+ checkModelfile(_load_raw_input_filename);
+ };
+
+ auto process_load_raw_expectedfile = [&](const std::string &expected_filename) {
+ _load_raw_expected_filename = expected_filename;
+
+ std::cerr << "Model Expected Filename " << _load_raw_expected_filename << std::endl;
+ checkModelfile(_load_raw_expected_filename);
+ };
+
+ auto process_output_sizes = [&](const std::string &output_sizes_json_str) {
+ Json::Value root;
+ Json::Reader reader;
+ if (!reader.parse(output_sizes_json_str, root, false))
+ {
+ std::cerr << "Invalid JSON format for output_sizes \"" << output_sizes_json_str << "\"\n";
+ exit(1);
+ }
+
+ auto arg_map = argArrayToMap(root);
+ for (auto &pair : arg_map)
+ {
+ uint32_t key = pair.first;
+ Json::Value &val_json = pair.second;
+ if (!val_json.isUInt())
+ {
+ std::cerr << "All the values in `output_sizes` must be unsigned integers\n";
+ exit(1);
+ }
+ uint32_t val = val_json.asUInt();
+ _output_sizes[key] = val;
+ }
+ };
+
+ // General options
+ po::options_description general("General options", 100);
+
+ // clang-format off
+ general.add_options()
+ ("help,h", "Print available options")
+ ("version", "Print version and exit immediately")
+ ("nnpackage", po::value<std::string>()->notifier(process_nnpackage), "NN Package file(directory) name")
+ ("modelfile", po::value<std::string>()->notifier(process_modelfile), "NN Model filename")
+ ("path", po::value<std::string>()->notifier(process_path), "NN Package or NN Modelfile path")
+ ("data_length", po::value<int>()->default_value(-1)->notifier([&](const auto &v) { _data_length = v; }), "Data length number")
+ ("load_input:raw", po::value<std::string>()->notifier(process_load_raw_inputfile),
+ "NN Model Raw Input data file\n"
+ "The datafile must have data for each input number.\n"
+ "If there are 3 inputs, the data of input0 must exist as much as data_length, "
+ "and the data for input1 and input2 must be held sequentially as data_length.\n"
+ )
+ ("load_expected:raw", po::value<std::string>()->notifier(process_load_raw_expectedfile),
+ "NN Model Raw Expected data file\n"
+ "(Same data policy with load_input:raw)\n"
+ )
+ ("mem_poll,m", po::value<bool>()->default_value(false)->notifier([&](const auto &v) { _mem_poll = v; }), "Check memory polling")
+ ("epoch", po::value<int>()->default_value(5)->notifier([&](const auto &v) { _epoch = v; }), "Epoch number (default: 5)")
+ ("batch_size", po::value<int>()->default_value(32)->notifier([&](const auto &v) { _batch_size = v; }), "Batch size (default: 32)")
+ ("learning_rate", po::value<float>()->default_value(1.0e-4)->notifier([&](const auto &v) { _learning_rate = v; }), "Learning rate (default: 1.0e-4)")
+ ("loss", po::value<int>()->default_value(0)->notifier([&] (const auto &v) { _loss_type = v; }),
+ "Loss type\n"
+ "0: MEAN_SQUARED_ERROR (default)\n"
+ "1: CATEGORICAL_CROSSENTROPY\n")
+ ("optimizer", po::value<int>()->default_value(0)->notifier([&] (const auto &v) { _optimizer_type = v; }),
+ "Optimizer type\n"
+ "0: SGD (default)\n"
+ "1: Adam\n")
+ ("verbose_level,v", po::value<int>()->default_value(0)->notifier([&](const auto &v) { _verbose_level = v; }),
+ "Verbose level\n"
+ "0: prints the only result. Messages btw run don't print\n"
+ "1: prints result and message btw run\n"
+ "2: prints all of messages to print\n")
+ ("output_sizes", po::value<std::string>()->notifier(process_output_sizes),
+ "The output buffer size in JSON 1D array\n"
+ "If not given, the model's output sizes are used\n"
+ "e.g. '[0, 40, 2, 80]' to set 0th tensor to 40 and 2nd tensor to 80.\n")
+ ;
+ // clang-format on
+
+ _options.add(general);
+ _positional.add("path", -1);
+}
+
+void Args::Parse(const int argc, char **argv)
+{
+ po::variables_map vm;
+ po::store(po::command_line_parser(argc, argv).options(_options).positional(_positional).run(),
+ vm);
+
+ if (vm.count("help"))
+ {
+ std::cout << "onert_train\n\n";
+ std::cout << "Usage: " << argv[0] << "[model path] [<options>]\n\n";
+ std::cout << _options;
+ std::cout << "\n";
+
+ exit(0);
+ }
+
+ if (vm.count("version"))
+ {
+ _print_version = true;
+ return;
+ }
+
+ {
+ auto conflicting_options = [&](const std::string &o1, const std::string &o2) {
+ if ((vm.count(o1) && !vm[o1].defaulted()) && (vm.count(o2) && !vm[o2].defaulted()))
+ {
+ throw boost::program_options::error(std::string("Two options '") + o1 + "' and '" + o2 +
+ "' cannot be given at once.");
+ }
+ };
+
+ // Cannot use both single model file and nnpackage at once
+ conflicting_options("modelfile", "nnpackage");
+
+ // Require modelfile, nnpackage, or path
+ if (!vm.count("modelfile") && !vm.count("nnpackage") && !vm.count("path"))
+ throw boost::program_options::error(
+ std::string("Require one of options modelfile, nnpackage, or path."));
+ }
+
+ try
+ {
+ po::notify(vm);
+ }
+ catch (const std::bad_cast &e)
+ {
+ std::cerr << "Bad cast error - " << e.what() << '\n';
+ exit(1);
+ }
+}
+
+} // end of namespace onert_train
diff --git a/tests/tools/onert_train/src/args.h b/tests/tools/onert_train/src/args.h
new file mode 100644
index 000000000..cbd87e111
--- /dev/null
+++ b/tests/tools/onert_train/src/args.h
@@ -0,0 +1,92 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_TRAIN_ARGS_H__
+#define __ONERT_TRAIN_ARGS_H__
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+#include <boost/program_options.hpp>
+
+#include "types.h"
+
+namespace po = boost::program_options;
+
+namespace onert_train
+{
+
+using TensorShapeMap = std::unordered_map<uint32_t, TensorShape>;
+
+#if defined(ONERT_HAVE_HDF5) && ONERT_HAVE_HDF5 == 1
+enum class WhenToUseH5Shape
+{
+ NOT_PROVIDED, // Param not provided
+ PREPARE, // read shapes in h5 file and set them as inputs' shape before calling nnfw_prepare()
+ RUN, // read shapes in h5 file and set them as inputs' shape before calling nnfw_run()
+};
+#endif
+
+class Args
+{
+public:
+ Args(const int argc, char **argv);
+ void print(void);
+
+ const std::string &getPackageFilename(void) const { return _package_filename; }
+ const std::string &getModelFilename(void) const { return _model_filename; }
+ const bool useSingleModel(void) const { return _use_single_model; }
+ const int getDataLength(void) const { return _data_length; }
+ const std::string &getLoadRawInputFilename(void) const { return _load_raw_input_filename; }
+ const std::string &getLoadRawExpectedFilename(void) const { return _load_raw_expected_filename; }
+ const bool getMemoryPoll(void) const { return _mem_poll; }
+ const int getEpoch(void) const { return _epoch; }
+ const int getBatchSize(void) const { return _batch_size; }
+ const float getLearningRate(void) const { return _learning_rate; }
+ const int getLossType(void) const { return _loss_type; }
+ const int getOptimizerType(void) const { return _optimizer_type; }
+ const bool printVersion(void) const { return _print_version; }
+ const int getVerboseLevel(void) const { return _verbose_level; }
+ std::unordered_map<uint32_t, uint32_t> getOutputSizes(void) const { return _output_sizes; }
+
+private:
+ void Initialize();
+ void Parse(const int argc, char **argv);
+
+private:
+ po::positional_options_description _positional;
+ po::options_description _options;
+
+ std::string _package_filename;
+ std::string _model_filename;
+ bool _use_single_model = false;
+ int _data_length;
+ std::string _load_raw_input_filename;
+ std::string _load_raw_expected_filename;
+ bool _mem_poll;
+ int _epoch;
+ int _batch_size;
+ float _learning_rate;
+ int _loss_type;
+ int _optimizer_type;
+ bool _print_version = false;
+ int _verbose_level;
+ std::unordered_map<uint32_t, uint32_t> _output_sizes;
+};
+
+} // end of namespace onert_train
+
+#endif // __ONERT_TRAIN_ARGS_H__
diff --git a/tests/tools/onert_train/src/formatter.h b/tests/tools/onert_train/src/formatter.h
new file mode 100644
index 000000000..6d256804e
--- /dev/null
+++ b/tests/tools/onert_train/src/formatter.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_TRAIN_FORMATTER_H__
+#define __ONERT_TRAIN_FORMATTER_H__
+
+#include <string>
+#include <vector>
+
+#include "types.h"
+#include "allocation.h"
+
+struct nnfw_session;
+
+namespace onert_train
+{
+class Formatter
+{
+public:
+ virtual ~Formatter() = default;
+ Formatter(nnfw_session *sess) : session_(sess) {}
+ virtual void loadInputs(const std::string &filename, std::vector<Allocation> &inputs) = 0;
+ virtual void dumpOutputs(const std::string &filename, std::vector<Allocation> &outputs) = 0;
+ virtual std::vector<TensorShape> readTensorShapes(const std::string &filename)
+ {
+ return std::vector<TensorShape>();
+ };
+
+protected:
+ nnfw_session *session_;
+};
+} // namespace onert_train
+
+#endif // __ONERT_TRAIN_FORMATTER_H__
diff --git a/tests/tools/onert_train/src/h5formatter.cc b/tests/tools/onert_train/src/h5formatter.cc
new file mode 100644
index 000000000..12c570b5d
--- /dev/null
+++ b/tests/tools/onert_train/src/h5formatter.cc
@@ -0,0 +1,258 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "h5formatter.h"
+#include "nnfw.h"
+#include "nnfw_util.h"
+
+#include <iostream>
+#include <stdexcept>
+#include <H5Cpp.h>
+
+namespace
+{
+onert_train::TensorShape getShape(H5::DataSet &data_set)
+{
+ std::vector<hsize_t> h5_shape; // hsize_t is unsigned long long
+ H5::DataSpace data_space = data_set.getSpace();
+ int rank = data_space.getSimpleExtentNdims();
+ h5_shape.resize(rank);
+
+ // read shape info from H5 file
+ data_space.getSimpleExtentDims(h5_shape.data(), NULL);
+
+ onert_train::TensorShape shape;
+ for (auto dim : h5_shape)
+ shape.emplace_back(static_cast<int>(dim));
+
+ return shape;
+}
+} // namespace
+
+namespace onert_train
+{
+static const char *h5_value_grpname = "value";
+
+std::vector<TensorShape> H5Formatter::readTensorShapes(const std::string &filename)
+{
+ uint32_t num_inputs;
+ NNPR_ENSURE_STATUS(nnfw_input_size(session_, &num_inputs));
+ std::vector<TensorShape> tensor_shapes;
+
+ try
+ {
+ H5::Exception::dontPrint();
+
+ H5::H5File file(filename, H5F_ACC_RDONLY);
+ H5::Group value_group = file.openGroup(h5_value_grpname);
+
+ // Constraints: if there are n data set names, they should be unique and
+ // one of [ "0", "1", .. , "n-1" ]
+ for (uint32_t i = 0; i < num_inputs; ++i)
+ {
+ H5::DataSet data_set = value_group.openDataSet(std::to_string(i));
+ H5::DataType type = data_set.getDataType();
+ auto shape = getShape(data_set);
+
+ tensor_shapes.emplace_back(shape);
+ }
+
+ return tensor_shapes;
+ }
+ catch (const H5::Exception &e)
+ {
+ H5::Exception::printErrorStack();
+ std::exit(-1);
+ }
+ catch (const std::exception &e)
+ {
+ std::cerr << e.what() << std::endl;
+ std::exit(-1);
+ }
+}
+
+void H5Formatter::loadInputs(const std::string &filename, std::vector<Allocation> &inputs)
+{
+ uint32_t num_inputs;
+ NNPR_ENSURE_STATUS(nnfw_input_size(session_, &num_inputs));
+ try
+ {
+ // Turn off the automatic error printing.
+ H5::Exception::dontPrint();
+
+ H5::H5File file(filename, H5F_ACC_RDONLY);
+ H5::Group value_group = file.openGroup(h5_value_grpname);
+ for (uint32_t i = 0; i < num_inputs; ++i)
+ {
+ nnfw_tensorinfo ti;
+ NNPR_ENSURE_STATUS(nnfw_input_tensorinfo(session_, i, &ti));
+
+ // TODO Add Assert(nnfw shape, h5 file shape size)
+
+ // allocate memory for data
+ auto bufsz = bufsize_for(&ti);
+ inputs[i].alloc(bufsz);
+
+ H5::DataSet data_set = value_group.openDataSet(std::to_string(i));
+ H5::DataType type = data_set.getDataType();
+ switch (ti.dtype)
+ {
+ case NNFW_TYPE_TENSOR_FLOAT32:
+ if (type == H5::PredType::IEEE_F32BE || type == H5::PredType::IEEE_F32LE)
+ data_set.read(inputs[i].data(), H5::PredType::NATIVE_FLOAT);
+ else
+ throw std::runtime_error("model input type is f32. But h5 data type is different.");
+ break;
+ case NNFW_TYPE_TENSOR_INT32:
+ if (type == H5::PredType::STD_I32BE || type == H5::PredType::STD_I32LE)
+ data_set.read(inputs[i].data(), H5::PredType::NATIVE_INT32);
+ else
+ throw std::runtime_error("model input type is i32. But h5 data type is different.");
+ break;
+ case NNFW_TYPE_TENSOR_INT64:
+ if (type == H5::PredType::STD_I64BE || type == H5::PredType::STD_I64LE)
+ data_set.read(inputs[i].data(), H5::PredType::NATIVE_INT64);
+ else
+ throw std::runtime_error("model input type is i64. But h5 data type is different.");
+ break;
+ case NNFW_TYPE_TENSOR_QUANT8_ASYMM:
+ case NNFW_TYPE_TENSOR_BOOL:
+ case NNFW_TYPE_TENSOR_UINT8:
+ if (type == H5::PredType::STD_U8BE || type == H5::PredType::STD_U8LE)
+ data_set.read(inputs[i].data(), H5::PredType::NATIVE_UINT8);
+ else
+ throw std::runtime_error(
+ "model input type is qasymm8, bool or uint8. But h5 data type is different.");
+ break;
+ case NNFW_TYPE_TENSOR_QUANT8_ASYMM_SIGNED:
+ if (type == H5::PredType::STD_I8BE || type == H5::PredType::STD_I8LE)
+ data_set.read(inputs[i].data(), H5::PredType::NATIVE_INT8);
+ else
+ throw std::runtime_error("model input type is int8. But h5 data type is different.");
+ break;
+ case NNFW_TYPE_TENSOR_QUANT16_SYMM_SIGNED:
+ throw std::runtime_error("NYI for NNFW_TYPE_TENSOR_QUANT16_SYMM_SIGNED type");
+ default:
+ throw std::runtime_error("onert_run can load f32, i32, qasymm8, bool and uint8.");
+ }
+ NNPR_ENSURE_STATUS(nnfw_set_input(session_, i, ti.dtype, inputs[i].data(), bufsz));
+ NNPR_ENSURE_STATUS(nnfw_set_input_layout(session_, i, NNFW_LAYOUT_CHANNELS_LAST));
+ }
+ }
+ catch (const H5::Exception &e)
+ {
+ H5::Exception::printErrorStack();
+ std::exit(-1);
+ }
+ catch (const std::exception &e)
+ {
+ std::cerr << e.what() << std::endl;
+ std::exit(-1);
+ }
+};
+
+void H5Formatter::dumpOutputs(const std::string &filename, std::vector<Allocation> &outputs)
+{
+ uint32_t num_outputs;
+ NNPR_ENSURE_STATUS(nnfw_output_size(session_, &num_outputs));
+ try
+ {
+ // Turn off the automatic error printing.
+ H5::Exception::dontPrint();
+
+ H5::H5File file(filename, H5F_ACC_TRUNC);
+ H5::Group value_group = file.createGroup(h5_value_grpname);
+ for (uint32_t i = 0; i < num_outputs; i++)
+ {
+ nnfw_tensorinfo ti;
+ NNPR_ENSURE_STATUS(nnfw_output_tensorinfo(session_, i, &ti));
+ std::vector<hsize_t> dims(ti.rank);
+ for (uint32_t j = 0; j < ti.rank; ++j)
+ {
+ if (ti.dims[j] >= 0)
+ dims[j] = static_cast<hsize_t>(ti.dims[j]);
+ else
+ {
+ std::cerr << "Negative dimension in output tensor" << std::endl;
+ exit(-1);
+ }
+ }
+ H5::DataSpace data_space(ti.rank, dims.data());
+ switch (ti.dtype)
+ {
+ case NNFW_TYPE_TENSOR_FLOAT32:
+ {
+ H5::DataSet data_set =
+ value_group.createDataSet(std::to_string(i), H5::PredType::IEEE_F32BE, data_space);
+ data_set.write(outputs[i].data(), H5::PredType::NATIVE_FLOAT);
+ break;
+ }
+ case NNFW_TYPE_TENSOR_INT32:
+ {
+ H5::DataSet data_set =
+ value_group.createDataSet(std::to_string(i), H5::PredType::STD_I32LE, data_space);
+ data_set.write(outputs[i].data(), H5::PredType::NATIVE_INT32);
+ break;
+ }
+ case NNFW_TYPE_TENSOR_INT64:
+ {
+ H5::DataSet data_set =
+ value_group.createDataSet(std::to_string(i), H5::PredType::STD_I64LE, data_space);
+ data_set.write(outputs[i].data(), H5::PredType::NATIVE_INT64);
+ break;
+ }
+ case NNFW_TYPE_TENSOR_UINT8:
+ case NNFW_TYPE_TENSOR_QUANT8_ASYMM:
+ {
+ H5::DataSet data_set =
+ value_group.createDataSet(std::to_string(i), H5::PredType::STD_U8BE, data_space);
+ data_set.write(outputs[i].data(), H5::PredType::NATIVE_UINT8);
+ break;
+ }
+ case NNFW_TYPE_TENSOR_BOOL:
+ {
+ H5::DataSet data_set =
+ value_group.createDataSet(std::to_string(i), H5::PredType::STD_U8LE, data_space);
+ data_set.write(outputs[i].data(), H5::PredType::NATIVE_INT8);
+ break;
+ }
+ case NNFW_TYPE_TENSOR_QUANT8_ASYMM_SIGNED:
+ {
+ H5::DataSet data_set =
+ value_group.createDataSet(std::to_string(i), H5::PredType::STD_I8LE, data_space);
+ data_set.write(outputs[i].data(), H5::PredType::NATIVE_INT8);
+ break;
+ }
+ case NNFW_TYPE_TENSOR_QUANT16_SYMM_SIGNED:
+ throw std::runtime_error("NYI for NNFW_TYPE_TENSOR_QUANT16_SYMM_SIGNED type");
+ default:
+ throw std::runtime_error("onert_run can dump f32, i32, qasymm8, bool and uint8.");
+ }
+ }
+ }
+ catch (const H5::Exception &e)
+ {
+ H5::Exception::printErrorStack();
+ std::exit(-1);
+ }
+ catch (const std::runtime_error &e)
+ {
+ std::cerr << "Error during dumpOutputs on onert_run : " << e.what() << std::endl;
+ std::exit(-1);
+ }
+};
+
+} // end of namespace onert_train
diff --git a/tests/tools/onert_train/src/h5formatter.h b/tests/tools/onert_train/src/h5formatter.h
new file mode 100644
index 000000000..21ef16526
--- /dev/null
+++ b/tests/tools/onert_train/src/h5formatter.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_TRAIN_H5FORMATTER_H__
+#define __ONERT_TRAIN_H5FORMATTER_H__
+
+#include "allocation.h"
+#include "formatter.h"
+#include "types.h"
+
+#include <string>
+#include <vector>
+
+struct nnfw_session;
+
+namespace onert_train
+{
+class H5Formatter : public Formatter
+{
+public:
+ H5Formatter(nnfw_session *sess) : Formatter(sess) {}
+ std::vector<TensorShape> readTensorShapes(const std::string &filename) override;
+ void loadInputs(const std::string &filename, std::vector<Allocation> &inputs) override;
+ void dumpOutputs(const std::string &filename, std::vector<Allocation> &outputs) override;
+};
+} // namespace onert_train
+
+#endif // __ONERT_TRAIN_H5FORMATTER_H__
diff --git a/tests/tools/onert_train/src/measure.h b/tests/tools/onert_train/src/measure.h
new file mode 100644
index 000000000..f7c8610d0
--- /dev/null
+++ b/tests/tools/onert_train/src/measure.h
@@ -0,0 +1,90 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_TRAIN_MEASURE_H__
+#define __ONERT_TRAIN_MEASURE_H__
+
+#include <algorithm>
+#include <ctime>
+#include <vector>
+
+namespace
+{
+uint64_t nowMicros()
+{
+ struct timespec ts;
+ clock_gettime(CLOCK_MONOTONIC, &ts);
+ return static_cast<uint64_t>(ts.tv_nsec) / 1e3 + static_cast<uint64_t>(ts.tv_sec) * 1e6;
+}
+} // namespace
+
+namespace onert_train
+{
+
+struct Step
+{
+ uint64_t time; // us
+ // TODO Support memory usage
+};
+
+class Measure
+{
+public:
+ Measure() = default;
+
+ void set(const int epoch, const int step)
+ {
+ _results.clear();
+ _results.resize(epoch);
+ std::for_each(_results.begin(), _results.end(), [step](auto &v) { v.resize(step); });
+ }
+
+ void run(const int epoch, const int step, const std::function<void()> &func)
+ {
+ if (_results.empty() || _results.size() <= epoch || _results[epoch].size() <= step)
+ {
+ throw std::runtime_error("Please set the number of epochs and steps first");
+ }
+
+ _results[epoch][step].time = nowMicros();
+
+ func();
+
+ _results[epoch][step].time = nowMicros() - _results[epoch][step].time;
+ }
+
+ double timeMicros(const int epoch)
+ {
+ if (_results.empty() || _results.size() <= epoch)
+ {
+ throw std::runtime_error("Invalid epoch");
+ }
+
+ double sum = 0u;
+ std::for_each(_results[epoch].begin(), _results[epoch].end(),
+ [&sum](auto &v) { sum += v.time; });
+ return sum / _results[epoch].size();
+ }
+
+ double timeMs(const int epoch) { return timeMicros(epoch) / 1e3; }
+
+private:
+ std::vector<std::vector<Step>> _results;
+};
+
+} // namespace onert_train
+
+#endif // __ONERT_TRAIN_MEASURE_H__
diff --git a/tests/tools/onert_train/src/nnfw_util.cc b/tests/tools/onert_train/src/nnfw_util.cc
new file mode 100644
index 000000000..8dd2aa871
--- /dev/null
+++ b/tests/tools/onert_train/src/nnfw_util.cc
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <cassert>
+#include <string>
+#include "nnfw.h"
+
+namespace onert_train
+{
+uint64_t num_elems(const nnfw_tensorinfo *ti)
+{
+ uint64_t n = 1;
+ for (uint32_t i = 0; i < ti->rank; ++i)
+ {
+ assert(ti->dims[i] >= 0);
+ n *= ti->dims[i];
+ }
+ return n;
+}
+
+uint64_t bufsize_for(const nnfw_tensorinfo *ti)
+{
+ static int elmsize[] = {
+ sizeof(float), /* NNFW_TYPE_TENSOR_FLOAT32 */
+ sizeof(int), /* NNFW_TYPE_TENSOR_INT32 */
+ sizeof(uint8_t), /* NNFW_TYPE_TENSOR_QUANT8_ASYMM */
+ sizeof(bool), /* NNFW_TYPE_TENSOR_BOOL = 3 */
+ sizeof(uint8_t), /* NNFW_TYPE_TENSOR_UINT8 = 4 */
+ sizeof(int64_t), /* NNFW_TYPE_TENSOR_INT64 = 5 */
+ sizeof(int8_t), /* NNFW_TYPE_TENSOR_QUANT8_ASYMM_SIGNED = 6 */
+ sizeof(int16_t), /* NNFW_TYPE_TENSOR_QUANT16_SYMM_SIGNED = 7 */
+ };
+ return elmsize[ti->dtype] * num_elems(ti);
+}
+
+} // namespace onert_train
diff --git a/tests/tools/onert_train/src/nnfw_util.h b/tests/tools/onert_train/src/nnfw_util.h
new file mode 100644
index 000000000..674e18fb2
--- /dev/null
+++ b/tests/tools/onert_train/src/nnfw_util.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_TRAIN_NNFW_UTIL_H__
+#define __ONERT_TRAIN_NNFW_UTIL_H__
+
+#include "nnfw.h"
+
+#define NNPR_ENSURE_STATUS(a) \
+ do \
+ { \
+ if ((a) != NNFW_STATUS_NO_ERROR) \
+ { \
+ exit(-1); \
+ } \
+ } while (0)
+
+namespace onert_train
+{
+uint64_t num_elems(const nnfw_tensorinfo *ti);
+uint64_t bufsize_for(const nnfw_tensorinfo *ti);
+} // end of namespace onert_train
+
+#endif // __ONERT_TRAIN_NNFW_UTIL_H__
diff --git a/tests/tools/onert_train/src/onert_train.cc b/tests/tools/onert_train/src/onert_train.cc
new file mode 100644
index 000000000..678d13fc9
--- /dev/null
+++ b/tests/tools/onert_train/src/onert_train.cc
@@ -0,0 +1,277 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "allocation.h"
+#include "args.h"
+#include "benchmark.h"
+#include "measure.h"
+#include "nnfw.h"
+#include "nnfw_util.h"
+#include "nnfw_internal.h"
+#include "nnfw_experimental.h"
+#include "randomgen.h"
+#include "rawformatter.h"
+#include "rawdataloader.h"
+
+#include <boost/program_options.hpp>
+#include <cassert>
+#include <chrono>
+#include <cstdlib>
+#include <iostream>
+#include <libgen.h>
+#include <stdexcept>
+#include <unordered_map>
+#include <vector>
+
+static const char *default_backend_cand = "train";
+
+int main(const int argc, char **argv)
+{
+ using namespace onert_train;
+
+ try
+ {
+ Args args(argc, argv);
+ if (args.printVersion())
+ {
+ uint32_t version;
+ NNPR_ENSURE_STATUS(nnfw_query_info_u32(NULL, NNFW_INFO_ID_VERSION, &version));
+ std::cout << "onert_train (nnfw runtime: v" << (version >> 24) << "."
+ << ((version & 0x0000FF00) >> 8) << "." << (version & 0xFF) << ")" << std::endl;
+ exit(0);
+ }
+
+ // TODO Apply verbose level to phases
+ const int verbose = args.getVerboseLevel();
+ benchmark::Phases phases(benchmark::PhaseOption{});
+
+ nnfw_session *session = nullptr;
+ NNPR_ENSURE_STATUS(nnfw_create_session(&session));
+
+ // ModelLoad
+ phases.run("MODEL_LOAD", [&](const benchmark::Phase &, uint32_t) {
+ if (args.useSingleModel())
+ NNPR_ENSURE_STATUS(
+ nnfw_load_model_from_modelfile(session, args.getModelFilename().c_str()));
+ else
+ NNPR_ENSURE_STATUS(nnfw_load_model_from_file(session, args.getPackageFilename().c_str()));
+ });
+
+ // Set training backend
+ NNPR_ENSURE_STATUS(nnfw_set_available_backends(session, default_backend_cand));
+
+ uint32_t num_inputs;
+ NNPR_ENSURE_STATUS(nnfw_input_size(session, &num_inputs));
+
+ uint32_t num_expecteds;
+ NNPR_ENSURE_STATUS(nnfw_output_size(session, &num_expecteds));
+
+ // verify input and output
+
+ auto verifyInputTypes = [session]() {
+ uint32_t sz;
+ NNPR_ENSURE_STATUS(nnfw_input_size(session, &sz));
+ for (uint32_t i = 0; i < sz; ++i)
+ {
+ nnfw_tensorinfo ti;
+ NNPR_ENSURE_STATUS(nnfw_input_tensorinfo(session, i, &ti));
+
+ if (ti.dtype < NNFW_TYPE_TENSOR_FLOAT32 || ti.dtype > NNFW_TYPE_TENSOR_QUANT16_SYMM_SIGNED)
+ {
+ std::cerr << "E: not supported input type" << std::endl;
+ exit(-1);
+ }
+ }
+ };
+
+ auto verifyOutputTypes = [session]() {
+ uint32_t sz;
+ NNPR_ENSURE_STATUS(nnfw_output_size(session, &sz));
+
+ for (uint32_t i = 0; i < sz; ++i)
+ {
+ nnfw_tensorinfo ti;
+ NNPR_ENSURE_STATUS(nnfw_output_tensorinfo(session, i, &ti));
+
+ if (ti.dtype < NNFW_TYPE_TENSOR_FLOAT32 || ti.dtype > NNFW_TYPE_TENSOR_QUANT16_SYMM_SIGNED)
+ {
+ std::cerr << "E: not supported output type" << std::endl;
+ exit(-1);
+ }
+ }
+ };
+
+ verifyInputTypes();
+ verifyOutputTypes();
+
+ auto convertLossType = [](int type) {
+ switch (type)
+ {
+ case 0:
+ return NNFW_TRAIN_LOSS_MEAN_SQUARED_ERROR;
+ case 1:
+ return NNFW_TRAIN_LOSS_CATEGORICAL_CROSSENTROPY;
+ default:
+ std::cerr << "E: not supported loss type" << std::endl;
+ exit(-1);
+ }
+ };
+
+ auto convertOptType = [](int type) {
+ switch (type)
+ {
+ case 0:
+ return NNFW_TRAIN_OPTIMIZER_SGD;
+ case 1:
+ return NNFW_TRAIN_OPTIMIZER_ADAM;
+ default:
+ std::cerr << "E: not supported optimizer type" << std::endl;
+ exit(-1);
+ }
+ };
+
+ // prepare training info
+ nnfw_train_info tri;
+ tri.batch_size = args.getBatchSize();
+ tri.learning_rate = args.getLearningRate();
+ tri.loss = convertLossType(args.getLossType());
+ tri.opt = convertOptType(args.getOptimizerType());
+
+ // prepare execution
+
+ // TODO When nnfw_{prepare|run} are failed, can't catch the time
+ phases.run("PREPARE", [&](const benchmark::Phase &, uint32_t) {
+ NNPR_ENSURE_STATUS(nnfw_train_prepare(session, &tri));
+ });
+
+ // prepare input and expected tensor info lists
+ std::vector<nnfw_tensorinfo> input_infos;
+ std::vector<nnfw_tensorinfo> expected_infos;
+
+ // prepare data buffers
+ std::vector<Allocation> input_data(num_inputs);
+ std::vector<Allocation> expected_data(num_expecteds);
+
+ for (uint32_t i = 0; i < num_inputs; ++i)
+ {
+ nnfw_tensorinfo ti;
+ NNPR_ENSURE_STATUS(nnfw_input_tensorinfo(session, i, &ti));
+ input_data[i].alloc(bufsize_for(&ti));
+ input_infos.emplace_back(std::move(ti));
+ }
+
+ for (uint32_t i = 0; i < num_expecteds; ++i)
+ {
+ nnfw_tensorinfo ti;
+ NNPR_ENSURE_STATUS(nnfw_output_tensorinfo(session, i, &ti));
+ expected_data[i].alloc(bufsize_for(&ti));
+ expected_infos.emplace_back(std::move(ti));
+ }
+
+ auto data_length = args.getDataLength();
+
+ Generator generator;
+ RawDataLoader rawDataLoader;
+
+ if (!args.getLoadRawInputFilename().empty() && !args.getLoadRawExpectedFilename().empty())
+ {
+ generator =
+ rawDataLoader.loadData(args.getLoadRawInputFilename(), args.getLoadRawExpectedFilename(),
+ input_infos, expected_infos, data_length, tri.batch_size);
+ }
+ else
+ {
+ // TODO Use random generator
+ std::cerr << "E: not supported random input and expected generator" << std::endl;
+ exit(-1);
+ }
+
+ Measure measure;
+ std::vector<float> losses(num_expecteds);
+ phases.run("EXECUTE", [&](const benchmark::Phase &, uint32_t) {
+ const int num_step = data_length / tri.batch_size;
+ const int num_epoch = args.getEpoch();
+ measure.set(num_epoch, num_step);
+ for (uint32_t epoch = 0; epoch < num_epoch; ++epoch)
+ {
+ std::fill(losses.begin(), losses.end(), 0);
+ for (uint32_t n = 0; n < num_step; ++n)
+ {
+ // get batchsize data
+ if (!generator(n, input_data, expected_data))
+ break;
+
+ // prepare input
+ for (uint32_t i = 0; i < num_inputs; ++i)
+ {
+ NNPR_ENSURE_STATUS(
+ nnfw_train_set_input(session, i, input_data[i].data(), &input_infos[i]));
+ }
+
+ // prepare output
+ for (uint32_t i = 0; i < num_expecteds; ++i)
+ {
+ NNPR_ENSURE_STATUS(
+ nnfw_train_set_expected(session, i, expected_data[i].data(), &expected_infos[i]));
+ }
+
+ // train
+ measure.run(epoch, n, [&]() { NNPR_ENSURE_STATUS(nnfw_train(session, true)); });
+
+ // store loss
+ for (int32_t i = 0; i < num_expecteds; ++i)
+ {
+ float temp = 0.f;
+ NNPR_ENSURE_STATUS(nnfw_train_get_loss(session, i, &temp));
+ losses[i] += temp;
+ }
+ }
+
+ // print loss
+ std::cout << std::fixed;
+ std::cout.precision(3);
+ std::cout << "Epoch " << epoch + 1 << "/" << num_epoch << " - " << measure.timeMs(epoch)
+ << "ms/step - loss: ";
+ std::cout.precision(4);
+ for (uint32_t i = 0; i < num_expecteds; ++i)
+ {
+ std::cout << "[" << i << "] " << losses[i] / num_step;
+ }
+ std::cout /* << "- accuracy: " << accuracy*/ << std::endl;
+ }
+ });
+
+ NNPR_ENSURE_STATUS(nnfw_close_session(session));
+
+ // prepare result
+ benchmark::Result result(phases);
+
+ // to stdout
+ benchmark::printResult(result);
+
+ return 0;
+ }
+ catch (boost::program_options::error &e)
+ {
+ std::cerr << "E: " << e.what() << std::endl;
+ exit(-1);
+ }
+ catch (std::runtime_error &e)
+ {
+ std::cerr << "E: Fail to run by runtime error:" << e.what() << std::endl;
+ exit(-1);
+ }
+}
diff --git a/tests/tools/onert_train/src/randomgen.cc b/tests/tools/onert_train/src/randomgen.cc
new file mode 100644
index 000000000..72599cbb2
--- /dev/null
+++ b/tests/tools/onert_train/src/randomgen.cc
@@ -0,0 +1,77 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "randomgen.h"
+#include "nnfw.h"
+#include "nnfw_util.h"
+#include "misc/RandomGenerator.h"
+
+#include <iostream>
+
+namespace onert_train
+{
+
+template <class T> void randomData(nnfw::misc::RandomGenerator &randgen, void *data, uint64_t size)
+{
+ for (uint64_t i = 0; i < size; i++)
+ reinterpret_cast<T *>(data)[i] = randgen.generate<T>();
+}
+
+void RandomGenerator::generate(std::vector<Allocation> &inputs)
+{
+ // generate random data
+ const int seed = 1;
+ nnfw::misc::RandomGenerator randgen{seed, 0.0f, 2.0f};
+ for (uint32_t i = 0; i < inputs.size(); ++i)
+ {
+ nnfw_tensorinfo ti;
+ NNPR_ENSURE_STATUS(nnfw_input_tensorinfo(session_, i, &ti));
+ auto input_size_in_bytes = bufsize_for(&ti);
+ inputs[i].alloc(input_size_in_bytes);
+ switch (ti.dtype)
+ {
+ case NNFW_TYPE_TENSOR_FLOAT32:
+ randomData<float>(randgen, inputs[i].data(), num_elems(&ti));
+ break;
+ case NNFW_TYPE_TENSOR_QUANT8_ASYMM:
+ randomData<uint8_t>(randgen, inputs[i].data(), num_elems(&ti));
+ break;
+ case NNFW_TYPE_TENSOR_BOOL:
+ randomData<bool>(randgen, inputs[i].data(), num_elems(&ti));
+ break;
+ case NNFW_TYPE_TENSOR_UINT8:
+ randomData<uint8_t>(randgen, inputs[i].data(), num_elems(&ti));
+ break;
+ case NNFW_TYPE_TENSOR_INT32:
+ randomData<int32_t>(randgen, inputs[i].data(), num_elems(&ti));
+ break;
+ case NNFW_TYPE_TENSOR_INT64:
+ randomData<int64_t>(randgen, inputs[i].data(), num_elems(&ti));
+ break;
+ case NNFW_TYPE_TENSOR_QUANT16_SYMM_SIGNED:
+ randomData<int16_t>(randgen, inputs[i].data(), num_elems(&ti));
+ break;
+ default:
+ std::cerr << "Not supported input type" << std::endl;
+ std::exit(-1);
+ }
+ NNPR_ENSURE_STATUS(
+ nnfw_set_input(session_, i, ti.dtype, inputs[i].data(), input_size_in_bytes));
+ NNPR_ENSURE_STATUS(nnfw_set_input_layout(session_, i, NNFW_LAYOUT_CHANNELS_LAST));
+ }
+};
+
+} // end of namespace onert_train
diff --git a/tests/tools/onert_train/src/randomgen.h b/tests/tools/onert_train/src/randomgen.h
new file mode 100644
index 000000000..410c66d6f
--- /dev/null
+++ b/tests/tools/onert_train/src/randomgen.h
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_TRAIN_RANDOMGEN_H__
+#define __ONERT_TRAIN_RANDOMGEN_H__
+
+#include <string>
+#include <vector>
+
+#include "allocation.h"
+
+struct nnfw_session;
+
+namespace onert_train
+{
+class RandomGenerator
+{
+public:
+ RandomGenerator(nnfw_session *sess) : session_(sess) {}
+ void generate(std::vector<Allocation> &inputs);
+
+private:
+ nnfw_session *session_;
+};
+} // namespace onert_train
+
+#endif // __ONERT_TRAIN_RANDOMGEN_H__
diff --git a/tests/tools/onert_train/src/rawdataloader.cc b/tests/tools/onert_train/src/rawdataloader.cc
new file mode 100644
index 000000000..a3672a9f3
--- /dev/null
+++ b/tests/tools/onert_train/src/rawdataloader.cc
@@ -0,0 +1,77 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "rawdataloader.h"
+#include "nnfw_util.h"
+
+#include <iostream>
+#include <stdexcept>
+#include <numeric>
+
+namespace onert_train
+{
+
+Generator RawDataLoader::loadData(const std::string &input_file, const std::string &expected_file,
+ const std::vector<nnfw_tensorinfo> &input_infos,
+ const std::vector<nnfw_tensorinfo> &expected_infos,
+ const uint32_t data_length, const uint32_t batch_size)
+{
+ std::vector<uint32_t> input_origins(input_infos.size());
+ uint32_t start = 0;
+ for (uint32_t i = 0; i < input_infos.size(); ++i)
+ {
+ input_origins.at(i) = start;
+ start += (bufsize_for(&input_infos[i]) / batch_size * data_length);
+ }
+
+ std::vector<uint32_t> expected_origins(expected_infos.size());
+ start = 0;
+ for (uint32_t i = 0; i < expected_infos.size(); ++i)
+ {
+ expected_origins.at(i) = start;
+ start += (bufsize_for(&expected_infos[i]) / batch_size * data_length);
+ }
+
+ try
+ {
+ _input_file = std::ifstream(input_file, std::ios::ate | std::ios::binary);
+ _expected_file = std::ifstream(expected_file, std::ios::ate | std::ios::binary);
+ }
+ catch (const std::exception &e)
+ {
+ std::cerr << e.what() << std::endl;
+ std::exit(-1);
+ }
+
+ return [input_origins, expected_origins, &input_infos, &expected_infos,
+ this](uint32_t idx, std::vector<Allocation> &inputs, std::vector<Allocation> &expecteds) {
+ for (uint32_t i = 0; i < input_infos.size(); ++i)
+ {
+ auto bufsz = bufsize_for(&input_infos[i]);
+ _input_file.seekg(input_origins[i] + idx * bufsz, std::ios::beg);
+ _input_file.read(reinterpret_cast<char *>(inputs[i].data()), bufsz);
+ }
+ for (uint32_t i = 0; i < expected_infos.size(); ++i)
+ {
+ auto bufsz = bufsize_for(&expected_infos[i]);
+ _expected_file.seekg(expected_origins[i] + idx * bufsz, std::ios::beg);
+ _expected_file.read(reinterpret_cast<char *>(expecteds[i].data()), bufsz);
+ }
+ return true;
+ };
+}
+
+} // namespace onert_train
diff --git a/tests/tools/onert_train/src/rawdataloader.h b/tests/tools/onert_train/src/rawdataloader.h
new file mode 100644
index 000000000..3fb292770
--- /dev/null
+++ b/tests/tools/onert_train/src/rawdataloader.h
@@ -0,0 +1,51 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_TRAIN_RAWDATALOADER_H__
+#define __ONERT_TRAIN_RAWDATALOADER_H__
+
+#include "allocation.h"
+#include "nnfw.h"
+
+#include <functional>
+#include <string>
+#include <vector>
+#include <fstream>
+
+namespace onert_train
+{
+
+using Generator = std::function<bool(uint32_t, /** index **/
+ std::vector<Allocation> &, /** input **/
+ std::vector<Allocation> & /** expected **/)>;
+
+class RawDataLoader
+{
+public:
+ RawDataLoader() = default;
+ Generator loadData(const std::string &input_file, const std::string &expected_file,
+ const std::vector<nnfw_tensorinfo> &input_infos,
+ const std::vector<nnfw_tensorinfo> &output_infos, const uint32_t data_length,
+ const uint32_t batch_size);
+
+private:
+ std::ifstream _input_file;
+ std::ifstream _expected_file;
+};
+
+} // namespace onert_train
+
+#endif // __ONERT_TRAIN_RAWDATALOADER_H__
diff --git a/tests/tools/onert_train/src/rawformatter.cc b/tests/tools/onert_train/src/rawformatter.cc
new file mode 100644
index 000000000..a17071684
--- /dev/null
+++ b/tests/tools/onert_train/src/rawformatter.cc
@@ -0,0 +1,97 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "rawformatter.h"
+#include "nnfw.h"
+#include "nnfw_util.h"
+
+#include <iostream>
+#include <fstream>
+#include <stdexcept>
+
+namespace onert_train
+{
+void RawFormatter::loadInputs(const std::string &filename, std::vector<Allocation> &inputs)
+{
+ uint32_t num_inputs;
+ NNPR_ENSURE_STATUS(nnfw_input_size(session_, &num_inputs));
+
+ // Support multiple inputs
+ // Option 1: Get comman-separated input file list like --load:raw a,b,c
+ // Option 2: Get prefix --load:raw in
+ // Internally access in.0, in.1, in.2, ... in.{N-1} where N is determined by nnfw info
+ // query api.
+ //
+ // Currently Option 2 is implemented.
+ try
+ {
+ for (uint32_t i = 0; i < num_inputs; ++i)
+ {
+ nnfw_tensorinfo ti;
+ NNPR_ENSURE_STATUS(nnfw_input_tensorinfo(session_, i, &ti));
+
+ // allocate memory for data
+ auto bufsz = bufsize_for(&ti);
+ inputs[i].alloc(bufsz);
+
+ std::ifstream file(filename + "." + std::to_string(i), std::ios::ate | std::ios::binary);
+ auto filesz = file.tellg();
+ if (bufsz != filesz)
+ {
+ throw std::runtime_error("Input " + std::to_string(i) +
+ " size does not match: " + std::to_string(bufsz) +
+ " expected, but " + std::to_string(filesz) + " provided.");
+ }
+ file.seekg(0, std::ios::beg);
+ file.read(reinterpret_cast<char *>(inputs[i].data()), filesz);
+ file.close();
+
+ NNPR_ENSURE_STATUS(nnfw_set_input(session_, i, ti.dtype, inputs[i].data(), bufsz));
+ NNPR_ENSURE_STATUS(nnfw_set_input_layout(session_, i, NNFW_LAYOUT_CHANNELS_LAST));
+ }
+ }
+ catch (const std::exception &e)
+ {
+ std::cerr << e.what() << std::endl;
+ std::exit(-1);
+ }
+};
+
+void RawFormatter::dumpOutputs(const std::string &filename, std::vector<Allocation> &outputs)
+{
+ uint32_t num_outputs;
+ NNPR_ENSURE_STATUS(nnfw_output_size(session_, &num_outputs));
+ try
+ {
+ for (uint32_t i = 0; i < num_outputs; i++)
+ {
+ nnfw_tensorinfo ti;
+ NNPR_ENSURE_STATUS(nnfw_output_tensorinfo(session_, i, &ti));
+ auto bufsz = bufsize_for(&ti);
+
+ std::ofstream file(filename + "." + std::to_string(i), std::ios::out | std::ios::binary);
+ file.write(reinterpret_cast<const char *>(outputs[i].data()), bufsz);
+ file.close();
+ std::cerr << filename + "." + std::to_string(i) + " is generated.\n";
+ }
+ }
+ catch (const std::runtime_error &e)
+ {
+ std::cerr << "Error during dumpOutputs on onert_run : " << e.what() << std::endl;
+ std::exit(-1);
+ }
+}
+} // end of namespace onert_train
diff --git a/tests/tools/onert_train/src/rawformatter.h b/tests/tools/onert_train/src/rawformatter.h
new file mode 100644
index 000000000..90e81b2e9
--- /dev/null
+++ b/tests/tools/onert_train/src/rawformatter.h
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_TRAIN_RAWFORMATTER_H__
+#define __ONERT_TRAIN_RAWFORMATTER_H__
+
+#include "allocation.h"
+#include "formatter.h"
+#include "types.h"
+
+#include <string>
+#include <vector>
+
+struct nnfw_session;
+
+namespace onert_train
+{
+class RawFormatter : public Formatter
+{
+public:
+ RawFormatter(nnfw_session *sess) : Formatter(sess) {}
+ void loadInputs(const std::string &filename, std::vector<Allocation> &inputs) override;
+ void dumpOutputs(const std::string &filename, std::vector<Allocation> &outputs) override;
+};
+} // namespace onert_train
+
+#endif // __ONERT_TRAIN_RAWFORMATTER_H__
diff --git a/tests/tools/onert_train/src/types.h b/tests/tools/onert_train/src/types.h
new file mode 100644
index 000000000..6e2693016
--- /dev/null
+++ b/tests/tools/onert_train/src/types.h
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_TRAIN_TYPES_H__
+#define __ONERT_TRAIN_TYPES_H__
+
+namespace onert_train
+{
+
+using TensorShape = std::vector<int>;
+
+} // end of namespace onert_train
+
+#endif // __ONERT_TRAIN_TYPES_H__