diff options
author | Hyeongseok Oh <hseok82.oh@samsung.com> | 2023-09-08 10:51:25 +0000 |
---|---|---|
committer | Hyeongseok Oh <hseok82.oh@samsung.com> | 2023-09-08 10:51:25 +0000 |
commit | eed258505ee1ad0f72d9e0a8a3934f2e9e7b5e79 (patch) | |
tree | 1aa860656489469003375a0f67edb1d729f7dc6b /runtime/onert | |
parent | 3a0ad354832744d138b361ffcfd21f33494beb6b (diff) | |
download | nnfw-eed258505ee1ad0f72d9e0a8a3934f2e9e7b5e79.tar.gz nnfw-eed258505ee1ad0f72d9e0a8a3934f2e9e7b5e79.tar.bz2 nnfw-eed258505ee1ad0f72d9e0a8a3934f2e9e7b5e79.zip |
Imported Upstream version 1.25.0upstream/1.25.0submit/tizen/20230908.105404
Diffstat (limited to 'runtime/onert')
280 files changed, 13358 insertions, 585 deletions
diff --git a/runtime/onert/CMakeLists.txt b/runtime/onert/CMakeLists.txt index 3c9ca99da..74f7ae568 100644 --- a/runtime/onert/CMakeLists.txt +++ b/runtime/onert/CMakeLists.txt @@ -6,4 +6,5 @@ add_subdirectory(backend) add_subdirectory(frontend) add_subdirectory(core) add_subdirectory(api) +add_subdirectory(odc) add_subdirectory(sample) diff --git a/runtime/onert/api/include/nnfw.h b/runtime/onert/api/include/nnfw.h index 658cba4d5..1f1541a7e 100644 --- a/runtime/onert/api/include/nnfw.h +++ b/runtime/onert/api/include/nnfw.h @@ -243,11 +243,11 @@ NNFW_STATUS nnfw_apply_tensorinfo(nnfw_session *session, uint32_t index, /** * @brief Set input model's tensor info for resizing * - * This function can be called at any time after calling {@link nnfw_model_load_from_file}. Changing + * This function can be called at any time after calling {@link nnfw_load_model_from_file}. Changing * input tensor's shape will cause shape inference for the model. There are two different types of * shape inference - static and dynamic. Which one to use is depend on the current state of the * session. - * When it is called after calling {@link nnfw_model_load_from_file} and before calling {@link + * When it is called after calling {@link nnfw_load_model_from_file} and before calling {@link * nnfw_prepare}, this info will be used when {@link nnfw_prepare}. And it will perform static shape * inference for all tensors. * When it is called after calling {@link nnfw_prepare} or even after {@link nnfw_run}, this info @@ -266,7 +266,7 @@ NNFW_STATUS nnfw_set_input_tensorinfo(nnfw_session *session, uint32_t index, * @brief Prepare session to be ready for inference * * This phase may finalize model compilation, scheduling, and additional settings. - * If {@link nnfw_apply_tensor} is called to apply input tensor info different with model + * If {@link nnfw_apply_tensorinfo} is called to apply input tensor info different with model * before this function, tries to resize all tensors. * * @param[in] session the session to be prepared @@ -309,7 +309,7 @@ NNFW_STATUS nnfw_run_async(nnfw_session *session); /** * @brief Wait for asynchronous run to finish * - * <p>This function must be called after calling {@link nnfw_run_asnyc}, and can be called only once + * <p>This function must be called after calling {@link nnfw_run_async}, and can be called only once * for a {@link nnfw_run_async} call. * * <p>When this function returns, it means that this session has finished the asynchronous run. Then @@ -496,7 +496,7 @@ NNFW_STATUS nnfw_set_op_backend(nnfw_session *session, const char *op, const cha * @note: The input session could be null for global information (e.g. runtime version).* * * @param[in] session session to be queried on. - * @param[in] information ID to be queried + * @param[in] id ID to be queried * @param[out] val uint32 value to be returned. * * @return @c NNFW_STATUS_NO_ERROR if successful diff --git a/runtime/onert/api/include/nnfw_experimental.h b/runtime/onert/api/include/nnfw_experimental.h index b20447e9e..3c8b08f52 100644 --- a/runtime/onert/api/include/nnfw_experimental.h +++ b/runtime/onert/api/include/nnfw_experimental.h @@ -19,6 +19,10 @@ #include "nnfw.h" +#ifdef __cplusplus +extern "C" { +#endif + // Used for custom kernel development /* @@ -152,4 +156,244 @@ NNFW_STATUS nnfw_push_pipeline_input(nnfw_session *session, void *inputs, void * */ NNFW_STATUS nnfw_pop_pipeline_output(nnfw_session *session, void *outputs); +/** + * Training C APIs + * + * Training APIs are designed to be used in the following order for training + * 1. nnfw_train_prepare + * 2. nnfw_train_set_input, nnfw_train_set_expected for inputs & expected outputs + * 3. nnfw_train + * 4. nnfw_train_get_loss + * + * If you want to inference after training with the same session, you can use the following order + * 1. nnfw_set_input + * 2. nnfw_set_output + * 3. nnfw_run + */ + +////////////////////////////////////////////// +// Essential APIs for training +////////////////////////////////////////////// +typedef enum +{ + NNFW_TRAIN_LOSS_MEAN_SQUARED_ERROR = 0, + NNFW_TRAIN_LOSS_CATEGORICAL_CROSSENTROPY = 1, +} NNFW_TRAIN_LOSS; + +typedef enum +{ + NNFW_TRAIN_OPTIMIZER_SGD = 0, + NNFW_TRAIN_OPTIMIZER_ADAM = 1, +} NNFW_TRAIN_OPTIMIZER; + +/** + * @brief Training information to prepare training + * @todo Add more training information + * (e.g. optimizer, loss function, ...) + */ +typedef struct nnfw_train_info +{ + /** Learning rate */ + float learning_rate = 0.001f; + /** Batch size */ + uint32_t batch_size = 1; + /** loss type */ + NNFW_TRAIN_LOSS loss = NNFW_TRAIN_LOSS_MEAN_SQUARED_ERROR; + /** optimizer type */ + NNFW_TRAIN_OPTIMIZER opt = NNFW_TRAIN_OPTIMIZER_SGD; +} nnfw_train_info; + +/** + * @brief Prepare session to be ready for training + * @note The session will be entered into training mode + * + * @param[in] session The session to be prepared for training + * @param[in] info Training information. + * If info is nullptr, it will not change training information. + * If it is nullptr and model has not training information, + * it will use default training information. + * Default training information is {learning_rate = 0.001f, batch_size = 1} + * + * @return @c NNFW_STATUS_NO_ERROR if successful + */ +NNFW_STATUS nnfw_train_prepare(nnfw_session *session, const nnfw_train_info *info); + +/** + * @brief Set training input + * @note This function should be called after {@link nnfw_train_prepare} + * + * @param[in] session The session to be set training inputs and expected model outputs + * @param[in] index The index of training input + * @param[in] input The input buffers for training + * @param[in] input_info The shape and type of input buffer + * If it is nullptr, it will not change shape and batch size + * @return @c NNFW_STATUS_NO_ERROR if successful + */ +NNFW_STATUS nnfw_train_set_input(nnfw_session *session, uint32_t index, const void *input, + const nnfw_tensorinfo *input_info); + +/** + * @brief Set training expected output + * @note This function should be called after {@link nnfw_train_prepare} + * + * @param session The session to be set training inputs and expected model outputs + * @param index The index of training expected output + * @param expected The expected buffers for training + * @param expected_info The shape and type of expected buffer + * If it is nullptr, it will not change shape and batch size + * @return @c NNFW_STATUS_NO_ERROR if successful + */ +NNFW_STATUS nnfw_train_set_expected(nnfw_session *session, uint32_t index, const void *expected, + const nnfw_tensorinfo *expected_info); + +/** + * @brief Train the model + * @note This function should be called after {@link nnfw_train_set_input} and + * {@link nnfw_train_set_expected} for each input and expected output + * + * @param[in] session The session to be trained + * @param[in] update_weights If true, update weights of the model + * If false, do not update weights of the model (for validation) + * @return @c NNFW_STATUS_NO_ERROR if successful + */ +NNFW_STATUS nnfw_train(nnfw_session *session, bool update_weights); + +/** + * @brief Get loss value for expected output + * @note This function should be called after {@link nnfw_train} + * + * @param[in] session The session to get loss value + * @param[in] index The index of loss value [0, number of expected outputs) + * @param[out] loss The loss value + * @return @c NNFW_STATUS_NO_ERROR if successful + */ +NNFW_STATUS nnfw_train_get_loss(nnfw_session *session, uint32_t index, float *loss); + +/** + * @brief Export circle model + * @note This function should be called on training mode + * This function should be called after {@link nnfw_train} + * + * @param[in] session The session to export inference model + * @param[in] path The path to export inference model + * @return @c NNFW_STATUS_NO_ERROR if successful + */ +NNFW_STATUS nnfw_train_export_circle(nnfw_session *session, const char *path); + +////////////////////////////////////////////// +// Optional APIs for training +////////////////////////////////////////////// + +/** + * @brief Get the training model input information + * @note This function should be called after {@link nnfw_train_prepare} + * + * @param[in] session The session to get the training model input information + * @param[in] index The index of training model input + * @param[out] info The shape and type of training model input + * @return @c NNFW_STATUS_NO_ERROR if successful + */ +NNFW_STATUS nnfw_train_input_tensorinfo(nnfw_session *session, uint32_t index, + nnfw_tensorinfo *info); + +/** + * @brief Get the training model expected output information + * @note This function should be called after {@link nnfw_train_prepare} + * + * @param[in] session The session to get the training model expected output information + * @param[in] index The index of training model expected output + * @param[out] info The shape and type of training model expected output + * @return @c NNFW_STATUS_NO_ERROR if successful + */ +NNFW_STATUS nnfw_train_expected_tensorinfo(nnfw_session *session, uint32_t index, + nnfw_tensorinfo *info); + +////////////////////////////////////////////// +// Not planned to be implemented +////////////////////////////////////////////// + +/** + * @brief Convert between training mode and inference mode + * @note This function should be called after {@link nnfw_train} or {@link nnfw_prepare} + * + * @param[in] session The session to convert training mode to inference mode + * @param[in] train If false, convert training model to inference model + * If true, convert inference model to training model + * @return @c NNFW_STATUS_NO_ERROR if successful + */ +// NNFW_STATUS nnfw_set_training_mode(nnfw_session *session, bool train); + +/** + * @brief Set training information after prepare training + * @note This function may be used after {@link nnfw_train_prepare} + * + * @param[in] session The session prepared for training + * @param[in] info Training information + * @return @c NNFW_STATUS_NO_ERROR if successful + */ +// NNFW_STATUS nnfw_train_set_traininfo(nnfw_session *session, const nnfw_train_info info); + +/** + * On-Device Quantization APIs + * + * On-Device Quantization APIs are designed to be used in the following order + * 1. nnfw_set_quantization_type + * 2. nnfw_set_quantized_model_path + * 3. nnfw_quantize + * + * You should use Quantization APIs after {@link nnfw_load_model_from_file}, + * before {@link nnfw_prepare} and {@link nnfw_set_input_tensorinfo}. + */ + +/** + * @brief quantization type + */ +typedef enum +{ + /** default value: type not set */ + NNFW_QUANTIZE_TYPE_NOT_SET, + /** asymmetric quantization with a scale and zero point */ + NNFW_QUANTIZE_TYPE_U8_ASYM, + /** symmetric quantization with a scale only */ + NNFW_QUANTIZE_TYPE_I16_SYM, +} NNFW_QUANTIZE_TYPE; + +/** + * @brief Set quantization type + * + * This function should be called before {@link nnfw_quantize} is invoked. + * + * @param[in] session nnfw_session to set quantization type + * @param[in] pref @c NNFW_QUANTIZE_TYPE + * @return @c NNFW_STATUS_NO_ERROR if successful, + * @c NNFW_STATUS_UNEXPECTED_NULL if session is null, + * otherwise return @c NNFW_STATUS_ERROR + */ +NNFW_STATUS nnfw_set_quantization_type(nnfw_session *session, NNFW_QUANTIZE_TYPE qtype); + +/** + * @brief Set exported quantized model path + * + * This function should be called before {@link nnfw_quantize} is invoked. + * + * TODO: If this function is not called, quantized model will not be exported + * + * @param[in] session nnfw_session to set quantized model path + * @param[in] path Quantized model path + * @return @c NNFW_STATUS_NO_ERROR if successful, otherwise return @c NNFW_STATUS_ERROR + */ +NNFW_STATUS nnfw_set_quantized_model_path(nnfw_session *session, const char *path); + +/** + * @brief Quantize circle model + * + * @param[in] session nnfw_session to quantize + * @return @c ODC_STATUS_NO_ERROR if successful, otherwise return @c ODC_STATUS_ERROR + */ +NNFW_STATUS nnfw_quantize(nnfw_session *session); + +#ifdef __cplusplus +} +#endif + #endif // __NNFW_EXPERIMENTAL_H__ diff --git a/runtime/onert/api/include/nnfw_version.h b/runtime/onert/api/include/nnfw_version.h index db35c6700..7a280a66d 100644 --- a/runtime/onert/api/include/nnfw_version.h +++ b/runtime/onert/api/include/nnfw_version.h @@ -21,6 +21,6 @@ * NNFW_VERSION is a uint32 value representing nnfw runtime version * in 0xMMmmmmPP, where MM = major, mmmm = minor, PP = patch */ -#define NNFW_VERSION 0x01001601 +#define NNFW_VERSION 0x01001900 #endif // __NNFW_VERSION_H__ diff --git a/runtime/onert/api/src/nnfw_api.cc b/runtime/onert/api/src/nnfw_api.cc index a0e6ee094..185738add 100644 --- a/runtime/onert/api/src/nnfw_api.cc +++ b/runtime/onert/api/src/nnfw_api.cc @@ -385,3 +385,133 @@ NNFW_STATUS nnfw_pop_pipeline_output(nnfw_session *session, void *outputs) NNFW_RETURN_ERROR_IF_NULL(session); return session->pop_pipeline_output((std::vector<void *> *)outputs); } + +// Training + +#ifdef ONERT_TRAIN + +NNFW_STATUS nnfw_train_prepare(nnfw_session *session, const nnfw_train_info *info) +{ + NNFW_RETURN_ERROR_IF_NULL(session); + return session->train_prepare(info); +} + +NNFW_STATUS nnfw_train_input_tensorinfo(nnfw_session *session, uint32_t index, + nnfw_tensorinfo *info) +{ + NNFW_RETURN_ERROR_IF_NULL(session); + return session->train_input_tensorinfo(index, info); +} + +NNFW_STATUS nnfw_train_expected_tensorinfo(nnfw_session *session, uint32_t index, + nnfw_tensorinfo *info) +{ + NNFW_RETURN_ERROR_IF_NULL(session); + return session->train_expected_tensorinfo(index, info); +} + +NNFW_STATUS nnfw_train_set_input(nnfw_session *session, uint32_t index, const void *input, + const nnfw_tensorinfo *input_info) +{ + NNFW_RETURN_ERROR_IF_NULL(session); + return session->train_set_input(index, input, input_info); +} + +NNFW_STATUS nnfw_train_set_expected(nnfw_session *session, uint32_t index, const void *expected, + const nnfw_tensorinfo *expected_info) +{ + NNFW_RETURN_ERROR_IF_NULL(session); + return session->train_set_expected(index, expected, expected_info); +} + +NNFW_STATUS nnfw_train(nnfw_session *session, bool update_weights) +{ + NNFW_RETURN_ERROR_IF_NULL(session); + return session->train_run(update_weights); +} + +NNFW_STATUS nnfw_train_get_loss(nnfw_session *session, uint32_t index, float *loss) +{ + NNFW_RETURN_ERROR_IF_NULL(session); + return session->train_get_loss(index, loss); +} + +NNFW_STATUS nnfw_train_export_circle(nnfw_session *session, const char *path) +{ + NNFW_RETURN_ERROR_IF_NULL(session); + return session->train_export_circle(path); +} + +#else // ONERT_TRAIN + +NNFW_STATUS nnfw_train_prepare(nnfw_session *session, const nnfw_train_info *) +{ + NNFW_RETURN_ERROR_IF_NULL(session); + return NNFW_STATUS_ERROR; +} + +NNFW_STATUS nnfw_train_input_tensorinfo(nnfw_session *session, uint32_t, nnfw_tensorinfo *) +{ + NNFW_RETURN_ERROR_IF_NULL(session); + return NNFW_STATUS_ERROR; +} + +NNFW_STATUS nnfw_train_expected_tensorinfo(nnfw_session *session, uint32_t, nnfw_tensorinfo *) +{ + NNFW_RETURN_ERROR_IF_NULL(session); + return NNFW_STATUS_ERROR; +} + +NNFW_STATUS nnfw_train_set_input(nnfw_session *session, uint32_t, const void *, + const nnfw_tensorinfo *) +{ + NNFW_RETURN_ERROR_IF_NULL(session); + return NNFW_STATUS_ERROR; +} + +NNFW_STATUS nnfw_train_set_expected(nnfw_session *session, uint32_t, const void *, + const nnfw_tensorinfo *) +{ + NNFW_RETURN_ERROR_IF_NULL(session); + return NNFW_STATUS_ERROR; +} + +NNFW_STATUS nnfw_train(nnfw_session *session, bool) +{ + NNFW_RETURN_ERROR_IF_NULL(session); + return NNFW_STATUS_ERROR; +} + +NNFW_STATUS nnfw_train_get_loss(nnfw_session *session, uint32_t, float *) +{ + NNFW_RETURN_ERROR_IF_NULL(session); + return NNFW_STATUS_ERROR; +} + +NNFW_STATUS nnfw_train_export_circle(nnfw_session *session, const char *) +{ + NNFW_RETURN_ERROR_IF_NULL(session); + return NNFW_STATUS_ERROR; +} + +#endif // ONERT_TRAIN + +// Quantization + +NNFW_STATUS nnfw_set_quantization_type(nnfw_session *session, NNFW_QUANTIZE_TYPE qtype) +{ + NNFW_RETURN_ERROR_IF_NULL(session); + return session->set_quantization_type(qtype); +} + +NNFW_STATUS nnfw_set_quantized_model_path(nnfw_session *session, const char *path) +{ + NNFW_RETURN_ERROR_IF_NULL(session); + return session->set_quantized_model_path(path); +} + +NNFW_STATUS nnfw_quantize(nnfw_session *session) +{ + NNFW_RETURN_ERROR_IF_NULL(session); + return session->quantize(); +} diff --git a/runtime/onert/api/src/nnfw_api_internal.cc b/runtime/onert/api/src/nnfw_api_internal.cc index 8eedb5314..fc02a9227 100644 --- a/runtime/onert/api/src/nnfw_api_internal.cc +++ b/runtime/onert/api/src/nnfw_api_internal.cc @@ -28,6 +28,7 @@ #include "ir/NNPkg.h" #include "ir/OpCode.h" #include "util/TracingCtx.h" +#include "odc/QuantizeManager.h" #include <fstream> #include <iostream> @@ -73,7 +74,7 @@ onert::ir::Layout convertLayout(NNFW_LAYOUT layout) return onert::ir::Layout::UNKNOWN; } -NNFW_STATUS getTensorIndexImpl(const onert::ir::Graph &graph, const char *tensorname, +NNFW_STATUS getTensorIndexImpl(const onert::ir::IGraph &graph, const char *tensorname, uint32_t *index, bool is_input) { if (!tensorname || !index) @@ -195,11 +196,34 @@ std::unique_ptr<onert::ir::Model> loadModel(const std::string filename, return std::unique_ptr<onert::ir::Model>(nullptr); } +#ifdef ONERT_TRAIN +uint64_t getBufSize(const nnfw_tensorinfo *info) +{ + static int elmsize[] = { + sizeof(float), /* NNFW_TYPE_TENSOR_FLOAT32 = 0 */ + sizeof(int), /* NNFW_TYPE_TENSOR_INT32 = 1 */ + sizeof(uint8_t), /* NNFW_TYPE_TENSOR_QUANT8_ASYMM = 2 */ + sizeof(bool), /* NNFW_TYPE_TENSOR_BOOL = 3 */ + sizeof(uint8_t), /* NNFW_TYPE_TENSOR_UINT8 = 4 */ + sizeof(int64_t), /* NNFW_TYPE_TENSOR_INT64 = 5 */ + sizeof(int8_t), /* NNFW_TYPE_TENSOR_QUANT8_ASYMM_SIGNED = 6 */ + sizeof(int16_t), /* NNFW_TYPE_TENSOR_QUANT16_SYMM_SIGNED = 7 */ + }; + + uint64_t n = 1; + for (int32_t i = 0; i < info->rank; ++i) + { + assert(info->dims[i] >= 0); + n *= info->dims[i]; + } + return elmsize[info->dtype] * n; +} +#endif // ONERT_TRAIN } // namespace nnfw_session::nnfw_session() : _nnpkg{nullptr}, _coptions{}, _compiler_artifact{nullptr}, _execution{nullptr}, - _kernel_registry{nullptr} + _kernel_registry{nullptr}, _quant_manager{nullptr} { // DO NOTHING } @@ -268,6 +292,9 @@ NNFW_STATUS nnfw_session::load_model_from_modelfile(const char *model_file_path) return NNFW_STATUS_UNEXPECTED_NULL; } + // Create quantize manager + _quant_manager = std::make_unique<onert::odc::QuantizeManager>(std::string(model_file_path)); + std::string filename{model_file_path}; // TODO: Use std::filesystem::path when we can use c++17. auto dotidx = filename.find_last_of('.'); @@ -352,6 +379,11 @@ NNFW_STATUS nnfw_session::load_model_from_nnpackage(const char *package_dir) return NNFW_STATUS_ERROR; } + // Create quantize manager + // TODO Support multiple models + auto const model_filename = package_path + std::string("/") + models[0].asString(); + _quant_manager = std::make_unique<onert::odc::QuantizeManager>(model_filename); + for (uint16_t i = 0; i < num_models; ++i) { auto model_file_path = package_path + std::string("/") + models[i].asString(); @@ -359,7 +391,7 @@ NNFW_STATUS nnfw_session::load_model_from_nnpackage(const char *package_dir) auto model = loadModel(model_file_path, model_type); if (model == nullptr) return NNFW_STATUS_ERROR; - model->primary_subgraph()->bindKernelBuilder(_kernel_registry->getBuilder()); + model->bindKernelBuilder(_kernel_registry->getBuilder()); _nnpkg->push(onert::ir::ModelIndex{i}, std::move(model)); _coptions.push_back(onert::compiler::CompilerOptions::fromGlobalConfig()); } @@ -697,8 +729,7 @@ NNFW_STATUS nnfw_session::apply_tensorinfo(uint32_t index, nnfw_tensorinfo ti) { // In this case, if we apply input shape, it will propagate after compilation and excution - auto &info = _nnpkg->inputInfo(index); - info.shape(new_shape); + _nnpkg->changeInputShape(index, new_shape); } else // when called after nnfw_session::prepare() _execution->changeInputShape(onert::ir::IOIndex(index), new_shape); @@ -941,7 +972,7 @@ NNFW_STATUS nnfw_session::set_config(const char *key, const char *value) return NNFW_STATUS_NO_ERROR; } -const onert::ir::Graph *nnfw_session::primary_subgraph() +const onert::ir::IGraph *nnfw_session::primary_subgraph() { if (_nnpkg != nullptr) { @@ -1129,3 +1160,400 @@ NNFW_STATUS nnfw_session::set_backends_per_operation(const char *backend_setting return NNFW_STATUS_NO_ERROR; } + +#ifdef ONERT_TRAIN +NNFW_STATUS nnfw_session::train_prepare(const nnfw_train_info *info) +{ + // We may need different state to represent training model is loaded + if (!isStateModelLoaded()) + { + std::cerr << "Error during model prepare training: "; + if (_state == State::PREPARED_TRAINING) + std::cerr << "prepare should be run once"; + else + std::cerr << "invalid state"; + std::cerr << std::endl; + return NNFW_STATUS_INVALID_STATE; + } + + try + { + nnfw_train_info tinfo; + if (info != nullptr) + { + tinfo = *info; + } + + auto convertLossType = [](const int &type) { + if (type == NNFW_TRAIN_LOSS_MEAN_SQUARED_ERROR) + return onert::ir::operation::Loss::Type::MEAN_SQUARED_ERROR; + if (type == NNFW_TRAIN_LOSS_CATEGORICAL_CROSSENTROPY) + return onert::ir::operation::Loss::Type::CATEGORICAL_CROSSENTROPY; + else + throw std::runtime_error("not supported loss type"); + }; + onert::compiler::train::LossInfo loss_info; + loss_info.type = convertLossType(tinfo.loss); + + auto convertOptType = [](const int &type) { + if (type == NNFW_TRAIN_OPTIMIZER_SGD) + return onert::exec::train::optimizer::OptimizerCode::SGD; + else if (type == NNFW_TRAIN_OPTIMIZER_ADAM) + return onert::exec::train::optimizer::OptimizerCode::Adam; + else + throw std::runtime_error("not supported optimizer type"); + }; + onert::compiler::train::OptimizerInfo opt_info; + opt_info.learning_rate = tinfo.learning_rate; + opt_info.optim_code = convertOptType(tinfo.opt); + + onert::compiler::train::TrainingInfo training_info; + training_info.setBatchSize(tinfo.batch_size); + training_info.setLossInfo(loss_info); + training_info.setOptimizerInfo(opt_info); + + auto compiler = + onert::compiler::CompilerFactory::get().create(_nnpkg, _coptions, &training_info); + _nnpkg.reset(); + _compiler_artifact = compiler->compile(); + _execution = std::make_unique<onert::exec::Execution>(_compiler_artifact->_executors); + } + catch (const std::exception &e) + { + std::cerr << "Error during nnfw_session::train_prepare : " << e.what() << std::endl; + return NNFW_STATUS_ERROR; + } + + _state = State::PREPARED_TRAINING; + return NNFW_STATUS_NO_ERROR; +} + +NNFW_STATUS nnfw_session::train_input_tensorinfo(uint32_t index, nnfw_tensorinfo *ti) +{ + if (!isStatePreparedOrFinishedTraining()) + { + std::cerr << "Error during nnfw_session::train_input_tensorinfo : invalid state" << std::endl; + return NNFW_STATUS_INVALID_STATE; + } + + // Check index is valid: [0, getInputSize()) + + // NYI + (void)index; + (void)ti; + return NNFW_STATUS_ERROR; +} + +NNFW_STATUS nnfw_session::train_expected_tensorinfo(uint32_t index, nnfw_tensorinfo *ti) +{ + if (!isStatePreparedOrFinishedTraining()) + { + std::cerr << "Error during nnfw_session::train_expected_tensorinfo : invalid state" + << std::endl; + return NNFW_STATUS_INVALID_STATE; + } + + // Check index is valid: [0, getExpectedSize()) + + // NYI + (void)index; + (void)ti; + return NNFW_STATUS_ERROR; +} + +NNFW_STATUS nnfw_session::train_set_input(uint32_t index, const void *input, + const nnfw_tensorinfo *input_tensorinfo) +{ + if (input == nullptr) + { + std::cerr << "Error during nnfw_session::train_set_input : input buffer is null" << std::endl; + return NNFW_STATUS_UNEXPECTED_NULL; + } + + if (!isStatePreparedOrFinishedTraining()) + { + std::cerr << "Error during nnfw_session::train_set_input : invalid state" << std::endl; + return NNFW_STATUS_INVALID_STATE; + } + + if (index >= getInputSize()) + { + std::cerr << "Error during nnfw_session::train_set_input : index is out of range" << std::endl; + return NNFW_STATUS_ERROR; + } + + try + { + auto ind = onert::ir::IOIndex(index); + auto size = _execution->getInputTotalSize(ind); + if (input_tensorinfo && getBufSize(input_tensorinfo) != size) + { + std::cerr + << "Error during nnfw_session::train_set_input : not supporeted to change tensorinfo" + << std::endl; + return NNFW_STATUS_ERROR; + } + + _execution->setInput(ind, input, size); + } + catch (const std::exception &e) + { + std::cerr << "Error during nnfw_session::train_set_input : " << e.what() << std::endl; + return NNFW_STATUS_ERROR; + } + + return NNFW_STATUS_NO_ERROR; +} + +NNFW_STATUS nnfw_session::train_set_expected(uint32_t index, const void *expected, + const nnfw_tensorinfo *expected_tensorinfo) +{ + if (expected == nullptr) + { + std::cerr << "Error during nnfw_session::train_set_expected : expected buffer is null" + << std::endl; + return NNFW_STATUS_UNEXPECTED_NULL; + } + + if (!isStatePreparedOrFinishedTraining()) + { + std::cerr << "Error during nnfw_session::train_set_expected : invalid state" << std::endl; + return NNFW_STATUS_INVALID_STATE; + } + + if (index >= getOutputSize()) + { + std::cerr << "Error during nnfw_session::train_set_expected : index is out of range" + << std::endl; + return NNFW_STATUS_ERROR; + } + + try + { + auto output_ind = onert::ir::IOIndex(index); + auto size = _execution->getOutputTotalSize(output_ind); + if (expected_tensorinfo && getBufSize(expected_tensorinfo) != size) + { + std::cerr << "Error during nnfw_session::train_set_expected : invalid tensorinfo" + << std::endl; + return NNFW_STATUS_ERROR; + } + + // NOTE Find the loss input index + // Input is added as many as the number of outputs. + // The loss index is calculated from the value obtained by subtracting the + // total output(added loss input) from the total input size. + auto input_index = getInputSize() - getOutputSize() + index; + auto input_ind = onert::ir::IOIndex(input_index); + _execution->setInput(input_ind, expected, size); + } + catch (const std::exception &e) + { + std::cerr << "Error during nnfw_session::train_set_expected : " << e.what() << std::endl; + return NNFW_STATUS_ERROR; + } + + return NNFW_STATUS_NO_ERROR; +} + +NNFW_STATUS nnfw_session::train_run(bool update_weights) +{ + if (!isStatePreparedOrFinishedTraining()) + { + std::cerr << "Error during nnfw_session::train_run : invalid state" << std::endl; + return NNFW_STATUS_INVALID_STATE; + } + + try + { + if (update_weights) + { + _execution->train(_training_step++); + } + else + _execution->execute(); + } + catch (const onert::InsufficientBufferSizeException &e) + { + // Currently insufficient buffer always means output buffer. + std::cerr << "Error during nnfw_session::train_run : " << e.what() << std::endl; + return NNFW_STATUS_INSUFFICIENT_OUTPUT_SIZE; + } + catch (const std::exception &e) + { + std::cerr << "Error during nnfw_session::train_run : " << e.what() << std::endl; + return NNFW_STATUS_ERROR; + } + + _state = State::FINISHED_TRAINING; + return NNFW_STATUS_NO_ERROR; +} + +NNFW_STATUS nnfw_session::train_get_loss(uint32_t index, float *loss) +{ + if (loss == nullptr) + { + std::cerr << "Error during nnfw_session::train_get_loss : loss is null" << std::endl; + return NNFW_STATUS_UNEXPECTED_NULL; + } + + if (!isStateFinishedTraining()) + { + std::cerr << "Error during nnfw_session::train_get_loss : invalid state" << std::endl; + return NNFW_STATUS_INVALID_STATE; + } + + if (index >= getOutputSize()) + { + std::cerr << "Error during nnfw_session::train_get_loss : index is out of range" << std::endl; + return NNFW_STATUS_ERROR; + } + + try + { + auto ind = onert::ir::IOIndex(index); + *loss = _execution->getLoss(ind); + } + catch (const std::exception &e) + { + std::cerr << "Error during nnfw_session::train_get_loss : " << e.what() << std::endl; + return NNFW_STATUS_ERROR; + } + + return NNFW_STATUS_NO_ERROR; +} + +NNFW_STATUS nnfw_session::train_export_circle(const char *path) +{ + if (path == nullptr) + { + std::cerr << "Error during nnfw_session::train_export_circle : path is null" << std::endl; + return NNFW_STATUS_UNEXPECTED_NULL; + } + + // Check training mode is enabled + if (!isStateFinishedTraining()) + { + std::cerr << "Error during nnfw_session::train_export_circle : invalid state" << std::endl; + return NNFW_STATUS_INVALID_STATE; + } + + // NYI + return NNFW_STATUS_ERROR; +} + +bool nnfw_session::isStatePreparedTraining() +{ + if (_state == State::PREPARED_TRAINING) + { + assert(_nnpkg == nullptr); + assert(!_coptions.empty()); + assert(_execution != nullptr); + return true; + } + else + return false; +} + +bool nnfw_session::isStateFinishedTraining() +{ + if (_state == State::FINISHED_TRAINING) + { + assert(_nnpkg == nullptr); + assert(!_coptions.empty()); + assert(_execution != nullptr); + return true; + } + else + return false; +} + +bool nnfw_session::isStatePreparedOrFinishedTraining() +{ + return isStatePreparedTraining() || isStateFinishedTraining(); +} + +#endif // ONERT_TRAIN + +NNFW_STATUS nnfw_session::set_quantization_type(NNFW_QUANTIZE_TYPE qtype) +{ + try + { + if (!isStateModelLoaded()) + { + std::cerr << "invalid state" << std::endl; + return NNFW_STATUS_INVALID_STATE; + } + + bool is_q16 = false; + switch (qtype) + { + case NNFW_QUANTIZE_TYPE_U8_ASYM: + break; + case NNFW_QUANTIZE_TYPE_I16_SYM: + is_q16 = true; + break; + default: + return NNFW_STATUS_INVALID_STATE; + } + _quant_manager->quantizeType(is_q16); + } + catch (const std::exception &e) + { + std::cerr << "Error during nnfw_session::set_quantization_type : " << e.what() << std::endl; + return NNFW_STATUS_ERROR; + } + + return NNFW_STATUS_NO_ERROR; +} + +NNFW_STATUS nnfw_session::set_quantized_model_path(const char *path) +{ + try + { + if (!isStateModelLoaded()) + { + std::cerr << "invalid state" << std::endl; + return NNFW_STATUS_INVALID_STATE; + } + + _quant_manager->exportModelPath(std::string(path)); + } + catch (const std::exception &e) + { + std::cerr << "Error during nnfw_session::set_quantized_model_path : " << e.what() << std::endl; + return NNFW_STATUS_ERROR; + } + + return NNFW_STATUS_NO_ERROR; +} + +NNFW_STATUS nnfw_session::quantize() +{ + try + { + if (!isStateModelLoaded()) + { + std::cerr << "invalid state" << std::endl; + return NNFW_STATUS_INVALID_STATE; + } + + auto result = _quant_manager->quantize(); + if (!result) + return NNFW_STATUS_INVALID_STATE; + + // Replace model + // TODO Support buffer replace, not file reload + auto model = loadModel(_quant_manager->exportModelPath(), "circle"); + if (model == nullptr) + return NNFW_STATUS_ERROR; + _nnpkg->replaceModel(std::move(model)); + } + catch (const std::exception &e) + { + std::cerr << "Error during nnfw_session::quantize : " << e.what() << std::endl; + return NNFW_STATUS_ERROR; + } + + return NNFW_STATUS_NO_ERROR; +} diff --git a/runtime/onert/api/src/nnfw_api_internal.h b/runtime/onert/api/src/nnfw_api_internal.h index 8e2c2fba6..62791765e 100644 --- a/runtime/onert/api/src/nnfw_api_internal.h +++ b/runtime/onert/api/src/nnfw_api_internal.h @@ -39,7 +39,7 @@ class Execution; } // namespace exec namespace ir { -class Graph; +struct IGraph; class Model; class NNPkg; } // namespace ir @@ -48,6 +48,10 @@ namespace compiler struct CompilerArtifact; class CompilerOptions; } // namespace compiler +namespace odc +{ +class QuantizeManager; +} // namespace odc } // namespace onert struct nnfw_session @@ -90,11 +94,13 @@ private: */ enum class State { - INITIALIZED, //< Session is initialized and nothing has done to it - MODEL_LOADED, //< Model is loaded - PREPARED, //< Prepared(compiled) for execution - RUNNING, //< Execution is in progress (only for asynchronous execution) - FINISHED_RUN //< Executed at least once + INITIALIZED, //< Session is initialized and nothing has done to it + MODEL_LOADED, //< Model is loaded + PREPARED, //< Prepared(compiled) for execution + RUNNING, //< Execution is in progress (only for asynchronous execution) + FINISHED_RUN, //< Executed at least once + PREPARED_TRAINING, //< Prepared for training + FINISHED_TRAINING //< Trained at least once }; public: @@ -160,8 +166,25 @@ public: */ NNFW_STATUS set_backends_per_operation(const char *backend_settings); +#ifdef ONERT_TRAIN + NNFW_STATUS train_prepare(const nnfw_train_info *info); + NNFW_STATUS train_input_tensorinfo(uint32_t index, nnfw_tensorinfo *ti); + NNFW_STATUS train_expected_tensorinfo(uint32_t index, nnfw_tensorinfo *ti); + NNFW_STATUS train_set_input(uint32_t index, const void *input, + const nnfw_tensorinfo *input_tensorinfo); + NNFW_STATUS train_set_expected(uint32_t index, const void *expected, + const nnfw_tensorinfo *expected_tensorinfo); + NNFW_STATUS train_run(bool update_weights); + NNFW_STATUS train_get_loss(uint32_t index, float *loss); + NNFW_STATUS train_export_circle(const char *path); +#endif // ONERT_TRAIN + + NNFW_STATUS set_quantization_type(NNFW_QUANTIZE_TYPE qtype); + NNFW_STATUS set_quantized_model_path(const char *path); + NNFW_STATUS quantize(); + private: - const onert::ir::Graph *primary_subgraph(); + const onert::ir::IGraph *primary_subgraph(); uint32_t getInputSize(); uint32_t getOutputSize(); @@ -171,6 +194,11 @@ private: bool isStateRunning(); bool isStateFinishedRun(); bool isStatePreparedOrFinishedRun(); +#ifdef ONERT_TRAIN + bool isStatePreparedTraining(); + bool isStateFinishedTraining(); + bool isStatePreparedOrFinishedTraining(); +#endif // ONERT_TRAIN private: State _state{State::INITIALIZED}; @@ -180,6 +208,10 @@ private: std::unique_ptr<onert::exec::Execution> _execution; std::shared_ptr<onert::api::CustomKernelRegistry> _kernel_registry; std::vector<std::thread> _threads; +#ifdef ONERT_TRAIN + uint32_t _training_step{0}; +#endif // ONERT_TRAIN + std::unique_ptr<onert::odc::QuantizeManager> _quant_manager; }; #endif // __API_NNFW_API_INTERNAL_H__ diff --git a/runtime/onert/backend/CMakeLists.txt b/runtime/onert/backend/CMakeLists.txt index c43160ba7..e6af06afe 100644 --- a/runtime/onert/backend/CMakeLists.txt +++ b/runtime/onert/backend/CMakeLists.txt @@ -5,6 +5,7 @@ add_subdirectory(cl_common) add_subdirectory(acl_common) # Backends +set(LIB_ONERT_BACKEND_CPU onert_backend_cpu) add_subdirectory(cpu) add_subdirectory(acl_cl) add_subdirectory(acl_neon) @@ -12,3 +13,9 @@ add_subdirectory(ruy) add_subdirectory(gpu_cl) add_subdirectory(xnnpack) add_subdirectory(trix) + +# Backend to train +if(ENABLE_ONERT_TRAIN) + add_subdirectory(train) +endif(ENABLE_ONERT_TRAIN) + diff --git a/runtime/onert/backend/acl_cl/Config.cc b/runtime/onert/backend/acl_cl/Config.cc index c10fdc1fe..4d12d60b3 100644 --- a/runtime/onert/backend/acl_cl/Config.cc +++ b/runtime/onert/backend/acl_cl/Config.cc @@ -47,7 +47,7 @@ bool Config::initialize() return true; } -ir::Layout Config::supportLayout(const ir::Operation &, ir::Layout frontend_layout) +ir::Layout Config::supportLayout(const ir::IOperation &, ir::Layout frontend_layout) { const std::string acl_layout_str = util::getConfigString(util::config::ACL_LAYOUT); if (acl_layout_str == "NHWC") diff --git a/runtime/onert/backend/acl_cl/Config.h b/runtime/onert/backend/acl_cl/Config.h index f71e81b6a..1fa1aeb00 100644 --- a/runtime/onert/backend/acl_cl/Config.h +++ b/runtime/onert/backend/acl_cl/Config.h @@ -35,7 +35,7 @@ public: std::string id() override { return "acl_cl"; } bool initialize() override; bool supportPermutation() override { return true; } - ir::Layout supportLayout(const ir::Operation &node, ir::Layout frontend_layout) override; + ir::Layout supportLayout(const ir::IOperation &node, ir::Layout frontend_layout) override; bool supportDynamicTensor() override { return false; } bool supportFP16() override { return true; } void sync() const override { arm_compute::CLScheduler::get().sync(); } diff --git a/runtime/onert/backend/acl_cl/KernelGenerator.cc b/runtime/onert/backend/acl_cl/KernelGenerator.cc index 5b0ec92b7..dcf31858e 100644 --- a/runtime/onert/backend/acl_cl/KernelGenerator.cc +++ b/runtime/onert/backend/acl_cl/KernelGenerator.cc @@ -82,7 +82,7 @@ void KernelGenerator::visit(const ir::operation::BatchToSpaceND &node) } auto crops = _ctx.at(crops_index).asVector<int32_t>(); - for (auto crop : crops) + for (auto &&crop : crops) { if (crop != 0) { diff --git a/runtime/onert/backend/acl_cl/Optimizer.cc b/runtime/onert/backend/acl_cl/Optimizer.cc index a9ce888ee..0f779f483 100644 --- a/runtime/onert/backend/acl_cl/Optimizer.cc +++ b/runtime/onert/backend/acl_cl/Optimizer.cc @@ -44,7 +44,7 @@ void Optimizer::optimize() acl_common::AclSubTensorAnalyzer sa{*_context->graph()}; sa.setUsePadding(); _context->graph()->operations().iterate( - [&](const ir::OperationIndex &, const ir::Operation &op) { + [&](const ir::OperationIndex &, const ir::IOperation &op) { sa.setLayout(_context->graph()->layout()); op.accept(sa); }); diff --git a/runtime/onert/backend/acl_common/AclTensorManager.h b/runtime/onert/backend/acl_common/AclTensorManager.h index 268cec201..41a89fbf2 100644 --- a/runtime/onert/backend/acl_common/AclTensorManager.h +++ b/runtime/onert/backend/acl_common/AclTensorManager.h @@ -261,13 +261,13 @@ template <typename T_ITensor, typename T_Tensor, typename T_SubTensor> void AclTensorManager<T_ITensor, T_Tensor, T_SubTensor>::iterate( const std::function<void(const ir::OperandIndex &)> &fn) { - for (auto it : _nonconst_mgr->tensors()) + for (auto &&it : _nonconst_mgr->tensors()) fn(it.first); - for (auto it : _nonconst_mgr->subtensors()) + for (auto &&it : _nonconst_mgr->subtensors()) fn(it.first); - for (auto it : _const_mgr->tensors()) + for (auto &&it : _const_mgr->tensors()) fn(it.first); } diff --git a/runtime/onert/backend/acl_neon/Config.cc b/runtime/onert/backend/acl_neon/Config.cc index 4e78efd2d..3f1758c80 100644 --- a/runtime/onert/backend/acl_neon/Config.cc +++ b/runtime/onert/backend/acl_neon/Config.cc @@ -27,7 +27,7 @@ namespace acl_neon bool Config::initialize() { return true; } -ir::Layout Config::supportLayout(const ir::Operation &, ir::Layout frontend_layout) +ir::Layout Config::supportLayout(const ir::IOperation &, ir::Layout frontend_layout) { const std::string acl_layout_str = util::getConfigString(util::config::ACL_LAYOUT); if (acl_layout_str == "NHWC") diff --git a/runtime/onert/backend/acl_neon/Config.h b/runtime/onert/backend/acl_neon/Config.h index 089d9479a..ffd9b21e3 100644 --- a/runtime/onert/backend/acl_neon/Config.h +++ b/runtime/onert/backend/acl_neon/Config.h @@ -33,7 +33,7 @@ class Config : public IConfig public: std::string id() override { return "acl_neon"; } bool initialize() override; - ir::Layout supportLayout(const ir::Operation &node, ir::Layout frontend_layout) override; + ir::Layout supportLayout(const ir::IOperation &node, ir::Layout frontend_layout) override; bool supportPermutation() override { return true; } bool supportDynamicTensor() override { return false; } bool supportFP16() override { return false; } diff --git a/runtime/onert/backend/acl_neon/KernelGenerator.cc b/runtime/onert/backend/acl_neon/KernelGenerator.cc index 94ea86dcf..e71aa3693 100644 --- a/runtime/onert/backend/acl_neon/KernelGenerator.cc +++ b/runtime/onert/backend/acl_neon/KernelGenerator.cc @@ -111,7 +111,7 @@ void KernelGenerator::visit(const ir::operation::BatchToSpaceND &node) } auto crops = _ctx.at(crops_index).asVector<int32_t>(); - for (auto crop : crops) + for (auto &&crop : crops) { if (crop != 0) { diff --git a/runtime/onert/backend/acl_neon/Optimizer.cc b/runtime/onert/backend/acl_neon/Optimizer.cc index 283edd174..f207ca8cb 100644 --- a/runtime/onert/backend/acl_neon/Optimizer.cc +++ b/runtime/onert/backend/acl_neon/Optimizer.cc @@ -44,7 +44,7 @@ void Optimizer::optimize() acl_common::AclSubTensorAnalyzer sa{*_context->graph()}; sa.setUsePadding(); _context->graph()->operations().iterate( - [&](const ir::OperationIndex &, const ir::Operation &op) { + [&](const ir::OperationIndex &, const ir::IOperation &op) { sa.setLayout(_context->graph()->layout()); op.accept(sa); }); diff --git a/runtime/onert/backend/cl_common/include/cl_common/BackendContext.h b/runtime/onert/backend/cl_common/include/cl_common/BackendContext.h index 5536d2780..76d403949 100644 --- a/runtime/onert/backend/cl_common/include/cl_common/BackendContext.h +++ b/runtime/onert/backend/cl_common/include/cl_common/BackendContext.h @@ -51,7 +51,7 @@ public: FunctionMap ret; // kernel_gen - for (auto op_ind : _data.op_order) + for (auto &&op_ind : _data.op_order) { auto fn_seq = kernel_gen->generate(op_ind); ret.emplace_back(op_ind, std::move(fn_seq)); @@ -80,7 +80,7 @@ public: protected: void initConsts() { - _data.graph->operations().iterate([&](const ir::OperationIndex &, const ir::Operation &op) { + _data.graph->operations().iterate([&](const ir::OperationIndex &, const ir::IOperation &op) { constant_initializer->setLayout(graph()->layout()); op.accept(*constant_initializer); }); @@ -144,7 +144,7 @@ protected: // 1. Scan DEF of outputs. If the DEF, allocate it // 2. Scan DEF of inputs. If variable tensor, allocate it // 3. Scan USE of inputs. Decrease the USE and deallocate if the USE is 0 - for (const auto op_ind : _data.op_order) + for (const auto &op_ind : _data.op_order) { const auto &op = graph()->operations().at(op_ind); auto op_inputs = op.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED; diff --git a/runtime/onert/backend/cpu/BackendContext.cc b/runtime/onert/backend/cpu/BackendContext.cc index da48a785d..45de6b972 100644 --- a/runtime/onert/backend/cpu/BackendContext.cc +++ b/runtime/onert/backend/cpu/BackendContext.cc @@ -37,7 +37,7 @@ FunctionMap BackendContext::genKernels() { FunctionMap ret; - for (auto op_ind : _data.op_order) + for (auto &&op_ind : _data.op_order) { auto fn_seq = kernel_gen->generate(op_ind); ret.emplace_back(op_ind, std::move(fn_seq)); diff --git a/runtime/onert/backend/cpu/CMakeLists.txt b/runtime/onert/backend/cpu/CMakeLists.txt index 99643b983..1383263e7 100644 --- a/runtime/onert/backend/cpu/CMakeLists.txt +++ b/runtime/onert/backend/cpu/CMakeLists.txt @@ -1,11 +1,10 @@ -set(LIB_ONERT_BACKEND_CPU onert_backend_cpu) - nnfw_find_package(Ruy REQUIRED) file(GLOB_RECURSE SOURCES "*.cc") add_library(${LIB_ONERT_BACKEND_CPU} SHARED ${SOURCES}) +target_include_directories(${LIB_ONERT_BACKEND_CPU} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) target_link_libraries(${LIB_ONERT_BACKEND_CPU} PRIVATE nnfw_lib_cker nnfw_lib_misc) target_link_libraries(${LIB_ONERT_BACKEND_CPU} PRIVATE onert_core) target_link_libraries(${LIB_ONERT_BACKEND_CPU} PRIVATE nnfw_common) @@ -15,6 +14,7 @@ target_link_libraries(${LIB_ONERT_BACKEND_CPU} INTERFACE ruy_instrumentation) target_link_libraries(${LIB_ONERT_BACKEND_CPU} PRIVATE ndarray) set_target_properties(${LIB_ONERT_BACKEND_CPU} PROPERTIES OUTPUT_NAME backend_cpu) +set_target_properties(${LIB_ONERT_BACKEND_CPU} PROPERTIES POSITION_INDEPENDENT_CODE ON) if(CMAKE_BUILD_TYPE_LC STREQUAL "release") add_custom_command(TARGET ${LIB_ONERT_BACKEND_CPU} POST_BUILD diff --git a/runtime/onert/backend/cpu/Config.cc b/runtime/onert/backend/cpu/Config.cc index 3ace47f5d..f80c2caf1 100644 --- a/runtime/onert/backend/cpu/Config.cc +++ b/runtime/onert/backend/cpu/Config.cc @@ -25,7 +25,7 @@ namespace cpu bool Config::initialize() { return true; } -ir::Layout Config::supportLayout(const ir::Operation &, ir::Layout) { return ir::Layout::NHWC; } +ir::Layout Config::supportLayout(const ir::IOperation &, ir::Layout) { return ir::Layout::NHWC; } } // namespace cpu } // namespace backend diff --git a/runtime/onert/backend/cpu/Config.h b/runtime/onert/backend/cpu/Config.h index 37e49581a..841a839d1 100644 --- a/runtime/onert/backend/cpu/Config.h +++ b/runtime/onert/backend/cpu/Config.h @@ -33,7 +33,7 @@ class Config : public IConfig public: std::string id() override { return "cpu"; } bool initialize() override; - ir::Layout supportLayout(const ir::Operation &node, ir::Layout frontend_layout) override; + ir::Layout supportLayout(const ir::IOperation &node, ir::Layout frontend_layout) override; bool supportPermutation() override { return true; } bool supportDynamicTensor() override { return true; } bool supportFP16() override { return false; } diff --git a/runtime/onert/backend/cpu/KernelGenerator.cc b/runtime/onert/backend/cpu/KernelGenerator.cc index 896883bc3..c927bf5d4 100644 --- a/runtime/onert/backend/cpu/KernelGenerator.cc +++ b/runtime/onert/backend/cpu/KernelGenerator.cc @@ -257,7 +257,7 @@ std::unique_ptr<exec::FunctionSequence> KernelGenerator::generate(ir::OperationI assert(_return_fn); // _return_fn must have been generated ret->append(std::move(_return_fn)); - for (auto ind : (op.getInputs() | ir::Remove::UNDEFINED) + op.getOutputs()) + for (auto &&ind : (op.getInputs() | ir::Remove::UNDEFINED) + op.getOutputs()) { auto portable_tensor = _tensor_reg->getPortableTensor(ind); if (portable_tensor) diff --git a/runtime/onert/backend/cpu/ops/ConvolutionLayer.cc b/runtime/onert/backend/cpu/ops/ConvolutionLayer.cc index 4672fe406..62e8ae4ba 100644 --- a/runtime/onert/backend/cpu/ops/ConvolutionLayer.cc +++ b/runtime/onert/backend/cpu/ops/ConvolutionLayer.cc @@ -16,6 +16,7 @@ #include "ConvolutionLayer.h" #include "OperationUtils.h" +#include "cker/PortableTensorUtils.h" #include "../Tensor.h" #include "ir/Padding.h" @@ -34,7 +35,7 @@ ConvolutionLayer::ConvolutionLayer() _paddingType(ir::PaddingType::EXPLICIT), _paddingLeft(0), _paddingTop(0), _paddingRight(0), _paddingBottom(0), _strideWidth(0), _strideHeight(0), _dilationWidthFactor(1), _dilationHeightFactor(1), _activation(ir::Activation::NONE), - _conv_kernel(new nnfw::cker::Conv()), _prepare(false) + _conv_kernel(new nnfw::cker::Conv()), _prepare(false), _is_hybrid(false) { // DO NOTHING } @@ -151,6 +152,47 @@ void ConvolutionLayer::convQ8i() reinterpret_cast<int8_t *>(_output->buffer())); } +void ConvolutionLayer::convQ8iHybridPerChannel() +{ + float output_activation_min = 0; + float output_activation_max = 0; + CalculateActivationRange(_activation, &output_activation_min, &output_activation_max); + + const int batch_size = getShape(_input).Dims(0); + if (batch_size == 0) + throw std::runtime_error{"Convolution input batch_size = 0"}; + auto input_shape = getShape(_input); + const int input_size = input_shape.FlatSize() / batch_size; + + auto input_quantized_ptr = _hybrid_arena->input_quantized.data(); + auto input_scaling_factors_ptr = _hybrid_arena->input_scaling_factors.data(); + auto input_offsets_ptr = _hybrid_arena->input_offsets.data(); + for (int b = 0; b < batch_size; ++b) + { + const int offset = b * input_size; + nnfw::cker::PortableAsymmetricQuantizeFloats( + reinterpret_cast<const float *>(_input->buffer()) + offset, input_size, + input_quantized_ptr + offset, &input_scaling_factors_ptr[b], &input_offsets_ptr[b]); + } + nnfw::cker::ConvParams op_params; + op_params.padding_type = getPaddingType(_paddingType); + op_params.padding_values.width = _paddingLeft; + op_params.padding_values.height = _paddingTop; + op_params.stride_width = _strideWidth; + op_params.stride_height = _strideHeight; + op_params.dilation_width_factor = _dilationWidthFactor; + op_params.dilation_height_factor = _dilationHeightFactor; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + + const auto *filter_per_channel_scales = _kernel->data_scales().data(); + nnfw::cker::reference::HybridConvPerChannel( + op_params, input_scaling_factors_ptr, getShape(_input), input_quantized_ptr, getShape(_kernel), + reinterpret_cast<const int8_t *>(_kernel->buffer()), getShape(_bias), + reinterpret_cast<const float *>(_bias->buffer()), getShape(_output), + reinterpret_cast<float *>(_output->buffer()), filter_per_channel_scales, input_offsets_ptr); +} + void ConvolutionLayer::configure(const IPortableTensor *input, const IPortableTensor *kernel, const IPortableTensor *bias, const ir::PaddingType paddingType, const uint32_t paddingLeft, const uint32_t paddingRight, @@ -174,12 +216,13 @@ void ConvolutionLayer::configure(const IPortableTensor *input, const IPortableTe _dilationHeightFactor = dilationHeightFactor; _activation = activation; _output = output; + _is_hybrid = _input->data_type() == OperandType::FLOAT32 && + _kernel->data_type() == OperandType::QUANT_INT8_SYMM; } void ConvolutionLayer::run() { prepare(); - if (_input->is_dynamic() || _kernel->is_dynamic()) { const auto ifm_shape = _input->getShape().asFeature(_input->layout()); @@ -209,7 +252,11 @@ void ConvolutionLayer::run() _paddingTop = padding.top; _paddingBottom = padding.bottom; } - if (_input->data_type() == OperandType::FLOAT32) + if (_is_hybrid) + { + convQ8iHybridPerChannel(); + } + else if (_input->data_type() == OperandType::FLOAT32) { convFloat32(); } @@ -236,6 +283,27 @@ void ConvolutionLayer::prepare() if (_prepare) return; + if (_is_hybrid) + { + // ensure weight is per-channel quantized. + int32_t kernel_output_channel = getShape(_kernel).Dims(0); + // zero_points comes from flatbuffer vector. Its size is within uint32_t range. + size_t kernel_zerop_cnt = _kernel->data_scales().size(); + // promote to int64_t to compare int32_t and uint32_t + if ((int64_t)kernel_output_channel != (int64_t)kernel_zerop_cnt) + throw std::runtime_error{"Conv2D hybrid supports only per-channel quantized weight."}; + + // allocate memory for activation quantization. + // - quantized values (int8_t type and same shape of original input) + // - quantization params (= scale/zeropoint for each input) + auto input_shape = getShape(_input); + const int batch_size = input_shape.Dims(0); + const int input_size = input_shape.FlatSize() / batch_size; + _hybrid_arena = std::make_unique<nnfw::cker::ConvHybridTempArena>(batch_size, input_size); + _prepare = true; + return; + } + nnfw::cker::Conv &kernel = *_conv_kernel; if (_input->data_type() == OperandType::FLOAT32 && _kernel->is_constant()) { diff --git a/runtime/onert/backend/cpu/ops/ConvolutionLayer.h b/runtime/onert/backend/cpu/ops/ConvolutionLayer.h index 9f5253c8e..5e1bd0b08 100644 --- a/runtime/onert/backend/cpu/ops/ConvolutionLayer.h +++ b/runtime/onert/backend/cpu/ops/ConvolutionLayer.h @@ -29,7 +29,9 @@ namespace nnfw namespace cker { class Conv; -} +struct ConvHybridTempArena; +class Shape; +} // namespace cker } // namespace nnfw namespace onert @@ -48,13 +50,6 @@ public: ~ConvolutionLayer(); public: - void convFloat32(); - - void convQ8uPerTensor(); - void convQ8uPerChannel(); - - void convQ8i(); - void configure(const IPortableTensor *input, const IPortableTensor *kernel, const IPortableTensor *bias, ir::PaddingType _paddingType, const uint32_t paddingLeft, const uint32_t paddingRight, const uint32_t paddingTop, @@ -62,10 +57,15 @@ public: const uint32_t strideHeight, const uint32_t dilationWidthFactor, const uint32_t dilationHeightFactor, const ir::Activation activation, IPortableTensor *output); - + void prepare() override; void run() override; - void prepare() override; +private: + void convFloat32(); + void convQ8uPerTensor(); + void convQ8uPerChannel(); + void convQ8i(); + void convQ8iHybridPerChannel(); private: const IPortableTensor *_input; @@ -87,8 +87,10 @@ private: ir::Activation _activation; std::unique_ptr<nnfw::cker::Conv> _conv_kernel; + std::unique_ptr<nnfw::cker::ConvHybridTempArena> _hybrid_arena; bool _prepare; + bool _is_hybrid; }; } // namespace ops diff --git a/runtime/onert/backend/cpu/ops/DepthwiseConvolutionLayer.cc b/runtime/onert/backend/cpu/ops/DepthwiseConvolutionLayer.cc index 8a48497d5..9e6de17f2 100644 --- a/runtime/onert/backend/cpu/ops/DepthwiseConvolutionLayer.cc +++ b/runtime/onert/backend/cpu/ops/DepthwiseConvolutionLayer.cc @@ -16,6 +16,7 @@ #include "DepthwiseConvolutionLayer.h" +#include "cker/PortableTensorUtils.h" #include <cker/operation/DepthwiseConv.h> namespace onert @@ -147,6 +148,50 @@ void DepthwiseConvolutionLayer::convQ8i() _external_context->ruy_context()); } +void DepthwiseConvolutionLayer::convQ8iHybridPerChannel() +{ + if (!_prepared) + { + prepareQ8iHybridPerChannel(); + _prepared = true; + } + + float output_activation_min = 0, output_activation_max = 0; + CalculateActivationRange(_activation, &output_activation_min, &output_activation_max); + + auto input_shape = getShape(_input); + const int batch_size = input_shape.Dims(0); + const int input_size = input_shape.FlatSize() / batch_size; + + auto scaling_factors_ptr = _input_scaling_factors.data(); + auto input_offsets_ptr = _input_offsets.data(); + + for (int b = 0; b < batch_size; ++b) + { + const int offset = b * input_size; + nnfw::cker::PortableAsymmetricQuantizeFloats(getBuffer<float>(_input) + offset, input_size, + _input_quantized.data() + offset, + &scaling_factors_ptr[b], &input_offsets_ptr[b]); + } + + nnfw::cker::DepthwiseConvParams op_params; + op_params.padding_values.width = _paddingLeft; + op_params.padding_values.height = _paddingTop; + op_params.depth_multiplier = _multiplier; + op_params.stride_width = _strideWidth; + op_params.stride_height = _strideHeight; + op_params.dilation_width_factor = _dilationWidth; + op_params.dilation_height_factor = _dilationHeight; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + + nnfw::cker::reference_integer_ops::DepthwiseConvHybridPerChannel( + op_params, _input_scaling_factors.data(), getShape(_input), _input_quantized.data(), + getShape(_kernel), getBuffer<int8_t>(_kernel), getShape(_bias), getBuffer<float>(_bias), + getShape(_output), getBuffer<float>(_output), _kernel->data_scales().data(), + _input_offsets.data()); +} + void DepthwiseConvolutionLayer::prepareQ8i() { GetQuantizedConvolutionMultipliersAndShifts( @@ -163,6 +208,31 @@ void DepthwiseConvolutionLayer::prepareQ8uPerChannel() _per_channel_output_shift); } +void DepthwiseConvolutionLayer::prepareQ8iHybridPerChannel() +{ + // allocate memory for activation quantization. + // - quantized values (int8_t type and same shape of original input) + // - quantization params (= scale/zeropoint for each input) + auto input_shape = getShape(_input); + const int batch_size = input_shape.Dims(0); + const int input_size = input_shape.FlatSize() / batch_size; + _input_quantized.resize(input_size); + // TODO: Optimize the case of batch_size = 1 + _input_scaling_factors.resize(batch_size); + _input_offsets.resize(batch_size); +} + +void DepthwiseConvolutionLayer::ensureQ8iHybridPerChannel() +{ + // ensure weight is per-channel quantized. + int32_t kernel_input_channel = getShape(_kernel).Dims(3); + // zero_points comes from flatbuffer vector. Its size is within uint32_t range. + size_t kernel_zerop_cnt = _kernel->data_scales().size(); + // promote to int64_t to compare int32_t and uint32_t + if ((int64_t)kernel_input_channel != (int64_t)kernel_zerop_cnt) + throw std::runtime_error{"DConv2D hybrid supports only per-channel quantized weight."}; +} + void DepthwiseConvolutionLayer::configure( const IPortableTensor *input, const IPortableTensor *kernel, const IPortableTensor *bias, const uint32_t paddingLeft, const uint32_t paddingRight, const uint32_t paddingTop, @@ -186,8 +256,16 @@ void DepthwiseConvolutionLayer::configure( _activation = activation; _output = output; _external_context = external_context; + _is_hybrid = _input->data_type() == OperandType::FLOAT32 && + _kernel->data_type() == OperandType::QUANT_INT8_SYMM; - if (_input->data_type() == OperandType::QUANT_INT8_ASYMM) + if (_is_hybrid) + { + ensureQ8iHybridPerChannel(); + prepareQ8iHybridPerChannel(); + _prepared = true; + } + else if (_input->data_type() == OperandType::QUANT_INT8_ASYMM) { if (_kernel->is_constant() && !_input->is_dynamic() && !_output->is_dynamic()) { @@ -209,7 +287,11 @@ void DepthwiseConvolutionLayer::configure( void DepthwiseConvolutionLayer::run() { - if (_input->data_type() == OperandType::FLOAT32) + if (_is_hybrid) + { + convQ8iHybridPerChannel(); + } + else if (_input->data_type() == OperandType::FLOAT32) { convFloat32(); } diff --git a/runtime/onert/backend/cpu/ops/DepthwiseConvolutionLayer.h b/runtime/onert/backend/cpu/ops/DepthwiseConvolutionLayer.h index 5c910109a..5721f8796 100644 --- a/runtime/onert/backend/cpu/ops/DepthwiseConvolutionLayer.h +++ b/runtime/onert/backend/cpu/ops/DepthwiseConvolutionLayer.h @@ -44,6 +44,7 @@ public: void convQ8uPerChannel(); void convQ8i(); + void convQ8iHybridPerChannel(); void configure(const IPortableTensor *input, const IPortableTensor *kernel, const IPortableTensor *bias, const uint32_t paddingLeft, @@ -58,6 +59,8 @@ public: private: void prepareQ8i(); void prepareQ8uPerChannel(); + void prepareQ8iHybridPerChannel(); + void ensureQ8iHybridPerChannel(); private: const IPortableTensor *_input{nullptr}; @@ -87,6 +90,12 @@ private: // Per channel output multiplier and shift. std::vector<int32_t> _per_channel_output_multiplier; std::vector<int> _per_channel_output_shift; + + // For hybrid + bool _is_hybrid{false}; + std::vector<int8_t> _input_quantized; + std::vector<float> _input_scaling_factors; + std::vector<int32_t> _input_offsets; }; } // namespace ops diff --git a/runtime/onert/backend/cpu/ops/ElementwiseActivationLayer.h b/runtime/onert/backend/cpu/ops/ElementwiseActivationLayer.h index 948ab3b57..d8a90148f 100644 --- a/runtime/onert/backend/cpu/ops/ElementwiseActivationLayer.h +++ b/runtime/onert/backend/cpu/ops/ElementwiseActivationLayer.h @@ -54,7 +54,7 @@ public: void EvalUsingLookupTable(const IPortableTensor *input, IPortableTensor *output); -private: +protected: const IPortableTensor *_input; IPortableTensor *_output; uint8_t _table[256]; diff --git a/runtime/onert/backend/cpu/ops/FullyConnectedLayer.cc b/runtime/onert/backend/cpu/ops/FullyConnectedLayer.cc index 6857f7f9f..32cad84cb 100644 --- a/runtime/onert/backend/cpu/ops/FullyConnectedLayer.cc +++ b/runtime/onert/backend/cpu/ops/FullyConnectedLayer.cc @@ -43,7 +43,16 @@ FullyConnectedLayer::~FullyConnectedLayer() = default; void FullyConnectedLayer::fullyConnectedFloat32() { nnfw::cker::FullyConnectedParams op_params; + float output_activation_min = 0; + float output_activation_max = 0; + CalculateActivationRange(_activation, &output_activation_min, &output_activation_max); + op_params.activation = convertActivationType(_activation); + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + // TODO Set both cachables as false when training + op_params.lhs_cacheable = _weights->is_constant(); + op_params.rhs_cacheable = _input->is_constant(); nnfw::cker::FullyConnected(op_params, getShape(_input), getBuffer<float>(_input), getShape(_weights), getBuffer<float>(_weights), getShape(_bias), diff --git a/runtime/onert/backend/cpu/ops/FullyConnectedLayer.h b/runtime/onert/backend/cpu/ops/FullyConnectedLayer.h index cd12829a0..c56398def 100644 --- a/runtime/onert/backend/cpu/ops/FullyConnectedLayer.h +++ b/runtime/onert/backend/cpu/ops/FullyConnectedLayer.h @@ -66,7 +66,7 @@ public: void prepare() override; -private: +protected: const IPortableTensor *_input; const IPortableTensor *_weights; const IPortableTensor *_bias; diff --git a/runtime/onert/backend/cpu/ops/OperationUtils.cc b/runtime/onert/backend/cpu/ops/OperationUtils.cc index aa4ef352e..686865af2 100644 --- a/runtime/onert/backend/cpu/ops/OperationUtils.cc +++ b/runtime/onert/backend/cpu/ops/OperationUtils.cc @@ -256,7 +256,7 @@ uint32_t sizeOfData(OperandType type, const std::vector<int32_t> &dimensions) break; } - for (auto d : dimensions) + for (auto &&d : dimensions) { assert(d >= 0); size *= static_cast<uint32_t>(d); diff --git a/runtime/onert/backend/gpu_cl/BackendContext.cc b/runtime/onert/backend/gpu_cl/BackendContext.cc index b09319d98..9d4577013 100644 --- a/runtime/onert/backend/gpu_cl/BackendContext.cc +++ b/runtime/onert/backend/gpu_cl/BackendContext.cc @@ -90,7 +90,7 @@ FunctionMap BackendContext::genKernels() { FunctionMap fn_map; - for (auto op_ind : _data.op_order) + for (auto &&op_ind : _data.op_order) { auto fn_seq = kernel_gen->generate(op_ind); fn_map.emplace_back(op_ind, std::move(fn_seq)); diff --git a/runtime/onert/backend/gpu_cl/Config.cc b/runtime/onert/backend/gpu_cl/Config.cc index 9959a471b..9b314d679 100644 --- a/runtime/onert/backend/gpu_cl/Config.cc +++ b/runtime/onert/backend/gpu_cl/Config.cc @@ -41,7 +41,7 @@ bool Config::initialize() } } -ir::Layout Config::supportLayout(const ir::Operation &, ir::Layout) { return ir::Layout::NHWC; } +ir::Layout Config::supportLayout(const ir::IOperation &, ir::Layout) { return ir::Layout::NHWC; } } // namespace gpu_cl } // namespace backend diff --git a/runtime/onert/backend/gpu_cl/Config.h b/runtime/onert/backend/gpu_cl/Config.h index f8f94aaf4..980eb228b 100644 --- a/runtime/onert/backend/gpu_cl/Config.h +++ b/runtime/onert/backend/gpu_cl/Config.h @@ -36,7 +36,7 @@ public: public: std::string id() override { return "gpu_cl"; } bool initialize() override; - ir::Layout supportLayout(const ir::Operation &node, ir::Layout frontend_layout) override; + ir::Layout supportLayout(const ir::IOperation &node, ir::Layout frontend_layout) override; bool supportPermutation() override { return true; } bool supportDynamicTensor() override { return false; } bool supportFP16() override { return true; } diff --git a/runtime/onert/backend/gpu_cl/TensorManager.cc b/runtime/onert/backend/gpu_cl/TensorManager.cc index 02e26ed91..02e3441ca 100644 --- a/runtime/onert/backend/gpu_cl/TensorManager.cc +++ b/runtime/onert/backend/gpu_cl/TensorManager.cc @@ -103,10 +103,10 @@ ir::OperandIndexMap<std::shared_ptr<operand::CLTensor>> &TensorManager::nonconst void TensorManager::iterate(const std::function<void(const ir::OperandIndex &)> &fn) { - for (auto it : _nonconst_mgr->tensors()) + for (auto &&it : _nonconst_mgr->tensors()) fn(it.first); - for (auto it : _const_mgr->tensors()) + for (auto &&it : _const_mgr->tensors()) fn(it.first); } diff --git a/runtime/onert/backend/ruy/BackendContext.cc b/runtime/onert/backend/ruy/BackendContext.cc index 48da91b50..1943f70c7 100644 --- a/runtime/onert/backend/ruy/BackendContext.cc +++ b/runtime/onert/backend/ruy/BackendContext.cc @@ -37,7 +37,7 @@ FunctionMap BackendContext::genKernels() { FunctionMap ret; - for (auto op_ind : _data.op_order) + for (auto &&op_ind : _data.op_order) { auto fn_seq = kernel_gen->generate(op_ind); ret.emplace_back(op_ind, std::move(fn_seq)); diff --git a/runtime/onert/backend/ruy/Config.cc b/runtime/onert/backend/ruy/Config.cc index c794f89bf..fbeb2f7f0 100644 --- a/runtime/onert/backend/ruy/Config.cc +++ b/runtime/onert/backend/ruy/Config.cc @@ -25,7 +25,7 @@ namespace ruy bool Config::initialize() { return true; } -ir::Layout Config::supportLayout(const ir::Operation &, ir::Layout) { return ir::Layout::NHWC; } +ir::Layout Config::supportLayout(const ir::IOperation &, ir::Layout) { return ir::Layout::NHWC; } } // namespace ruy } // namespace backend diff --git a/runtime/onert/backend/ruy/Config.h b/runtime/onert/backend/ruy/Config.h index 9160dd5b1..fa6415b14 100644 --- a/runtime/onert/backend/ruy/Config.h +++ b/runtime/onert/backend/ruy/Config.h @@ -33,7 +33,7 @@ class Config : public IConfig public: std::string id() override { return "ruy"; } bool initialize() override; - ir::Layout supportLayout(const ir::Operation &node, ir::Layout frontend_layout) override; + ir::Layout supportLayout(const ir::IOperation &node, ir::Layout frontend_layout) override; bool supportPermutation() override { return true; } bool supportDynamicTensor() override { return true; } bool supportFP16() override { return false; } diff --git a/runtime/onert/backend/ruy/KernelGenerator.cc b/runtime/onert/backend/ruy/KernelGenerator.cc index b2bbf9bfc..ae7ec28fd 100644 --- a/runtime/onert/backend/ruy/KernelGenerator.cc +++ b/runtime/onert/backend/ruy/KernelGenerator.cc @@ -55,7 +55,7 @@ std::unique_ptr<exec::FunctionSequence> KernelGenerator::generate(ir::OperationI assert(_return_fn); // _return_fn must have been generated ret->append(std::move(_return_fn)); - for (auto ind : (op.getInputs() | ir::Remove::UNDEFINED) + op.getOutputs()) + for (auto &&ind : (op.getInputs() | ir::Remove::UNDEFINED) + op.getOutputs()) { auto portable_tensor = _tensor_reg->getPortableTensor(ind); if (portable_tensor) diff --git a/runtime/onert/backend/train/Backend.h b/runtime/onert/backend/train/Backend.h new file mode 100644 index 000000000..9b8d50a56 --- /dev/null +++ b/runtime/onert/backend/train/Backend.h @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_TRAIN_BACKEND_H__ +#define __ONERT_BACKEND_TRAIN_BACKEND_H__ + +#include "BackendContext.h" +#include "Config.h" +#include "KernelGenerator.h" + +#include <backend/Backend.h> +#include <backend/train/ITrainableBackend.h> + +#include <memory> + +namespace onert +{ +namespace backend +{ +namespace train +{ + +// TODO Unify TensorBuilder +// TODO Unify TensorRegistry +class Backend : public ::onert::backend::Backend, public backend::train::ITrainableBackend +{ +public: + Backend() : _config{std::make_shared<Config>()} {} + + std::shared_ptr<IConfig> config() const override { return _config; } + + std::unique_ptr<onert::backend::BackendContext> newContext(ContextData &&data) const override + { + return std::make_unique<DummyBackendContext>(this, std::move(data)); + } + + std::unique_ptr<backend::train::TrainableBackendContext> + newContext(backend::train::TrainableContextData &&tdata) const override + { + const auto &tgraph = *tdata.tgraph; + auto tr = std::make_shared<TensorRegistry>(); + auto tb = std::make_shared<TensorBuilder>(tr, "Bump"); + auto tdata_ptr = std::make_unique<backend::train::TrainableContextData>(std::move(tdata)); + auto context = std::make_unique<train::BackendContext>(this, std::move(tdata_ptr), tr, tb); + + context->kernel_gen = std::make_shared<train::KernelGenerator>( + tgraph, tr, context->external_context(), context->data()->optimizer); + return context; + } + +private: + std::shared_ptr<IConfig> _config; +}; + +} // namespace train +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_TRAIN_BACKEND_H__ diff --git a/runtime/onert/backend/train/BackendContext.cc b/runtime/onert/backend/train/BackendContext.cc new file mode 100644 index 000000000..3ee9a7233 --- /dev/null +++ b/runtime/onert/backend/train/BackendContext.cc @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "BackendContext.h" + +#include "TensorBuilder.h" +#include "KernelGenerator.h" + +#include <backend/basic/train/TrainableBackendContextHelpers.h> + +namespace onert +{ +namespace backend +{ +namespace train +{ + +backend::ITensorRegistry *BackendContext::genTensors() +{ + return basic::train::genTensors(*this, _tensor_builder); +} + +backend::train::ITensorRegistry *BackendContext::genTrainingTensors() +{ + const ir::train::TrainableGraph &tgraph = *trainable_graph(); + auto tensor_builder = _tensor_builder; + + tgraph.operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &obj) { + if (external_operands().contains(ind)) + return; + // NOTE Assuming there is no layout changes (Always assume NHWC or UNKNOWN) + assert(tgraph.layout() != ir::Layout::NCHW); + + // TODO Different shape of deriv tensor + ir::OperandInfo backend_info{obj.shape(), obj.typeInfo(), obj.info().memAllocType(), + obj.isConstant()}; + tensor_builder->registerBackwardTensorInfo(ind, backend_info, ir::Layout::NHWC); + }); + + // TODO Plan tensor builds to reduce peak memory usage + tgraph.operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &) { + if (tensor_builder->isRegisteredBackward(ind)) + tensor_builder->notifyBackwardFirstUse(ind); + }); + + tensor_builder->allocateBackward(); + + return _tensor_registry.get(); +} + +FunctionMap BackendContext::genKernels() +{ + train::FunctionMap ret; + + for (const auto &op_ind : _tdata->op_order) + { + auto fn_seq = kernel_gen->generate(op_ind); + ret.emplace_back(op_ind, std::move(fn_seq)); + } + + // Initialize TrainableTensors + trainable_graph()->operands().iterate( + [&](const ir::OperandIndex &ind, const ir::Operand &operand) { + if (external_operands().contains(ind) || !operand.isConstant()) + return; + + auto tensor = tensor_registry()->getNativeITensor(ind); + assert(tensor != nullptr); + + VERBOSE(FillOperandData) << "Fill data for " << ind << std::endl; + + auto data = operand.shareData(); + assert(data && data->base()); + auto trainable_tensor = dynamic_cast<TrainableTensor *>(tensor); + + if (trainable_tensor == nullptr) + throw std::runtime_error{"This tensor is not trainable tensor"}; + + trainable_tensor->fillBuffer(data); + }); + + // NOTE For memory optimization, we want to free some operand data + const_cast<ir::train::TrainableGraph &>(*_tdata->tgraph) + .operands() + .iterate([&](const ir::OperandIndex &, ir::Operand &obj) { obj.releaseData(); }); + + // TODO Enable + // for (auto &&it : ret) + // { + // auto &fn_seq = it.second; + // fn_seq->iterate([&](exec::IFunction &ifunc) { ifunc.prepare(); }); + // } + + return ret; +} + +} // namespace train +} // namespace backend +} // namespace onert diff --git a/runtime/onert/backend/train/BackendContext.h b/runtime/onert/backend/train/BackendContext.h new file mode 100644 index 000000000..b5b572b35 --- /dev/null +++ b/runtime/onert/backend/train/BackendContext.h @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_TRAIN_BACKEND_CONTEXT_H__ +#define __ONERT_BACKEND_TRAIN_BACKEND_CONTEXT_H__ + +#include <backend/train/TrainableBackendContext.h> + +#include "ExternalContext.h" +#include "KernelGenerator.h" +#include "TensorBuilder.h" + +namespace onert +{ +namespace backend +{ +namespace train +{ + +// TODO Remove this class if ExecutorFactory creates trainable context only once instead of +// replacing BackendContext +class DummyBackendContext : public backend::BackendContext +{ +public: + DummyBackendContext(const Backend *backend, ContextData &&data, + std::shared_ptr<backend::ITensorRegistry> tensor_registry = nullptr) + : backend::BackendContext(backend, std::move(data), tensor_registry) + { + } + + backend::ITensorRegistry *genTensors() override { return nullptr; } + + backend::FunctionMap genKernels() override { return backend::FunctionMap{}; } +}; + +// TODO Unify TensorBuilder +// TODO Unify TensorRegistry +class BackendContext : public onert::backend::train::TrainableBackendContext +{ +public: + BackendContext(const ITrainableBackend *backend, std::unique_ptr<TrainableContextData> &&tdata, + std::shared_ptr<backend::train::ITensorRegistry> tensor_registry = nullptr, + std::shared_ptr<TensorBuilder> tensor_builder = nullptr, + std::shared_ptr<KernelGenerator> kernel_gen = nullptr) + : onert::backend::train::TrainableBackendContext(backend, std::move(tdata), tensor_registry), + kernel_gen{kernel_gen}, + _external_context(new ExternalContext), _tensor_builder{tensor_builder} + { + } + + backend::ITensorRegistry *genTensors() override; + backend::train::ITensorRegistry *genTrainingTensors() override; + +public: + FunctionMap genKernels() override; + + std::shared_ptr<ExternalContext> external_context() { return _external_context; } + +public: + // TODO Make it private + std::shared_ptr<KernelGenerator> kernel_gen; + +private: + // NOTE ruy context has a thread pool, and when multiple ruy contexts are created, + // the thread pool is also created in duplicate + // TODO Create one ruy context for session + std::shared_ptr<ExternalContext> _external_context; + +private: + std::shared_ptr<TensorBuilder> _tensor_builder; +}; + +} // namespace train +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_TRAIN_BACKEND_CONTEXT_H__ diff --git a/runtime/onert/backend/train/CMakeLists.txt b/runtime/onert/backend/train/CMakeLists.txt new file mode 100644 index 000000000..fd50685b4 --- /dev/null +++ b/runtime/onert/backend/train/CMakeLists.txt @@ -0,0 +1,20 @@ +set(LIB_ONERT_BACKEND_TRAIN onert_backend_train) + +file(GLOB_RECURSE SOURCES "*.cc") + +add_library(${LIB_ONERT_BACKEND_TRAIN} SHARED ${SOURCES}) + +target_link_libraries(${LIB_ONERT_BACKEND_TRAIN} PRIVATE ${LIB_ONERT_BACKEND_CPU}) +target_link_libraries(${LIB_ONERT_BACKEND_TRAIN} PRIVATE onert_core) +target_link_libraries(${LIB_ONERT_BACKEND_TRAIN} PRIVATE nnfw_lib_cker nnfw_lib_misc) +target_link_libraries(${LIB_ONERT_BACKEND_TRAIN} PRIVATE nnfw_common) +target_link_libraries(${LIB_ONERT_BACKEND_TRAIN} PRIVATE nnfw_coverage) + +set_target_properties(${LIB_ONERT_BACKEND_TRAIN} PROPERTIES OUTPUT_NAME backend_train) + +if(CMAKE_BUILD_TYPE_LC STREQUAL "release") + add_custom_command(TARGET ${LIB_ONERT_BACKEND_TRAIN} POST_BUILD + COMMAND ${CMAKE_STRIP} "--strip-unneeded" $<TARGET_FILE_NAME:${LIB_ONERT_BACKEND_TRAIN}>) +endif() + +install(TARGETS ${LIB_ONERT_BACKEND_TRAIN} DESTINATION lib) diff --git a/runtime/onert/backend/train/Config.cc b/runtime/onert/backend/train/Config.cc new file mode 100644 index 000000000..57a68adc4 --- /dev/null +++ b/runtime/onert/backend/train/Config.cc @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Config.h" + +namespace onert +{ +namespace backend +{ +namespace train +{ + +bool Config::initialize() { return true; } + +ir::Layout Config::supportLayout(const ir::IOperation &, ir::Layout) { return ir::Layout::NHWC; } + +} // namespace train +} // namespace backend +} // namespace onert diff --git a/runtime/onert/backend/train/Config.h b/runtime/onert/backend/train/Config.h new file mode 100644 index 000000000..c8cf52b4d --- /dev/null +++ b/runtime/onert/backend/train/Config.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_TRAIN_CONFIG_H__ +#define __ONERT_BACKEND_TRAIN_CONFIG_H__ + +#include <backend/IConfig.h> +#include <util/ITimer.h> + +namespace onert +{ +namespace backend +{ +namespace train +{ + +class Config : public IConfig +{ +public: + std::string id() override { return "train"; } + bool initialize() override; + ir::Layout supportLayout(const ir::IOperation &node, ir::Layout frontend_layout) override; + bool supportPermutation() override { return true; } + bool supportDynamicTensor() override { return false; } + bool supportFP16() override { return false; } + + std::unique_ptr<util::ITimer> timer() override { return std::make_unique<util::CPUTimer>(); } +}; + +} // namespace train +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_TRAIN_CONFIG_H__ diff --git a/runtime/onert/backend/train/ExternalContext.h b/runtime/onert/backend/train/ExternalContext.h new file mode 100644 index 000000000..c24010ea2 --- /dev/null +++ b/runtime/onert/backend/train/ExternalContext.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_TRAIN_EXTERNAL_CONTEXT_H__ +#define __ONERT_BACKEND_TRAIN_EXTERNAL_CONTEXT_H__ + +#include <ExternalContext.h> // From cpu backend + +namespace onert +{ +namespace backend +{ +namespace train +{ + +using ExternalContext = cpu::ExternalContext; + +} // namespace train +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_TRAIN_EXTERNAL_CONTEXT_H__ diff --git a/runtime/onert/backend/train/KernelGenerator.cc b/runtime/onert/backend/train/KernelGenerator.cc new file mode 100644 index 000000000..d3114e822 --- /dev/null +++ b/runtime/onert/backend/train/KernelGenerator.cc @@ -0,0 +1,260 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "KernelGenerator.h" + +#include "ops/ConvolutionLayer.h" +#include "ops/ElementwiseActivationLayer.h" +#include "ops/FullyConnectedLayer.h" +#include "ops/LossLayer.h" +#include "ops/GradientApplier.h" +#include "ops/PoolLayer.h" +#include "ops/ReshapeLayer.h" + +#include <backend/Backend.h> +#include <backend/IConfig.h> +#include <memory> +#include <util/Utils.h> +#include <util/logging.h> +#include <exec/DynamicShapeInferer.h> + +#include <stdexcept> + +namespace onert +{ +namespace backend +{ +namespace train +{ + +namespace +{ +ops::ElementwiseActivationType +convertElementwiseActivationType(ir::operation::ElementwiseActivation::Type type_ir) +{ + switch (type_ir) + { + case ir::operation::ElementwiseActivation::Type::RELU: + return ops::ElementwiseActivationType::kReLU; + default: + throw std::runtime_error("train KernelGenerator : Not supported operation yet"); + } +} + +ops::LossType convertLossType(ir::operation::Loss::Type type_ir) +{ + switch (type_ir) + { + case ir::operation::Loss::Type::MEAN_SQUARED_ERROR: + return ops::LossType::kMSE; + default: + throw std::runtime_error("train KernelGenerator : Not supported operation yet"); + } +} + +ops::PoolType convertPoolType(ir::operation::Pool2D::PoolType type_ir) +{ + switch (type_ir) + { + // TODO Implement AVG PoolType + case ir::operation::Pool2D::PoolType::MAX: + return ops::PoolType::kMax; + default: + throw std::runtime_error("train KernelGenerator : Not supported operation yet"); + } +} + +std::unique_ptr<ops::GradientApplier> +generateGradientApplier(const std::shared_ptr<exec::train::optimizer::Optimizer> optimizer, + const IPortableTensor *gradient, ITrainableTensor *trainable) +{ + auto update_fn = std::make_unique<ops::GradientApplier>(); + update_fn->configure(optimizer, gradient, trainable); + return update_fn; +} +} // namespace + +std::unique_ptr<exec::train::TrainableFnSequence> KernelGenerator::generate(ir::OperationIndex idx) +{ + auto ret = std::make_unique<exec::train::TrainableFnSequence>(); + + const auto &op = _tgraph.operation(idx); + op.accept(*this); + assert(_return_fn); + ret->append(std::move(_return_fn)); + + for (auto &&update_fn : _update_funcs) + ret->append(std::move(update_fn)); + _update_funcs.clear(); + + for (auto &&ind : (op.getInputs() | ir::Remove::UNDEFINED) + op.getOutputs()) + { + auto portable_tensor = _tensor_reg->getPortableTensor(ind); + if (portable_tensor) + { + assert(portable_tensor->layout() == ir::Layout::NHWC); + } + auto tensor = _tensor_reg->getNonConstTensor(ind); + if (tensor) + { + tensor->increase_ref(); + } + } + return ret; +} + +KernelGenerator::KernelGenerator(const ir::train::TrainableGraph &tgraph, + const std::shared_ptr<TensorRegistry> &tensor_reg, + const std::shared_ptr<ExternalContext> &external_context, + std::shared_ptr<exec::train::optimizer::Optimizer> optimizer) + : backend::train::KernelGeneratorBase{tgraph}, _current_layout{tgraph.layout()}, + _tensor_reg{tensor_reg}, + _external_context(external_context), _optimizer{optimizer}, _update_funcs{} +{ + // DO NOTHING +} + +void KernelGenerator::visit(const ir::train::operation::Conv2D &node) +{ + // TODO Generate kernel + + // Generate GradientApplier + const auto ker_index{node.getInputs().at(ir::train::operation::Conv2D::Input::KERNEL)}; + + auto grad_tensor = _tensor_reg->getGradientTensor(ker_index); + auto ker_tensor = _tensor_reg->getTrainableTensor(ker_index); + + auto update_fn = std::make_unique<ops::GradientApplier>(); + + update_fn->configure(_optimizer, grad_tensor, ker_tensor); + + _update_funcs.emplace_back(generateGradientApplier(_optimizer, grad_tensor, ker_tensor)); +} + +void KernelGenerator::visit(const ir::train::operation::ElementwiseActivation &node) +{ + using ir::train::operation::ElementwiseActivation; + + const auto output_index{node.getOutputs().at(0)}; + const auto input_index{node.getInputs().at(ElementwiseActivation::Input::INPUT)}; + + auto output_tensor = _tensor_reg->getPortableTensor(output_index); + auto input_tensor = _tensor_reg->getPortableTensor(input_index); + + auto deriv_input_tensor = _tensor_reg->getDerivativeTensor(input_index); + auto deriv_output_tensor = _tensor_reg->getDerivativeTensor(output_index); + + auto fn = std::make_unique<ops::ElementwiseActivationLayer>(); + + fn->configure(input_tensor, output_tensor, deriv_input_tensor, deriv_output_tensor, + node.param().alpha, node.param().beta, + convertElementwiseActivationType(node.param().op_type)); + + _return_fn = std::move(fn); +} + +void KernelGenerator::visit(const ir::train::operation::FullyConnected &node) +{ + using ir::train::operation::FullyConnected; + + const auto out_index{node.getOutputs().at(0)}; + const auto in_index{node.getInputs().at(FullyConnected::Input::INPUT)}; + const auto weights_index{node.getInputs().at(FullyConnected::Input::WEIGHT)}; + const auto bias_index{node.getInputs().at(FullyConnected::Input::BIAS)}; + + auto out_tensor = _tensor_reg->getPortableTensor(out_index); + auto in_tensor = _tensor_reg->getPortableTensor(in_index); + auto weights_tensor = _tensor_reg->getTrainableTensor(weights_index); + auto bias_tensor = _tensor_reg->getTrainableTensor(bias_index); + + auto out_deriv_tensor = _tensor_reg->getDerivativeTensor(out_index); + auto in_deriv_tensor = _tensor_reg->getDerivativeTensor(in_index); + auto weights_grad_tensor = _tensor_reg->getGradientTensor(weights_index); + auto bias_grad_tensor = _tensor_reg->getGradientTensor(bias_index); + + // Generate kernel + const auto activation = node.param().activation; + const auto weights_format = node.param().weights_format; + + auto fn = std::make_unique<ops::FullyConnectedLayer>(); + + fn->configure(in_tensor, weights_tensor, bias_tensor, out_tensor, in_deriv_tensor, + weights_grad_tensor, bias_grad_tensor, out_deriv_tensor, activation, weights_format, + _external_context); + + _return_fn = std::move(fn); + + // Generate GradientAppliers + if (bias_tensor) + _update_funcs.emplace_back(generateGradientApplier(_optimizer, bias_grad_tensor, bias_tensor)); + _update_funcs.emplace_back( + generateGradientApplier(_optimizer, weights_grad_tensor, weights_tensor)); +} + +void KernelGenerator::visit(const ir::train::operation::Loss &node) +{ + using ir::train::operation::Loss; + + const auto output_index{node.getOutputs().at(0)}; + const auto y_pred_index{node.getInputs().at(Loss::Y_PRED)}; + const auto y_true_index{node.getInputs().at(Loss::Y_TRUE)}; + + auto output_tensor = _tensor_reg->getPortableTensor(output_index); + auto y_pred_tensor = _tensor_reg->getPortableTensor(y_pred_index); + auto y_true_tensor = _tensor_reg->getPortableTensor(y_true_index); + + auto deriv_y_pred_tensor = _tensor_reg->getDerivativeTensor(y_pred_index); + auto fn = std::make_unique<ops::LossLayer>(); + + fn->configure(y_pred_tensor, y_true_tensor, output_tensor, deriv_y_pred_tensor, + convertLossType(node.param().op_type)); + + _return_fn = std::move(fn); + + UNUSED_RELEASE(convertPoolType); +} + +void KernelGenerator::visit(const ir::train::operation::Reshape &node) +{ + using ir::train::operation::Reshape; + + const auto output_index{node.getOutputs().at(0)}; + const auto input_index{node.getInputs().at(ir::operation::Reshape::Input::INPUT)}; + + auto output_tensor = _tensor_reg->getPortableTensor(output_index); + auto input_tensor = _tensor_reg->getPortableTensor(input_index); + + auto output_deriv_tensor = _tensor_reg->getDerivativeTensor(output_index); + auto input_deriv_tensor = _tensor_reg->getDerivativeTensor(input_index); + + // optional 2nd input + IPortableTensor *shape_tensor = nullptr; + + if (node.getInputs().size() == 2) + { + const auto shape_index{node.getInputs().at(ir::operation::Reshape::Input::SHAPE)}; + shape_tensor = _tensor_reg->getPortableTensor(shape_index); + } + + auto fn = std::make_unique<ops::ReshapeLayer>(); + + fn->configure(input_tensor, shape_tensor, output_tensor, input_deriv_tensor, output_deriv_tensor); + _return_fn = std::move(fn); +} + +} // namespace train +} // namespace backend +} // namespace onert diff --git a/runtime/onert/backend/train/KernelGenerator.h b/runtime/onert/backend/train/KernelGenerator.h new file mode 100644 index 000000000..660dc5d70 --- /dev/null +++ b/runtime/onert/backend/train/KernelGenerator.h @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_TRAIN_KERNEL_GENERATOR_H__ +#define __ONERT_BACKEND_TRAIN_KERNEL_GENERATOR_H__ + +#include "ExternalContext.h" +#include "backend/basic/TensorRegistry.h" +#include "TensorBuilder.h" +#include "Tensor.h" + +#include <backend/train/KernelGeneratorBase.h> +#include <exec/train/IGradientApplier.h> +#include <exec/train/optimizer/Optimizer.h> +#include <ir/Operands.h> +#include <ir/Operations.h> + +namespace onert +{ +namespace backend +{ +namespace train +{ + +// TODO Unify TensorRegistry +class KernelGenerator : public backend::train::KernelGeneratorBase +{ +public: + KernelGenerator(const ir::train::TrainableGraph &tgraph, + const std::shared_ptr<TensorRegistry> &tensor_reg, + const std::shared_ptr<ExternalContext> &external_context, + std::shared_ptr<exec::train::optimizer::Optimizer> optimizer); + + std::unique_ptr<exec::train::TrainableFnSequence> generate(ir::OperationIndex op_ind) override; + + void visit(const ir::train::operation::Conv2D &) override; + void visit(const ir::train::operation::ElementwiseActivation &) override; + void visit(const ir::train::operation::FullyConnected &) override; + void visit(const ir::train::operation::Loss &) override; + void visit(const ir::train::operation::Reshape &node) override; + +private: + ir::Layout _current_layout; + std::shared_ptr<TensorRegistry> _tensor_reg; + const std::shared_ptr<ExternalContext> _external_context; + std::shared_ptr<exec::train::optimizer::Optimizer> _optimizer; + std::vector<std::unique_ptr<exec::train::IGradientApplier>> _update_funcs; +}; + +} // namespace train +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_TRAIN_KERNEL_GENERATOR_H__ diff --git a/runtime/onert/backend/train/MemoryManager.h b/runtime/onert/backend/train/MemoryManager.h new file mode 100644 index 000000000..6ac57996f --- /dev/null +++ b/runtime/onert/backend/train/MemoryManager.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_TRAIN_MEMORY_MANAGER_H__ +#define __ONERT_BACKEND_TRAIN_MEMORY_MANAGER_H__ + +#include <backend/basic/MemoryManager.h> + +namespace onert +{ +namespace backend +{ +namespace train +{ + +using MemoryManager = backend::basic::MemoryManager; + +} // namespace train +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_TRAIN_MEMORY_MANAGER_H__ diff --git a/runtime/onert/backend/train/Tensor.h b/runtime/onert/backend/train/Tensor.h new file mode 100644 index 000000000..34a3cc191 --- /dev/null +++ b/runtime/onert/backend/train/Tensor.h @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_TRAIN_TENSOR_H__ +#define __ONERT_BACKEND_TRAIN_TENSOR_H__ + +#include <backend/basic/Tensor.h> +#include <backend/basic/train/TrainableTensor.h> + +namespace onert +{ +namespace backend +{ +namespace train +{ + +// NOTE This class can be replaced with basic::Tensor if this backend supports dynamic tensors. +class Tensor : public basic::Tensor +{ +public: + Tensor() = delete; + +public: + Tensor(const ir::OperandInfo &info, const ir::Layout layout) + : basic::Tensor{info, layout, nullptr} + { + // DO NOTHING + } + +public: + bool applyShape(const ir::Shape &) override { return false; } +}; + +using TrainableTensor = basic::train::TrainableTensor; +using DerivativeTensor = Tensor; +using GradientTensor = Tensor; + +} // namespace train +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_TRAIN_TENSOR_H__ diff --git a/runtime/onert/backend/train/TensorBuilder.cc b/runtime/onert/backend/train/TensorBuilder.cc new file mode 100644 index 000000000..99e06d3a4 --- /dev/null +++ b/runtime/onert/backend/train/TensorBuilder.cc @@ -0,0 +1,138 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TensorBuilder.h" + +#include "Tensor.h" + +namespace onert +{ +namespace backend +{ +namespace train +{ + +TensorBuilder::TensorBuilder(const std::shared_ptr<TensorRegistry> &tensor_reg, + const std::string planner_id) + : _tensor_reg{tensor_reg}, _tensor_mgr{new TensorManager(tensor_reg, planner_id)} +{ + /* empty */ +} + +void TensorBuilder::registerTensorInfo(const ir::OperandIndex &index, const ir::OperandInfo &info, + ir::Layout layout) +{ + _tensor_info_map.emplace(index, info); + _as_constants[index] = info.isConstant(); + + // Train backend supports only one layout as NHWC + assert(layout == ir::Layout::NHWC); + assert(!info.isDynamic()); + + // NOTE For now, whether or not to build operands to trainable tensor depends on whether + // the corresponding operand is constant. + if (_as_constants[index]) + { + auto tensor = std::make_unique<TrainableTensor>(info, layout); + _tensor_reg->setTrainableTensor(index, std::move(tensor)); + } + else + { + auto tensor = std::make_unique<Tensor>(info, layout); + _tensor_reg->setNonConstTensor(index, std::move(tensor)); + } +} + +void TensorBuilder::registerBackwardTensorInfo(const ir::OperandIndex &index, + const ir::OperandInfo &info, ir::Layout layout) +{ + _backward_tensor_info_map.emplace(index, info); + + // Train backend supports only one layout as NHWC + assert(layout == ir::Layout::NHWC); + assert(!info.isDynamic()); + + // NOTE For now, whether or not to build operands to trainable tensor depends on whether + // the corresponding operand is constant. + assert(_as_constants[index] == info.isConstant()); + if (_as_constants[index]) + { + auto tensor = std::make_unique<GradientTensor>(info, layout); + _tensor_reg->setGradientTensor(index, std::move(tensor)); + } + else + { + auto tensor = std::make_unique<DerivativeTensor>(info, layout); + _tensor_reg->setDerivativeTensor(index, std::move(tensor)); + } +} + +void TensorBuilder::notifyFirstUse(const ir::OperandIndex &index) +{ + // TODO Support momory plan + if (_as_constants[index]) + { + _tensor_mgr->claimTrainablePlan(index); + } + else + { + _tensor_mgr->claimNonConstPlan(index); + } +} + +void TensorBuilder::notifyLastUse(const ir::OperandIndex &) +{ + // TODO Support momory plan +} + +void TensorBuilder::notifyBackwardFirstUse(const ir::OperandIndex &index) +{ + // TODO Support momory plan + if (_as_constants[index]) + { + _tensor_mgr->claimGradientPlan(index); + } + else + { + _tensor_mgr->claimDerivativePlan(index); + } +} + +bool TensorBuilder::isRegistered(const ir::OperandIndex &index) const +{ + return _tensor_info_map.find(index) != _tensor_info_map.end(); +} + +bool TensorBuilder::isRegisteredBackward(const ir::OperandIndex &index) const +{ + return _backward_tensor_info_map.find(index) != _backward_tensor_info_map.end(); +} + +void TensorBuilder::allocate(void) +{ + _tensor_mgr->allocateNonConstTensors(); + _tensor_mgr->allocateTrainableTensors(); +} + +void TensorBuilder::allocateBackward(void) +{ + _tensor_mgr->allocateDerivativeTensors(); + _tensor_mgr->allocateGradientTensors(); +} + +} // namespace train +} // namespace backend +} // namespace onert diff --git a/runtime/onert/backend/train/TensorBuilder.h b/runtime/onert/backend/train/TensorBuilder.h new file mode 100644 index 000000000..d0738fe68 --- /dev/null +++ b/runtime/onert/backend/train/TensorBuilder.h @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_TRAIN_TENSOR_BUILDER_H__ +#define __ONERT_BACKEND_TRAIN_TENSOR_BUILDER_H__ + +#include "TensorManager.h" +#include "TensorRegistry.h" + +namespace onert +{ +namespace backend +{ +namespace train +{ + +// TODO Support dynamic tensors +class TensorBuilder +{ +public: + TensorBuilder(const std::shared_ptr<TensorRegistry> &tensor_reg, const std::string planner_id); + + /** + * @brief Register tensor information to allocate on train backend + * @param[in] ind Operand index + * @param[in] info Operand information + * @param[in] layout Operand data layout + */ + void registerTensorInfo(const ir::OperandIndex &ind, const ir::OperandInfo &info, + ir::Layout backend_layout); + + /** + * @brief Register informations of tensor used only in backward to allocate on train backend + * @param[in] ind Operand index + * @param[in] info Operand information + * @param[in] layout Operand data layout + */ + void registerBackwardTensorInfo(const ir::OperandIndex &ind, const ir::OperandInfo &info, + ir::Layout backend_layout); + + // TODO Support memory plan of all tensors + void notifyFirstUse(const ir::OperandIndex &); + void notifyLastUse(const ir::OperandIndex &); + void notifyBackwardFirstUse(const ir::OperandIndex &); + + bool isRegistered(const ir::OperandIndex &) const; + bool isRegisteredBackward(const ir::OperandIndex &) const; + + void allocate(void); + void allocateBackward(void); + +private: + const std::shared_ptr<TensorRegistry> _tensor_reg; + std::unique_ptr<TensorManager> _tensor_mgr; + ir::OperandIndexMap<ir::OperandInfo> _tensor_info_map; + ir::OperandIndexMap<ir::OperandInfo> _backward_tensor_info_map; + ir::OperandIndexMap<bool> _as_constants; +}; + +} // namespace train +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_TRAIN_TENSOR_BUILDER_H__ diff --git a/runtime/onert/backend/train/TensorManager.cc b/runtime/onert/backend/train/TensorManager.cc new file mode 100644 index 000000000..50144a78f --- /dev/null +++ b/runtime/onert/backend/train/TensorManager.cc @@ -0,0 +1,155 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TensorManager.h" + +#include <util/logging.h> + +namespace +{ + +using namespace onert; + +template <typename Tensor> +void allocateMemory(backend::train::MemoryManager *mgr, + const ir::OperandIndexMap<std::unique_ptr<Tensor>> &tensors, + const std::string tensor_type) +{ + mgr->allocate(); + + for (auto &&pair : tensors) + { + const auto &index = pair.first; + auto tensor = pair.second.get(); + assert(!tensor->is_dynamic()); + + auto *buffer = mgr->getBuffer(index); + tensor->setBuffer(buffer); + VERBOSE(TensorManager) << tensor_type << index << " : " << static_cast<void *>(buffer) + << std::endl; + } +} + +} // namespace + +namespace onert +{ +namespace backend +{ +namespace train +{ + +TensorManager::TensorManager(const std::shared_ptr<TensorRegistry> ®, + const std::string planner_id) + : _nonconst_mgr{new MemoryManager(planner_id)}, _trainable_mgr{new MemoryManager(planner_id)}, + _derivative_mgr{new MemoryManager(planner_id)}, + _gradient_mgr{new MemoryManager(planner_id)}, _tensors{reg} +{ + // DO NOTHING +} + +void TensorManager::allocateNonConstTensors() +{ + allocateMemory(_nonconst_mgr.get(), _tensors->nonconst_tensors(), + std::string{" TENSOR "}); +} + +void TensorManager::allocateTrainableTensors() +{ + allocateMemory(_trainable_mgr.get(), _tensors->trainable_tensors(), + std::string{"TRAINABLE TENSOR "}); +} + +void TensorManager::allocateDerivativeTensors() +{ + allocateMemory(_derivative_mgr.get(), _tensors->derivative_tensors(), + std::string{"DERIVATIVE TENSOR "}); +} + +void TensorManager::allocateGradientTensors() +{ + allocateMemory(_gradient_mgr.get(), _tensors->gradient_tensors(), + std::string{"GRADIENT TENSOR "}); +} + +void TensorManager::claimNonConstPlan(const ir::OperandIndex &index) +{ + auto tensor = _tensors->getNonConstTensor(index); + assert(tensor && !tensor->is_dynamic()); + + auto size = tensor->total_size(); + _nonconst_mgr->claimPlan(index, size); +} + +void TensorManager::releaseNonConstPlan(const ir::OperandIndex &index) +{ + assert(_tensors->getNonConstTensor(index) && !_tensors->getNonConstTensor(index)->is_dynamic()); + + _nonconst_mgr->releasePlan(index); +} + +void TensorManager::claimTrainablePlan(const ir::OperandIndex &index) +{ + auto tensor = _tensors->getTrainableTensor(index); + assert(tensor && !tensor->is_dynamic()); + + auto size = tensor->total_size(); + _trainable_mgr->claimPlan(index, size); +} + +void TensorManager::releaseTrainablePlan(const ir::OperandIndex &index) +{ + assert(_tensors->getTrainableTensor(index) && !_tensors->getTrainableTensor(index)->is_dynamic()); + + _trainable_mgr->releasePlan(index); +} + +void TensorManager::claimDerivativePlan(const ir::OperandIndex &index) +{ + auto tensor = _tensors->getDerivativeTensor(index); + assert(tensor && !tensor->is_dynamic()); + + auto size = tensor->total_size(); + _derivative_mgr->claimPlan(index, size); +} + +void TensorManager::releaseDerivativePlan(const ir::OperandIndex &index) +{ + assert(_tensors->getDerivativeTensor(index) && + !_tensors->getDerivativeTensor(index)->is_dynamic()); + + _derivative_mgr->releasePlan(index); +} + +void TensorManager::claimGradientPlan(const ir::OperandIndex &index) +{ + auto tensor = _tensors->getGradientTensor(index); + assert(tensor && !tensor->is_dynamic()); + + auto size = tensor->total_size(); + _gradient_mgr->claimPlan(index, size); +} + +void TensorManager::releaseGradientPlan(const ir::OperandIndex &index) +{ + assert(_tensors->getGradientTensor(index) && !_tensors->getGradientTensor(index)->is_dynamic()); + + _gradient_mgr->releasePlan(index); +} + +} // namespace train +} // namespace backend +} // namespace onert diff --git a/runtime/onert/backend/train/TensorManager.h b/runtime/onert/backend/train/TensorManager.h new file mode 100644 index 000000000..06da3edd7 --- /dev/null +++ b/runtime/onert/backend/train/TensorManager.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_TRAIN_TENSOR_MANAGER_H__ +#define __ONERT_BACKEND_TRAIN_TENSOR_MANAGER_H__ + +#include "MemoryManager.h" +#include "TensorRegistry.h" + +#include <ir/OperandIndexMap.h> +#include <ir/OperandInfo.h> + +namespace onert +{ +namespace backend +{ +namespace train +{ + +class TensorManager +{ +public: + TensorManager(const std::shared_ptr<TensorRegistry> ®, const std::string planner_id); + virtual ~TensorManager() = default; + + void allocateNonConstTensors(); + void allocateTrainableTensors(); + void allocateDerivativeTensors(); + void allocateGradientTensors(); + // TODO Add member functions to deallocate tensors + + void claimNonConstPlan(const ir::OperandIndex &ind); + void releaseNonConstPlan(const ir::OperandIndex &ind); + void claimTrainablePlan(const ir::OperandIndex &ind); + void releaseTrainablePlan(const ir::OperandIndex &ind); + void claimDerivativePlan(const ir::OperandIndex &ind); + void releaseDerivativePlan(const ir::OperandIndex &ind); + void claimGradientPlan(const ir::OperandIndex &ind); + void releaseGradientPlan(const ir::OperandIndex &ind); + +private: + std::unique_ptr<MemoryManager> _nonconst_mgr; + std::unique_ptr<MemoryManager> _trainable_mgr; + std::unique_ptr<MemoryManager> _derivative_mgr; + std::unique_ptr<MemoryManager> _gradient_mgr; + const std::shared_ptr<TensorRegistry> _tensors; +}; + +} // namespace train +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_TRAIN_TENSOR_MANAGER_H__ diff --git a/runtime/onert/backend/train/TensorRegistry.h b/runtime/onert/backend/train/TensorRegistry.h new file mode 100644 index 000000000..34aeb0fcd --- /dev/null +++ b/runtime/onert/backend/train/TensorRegistry.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_TRAIN_TENSOR_REGISTRY__ +#define __ONERT_BACKEND_TRAIN_TENSOR_REGISTRY__ + +#include <backend/train/ITensorRegistry.h> + +#include "Tensor.h" + +namespace onert +{ +namespace backend +{ +namespace train +{ + +using TensorRegistry = + PortableTensorRegistryTemplate<Tensor, TrainableTensor, DerivativeTensor, GradientTensor>; + +} // namespace train +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_TRAIN_TENSOR_REGISTRY__ diff --git a/runtime/onert/backend/train/ops/ConvolutionLayer.cc b/runtime/onert/backend/train/ops/ConvolutionLayer.cc new file mode 100644 index 000000000..ac736c34d --- /dev/null +++ b/runtime/onert/backend/train/ops/ConvolutionLayer.cc @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConvolutionLayer.h" + +namespace onert +{ +namespace backend +{ +namespace train +{ +namespace ops +{ +ConvolutionLayer::ConvolutionLayer() : cpu::ops::ConvolutionLayer() +{ + // DO NOTHING +} + +ConvolutionLayer::~ConvolutionLayer() = default; + +void ConvolutionLayer::configure(const IPortableTensor *input, const IPortableTensor *kernel, + const IPortableTensor *bias, const ir::PaddingType paddingType, + const uint32_t paddingLeft, const uint32_t paddingRight, + const uint32_t paddingTop, const uint32_t paddingBottom, + const uint32_t strideWidth, const uint32_t strideHeight, + const uint32_t dilationWidthFactor, + const uint32_t dilationHeightFactor, + const ir::Activation activation, IPortableTensor *output) +{ + cpu::ops::ConvolutionLayer::configure( + input, kernel, bias, paddingType, paddingLeft, paddingRight, paddingTop, paddingBottom, + strideWidth, strideHeight, dilationWidthFactor, dilationHeightFactor, activation, output); +} + +void ConvolutionLayer::forward(bool) { cpu::ops::ConvolutionLayer::run(); } +void ConvolutionLayer::backward() +{ + // TODO Implement detail +} + +} // namespace ops +} // namespace train +} // namespace backend +} // namespace onert diff --git a/runtime/onert/backend/train/ops/ConvolutionLayer.h b/runtime/onert/backend/train/ops/ConvolutionLayer.h new file mode 100644 index 000000000..ed42a2099 --- /dev/null +++ b/runtime/onert/backend/train/ops/ConvolutionLayer.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_TRAIN_OPS_CONVOLUTIONLAYER_H__ +#define __ONERT_BACKEND_TRAIN_OPS_CONVOLUTIONLAYER_H__ + +#include <ops/ConvolutionLayer.h> + +#include <exec/train/ITrainableFunction.h> + +namespace onert +{ +namespace backend +{ +namespace train +{ +namespace ops +{ + +class ConvolutionLayer : public ::onert::exec::train::ITrainableFunction, + public cpu::ops::ConvolutionLayer +{ +public: + ConvolutionLayer(); + ~ConvolutionLayer(); + + void configure(const IPortableTensor *input, const IPortableTensor *kernel, + const IPortableTensor *bias, ir::PaddingType _paddingType, + const uint32_t paddingLeft, const uint32_t paddingRight, const uint32_t paddingTop, + const uint32_t paddingBottom, const uint32_t strideWidth, + const uint32_t strideHeight, const uint32_t dilationWidthFactor, + const uint32_t dilationHeightFactor, const ir::Activation activation, + IPortableTensor *output); + void forward(bool training) override; + void backward() override; +}; + +} // namespace ops +} // namespace train +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_TRAIN_OPS_CONVOLUTIONLAYER_H__ diff --git a/runtime/onert/backend/train/ops/ElementwiseActivationLayer.cc b/runtime/onert/backend/train/ops/ElementwiseActivationLayer.cc new file mode 100644 index 000000000..860eca43c --- /dev/null +++ b/runtime/onert/backend/train/ops/ElementwiseActivationLayer.cc @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ElementwiseActivationLayer.h" + +#include "OperationUtils.h" + +#include <cker/train/operation/ReLU.h> + +namespace onert +{ +namespace backend +{ +namespace train +{ +namespace ops +{ + +ElementwiseActivationLayer::ElementwiseActivationLayer() : cpu::ops::ElementwiseActivationLayer() +{ + // DO NOTHING +} + +void ElementwiseActivationLayer::configure(const IPortableTensor *input, IPortableTensor *output, + IPortableTensor *deriv_input, + const IPortableTensor *deriv_output, float alpha, + float beta, ElementwiseActivationType op_type) +{ + assert(input != nullptr); + assert(output != nullptr); + assert(deriv_input != nullptr); + assert(deriv_output != nullptr); + + _deriv_input = deriv_input; + _deriv_output = deriv_output; + + _op_type = op_type; + + switch (op_type) + { + case ElementwiseActivationType::kReLU: + if (input->data_type() == OperandType::FLOAT32) + { + if (alpha == std::numeric_limits<float>::infinity() && beta == 0.f) + { + cpu::ops::ElementwiseActivationLayer::configure( + input, output, alpha, beta, cpu::ops::ElementwiseActivationType::kReLU); + + _backward_kernel = [](const IPortableTensor *output, const IPortableTensor *incoming, + IPortableTensor *outgoing) { + nnfw::cker::train::ReLUGrad(getShape(output), getBuffer<float>(output), + getShape(incoming), getBuffer<float>(incoming), + getShape(outgoing), getBuffer<float>(outgoing)); + }; + } + else + { + throw std::runtime_error("train ElementwiseActivationLayer : This layer does not " + "suppport other ReLU except for ReLU(0-inf)"); + } + } + else + { + throw std::runtime_error("train ElementwiseActivationLayer: Unsupported datatype"); + } + break; + default: + throw std::runtime_error("train ElementwiseActivationLayer: Unsupported activation type yet"); + } +} + +void ElementwiseActivationLayer::forward(bool) { cpu::ops::ElementwiseActivationLayer::run(); } + +void ElementwiseActivationLayer::backward() +{ + _backward_kernel(_output, _deriv_output, _deriv_input); +} + +} // namespace ops +} // namespace train +} // namespace backend +} // namespace onert diff --git a/runtime/onert/backend/train/ops/ElementwiseActivationLayer.h b/runtime/onert/backend/train/ops/ElementwiseActivationLayer.h new file mode 100644 index 000000000..dac1efe92 --- /dev/null +++ b/runtime/onert/backend/train/ops/ElementwiseActivationLayer.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_TRAIN_OPS_ELEMENTWISEACTIVATIONLAYER_H__ +#define __ONERT_BACKEND_TRAIN_OPS_ELEMENTWISEACTIVATIONLAYER_H__ + +#include <backend/IPortableTensor.h> +#include <ops/ElementwiseActivationLayer.h> + +#include <exec/train/ITrainableFunction.h> + +namespace onert +{ +namespace backend +{ +namespace train +{ +namespace ops +{ + +enum class ElementwiseActivationType +{ + kReLU, +}; + +class ElementwiseActivationLayer : public ::onert::exec::train::ITrainableFunction, + public cpu::ops::ElementwiseActivationLayer +{ +public: + ElementwiseActivationLayer(); + + void configure(const IPortableTensor *input, IPortableTensor *output, + IPortableTensor *deriv_input, const IPortableTensor *deriv_output, float alpha, + float beta, ElementwiseActivationType op_type); + void forward(bool training) override; + void backward() override; + +private: + IPortableTensor *_deriv_input; + const IPortableTensor *_deriv_output; + + ElementwiseActivationType _op_type; + std::function<void(const IPortableTensor *output, const IPortableTensor *incoming, + IPortableTensor *outgoing)> + _backward_kernel; +}; + +} // namespace ops +} // namespace train +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_TRAIN_OPS_ELEMENTWISEACTIVATIONLAYER_H__ diff --git a/runtime/onert/backend/train/ops/FullyConnectedLayer.cc b/runtime/onert/backend/train/ops/FullyConnectedLayer.cc new file mode 100644 index 000000000..8fdc822d2 --- /dev/null +++ b/runtime/onert/backend/train/ops/FullyConnectedLayer.cc @@ -0,0 +1,209 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "FullyConnectedLayer.h" + +#include "OperationUtils.h" + +#include <cker/operation/FullyConnected.h> +#include <cker/operation/Transpose.h> +#include <cker/train/operation/FullyConnected.h> +#include <cker/train/operation/ReLU.h> + +namespace +{ + +using namespace onert; + +std::unique_ptr<backend::train::Tensor> +createTransposedTensor(const backend::IPortableTensor *origin_tensor) +{ + const auto &origin_shape = origin_tensor->getShape(); + assert(origin_shape.rank() == 2); + + auto transposed_info = origin_tensor->get_info(); + auto transposed_shape = ir::Shape{origin_shape.dim(1), origin_shape.dim(0)}; + transposed_info.shape(transposed_shape); + + return std::make_unique<backend::train::Tensor>(transposed_info, origin_tensor->layout()); +} + +} // namespace + +namespace onert +{ +namespace backend +{ +namespace train +{ +namespace ops +{ + +FullyConnectedLayer::FullyConnectedLayer() + : cpu::ops::FullyConnectedLayer{}, _grad_weights{nullptr}, _grad_bias{nullptr}, + _deriv_input{nullptr}, _deriv_output{nullptr}, _transposed_weights{nullptr}, + _transposed_input{nullptr}, _transposed_deriv_output{nullptr}, _act_deriv_output{nullptr} +{ + // DO NOTHING +} + +FullyConnectedLayer::~FullyConnectedLayer() = default; + +void FullyConnectedLayer::configure(const IPortableTensor *input, const IPortableTensor *weights, + const IPortableTensor *bias, IPortableTensor *output, + IPortableTensor *deriv_input, IPortableTensor *grad_weights, + IPortableTensor *grad_bias, const IPortableTensor *deriv_output, + ir::Activation activation, + ir::FullyConnectedWeightsFormat weights_format, + const std::shared_ptr<train::ExternalContext> &external_context) +{ + cpu::ops::FullyConnectedLayer::configure(input, weights, bias, activation, weights_format, output, + external_context); + + _deriv_input = deriv_input; + _grad_weights = grad_weights; + _grad_bias = grad_bias; + _deriv_output = deriv_output; + + if (weights_format != ir::FullyConnectedWeightsFormat::Default) + throw std::runtime_error{ + "train FullyConnectedLayer: Weight formats other than default are not supported."}; + + if (input->get_info().shape().rank() != 2 || weights->get_info().shape().rank() != 2 || + output->get_info().shape().rank() != 2 || deriv_input->get_info().shape().rank() != 2 || + grad_weights->get_info().shape().rank() != 2 || deriv_output->get_info().shape().rank() != 2) + throw std::runtime_error{ + "train FullyConnectedLayer: Input other ranks than 2 are not supported."}; + + _transposed_weights = createTransposedTensor(weights); + _transposed_weights->setBuffer(std::make_shared<basic::Allocator>(weights->total_size())); + + _transposed_input = createTransposedTensor(input); + _transposed_input->setBuffer(std::make_shared<basic::Allocator>(input->total_size())); + + _transposed_deriv_output = createTransposedTensor(deriv_output); + _transposed_deriv_output->setBuffer( + std::make_shared<basic::Allocator>(deriv_output->total_size())); + + if (activation != ir::Activation::NONE) + { + _act_deriv_output = + std::make_unique<Tensor>(_deriv_output->get_info(), _deriv_output->layout()); + _act_deriv_output->setBuffer(std::make_shared<basic::Allocator>(_deriv_output->total_size())); + } +} + +void FullyConnectedLayer::forward(bool) { cpu::ops::FullyConnectedLayer::run(); } + +void FullyConnectedLayer::backward() +{ + const auto data_type = _deriv_output->data_type(); + assert(data_type == _input->data_type()); + switch (data_type) + { + case OperandType::FLOAT32: + { + assert(data_type == _grad_weights->data_type()); + assert(data_type == _grad_bias->data_type()); + backwardFloat32(); + break; + } + default: + throw std::runtime_error{"train FullyConnectedLayer: unsupported data type"}; + } +} + +void FullyConnectedLayer::backwardFloat32() +{ + // Calculate gradient for activation + const IPortableTensor *backprop_act; + switch (_activation) + { + case ir::Activation::NONE: + backprop_act = _deriv_output; + break; + case ir::Activation::RELU: + nnfw::cker::train::ReLUGrad(getShape(_output), getBuffer<float>(_output), + getShape(_deriv_output), getBuffer<float>(_deriv_output), + getShape(_act_deriv_output.get()), + getBuffer<float>(_act_deriv_output.get())); + backprop_act = _act_deriv_output.get(); + break; + default: + throw std::runtime_error("train FullyConnectedLayer: Unsupported activation type yet"); + } + + // Initialize TransposeParams + nnfw::cker::TransposeParams transpose_param; + transpose_param.perm_count = 2; + transpose_param.perm[0] = 1; + transpose_param.perm[1] = 0; + + // Initialize FullyConnectedParams + nnfw::cker::FullyConnectedParams op_params; + float output_activation_min = 0; + float output_activation_max = 0; + CalculateActivationRange(ir::Activation::NONE, &output_activation_min, &output_activation_max); + op_params.activation = nnfw::cker::FusedActivationFunctionType::kNone; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + op_params.lhs_cacheable = false; + op_params.rhs_cacheable = false; + + // Transpose and compute gradient for input + // ∂L/∂X = fc(Incoming gradient, transposed W) + auto transposed_weights = _transposed_weights.get(); + assert(transposed_weights->getShape().rank() == 2); + nnfw::cker::Transpose(transpose_param, getShape(_weights), getBuffer<float>(_weights), + getShape(transposed_weights), getBuffer<float>(transposed_weights)); + + nnfw::cker::FullyConnected(op_params, getShape(backprop_act), getBuffer<float>(backprop_act), + getShape(transposed_weights), getBuffer<float>(transposed_weights), + getShape(nullptr), nullptr, getShape(_deriv_input), + getBuffer<float>(_deriv_input)); + + // Transpose and compute gradient for weights + // ∂L/∂W = fc(transposed incomming gradient, transposed X) + auto transposed_input = _transposed_input.get(); + assert(transposed_input->getShape().rank() == 2); + nnfw::cker::Transpose(transpose_param, getShape(_input), getBuffer<float>(_input), + getShape(transposed_input), getBuffer<float>(transposed_input)); + + auto transposed_deriv_output = _transposed_deriv_output.get(); + assert(transposed_deriv_output->getShape().rank() == 2); + nnfw::cker::Transpose(transpose_param, getShape(backprop_act), getBuffer<float>(backprop_act), + getShape(transposed_deriv_output), + getBuffer<float>(transposed_deriv_output)); + + nnfw::cker::FullyConnected(op_params, getShape(transposed_deriv_output), + getBuffer<float>(transposed_deriv_output), getShape(transposed_input), + getBuffer<float>(transposed_input), getShape(nullptr), nullptr, + getShape(_grad_weights), getBuffer<float>(_grad_weights)); + + // Compute gradient for bias + if (_bias) + { + assert(_grad_bias); + nnfw::cker::train::FullyConnectedBiasGrad(getShape(backprop_act), + getBuffer<float>(backprop_act), getShape(_grad_bias), + getBuffer<float>(_grad_bias)); + } +} + +} // namespace ops +} // namespace train +} // namespace backend +} // namespace onert diff --git a/runtime/onert/backend/train/ops/FullyConnectedLayer.h b/runtime/onert/backend/train/ops/FullyConnectedLayer.h new file mode 100644 index 000000000..1d9b30a23 --- /dev/null +++ b/runtime/onert/backend/train/ops/FullyConnectedLayer.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_TRAIN_OPS_FULLYCONNECTEDLAYER_H__ +#define __ONERT_BACKEND_TRAIN_OPS_FULLYCONNECTEDLAYER_H__ + +#include "../ExternalContext.h" +#include "../Tensor.h" + +#include <exec/train/ITrainableFunction.h> +#include <ops/FullyConnectedLayer.h> + +namespace onert +{ +namespace backend +{ +namespace train +{ +namespace ops +{ + +class FullyConnectedLayer : public exec::train::ITrainableFunction, + public cpu::ops::FullyConnectedLayer +{ +public: + FullyConnectedLayer(); + ~FullyConnectedLayer(); + +public: + void configure(const IPortableTensor *input, const IPortableTensor *weights, + const IPortableTensor *bias, IPortableTensor *output, IPortableTensor *deriv_input, + IPortableTensor *grad_weights, IPortableTensor *grad_bias, + const IPortableTensor *deriv_output, ir::Activation activation, + ir::FullyConnectedWeightsFormat weights_format, + const std::shared_ptr<train::ExternalContext> &external_context); + + void forward(bool training) override; + void backward() override; + +private: + void backwardFloat32(); + +private: + IPortableTensor *_grad_weights; + IPortableTensor *_grad_bias; + IPortableTensor *_deriv_input; + const IPortableTensor *_deriv_output; + + // TODO Optimize memory + std::unique_ptr<Tensor> _transposed_weights; + std::unique_ptr<Tensor> _transposed_input; + std::unique_ptr<Tensor> _transposed_deriv_output; + std::unique_ptr<Tensor> _act_deriv_output; +}; + +} // namespace ops +} // namespace train +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_TRAIN_OPS_FULLYCONNECTEDLAYER_H__ diff --git a/runtime/onert/backend/train/ops/GradientApplier.cc b/runtime/onert/backend/train/ops/GradientApplier.cc new file mode 100644 index 000000000..90d1bb9d0 --- /dev/null +++ b/runtime/onert/backend/train/ops/GradientApplier.cc @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "GradientApplier.h" + +#include <exec/train/optimizer/Optimizer.h> + +namespace onert +{ +namespace backend +{ +namespace train +{ +namespace ops +{ + +GradientApplier::GradientApplier() : _optimizer{nullptr}, _gradient_tensor{}, _trainable_tensor{} +{ + // DO NOTHING +} + +void GradientApplier::configure(std::shared_ptr<exec::train::optimizer::Optimizer> optimizer, + const IPortableTensor *gradient, ITrainableTensor *trainable) +{ + _optimizer = optimizer; + _gradient_tensor = gradient; + _trainable_tensor = trainable; +} + +void GradientApplier::applyGradient(uint32_t training_step) +{ + _optimizer->applyGradient( + std::forward_as_tuple(*_gradient_tensor, *_trainable_tensor, training_step)); +} + +} // namespace ops +} // namespace train +} // namespace backend +} // namespace onert diff --git a/runtime/onert/backend/train/ops/GradientApplier.h b/runtime/onert/backend/train/ops/GradientApplier.h new file mode 100644 index 000000000..94234e182 --- /dev/null +++ b/runtime/onert/backend/train/ops/GradientApplier.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_TRAIN_OPS_GRADIENT_APPLIER_H__ +#define __ONERT_BACKEND_TRAIN_OPS_GRADIENT_APPLIER_H__ + +#include <exec/train/IGradientApplier.h> + +#include <exec/train/optimizer/Optimizer.h> + +namespace onert +{ +namespace backend +{ +namespace train +{ +namespace ops +{ + +class GradientApplier : public ::onert::exec::train::IGradientApplier +{ +public: + GradientApplier(); + ~GradientApplier() = default; + + void configure(std::shared_ptr<exec::train::optimizer::Optimizer> optimizer, + const IPortableTensor *gradient, ITrainableTensor *trainable); + void applyGradient(uint32_t training_step) override; + +private: + std::shared_ptr<exec::train::optimizer::Optimizer> _optimizer; + const IPortableTensor *_gradient_tensor; + ITrainableTensor *_trainable_tensor; +}; + +} // namespace ops +} // namespace train +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_TRAIN_OPS_GRADIENT_APPLIER_H__ diff --git a/runtime/onert/backend/train/ops/LossLayer.cc b/runtime/onert/backend/train/ops/LossLayer.cc new file mode 100644 index 000000000..d004722a0 --- /dev/null +++ b/runtime/onert/backend/train/ops/LossLayer.cc @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "LossLayer.h" +#include "OperationUtils.h" + +#include <cker/train/operation/Loss.h> + +namespace onert +{ +namespace backend +{ +namespace train +{ +namespace ops +{ + +LossLayer::LossLayer() + : _y_pred(nullptr), _y_true(nullptr), _output(nullptr), _deriv_y_pred(nullptr), + _loss_type(LossType::kMSE) +{ + // DO NOTHING +} + +void LossLayer::configure(const IPortableTensor *y_pred, const IPortableTensor *y_true, + IPortableTensor *output, IPortableTensor *deriv_y_pred, + LossType loss_type) +{ + assert(y_pred != nullptr); + assert(y_true != nullptr); + assert(output != nullptr); + assert(deriv_y_pred != nullptr); + switch (loss_type) + { + case LossType::kMSE: + break; + default: + throw std::runtime_error("LossLayer: unsupported loss type"); + } + + _y_pred = y_pred; + _y_true = y_true; + _output = output; + _deriv_y_pred = deriv_y_pred; + _loss_type = loss_type; +} + +void LossLayer::forward(bool) +{ + // TODO Implement this + switch (_loss_type) + { + case LossType::kMSE: + if (_y_pred->data_type() == OperandType::FLOAT32) + { + nnfw::cker::train::MSE(getShape(_y_pred), getBuffer<float>(_y_pred), getShape(_y_true), + getBuffer<float>(_y_true), getShape(_output), + getBuffer<float>(_output)); + } + break; + default: + throw std::runtime_error("LossLayer: unsupported loss type"); + } +} + +void LossLayer::backward() +{ + switch (_loss_type) + { + case LossType::kMSE: + if (_y_pred->data_type() == OperandType::FLOAT32) + { + nnfw::cker::train::MSEGrad(getShape(_y_pred), getBuffer<float>(_y_pred), getShape(_y_true), + getBuffer<float>(_y_true), getShape(_deriv_y_pred), + getBuffer<float>(_deriv_y_pred)); + } + break; + default: + throw std::runtime_error("LossLayer: unsupported loss type"); + } +} + +} // namespace ops +} // namespace train +} // namespace backend +} // namespace onert diff --git a/runtime/onert/backend/train/ops/LossLayer.h b/runtime/onert/backend/train/ops/LossLayer.h new file mode 100644 index 000000000..18c6b315b --- /dev/null +++ b/runtime/onert/backend/train/ops/LossLayer.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_TRAIN_OPS_LOSSLAYER_H__ +#define __ONERT_BACKEND_TRAIN_OPS_LOSSLAYER_H__ + +#include <backend/IPortableTensor.h> +#include <ops/ElementwiseActivationLayer.h> + +#include <exec/train/ITrainableFunction.h> + +namespace onert +{ +namespace backend +{ +namespace train +{ +namespace ops +{ + +enum class LossType +{ + kMSE, +}; + +class LossLayer : public ::onert::exec::train::ITrainableFunction +{ +public: + LossLayer(); + + void configure(const IPortableTensor *y_pred, const IPortableTensor *y_true, + IPortableTensor *output, IPortableTensor *deriv_y_pred, LossType loss_type); + void forward(bool training) override; + void backward() override; + +private: + const IPortableTensor *_y_pred; + const IPortableTensor *_y_true; + IPortableTensor *_output; + IPortableTensor *_deriv_y_pred; + LossType _loss_type; +}; + +} // namespace ops +} // namespace train +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_TRAIN_OPS_LOSSLAYER_H__ diff --git a/runtime/onert/backend/train/ops/OperationUtils.h b/runtime/onert/backend/train/ops/OperationUtils.h new file mode 100644 index 000000000..fe0a02340 --- /dev/null +++ b/runtime/onert/backend/train/ops/OperationUtils.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_TRAIN_OPS_OPERATION_UTILS_H__ +#define __ONERT_BACKEND_TRAIN_OPS_OPERATION_UTILS_H__ + +#include <ops/OperationUtils.h> + +namespace onert +{ +namespace backend +{ +namespace train +{ +namespace ops +{ + +using OperandType = onert::ir::DataType; +using cpu::ops::getBuffer; +using cpu::ops::getShape; + +} // namespace ops +} // namespace train +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_TRAIN_OPS_OPERATION_UTILS_H__ diff --git a/runtime/onert/backend/train/ops/PoolLayer.cc b/runtime/onert/backend/train/ops/PoolLayer.cc new file mode 100644 index 000000000..c8a8422aa --- /dev/null +++ b/runtime/onert/backend/train/ops/PoolLayer.cc @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "PoolLayer.h" + +namespace onert +{ +namespace backend +{ +namespace train +{ +namespace ops +{ + +PoolLayer::PoolLayer() : cpu::ops::PoolLayer() +{ + // DO NOTHING +} + +void PoolLayer::configure(const IPortableTensor *input, const uint32_t paddingLeft, + const uint32_t paddingRight, const uint32_t paddingTop, + const uint32_t paddingBottom, const uint32_t strideWidth, + const uint32_t strideHeight, const uint32_t kernelWidth, + const uint32_t kernelHeight, const ir::Activation activation, + IPortableTensor *output, const PoolType op_type) +{ + switch (op_type) + { + case PoolType::kMax: + cpu::ops::PoolLayer::configure(input, paddingLeft, paddingRight, paddingTop, paddingBottom, + strideWidth, strideHeight, kernelWidth, kernelHeight, + activation, output, cpu::ops::PoolType::kMax); + break; + default: + throw std::runtime_error("PoolLayer: Unsupported pool type"); + } +} + +void PoolLayer::forward(bool training) +{ + if (training) + { + // TODO Implement training pool layer + } + else + { + cpu::ops::PoolLayer::run(); + } +} + +void PoolLayer::backward() +{ + // TODO Implement detail +} + +} // namespace ops +} // namespace train +} // namespace backend +} // namespace onert diff --git a/runtime/onert/backend/train/ops/PoolLayer.h b/runtime/onert/backend/train/ops/PoolLayer.h new file mode 100644 index 000000000..7f93b4a97 --- /dev/null +++ b/runtime/onert/backend/train/ops/PoolLayer.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_TRAIN_OPS_POOLLAYER_H__ +#define __ONERT_BACKEND_TRAIN_OPS_POOLLAYER_H__ + +#include <ops/PoolLayer.h> + +#include <exec/train/ITrainableFunction.h> + +namespace onert +{ +namespace backend +{ +namespace train +{ +namespace ops +{ + +enum class PoolType +{ + kMax, +}; + +class PoolLayer : public ::onert::exec::train::ITrainableFunction, public cpu::ops::PoolLayer +{ +public: + PoolLayer(); + +public: + void configure(const IPortableTensor *input, const uint32_t paddingLeft, + const uint32_t paddingRight, const uint32_t paddingTop, + const uint32_t paddingBottom, const uint32_t strideWidth, + const uint32_t strideHeight, const uint32_t kernelWidth, + const uint32_t kernelHeight, const ir::Activation activation, + IPortableTensor *output, const PoolType op_type); + void forward(bool training) override; + void backward() override; +}; + +} // namespace ops +} // namespace train +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_TRAIN_OPS_POOLLAYER_H__ diff --git a/runtime/onert/backend/train/ops/ReshapeLayer.cc b/runtime/onert/backend/train/ops/ReshapeLayer.cc new file mode 100644 index 000000000..1716174a9 --- /dev/null +++ b/runtime/onert/backend/train/ops/ReshapeLayer.cc @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ReshapeLayer.h" + +namespace onert +{ +namespace backend +{ +namespace train +{ +namespace ops +{ + +ReshapeLayer::ReshapeLayer() + : _input{nullptr}, _shape{nullptr}, _output{nullptr}, _deriv_input{nullptr}, _deriv_output{ + nullptr} +{ + // DO NOTHING +} + +void ReshapeLayer::reshapeGeneric(const IPortableTensor *input, IPortableTensor *output) +{ + size_t count = input->total_size(); + memcpy(output->buffer(), input->buffer(), count); +} + +void ReshapeLayer::configure(const IPortableTensor *input, const IPortableTensor *shape, + IPortableTensor *output, IPortableTensor *deriv_input, + const IPortableTensor *deriv_output) +{ + _input = input; + /* note : shape is optional. If not provided from model, _shape is nullptr. */ + _shape = shape; + _output = output; + + _deriv_input = deriv_input; + _deriv_output = deriv_output; +} + +void ReshapeLayer::forward(bool) { reshapeGeneric(_input, _output); } + +void ReshapeLayer::backward() { reshapeGeneric(_deriv_output, _deriv_input); } + +} // namespace ops +} // namespace train +} // namespace backend +} // namespace onert diff --git a/runtime/onert/backend/train/ops/ReshapeLayer.h b/runtime/onert/backend/train/ops/ReshapeLayer.h new file mode 100644 index 000000000..e4f017225 --- /dev/null +++ b/runtime/onert/backend/train/ops/ReshapeLayer.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_TRAIN_OPS_RESHAPELAYER_H__ +#define __ONERT_BACKEND_TRAIN_OPS_RESHAPELAYER_H__ + +#include <backend/IPortableTensor.h> + +#include <exec/train/ITrainableFunction.h> + +namespace onert +{ +namespace backend +{ +namespace train +{ +namespace ops +{ + +class ReshapeLayer : public ::onert::exec::train::ITrainableFunction +{ +public: + ReshapeLayer(); + +public: + void configure(const IPortableTensor *input, const IPortableTensor *shape, + IPortableTensor *output, IPortableTensor *deriv_input, + const IPortableTensor *deriv_output); + void forward(bool training) override; + void backward() override; + +private: + void reshapeGeneric(const IPortableTensor *input, IPortableTensor *output); + +private: + const IPortableTensor *_input; + const IPortableTensor *_shape; + IPortableTensor *_output; + + IPortableTensor *_deriv_input; + const IPortableTensor *_deriv_output; +}; + +} // namespace ops +} // namespace train +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_TRAIN_OPS_RESHAPELAYER_H__ diff --git a/runtime/onert/backend/train/train.cc b/runtime/onert/backend/train/train.cc new file mode 100644 index 000000000..a77f71c43 --- /dev/null +++ b/runtime/onert/backend/train/train.cc @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Backend.h" + +extern "C" { + +onert::backend::Backend *onert_backend_create() { return new onert::backend::train::Backend; } + +void onert_backend_destroy(onert::backend::Backend *backend) { delete backend; } +} diff --git a/runtime/onert/backend/trix/BackendContext.cc b/runtime/onert/backend/trix/BackendContext.cc index 39048f2be..51571b458 100644 --- a/runtime/onert/backend/trix/BackendContext.cc +++ b/runtime/onert/backend/trix/BackendContext.cc @@ -37,7 +37,7 @@ FunctionMap BackendContext::genKernels() { FunctionMap ret; - for (auto op_ind : _data.op_order) + for (auto &&op_ind : _data.op_order) { auto fn_seq = kernel_gen->generate(op_ind); ret.emplace_back(op_ind, std::move(fn_seq)); diff --git a/runtime/onert/backend/trix/Config.cc b/runtime/onert/backend/trix/Config.cc index c23326423..b536fd58c 100644 --- a/runtime/onert/backend/trix/Config.cc +++ b/runtime/onert/backend/trix/Config.cc @@ -25,7 +25,7 @@ namespace trix bool Config::initialize() { return true; } -ir::Layout Config::supportLayout(const ir::Operation &, ir::Layout) { return ir::Layout::NHWC; } +ir::Layout Config::supportLayout(const ir::IOperation &, ir::Layout) { return ir::Layout::NHWC; } } // namespace trix } // namespace backend diff --git a/runtime/onert/backend/trix/Config.h b/runtime/onert/backend/trix/Config.h index 799047d6f..310c57b29 100644 --- a/runtime/onert/backend/trix/Config.h +++ b/runtime/onert/backend/trix/Config.h @@ -33,7 +33,7 @@ class Config : public IConfig public: std::string id() override { return "trix"; } bool initialize() override; - ir::Layout supportLayout(const ir::Operation &node, ir::Layout frontend_layout) override; + ir::Layout supportLayout(const ir::IOperation &node, ir::Layout frontend_layout) override; bool supportPermutation() override { return true; } bool supportDynamicTensor() override { return false; } bool supportFP16() override { return false; } diff --git a/runtime/onert/backend/trix/DevContext.cc b/runtime/onert/backend/trix/DevContext.cc index 059514878..4d58a7d9f 100644 --- a/runtime/onert/backend/trix/DevContext.cc +++ b/runtime/onert/backend/trix/DevContext.cc @@ -71,7 +71,13 @@ DevContext::~DevContext() ModelID DevContext::registerModel(const std::string &model_file_path) { - auto meta = getNPUmodel_metadata(model_file_path.c_str(), false); + if (_dev_handles.size() == 0) + { + throw std::runtime_error("No npu device is available"); + } + + std::unique_ptr<npubin_meta, decltype(&free)> meta( + getNPUmodel_metadata(model_file_path.c_str(), false), free); if (meta == nullptr) { @@ -83,7 +89,7 @@ ModelID DevContext::registerModel(const std::string &model_file_path) file_info.filepath = model_file_path.c_str(); file_info.size = meta->size; - ModelID model_id; + ModelID model_id = 0; for (uint32_t dev_num = 0; dev_num < _dev_handles.size(); ++dev_num) { @@ -97,7 +103,7 @@ ModelID DevContext::registerModel(const std::string &model_file_path) if (dev_num == 0) { model_id = model_id_at_device; - _meta_map[model_id_at_device] = std::shared_ptr<npubin_meta>(meta); + _meta_map[model_id_at_device] = std::shared_ptr<npubin_meta>(std::move(meta)); } else { diff --git a/runtime/onert/backend/xnnpack/BackendContext.cc b/runtime/onert/backend/xnnpack/BackendContext.cc index c52e275aa..b555a4ac6 100644 --- a/runtime/onert/backend/xnnpack/BackendContext.cc +++ b/runtime/onert/backend/xnnpack/BackendContext.cc @@ -37,7 +37,7 @@ FunctionMap BackendContext::genKernels() { FunctionMap ret; - for (auto op_ind : _data.op_order) + for (auto &&op_ind : _data.op_order) { auto fn_seq = kernel_gen->generate(op_ind); ret.emplace_back(op_ind, std::move(fn_seq)); diff --git a/runtime/onert/backend/xnnpack/Config.cc b/runtime/onert/backend/xnnpack/Config.cc index 8783ff390..cc27f717f 100644 --- a/runtime/onert/backend/xnnpack/Config.cc +++ b/runtime/onert/backend/xnnpack/Config.cc @@ -37,7 +37,7 @@ bool Config::initialize() return true; } -ir::Layout Config::supportLayout(const ir::Operation &, ir::Layout) { return ir::Layout::NHWC; } +ir::Layout Config::supportLayout(const ir::IOperation &, ir::Layout) { return ir::Layout::NHWC; } } // namespace xnnpack } // namespace backend diff --git a/runtime/onert/backend/xnnpack/Config.h b/runtime/onert/backend/xnnpack/Config.h index 2cf7406e5..4c5fba587 100644 --- a/runtime/onert/backend/xnnpack/Config.h +++ b/runtime/onert/backend/xnnpack/Config.h @@ -36,7 +36,7 @@ public: public: std::string id() override { return "xnnpack"; } bool initialize() override; - ir::Layout supportLayout(const ir::Operation &node, ir::Layout frontend_layout) override; + ir::Layout supportLayout(const ir::IOperation &node, ir::Layout frontend_layout) override; bool supportPermutation() override { return true; } bool supportDynamicTensor() override { return true; } bool supportFP16() override { return false; } diff --git a/runtime/onert/backend/xnnpack/KernelGenerator.cc b/runtime/onert/backend/xnnpack/KernelGenerator.cc index 9580bec8c..25f3fd238 100644 --- a/runtime/onert/backend/xnnpack/KernelGenerator.cc +++ b/runtime/onert/backend/xnnpack/KernelGenerator.cc @@ -69,7 +69,7 @@ std::unique_ptr<exec::FunctionSequence> KernelGenerator::generate(ir::OperationI assert(_return_fn); // _return_fn must have been generated ret->append(std::move(_return_fn)); - for (auto ind : (op.getInputs() | ir::Remove::UNDEFINED) + op.getOutputs()) + for (auto &&ind : (op.getInputs() | ir::Remove::UNDEFINED) + op.getOutputs()) { auto portable_tensor = _tensor_reg->getPortableTensor(ind); if (portable_tensor) diff --git a/runtime/onert/core/CMakeLists.txt b/runtime/onert/core/CMakeLists.txt index 8041ab5bc..8ff3fdf42 100644 --- a/runtime/onert/core/CMakeLists.txt +++ b/runtime/onert/core/CMakeLists.txt @@ -2,7 +2,19 @@ file(GLOB_RECURSE SOURCES "src/*.cc") file(GLOB_RECURSE TESTS "*.test.cc") list(REMOVE_ITEM SOURCES ${TESTS}) -nnfw_find_package(Ruy REQUIRED) +if(NOT BUILD_MINMAX_H5DUMPER) + file(GLOB_RECURSE SRC_TO_REMOVE "src/dumper/h5/*.cc") + list(REMOVE_ITEM SOURCES ${SRC_TO_REMOVE}) + file(GLOB_RECURSE SRC_TO_REMOVE "src/exec/MinMaxRecorder.cc") + list(REMOVE_ITEM SOURCES ${SRC_TO_REMOVE}) +endif(NOT BUILD_MINMAX_H5DUMPER) + +if(NOT ENABLE_ONERT_TRAIN) + file(GLOB_RECURSE SRC_TRAIN "src/*/train/*.cc") + list(REMOVE_ITEM SOURCES ${SRC_TRAIN}) + file(GLOB_RECURSE SRC_TRAIN "src/*/*/train/*.cc") + list(REMOVE_ITEM SOURCES ${SRC_TRAIN}) +endif(NOT ENABLE_ONERT_TRAIN) add_library(onert_core SHARED ${SOURCES}) set_target_properties(onert_core PROPERTIES POSITION_INDEPENDENT_CODE ON) @@ -18,9 +30,26 @@ target_link_libraries(onert_core PRIVATE nnfw_lib_misc nnfw_lib_cker) target_link_libraries(onert_core PRIVATE nnfw_common) target_link_libraries(onert_core PRIVATE nnfw_coverage) target_link_libraries(onert_core PRIVATE dl ${LIB_PTHREAD}) + +# Ruy +nnfw_find_package(Ruy REQUIRED) target_link_libraries(onert_core PRIVATE ruy) target_link_libraries(onert_core INTERFACE ruy_instrumentation) +# H5 Minmax Dumper +if(BUILD_MINMAX_H5DUMPER) + nnfw_find_package(HDF5 REQUIRED) + target_compile_definitions(onert_core PRIVATE MINMAX_H5DUMPER=1) + target_include_directories(onert_core PRIVATE ${HDF5_INCLUDE_DIRS}) + target_link_libraries(onert_core PRIVATE ${HDF5_CXX_LIBRARIES}) +endif(BUILD_MINMAX_H5DUMPER) + +# Training feature +# Use public to use this flag on all modules and tests +if(ENABLE_ONERT_TRAIN) + target_compile_definitions(onert_core PUBLIC ONERT_TRAIN) +endif(ENABLE_ONERT_TRAIN) + if(CMAKE_BUILD_TYPE_LC STREQUAL "release") add_custom_command(TARGET onert_core POST_BUILD COMMAND ${CMAKE_STRIP} "--strip-unneeded" $<TARGET_FILE_NAME:onert_core>) diff --git a/runtime/onert/core/include/backend/IConfig.h b/runtime/onert/core/include/backend/IConfig.h index 409fd3d9f..e297c5f1e 100644 --- a/runtime/onert/core/include/backend/IConfig.h +++ b/runtime/onert/core/include/backend/IConfig.h @@ -18,7 +18,7 @@ #define __ONERT_BACKEND_ICONFIG_H__ #include "ir/Layout.h" -#include "ir/Operation.h" +#include "ir/IOperation.h" #include "util/ITimer.h" #include <memory> @@ -48,11 +48,11 @@ struct IConfig /** * @brief Returns supported layout for the given \p node and \p frontend_layout * - * @param node Operation + * @param node IOperation * @param frontend_layout The layout defined in the model * @return ir::Layout The layout that the backend kernel actually uses */ - virtual ir::Layout supportLayout(const ir::Operation &node, ir::Layout frontend_layout) = 0; + virtual ir::Layout supportLayout(const ir::IOperation &node, ir::Layout frontend_layout) = 0; /** * @brief The function that is called after each Operation run on profiling mode. * This may be useful for profiling GPU-based or special computing units. diff --git a/runtime/onert/core/include/backend/basic/BackendContextHelpers.h b/runtime/onert/core/include/backend/basic/BackendContextHelpers.h index 970a9f71c..9992ca140 100644 --- a/runtime/onert/core/include/backend/basic/BackendContextHelpers.h +++ b/runtime/onert/core/include/backend/basic/BackendContextHelpers.h @@ -226,13 +226,15 @@ template <typename T_BackendContext> ITensorRegistry *genTensors(T_BackendContex return ctx.tensor_registry.get(); } -inline void initConsts(BackendContext &ctx) +inline void initConsts(const ir::Operands &operands, + const util::Set<ir::OperandIndex> &external_operands, + ITensorRegistry *tensor_registry) { - ctx.graph()->operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &operand) { - if (ctx.external_operands().contains(ind) || !operand.isConstant()) + operands.iterate([&](const ir::OperandIndex &ind, const ir::Operand &operand) { + if (external_operands.contains(ind) || !operand.isConstant()) return; - auto tensor = ctx.tensor_registry->getNativeITensor(ind); + auto tensor = tensor_registry->getNativeITensor(ind); assert(tensor != nullptr); VERBOSE(FillOperandData) << "Fill data for " << ind << std::endl; @@ -248,6 +250,11 @@ inline void initConsts(BackendContext &ctx) }); } +inline void initConsts(BackendContext &ctx) +{ + initConsts(ctx.graph()->operands(), ctx.external_operands(), ctx.tensor_registry.get()); +} + } // namespace basic } // namespace backend } // namespace onert diff --git a/runtime/onert/core/include/backend/basic/DynamicTensorManager.h b/runtime/onert/core/include/backend/basic/DynamicTensorManager.h index 92d8ee3ab..0535dd5e1 100644 --- a/runtime/onert/core/include/backend/basic/DynamicTensorManager.h +++ b/runtime/onert/core/include/backend/basic/DynamicTensorManager.h @@ -21,9 +21,11 @@ #include "TensorRegistry.h" #include <ir/OperandInfo.h> -#include <ir/Operation.h> +#include <ir/IOperation.h> #include <ir/Index.h> +#include <unordered_set> + namespace onert { namespace backend diff --git a/runtime/onert/core/include/backend/basic/StaticTensorManager.h b/runtime/onert/core/include/backend/basic/StaticTensorManager.h index f35dbdfe4..6088306ec 100644 --- a/runtime/onert/core/include/backend/basic/StaticTensorManager.h +++ b/runtime/onert/core/include/backend/basic/StaticTensorManager.h @@ -38,6 +38,8 @@ class StaticTensorManager public: StaticTensorManager(const std::shared_ptr<TensorRegistry> ®, DynamicTensorManager *dynamic_tensor_manager); + StaticTensorManager(const std::shared_ptr<TensorRegistry> ®, const std::string planner_id, + DynamicTensorManager *dynamic_tensor_manager); virtual ~StaticTensorManager() = default; void allocateNonconsts(void); diff --git a/runtime/onert/core/include/backend/basic/TensorBuilder.h b/runtime/onert/core/include/backend/basic/TensorBuilder.h index a8014e55d..8ea114912 100644 --- a/runtime/onert/core/include/backend/basic/TensorBuilder.h +++ b/runtime/onert/core/include/backend/basic/TensorBuilder.h @@ -38,6 +38,7 @@ class TensorBuilder { public: TensorBuilder(const std::shared_ptr<TensorRegistry> &tensor_reg); + TensorBuilder(const std::shared_ptr<TensorRegistry> &tensor_reg, const std::string planner_id); /** * @brief Register tensor information to allocate on CPU backend diff --git a/runtime/onert/core/include/backend/basic/train/TrainableBackendContextHelpers.h b/runtime/onert/core/include/backend/basic/train/TrainableBackendContextHelpers.h new file mode 100644 index 000000000..e1d3b034a --- /dev/null +++ b/runtime/onert/core/include/backend/basic/train/TrainableBackendContextHelpers.h @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_BASIC_TRAIN_TRAINABLE_BACKEND_CONTEXT_HELPERS_H__ +#define __ONERT_BACKEND_BASIC_TRAIN_TRAINABLE_BACKEND_CONTEXT_HELPERS_H__ + +#include "backend/basic/BackendContextHelpers.h" +#include "backend/train/TrainableBackendContext.h" + +namespace onert +{ +namespace backend +{ +namespace basic +{ +namespace train +{ + +// TODO Unify with the above `getTensors()` function in `BackendContextHelpers.h` +template <typename TensorBuilder> +ITensorRegistry *genTensors(backend::train::TrainableBackendContext &ctx, + const std::shared_ptr<TensorBuilder> &tensor_builder) +{ + const auto &tgraph = *ctx.trainable_graph(); + + auto model_io = + (tgraph.getInputs() + tgraph.getOutputs()) | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED; + tgraph.operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &obj) { + if (ctx.external_operands().contains(ind)) + return; + // NOTE Assuming there is no layout changes (Always assume NHWC or UNKNOWN) + assert(tgraph.layout() != ir::Layout::NCHW); + ir::OperandInfo backend_info{obj.shape(), obj.typeInfo(), obj.info().memAllocType(), + obj.isConstant()}; + tensor_builder->registerTensorInfo(ind, backend_info, ir::Layout::NHWC); + }); + + // For the executors that does not have fixed linear execution order: + // To make tensors never be deallocated, this is a workaround to use static memory planner + tgraph.operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &) { + if (tensor_builder->isRegistered(ind)) + tensor_builder->notifyFirstUse(ind); + }); + + tensor_builder->allocate(); + + return ctx.tensor_registry().get(); +} + +} // namespace train +} // namespace basic +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_BASIC_TRAIN_TRAINABLE_BACKEND_CONTEXT_HELPERS_H__ diff --git a/runtime/onert/core/include/backend/basic/train/TrainableTensor.h b/runtime/onert/core/include/backend/basic/train/TrainableTensor.h new file mode 100644 index 000000000..e985f2930 --- /dev/null +++ b/runtime/onert/core/include/backend/basic/train/TrainableTensor.h @@ -0,0 +1,98 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_BASIC_TRAIN_TRAINABLE_TENSOR_H__ +#define __ONERT_BACKEND_BASIC_TRAIN_TRAINABLE_TENSOR_H__ + +#include "backend/train/ITrainableTensor.h" + +#include "backend/basic/Tensor.h" + +namespace onert +{ +namespace backend +{ +namespace basic +{ +namespace train +{ + +class TrainableTensor : public backend::train::ITrainableTensor +{ +public: + TrainableTensor() = delete; + virtual ~TrainableTensor() = default; + +public: + TrainableTensor(const ir::OperandInfo &info, const ir::Layout layout) + : ITrainableTensor{info}, _tensor{info, layout, nullptr}, _opt_vars{} + { + // DO NOTHING + } + +public: + /** + * @brief Set the Buffer object. This method is called for static and non-const tensor + */ + void setBuffer(uint8_t *buffer) { _tensor.setBuffer(buffer); } + +public: + uint8_t *buffer() const override { return _tensor.buffer(); } + /** + * @brief Get dimension by index + * + * @param index Index to get diemension + * @return size_t Dimension at index + * @note N : dimension(0) + * H : dimension(1) + * W : dimension(2) + * C : dimension(3) + */ + size_t total_size() const override { return _tensor.total_size(); } + size_t calcOffset(const ir::Coordinates &coords) const override + { + return _tensor.calcOffset(coords); + } + ir::Layout layout() const override { return _tensor.layout(); } + ir::DataType data_type() const override { return _tensor.data_type(); } + bool is_constant() const override { return _tensor.is_constant(); } + bool is_dynamic() const override { return _tensor.is_dynamic(); } + ir::Shape getShape() const override { return _tensor.getShape(); }; + const ir::OperandInfo &get_info() { return _tensor.get_info(); } + +public: + std::vector<ITensor *> optVars() override; + void appendOptVar(std::unique_ptr<Tensor> opt_var) { _opt_vars.emplace_back(std::move(opt_var)); } + +public: + void fillBuffer(const std::shared_ptr<ir::Data> &data); + +private: + using ITensor::setShape; + using ITensor::set_dynamic; + using ITensor::applyShape; + +protected: + Tensor _tensor; + std::vector<std::unique_ptr<Tensor>> _opt_vars; //< Optimizer variables +}; + +} // namespace train +} // namespace basic +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_BASIC_TRAIN_TRAINABLE_TENSOR_H__ diff --git a/runtime/onert/core/include/backend/train/ITensorRegistry.h b/runtime/onert/core/include/backend/train/ITensorRegistry.h new file mode 100644 index 000000000..72b8a35db --- /dev/null +++ b/runtime/onert/core/include/backend/train/ITensorRegistry.h @@ -0,0 +1,246 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_TRAIN_ITENSOR_REGISTRY_H__ +#define __ONERT_BACKEND_TRAIN_ITENSOR_REGISTRY_H__ + +#include "backend/ITensorRegistry.h" + +namespace onert +{ +namespace backend +{ +namespace train +{ + +class ITensorRegistry : public backend::ITensorRegistry +{ +public: + /** + * @brief Returns pointer of ITensor among native and migrant tensors, not derivative and gradient + * + */ + using backend::ITensorRegistry::getITensor; + + /** + * @brief Returns pointer of ITensor among native tensors, not derivative and gradient + * + */ + using backend::ITensorRegistry::getNativeITensor; + + /** + * @brief Returns pointer of ITensor for derivative + * + * @note Return tensor cannot be used longer than dynamic tensor manager + */ + virtual ITensor *getDerivativeITensor(const ir::OperandIndex &) = 0; + + /** + * @brief Returns pointer of ITensor for gradient + * + * @note Returned tensor cannot be used longer than dynamic tensor manager + */ + virtual ITensor *getGradientITensor(const ir::OperandIndex &) = 0; +}; + +} // namespace train +} // namespace backend +} // namespace onert + +namespace onert +{ +namespace backend +{ +namespace train +{ + +template <typename Tensor, typename TrainableTensor, typename DerivativeTensor, + typename GradientTensor> +class PortableTensorRegistryTemplate : public backend::train::ITensorRegistry +{ +public: + using TrainingTensors = std::tuple<TrainableTensor *, GradientTensor *>; + +public: + ITensor *getITensor(const ir::OperandIndex &index) override + { + auto _migrant_tensor = _migrant.find(index); + if (_migrant_tensor != _migrant.end()) + return _migrant_tensor->second; + return getNativeITensor(index); + } + + ITensor *getNativeITensor(const ir::OperandIndex &index) override + { + ITensor *tensor = getTrainableTensor(index); + if (tensor == nullptr) + tensor = getNonConstTensor(index); + return tensor; + } + + ITensor *getDerivativeITensor(const ir::OperandIndex &index) override + { + return getDerivativeTensor(index); + } + + ITensor *getGradientITensor(const ir::OperandIndex &index) override + { + return getGradientTensor(index); + } + + IPortableTensor *getPortableTensor(const ir::OperandIndex &index) + { + auto tensor = _trainable.find(index); + if (tensor != _trainable.end()) + { + if (tensor->second) + return tensor->second.get(); + } + return getNonConstTensor(index); + } + + Tensor *getNonConstTensor(const ir::OperandIndex &index) + { + auto tensor = _non_const.find(index); + if (tensor != _non_const.end()) + return tensor->second.get(); + return nullptr; + } + + TrainableTensor *getTrainableTensor(const ir::OperandIndex &index) + { + auto tensor = _trainable.find(index); + if (tensor != _trainable.end()) + return tensor->second.get(); + + return nullptr; + } + + DerivativeTensor *getDerivativeTensor(const ir::OperandIndex &index) + { + auto tensor = _derivative.find(index); + if (tensor != _derivative.end()) + return tensor->second.get(); + return nullptr; + } + + GradientTensor *getGradientTensor(const ir::OperandIndex &index) + { + auto tensor = _gradient.find(index); + if (tensor != _gradient.end()) + return tensor->second.get(); + return nullptr; + } + + TrainingTensors getTrainingTensors(const ir::OperandIndex &index) + { + auto trainable = getTrainableTensor(index); + if (trainable == nullptr) + throw std::runtime_error{ + "Tried to get a trainable tensor but the corresponding tensor does not exist."}; + + auto gradient = getGradientTensor(index); + if (gradient == nullptr) + throw std::runtime_error{ + "Tried to get a gradient tensor but the corresponding tensor does not exist."}; + + return TrainingTensors{std::make_pair(trainable, gradient)}; + } + + bool setMigrantTensor(const ir::OperandIndex &index, IPortableTensor *tensor) override + { + assert(tensor != nullptr); + if (getITensor(index) != nullptr) + throw std::runtime_error{ + "Tried to set a trainable tensor but another tensor already exists."}; + + _migrant[index] = tensor; + return true; + } + + void setNonConstTensor(const ir::OperandIndex &index, std::unique_ptr<Tensor> tensor) + { + assert(tensor != nullptr); + if (getITensor(index) != nullptr) + throw std::runtime_error{ + "Tried to set a trainable tensor but another tensor already exists."}; + + _non_const[index] = std::move(tensor); + } + + void setTrainableTensor(const ir::OperandIndex &index, std::unique_ptr<TrainableTensor> tensor) + { + assert(tensor != nullptr); + if (getITensor(index) != nullptr) + throw std::runtime_error{ + "Tried to set a trainable tensor but another tensor already exists."}; + + _trainable[index] = std::move(tensor); + } + + void setDerivativeTensor(const ir::OperandIndex &index, std::unique_ptr<DerivativeTensor> tensor) + { + assert(tensor != nullptr); + auto itr = _derivative.find(index); + if (itr != _derivative.end()) + throw std::runtime_error{ + "Tried to set a derivative tensor but another derivative tensor already exists."}; + + _derivative[index] = std::move(tensor); + } + + void setGradientTensor(const ir::OperandIndex &index, std::unique_ptr<GradientTensor> tensor) + { + assert(tensor != nullptr); + auto itr = _gradient.find(index); + if (itr != _gradient.end()) + throw std::runtime_error{ + "Tried to set a gradient tensor but another gradient tensor already exists."}; + + _gradient[index] = std::move(tensor); + } + + const ir::OperandIndexMap<std::unique_ptr<TrainableTensor>> &trainable_tensors() + { + return _trainable; + } + const ir::OperandIndexMap<std::unique_ptr<Tensor>> &nonconst_tensors() { return _non_const; } + const ir::OperandIndexMap<std::unique_ptr<Tensor>> &derivative_tensors() { return _derivative; } + const ir::OperandIndexMap<std::unique_ptr<GradientTensor>> &gradient_tensors() + { + return _gradient; + } + +private: + // Native tensors + ir::OperandIndexMap<std::unique_ptr<Tensor>> _non_const; + ir::OperandIndexMap<std::unique_ptr<TrainableTensor>> _trainable; + + // Migrant tensors + ir::OperandIndexMap<IPortableTensor *> _migrant; + + // Tensors for backpropagation + ir::OperandIndexMap<std::unique_ptr<DerivativeTensor>> _derivative; + + // Tensors for updating trainable tensors + ir::OperandIndexMap<std::unique_ptr<GradientTensor>> _gradient; +}; + +} // namespace train +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_TRAIN_ITENSOR_REGISTRY_H__ diff --git a/runtime/onert/core/include/backend/train/ITrainableBackend.h b/runtime/onert/core/include/backend/train/ITrainableBackend.h new file mode 100644 index 000000000..76e394216 --- /dev/null +++ b/runtime/onert/core/include/backend/train/ITrainableBackend.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_TRAIN_ITRAINABLE_BACKEND_H__ +#define __ONERT_BACKEND_TRAIN_ITRAINABLE_BACKEND_H__ + +#include <memory> + +namespace onert +{ +namespace backend +{ +namespace train +{ + +class TrainableBackendContext; +struct TrainableContextData; + +struct ITrainableBackend +{ + virtual ~ITrainableBackend() = default; + virtual std::unique_ptr<TrainableBackendContext> newContext(TrainableContextData &&) const = 0; +}; + +} // namespace train +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_TRAIN_ITRAINABLE_BACKEND_H__ diff --git a/runtime/onert/core/include/backend/train/ITrainableTensor.h b/runtime/onert/core/include/backend/train/ITrainableTensor.h new file mode 100644 index 000000000..9d7ab345b --- /dev/null +++ b/runtime/onert/core/include/backend/train/ITrainableTensor.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_TRAIN_ITRAINABLE_TENSOR_H__ +#define __ONERT_BACKEND_TRAIN_ITRAINABLE_TENSOR_H__ + +#include "backend/IPortableTensor.h" + +namespace onert +{ +namespace backend +{ +namespace train +{ + +/** + * @brief A tensor class that can be trained + * + */ +// NOTE It is more appropriate to inherit ITensor, but there is no easy way +// except for virtual inheritance. +class ITrainableTensor : public IPortableTensor +{ +public: + using IPortableTensor::IPortableTensor; + virtual ~ITrainableTensor() = default; + + /** + * @brief Get optimizer variables of this trainable tensor + * + * @return Optimizer variables + */ + virtual std::vector<ITensor *> optVars() = 0; +}; + +} // namespace train +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_TRAIN_ITRAINABLE_TENSOR_H__ diff --git a/runtime/onert/core/include/backend/train/KernelGeneratorBase.h b/runtime/onert/core/include/backend/train/KernelGeneratorBase.h new file mode 100644 index 000000000..b5031a5cd --- /dev/null +++ b/runtime/onert/core/include/backend/train/KernelGeneratorBase.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_TRAIN_KERNEL_GENERATOR_BASE_H__ +#define __ONERT_BACKEND_TRAIN_KERNEL_GENERATOR_BASE_H__ + +#include <memory> + +#include "backend/ITensorRegistry.h" +#include "exec/train/TrainableFnSequence.h" +#include "ir/train/TrainableGraph.h" +#include "ir/train/TrainableOperationVisitor.h" + +namespace onert +{ +namespace backend +{ +namespace train +{ + +class KernelGeneratorBase : public ir::train::TrainableOperationVisitor +{ +public: + virtual ~KernelGeneratorBase() = default; + KernelGeneratorBase(const ir::train::TrainableGraph &tgraph) : _tgraph{tgraph} {} + + virtual std::unique_ptr<exec::train::TrainableFnSequence> generate(ir::OperationIndex ind) = 0; + +protected: +#define OP(InternalName) \ + void visit(const ir::train::operation::InternalName &) override \ + { \ + throw std::runtime_error("KernelGenerator: NYI for operation '" #InternalName "'"); \ + } +#include "ir/train/Operations.lst" +#undef OP + +protected: + const ir::train::TrainableGraph &_tgraph; + std::unique_ptr<exec::train::ITrainableFunction> _return_fn; +}; + +} // namespace train +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_TRAIN_KERNEL_GENERATOR_BASE_H__ diff --git a/runtime/onert/core/include/backend/train/TrainableBackendContext.h b/runtime/onert/core/include/backend/train/TrainableBackendContext.h new file mode 100644 index 000000000..3f47af747 --- /dev/null +++ b/runtime/onert/core/include/backend/train/TrainableBackendContext.h @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_BACKEND_TRAIN_TRAINABLE_CONTEXT_H__ +#define __ONERT_BACKEND_BACKEND_TRAIN_TRAINABLE_CONTEXT_H__ + +#include "backend/Backend.h" +#include "backend/train/ITensorRegistry.h" +#include "backend/train/ITrainableBackend.h" +#include "exec/train/optimizer/Optimizer.h" +#include "exec/train/TrainableFnSequence.h" +#include "ir/OperandIndexMap.h" +#include "ir/train/TrainableGraph.h" +#include "util/Set.h" + +namespace onert +{ +namespace backend +{ +namespace train +{ + +using FunctionMap = + std::vector<std::pair<ir::OperationIndex, std::unique_ptr<exec::train::TrainableFnSequence>>>; + +struct TrainableContextData +{ + // A partial and trainable graph that only includes used operand/operations of the original graph + std::unique_ptr<ir::train::TrainableGraph> tgraph; + /* A linear order of operations. This is neccessary for when a graph is not fully connected */ + std::vector<onert::ir::OperationIndex> op_order; + /* Operands that are defined by other backends */ + util::Set<ir::OperandIndex> external_operands; + /* Operand layout info */ + ir::OperandIndexMap<ir::Layout> operand_layouts; + /* Custom kernel builder */ + std::shared_ptr<custom::IKernelBuilder> custom_kernel_builder; + /* Is linear executor or not */ + bool is_linear_executor; + /* Optimizer */ + std::shared_ptr<exec::train::optimizer::Optimizer> optimizer; +}; + +class TrainableBackendContext +{ +public: + TrainableBackendContext(const ITrainableBackend *backend, + std::unique_ptr<TrainableContextData> &&tdata, + std::shared_ptr<ITensorRegistry> tensor_registry = nullptr) + : _backend{backend}, _tdata{std::move(tdata)}, _tensor_registry{tensor_registry} + { + assert(_tdata); + } + virtual ~TrainableBackendContext() = default; + + const ir::train::TrainableGraph *trainable_graph() const { return _tdata->tgraph.get(); } + + const TrainableContextData *data() const { return _tdata.get(); } + + const ITrainableBackend *backend() const { return _backend; } + const util::Set<ir::OperandIndex> &external_operands() const { return _tdata->external_operands; } + const ir::OperandIndexMap<ir::Layout> &operand_layouts() const { return _tdata->operand_layouts; } + + std::shared_ptr<ITensorRegistry> tensor_registry() { return _tensor_registry; } + + virtual ITensorRegistry *genTrainingTensors() = 0; + virtual backend::ITensorRegistry *genTensors() = 0; + virtual FunctionMap genKernels() = 0; + +private: + const ITrainableBackend *_backend{nullptr}; + +protected: + std::unique_ptr<TrainableContextData> _tdata; + +protected: + std::shared_ptr<ITensorRegistry> _tensor_registry; +}; + +using TrainableBackendContexts = + std::unordered_map<const Backend *, std::unique_ptr<TrainableBackendContext>>; + +} // namespace train +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_BACKEND_TRAIN_TRAINABLE_CONTEXT_H__ diff --git a/runtime/onert/core/include/compiler/CodeMap.h b/runtime/onert/core/include/compiler/CodeMap.h index b1d861cf8..93fe43cfd 100644 --- a/runtime/onert/core/include/compiler/CodeMap.h +++ b/runtime/onert/core/include/compiler/CodeMap.h @@ -19,7 +19,7 @@ #include <unordered_map> #include "ir/Index.h" -#include "ir/Operation.h" +#include "ir/IOperation.h" #include "exec/FunctionSequence.h" #include "OperationLowerInfo.h" @@ -31,11 +31,11 @@ namespace compiler struct CodeAndInfo { ir::OperationIndex op_ind; - const ir::Operation *op; + const ir::IOperation *op; const OperationLowerInfo *lower_info; std::unique_ptr<exec::FunctionSequence> fn_seq; - CodeAndInfo(const ir::OperationIndex op_ind, const ir::Operation *op, + CodeAndInfo(const ir::OperationIndex op_ind, const ir::IOperation *op, const OperationLowerInfo *lower_info, std::unique_ptr<exec::FunctionSequence> &&fn_seq) : op_ind{op_ind}, op{op}, lower_info{lower_info}, fn_seq{std::move(fn_seq)} diff --git a/runtime/onert/core/include/compiler/CompilerFactory.h b/runtime/onert/core/include/compiler/CompilerFactory.h index 4894366a2..5a8886aa1 100644 --- a/runtime/onert/core/include/compiler/CompilerFactory.h +++ b/runtime/onert/core/include/compiler/CompilerFactory.h @@ -19,6 +19,7 @@ #include "ICompiler.h" #include "CompilerOptions.h" +#include "compiler/train/TrainingInfo.h" #include "ir/NNPkg.h" namespace onert @@ -34,7 +35,8 @@ public: public: std::unique_ptr<ICompiler> create(const std::shared_ptr<ir::NNPkg> &nnpkg, - std::vector<std::unique_ptr<CompilerOptions>> &copts); + std::vector<std::unique_ptr<CompilerOptions>> &copts, + const compiler::train::TrainingInfo *training_info = nullptr); private: // It is not allowed to use CompilerFactory without get() diff --git a/runtime/onert/core/include/compiler/CompilerOptions.h b/runtime/onert/core/include/compiler/CompilerOptions.h index bbe15fc06..bb0d0a430 100644 --- a/runtime/onert/core/include/compiler/CompilerOptions.h +++ b/runtime/onert/core/include/compiler/CompilerOptions.h @@ -74,6 +74,7 @@ public: public: // GENERAL OPTIONS std::vector<std::string> backend_list; + std::string minmax_filepath; //< File path to save minmax // OPTIONS ONLY FOR DEBUGGING/PROFILING std::string trace_filepath; //< File path to save trace records diff --git a/runtime/onert/core/include/compiler/ILoweredGraph.h b/runtime/onert/core/include/compiler/ILoweredGraph.h new file mode 100644 index 000000000..bc49fa1d7 --- /dev/null +++ b/runtime/onert/core/include/compiler/ILoweredGraph.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_COMPILER_ILOWERED_GRAPH_H__ +#define __ONERT_COMPILER_ILOWERED_GRAPH_H__ + +#include "ir/Graph.h" +#include "compiler/GraphLowerInfo.h" + +namespace onert +{ +namespace compiler +{ + +struct ILoweredGraph +{ + virtual ~ILoweredGraph() = default; + virtual ir::Graph &graph() = 0; + virtual const ir::Graph &graph() const = 0; + virtual const compiler::GraphLowerInfo &lower_info() const = 0; + virtual compiler::GraphLowerInfo &lower_info() = 0; + virtual void setHasDynamicTensor(ir::OperationIndex ind, bool val) = 0; + virtual bool getHasDynamicTensor(ir::OperationIndex ind) const = 0; +}; + +} // namespace compiler +} // namespace onert + +#endif // __ONERT_COMPILER_ILOWERED_GRAPH_H__ diff --git a/runtime/onert/core/include/compiler/LoweredGraph.h b/runtime/onert/core/include/compiler/LoweredGraph.h index e9f0ae0de..b970a884b 100644 --- a/runtime/onert/core/include/compiler/LoweredGraph.h +++ b/runtime/onert/core/include/compiler/LoweredGraph.h @@ -17,10 +17,11 @@ #ifndef __ONERT_COMPILER_LOWERED_GRAPH_H__ #define __ONERT_COMPILER_LOWERED_GRAPH_H__ -#include "ir/Graph.h" -#include "compiler/GraphLowerInfo.h" #include "compiler/BackendResolver.h" #include "compiler/Compiler.h" +#include "compiler/GraphLowerInfo.h" +#include "compiler/ILoweredGraph.h" +#include "ir/Graph.h" namespace onert { @@ -32,22 +33,22 @@ namespace compiler * In addition, after lowering, operands in graph will be set to "dynamic" * if the shape of output of an operation cannot be decided at compilation time. */ -class LoweredGraph +class LoweredGraph : public ILoweredGraph { public: LoweredGraph(const ir::Graph &graph, const compiler::CompilerOptions &options); - ir::Graph &graph() { return _graph; } - const ir::Graph &graph() const { return _graph; } - const compiler::GraphLowerInfo &lower_info() const { return _lower_info_map; } - compiler::GraphLowerInfo &lower_info() { return _lower_info_map; } + ir::Graph &graph() override { return _graph; } + const ir::Graph &graph() const override { return _graph; } + const compiler::GraphLowerInfo &lower_info() const override { return _lower_info_map; } + compiler::GraphLowerInfo &lower_info() override { return _lower_info_map; } std::shared_ptr<ir::OperationIndexMap<int64_t>> indexed_ranks() { return _indexed_ranks; } - void setHasDynamicTensor(ir::OperationIndex ind, bool val) + void setHasDynamicTensor(ir::OperationIndex ind, bool val) override { _has_dynamic_tensor_map.emplace(ind, val); } - bool getHasDynamicTensor(ir::OperationIndex ind) const + bool getHasDynamicTensor(ir::OperationIndex ind) const override { auto itr = _has_dynamic_tensor_map.find(ind); return (itr == _has_dynamic_tensor_map.end()) ? false : itr->second; diff --git a/runtime/onert/core/include/compiler/StaticShapeInferer.h b/runtime/onert/core/include/compiler/StaticShapeInferer.h index 94d6ba1a7..83dede726 100644 --- a/runtime/onert/core/include/compiler/StaticShapeInferer.h +++ b/runtime/onert/core/include/compiler/StaticShapeInferer.h @@ -68,7 +68,7 @@ private: class StaticShapeInferer : public ir::OperationVisitor { public: - StaticShapeInferer(compiler::LoweredGraph *lowered_subg) + StaticShapeInferer(compiler::ILoweredGraph *lowered_subg) : _lowered_subg{lowered_subg}, _subg_input_observers{}, _controlflow_output_observer{nullptr}, _child_inferers{} { @@ -102,18 +102,18 @@ public: void dump(); /** - * @brief Create a lowered model shape inferer map - * @param[in] lowered_subgs lowered model subgraph map + * @brief Create a shape inferer map for a lowered model + * @param[in] lowered_subgs lowered model map * @return Shape inferer map */ static std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>> createStaticShapeInferers( - const std::unordered_map<ir::SubgraphIndex, std::unique_ptr<LoweredGraph>> &lowered_subgs); + const std::unordered_map<ir::SubgraphIndex, ILoweredGraph *> &lowered_subgs); private: - bool checkDynamicInput(const ir::Operation &op); - bool checkDynamicOutput(const ir::Operation &op); - void setDynamicOutput(const ir::Operation &op); + bool checkDynamicInput(const ir::IOperation &op); + bool checkDynamicOutput(const ir::IOperation &op); + void setDynamicOutput(const ir::IOperation &op); private: // TODO Define visitors for operations. List them in alphabetic order. @@ -136,6 +136,7 @@ private: void visit(const ir::operation::Gather &op) override; void visit(const ir::operation::If &op) override; void visit(const ir::operation::L2Normalization &op) override; + void visit(const ir::operation::Loss &op) override; void visit(const ir::operation::LSTM &op) override; void visit(const ir::operation::MatrixBandPart &op) override; void visit(const ir::operation::OneHot &op) override; @@ -178,7 +179,7 @@ private: void handleSimpleUnaryOp(const ir::Operation &op, const ir::OperandIndex input_idx); private: - compiler::LoweredGraph *_lowered_subg; + compiler::ILoweredGraph *_lowered_subg; std::unordered_map<ir::SubgraphIndex, std::unique_ptr<OperandObserver>> _subg_input_observers; // child subg input std::unique_ptr<OperandObserver> _controlflow_output_observer; // parent controlflow op output diff --git a/runtime/onert/core/include/compiler/train/LoweredTrainableGraph.h b/runtime/onert/core/include/compiler/train/LoweredTrainableGraph.h new file mode 100644 index 000000000..a49d1c6a8 --- /dev/null +++ b/runtime/onert/core/include/compiler/train/LoweredTrainableGraph.h @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_COMPILER_TRAIN_LOWERED_TRAINABLE_GRAPH_H__ +#define __ONERT_COMPILER_TRAIN_LOWERED_TRAINABLE_GRAPH_H__ + +#include "compiler/BackendResolver.h" +#include "compiler/CompilerOptions.h" +#include "compiler/GraphLowerInfo.h" +#include "compiler/ILoweredGraph.h" +#include "ir/train/TrainableGraph.h" + +namespace onert +{ +namespace compiler +{ +namespace train +{ + +// TODO Unify with LoweredGraph +/** + * @brief Class that contains lowering information on graph. + * In addition, after lowering, operands in graph will be set to "dynamic" + * if the shape of output of an operation cannot be decided at compilation time. + */ +class LoweredTrainableGraph : public ILoweredGraph +{ +public: + LoweredTrainableGraph(ir::train::TrainableGraph &graph, const compiler::CompilerOptions &options); + + // TODO Remove const_cast + ir::Graph &graph() override { return const_cast<ir::Graph &>(_trainable_graph.graph()); } + const ir::Graph &graph() const override { return _trainable_graph.graph(); } + ir::train::TrainableGraph &trainable_graph() { return _trainable_graph; } + const ir::train::TrainableGraph &trainable_graph() const { return _trainable_graph; } + const compiler::GraphLowerInfo &lower_info() const override { return _lower_info_map; } + compiler::GraphLowerInfo &lower_info() override { return _lower_info_map; } + std::shared_ptr<ir::OperationIndexMap<int64_t>> indexed_ranks() { return _indexed_ranks; } + + void setHasDynamicTensor(ir::OperationIndex, bool has_dynamic) override + { + if (has_dynamic) + throw std::runtime_error("LoweredTrainableGraph does not support dynamic tensors yet"); + } + bool getHasDynamicTensor(ir::OperationIndex) const override { return false; } + +private: + void makeLowerInfo(const compiler::BackendResolver &backend_resolver); + void dumpLowerInfo(); + void lowerGraph(const compiler::CompilerOptions &options); + +private: + /** + * @brief Copy of target graph for lowering + * @note It uses copy of graph, not reference. + * It allows the original graph can be compiled multiple times. + */ + ir::train::TrainableGraph _trainable_graph; + std::shared_ptr<ir::OperationIndexMap<int64_t>> _indexed_ranks; + compiler::GraphLowerInfo _lower_info_map; +}; + +} // namespace train +} // namespace compiler +} // namespace onert + +#endif // __ONERT_COMPILER_TRAIN_LOWERED_TRAINABLE_GRAPH_H__ diff --git a/runtime/onert/core/include/compiler/train/TrainableCodeMap.h b/runtime/onert/core/include/compiler/train/TrainableCodeMap.h new file mode 100644 index 000000000..1069a47c9 --- /dev/null +++ b/runtime/onert/core/include/compiler/train/TrainableCodeMap.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_COMPILER_TRAIN_TRAINABLE_CODE_MAP_H__ +#define __ONERT_COMPILER_TRAIN_TRAINABLE_CODE_MAP_H__ + +#include <unordered_map> +#include "compiler/OperationLowerInfo.h" +#include "exec/train/TrainableFnSequence.h" +#include "ir/train/ITrainableOperation.h" + +namespace onert +{ +namespace compiler +{ +namespace train +{ + +struct TrainableCodeAndInfo +{ + ir::OperationIndex op_ind; + const ir::train::ITrainableOperation *op; + const OperationLowerInfo *lower_info; + // TODO Change to TrainableFnSequence + std::unique_ptr<exec::train::TrainableFnSequence> tn_seq; + + TrainableCodeAndInfo(const ir::OperationIndex op_ind, const ir::train::ITrainableOperation *op, + const OperationLowerInfo *lower_info, + std::unique_ptr<exec::train::TrainableFnSequence> &&tn_seq) + : op_ind{op_ind}, op{op}, lower_info{lower_info}, tn_seq{std::move(tn_seq)} + { + } +}; + +using TrainableCodeMap = std::unordered_map<ir::OperationIndex, TrainableCodeAndInfo>; + +} // namespace train +} // namespace compiler +} // namespace onert + +#endif // __ONERT_COMPILER_TRAIN_TRAINABLE_CODE_MAP_H__ diff --git a/runtime/onert/core/include/compiler/train/TrainingInfo.h b/runtime/onert/core/include/compiler/train/TrainingInfo.h new file mode 100644 index 000000000..3b77c838c --- /dev/null +++ b/runtime/onert/core/include/compiler/train/TrainingInfo.h @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_COMPILER_TRAIN_TRAINING_INFO_H__ +#define __ONERT_COMPILER_TRAIN_TRAINING_INFO_H__ + +#include "ir/Index.h" +#include "exec/train/optimizer/OptimizerCode.h" +#include "ir/operation/Loss.h" + +namespace onert +{ +namespace compiler +{ +namespace train +{ + +struct LossInfo +{ + ir::operation::Loss::Type type; + // TODO Add members for loss +}; + +struct OptimizerInfo +{ + exec::train::optimizer::OptimizerCode optim_code; + float learning_rate; + // TODO Add properties +}; + +class TrainingInfo +{ +public: + TrainingInfo() {} + TrainingInfo(const TrainingInfo &obj) = default; + TrainingInfo(TrainingInfo &&) = default; + TrainingInfo &operator=(const TrainingInfo &) = default; + TrainingInfo &operator=(TrainingInfo &&) = default; + ~TrainingInfo() = default; + + uint32_t batchSize() const { return _batch_size; } + void setBatchSize(const uint32_t batch_size) { _batch_size = batch_size; } + const LossInfo &lossInfo() const { return _loss_info; } + void setLossInfo(const LossInfo &loss_info) { _loss_info = loss_info; } + const OptimizerInfo &optimizerInfo() const { return _optimizer_info; } + void setOptimizerInfo(const OptimizerInfo &optimizer_info) { _optimizer_info = optimizer_info; } + +private: + LossInfo _loss_info; + OptimizerInfo _optimizer_info; + uint32_t _batch_size; +}; + +} // namespace train +} // namespace compiler +} // namespace onert + +#endif // __ONERT_COMPILER_TRAIN_TRAINING_INFO_H__ diff --git a/runtime/onert/core/include/exec/Execution.h b/runtime/onert/core/include/exec/Execution.h index ba3edcdd6..da4d20dbe 100644 --- a/runtime/onert/core/include/exec/Execution.h +++ b/runtime/onert/core/include/exec/Execution.h @@ -142,8 +142,28 @@ public: */ bool isFinished(void) const; +#ifdef ONERT_TRAIN + /** + * @brief Train + * @note It should be called after setting input and output buffer + * @param training_step The number of iterations of the training process. + * In other words, the number of gradient update. + */ + void train(uint32_t training_step); + + /** + * @brief Get loss + * @note It should be called after training + * @param[in] ind Output index + * @return @c float Loss value + */ + float getLoss(const ir::IOIndex &ind); +#endif // ONERT_TRAIN + ir::Shape getInputShape(ir::IOIndex ind) const; ir::Shape getOutputShape(ir::IOIndex ind) const; + size_t getInputTotalSize(ir::IOIndex ind) const; + size_t getOutputTotalSize(ir::IOIndex ind) const; private: const IExecutor *entryExecutor() const { return _executors->entryExecutor(); }; diff --git a/runtime/onert/core/include/exec/FunctionSequence.h b/runtime/onert/core/include/exec/FunctionSequence.h index a7020d425..f3384be3c 100644 --- a/runtime/onert/core/include/exec/FunctionSequence.h +++ b/runtime/onert/core/include/exec/FunctionSequence.h @@ -75,7 +75,7 @@ public: public: // methods related to dynamic tensor struct DynamicTensorCtx { - const ir::Operation *op = nullptr; + const ir::IOperation *op = nullptr; std::shared_ptr<exec::DynamicShapeInferer> dynamic_shape_inferer = nullptr; }; diff --git a/runtime/onert/core/include/exec/MinMaxMap.h b/runtime/onert/core/include/exec/MinMaxMap.h new file mode 100644 index 000000000..fc6849e74 --- /dev/null +++ b/runtime/onert/core/include/exec/MinMaxMap.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_EXEC_MINMAX_MAP_H__ +#define __ONERT_EXEC_MINMAX_MAP_H__ + +#include "ir/Index.h" +#include "util/MinMaxMap.h" + +namespace onert +{ +namespace exec +{ +struct SMHash +{ + size_t operator()(const std::pair<ir::SubgraphIndex, ir::OperationIndex> &k) const noexcept + { + return std::hash<ir::SubgraphIndex>()(k.first) ^ std::hash<ir::OperationIndex>()(k.second); + } +}; +// SM means single model +using SMMinMaxMap = util::MinMaxMap<std::pair<ir::SubgraphIndex, ir::OperationIndex>, SMHash>; +} // namespace exec +} // namespace onert + +#endif // __ONERT_EXEC_MINMAX_MAP_H__ diff --git a/runtime/onert/core/include/exec/train/IGradientApplier.h b/runtime/onert/core/include/exec/train/IGradientApplier.h new file mode 100644 index 000000000..65e931e0e --- /dev/null +++ b/runtime/onert/core/include/exec/train/IGradientApplier.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_EXEC_TRAIN_IGRADIENT_APPLIER_H__ +#define __ONERT_EXEC_TRAIN_IGRADIENT_APPLIER_H__ + +#include <cstdint> + +namespace onert +{ +namespace exec +{ +namespace train +{ + +class IGradientApplier +{ +public: + virtual ~IGradientApplier() = default; + + /** + * @brief Apply gradients to a trainable tensor + * + * @param training_step The number of iterations of the training process. + */ + virtual void applyGradient(uint32_t training_step) = 0; +}; + +} // namespace train +} // namespace exec +} // namespace onert + +#endif // __ONERT_EXEC_TRAIN_IGRADIENT_APPLIER_H__ diff --git a/runtime/onert/core/include/exec/train/ITrainableFunction.h b/runtime/onert/core/include/exec/train/ITrainableFunction.h new file mode 100644 index 000000000..45adc258f --- /dev/null +++ b/runtime/onert/core/include/exec/train/ITrainableFunction.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_EXEC_TRAIN_I_TRAINABLE_FUNCTION_H__ +#define __ONERT_EXEC_TRAIN_I_TRAINABLE_FUNCTION_H__ + +#include <cstdint> + +namespace onert +{ +namespace exec +{ +namespace train +{ + +class ITrainableFunction +{ +public: + virtual ~ITrainableFunction() = default; + virtual void forward(bool training) = 0; + virtual void backward() = 0; +}; + +} // namespace train +} // namespace exec +} // namespace onert + +#endif // __ONERT_EXEC_TRAIN_I_TRAINABLE_FUNCTION_H__ diff --git a/runtime/onert/core/include/exec/train/TrainableFnSequence.h b/runtime/onert/core/include/exec/train/TrainableFnSequence.h new file mode 100644 index 000000000..8be1b1e5d --- /dev/null +++ b/runtime/onert/core/include/exec/train/TrainableFnSequence.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_EXEC_TRAIN_TRAINABLE_FN_SEQUENCE_H__ +#define __ONERT_EXEC_TRAIN_TRAINABLE_FN_SEQUENCE_H__ + +#include "exec/train/ITrainableFunction.h" +#include "exec/train/IGradientApplier.h" + +#include <memory> +#include <vector> +#include <functional> + +namespace onert +{ +namespace exec +{ +namespace train +{ +class TrainableFnSequence +{ +public: + void forward(bool training); + void backward(uint32_t training_step); + + void append(std::unique_ptr<ITrainableFunction> &&fn); + void append(std::unique_ptr<IGradientApplier> &&applier); + void iterate(const std::function<void(ITrainableFunction &)> &fn); + +public: + // TODO Change members + std::vector<std::unique_ptr<ITrainableFunction>> _functions; + std::vector<std::unique_ptr<IGradientApplier>> _appliers; +}; +} // namespace train +} // namespace exec +} // namespace onert + +#endif // __ONERT_EXEC_TRAIN_TRAINABLE_FN_SEQUENCE_H__ diff --git a/runtime/onert/core/include/exec/train/optimizer/Optimizer.h b/runtime/onert/core/include/exec/train/optimizer/Optimizer.h new file mode 100644 index 000000000..05f2ee19b --- /dev/null +++ b/runtime/onert/core/include/exec/train/optimizer/Optimizer.h @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_EXEC_TRAIN_OPTIMIZER_OPTIMIZER_H__ +#define __ONERT_EXEC_TRAIN_OPTIMIZER_OPTIMIZER_H__ + +#include "backend/IPortableTensor.h" +#include "backend/train/ITrainableTensor.h" + +#include <string> + +namespace onert +{ +namespace exec +{ +namespace train +{ +namespace optimizer +{ + +// Gradient tensor, Trainable Tensor, Number of training steps +using UpdateFactors = + std::tuple<const backend::IPortableTensor &, backend::train::ITrainableTensor &, size_t>; + +/** + * @class Optimizer Base class for optimizers + * @brief Base class for all optimizers + */ +class Optimizer +{ +public: + virtual ~Optimizer() = default; + + /** + * @brief Get the name of optimizer + * + * @return The name of optimizer + */ + virtual std::string name() const { return std::string{"Invalid"}; } + + /** + * @brief Get the Learning Rate + * + * @param iteration The number of training steps + * @return Learning rate + */ + virtual double getLearningRate(uint32_t iteration) const = 0; + + /** + * @brief Apply gradient to a trainable tensor + * + * @param factors UpdateFactors to be used for applying gradient to a trainable tensor + */ + virtual void applyGradient(const UpdateFactors &factors) const = 0; + + // TODO Add member functions for exporting optimizer information +}; + +} // namespace optimizer +} // namespace train +} // namespace exec +} // namespace onert + +#endif // __ONERT_EXEC_TRAIN_OPTIMIZER_OPTIMIZER_H__ diff --git a/runtime/onert/core/include/exec/train/optimizer/OptimizerCode.h b/runtime/onert/core/include/exec/train/optimizer/OptimizerCode.h new file mode 100644 index 000000000..3e4a8f2a6 --- /dev/null +++ b/runtime/onert/core/include/exec/train/optimizer/OptimizerCode.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_EXEC_TRAIN_OPTIMIZER_OPTIMIZER_CODE_H__ +#define __ONERT_EXEC_TRAIN_OPTIMIZER_OPTIMIZER_CODE_H__ + +#include <functional> +#include <stdint.h> +#include <string> + +namespace onert +{ +namespace exec +{ +namespace train +{ +namespace optimizer +{ + +enum class OptimizerCode +{ + Invalid, //< Invalid + SGD, //< SGD optimizer + Adam //< Adam optimizer +}; + +/** + * @brief Convert the optimizer code to the name + * + * @param opcode The optimizer code + * @return The name of the optimizer + */ +std::string toString(OptimizerCode opcode); + +} // namespace optimizer +} // namespace train +} // namespace exec +} // namespace onert + +#endif // __ONERT_EXEC_TRAIN_OPTIMIZER_OPTIMIZER_CODE_H__ diff --git a/runtime/onert/core/include/exec/train/optimizer/SGD.h b/runtime/onert/core/include/exec/train/optimizer/SGD.h new file mode 100644 index 000000000..6a1a5c9b8 --- /dev/null +++ b/runtime/onert/core/include/exec/train/optimizer/SGD.h @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_EXEC_TRAIN_OPTIMIZER_SGD_H__ +#define __ONERT_EXEC_TRAIN_OPTIMIZER_SGD_H__ + +#include "exec/train/optimizer/Optimizer.h" + +namespace onert +{ +namespace exec +{ +namespace train +{ +namespace optimizer +{ + +/** + * @class SGD optimizer class + * @brief SGD optimizer + */ +class SGD : public Optimizer +{ +public: + struct Property + { + double momentum{0.0}; + bool nesterov{false}; + }; + +public: + explicit SGD() : _props{}, _learning_rate{0.01} {} + explicit SGD(const Property &props) : _props{props}, _learning_rate{0.01} {} + explicit SGD(double lr) : _props{}, _learning_rate{lr} {} + explicit SGD(const Property &props, double lr) : _props{props}, _learning_rate{lr} {} + +public: + /** + * @brief Get the name of optimizer + * + * @return The name of optimizer + */ + std::string name() const override { return std::string{"SGD"}; } + + /** + * @brief Get the Learning Rate + * + * @param iteration The number of training steps + * @return Learning rate + */ + double getLearningRate(uint32_t iteration = 0) const override; + + /** + * @brief Apply gradient to a trainable tensor + * + * @param factors UpdateFactors to be used for applying gradient to a trainable tensor + */ + void applyGradient(const UpdateFactors &factors) const override; + +private: + Property _props; + double _learning_rate; +}; + +} // namespace optimizer +} // namespace train +} // namespace exec +} // namespace onert + +#endif // __ONERT_EXEC_TRAIN_OPTIMIZER_SGD_H__ diff --git a/runtime/onert/core/include/ir/Graph.h b/runtime/onert/core/include/ir/Graph.h index 1783cdca0..641698eb2 100644 --- a/runtime/onert/core/include/ir/Graph.h +++ b/runtime/onert/core/include/ir/Graph.h @@ -20,27 +20,17 @@ #include <functional> #include <unordered_map> +#include "ir/IGraph.h" #include "ir/Model.h" #include "ir/Operands.h" #include "ir/Operations.h" namespace onert { -namespace backend -{ -namespace custom -{ -class IKernelBuilder; -} // namespace custom -} // namespace backend -} // namespace onert - -namespace onert -{ namespace ir { -class Graph +class Graph : public IGraph { private: enum class Phase @@ -70,7 +60,7 @@ public: * @return OperandIndex @c index if successful, Undefined otherwise */ OperandIndex addOperand(OperandIndex index, std::unique_ptr<Operand> &&operand); - OperationIndex addOperation(std::unique_ptr<Operation> &&node); + OperationIndex addOperation(std::unique_ptr<IOperation> &&node); /** * @brief Add an operation to the graph with the given index and object * @@ -79,52 +69,50 @@ public: * moved so the caller's pointer will be still valid. * * @param index Index to be added - * @param operation Operation to be added + * @param operation IOperation to be added * @return OperandIndex @c index if successful, Undefined otherwise */ - OperationIndex addOperation(OperationIndex index, std::unique_ptr<Operation> &&operation); + OperationIndex addOperation(OperationIndex index, std::unique_ptr<IOperation> &&operation); + /** + * @brief Replace an operation which the graph already has + * + * If the given @c index is available, it succeeds. And @c operation is moved which invalidates + * the caller's pointer. If the given @c operation has at least one invalid operand index, it + * fails. And @c operation will not be moved so the caller's pointer will be still valid. + * + * No information in the graph is changed except for replacing an operation. + * + * @param operation Operation to be added + * @return OperationIndex @c index if successful, UNDEFINED otherwise + */ + OperationIndex replaceOperation(OperationIndex index, std::unique_ptr<IOperation> &&operation); void setOperandValue(const OperandIndex &ind, std::shared_ptr<Data> data); + void changeShape(const OperandIndex &ind, const ir::Shape &new_shape) override; void addInput(const OperandIndex &ind, const std::string &name = ""); void addOutput(const OperandIndex &ind, const std::string &name = ""); - void verify(void); + void verify(void) const; void removeOperand(const OperandIndex &ind) { _operands.remove(ind); } void setLayout(Layout layout) { _layout = layout; } private: - bool checkOperandsForOperation(const Operation &operation); - void linkOperandToOperation(OperationIndex index, const Operation &operation); + bool checkOperandsForOperation(const IOperation &operation); + void linkOperandToOperation(OperationIndex index, const IOperation &operation); void initializeUseDef(); // TODO Rename to `sweepUnusedOperands` // TODO Make this public void sweepGarbageOperands(); - // Custom operations support -public: - void - bindKernelBuilder(const std::shared_ptr<onert::backend::custom::IKernelBuilder> &kernel_builder) - { - _kernel_builder = kernel_builder; - } - - const std::shared_ptr<backend::custom::IKernelBuilder> &getKernelBuilder() const - { - return _kernel_builder; - } - -private: - std::shared_ptr<backend::custom::IKernelBuilder> _kernel_builder; - // Accessors public: - const OperandIndexSequence &getInputs() const { return _inputs; } + const OperandIndexSequence &getInputs() const override { return _inputs; } OperandIndexSequence &getInputs() { return _inputs; } - const OperandIndexSequence &getOutputs() const { return _outputs; } + const OperandIndexSequence &getOutputs() const override { return _outputs; } OperandIndexSequence &getOutputs() { return _outputs; } - IOIndex getInputIndex(const std::string &name) const; - IOIndex getOutputIndex(const std::string &name) const; - const Operands &operands() const { return _operands; } + IOIndex getInputIndex(const std::string &name) const override; + IOIndex getOutputIndex(const std::string &name) const override; + const Operands &operands() const override { return _operands; } Operands &operands() { return _operands; } // TODO Remove this non-const accessor - const Operations &operations() const { return _operations; } + const Operations &operations() const override { return _operations; } Operations &operations() { return _operations; } Layout layout() const { return _layout; } diff --git a/runtime/onert/core/include/ir/IGraph.h b/runtime/onert/core/include/ir/IGraph.h new file mode 100644 index 000000000..34fb20188 --- /dev/null +++ b/runtime/onert/core/include/ir/IGraph.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_IR_IGRAPH_H__ +#define __ONERT_IR_IGRAPH_H__ + +#include "ir/Operands.h" +#include "ir/Operations.h" + +namespace onert +{ +namespace ir +{ + +struct IGraph +{ + virtual ~IGraph() = default; + + // Accessors + virtual const OperandIndexSequence &getInputs() const = 0; + virtual const OperandIndexSequence &getOutputs() const = 0; + virtual IOIndex getInputIndex(const std::string &name) const = 0; + virtual IOIndex getOutputIndex(const std::string &name) const = 0; + virtual const Operands &operands() const = 0; + virtual const Operations &operations() const = 0; + + // Methods that can change graph + virtual void changeShape(const OperandIndex &index, const ir::Shape &new_shape) = 0; +}; + +} // namespace ir +} // namespace onert + +#endif // __ONERT_IR_IGRAPH_H__ diff --git a/runtime/onert/core/include/ir/IOperation.h b/runtime/onert/core/include/ir/IOperation.h new file mode 100644 index 000000000..be0dd939d --- /dev/null +++ b/runtime/onert/core/include/ir/IOperation.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_IR_IOPERATION_H__ +#define __ONERT_IR_IOPERATION_H__ + +#include <memory> + +#include "ir/Index.h" +#include "ir/OpCode.h" +#include "ir/OperandIndexSequence.h" + +namespace onert +{ +namespace ir +{ + +struct OperationVisitor; + +struct IOperation +{ + virtual ~IOperation() = default; + + virtual void accept(OperationVisitor &v) const = 0; + virtual std::string name() const { return std::string{toString(opcode())}; } + virtual OpCode opcode() const = 0; + + virtual void replaceInputs(const OperandIndex &from, const OperandIndex &to) = 0; + virtual void replaceOutputs(const OperandIndex &from, const OperandIndex &to) = 0; + virtual const OperandIndexSequence &getInputs() const = 0; + virtual const OperandIndexSequence &getOutputs() const = 0; +}; + +} // namespace ir +} // namespace onert + +#endif // __ONERT_IR_IOPERATION_H__ diff --git a/runtime/onert/core/include/ir/Model.h b/runtime/onert/core/include/ir/Model.h index c3c0d87b8..950d28850 100644 --- a/runtime/onert/core/include/ir/Model.h +++ b/runtime/onert/core/include/ir/Model.h @@ -20,15 +20,25 @@ #include <memory> #include <unordered_map> +#include "ir/IGraph.h" #include "ir/Index.h" #include "util/ObjectManager.h" namespace onert { -namespace ir +namespace backend +{ +namespace custom { +class IKernelBuilder; +} // namespace custom +} // namespace backend +} // namespace onert -class Graph; +namespace onert +{ +namespace ir +{ class Model { @@ -47,7 +57,7 @@ public: * @param[in] index Index of subgraph to be pushed * @return Created */ - void push(SubgraphIndex index, const std::shared_ptr<Graph> &subg) { _subgraphs[index] = subg; } + void push(SubgraphIndex index, const std::shared_ptr<IGraph> &subg) { _subgraphs[index] = subg; } /** * @brief Remove the subgraph that is associated with the given index @@ -61,9 +71,9 @@ public: * @brief Get the subgraph that is associated with the given index * * @param[in] index Index of the subgraph to be returned - * @return Graph + * @return IGraph */ - const std::shared_ptr<Graph> &at(const SubgraphIndex &index) const + const std::shared_ptr<IGraph> &at(const SubgraphIndex &index) const { return _subgraphs.at(index); } @@ -71,9 +81,9 @@ public: * @brief Get the subgraph that is associated with the given index * * @param[in] index Index of the subgraph to be returned - * @return Graph + * @return IGraph */ - std::shared_ptr<Graph> &at(const SubgraphIndex &index) { return _subgraphs.at(index); } + std::shared_ptr<IGraph> &at(const SubgraphIndex &index) { return _subgraphs.at(index); } /** * @brief Get the subgraph that is associated with the given index @@ -93,7 +103,7 @@ public: * @param[in] fn Function to be run for every container entry * @return N/A */ - void iterate(const std::function<void(const SubgraphIndex &, const Graph &)> &fn) const + void iterate(const std::function<void(const SubgraphIndex &, const IGraph &)> &fn) const { for (const auto &e : _subgraphs) { @@ -107,7 +117,7 @@ public: * @param[in] fn Function to be run for every container entry * @return N/A */ - void iterate(const std::function<void(const SubgraphIndex &, Graph &)> &fn) + void iterate(const std::function<void(const SubgraphIndex &, IGraph &)> &fn) { for (const auto &e : _subgraphs) { @@ -125,12 +135,46 @@ public: /** * @brief Return the primary subgraph * - * @return std::shared_ptr<Graph> Primary subgraph + * @return std::shared_ptr<IGraph> Primary subgraph */ - std::shared_ptr<Graph> primary_subgraph() const { return _subgraphs.at(SubgraphIndex{0}); } + std::shared_ptr<IGraph> primary_subgraph() const { return _subgraphs.at(SubgraphIndex{0}); } + + /** + * @brief Return whether the model has only typename Graph + * + * @tparam Graph Type that inherits from IGraph + * + * @return true if the model has only typename Graph, otherwise false + */ + template <typename Graph, std::enable_if_t<std::is_base_of<IGraph, Graph>::value, bool> = true> + bool hasOnly() + { + for (const auto &e : _subgraphs) + { + if (std::dynamic_pointer_cast<Graph>(e.second) == nullptr) + return false; + } + return true; + } + +private: + std::unordered_map<SubgraphIndex, std::shared_ptr<IGraph>> _subgraphs; + + // Custom operations support +public: + void + bindKernelBuilder(const std::shared_ptr<onert::backend::custom::IKernelBuilder> &kernel_builder) + { + _kernel_builder = kernel_builder; + } + + const std::shared_ptr<backend::custom::IKernelBuilder> &getKernelBuilder() const + { + return _kernel_builder; + } private: - std::unordered_map<SubgraphIndex, std::shared_ptr<Graph>> _subgraphs; + std::shared_ptr<backend::custom::IKernelBuilder> _kernel_builder; }; } // namespace ir diff --git a/runtime/onert/core/include/ir/NNPkg.h b/runtime/onert/core/include/ir/NNPkg.h index b23745d55..5df58bde7 100644 --- a/runtime/onert/core/include/ir/NNPkg.h +++ b/runtime/onert/core/include/ir/NNPkg.h @@ -21,7 +21,6 @@ #include <unordered_set> #include <vector> -#include "ir/Graph.h" #include "ir/Index.h" #include "ir/Model.h" @@ -233,7 +232,7 @@ public: /** * @brief Get model input info */ - OperandInfo &inputInfo(uint32_t index) const + const OperandInfo &inputInfo(uint32_t index) const { if (_models.size() == 1) { @@ -251,7 +250,7 @@ public: /** * @brief Get model output info */ - OperandInfo &outputInfo(uint32_t index) const + const OperandInfo &outputInfo(uint32_t index) const { if (_models.size() == 1) { @@ -266,6 +265,31 @@ public: return graph->operands().at(operand_index).info(); } + void changeInputShape(uint32_t index, const ir::Shape &new_shape) + { + if (_models.size() == 1) + { + auto graph = primary_model()->primary_subgraph(); + auto const operand_index = graph->getInputs().at(index); + graph->changeShape(operand_index, new_shape); + return; + } + + auto const &desc = input(index); + auto graph = model(std::get<ModelIndex>(desc))->primary_subgraph(); + auto const operand_index = graph->getInputs().at(std::get<IOIndex>(desc).value()); + graph->changeShape(operand_index, new_shape); + } + + /** + * @brief Replace model + * + * @param[in] model Model to be replaced + * + * TODO: Support multiple models + */ + void replaceModel(std::shared_ptr<Model> model) { _models[ModelIndex{0}] = model; } + // TODO: Add iterate() or getter for edges private: diff --git a/runtime/onert/core/include/ir/OperandIndexSequence.h b/runtime/onert/core/include/ir/OperandIndexSequence.h index 846c3f950..66d00761b 100644 --- a/runtime/onert/core/include/ir/OperandIndexSequence.h +++ b/runtime/onert/core/include/ir/OperandIndexSequence.h @@ -76,6 +76,7 @@ public: } public: + bool operator==(const OperandIndexSequence &other) const; OperandIndexSequence operator+(const OperandIndexSequence &other) const; friend std::ostream &operator<<(std::ostream &o, const OperandIndexSequence &operand_seq); diff --git a/runtime/onert/core/include/ir/Operation.h b/runtime/onert/core/include/ir/Operation.h index 89f7e340d..06ab29ecb 100644 --- a/runtime/onert/core/include/ir/Operation.h +++ b/runtime/onert/core/include/ir/Operation.h @@ -19,9 +19,8 @@ #include <memory> -#include "ir/OpCode.h" +#include "ir/IOperation.h" #include "ir/Operand.h" -#include "ir/OperandIndexSequence.h" #include "ir/OperandConstraint.h" namespace onert @@ -29,9 +28,9 @@ namespace onert namespace ir { -struct OperationVisitor; - -class Operation +// NOTE Virtual inheritance is introduced because trainable operations inherit +// `ITrainableOperation` and `Operation` which inherit `IOperation`. +class Operation : virtual public IOperation { public: // TODO Remove default parameter @@ -49,16 +48,11 @@ public: virtual ~Operation(); public: - virtual void accept(OperationVisitor &v) const = 0; - virtual std::string name() const { return std::string{toString(opcode())}; } - virtual OpCode opcode() const = 0; - -public: - void replaceInputs(const OperandIndex &from, const OperandIndex &to); - void replaceOutputs(const OperandIndex &from, const OperandIndex &to); + void replaceInputs(const OperandIndex &from, const OperandIndex &to) override; + void replaceOutputs(const OperandIndex &from, const OperandIndex &to) override; OperandIndexSequence &getInputs() { return _inputs; } - const OperandIndexSequence &getInputs() const { return _inputs; } - const OperandIndexSequence &getOutputs() const { return _outputs; } + const OperandIndexSequence &getInputs() const override { return _inputs; } + const OperandIndexSequence &getOutputs() const override { return _outputs; } // It's for only input/output tensors but const data. void setInputs(const OperandIndexSequence &indexes); void setOutputs(const OperandIndexSequence &indexes); diff --git a/runtime/onert/core/include/ir/Operations.Include.h b/runtime/onert/core/include/ir/Operations.Include.h index 4602fafec..6352b8ed9 100644 --- a/runtime/onert/core/include/ir/Operations.Include.h +++ b/runtime/onert/core/include/ir/Operations.Include.h @@ -49,6 +49,7 @@ #include "ir/operation/L2Normalization.h" #include "ir/operation/LocalResponseNormalization.h" #include "ir/operation/LogSoftmax.h" +#include "ir/operation/Loss.h" #include "ir/operation/LSTM.h" #include "ir/operation/MatrixBandPart.h" #include "ir/operation/DetectionPostProcess.h" diff --git a/runtime/onert/core/include/ir/Operations.h b/runtime/onert/core/include/ir/Operations.h index 0b5fbf529..4102fcebe 100644 --- a/runtime/onert/core/include/ir/Operations.h +++ b/runtime/onert/core/include/ir/Operations.h @@ -18,7 +18,7 @@ #define __ONERT_IR_OPERATIONS_H__ #include "ir/Index.h" -#include "ir/Operation.h" +#include "ir/IOperation.h" #include "util/ObjectManager.h" namespace onert @@ -26,7 +26,7 @@ namespace onert namespace ir { -class Operations : public util::ObjectManager<OperationIndex, Operation> +class Operations : public util::ObjectManager<OperationIndex, IOperation> { public: Operations() = default; diff --git a/runtime/onert/core/include/ir/Operations.lst b/runtime/onert/core/include/ir/Operations.lst index f37d89505..1f91aecb2 100644 --- a/runtime/onert/core/include/ir/Operations.lst +++ b/runtime/onert/core/include/ir/Operations.lst @@ -88,3 +88,6 @@ OP(Transpose) OP(TransposeConv) OP(Unpack) OP(While) + +// Training Only +OP(Loss) diff --git a/runtime/onert/core/include/ir/operation/BinaryArithmetic.h b/runtime/onert/core/include/ir/operation/BinaryArithmetic.h index 110fff565..3dca80bbc 100644 --- a/runtime/onert/core/include/ir/operation/BinaryArithmetic.h +++ b/runtime/onert/core/include/ir/operation/BinaryArithmetic.h @@ -27,7 +27,7 @@ namespace ir namespace operation { -class BinaryArithmetic final : public Operation +class BinaryArithmetic : public Operation { public: enum Input diff --git a/runtime/onert/core/include/ir/operation/Loss.h b/runtime/onert/core/include/ir/operation/Loss.h new file mode 100644 index 000000000..73f1aed59 --- /dev/null +++ b/runtime/onert/core/include/ir/operation/Loss.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_IR_OPERATION_LOSS_H__ +#define __ONERT_IR_OPERATION_LOSS_H__ + +#include "ir/Operation.h" + +namespace onert +{ +namespace ir +{ +namespace operation +{ + +class Loss : public Operation +{ +public: + enum Input + { + Y_PRED = 0, + Y_TRUE = 1 + // TODO Add more inputs if necessary + }; + + // NOTE It is not yet determined how to get the information of the previous activation when + // generating kernels of Loss operation for each backend. If it is determined to get it + // from the object of this class, we have to consider whether to change this enum class. + enum class Type + { + MEAN_SQUARED_ERROR, + CATEGORICAL_CROSSENTROPY + }; + + struct Param + { + Type op_type; + // TODO Add more params if necessary + Param() : op_type(Type::MEAN_SQUARED_ERROR) {} + }; + +public: + Loss(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param ¶m); + +public: + void accept(OperationVisitor &v) const override; + std::string name() const override; + OpCode opcode() const final { return OpCode::Loss; } + +public: + const Param ¶m() const { return _param; } + +private: + Param _param; +}; + +} // namespace operation +} // namespace ir +} // namespace onert + +#endif // __ONERT_IR_OPERATION_LOSS_H__ diff --git a/runtime/onert/core/include/ir/train/ITrainableOperation.h b/runtime/onert/core/include/ir/train/ITrainableOperation.h new file mode 100644 index 000000000..590bed45d --- /dev/null +++ b/runtime/onert/core/include/ir/train/ITrainableOperation.h @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_IR_TRAIN_ITRAINABLE_OPERATION_H__ +#define __ONERT_IR_TRAIN_ITRAINABLE_OPERATION_H__ + +#include "ir/IOperation.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ + +struct TrainableOperationVisitor; + +// NOTE Virtual inheritance is introduced because trainable operations inherit +// `ITrainableOperation` and `Operation` which inherit `IOperation`. +class ITrainableOperation : virtual public IOperation +{ +public: + virtual ~ITrainableOperation() = default; + +public: + virtual std::unique_ptr<ITrainableOperation> clone() const = 0; + virtual void accept(OperationVisitor &v) const override = 0; + virtual void accept(TrainableOperationVisitor &v) const = 0; + // TODO Add virtual methods related to training +}; + +} // namespace train +} // namespace ir +} // namespace onert + +#endif // __ONERT_IR_TRAIN_ITRAINABLE_OPERATION_H__ diff --git a/runtime/onert/core/include/ir/train/Operations.Include.h b/runtime/onert/core/include/ir/train/Operations.Include.h new file mode 100644 index 000000000..56e752f94 --- /dev/null +++ b/runtime/onert/core/include/ir/train/Operations.Include.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_IR_TRAIN_OPERATIONS_OPERATION_INCLUDE_H__ +#define __ONERT_IR_TRAIN_OPERATIONS_OPERATION_INCLUDE_H__ + +#include "ir/train/operation/Conv2D.h" +#include "ir/train/operation/ElementwiseActivation.h" +#include "ir/train/operation/FullyConnected.h" +#include "ir/train/operation/Loss.h" +#include "ir/train/operation/Permute.h" +#include "ir/train/operation/Pool2D.h" +#include "ir/train/operation/Reshape.h" +#include "ir/train/operation/Softmax.h" + +#endif // __ONERT_IR_TRAIN_OPERATIONS_OPERATION_INCLUDE_H__ diff --git a/runtime/onert/core/include/ir/train/Operations.lst b/runtime/onert/core/include/ir/train/Operations.lst new file mode 100644 index 000000000..14dc38819 --- /dev/null +++ b/runtime/onert/core/include/ir/train/Operations.lst @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef OP +#error Define OP before including this file +#endif + +OP(Conv2D) +OP(ElementwiseActivation) +OP(FullyConnected) +OP(Loss) +OP(Permute) +OP(Pool2D) +OP(Reshape) +OP(Softmax) diff --git a/runtime/onert/core/include/ir/train/TrainableGraph.h b/runtime/onert/core/include/ir/train/TrainableGraph.h new file mode 100644 index 000000000..90c49e212 --- /dev/null +++ b/runtime/onert/core/include/ir/train/TrainableGraph.h @@ -0,0 +1,145 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_IR_TRAIN_TRAINABLE_GRAPH_H__ +#define __ONERT_IR_TRAIN_TRAINABLE_GRAPH_H__ + +#include <functional> +#include <unordered_map> + +#include "ir/Graph.h" +#include "ir/train/ITrainableOperation.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ + +class TrainableGraph : public IGraph +{ +public: + /** + * @brief Construct a new Trainable Graph object + * + * @param graph + */ + explicit TrainableGraph(); + explicit TrainableGraph(const TrainableGraph &tgraph); + explicit TrainableGraph(const Graph &graph); + ~TrainableGraph() = default; + + // TrainableGraph Building +public: + OperandIndex addOperand(const Shape &shape, const TypeInfo &type); + /** + * @brief Add an operand to the graph with the given index and object + * + * If the given index is available, it succeeds. And @c operand is moved which invalidates the + * caller's pointer. If the given index is already taken, it fails. And @c operand will not be + * moved so the caller's pointer will be still valid. + * + * @param[in] index Index to be added + * @param[in] operand Operand to be added + * @return OperandIndex @c index if successful, UNDEFINED otherwise + */ + OperandIndex addOperand(OperandIndex index, std::unique_ptr<Operand> &&operand); + /** + * @brief Add a new trainable operation to the graph + * + * If the given @c operation has at least one invalid operand index, it fails. And @c operation + * will not be moved so the caller's pointer will be still valid. + * + * @param operation Operation to be added + * @return OperationIndex @c index if successful, UNDEFINED otherwise + */ + OperationIndex addOperation(std::unique_ptr<ITrainableOperation> &&operation); + /** + * @brief Replace a trainable operation which the graph already has + * + * If the given @c index is available, it succeeds. And @c operation is moved which invalidates + * the caller's pointer. If the given @c operation has at least one invalid operand index, it + * fails. And @c operation will not be moved so the caller's pointer will be still valid. + * + * No information in the graph is changed except for replacing an operation. + * + * @param operation Operation to be added + * @return OperationIndex @c index if successful, UNDEFINED otherwise + */ + OperationIndex replaceOperation(OperationIndex index, + std::unique_ptr<ITrainableOperation> &&operation); + + /** + * @brief Add a derivative to the graph with the given index and object + * + * If the given index is available, it succeeds. And @c derivative is moved which invalidates the + * caller's pointer. If the given index is already taken, it fails. And @c derivative will not be + * moved so the caller's pointer will be still valid. + * + * @param[in] index Index to be added + * @param[in] derivative Derivative operand to be added + * @return OperandIndex @c index if successful, UNDEFINED otherwise + */ + OperandIndex addDerivative(OperandIndex index, std::unique_ptr<Operand> &&derivative); + +public: + void changeShape(const OperandIndex &ind, const ir::Shape &new_shape) override; + void changeDerivativeShape(const OperandIndex &ind, const ir::Shape &new_shape); + void addInput(const OperandIndex &ind, const std::string &name = ""); + void addOutput(const OperandIndex &ind, const std::string &name = ""); + void addLoss(const OperandIndex &loss_ind, const IOIndex &pred_io_ind); + void verify() const; + void removeOperand(const OperandIndex &ind); + void setLayout(Layout layout); + void setInputs(OperandIndexSequence inputs, + std::unordered_map<std::string, IOIndex> name_to_input); + void setOutputs(OperandIndexSequence outputs, + std::unordered_map<std::string, IOIndex> name_to_output); + + // Accessors +public: + const OperandIndexSequence &getInputs() const override { return _graph.getInputs(); } + const OperandIndexSequence &getOutputs() const override { return _graph.getOutputs(); } + IOIndex getInputIndex(const std::string &name) const override; + IOIndex getOutputIndex(const std::string &name) const override; + const Operands &operands() const override { return _graph.operands(); } + Operands &operands() { return _graph.operands(); } // TODO Remove this non-const accessor + const Operations &operations() const override { return _graph.operations(); } + const Operands &derivatives() const { return _derivatives; } + OperandIndex getLossIndex(const IOIndex &pred_io_ind) const; + Layout layout() const { return _graph.layout(); } + const Graph &graph() const { return _graph; } + +public: + const ITrainableOperation &operation(OperationIndex index) const; + +public: + std::vector<ir::OperationIndex> topolSortOperations() const; + // TODO Support topological sort for backwarding + +private: + Graph _graph; + Operands _derivatives; + + std::unordered_map<IOIndex, OperandIndex> _losses; +}; + +} // namespace train +} // namespace ir +} // namespace onert + +#endif // __ONERT_IR_TRAIN_TRAINABLE_GRAPH_H__ diff --git a/runtime/onert/core/include/ir/train/TrainableOperationVisitor.h b/runtime/onert/core/include/ir/train/TrainableOperationVisitor.h new file mode 100644 index 000000000..fc58c351d --- /dev/null +++ b/runtime/onert/core/include/ir/train/TrainableOperationVisitor.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_IR_TRAIN_TRAINABLE_OPERATION_VISITOR_H__ +#define __ONERT_IR_TRAIN_TRAINABLE_OPERATION_VISITOR_H__ + +#include "ir/train/Operations.Include.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ + +struct TrainableOperationVisitor +{ + virtual ~TrainableOperationVisitor() = default; + +#define OP(InternalName) \ + virtual void visit(const operation::InternalName &) {} +#include "ir/train/Operations.lst" +#undef OP +}; + +} // namespace train +} // namespace ir +} // namespace onert + +#endif // __ONERT_IR_TRAIN_TRAINABLE_OPERATION_VISITOR_H__ diff --git a/runtime/onert/core/include/ir/train/operation/Conv2D.h b/runtime/onert/core/include/ir/train/operation/Conv2D.h new file mode 100644 index 000000000..b8968926a --- /dev/null +++ b/runtime/onert/core/include/ir/train/operation/Conv2D.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_IR_TRAIN_OPERATION_CONV2D_H__ +#define __ONERT_IR_TRAIN_OPERATION_CONV2D_H__ + +#include "ir/operation/Conv2D.h" +#include "ir/train/ITrainableOperation.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +class Conv2D : public ir::operation::Conv2D, public ITrainableOperation +{ +private: + using OperationType = ir::operation::Conv2D; + +public: + Conv2D(const OperationType &operation); + +public: + std::unique_ptr<ITrainableOperation> clone() const override; + void accept(OperationVisitor &v) const override; + void accept(TrainableOperationVisitor &v) const override; +}; + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert + +#endif // __ONERT_IR_TRAIN_OPERATION_CONV2D_H__ diff --git a/runtime/onert/core/include/ir/train/operation/ElementwiseActivation.h b/runtime/onert/core/include/ir/train/operation/ElementwiseActivation.h new file mode 100644 index 000000000..97ab54d17 --- /dev/null +++ b/runtime/onert/core/include/ir/train/operation/ElementwiseActivation.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_IR_TRAIN_OPERATION_ELEMENTWISE_ACTIVATION_H__ +#define __ONERT_IR_TRAIN_OPERATION_ELEMENTWISE_ACTIVATION_H__ + +#include "ir/operation/ElementwiseActivation.h" +#include "ir/train/ITrainableOperation.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +class ElementwiseActivation : public ir::operation::ElementwiseActivation, + public ITrainableOperation +{ +private: + using OperationType = ir::operation::ElementwiseActivation; + +public: + ElementwiseActivation(const OperationType &operation); + +public: + std::unique_ptr<ITrainableOperation> clone() const override; + void accept(OperationVisitor &v) const override; + void accept(TrainableOperationVisitor &v) const override; +}; + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert + +#endif // __ONERT_IR_TRAIN_OPERATION_ELEMENTWISE_ACTIVATION_H__ diff --git a/runtime/onert/core/include/ir/train/operation/FullyConnected.h b/runtime/onert/core/include/ir/train/operation/FullyConnected.h new file mode 100644 index 000000000..bede58d69 --- /dev/null +++ b/runtime/onert/core/include/ir/train/operation/FullyConnected.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_IR_TRAIN_OPERATION_FULLYCONNECTED_H__ +#define __ONERT_IR_TRAIN_OPERATION_FULLYCONNECTED_H__ + +#include "ir/operation/FullyConnected.h" +#include "ir/train/ITrainableOperation.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +class FullyConnected : public ir::operation::FullyConnected, public ITrainableOperation +{ +private: + using OperationType = ir::operation::FullyConnected; + +public: + FullyConnected(const OperationType &operation); + +public: + std::unique_ptr<ITrainableOperation> clone() const override; + void accept(OperationVisitor &v) const override; + void accept(TrainableOperationVisitor &v) const override; +}; + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert + +#endif // __ONERT_IR_TRAIN_OPERATION_FULLYCONNECTED_H__ diff --git a/runtime/onert/core/include/ir/train/operation/Loss.h b/runtime/onert/core/include/ir/train/operation/Loss.h new file mode 100644 index 000000000..c7cc4213a --- /dev/null +++ b/runtime/onert/core/include/ir/train/operation/Loss.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_IR_TRAIN_OPERATION_LOSS_H__ +#define __ONERT_IR_TRAIN_OPERATION_LOSS_H__ + +#include "ir/operation/Loss.h" +#include "ir/train/ITrainableOperation.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +class Loss : public ir::operation::Loss, public ITrainableOperation +{ +private: + using OperationType = ir::operation::Loss; + +public: + Loss(const OperationType &operation); + +public: + std::unique_ptr<ITrainableOperation> clone() const override; + void accept(OperationVisitor &v) const override; + void accept(TrainableOperationVisitor &v) const override; +}; + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert + +#endif // __ONERT_IR_TRAIN_OPERATION_LOSS_H__ diff --git a/runtime/onert/core/include/ir/train/operation/Permute.h b/runtime/onert/core/include/ir/train/operation/Permute.h new file mode 100644 index 000000000..e652b136d --- /dev/null +++ b/runtime/onert/core/include/ir/train/operation/Permute.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_IR_TRAIN_OPERATION_PERMUTE_H__ +#define __ONERT_IR_TRAIN_OPERATION_PERMUTE_H__ + +#include "ir/operation/Permute.h" +#include "ir/train/ITrainableOperation.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +class Permute : public ir::operation::Permute, public ITrainableOperation +{ +private: + using OperationType = ir::operation::Permute; + +public: + Permute(const OperationType &operation); + +public: + std::unique_ptr<ITrainableOperation> clone() const override; + void accept(OperationVisitor &v) const override; + void accept(TrainableOperationVisitor &v) const override; +}; + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert + +#endif // __ONERT_IR_TRAIN_OPERATION_PERMUTE_H__ diff --git a/runtime/onert/core/include/ir/train/operation/Pool2D.h b/runtime/onert/core/include/ir/train/operation/Pool2D.h new file mode 100644 index 000000000..024997074 --- /dev/null +++ b/runtime/onert/core/include/ir/train/operation/Pool2D.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_IR_TRAIN_OPERATION_POOL2D_H__ +#define __ONERT_IR_TRAIN_OPERATION_POOL2D_H__ + +#include "ir/operation/Pool2D.h" +#include "ir/train/ITrainableOperation.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +class Pool2D : public ir::operation::Pool2D, public ITrainableOperation +{ +private: + using OperationType = ir::operation::Pool2D; + +public: + Pool2D(const OperationType &operation); + +public: + std::unique_ptr<ITrainableOperation> clone() const override; + void accept(OperationVisitor &v) const override; + void accept(TrainableOperationVisitor &v) const override; +}; + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert + +#endif // __ONERT_IR_TRAIN_OPERATION_POOL2D_H__ diff --git a/runtime/onert/core/include/ir/train/operation/Reshape.h b/runtime/onert/core/include/ir/train/operation/Reshape.h new file mode 100644 index 000000000..1efd62cfe --- /dev/null +++ b/runtime/onert/core/include/ir/train/operation/Reshape.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_IR_TRAIN_OPERATION_RESHAPE_H__ +#define __ONERT_IR_TRAIN_OPERATION_RESHAPE_H__ + +#include "ir/operation/Reshape.h" +#include "ir/train/ITrainableOperation.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +class Reshape : public ir::operation::Reshape, public ITrainableOperation +{ +private: + using OperationType = ir::operation::Reshape; + +public: + Reshape(const OperationType &operation); + +public: + std::unique_ptr<ITrainableOperation> clone() const override; + void accept(OperationVisitor &v) const override; + void accept(TrainableOperationVisitor &v) const override; +}; + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert + +#endif // __ONERT_IR_TRAIN_OPERATION_RESHAPE_H__ diff --git a/runtime/onert/core/include/ir/train/operation/Softmax.h b/runtime/onert/core/include/ir/train/operation/Softmax.h new file mode 100644 index 000000000..b12e6abc1 --- /dev/null +++ b/runtime/onert/core/include/ir/train/operation/Softmax.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_IR_TRAIN_OPERATION_SOFTMAX_H__ +#define __ONERT_IR_TRAIN_OPERATION_SOFTMAX_H__ + +#include "ir/operation/Softmax.h" +#include "ir/train/ITrainableOperation.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +class Softmax : public ir::operation::Softmax, public ITrainableOperation +{ +private: + using OperationType = ir::operation::Softmax; + +public: + Softmax(const OperationType &operation); + +public: + std::unique_ptr<ITrainableOperation> clone() const override; + void accept(OperationVisitor &v) const override; + void accept(TrainableOperationVisitor &v) const override; +}; + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert + +#endif // __ONERT_IR_TRAIN_OPERATION_SOFTMAX_H__ diff --git a/runtime/onert/core/include/ir/train/operation/UntrainableOperation.h b/runtime/onert/core/include/ir/train/operation/UntrainableOperation.h new file mode 100644 index 000000000..7cda0ec0c --- /dev/null +++ b/runtime/onert/core/include/ir/train/operation/UntrainableOperation.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_IR_TRAIN_OPERATION_UNTRAINABLE_OPERATION_H__ +#define __ONERT_IR_TRAIN_OPERATION_UNTRAINABLE_OPERATION_H__ + +#include "ir/train/ITrainableOperation.h" + +#include "ir/OperationVisitor.h" +#include "ir/train/TrainableOperationVisitor.h" + +#include <type_traits> + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +// `UntrainableOperation` wraps operations that are not yet supported for training. +// This class can be removed if all operations are supported for training. +template <typename OperationType, + typename = std::enable_if_t<std::is_base_of<Operation, OperationType>::value>> +class UntrainableOperation : public OperationType, public ITrainableOperation +{ +public: + UntrainableOperation(const OperationType &operation) : OperationType{operation} {} + virtual ~UntrainableOperation() = default; + +public: + std::unique_ptr<ITrainableOperation> clone() const override + { + return std::make_unique<UntrainableOperation<OperationType>>(*this); + } + void accept(OperationVisitor &v) const override { v.visit(*this); } + void accept(TrainableOperationVisitor &) const override + { + throw std::runtime_error(OperationType::name() + "operation is not trainable yet"); + } +}; + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert + +#endif // __ONERT_IR_TRAIN_OPERATION_UNTRAINABLE_OPERATION_H__ diff --git a/runtime/onert/core/include/odc/IQuantizer.h b/runtime/onert/core/include/odc/IQuantizer.h new file mode 100644 index 000000000..d698d9ef0 --- /dev/null +++ b/runtime/onert/core/include/odc/IQuantizer.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_ODC_IQUANTIZER_H__ +#define __ONERT_ODC_IQUANTIZER_H__ + +namespace onert +{ +namespace odc +{ + +class IQuantizer +{ +public: + virtual ~IQuantizer() = default; + + virtual int quantize(const char *in, const char *out, bool is_q16) = 0; +}; + +} // namespace odc +} // namespace onert + +#endif // __ONERT_ODC_IQUANTIZER_H__ diff --git a/runtime/onert/core/include/odc/QuantizeManager.h b/runtime/onert/core/include/odc/QuantizeManager.h new file mode 100644 index 000000000..a749c0ee1 --- /dev/null +++ b/runtime/onert/core/include/odc/QuantizeManager.h @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_ODC_QUANTIZE_MANAGER_H__ +#define __ONERT_ODC_QUANTIZE_MANAGER_H__ + +#include "IQuantizer.h" + +#include <functional> +#include <string> + +namespace onert +{ +namespace odc +{ + +class Quantize; + +class QuantizeManager +{ +public: + // Non-copyable + QuantizeManager() = delete; + QuantizeManager(const std::string &model_path) : _model_path(model_path) {} + QuantizeManager(QuantizeManager const &) = delete; + QuantizeManager &operator=(QuantizeManager const &) = delete; + +public: + /** + * @brief Set model path to export quantized model + * + * @param model_path Model path to export quantized model + */ + void exportModelPath(const std::string &model_path) { _export_model_path = model_path; } + + /** + * @brief Get model path to export quantized model + * + * @return Model path to export quantized model + */ + std::string &exportModelPath() { return _export_model_path; } + + /** + * @brief Set quantize type + * + * @param is_q16 true if q16, false if q8 + * + * @todo Support more general quantize type + */ + void quantizeType(bool is_q16) { _is_q16 = is_q16; } + + /** + * @brief Quantize model + * + * @return true if success, otherwise false + */ + bool quantize(); + +private: + std::string _model_path = ""; + std::string _export_model_path = ""; + bool _is_q16 = false; +}; + +} // namespace odc +} // namespace onert + +#endif // __ONERT_ODC_QUANTIZE_MANAGER_H__ diff --git a/runtime/onert/core/include/util/Config.lst b/runtime/onert/core/include/util/Config.lst index b9bad1b59..d3e37ce8f 100644 --- a/runtime/onert/core/include/util/Config.lst +++ b/runtime/onert/core/include/util/Config.lst @@ -31,6 +31,7 @@ CONFIG(NCNN_LAYOUT , std::string , "NCHW") CONFIG(PROFILING_MODE , bool , "0") CONFIG(USE_SCHEDULER , bool , "0") CONFIG(TRACE_FILEPATH , std::string , "") +CONFIG(MINMAX_FILEPATH , std::string , "") CONFIG(FP16_ENABLE , bool , "0") CONFIG(RUY_THREADS , int , "-1") CONFIG(XNNPACK_THREADS , int , "-1") diff --git a/runtime/onert/core/include/util/MinMaxMap.h b/runtime/onert/core/include/util/MinMaxMap.h new file mode 100644 index 000000000..2245f84b0 --- /dev/null +++ b/runtime/onert/core/include/util/MinMaxMap.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_UTIL_MINMAX_MAP_H_ +#define __ONERT_UTIL_MINMAX_MAP_H_ + +#include <unordered_map> +#include <utility> + +namespace onert +{ +namespace util +{ + +template <typename N, typename Hash = std::hash<N>> class MinMaxMap +{ + struct MinMaxPair + { + float data[2]; // [0] = min, [1] = max + }; + +public: + void append(N node, float min, float max) { _minmax_map[node] = {min, max}; } + auto begin() const { return _minmax_map.begin(); } + auto end() const { return _minmax_map.end(); } + +private: + std::unordered_map<N, MinMaxPair, Hash> _minmax_map; +}; + +} // namespace util +} // namespace onert + +#endif // __ONERT_UTIL_MINMAX_MAP_H_ diff --git a/runtime/onert/core/include/util/Set.h b/runtime/onert/core/include/util/Set.h index ee4062d25..73d43d4f0 100644 --- a/runtime/onert/core/include/util/Set.h +++ b/runtime/onert/core/include/util/Set.h @@ -53,6 +53,16 @@ public: public: /** + * @brief copy assignment operator + */ + Set<Element> &operator=(const Set<Element> &) = default; + /** + * @brief move assignment operator + */ + Set<Element> &operator=(Set<Element> &&) = default; + +public: + /** * @brief Add a given element to the set * * @param e Element added @@ -104,7 +114,7 @@ public: Set<Element> operator|(const Set<Element> &other) const // Union { auto ret = *this; - for (auto e : other) + for (auto &&e : other) { ret.add(e); } @@ -118,7 +128,7 @@ public: Set<Element> operator&(const Set<Element> &other) const // Intersect { Set<Element> ret; - for (auto e : other) + for (auto &&e : other) { if (contains(e)) { @@ -135,7 +145,7 @@ public: Set<Element> operator-(const Set<Element> &other) const // Minus { auto ret = *this; - for (auto e : other) + for (auto &&e : other) { ret.remove(e); } diff --git a/runtime/onert/core/include/util/Utils.h b/runtime/onert/core/include/util/Utils.h index 505f5a9b3..6b6bc2400 100644 --- a/runtime/onert/core/include/util/Utils.h +++ b/runtime/onert/core/include/util/Utils.h @@ -27,73 +27,56 @@ #define UNUSED_RELEASE(a) (void)(a) -template <size_t from, size_t to, typename Enable = void> struct ForEachDimension +template <size_t rest> struct ForEachDimension { template <typename L> static void unroll(const onert::ir::Shape &shape, onert::ir::Coordinates &coords, L lambda_function) { - static_assert(from < to, "from must not be less than to"); - assert(static_cast<int>(to) <= shape.rank()); - const auto &d = shape.dim(from); + if (static_cast<int>(rest) > shape.rank()) + { + ForEachDimension<rest - 1>::unroll(shape, coords, lambda_function); + return; + } + + const auto axis = shape.rank() - rest; + const auto &d = shape.dim(axis); for (auto v = 0; v < d; v++) { - coords.set(from, v); - ForEachDimension<from + 1, to>::unroll(shape, coords, lambda_function); + coords.set(axis, v); + ForEachDimension<rest - 1>::unroll(shape, coords, lambda_function); } } }; -template <size_t from, size_t to> -struct ForEachDimension<from, to, typename std::enable_if<from == to>::type> +template <> struct ForEachDimension<0> { template <typename L> static void unroll(const onert::ir::Shape &shape, onert::ir::Coordinates &coords, L lambda_function) { UNUSED_RELEASE(shape); - assert(static_cast<int>(to) <= shape.rank()); lambda_function(coords); } }; template <typename L> inline void ShapeLoop(const onert::ir::Shape &shape, L lambda_function) { - assert(shape.rank() > 0); - for (auto i = 0; i < shape.rank(); ++i) + int32_t rank = shape.rank(); + assert(rank > 0); + for (int32_t i = 0; i < rank; ++i) { assert(shape.dim(i) > 0); } onert::ir::Coordinates coords; - switch (shape.rank()) + if (rank == 0) { - case 0: - coords.set(0, 0); - ForEachDimension<0, 0>::unroll(shape, coords, lambda_function); - break; - case 1: - ForEachDimension<0, 1>::unroll(shape, coords, lambda_function); - break; - case 2: - ForEachDimension<0, 2>::unroll(shape, coords, lambda_function); - break; - case 3: - ForEachDimension<0, 3>::unroll(shape, coords, lambda_function); - break; - case 4: - ForEachDimension<0, 4>::unroll(shape, coords, lambda_function); - break; - case 5: - ForEachDimension<0, 5>::unroll(shape, coords, lambda_function); - break; - case 6: - ForEachDimension<0, 6>::unroll(shape, coords, lambda_function); - break; - default: - assert(false && "ShapeLoop, 1 <= Shape'rank <= 6"); - break; + coords.set(0, 0); } + // TODO Change 6 to onert::ir::Shape::kMaxRank if onert::ir::Shape::kMaxRank is modified as a + // constant expression + ForEachDimension<6>::unroll(shape, coords, lambda_function); } #endif // __ONERT_UTIL_UTILS_H__ diff --git a/runtime/onert/core/src/backend/BackendContext.cc b/runtime/onert/core/src/backend/BackendContext.cc index b9aab7994..7b36f106d 100644 --- a/runtime/onert/core/src/backend/BackendContext.cc +++ b/runtime/onert/core/src/backend/BackendContext.cc @@ -16,8 +16,6 @@ #include "backend/BackendContext.h" -#include "ir/Operation.h" - namespace onert { namespace backend diff --git a/runtime/onert/core/src/backend/basic/StaticTensorManager.cc b/runtime/onert/core/src/backend/basic/StaticTensorManager.cc index b03eb607c..71cde4cde 100644 --- a/runtime/onert/core/src/backend/basic/StaticTensorManager.cc +++ b/runtime/onert/core/src/backend/basic/StaticTensorManager.cc @@ -35,6 +35,15 @@ StaticTensorManager::StaticTensorManager(const std::shared_ptr<TensorRegistry> & // DO NOTHING } +StaticTensorManager::StaticTensorManager(const std::shared_ptr<TensorRegistry> ®, + const std::string planner_id, + DynamicTensorManager *dynamic_tensor_manager) + : _nonconst_mgr{new MemoryManager(planner_id)}, _tensors{reg}, _dynamic_tensor_manager{ + dynamic_tensor_manager} +{ + // DO NOTHING +} + void StaticTensorManager::allocateNonconsts(void) { _nonconst_mgr->allocate(); diff --git a/runtime/onert/core/src/backend/basic/Tensor.cc b/runtime/onert/core/src/backend/basic/Tensor.cc index c2bbc5a66..de1cff4f4 100644 --- a/runtime/onert/core/src/backend/basic/Tensor.cc +++ b/runtime/onert/core/src/backend/basic/Tensor.cc @@ -51,6 +51,7 @@ bool Tensor::applyShape(const ir::Shape &new_shape) auto allocTensorMem = [&]() { auto capacity = total_size(); + assert(_dynamic_mem_mgr); auto alloc = _dynamic_mem_mgr->allocate(this, capacity); setBuffer(alloc); }; @@ -68,6 +69,7 @@ bool Tensor::applyShape(const ir::Shape &new_shape) auto new_size = new_shape.num_elements() * ir::sizeOfDataType(data_type()); if (previous_size != new_size) { + assert(_dynamic_mem_mgr); _dynamic_mem_mgr->deallocate(this); setShape(new_shape); diff --git a/runtime/onert/core/src/backend/basic/TensorBuilder.cc b/runtime/onert/core/src/backend/basic/TensorBuilder.cc index a10cc2bf9..f9d83875d 100644 --- a/runtime/onert/core/src/backend/basic/TensorBuilder.cc +++ b/runtime/onert/core/src/backend/basic/TensorBuilder.cc @@ -34,6 +34,14 @@ TensorBuilder::TensorBuilder(const std::shared_ptr<TensorRegistry> &tensor_reg) /* empty */ } +TensorBuilder::TensorBuilder(const std::shared_ptr<TensorRegistry> &tensor_reg, + const std::string planner_id) + : _tensor_reg{tensor_reg}, _dynamic_tensor_mgr{new DynamicTensorManager(_tensor_reg)}, + _static_tensor_mgr{new StaticTensorManager(_tensor_reg, planner_id, _dynamic_tensor_mgr.get())} +{ + /* empty */ +} + void TensorBuilder::registerTensorInfo(const ir::OperandIndex &ind, const ir::OperandInfo &info, ir::Layout layout) { diff --git a/runtime/onert/core/src/backend/basic/train/TrainableTensor.cc b/runtime/onert/core/src/backend/basic/train/TrainableTensor.cc new file mode 100644 index 000000000..d09604224 --- /dev/null +++ b/runtime/onert/core/src/backend/basic/train/TrainableTensor.cc @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <backend/basic/train/TrainableTensor.h> + +namespace onert +{ +namespace backend +{ +namespace basic +{ +namespace train +{ + +std::vector<ITensor *> TrainableTensor::optVars() +{ + std::vector<ITensor *> ret; + for (auto &&e : _opt_vars) + { + ret.emplace_back(e.get()); + } + return ret; +} + +void TrainableTensor::fillBuffer(const std::shared_ptr<ir::Data> &data) +{ + auto *buffer = _tensor.buffer(); + assert(buffer); + assert(total_size() == data->size()); + std::memcpy(buffer, data->base(), data->size()); +} + +} // namespace train +} // namespace basic +} // namespace backend +} // namespace onert diff --git a/runtime/onert/core/src/backend/builtin/Backend.h b/runtime/onert/core/src/backend/builtin/Backend.h index 3791f3ffa..c05494a6a 100644 --- a/runtime/onert/core/src/backend/builtin/Backend.h +++ b/runtime/onert/core/src/backend/builtin/Backend.h @@ -22,8 +22,16 @@ #include "KernelGenerator.h" #include "TensorBuilder.h" #include "Tensor.h" +#ifdef ONERT_TRAIN +#include "train/BackendContext.h" +#include "train/KernelGenerator.h" +#include "train/TensorRegistry.h" +#endif // ONERT_TRAIN #include <backend/Backend.h> +#ifdef ONERT_TRAIN +#include <backend/train/ITrainableBackend.h> +#endif // ONERT_TRAIN #include <memory> @@ -35,6 +43,10 @@ namespace builtin { class Backend : public ::onert::backend::Backend +#ifdef ONERT_TRAIN + , + public backend::train::ITrainableBackend +#endif // ONERT_TRAIN { public: Backend() : _config{std::make_shared<Config>()} {} @@ -70,6 +82,22 @@ public: return context; } +#ifdef ONERT_TRAIN + std::unique_ptr<backend::train::TrainableBackendContext> + newContext(backend::train::TrainableContextData &&tdata) const override + { + const auto &tgraph = *tdata.tgraph; + auto tr = std::make_shared<train::TensorRegistry>(); + // TODO Create TensorBuilder if necessary + auto tdata_ptr = std::make_unique<backend::train::TrainableContextData>(std::move(tdata)); + auto context = std::make_unique<train::BackendContext>(this, std::move(tdata_ptr), tr); + + context->kernel_gen = + std::make_shared<train::KernelGenerator>(tgraph, tr, context->external_context()); + return context; + } +#endif // ONERT_TRAIN + private: std::shared_ptr<IConfig> _config; }; diff --git a/runtime/onert/core/src/backend/builtin/BackendContext.cc b/runtime/onert/core/src/backend/builtin/BackendContext.cc index c1a2ed537..573617e28 100644 --- a/runtime/onert/core/src/backend/builtin/BackendContext.cc +++ b/runtime/onert/core/src/backend/builtin/BackendContext.cc @@ -32,7 +32,7 @@ FunctionMap BackendContext::genKernels() { FunctionMap ret; - for (auto op_ind : _data.op_order) + for (auto &&op_ind : _data.op_order) { auto fn_seq = kernel_gen->generate(op_ind); ret.emplace_back(op_ind, std::move(fn_seq)); diff --git a/runtime/onert/core/src/backend/builtin/Config.cc b/runtime/onert/core/src/backend/builtin/Config.cc index f792c0c36..e5f6d4c21 100644 --- a/runtime/onert/core/src/backend/builtin/Config.cc +++ b/runtime/onert/core/src/backend/builtin/Config.cc @@ -27,7 +27,7 @@ std::string Config::ID = "builtin"; bool Config::initialize() { return true; } -ir::Layout Config::supportLayout(const ir::Operation &, ir::Layout frontend_layout) +ir::Layout Config::supportLayout(const ir::IOperation &, ir::Layout frontend_layout) { return frontend_layout; } diff --git a/runtime/onert/core/src/backend/builtin/Config.h b/runtime/onert/core/src/backend/builtin/Config.h index 5226eba69..196b299d3 100644 --- a/runtime/onert/core/src/backend/builtin/Config.h +++ b/runtime/onert/core/src/backend/builtin/Config.h @@ -34,7 +34,7 @@ public: static std::string ID; std::string id() override { return ID; } bool initialize() override; - ir::Layout supportLayout(const ir::Operation &node, ir::Layout frontend_layout) override; + ir::Layout supportLayout(const ir::IOperation &node, ir::Layout frontend_layout) override; bool supportPermutation() override { return false; } bool supportDynamicTensor() override { diff --git a/runtime/onert/core/src/backend/builtin/KernelGenerator.cc b/runtime/onert/core/src/backend/builtin/KernelGenerator.cc index 4533703a6..00c200a92 100644 --- a/runtime/onert/core/src/backend/builtin/KernelGenerator.cc +++ b/runtime/onert/core/src/backend/builtin/KernelGenerator.cc @@ -71,14 +71,14 @@ void KernelGenerator::visit(const ir::operation::If &node) const auto else_subg_index = node.param().else_subg_index; std::vector<backend::IPortableTensor *> input_tensors; - for (const auto input_index : node.getInputs()) + for (const auto &input_index : node.getInputs()) { auto input_tensor = getPortableTensor(input_index); input_tensors.emplace_back(input_tensor); } std::vector<backend::IPortableTensor *> output_tensors; - for (const auto output_index : node.getOutputs()) + for (const auto &output_index : node.getOutputs()) { auto output_tensor = getPortableTensor(output_index); output_tensors.emplace_back(output_tensor); @@ -117,14 +117,14 @@ void KernelGenerator::visit(const ir::operation::While &node) // This op does not support input as a constant, because builtin backend does not have // TensorBuilder std::vector<backend::IPortableTensor *> input_tensors; - for (const auto input_index : node.getInputs()) + for (const auto &input_index : node.getInputs()) { auto input_tensor = getPortableTensor(input_index); input_tensors.emplace_back(input_tensor); } std::vector<backend::IPortableTensor *> output_tensors; - for (const auto output_index : node.getOutputs()) + for (const auto &output_index : node.getOutputs()) { auto output_tensor = getPortableTensor(output_index); output_tensors.emplace_back(output_tensor); diff --git a/runtime/onert/core/src/backend/builtin/kernel/WhileLayer.cc b/runtime/onert/core/src/backend/builtin/kernel/WhileLayer.cc index c0ca4046c..8b00db468 100644 --- a/runtime/onert/core/src/backend/builtin/kernel/WhileLayer.cc +++ b/runtime/onert/core/src/backend/builtin/kernel/WhileLayer.cc @@ -96,7 +96,7 @@ void WhileLayer::run() // Need some temp tensors to hold the body subgraph output std::vector<std::unique_ptr<Tensor>> temp_outputs_o; std::vector<IPortableTensor *> temp_outputs; - for (auto io_tensor : body_exec->getOutputTensors()) + for (auto &&io_tensor : body_exec->getOutputTensors()) { auto tensor = std::make_unique<Tensor>(io_tensor->orig_info(), io_tensor->orig_layout(), _dyn_memory_manager); @@ -139,7 +139,7 @@ void WhileLayer::run() // Clean-up the temp tensors _dyn_memory_manager->deallocate(cond_output_tensor.get()); - for (auto tensor : temp_outputs) + for (auto &&tensor : temp_outputs) { _dyn_memory_manager->deallocate(tensor); } diff --git a/runtime/onert/core/src/backend/builtin/train/BackendContext.cc b/runtime/onert/core/src/backend/builtin/train/BackendContext.cc new file mode 100644 index 000000000..fa9131f4d --- /dev/null +++ b/runtime/onert/core/src/backend/builtin/train/BackendContext.cc @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "BackendContext.h" + +#include "backend/basic/train/TrainableBackendContextHelpers.h" +#include "exec/FunctionSequence.h" + +namespace onert +{ +namespace backend +{ +namespace builtin +{ +namespace train +{ + +backend::ITensorRegistry *BackendContext::genTensors() +{ + // For now, there is no need to generate tensors for forwarding. + // builtin train backend handles 3 operators: `Permute`, `IF`, `WHILE`. + // `Permute`: Tensor generation is not required. + // `IF`, `WHILE`: Not supported yet + return tensor_registry().get(); +} + +backend::train::ITensorRegistry *BackendContext::genTrainingTensors() +{ + // For now, there is no need to generate tensors for backwarding. + return tensor_registry().get(); +} + +backend::train::FunctionMap BackendContext::genKernels() +{ + backend::train::FunctionMap ret; + + for (auto &&op_ind : _tdata->op_order) + { + auto tn_seq = kernel_gen->generate(op_ind); + ret.emplace_back(op_ind, std::move(tn_seq)); + } + + trainable_graph()->operands().iterate( + [&](const ir::OperandIndex &ind, const ir::Operand &operand) { + if (!external_operands().contains(ind) && operand.isConstant()) + { + throw std::runtime_error( + "BackendContext: builtin backend does not support updatable weights yet"); + } + }); + + // TODO Enable prepare() + // for (auto &&it : ret) + // { + // auto &fn_seq = it.second; + // fn_seq->iterate([&](exec::IFunction &ifunc) { ifunc.prepare(); }); + // } + + return ret; +} + +} // namespace train +} // namespace builtin +} // namespace backend +} // namespace onert diff --git a/runtime/onert/core/src/backend/builtin/train/BackendContext.h b/runtime/onert/core/src/backend/builtin/train/BackendContext.h new file mode 100644 index 000000000..6f8ce4cae --- /dev/null +++ b/runtime/onert/core/src/backend/builtin/train/BackendContext.h @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_BUILTIN_TRAIN_BACKEND_CONTEXT_H__ +#define __ONERT_BACKEND_BUILTIN_TRAIN_BACKEND_CONTEXT_H__ + +#include <backend/train/TrainableBackendContext.h> + +#include "KernelGenerator.h" +#include "../ExternalContext.h" +#include "../TensorBuilder.h" + +namespace onert +{ +namespace backend +{ +namespace builtin +{ +namespace train +{ + +class BackendContext : public backend::train::TrainableBackendContext +{ +public: + BackendContext(const backend::train::ITrainableBackend *backend, + std::unique_ptr<backend::train::TrainableContextData> &&data, + std::shared_ptr<backend::train::ITensorRegistry> tensor_registry = nullptr, + std::shared_ptr<TensorBuilder> tensor_builder = nullptr, + std::shared_ptr<KernelGenerator> kernel_gen = nullptr) + : backend::train::TrainableBackendContext(backend, std::move(data), tensor_registry), + kernel_gen{kernel_gen}, + _external_context(new ExternalContext), _tensor_builder{tensor_builder} + { + } + + backend::ITensorRegistry *genTensors() override; + backend::train::ITensorRegistry *genTrainingTensors() override; + +public: + backend::train::FunctionMap genKernels() override; + + std::shared_ptr<ExternalContext> external_context() { return _external_context; } + +public: + // TODO Make it private + std::shared_ptr<KernelGenerator> kernel_gen; + +private: + // NOTE ruy context has a thread pool, and when multiple ruy contexts are created, + // the thread pool is also created in duplicate + // TODO Create one ruy context for session + std::shared_ptr<ExternalContext> _external_context; + +private: + std::shared_ptr<TensorBuilder> _tensor_builder; +}; + +} // namespace train +} // namespace builtin +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_BUILTIN_TRAIN_BACKEND_CONTEXT_H__ diff --git a/runtime/onert/core/src/backend/builtin/train/KernelGenerator.cc b/runtime/onert/core/src/backend/builtin/train/KernelGenerator.cc new file mode 100644 index 000000000..6f2c0a3b9 --- /dev/null +++ b/runtime/onert/core/src/backend/builtin/train/KernelGenerator.cc @@ -0,0 +1,98 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "KernelGenerator.h" + +#include "kernel/PermuteLayer.h" + +namespace onert +{ +namespace backend +{ +namespace builtin +{ +namespace train +{ + +KernelGenerator::KernelGenerator(const ir::train::TrainableGraph &tgraph, + const std::shared_ptr<TensorRegistry> &tensor_reg, + const std::shared_ptr<ExternalContext> &external_context) + : KernelGeneratorBase{tgraph}, _tensor_reg{tensor_reg}, _external_context(external_context) +{ +} + +std::unique_ptr<exec::train::TrainableFnSequence> KernelGenerator::generate(ir::OperationIndex ind) +{ + auto ret = std::make_unique<exec::train::TrainableFnSequence>(); + const auto &op = _tgraph.operation(ind); + op.accept(*this); + // _return_fn must have been generated + if (_return_fn == nullptr) + { + throw std::runtime_error(op.name() + " op does not supported trainable kernel yet"); + } + + ret->_functions.emplace_back(std::move(_return_fn)); + + return ret; +} + +void KernelGenerator::visit(const ir::train::operation::Permute &node) +{ + const auto output_index{node.getOutputs().at(0)}; + const auto input_index{node.getInputs().at(0)}; + + // Add PermuteLayer + std::vector<ITensor *> output_tensors{getTensor(output_index)}; + std::vector<ITensor *> input_tensors{getTensor(input_index)}; + + std::vector<ITensor *> output_deriv_tensors; + std::vector<ITensor *> input_deriv_tensors; + + auto input_deriv_tensor = getDerivativeTensor(input_index); + auto output_deriv_tensor = getDerivativeTensor(output_index); + output_deriv_tensors.emplace_back(output_deriv_tensor); + input_deriv_tensors.emplace_back(input_deriv_tensor); + + // NOTE IOTensors of graph outputs for passing data to users must be ignored in training + // because the buffers of those IOTensors are unnecessary and nullptr + bool ignore_forward_in_training = _whole_graph_outputs.contains(output_index); + auto fn = std::make_unique<kernel::PermuteLayer>(input_tensors, output_tensors, + input_deriv_tensors, output_deriv_tensors, + ignore_forward_in_training, _external_context); + + _return_fn = std::move(fn); +} + +backend::ITensor *KernelGenerator::getTensor(const ir::OperandIndex &index) +{ + // Get Tensor from all tensor registries (for Permute op) + auto ret = _tensor_registries.getITensor(index); + assert(ret != nullptr); + return ret; +} + +backend::ITensor *KernelGenerator::getDerivativeTensor(const ir::OperandIndex &index) +{ + // Get derivative Tensor from all tensor registries (for Permute op) + auto ret = _tensor_registries.getDerivativeITensor(index); + return ret; +} + +} // namespace train +} // namespace builtin +} // namespace backend +} // namespace onert diff --git a/runtime/onert/core/src/backend/builtin/train/KernelGenerator.h b/runtime/onert/core/src/backend/builtin/train/KernelGenerator.h new file mode 100644 index 000000000..d8781c0d0 --- /dev/null +++ b/runtime/onert/core/src/backend/builtin/train/KernelGenerator.h @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_BUTIN_TRAIN_KERNEL_GENERATOR_H__ +#define __ONERT_BACKEND_BUTIN_TRAIN_KERNEL_GENERATOR_H__ + +#include "../ExternalContext.h" +#include "../train/TensorRegistry.h" +#include "../../../compiler/train/TensorRegistries.h" + +#include <backend/train/KernelGeneratorBase.h> +#include <exec/train/TrainableFnSequence.h> +#include <ir/train/TrainableGraph.h> + +namespace onert +{ +namespace backend +{ +namespace builtin +{ +namespace train +{ + +class KernelGenerator : public backend::train::KernelGeneratorBase +{ +public: + KernelGenerator(const ir::train::TrainableGraph &tgraph, + const std::shared_ptr<TensorRegistry> &tensor_reg, + const std::shared_ptr<ExternalContext> &external_context); + + std::unique_ptr<exec::train::TrainableFnSequence> generate(ir::OperationIndex ind) override; + + void setTensorRegistries(const compiler::train::TensorRegistries &tensor_registries) + { + _tensor_registries = tensor_registries; + } + + void setWholeGraphOutputs(const ir::OperandIndexSequence &outputs) + { + _whole_graph_outputs = outputs; + } + +private: + void visit(const ir::train::operation::Permute &) override; + +private: + backend::ITensor *getTensor(const ir::OperandIndex &index); + backend::ITensor *getDerivativeTensor(const ir::OperandIndex &index); + +private: + std::shared_ptr<TensorRegistry> _tensor_reg; + compiler::train::TensorRegistries _tensor_registries; + const std::shared_ptr<ExternalContext> _external_context; + ir::OperandIndexSequence _whole_graph_outputs; +}; + +} // namespace train +} // namespace builtin +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_BUTIN_TRAIN_KERNEL_GENERATOR_H__ diff --git a/runtime/onert/core/src/backend/builtin/train/Tensor.h b/runtime/onert/core/src/backend/builtin/train/Tensor.h new file mode 100644 index 000000000..611407bd2 --- /dev/null +++ b/runtime/onert/core/src/backend/builtin/train/Tensor.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_BUILTIN_TRAIN_TRAINABLE_TENSOR_H__ +#define __ONERT_BACKEND_BUILTIN_TRAIN_TRAINABLE_TENSOR_H__ + +#include <backend/basic/train/TrainableTensor.h> + +namespace onert +{ +namespace backend +{ +namespace builtin +{ +namespace train +{ + +using TrainableTensor = basic::train::TrainableTensor; +using DerivativeTensor = basic::Tensor; +using GradientTensor = basic::Tensor; + +} // namespace train +} // namespace builtin +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_BUILTIN_TRAIN_TRAINABLE_TENSOR_H__ diff --git a/runtime/onert/core/src/backend/builtin/train/TensorRegistry.h b/runtime/onert/core/src/backend/builtin/train/TensorRegistry.h new file mode 100644 index 000000000..c48e5fe93 --- /dev/null +++ b/runtime/onert/core/src/backend/builtin/train/TensorRegistry.h @@ -0,0 +1,132 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_BUILTIN_TRAIN_TENSOR_REGISTRY_H__ +#define __ONERT_BACKEND_BUILTIN_TRAIN_TENSOR_REGISTRY_H__ + +#include <backend/train/ITensorRegistry.h> + +#include "../IOTensor.h" +#include "../Tensor.h" +#include "Tensor.h" + +namespace onert +{ +namespace backend +{ +namespace builtin +{ +namespace train +{ + +using BaseTensorRegistry = + backend::train::PortableTensorRegistryTemplate<Tensor, TrainableTensor, DerivativeTensor, + GradientTensor>; + +class TensorRegistry : public backend::train::ITensorRegistry +{ +public: + TensorRegistry() : _base_reg{new BaseTensorRegistry} {} + + ITensor *getITensor(const ir::OperandIndex &index) override + { + auto base_tensor = _base_reg->getITensor(index); + if (base_tensor) + return base_tensor; + return getNativeIOTensor(index); + } + + ITensor *getNativeITensor(const ir::OperandIndex &index) override + { + auto base_tensor = _base_reg->getNativeITensor(index); + if (base_tensor) + return base_tensor; + return getNativeIOTensor(index); + } + + IPortableTensor *getPortableTensor(const ir::OperandIndex &index) + { + auto base_tensor = _base_reg->getPortableTensor(index); + if (base_tensor) + return base_tensor; + return getNativeIOTensor(index); + } + + IOTensor *getNativeIOTensor(const ir::OperandIndex &index) + { + auto tensor = _native_io_tensors.find(index); + if (tensor != _native_io_tensors.end()) + return tensor->second.get(); + return nullptr; + } + + ITensor *getDerivativeITensor(const ir::OperandIndex &index) override + { + return _base_reg->getDerivativeTensor(index); + } + + ITensor *getGradientITensor(const ir::OperandIndex &index) override + { + return _base_reg->getGradientTensor(index); + } + + DerivativeTensor *getDerivativeTensor(const ir::OperandIndex &index) + { + return _base_reg->getDerivativeTensor(index); + } + + bool setMigrantTensor(const ir::OperandIndex &index, IPortableTensor *tensor) override + { + assert(tensor); + assert(!getITensor(index)); // For the index, tensor is not registered yet + _base_reg->setMigrantTensor(index, tensor); + return true; + } + + void setDerivativeTensor(const ir::OperandIndex &index, std::unique_ptr<DerivativeTensor> tensor) + { + _base_reg->setDerivativeTensor(index, std::move(tensor)); + } + + void setGradientTensor(const ir::OperandIndex &index, std::unique_ptr<GradientTensor> tensor) + { + _base_reg->setGradientTensor(index, std::move(tensor)); + } + + void setNativeIOTensor(ir::OperandIndex index, std::unique_ptr<IOTensor> &&tensor) + { + assert(tensor); + assert(!getITensor(index)); // For the index, tensor is not registered yet + _native_io_tensors[index] = std::move(tensor); + } + + const ir::OperandIndexMap<std::unique_ptr<IOTensor>> &native_io_tensors() + { + return _native_io_tensors; + } + std::shared_ptr<BaseTensorRegistry> base_reg() { return _base_reg; } + +private: + std::shared_ptr<BaseTensorRegistry> _base_reg; + ir::OperandIndexMap<std::unique_ptr<IOTensor>> _native_io_tensors; +}; + +} // namespace train +} // namespace builtin +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_BUILTIN_TRAIN_TENSOR_REGISTRY_H__ diff --git a/runtime/onert/core/src/backend/builtin/train/kernel/PermuteLayer.cc b/runtime/onert/core/src/backend/builtin/train/kernel/PermuteLayer.cc new file mode 100644 index 000000000..929092dde --- /dev/null +++ b/runtime/onert/core/src/backend/builtin/train/kernel/PermuteLayer.cc @@ -0,0 +1,85 @@ + + +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "PermuteLayer.h" + +namespace onert +{ +namespace backend +{ +namespace builtin +{ +namespace train +{ +namespace kernel +{ + +PermuteLayer::PermuteLayer(const std::vector<ITensor *> &src_tensors, + const std::vector<ITensor *> &dst_tensors, + const std::vector<ITensor *> &input_deriv_tensors, + const std::vector<ITensor *> &output_deriv_tensors, + bool ignore_forward_in_training, + const std::shared_ptr<ExternalContext> &external_context) + : builtin::kernel::PermuteLayer{src_tensors, dst_tensors, external_context}, + _input_deriv_tensors{input_deriv_tensors}, _output_deriv_tensors{output_deriv_tensors}, + _ignore_forward_in_training{ignore_forward_in_training} +{ + assert(input_deriv_tensors.size() == output_deriv_tensors.size()); + assert(src_tensors.size() == dst_tensors.size()); +} + +void PermuteLayer::optimize() +{ + builtin::kernel::PermuteLayer::optimize(); + + // TODO Calculate offsets of derivative tensors if necessary +} + +void PermuteLayer::forward(bool training) +{ + if (training && _ignore_forward_in_training) + return; + + builtin::kernel::PermuteLayer::run(); +} + +void PermuteLayer::backward() +{ + for (uint32_t i = 0; i < _output_deriv_tensors.size(); ++i) + { + auto src_deriv = _output_deriv_tensors.at(i); + auto dst_deriv = _input_deriv_tensors.at(i); + + // NOTE The derivative tensors corresponding to inputs/outputs of model are nullptr + // because permuting those tensors is meaningless + if (src_deriv && dst_deriv) + { + const auto rank = src_deriv->getShape().rank(); + auto output_offsets = _dst_tensors_offsets.at(i); + auto input_offsets = _src_tensors_offsets.at(i); + + exec::IPermuteFunction::permute(src_deriv, dst_deriv, rank, output_offsets, input_offsets); + } + } +} + +} // namespace kernel +} // namespace train +} // namespace builtin +} // namespace backend +} // namespace onert diff --git a/runtime/onert/core/src/backend/builtin/train/kernel/PermuteLayer.h b/runtime/onert/core/src/backend/builtin/train/kernel/PermuteLayer.h new file mode 100644 index 000000000..de8063a21 --- /dev/null +++ b/runtime/onert/core/src/backend/builtin/train/kernel/PermuteLayer.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_BACKEND_BUILTIN_TRAIN_KERNEL_PERMUTELAYER_H__ +#define __ONERT_BACKEND_BUILTIN_TRAIN_KERNEL_PERMUTELAYER_H__ + +#include "../../kernel/PermuteLayer.h" + +#include "exec/train/ITrainableFunction.h" + +namespace onert +{ +namespace backend +{ +namespace builtin +{ +namespace train +{ +namespace kernel +{ + +class PermuteLayer : public builtin::kernel::PermuteLayer, public exec::train::ITrainableFunction +{ +public: + PermuteLayer(const std::vector<ITensor *> &src_tensors, const std::vector<ITensor *> &dst_tensors, + const std::vector<ITensor *> &input_deriv_tensors, + const std::vector<ITensor *> &output_deriv_tensors, bool ignore_forward_in_training, + const std::shared_ptr<ExternalContext> &external_context); + + void optimize() override; + + void forward(bool training) override; + void backward() override; + +private: + std::vector<ITensor *> _input_deriv_tensors; + std::vector<ITensor *> _output_deriv_tensors; + bool _ignore_forward_in_training; +}; + +} // namespace kernel +} // namespace train +} // namespace builtin +} // namespace backend +} // namespace onert + +#endif // __ONERT_BACKEND_BUILTIN_TRAIN_KERNEL_PERMUTELAYER_H__ diff --git a/runtime/onert/core/src/compiler/Compiler.cc b/runtime/onert/core/src/compiler/Compiler.cc index 45124556b..ba621bb4f 100644 --- a/runtime/onert/core/src/compiler/Compiler.cc +++ b/runtime/onert/core/src/compiler/Compiler.cc @@ -16,6 +16,7 @@ #include "compiler/Compiler.h" +#include "CompilerHelpers.h" #include "ExecutorFactory.h" #include "ShapeValidator.h" #include "pass/ConstantOutputPass.h" @@ -30,6 +31,7 @@ #include "compiler/StaticShapeInferer.h" #include <misc/string_helpers.h> +#include <misc/polymorphic_downcast.h> namespace onert { @@ -69,10 +71,25 @@ std::shared_ptr<CompilerArtifact> Compiler::compile(void) throw std::runtime_error("Profiling mode works only with 'Dataflow' executor"); } + if (!_options->minmax_filepath.empty()) + { + if (_options->executor != "Linear") + throw std::runtime_error("Recording minmax works only with Linear executor"); + } + + if (!_model->hasOnly<ir::Graph>()) + { + throw std::runtime_error("Compiler can only compile models for inference."); + } + _options->forceInternalOptions(); _options->verboseOptions(); - _model->iterate([&](const ir::SubgraphIndex &, ir::Graph &subg) { + auto custom_kernel_builder = _model->getKernelBuilder(); + + _model->iterate([&](const ir::SubgraphIndex &, ir::IGraph &graph) { + auto &subg = nnfw::misc::polymorphic_downcast<ir::Graph &>(graph); + // Mandatory passes pass::PassRunner{} .append(std::make_unique<pass::ConstantOutputPass>(subg)) @@ -96,7 +113,9 @@ std::shared_ptr<CompilerArtifact> Compiler::compile(void) // Lower: Assign backend std::unordered_map<ir::SubgraphIndex, std::unique_ptr<compiler::LoweredGraph>> lowered_subgs; { - _model->iterate([&](const ir::SubgraphIndex &subg_index, ir::Graph &subg) { + _model->iterate([&](const ir::SubgraphIndex &subg_index, ir::IGraph &graph) { + auto &subg = nnfw::misc::polymorphic_downcast<ir::Graph &>(graph); + // Lower: Assign backend lowered_subgs[subg_index] = std::make_unique<compiler::LoweredGraph>(subg, *_options); // Set tracing_ctx for copied graph @@ -119,7 +138,7 @@ std::shared_ptr<CompilerArtifact> Compiler::compile(void) // Run the StaticShapeInfer of primary subg. All child StaticShapeInferers are called // recursively std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>> inferers = - StaticShapeInferer::createStaticShapeInferers(lowered_subgs); + createStaticShapeInferers(lowered_subgs); const auto primary_subg_idx = ir::SubgraphIndex{0}; inferers.at(primary_subg_idx)->infer(); @@ -158,10 +177,15 @@ std::shared_ptr<CompilerArtifact> Compiler::compile(void) ir::OperationDumper dumper("Executor generation of Subgraph " + std::to_string(subg_index.value())); lowered_subg->graph().operations().iterate( - [&](const ir::OperationIndex &, const ir::Operation &op) { op.accept(dumper); }); - - auto executor = std::unique_ptr<exec::IExecutor>{ExecutorFactory::get().create( - std::move(lowered_subg), tracing_ctx.get(), *_options, executors, model_index)}; + [&](const ir::OperationIndex &, const ir::IOperation &op) { op.accept(dumper); }); + + ExecutorFactoryArgs args; + args.tracing_ctx = tracing_ctx.get(); + args.options = _options; + args.model_index = model_index; + args.custom_kernel_builder = custom_kernel_builder; + auto executor = std::unique_ptr<exec::IExecutor>{ + ExecutorFactory::get().create(std::move(lowered_subg), executors, args)}; executor->setIndexedRanks(indexed_ranks); executors->emplace(model_index, subg_index, std::move(executor)); } diff --git a/runtime/onert/core/src/compiler/CompilerFactory.cc b/runtime/onert/core/src/compiler/CompilerFactory.cc index d8d4bb277..aeb0876c4 100644 --- a/runtime/onert/core/src/compiler/CompilerFactory.cc +++ b/runtime/onert/core/src/compiler/CompilerFactory.cc @@ -17,6 +17,9 @@ #include "compiler/CompilerFactory.h" #include "MultiModelCompiler.h" +#ifdef ONERT_TRAIN +#include "train/TrainingCompiler.h" +#endif // ONERT_TRAIN #include "compiler/Compiler.h" @@ -33,8 +36,18 @@ CompilerFactory &CompilerFactory::get() std::unique_ptr<ICompiler> CompilerFactory::create(const std::shared_ptr<ir::NNPkg> &nnpkg, - std::vector<std::unique_ptr<CompilerOptions>> &copts) + std::vector<std::unique_ptr<CompilerOptions>> &copts, + const compiler::train::TrainingInfo *training_info) { +#ifdef ONERT_TRAIN + // Returing compiler for training + if (training_info) + return std::make_unique<train::TrainingCompiler>(nnpkg, copts, *training_info); +#else // ONERT_TRAIN + (void)training_info; +#endif // ONERT_TRAIN + + // Returing compiler for inference if (nnpkg->model_count() == 1) return std::make_unique<Compiler>(nnpkg, copts); diff --git a/runtime/onert/core/src/compiler/CompilerHelpers.h b/runtime/onert/core/src/compiler/CompilerHelpers.h new file mode 100644 index 000000000..798334b3b --- /dev/null +++ b/runtime/onert/core/src/compiler/CompilerHelpers.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_COMPILER_COMPILER_HELPERS_H__ +#define __ONERT_COMPILER_COMPILER_HELPERS_H__ + +#include <compiler/ILoweredGraph.h> +#include <compiler/StaticShapeInferer.h> +#include <ir/Index.h> + +#include <memory> +#include <unordered_map> + +namespace onert +{ +namespace compiler +{ + +/** + * @brief Create a shape inferer map for a lowered model + * @param[in] lowered_subgs lowered model map + * @return Shape inferer map + */ +template <typename LoweredGraphType, + typename = std::enable_if_t<std::is_base_of<ILoweredGraph, LoweredGraphType>::value>> +static std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>> +createStaticShapeInferers( + const std::unordered_map<ir::SubgraphIndex, std::unique_ptr<LoweredGraphType>> &lowered_subgs) +{ + std::unordered_map<ir::SubgraphIndex, ILoweredGraph *> lsubgs; + for (auto &&e : lowered_subgs) + lsubgs[e.first] = e.second.get(); + return StaticShapeInferer::createStaticShapeInferers(lsubgs); +} + +} // namespace compiler +} // namespace onert + +#endif // __ONERT_COMPILER_COMPILER_HELPERS_H__ diff --git a/runtime/onert/core/src/compiler/CompilerOptions.cc b/runtime/onert/core/src/compiler/CompilerOptions.cc index b5fd392e0..830d9dd00 100644 --- a/runtime/onert/core/src/compiler/CompilerOptions.cc +++ b/runtime/onert/core/src/compiler/CompilerOptions.cc @@ -75,6 +75,7 @@ std::unique_ptr<CompilerOptions> CompilerOptions::fromGlobalConfig() { auto o = std::make_unique<CompilerOptions>(); o->backend_list = nnfw::misc::split(util::getConfigString(util::config::BACKENDS), ';'); + o->minmax_filepath = util::getConfigString(util::config::MINMAX_FILEPATH); o->trace_filepath = util::getConfigString(util::config::TRACE_FILEPATH); o->graph_dump_level = util::getConfigInt(util::config::GRAPH_DOT_DUMP); o->executor = util::getConfigString(util::config::EXECUTOR); diff --git a/runtime/onert/core/src/compiler/ExecutorFactory.cc b/runtime/onert/core/src/compiler/ExecutorFactory.cc index b09d6b021..6a08524cc 100644 --- a/runtime/onert/core/src/compiler/ExecutorFactory.cc +++ b/runtime/onert/core/src/compiler/ExecutorFactory.cc @@ -25,6 +25,9 @@ #include "../exec/ExecTime.h" #include "../exec/ExecutionObservers.h" #include "../exec/LinearExecutor.h" +#ifdef MINMAX_H5DUMPER +#include "../exec/MinMaxRecorder.h" +#endif #include "../exec/ParallelExecutor.h" #include "../ir/OperationCloner.h" @@ -36,6 +39,14 @@ #include <functional> #include <memory> +#ifdef ONERT_TRAIN +#include "../backend/builtin/train/BackendContext.h" +#include "../exec/train/TrainableExecutor.h" + +#include <backend/train/TrainableBackendContext.h> +#include <backend/train/ITrainableBackend.h> +#endif // ONERT_TRAIN + namespace onert { namespace @@ -74,7 +85,7 @@ public: void run() override { - for (auto tensor : _dealloc_list) + for (auto &&tensor : _dealloc_list) { if (!tensor->is_dynamic()) continue; @@ -86,7 +97,8 @@ private: DeallocList _dealloc_list; }; -void initializeSubgraphIOTensors(compiler::LoweredGraph &lowered_graph, +// TODO Unify initializeSubgraphIOTensors +void initializeSubgraphIOTensors(compiler::ILoweredGraph &lowered_graph, const backend::BackendContexts &backend_contexts, const ir::OperandIndexSequence &indices) { @@ -104,7 +116,38 @@ void initializeSubgraphIOTensors(compiler::LoweredGraph &lowered_graph, } assert(builtin_tensor_reg); - for (auto ind : indices) + for (auto &&ind : indices) + { + const auto &operand = lowered_graph.graph().operands().at(ind); + auto tensor = std::make_unique<backend::builtin::IOTensor>( + operand.info(), + ir::Layout::NHWC /* FIXME find operation for this operand and use frontend_layout */ + ); + + // Add tensor to builtin TensorRegistry. + builtin_tensor_reg->setNativeIOTensor(ind, std::move(tensor)); + } +} + +#ifdef ONERT_TRAIN +void initializeSubgraphIOTensors(compiler::ILoweredGraph &lowered_graph, + const backend::train::TrainableBackendContexts &backend_contexts, + const ir::OperandIndexSequence &indices) +{ + std::shared_ptr<backend::builtin::train::TensorRegistry> builtin_tensor_reg; + for (const auto &e : backend_contexts) + { + auto backend = e.first; + auto &context = e.second; + if (backend->config()->id() == backend::builtin::Config::ID) + { + builtin_tensor_reg = std::dynamic_pointer_cast<backend::builtin::train::TensorRegistry>( + context->tensor_registry()); + } + } + assert(builtin_tensor_reg); + + for (auto &&ind : indices) { const auto &operand = lowered_graph.graph().operands().at(ind); auto tensor = std::make_unique<backend::builtin::IOTensor>( @@ -116,8 +159,11 @@ void initializeSubgraphIOTensors(compiler::LoweredGraph &lowered_graph, builtin_tensor_reg->setNativeIOTensor(ind, std::move(tensor)); } } +#endif // ONERT_TRAIN -backend::BackendContexts createBackendContexts(compiler::LoweredGraph &lgraph, bool linear_executor) +backend::BackendContexts +createBackendContexts(compiler::ILoweredGraph &lgraph, bool linear_executor, + std::shared_ptr<backend::custom::IKernelBuilder> custom_kernel_builder) { backend::BackendContexts contexts; auto &backend_manager = compiler::BackendManager::get(); @@ -125,7 +171,7 @@ backend::BackendContexts createBackendContexts(compiler::LoweredGraph &lgraph, b std::unordered_map<const backend::Backend *, backend::ContextData> context_data_map; // Generate partial graphs for each backend - for (auto backend : backend_manager.getAll()) + for (auto &&backend : backend_manager.getAll()) { auto &data = context_data_map[backend]; auto graph = std::make_unique<ir::Graph>(); @@ -157,7 +203,7 @@ backend::BackendContexts createBackendContexts(compiler::LoweredGraph &lgraph, b }); // Separate operations into partial graphs whole_graph.operations().iterate( - [&](const ir::OperationIndex &op_ind, const ir::Operation &operation) { + [&](const ir::OperationIndex &op_ind, const ir::IOperation &operation) { auto &op_li = lgraph.lower_info().operation; auto backend = op_li.at(op_ind).backend(); auto &partial_graph = *context_data_map[backend].graph; @@ -168,7 +214,7 @@ backend::BackendContexts createBackendContexts(compiler::LoweredGraph &lgraph, b // Add missing operands (externals) auto io_list = (operation.getInputs() + operation.getOutputs()) | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED; - for (auto operand_ind : io_list) + for (auto &&operand_ind : io_list) { if (partial_graph.operands().exist(operand_ind)) continue; @@ -217,12 +263,33 @@ backend::BackendContexts createBackendContexts(compiler::LoweredGraph &lgraph, b std::copy_if(whole_op_order.begin(), whole_op_order.end(), std::back_inserter(data.op_order), [&](const auto &ind) { return data.graph->operations().exist(ind); }); data.is_linear_executor = linear_executor; - data.custom_kernel_builder = lgraph.graph().getKernelBuilder(); + data.custom_kernel_builder = custom_kernel_builder; contexts.emplace(backend, backend->newContext(std::move(data))); } return contexts; } +template <typename Context> +std::deque<std::pair<const backend::Backend *, Context *>> orderBackendContext( + const std::unordered_map<const backend::Backend *, std::unique_ptr<Context>> &tbackend_contexts) +{ + std::deque<std::pair<const backend::Backend *, Context *>> ordered_contexts; + + for (auto &&pair : tbackend_contexts) + { + // NOTE builtin backend must be processed lastly. + // This is because of Permute layer's specialty which is the only operation that could have + // different ITensor objects for the input and the output. And it requires all other backends' + // tensors are ready to use. + if (pair.first->config()->id() == "builtin") + ordered_contexts.emplace_back(pair.first, pair.second.get()); + else + ordered_contexts.emplace_front(pair.first, pair.second.get()); + } + + return ordered_contexts; +} + } // namespace } // namespace onert @@ -240,34 +307,30 @@ ExecutorFactory &ExecutorFactory::get() ExecutorFactory::ExecutorFactory() { _map["Linear"] = createLinearExecutor; - _map["Dataflow"] = - std::bind(createDataflowExecutor, std::placeholders::_1, std::placeholders::_2, - std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, false); - _map["Parallel"] = - std::bind(createDataflowExecutor, std::placeholders::_1, std::placeholders::_2, - std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, true); + _map["Dataflow"] = std::bind(createDataflowExecutor, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, false); + _map["Parallel"] = std::bind(createDataflowExecutor, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, true); } exec::IExecutor *ExecutorFactory::create(std::unique_ptr<compiler::LoweredGraph> lowered_graph, - const util::TracingCtx *tracing_ctx, - const compiler::CompilerOptions &options, const std::shared_ptr<exec::IExecutors> &executors, - const ir::ModelIndex &index) + const ExecutorFactoryArgs &args) { - return _map.at(options.executor)(std::move(lowered_graph), tracing_ctx, options, executors, - index); + assert(args.options != nullptr); + return _map.at(args.options->executor)(std::move(lowered_graph), executors, args); } -void ExecutorFactory::prepareMigrantTensors(compiler::LoweredGraph &lowered_graph, +void ExecutorFactory::prepareMigrantTensors(compiler::ILoweredGraph &lowered_graph, const backend::BackendContexts &backend_contexts) { TensorRegistries tensor_regs{backend_contexts, true}; lowered_graph.graph().operations().iterate( - [&](const ir::OperationIndex &op_ind, const ir::Operation &op) { + [&](const ir::OperationIndex &op_ind, const ir::IOperation &op) { auto lower_info = lowered_graph.lower_info().operation.getRawPtr(op_ind); auto &backend_ctx = backend_contexts.at(lower_info->backend()); - for (auto ind : + for (auto &&ind : (op.getInputs() + op.getOutputs()) | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED) { // If an Operation's input/output tensor does not have an own tensor object, @@ -307,7 +370,6 @@ std::deque<std::pair<const backend::Backend *, backend::BackendContext *>> ExecutorFactory::orderBackendContext(const backend::BackendContexts &backend_contexts) { std::deque<std::pair<const backend::Backend *, backend::BackendContext *>> ordered_contexts; - for (auto &&pair : backend_contexts) { // NOTE builtin backend must be processed lastly. @@ -319,19 +381,22 @@ ExecutorFactory::orderBackendContext(const backend::BackendContexts &backend_con else ordered_contexts.emplace_front(pair.first, pair.second.get()); } - return ordered_contexts; } -exec::IExecutor *ExecutorFactory::createLinearExecutor( - std::unique_ptr<compiler::LoweredGraph> lowered_graph, const util::TracingCtx *tracing_ctx, - const compiler::CompilerOptions &options, const std::shared_ptr<exec::IExecutors> &executors, - const ir::ModelIndex &index) +exec::IExecutor * +ExecutorFactory::createLinearExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph, + const std::shared_ptr<exec::IExecutors> &executors, + const ExecutorFactoryArgs &args) { + const auto options = args.options; + const auto &model_index = args.model_index; + const auto tracing_ctx = args.tracing_ctx; + auto custom_kernel_builder = args.custom_kernel_builder; auto &graph = lowered_graph->graph(); backend::BackendContexts backend_contexts = - createBackendContexts(*lowered_graph, options.executor == "Linear"); + createBackendContexts(*lowered_graph, options->executor == "Linear", custom_kernel_builder); TensorRegistries tensor_regs{backend_contexts, true}; @@ -352,7 +417,7 @@ exec::IExecutor *ExecutorFactory::createLinearExecutor( prepareMigrantTensors(*lowered_graph, backend_contexts); // Give some runtime objects to builtin KernelGenerator - prepareBuiltinBackend(tensor_regs, executors, backend_contexts, index); + prepareBuiltinBackend(tensor_regs, executors, backend_contexts, model_index); ExecutionBuilder builder; @@ -382,7 +447,7 @@ exec::IExecutor *ExecutorFactory::createLinearExecutor( uses_map[ind]++; } - for (const auto op_ind : order) + for (const auto &op_ind : order) { const auto &op = graph.operations().at(op_ind); auto op_inputs = op.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED; @@ -422,7 +487,7 @@ exec::IExecutor *ExecutorFactory::createLinearExecutor( auto &fn_seq = pair.second; auto &op = lowered_graph->graph().operations().at(op_ind); auto lower_info = lowered_graph->lower_info().operation.getRawPtr(op_ind); - if (options.he_profiling_mode) + if (options->he_profiling_mode) fn_seq->wrap<SyncFunction>(lower_info->backend()->config()); if (!dealloc_list_map[op_ind].empty()) fn_seq->append(std::make_unique<DeallocFunction>(dealloc_list_map[op_ind])); @@ -439,23 +504,33 @@ exec::IExecutor *ExecutorFactory::createLinearExecutor( order, tracing_ctx}; - if (!options.trace_filepath.empty()) + if (!options->trace_filepath.empty()) { std::unique_ptr<exec::IExecutionObserver> ctp = - std::make_unique<exec::TracingObserver>(options.trace_filepath, exec->graph(), tracing_ctx); + std::make_unique<exec::TracingObserver>(options->trace_filepath, exec->graph(), tracing_ctx); exec->addObserver(std::move(ctp)); } +#ifdef MINMAX_H5DUMPER + if (!options->minmax_filepath.empty()) + exec->addObserver(std::make_unique<exec::MinMaxRecorder>( + options->minmax_filepath, exec->graph(), exec->getBackendContexts())); +#endif return exec; } -exec::IExecutor *ExecutorFactory::createDataflowExecutor( - std::unique_ptr<compiler::LoweredGraph> lowered_graph, const util::TracingCtx *tracing_ctx, - const compiler::CompilerOptions &options, const std::shared_ptr<exec::IExecutors> &executors, - const ir::ModelIndex &index, bool parallel) +exec::IExecutor * +ExecutorFactory::createDataflowExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph, + const std::shared_ptr<exec::IExecutors> &executors, + const ExecutorFactoryArgs &args, bool parallel) { + const auto options = args.options; + const auto &model_index = args.model_index; + const auto tracing_ctx = args.tracing_ctx; + auto custom_kernel_builder = args.custom_kernel_builder; + backend::BackendContexts backend_contexts = - createBackendContexts(*lowered_graph, options.executor == "Linear"); + createBackendContexts(*lowered_graph, options->executor == "Linear", custom_kernel_builder); TensorRegistries tensor_regs{backend_contexts, true}; @@ -472,7 +547,7 @@ exec::IExecutor *ExecutorFactory::createDataflowExecutor( prepareMigrantTensors(*lowered_graph, backend_contexts); // Give some runtime objects to builtin KernelGenerator - prepareBuiltinBackend(tensor_regs, executors, backend_contexts, index); + prepareBuiltinBackend(tensor_regs, executors, backend_contexts, model_index); ExecutionBuilder builder; @@ -489,7 +564,7 @@ exec::IExecutor *ExecutorFactory::createDataflowExecutor( auto &fn_seq = pair.second; auto &op = lowered_graph->graph().operations().at(op_ind); auto lower_info = lowered_graph->lower_info().operation.getRawPtr(op_ind); - if (options.he_profiling_mode) + if (options->he_profiling_mode) fn_seq->wrap<SyncFunction>(lower_info->backend()->config()); builder.append(op_ind, {op_ind, &op, lower_info, std::move(fn_seq)}); } @@ -508,7 +583,7 @@ exec::IExecutor *ExecutorFactory::createDataflowExecutor( auto dataflow_exec = new exec::DataflowExecutor{std::move(lowered_graph), std::move(backend_contexts), tensor_regs, std::move(code_map), tracing_ctx}; - if (options.he_profiling_mode) + if (options->he_profiling_mode) { std::vector<const backend::Backend *> backends; for (const auto &pair : backend_contexts) @@ -523,15 +598,304 @@ exec::IExecutor *ExecutorFactory::createDataflowExecutor( exec = dataflow_exec; } - if (!options.trace_filepath.empty()) + if (!options->trace_filepath.empty()) + { + std::unique_ptr<exec::IExecutionObserver> ctp = + std::make_unique<exec::TracingObserver>(options->trace_filepath, exec->graph(), tracing_ctx); + exec->addObserver(std::move(ctp)); + } + + return exec; +} + +#ifdef ONERT_TRAIN +exec::IExecutor * +ExecutorFactory::create(std::unique_ptr<compiler::train::LoweredTrainableGraph> lowered_graph, + const std::shared_ptr<exec::IExecutors> &executors, + const ExecutorFactoryArgs &args, + const std::shared_ptr<exec::train::optimizer::Optimizer> &optimizer) +{ + assert(args.options != nullptr); + + if (args.options->executor != "Linear") + throw std::runtime_error("ExecutorFactory: TrainableExecutor supports only 'Linear' now"); + + return createTrainableExecutor(std::move(lowered_graph), executors, args, optimizer); +} + +void ExecutorFactory::prepareMigrantTensors( + compiler::ILoweredGraph &lowered_graph, + const backend::train::TrainableBackendContexts &backend_contexts) +{ + train::TensorRegistries tensor_regs{backend_contexts, true}; + + lowered_graph.graph().operations().iterate( + [&](const ir::OperationIndex &op_ind, const ir::IOperation &op) { + auto lower_info = lowered_graph.lower_info().operation.getRawPtr(op_ind); + auto &backend_ctx = backend_contexts.at(lower_info->backend()); + for (auto &&ind : + (op.getInputs() + op.getOutputs()) | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED) + { + // If an Operation's input/output tensor does not have an own tensor object, + // it must be using migrant tensors, so find the tensor from other tensor registries and + // register it to the current tensor registry if it is portable + if (!backend_ctx->tensor_registry()->getITensor(ind)) + { + auto tensor = tensor_regs.getITensor(ind); + assert(tensor); // The tensor must have been registered + auto ptensor = dynamic_cast<backend::IPortableTensor *>(tensor); + if (ptensor) + backend_ctx->tensor_registry()->setMigrantTensor(ind, ptensor); + } + } + }); +} + +exec::IExecutor *ExecutorFactory::createTrainableExecutor( + std::unique_ptr<compiler::train::LoweredTrainableGraph> lowered_graph, + const std::shared_ptr<exec::IExecutors> &, const ExecutorFactoryArgs &args, + const std::shared_ptr<exec::train::optimizer::Optimizer> &optimizer) +{ + const auto options = args.options; + const auto tracing_ctx = args.tracing_ctx; + auto custom_kernel_builder = args.custom_kernel_builder; + + auto &graph = lowered_graph->graph(); + + lowered_graph->trainable_graph().operations().iterate([](const onert::ir::OperationIndex &, + const onert::ir::IOperation &op) { + try + { + UNUSED_RELEASE(dynamic_cast<const ir::train::ITrainableOperation &>(op)); + } + catch (std::bad_cast &) + { + throw std::runtime_error("ExecutorFactory: " + op.name() + " is not trainable operation yet"); + } + }); + + // TODO Create context only once instead of replacing + backend::train::TrainableBackendContexts tbackend_contexts; + backend::BackendContexts base_backend_contexts = + createBackendContexts(*lowered_graph, true, custom_kernel_builder); + + // Replace BackendContext with TrainbleBackendContext + for (auto &&pair : base_backend_contexts) + { + auto ctx = pair.second.get(); + const auto &data = ctx->data(); + + // Create partial and trainable graphs + auto tgraph = std::make_unique<ir::train::TrainableGraph>(*data.graph); + data.graph->operations().iterate( + [&](const onert::ir::OperationIndex &op_index, const onert::ir::IOperation &) { + const auto &orig_tgraph = lowered_graph->trainable_graph(); + const auto &trainable_op = orig_tgraph.operation(op_index); + auto gen_index = tgraph->replaceOperation(op_index, trainable_op.clone()); + UNUSED_RELEASE(gen_index); + assert(gen_index == op_index); + }); + data.graph->operands().iterate([&](const ir::OperandIndex &index, const ir::Operand &) { + const auto &orig_tgraph = lowered_graph->trainable_graph(); + if (orig_tgraph.derivatives().exist(index)) + { + const auto &deriv = orig_tgraph.derivatives().at(index); + auto new_deriv = std::make_unique<ir::Operand>(deriv); + auto gen_index = tgraph->addDerivative(index, std::move(new_deriv)); + UNUSED_RELEASE(gen_index); + assert(gen_index == index); + } + }); + + // Remove outputs of whole graph from external_operands + auto external_operands = data.external_operands; + for (const auto &index : lowered_graph->trainable_graph().getOutputs()) + { + if (external_operands.contains(index)) + external_operands.remove(index); + } + + // Set trainable context data + backend::train::TrainableContextData tdata; + tdata.tgraph = std::move(tgraph); + tdata.op_order = std::move(data.op_order); + tdata.external_operands = std::move(external_operands); + tdata.operand_layouts = std::move(data.operand_layouts); + tdata.custom_kernel_builder = std::move(data.custom_kernel_builder); + tdata.is_linear_executor = data.is_linear_executor; + tdata.optimizer = optimizer; + + // TODO Remove dynamic_cast + try + { + const auto backend = pair.first; + const auto tbackend = dynamic_cast<const backend::train::ITrainableBackend *>(backend); + tbackend_contexts.emplace(backend, tbackend->newContext(std::move(tdata))); + } + catch (const std::bad_cast &) + { + throw std::runtime_error("ExecutorFactory: Invalid backend - TrainableExecutor does not " + "support non-trainble backends"); + } + } + base_backend_contexts.clear(); + + train::TensorRegistries tensor_regs{tbackend_contexts, true}; + + initializeSubgraphIOTensors( + *lowered_graph, tbackend_contexts, + (lowered_graph->graph().getInputs() + lowered_graph->graph().getOutputs()) | + ir::Remove::DUPLICATED | ir::Remove::UNDEFINED); + + // linearize + auto order = Linear::linearize(*lowered_graph); + Linear::dump(*lowered_graph, order); + + for (auto &&pair : tbackend_contexts) + { + pair.second->genTensors(); + } + + for (auto &&pair : tbackend_contexts) + { + auto tctx = pair.second.get(); + tctx->genTrainingTensors(); + } + + prepareMigrantTensors(*lowered_graph, tbackend_contexts); + + // Give some runtime objects to builtin KernelGenerator + for (auto &&pair : tbackend_contexts) + { + auto builtin_context = + dynamic_cast<backend::builtin::train::BackendContext *>(pair.second.get()); + if (builtin_context != nullptr) + { + auto builtin_kernel_gen = builtin_context->kernel_gen; + builtin_kernel_gen->setTensorRegistries(tensor_regs); + builtin_kernel_gen->setWholeGraphOutputs(lowered_graph->trainable_graph().getOutputs()); + } + } + + // Adjust the order of backends for the upcoming iteration + auto ordered_contexts = + onert::orderBackendContext<backend::train::TrainableBackendContext>(tbackend_contexts); + + // TODO Remove this simulation + // Simulate the execution for deallocation of tensors + std::unordered_map<ir::OperationIndex, DeallocList> dealloc_list_map; + { + ir::OperandIndexMap<uint32_t> uses_map; + ir::OperandIndexSequence constants; + + auto model_io = + (graph.getInputs() + graph.getOutputs()) | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED; + + // Prepare scanning + graph.operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &obj) { + uses_map[ind] = obj.getUses().size(); + + if (obj.isConstant()) + constants.append(ind); + }); + + // A trick to consider constants as an execption + for (const auto &ind : constants) + { + uses_map[ind]++; + } + + for (const auto op_ind : order) + { + const auto &op = graph.operations().at(op_ind); + auto op_inputs = op.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED; + auto op_outputs = op.getOutputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED; + + for (const auto &ind : op_inputs) + { + const auto &operand = graph.operands().at(ind); + assert(uses_map.find(ind) != uses_map.end()); + assert(uses_map[ind] > 0); + uses_map[ind]--; + if (uses_map[ind] == 0 && !operand.info().isVariable() && !model_io.contains(ind)) + { + dealloc_list_map[op_ind].emplace_back(tensor_regs.getITensor(ind)); + } + } + } + + // Dispose and validate + for (const auto &ind : constants) + { + --uses_map[ind]; + } + + assert( + std::all_of(uses_map.begin(), uses_map.end(), + [](std::pair<const ir::OperandIndex, uint32_t> it) { return it.second == 0; })); + } + + // Check derivative tensors + { + // TODO Support multiple subgraphs + // Check if the derivative tensors corresponding to inputs of model are nullptr + // NOTE The derivative tensors corresponding to inputs of model are for inputs of PermuteLayers + // and they are nullptr and because they are meaningless. + assert(std::all_of(lowered_graph->trainable_graph().getInputs().begin(), + lowered_graph->trainable_graph().getInputs().end(), + [&](const auto &input_idx) { + return tensor_regs.getDerivativeITensor(input_idx) == nullptr; + })); + + // Check if the derivative tensors corresponding to outputs of model exist + assert(std::all_of(lowered_graph->trainable_graph().getOutputs().begin(), + lowered_graph->trainable_graph().getOutputs().end(), + [&](const auto &output_idx) { + return tensor_regs.getDerivativeITensor(output_idx) == nullptr; + })); + } + + train::TrainableCodeMap code_map; + // Generate kernels + for (auto &&pair : ordered_contexts) + { + auto codes = pair.second->genKernels(); + for (auto &&pair : codes) + { + auto &op_ind = pair.first; + auto &tn_seq = pair.second; + auto &op = lowered_graph->trainable_graph().operation(op_ind); + auto lower_info = lowered_graph->lower_info().operation.getRawPtr(op_ind); + + assert(code_map.find(op_ind) == code_map.end()); + code_map.insert( + {op_ind, train::TrainableCodeAndInfo{op_ind, &op, lower_info, std::move(tn_seq)}}); + } + } + + if (order.size() != code_map.size()) + { + throw std::runtime_error("ExecutorFactory: Some kernels are not generated"); + } + + auto exec = new exec::train::TrainableExecutor{std::move(lowered_graph), + std::move(tbackend_contexts), + tensor_regs, + std::move(code_map), + order, + tracing_ctx}; + + if (!options->trace_filepath.empty()) { std::unique_ptr<exec::IExecutionObserver> ctp = - std::make_unique<exec::TracingObserver>(options.trace_filepath, exec->graph(), tracing_ctx); + std::make_unique<exec::TracingObserver>(options->trace_filepath, exec->graph(), tracing_ctx); exec->addObserver(std::move(ctp)); } + // TODO Support MINMAX_H5DUMPER return exec; } +#endif // ONERT_TRAIN } // namespace compiler } // namespace onert diff --git a/runtime/onert/core/src/compiler/ExecutorFactory.h b/runtime/onert/core/src/compiler/ExecutorFactory.h index f8f989043..cc621bccf 100644 --- a/runtime/onert/core/src/compiler/ExecutorFactory.h +++ b/runtime/onert/core/src/compiler/ExecutorFactory.h @@ -20,7 +20,15 @@ #include "TensorRegistries.h" #include "backend/ITensor.h" + +#ifdef ONERT_TRAIN +#include "backend/train/TrainableBackendContext.h" +#endif // ONERT_TRAIN #include "compiler/LoweredGraph.h" +#ifdef ONERT_TRAIN +#include "compiler/train/LoweredTrainableGraph.h" +#include "exec/train/optimizer/Optimizer.h" +#endif // ONERT_TRAIN #include "exec/IExecutors.h" #include <deque> @@ -31,6 +39,15 @@ namespace onert namespace compiler { +// TODO Change to a better name +struct ExecutorFactoryArgs +{ + const util::TracingCtx *tracing_ctx; + const compiler::CompilerOptions *options; + ir::ModelIndex model_index; + std::shared_ptr<backend::custom::IKernelBuilder> custom_kernel_builder; +}; + class ExecutorFactory { public: @@ -38,16 +55,22 @@ public: public: exec::IExecutor *create(std::unique_ptr<compiler::LoweredGraph> lowered_graph, - const util::TracingCtx *tracing_ctx, - const compiler::CompilerOptions &options, const std::shared_ptr<exec::IExecutors> &executors, - const ir::ModelIndex &index); + const ExecutorFactoryArgs &args); + +#ifdef ONERT_TRAIN + // TODO Unify create() + exec::IExecutor *create(std::unique_ptr<compiler::train::LoweredTrainableGraph> lowered_graph, + const std::shared_ptr<exec::IExecutors> &executors, + const ExecutorFactoryArgs &args, + const std::shared_ptr<exec::train::optimizer::Optimizer> &optimizer); +#endif // ONERT_TRAIN private: ExecutorFactory(); private: - static void prepareMigrantTensors(compiler::LoweredGraph &lowered_graph, + static void prepareMigrantTensors(compiler::ILoweredGraph &lowered_graph, const backend::BackendContexts &backend_contexts); static void prepareBuiltinBackend(const TensorRegistries &tensor_regs, const std::shared_ptr<exec::IExecutors> &executors, @@ -56,22 +79,31 @@ private: static std::deque<std::pair<const backend::Backend *, backend::BackendContext *>> orderBackendContext(const backend::BackendContexts &backend_contexts); - static exec::IExecutor *createLinearExecutor( - std::unique_ptr<compiler::LoweredGraph> lowered_graph, const util::TracingCtx *tracing_ctx, - const compiler::CompilerOptions &options, const std::shared_ptr<exec::IExecutors> &executors, - const ir::ModelIndex &index); - static exec::IExecutor *createDataflowExecutor( - std::unique_ptr<compiler::LoweredGraph> lowered_graph, const util::TracingCtx *tracing_ctx, - const compiler::CompilerOptions &options, const std::shared_ptr<exec::IExecutors> &executors, - const ir::ModelIndex &index, bool parallel); + static exec::IExecutor * + createLinearExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph, + const std::shared_ptr<exec::IExecutors> &executors, + const ExecutorFactoryArgs &args); + static exec::IExecutor * + createDataflowExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph, + const std::shared_ptr<exec::IExecutors> &executors, + const ExecutorFactoryArgs &args, bool parallel); +#ifdef ONERT_TRAIN + // TODO Unify prepareMigrantTensors + static void + prepareMigrantTensors(compiler::ILoweredGraph &lowered_graph, + const backend::train::TrainableBackendContexts &backend_contexts); + static exec::IExecutor * + createTrainableExecutor(std::unique_ptr<compiler::train::LoweredTrainableGraph> lowered_graph, + const std::shared_ptr<exec::IExecutors> &executors, + const ExecutorFactoryArgs &args, + const std::shared_ptr<exec::train::optimizer::Optimizer> &optimizer); +#endif // ONERT_TRAIN private: std::unordered_map< - std::string, - std::function<exec::IExecutor *( - std::unique_ptr<compiler::LoweredGraph>, const util::TracingCtx *tracing_ctx, - const compiler::CompilerOptions &options, const std::shared_ptr<exec::IExecutors> &executors, - const ir::ModelIndex &index)>> + std::string, std::function<exec::IExecutor *(std::unique_ptr<compiler::LoweredGraph>, + const std::shared_ptr<exec::IExecutors> &executors, + const ExecutorFactoryArgs &args)>> _map; }; diff --git a/runtime/onert/core/src/compiler/Fp32ToFp16Converter.cc b/runtime/onert/core/src/compiler/Fp32ToFp16Converter.cc index fdf4e24f0..ce9b09c2d 100644 --- a/runtime/onert/core/src/compiler/Fp32ToFp16Converter.cc +++ b/runtime/onert/core/src/compiler/Fp32ToFp16Converter.cc @@ -776,7 +776,7 @@ Fp32ToFp16Converter::InputToOpSeqs Fp32ToFp16Converter::prepareInputToOpSeqs() c InputToOpSeqs input_to_op_seqs; op_seqs.iterate([&](const ir::OpSequenceIndex &op_seq_idx, const ir::OpSequence &op_seq) { - for (auto input : op_seq.getInputs() | ir::Remove::UNDEFINED) + for (auto &&input : op_seq.getInputs() | ir::Remove::UNDEFINED) { auto it = input_to_op_seqs.find(input); if (it == input_to_op_seqs.end()) @@ -862,7 +862,7 @@ void Fp32ToFp16Converter::manipulateContiguousOpSequences( // | // [OPERATION] // op_seq_ind_next_to_fp16 // - for (auto it : opseq_map_to_delete) + for (auto &&it : opseq_map_to_delete) { // fp16_to_fp32's input/output num is always 1 auto &op_seq_ind_fp16_to_fp32 = it.first; diff --git a/runtime/onert/core/src/compiler/HEScheduler.cc b/runtime/onert/core/src/compiler/HEScheduler.cc index 65fd4cd77..f662ef5b9 100644 --- a/runtime/onert/core/src/compiler/HEScheduler.cc +++ b/runtime/onert/core/src/compiler/HEScheduler.cc @@ -28,7 +28,7 @@ namespace using namespace onert; -uint32_t getOperationsFlattenedIOSize(const ir::Graph &graph, const ir::Operation &node) +uint32_t getOperationsFlattenedIOSize(const ir::Graph &graph, const ir::IOperation &node) { uint32_t size = 0; for (const auto &ind : @@ -39,7 +39,7 @@ uint32_t getOperationsFlattenedIOSize(const ir::Graph &graph, const ir::Operatio return size; } -bool isQuant(const ir::Graph &graph, const ir::Operation &node) +bool isQuant(const ir::Graph &graph, const ir::IOperation &node) { for (const auto &input : node.getInputs() | ir::Remove::UNDEFINED) { @@ -52,14 +52,14 @@ bool isQuant(const ir::Graph &graph, const ir::Operation &node) return false; } -bool isWorkaroundSkip(const ir::Graph &, const backend::Backend *, const ir::Operation &, bool) +bool isWorkaroundSkip(const ir::Graph &, const backend::Backend *, const ir::IOperation &, bool) { // Now, there is no workaround return false; } // if a node can be merged into op_seq -bool isMergeable(const ir::Graph &graph, const ir::Operation &node) +bool isMergeable(const ir::Graph &graph, const ir::IOperation &node) { size_t prev_op_cnt = 0; for (const auto &input : node.getInputs() | ir::Remove::UNDEFINED) @@ -137,7 +137,7 @@ void HEScheduler::scheduleShufflingBackends() } } -bool HEScheduler::isNodeProfiled(const ir::Operation &node) +bool HEScheduler::isNodeProfiled(const ir::IOperation &node) { const bool quant = isQuant(*_graph, node); const auto size = getOperationsFlattenedIOSize(*_graph, node); @@ -207,7 +207,7 @@ std::unique_ptr<compiler::BackendResolver> HEScheduler::schedule(const ir::Graph { // Check if profiling info about all backend/node pairs already exists bool all_nodes_are_profiled = true; - _graph->operations().iterate([&](const ir::OperationIndex &, const ir::Operation &op) { + _graph->operations().iterate([&](const ir::OperationIndex &, const ir::IOperation &op) { if (all_nodes_are_profiled) all_nodes_are_profiled = isNodeProfiled(op); }); @@ -224,7 +224,7 @@ std::unique_ptr<compiler::BackendResolver> HEScheduler::schedule(const ir::Graph ir::OperationIndexMap<bool> visited; graph.operations().iterate( - [&](const ir::OperationIndex &index, const ir::Operation &) { visited[index] = false; }); + [&](const ir::OperationIndex &index, const ir::IOperation &) { visited[index] = false; }); // for each task select the backend with the smallest earliest finishing time(eft) for (const auto &rank : _rank_to_op) { @@ -258,7 +258,7 @@ int64_t HEScheduler::getPermuteTime(const backend::Backend *src_backend, return size / 400; } -int64_t HEScheduler::tryBackend(const ir::Operation &node, const backend::Backend *backend) +int64_t HEScheduler::tryBackend(const ir::IOperation &node, const backend::Backend *backend) { // if there is no profiling info don't use this backend during scheduling if (!_is_profiling_mode) @@ -297,10 +297,10 @@ void HEScheduler::makeRank() VERBOSE(HEScheduler::makeRank) << "task prioritizing" << std::endl; _graph->operations().iterate( - [&](const ir::OperationIndex &index, const ir::Operation &) { DFSMaxRank(index); }); + [&](const ir::OperationIndex &index, const ir::IOperation &) { DFSMaxRank(index); }); // Check that ranks are calculated for all operations(nodes) - _graph->operations().iterate([&](const ir::OperationIndex &index, const ir::Operation &) { + _graph->operations().iterate([&](const ir::OperationIndex &index, const ir::IOperation &) { UNUSED_RELEASE(index); assert(_op_to_rank->find(index) != _op_to_rank->end()); }); @@ -564,7 +564,7 @@ HEScheduler::ESTAndExecTime(const backend::Backend *backend, const ir::Operation return {prev_op_ft, exec_time}; } -int64_t HEScheduler::predMaxEFT(const backend::Backend *backend, const ir::Operation &node, +int64_t HEScheduler::predMaxEFT(const backend::Backend *backend, const ir::IOperation &node, std::multimap<int64_t, int64_t> &transfer_st_exec_time) { int64_t max_pred_eft = 0; diff --git a/runtime/onert/core/src/compiler/HEScheduler.h b/runtime/onert/core/src/compiler/HEScheduler.h index 18ea388fd..df6c07926 100644 --- a/runtime/onert/core/src/compiler/HEScheduler.h +++ b/runtime/onert/core/src/compiler/HEScheduler.h @@ -58,7 +58,7 @@ public: _is_profiling_mode{options.he_profiling_mode}, _is_linear_exec{options.executor == "Linear"}, _is_parallel_exec{options.executor == "Parallel"} { - for (auto entry : backends) + for (auto &&entry : backends) { if (entry->config()->id() == backend::builtin::Config::ID) continue; @@ -88,7 +88,7 @@ public: std::shared_ptr<ir::OperationIndexMap<int64_t>> getIndexedRanks() { return _op_to_rank; } private: - bool isNodeProfiled(const ir::Operation &); + bool isNodeProfiled(const ir::IOperation &); bool schedule(const ir::OperationIndex &, const backend::Backend *parent_backend); /** @@ -115,7 +115,7 @@ private: * * @return earliest finishing time of parent nodes */ - int64_t predMaxEFT(const backend::Backend *backend, const ir::Operation &node, + int64_t predMaxEFT(const backend::Backend *backend, const ir::IOperation &node, std::multimap<int64_t, int64_t> &transfer_st_exec_time); void makeRank(); @@ -146,7 +146,7 @@ private: void scheduleShufflingBackends(); - int64_t tryBackend(const ir::Operation &node, const backend::Backend *backend); + int64_t tryBackend(const ir::IOperation &node, const backend::Backend *backend); /** * @brief Schedule a node and its successor until: diff --git a/runtime/onert/core/src/compiler/HEScheduler.test.cc b/runtime/onert/core/src/compiler/HEScheduler.test.cc index 589331b49..1654bfc8b 100644 --- a/runtime/onert/core/src/compiler/HEScheduler.test.cc +++ b/runtime/onert/core/src/compiler/HEScheduler.test.cc @@ -43,7 +43,7 @@ struct MockConfigCPU : public IConfig std::string id() override { return "cpu"; } bool initialize() override { return true; }; bool supportPermutation() override { return false; } - Layout supportLayout(const Operation &, Layout) override { return Layout::UNKNOWN; } + Layout supportLayout(const IOperation &, Layout) override { return Layout::UNKNOWN; } bool supportDynamicTensor() override { return false; } bool supportFP16() override { return false; } }; @@ -70,7 +70,7 @@ struct MockConfigGPU : public IConfig std::string id() override { return "gpu"; } bool initialize() override { return true; }; bool supportPermutation() override { return false; } - ir::Layout supportLayout(const ir::Operation &, ir::Layout) override + ir::Layout supportLayout(const ir::IOperation &, ir::Layout) override { return ir::Layout::UNKNOWN; } @@ -92,7 +92,7 @@ struct MockConfigNPU : public IConfig std::string id() override { return "npu"; } bool initialize() override { return true; }; bool supportPermutation() override { return false; } - ir::Layout supportLayout(const ir::Operation &, ir::Layout) override + ir::Layout supportLayout(const ir::IOperation &, ir::Layout) override { return ir::Layout::UNKNOWN; } diff --git a/runtime/onert/core/src/compiler/Linear.cc b/runtime/onert/core/src/compiler/Linear.cc index f85b8d1bd..4dbe229c8 100644 --- a/runtime/onert/core/src/compiler/Linear.cc +++ b/runtime/onert/core/src/compiler/Linear.cc @@ -28,16 +28,16 @@ namespace compiler { // TODO(easy) Change the LoweredGraph param to Graph -std::vector<ir::OperationIndex> Linear::linearize(const compiler::LoweredGraph &lowered_graph) +std::vector<ir::OperationIndex> Linear::linearize(const compiler::ILoweredGraph &lowered_graph) { return lowered_graph.graph().topolSortOperations(); } // TODO(easy) Change the LoweredGraph param to Graph -void Linear::dump(const compiler::LoweredGraph &lowered_graph, +void Linear::dump(const compiler::ILoweredGraph &lowered_graph, const std::vector<ir::OperationIndex> &order) { - for (const auto ind : order) + for (const auto &ind : order) { // TODO Could logging system can handle this? (Inserting prefix for each line) std::istringstream iss{dumper::text::formatOperation(lowered_graph.graph(), ind)}; diff --git a/runtime/onert/core/src/compiler/Linear.h b/runtime/onert/core/src/compiler/Linear.h index 9ac9a0139..4f92dc88d 100644 --- a/runtime/onert/core/src/compiler/Linear.h +++ b/runtime/onert/core/src/compiler/Linear.h @@ -21,7 +21,7 @@ #include <memory> #include "ir/Index.h" -#include "compiler/LoweredGraph.h" +#include "compiler/ILoweredGraph.h" namespace onert { @@ -31,8 +31,8 @@ namespace compiler class Linear { public: - static std::vector<ir::OperationIndex> linearize(const compiler::LoweredGraph &lowered_graph); - static void dump(const compiler::LoweredGraph &lowered_graph, + static std::vector<ir::OperationIndex> linearize(const compiler::ILoweredGraph &lowered_graph); + static void dump(const compiler::ILoweredGraph &lowered_graph, const std::vector<ir::OperationIndex> &order); }; diff --git a/runtime/onert/core/src/compiler/LoweredGraph.cc b/runtime/onert/core/src/compiler/LoweredGraph.cc index d53d0ed00..46a45e44a 100644 --- a/runtime/onert/core/src/compiler/LoweredGraph.cc +++ b/runtime/onert/core/src/compiler/LoweredGraph.cc @@ -49,7 +49,7 @@ void LoweredGraph::lowerGraph(const CompilerOptions &options) // Build backend contexts auto &backend_manager = BackendManager::get(); // Create contexts for other backends - for (auto backend_str : options.backend_list) + for (auto &&backend_str : options.backend_list) { backend_manager.loadBackend(backend_str); auto backend = backend_manager.get(backend_str); @@ -100,9 +100,9 @@ void LoweredGraph::lowerGraph(const CompilerOptions &options) pass::PassRunner{}.append(std::make_unique<pass::PermutationEliminationPass>(*this)).run(); VERBOSE(LoweredGraph) << "Dump after all the passes" << std::endl; - for (auto operand : _graph.getInputs()) + for (auto &&operand : _graph.getInputs()) VERBOSE(LoweredGraph) << "Graph Input : " << operand << std::endl; - for (auto operand : _graph.getOutputs()) + for (auto &&operand : _graph.getOutputs()) VERBOSE(LoweredGraph) << "Graph Output : " << operand << std::endl; dumper::text::dumpLoweredGraph(*this); @@ -121,8 +121,8 @@ void LoweredGraph::makeLowerInfo(const compiler::BackendResolver &backend_resolv }); // Set operand lower info using assigned backends to operations - _graph.operations().iterate([&](const ir::OperationIndex &op_ind, const ir::Operation &) { - const ir::Operation &op = _graph.operations().at(op_ind); + _graph.operations().iterate([&](const ir::OperationIndex &op_ind, const ir::IOperation &) { + const ir::IOperation &op = _graph.operations().at(op_ind); auto backend = backend_resolver.getBackend(op_ind); if (!backend) { @@ -135,12 +135,12 @@ void LoweredGraph::makeLowerInfo(const compiler::BackendResolver &backend_resolv // TODO Change setting layout of each backend at another place auto backend_layout = backend->config()->supportLayout(op, frontend_layout); - for (auto ind : op.getInputs() | ir::Remove::UNDEFINED) + for (auto &&ind : op.getInputs() | ir::Remove::UNDEFINED) { auto &operand_li = lower_info().operand.at(ind); operand_li.addUsePermuteFactor(PermuteFactor{backend, backend_layout}); } - for (auto ind : op.getOutputs() | ir::Remove::UNDEFINED) + for (auto &&ind : op.getOutputs() | ir::Remove::UNDEFINED) { auto &operand_li = lower_info().operand.at(ind); operand_li.addDefPermuteFactor(PermuteFactor{backend, backend_layout}); @@ -152,13 +152,13 @@ void LoweredGraph::makeLowerInfo(const compiler::BackendResolver &backend_resolv // Handle graph inputs and outputs const auto builtin_backend = BackendManager::get().getBuiltin(); auto factor = PermuteFactor{builtin_backend, _graph.layout()}; - for (auto index : _graph.getInputs() | ir::Remove::UNDEFINED) + for (auto &&index : _graph.getInputs() | ir::Remove::UNDEFINED) { auto &operand_li = lower_info().operand.at(index); assert(operand_li.def_factors().empty()); operand_li.addDefPermuteFactor(factor); } - for (auto index : _graph.getOutputs() | ir::Remove::UNDEFINED) + for (auto &&index : _graph.getOutputs() | ir::Remove::UNDEFINED) { auto &operand_li = lower_info().operand.at(index); operand_li.addUsePermuteFactor(factor); @@ -204,7 +204,7 @@ void LoweredGraph::dumpLowerInfo() auto factors_to_string = [](const PermuteFactorSet &factors) { std::string str; - for (auto factor : factors) + for (auto &&factor : factors) { str += factor.backend()->config()->id(); str += "(" + to_string(factor.layout()) + ")"; @@ -216,7 +216,7 @@ void LoweredGraph::dumpLowerInfo() auto operation_index_set_to_string = [](const ir::OperationIndexSet &operations) { std::stringstream sstream; sstream << "{ "; - for (auto op : operations) + for (auto &&op : operations) sstream << op << " "; sstream << "}"; return sstream.str(); diff --git a/runtime/onert/core/src/compiler/ManualScheduler.cc b/runtime/onert/core/src/compiler/ManualScheduler.cc index 621f0c7b7..ccd08893f 100644 --- a/runtime/onert/core/src/compiler/ManualScheduler.cc +++ b/runtime/onert/core/src/compiler/ManualScheduler.cc @@ -42,7 +42,7 @@ std::unique_ptr<BackendResolver> ManualScheduler::schedule(const ir::Graph &grap // This fallback will be used in case that `backend_for_all` is unavailable auto fallback = [&]() -> const backend::Backend * { - for (auto backend_id : _options.backend_list) + for (auto &&backend_id : _options.backend_list) { auto backend = resolveBackend(backend_id); if (backend) @@ -58,7 +58,7 @@ std::unique_ptr<BackendResolver> ManualScheduler::schedule(const ir::Graph &grap VERBOSE(ManualScheduler) << "Default backend for all ops: " << backend_all->config()->id() << std::endl; - graph.operations().iterate([&](const ir::OperationIndex &index, const ir::Operation &) { + graph.operations().iterate([&](const ir::OperationIndex &index, const ir::IOperation &) { backend_resolver->setBackend(index, backend_all); }); @@ -71,7 +71,7 @@ std::unique_ptr<BackendResolver> ManualScheduler::schedule(const ir::Graph &grap // By default, Custom uses cpu backend op_type_map[ir::OpCode::Custom] = BackendManager::get().get("cpu"); - graph.operations().iterate([&](const ir::OperationIndex &index, const ir::Operation &operation) { + graph.operations().iterate([&](const ir::OperationIndex &index, const ir::IOperation &operation) { auto itr = op_type_map.find(operation.opcode()); if (itr != op_type_map.end()) { diff --git a/runtime/onert/core/src/compiler/MultiModelCompiler.cc b/runtime/onert/core/src/compiler/MultiModelCompiler.cc index fea6a7f25..141fdfe09 100644 --- a/runtime/onert/core/src/compiler/MultiModelCompiler.cc +++ b/runtime/onert/core/src/compiler/MultiModelCompiler.cc @@ -16,6 +16,7 @@ #include "MultiModelCompiler.h" +#include "CompilerHelpers.h" #include "ExecutorFactory.h" #include "ShapeValidator.h" #include "pass/ConstantOutputPass.h" @@ -30,6 +31,7 @@ #include "compiler/StaticShapeInferer.h" #include <misc/string_helpers.h> +#include <misc/polymorphic_downcast.h> namespace onert { @@ -53,7 +55,7 @@ std::shared_ptr<CompilerArtifact> MultiModelCompiler::compile(void) /*************************************************** * Prepare compilation phase ***************************************************/ - for (auto options : _voptions) + for (auto &&options : _voptions) { if (!options) throw std::runtime_error{"Empty compile option"}; @@ -63,6 +65,9 @@ std::shared_ptr<CompilerArtifact> MultiModelCompiler::compile(void) if (options->he_profiling_mode) throw std::runtime_error("NYI: Profiling mode for multiple model is not supported yet"); + if (!options->minmax_filepath.empty()) + throw std::runtime_error("Recording minmax is not supported for multiple models"); + options->forceInternalOptions(); options->verboseOptions(); } @@ -74,7 +79,15 @@ std::shared_ptr<CompilerArtifact> MultiModelCompiler::compile(void) for (uint16_t i = 0; i < model_count; i++) { - _nnpkg->model(ir::ModelIndex{i})->iterate([&](const ir::SubgraphIndex &, ir::Graph &subg) { + if (!_nnpkg->model(ir::ModelIndex{i})->hasOnly<ir::Graph>()) + throw std::runtime_error("MultiModelCompiler can only compile models for inference."); + } + + for (uint16_t i = 0; i < model_count; i++) + { + _nnpkg->model(ir::ModelIndex{i})->iterate([&](const ir::SubgraphIndex &, ir::IGraph &graph) { + auto &subg = nnfw::misc::polymorphic_downcast<ir::Graph &>(graph); + // Mandatory passes pass::PassRunner{} .append(std::make_unique<pass::ConstantOutputPass>(subg)) @@ -100,6 +113,15 @@ std::shared_ptr<CompilerArtifact> MultiModelCompiler::compile(void) // Model edge context: copy model edge context auto model_edges = std::make_unique<ir::ModelEdges>(_nnpkg->model_edges()); + // Custom kernels + std::unordered_map<ir::ModelIndex, std::shared_ptr<backend::custom::IKernelBuilder>> + custom_kernel_builders; + for (uint16_t i = 0; i < model_count; i++) + { + auto const model_index = ir::ModelIndex{i}; + custom_kernel_builders[model_index] = _nnpkg->model(model_index)->getKernelBuilder(); + } + // Lower: Assign backend std::unordered_map<ir::ModelIndex, std::unordered_map<ir::SubgraphIndex, std::unique_ptr<compiler::LoweredGraph>>> @@ -110,7 +132,9 @@ std::shared_ptr<CompilerArtifact> MultiModelCompiler::compile(void) auto const model_index = ir::ModelIndex{i}; auto model = _nnpkg->model(model_index); - model->iterate([&](const ir::SubgraphIndex &subg_index, ir::Graph &subg) { + model->iterate([&](const ir::SubgraphIndex &subg_index, ir::IGraph &graph) { + auto &subg = nnfw::misc::polymorphic_downcast<ir::Graph &>(graph); + dot_dumper.dump(subg, nnfw::misc::str("before_lower_model-", i, "-subg-", subg_index.value())); // Lower: Assign backend @@ -146,7 +170,7 @@ std::shared_ptr<CompilerArtifact> MultiModelCompiler::compile(void) // Run the StaticShapeInfer of primary subg. All child StaticShapeInferers are called // recursively std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>> inferers = - StaticShapeInferer::createStaticShapeInferers(model_lsubgs); + createStaticShapeInferers(model_lsubgs); const auto primary_subg_idx = ir::SubgraphIndex{0}; inferers.at(primary_subg_idx)->infer(); @@ -194,11 +218,15 @@ std::shared_ptr<CompilerArtifact> MultiModelCompiler::compile(void) ir::OperationDumper dumper("Executor generation of Subgraph " + std::to_string(subg_index.value())); lowered_subg->graph().operations().iterate( - [&](const ir::OperationIndex &, const ir::Operation &op) { op.accept(dumper); }); - - auto &options = *_voptions[model_index.value()]; - auto executor = std::unique_ptr<exec::IExecutor>{ExecutorFactory::get().create( - std::move(lowered_subg), tracing_ctx.get(), options, executors, model_index)}; + [&](const ir::OperationIndex &, const ir::IOperation &op) { op.accept(dumper); }); + + ExecutorFactoryArgs args; + args.tracing_ctx = tracing_ctx.get(); + args.options = _voptions[model_index.value()]; + args.model_index = model_index; + args.custom_kernel_builder = custom_kernel_builders[model_index]; + auto executor = std::unique_ptr<exec::IExecutor>{ + ExecutorFactory::get().create(std::move(lowered_subg), executors, args)}; executor->setIndexedRanks(indexed_ranks); executors->emplace(model_index, subg_index, std::move(executor)); } diff --git a/runtime/onert/core/src/compiler/MultiModelCompiler.h b/runtime/onert/core/src/compiler/MultiModelCompiler.h index 89af664f8..b282a5087 100644 --- a/runtime/onert/core/src/compiler/MultiModelCompiler.h +++ b/runtime/onert/core/src/compiler/MultiModelCompiler.h @@ -59,12 +59,6 @@ public: std::shared_ptr<CompilerArtifact> compile(void); private: - std::shared_ptr<ir::Graph> &primary_subgraph() - { - return _nnpkg->primary_model()->at(ir::SubgraphIndex{0}); - } - -private: std::shared_ptr<ir::NNPkg> _nnpkg; std::vector<CompilerOptions *> _voptions; }; diff --git a/runtime/onert/core/src/compiler/ShapeValidator.cc b/runtime/onert/core/src/compiler/ShapeValidator.cc index 8c6421744..3e940f037 100644 --- a/runtime/onert/core/src/compiler/ShapeValidator.cc +++ b/runtime/onert/core/src/compiler/ShapeValidator.cc @@ -52,7 +52,7 @@ void ShapeValidator::checkUnaryOp(const ir::Operation &node) void ShapeValidator::operator()() { _graph.operations().iterate( - [&](const ir::OperationIndex &, const ir::Operation &node) { node.accept(*this); }); + [&](const ir::OperationIndex &, const ir::IOperation &node) { node.accept(*this); }); } void ShapeValidator::visit(const ir::operation::BatchMatMul &node) diff --git a/runtime/onert/core/src/compiler/StaticShapeInferer.cc b/runtime/onert/core/src/compiler/StaticShapeInferer.cc index 25747d950..a25b326f1 100644 --- a/runtime/onert/core/src/compiler/StaticShapeInferer.cc +++ b/runtime/onert/core/src/compiler/StaticShapeInferer.cc @@ -99,10 +99,10 @@ void StaticShapeInferer::infer() } } -bool StaticShapeInferer::checkDynamicInput(const ir::Operation &op) +bool StaticShapeInferer::checkDynamicInput(const ir::IOperation &op) { const auto &operands = _lowered_subg->graph().operands(); - for (auto input_idx : op.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED) + for (auto &&input_idx : op.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED) { if (operands.at(input_idx).info().isDynamic()) { @@ -113,10 +113,10 @@ bool StaticShapeInferer::checkDynamicInput(const ir::Operation &op) return false; } -bool StaticShapeInferer::checkDynamicOutput(const ir::Operation &op) +bool StaticShapeInferer::checkDynamicOutput(const ir::IOperation &op) { auto &operands = _lowered_subg->graph().operands(); - for (auto output_idx : op.getOutputs() | ir::Remove::UNDEFINED) + for (auto &&output_idx : op.getOutputs() | ir::Remove::UNDEFINED) { if (operands.at(output_idx).info().isDynamic()) { @@ -126,10 +126,10 @@ bool StaticShapeInferer::checkDynamicOutput(const ir::Operation &op) return false; } -void StaticShapeInferer::setDynamicOutput(const ir::Operation &op) +void StaticShapeInferer::setDynamicOutput(const ir::IOperation &op) { auto &operands = _lowered_subg->graph().operands(); - for (auto output_idx : op.getOutputs() | ir::Remove::UNDEFINED) + for (auto &&output_idx : op.getOutputs() | ir::Remove::UNDEFINED) { operands.at(output_idx).info().setDynamic(); } @@ -192,7 +192,7 @@ void StaticShapeInferer::dump() std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>> StaticShapeInferer::createStaticShapeInferers( - const std::unordered_map<ir::SubgraphIndex, std::unique_ptr<LoweredGraph>> &lowered_subgs) + const std::unordered_map<ir::SubgraphIndex, ILoweredGraph *> &lowered_subgs) { // Allocate StaticShapeInferer per each subgraph std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>> inferers; @@ -200,7 +200,7 @@ StaticShapeInferer::createStaticShapeInferers( { const auto &subg_index = pair.first; auto &lowered_subg = pair.second; - inferers[subg_index] = std::make_unique<StaticShapeInferer>(lowered_subg.get()); + inferers[subg_index] = std::make_unique<StaticShapeInferer>(lowered_subg); } // Append observers in all StaticShapeInferers @@ -211,7 +211,7 @@ StaticShapeInferer::createStaticShapeInferers( // TODO: Change this iteration for all to controlflow iteration lowered_subg->graph().operations().iterate( - [&](const ir::OperationIndex &, const ir::Operation &op) { + [&](const ir::OperationIndex &, const ir::IOperation &op) { // A Function to append child inferers. These make it possible for a StaticShapeInferer to // call StaticShapeInferes of child subgraphs recursively auto appendChildInferer = [&](const ir::SubgraphIndex &child_subg_idx) { @@ -251,7 +251,9 @@ StaticShapeInferer::createStaticShapeInferers( // Append Observers in a StaticShapeInferer if (op.opcode() == ir::OpCode::If) { - const auto &if_op = nnfw::misc::polymorphic_downcast<const ir::operation::If &>(op); + // TODO Remove dynamic_cast + // An virtual base class cannot be downcasted by static_cast + const auto &if_op = dynamic_cast<const ir::operation::If &>(op); appendChildInferer(if_op.param().then_subg_index); appendChildInferer(if_op.param().else_subg_index); @@ -263,7 +265,8 @@ StaticShapeInferer::createStaticShapeInferers( } else if (op.opcode() == ir::OpCode::While) { - const auto &while_op = nnfw::misc::polymorphic_downcast<const ir::operation::While &>(op); + // TODO Remove dynamic_cast + const auto &while_op = dynamic_cast<const ir::operation::While &>(op); appendChildInferer(while_op.param().cond_subg_index); appendChildInferer(while_op.param().body_subg_index); @@ -602,6 +605,13 @@ void StaticShapeInferer::visit(const ir::operation::L2Normalization &op) handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::L2Normalization::Input::INPUT)); } +void StaticShapeInferer::visit(const ir::operation::Loss &) +{ + // TODO Consider SparseCategoricalCrossentropy case + + // TODO Consider output shape in case of reduction option +} + void StaticShapeInferer::visit(const ir::operation::LSTM &op) { auto &operands = _lowered_subg->graph().operands(); @@ -1119,7 +1129,7 @@ void StaticShapeInferer::visit(const ir::operation::Split &op) auto outputs = op.getOutputs(); if (!axis.isConstant()) { - for (auto output_idx : outputs) + for (auto &&output_idx : outputs) { ir::Operand &output = operands.at(output_idx); output.info().setDynamic(); @@ -1137,7 +1147,7 @@ void StaticShapeInferer::visit(const ir::operation::Split &op) ir::Shape new_shape = shape_inference::inferSplitShape(input.info().shape(), axis_value, num_splits); - for (auto output_idx : outputs) + for (auto &&output_idx : outputs) { ir::Operand &output = operands.at(output_idx); output.info().shape(new_shape); diff --git a/runtime/onert/core/src/compiler/pass/ConstantInsertionPass.cc b/runtime/onert/core/src/compiler/pass/ConstantInsertionPass.cc index 89dd303d4..a6590b13f 100644 --- a/runtime/onert/core/src/compiler/pass/ConstantInsertionPass.cc +++ b/runtime/onert/core/src/compiler/pass/ConstantInsertionPass.cc @@ -28,14 +28,14 @@ namespace compiler namespace pass { -void ConstantInsertionPass::callback(const ir::OperationIndex &node_index, ir::Operation &node) +void ConstantInsertionPass::callback(const ir::OperationIndex &node_index, ir::IOperation &node) { const auto op_lower_info = _lowered_graph.lower_info().operation.getRawPtr(node_index); const auto backend = op_lower_info->backend(); const auto layout = op_lower_info->layout(); const auto factor = PermuteFactor{backend, layout}; - for (const auto input : node.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED) + for (const auto &input : node.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED) { auto &object = _graph.operands().at(input); diff --git a/runtime/onert/core/src/compiler/pass/ConstantInsertionPass.h b/runtime/onert/core/src/compiler/pass/ConstantInsertionPass.h index 4911ace2f..d5b9aa14e 100644 --- a/runtime/onert/core/src/compiler/pass/ConstantInsertionPass.h +++ b/runtime/onert/core/src/compiler/pass/ConstantInsertionPass.h @@ -39,7 +39,7 @@ public: std::string id() final { return "ConstantInsertionPass"; } public: - void callback(const ir::OperationIndex &index, ir::Operation &node) final; + void callback(const ir::OperationIndex &index, ir::IOperation &node) final; private: struct ReplaceKey diff --git a/runtime/onert/core/src/compiler/pass/ConstantLoweringPass.cc b/runtime/onert/core/src/compiler/pass/ConstantLoweringPass.cc index 6ed154548..32e32d0ef 100644 --- a/runtime/onert/core/src/compiler/pass/ConstantLoweringPass.cc +++ b/runtime/onert/core/src/compiler/pass/ConstantLoweringPass.cc @@ -29,7 +29,7 @@ namespace compiler namespace pass { -void ConstantLoweringPass::callback(const ir::OperationIndex &node_index, ir::Operation &node) +void ConstantLoweringPass::callback(const ir::OperationIndex &node_index, ir::IOperation &node) { const auto op_lower_info = _lowered_graph.lower_info().operation.getRawPtr(node_index); const auto backend = op_lower_info->backend(); @@ -37,7 +37,7 @@ void ConstantLoweringPass::callback(const ir::OperationIndex &node_index, ir::Op const auto factor = PermuteFactor{backend, layout}; // Now this runtime does not support the node making output of operation as constant - for (const auto input : node.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED) + for (const auto &input : node.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED) { auto &object = _graph.operands().at(input); if (object.isConstant()) diff --git a/runtime/onert/core/src/compiler/pass/ConstantLoweringPass.h b/runtime/onert/core/src/compiler/pass/ConstantLoweringPass.h index e17d776d1..d60a1033f 100644 --- a/runtime/onert/core/src/compiler/pass/ConstantLoweringPass.h +++ b/runtime/onert/core/src/compiler/pass/ConstantLoweringPass.h @@ -36,7 +36,7 @@ public: std::string id() final { return "ConstantLoweringPass"; } public: - void callback(const ir::OperationIndex &index, ir::Operation &node) final; + void callback(const ir::OperationIndex &index, ir::IOperation &node) final; }; } // namespace pass diff --git a/runtime/onert/core/src/compiler/pass/ConstantOutputPass.cc b/runtime/onert/core/src/compiler/pass/ConstantOutputPass.cc index c176f6ffb..1448de473 100644 --- a/runtime/onert/core/src/compiler/pass/ConstantOutputPass.cc +++ b/runtime/onert/core/src/compiler/pass/ConstantOutputPass.cc @@ -49,7 +49,7 @@ void ConstantOutputPass::callback(const ir::OperandIndex &ind, ir::Operand &obj) // Make the operations that uses this operand to use the generated operand auto orig_uses = obj.getUses(); - for (auto use : orig_uses) + for (auto &&use : orig_uses) { permute_input_obj.insertUse(use); obj.removeUse(use); diff --git a/runtime/onert/core/src/compiler/pass/IPass.h b/runtime/onert/core/src/compiler/pass/IPass.h new file mode 100644 index 000000000..77f5916fd --- /dev/null +++ b/runtime/onert/core/src/compiler/pass/IPass.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_COMPILER_PASS_IPASS_H__ +#define __ONERT_COMPILER_PASS_IPASS_H__ + +#include <string> + +namespace onert +{ +namespace compiler +{ +namespace pass +{ + +struct IPass +{ + virtual ~IPass() = default; + + virtual std::string id() = 0; + virtual void run() = 0; +}; + +} // namespace pass +} // namespace compiler +} // namespace onert + +#endif // __ONERT_COMPILER_PASS_IPASS_H__ diff --git a/runtime/onert/core/src/compiler/pass/LoweredOperandPass.h b/runtime/onert/core/src/compiler/pass/LoweredOperandPass.h index 1f1f32f6d..64831a0ac 100644 --- a/runtime/onert/core/src/compiler/pass/LoweredOperandPass.h +++ b/runtime/onert/core/src/compiler/pass/LoweredOperandPass.h @@ -18,7 +18,7 @@ #define __ONERT_IR_PASS_LOWERED_OPERAND_PASS_H__ #include "OperandPass.h" -#include "compiler/LoweredGraph.h" +#include "compiler/ILoweredGraph.h" namespace onert { @@ -30,7 +30,7 @@ namespace pass class LoweredOperandPass : public OperandPass { public: - LoweredOperandPass(compiler::LoweredGraph &lowered_graph) + LoweredOperandPass(compiler::ILoweredGraph &lowered_graph) : OperandPass{lowered_graph.graph()}, _lowered_graph{lowered_graph} { // DO NOTHING @@ -42,7 +42,7 @@ public: void callback(const ir::OperandIndex &i, ir::Operand &o) override = 0; protected: - compiler::LoweredGraph &_lowered_graph; + compiler::ILoweredGraph &_lowered_graph; }; } // namespace pass diff --git a/runtime/onert/core/src/compiler/pass/LoweredOperationPass.h b/runtime/onert/core/src/compiler/pass/LoweredOperationPass.h index 76ee3d7ff..27ca77c91 100644 --- a/runtime/onert/core/src/compiler/pass/LoweredOperationPass.h +++ b/runtime/onert/core/src/compiler/pass/LoweredOperationPass.h @@ -18,7 +18,7 @@ #define __ONERT_IR_PASS_LOWERED_OPERATION_PASS_H__ #include "OperationPass.h" -#include "compiler/LoweredGraph.h" +#include "compiler/ILoweredGraph.h" namespace onert { @@ -30,7 +30,7 @@ namespace pass class LoweredOperationPass : public OperationPass { public: - LoweredOperationPass(LoweredGraph &lowered_graph) + LoweredOperationPass(ILoweredGraph &lowered_graph) : OperationPass{lowered_graph.graph()}, _lowered_graph{lowered_graph} { // DO NOTHING @@ -39,10 +39,10 @@ public: virtual ~LoweredOperationPass() = default; std::string id() override = 0; - void callback(const ir::OperationIndex &i, ir::Operation &o) override = 0; + void callback(const ir::OperationIndex &i, ir::IOperation &o) override = 0; protected: - LoweredGraph &_lowered_graph; + ILoweredGraph &_lowered_graph; }; } // namespace pass diff --git a/runtime/onert/core/src/compiler/pass/OperationPass.cc b/runtime/onert/core/src/compiler/pass/OperationPass.cc index 357a8798a..bd9bcb4a4 100644 --- a/runtime/onert/core/src/compiler/pass/OperationPass.cc +++ b/runtime/onert/core/src/compiler/pass/OperationPass.cc @@ -17,7 +17,7 @@ #include "OperationPass.h" #include "ir/Index.h" -#include "ir/Operation.h" +#include "ir/IOperation.h" #include "ir/Graph.h" namespace onert @@ -30,7 +30,7 @@ namespace pass void OperationPass::run() { _graph.operations().iterate( - [&](const ir::OperationIndex &index, ir::Operation &node) { callback(index, node); }); + [&](const ir::OperationIndex &index, ir::IOperation &node) { callback(index, node); }); } } // namespace pass diff --git a/runtime/onert/core/src/compiler/pass/OperationPass.h b/runtime/onert/core/src/compiler/pass/OperationPass.h index ac4d818a2..0a00b11d1 100644 --- a/runtime/onert/core/src/compiler/pass/OperationPass.h +++ b/runtime/onert/core/src/compiler/pass/OperationPass.h @@ -29,7 +29,7 @@ namespace onert { namespace ir { -class Operation; +struct IOperation; } // namespace ir } // namespace onert @@ -62,7 +62,7 @@ public: * @param index is the index of a node in graph * @param node is the node in graph */ - virtual void callback(const ir::OperationIndex &index, ir::Operation &node) = 0; + virtual void callback(const ir::OperationIndex &index, ir::IOperation &node) = 0; /** * @brief Run the pass diff --git a/runtime/onert/core/src/compiler/pass/Pass.h b/runtime/onert/core/src/compiler/pass/Pass.h index 3016df490..b34695c97 100644 --- a/runtime/onert/core/src/compiler/pass/Pass.h +++ b/runtime/onert/core/src/compiler/pass/Pass.h @@ -17,6 +17,8 @@ #ifndef __ONERT_COMPILER_PASS_PASS_H__ #define __ONERT_COMPILER_PASS_PASS_H__ +#include "IPass.h" + #include <string> namespace onert @@ -34,7 +36,7 @@ namespace compiler namespace pass { -class Pass +class Pass : public IPass { public: Pass(ir::Graph &graph) : _graph{graph} {} diff --git a/runtime/onert/core/src/compiler/pass/PassRunner.cc b/runtime/onert/core/src/compiler/pass/PassRunner.cc index 2d11be201..cd1b82bb2 100644 --- a/runtime/onert/core/src/compiler/pass/PassRunner.cc +++ b/runtime/onert/core/src/compiler/pass/PassRunner.cc @@ -23,7 +23,7 @@ namespace compiler namespace pass { -PassRunner &PassRunner::append(std::unique_ptr<Pass> pass) +PassRunner &PassRunner::append(std::unique_ptr<IPass> pass) { _passes.emplace_back(std::move(pass)); return *this; diff --git a/runtime/onert/core/src/compiler/pass/PassRunner.h b/runtime/onert/core/src/compiler/pass/PassRunner.h index a43c83f89..03bfbe220 100644 --- a/runtime/onert/core/src/compiler/pass/PassRunner.h +++ b/runtime/onert/core/src/compiler/pass/PassRunner.h @@ -21,7 +21,7 @@ #include <memory> #include <vector> -#include "Pass.h" +#include "IPass.h" #include "util/logging.h" namespace onert @@ -38,12 +38,12 @@ class PassRunner { public: PassRunner() = default; - PassRunner &append(std::unique_ptr<Pass> pass); + PassRunner &append(std::unique_ptr<IPass> pass); void run(); private: - std::vector<std::unique_ptr<Pass>> _passes; + std::vector<std::unique_ptr<IPass>> _passes; }; } // namespace pass diff --git a/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.cc b/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.cc index c27ce3d09..d9452c7f9 100644 --- a/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.cc +++ b/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.cc @@ -16,6 +16,7 @@ #include "PermutationEliminationPass.h" +#include "backend/Backend.h" #include "util/logging.h" namespace onert @@ -25,7 +26,7 @@ namespace compiler namespace pass { -void PermutationEliminationPass::callback(const ir::OperationIndex &ind, ir::Operation &node) +void PermutationEliminationPass::callback(const ir::OperationIndex &ind, ir::IOperation &node) { _op_ind = ind; node.accept(*this); @@ -73,7 +74,7 @@ void PermutationEliminationPass::visit(const ir::operation::Permute &node) auto &out_operand_obj = _graph.operands().at(out_operand); assert(out_operand_obj.getDef() == _op_ind); out_operand_obj.unsetDef(); - _graph.operations().iterate([&](const ir::OperationIndex &op_ind, ir::Operation &op) { + _graph.operations().iterate([&](const ir::OperationIndex &op_ind, ir::IOperation &op) { if (!op.getOutputs().contains(in_operand)) return; // Update Operation and Operand edges @@ -87,7 +88,7 @@ void PermutationEliminationPass::visit(const ir::operation::Permute &node) _graph.operations().remove(_op_ind); } - _graph.operations().iterate([&](const ir::OperationIndex &op_ind, ir::Operation &op) { + _graph.operations().iterate([&](const ir::OperationIndex &op_ind, ir::IOperation &op) { if (!op.getInputs().contains(in_operand)) return; op.replaceInputs(in_operand, out_operand); @@ -106,7 +107,7 @@ void PermutationEliminationPass::visit(const ir::operation::Permute &node) in_operand_obj.removeUse(_op_ind); // Make operations(that use the output) use the input - _graph.operations().iterate([&](const ir::OperationIndex &op_ind, ir::Operation &op) { + _graph.operations().iterate([&](const ir::OperationIndex &op_ind, ir::IOperation &op) { if (!op.getInputs().contains(out_operand)) return; op.replaceInputs(out_operand, in_operand); diff --git a/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.h b/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.h index 50c38c53f..18ba99804 100644 --- a/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.h +++ b/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.h @@ -49,7 +49,7 @@ public: std::string id() final { return "PermutationEliminationPass"; } public: - void callback(const ir::OperationIndex &i, ir::Operation &n) final; + void callback(const ir::OperationIndex &i, ir::IOperation &n) final; private: void visit(const ir::operation::Permute &) final; diff --git a/runtime/onert/core/src/compiler/pass/PermutationInsertionPass.cc b/runtime/onert/core/src/compiler/pass/PermutationInsertionPass.cc index 0da1e54df..39eb803f5 100644 --- a/runtime/onert/core/src/compiler/pass/PermutationInsertionPass.cc +++ b/runtime/onert/core/src/compiler/pass/PermutationInsertionPass.cc @@ -54,13 +54,13 @@ void PermutationInsertionPass::callback(const ir::OperandIndex &index, ir::Opera std::unordered_map<PermuteFactor, ir::OperandIndex> factor_to_index; { assert(operand_li->def_factors().size() == 1); - for (auto factor : operand_li->def_factors()) + for (auto &&factor : operand_li->def_factors()) { factor_to_index.emplace(factor, index); } auto insert_set = operand_li->use_factors() - operand_li->def_factors(); - for (auto factor : insert_set) + for (auto &&factor : insert_set) { const auto permute_operation_index = insertPermute(index, factor); permute_indexes.push_back(permute_operation_index); @@ -75,7 +75,7 @@ void PermutationInsertionPass::callback(const ir::OperandIndex &index, ir::Opera std::list<ir::OperationIndex> remove_list; auto uses = object.getUses(); - for (auto use : uses) + for (auto &&use : uses) { // If permute operation, ignore it if (std::find(permute_indexes.begin(), permute_indexes.end(), use) != permute_indexes.end()) diff --git a/runtime/onert/core/src/compiler/pass/PermutationOperationPass.cc b/runtime/onert/core/src/compiler/pass/PermutationOperationPass.cc index f83b1ba31..f014d29d3 100644 --- a/runtime/onert/core/src/compiler/pass/PermutationOperationPass.cc +++ b/runtime/onert/core/src/compiler/pass/PermutationOperationPass.cc @@ -30,7 +30,7 @@ namespace pass using namespace ir; -void PermutationOperationPass::callback(const OperationIndex &, Operation &node) +void PermutationOperationPass::callback(const OperationIndex &, IOperation &node) { node.accept(*this); } diff --git a/runtime/onert/core/src/compiler/pass/PermutationOperationPass.h b/runtime/onert/core/src/compiler/pass/PermutationOperationPass.h index cea5de288..e253a77ad 100644 --- a/runtime/onert/core/src/compiler/pass/PermutationOperationPass.h +++ b/runtime/onert/core/src/compiler/pass/PermutationOperationPass.h @@ -36,7 +36,7 @@ public: std::string id() final { return "PermutationOperationPass"; } public: - void callback(const ir::OperationIndex &i, ir::Operation &n) final; + void callback(const ir::OperationIndex &i, ir::IOperation &n) final; public: void visit(const ir::operation::BinaryArithmetic &) final; diff --git a/runtime/onert/core/src/compiler/pass/UnusedOperandEliminationPass.cc b/runtime/onert/core/src/compiler/pass/UnusedOperandEliminationPass.cc index 35fb575b0..162c4e7ef 100644 --- a/runtime/onert/core/src/compiler/pass/UnusedOperandEliminationPass.cc +++ b/runtime/onert/core/src/compiler/pass/UnusedOperandEliminationPass.cc @@ -37,15 +37,15 @@ void UnusedOperandEliminationPass::run() { util::Set<ir::OperandIndex> used; - _graph.operations().iterate([&](const ir::OperationIndex &, const ir::Operation &node) { - for (auto ind : (node.getInputs() + node.getOutputs()) | ir::Remove::UNDEFINED) + _graph.operations().iterate([&](const ir::OperationIndex &, const ir::IOperation &node) { + for (auto &&ind : (node.getInputs() + node.getOutputs()) | ir::Remove::UNDEFINED) { used.add(ind); } }); // Graph's inputs/outputs are always considered as used - for (auto ind : (_graph.getInputs() + _graph.getOutputs()) | ir::Remove::UNDEFINED) + for (auto &&ind : (_graph.getInputs() + _graph.getOutputs()) | ir::Remove::UNDEFINED) { used.add(ind); } diff --git a/runtime/onert/core/src/compiler/train/LoweredTrainableGraph.cc b/runtime/onert/core/src/compiler/train/LoweredTrainableGraph.cc new file mode 100644 index 000000000..490c648cd --- /dev/null +++ b/runtime/onert/core/src/compiler/train/LoweredTrainableGraph.cc @@ -0,0 +1,285 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "compiler/train/LoweredTrainableGraph.h" + +#include "../ManualScheduler.h" +#include "../pass/ConstantInsertionPass.h" +#include "../pass/ConstantLoweringPass.h" +#include "../pass/PassRunner.h" +#include "../pass/PermutationEliminationPass.h" +#include "../pass/PermutationInsertionPass.h" +#include "../pass/PermutationOperationPass.h" +#include "../../backend/builtin/Config.h" +#include "../../dumper/text/GraphDumper.h" +#include "../../ir/verifier/Verifier.h" +#include "TrainableOperationConverter.h" + +#include "backend/Backend.h" +#include "backend/train/ITrainableBackend.h" +#include "compiler/BackendResolver.h" +#include "util/logging.h" + +#include <cassert> +#include <sstream> + +namespace onert +{ +namespace compiler +{ +namespace train +{ + +LoweredTrainableGraph::LoweredTrainableGraph(ir::train::TrainableGraph &graph, + const CompilerOptions &options) + : _trainable_graph{graph} +{ + lowerGraph(options); +} + +void LoweredTrainableGraph::lowerGraph(const CompilerOptions &options) +{ + // Build backend contexts + auto &backend_manager = BackendManager::get(); + // Create contexts for other backends + for (auto &&backend_str : options.backend_list) + { + backend_manager.loadBackend(backend_str); + auto backend = backend_manager.get(backend_str); + + // TODO As the default value of backend list contains "cpu", "acl_cl" and "acl_neon", and some + // are not available on x64 or some other platforms. So this may be a workaround for x64 and + // we should change it back(throw if backend is not loaded) later. + if (!backend) + { + VERBOSE(LoweredTrainableGraph) << "Cannot load backend - " << backend_str << std::endl; + continue; + } + } + if (backend_manager.num_backends() == 0) + throw std::runtime_error{"No available backends loaded."}; + + // TODO Move "schedule" phase out of here + // TODO Scheduling + std::unique_ptr<BackendResolver> backend_resolver; + auto all_backends = backend_manager.getAll(); + + auto scheduler = ManualScheduler(all_backends, options); + backend_resolver = scheduler.schedule(_trainable_graph.graph()); + + // Check if backends are trainable + _trainable_graph.operations().iterate( + [&](const ir::OperationIndex &op_ind, const ir::IOperation &) { + const auto backend = backend_resolver->getBackend(op_ind); + + // TODO Remove dynamic_cast + if (dynamic_cast<const backend::train::ITrainableBackend *>(backend) == nullptr) + { + throw std::runtime_error(backend->config()->id() + "backend does not support training"); + } + }); + + makeLowerInfo(*backend_resolver); + VERBOSE(LoweredTrainableGraph) << "dump before mandatory passes" << std::endl; + dumper::text::dumpLoweredGraph(*this); + + // Mandatory passes - kind of legalization(?) + compiler::pass::PassRunner{} + .append(std::make_unique<compiler::pass::ConstantInsertionPass>(*this)) + .append(std::make_unique<compiler::pass::ConstantLoweringPass>(*this)) + .append(std::make_unique<compiler::pass::PermutationOperationPass>(*this)) + .append(std::make_unique<compiler::pass::PermutationInsertionPass>(*this)) + .run(); + + // TODO Move converting Permute op into PermutationInsertionPass + auto op_converter = TrainableOperationConverter{_trainable_graph, nullptr}; + _trainable_graph.operations().iterate( + [&](const onert::ir::OperationIndex &index, const onert::ir::IOperation &op) { + if (op.opcode() == ir::OpCode::Permute) + { + auto trainable_op = op_converter(op); + auto gen_index = _trainable_graph.replaceOperation(index, std::move(trainable_op)); + UNUSED_RELEASE(gen_index); + assert(gen_index == index); + } + }); + + dumpLowerInfo(); + + // Optimization passes (optional) + compiler::pass::PassRunner{} + .append(std::make_unique<compiler::pass::PermutationEliminationPass>(*this)) + .run(); + + // TODO Update LowerInfo for training + + VERBOSE(LoweredTrainableGraph) << "Dump after all the passes" << std::endl; + for (auto &&operand : _trainable_graph.getInputs()) + VERBOSE(LoweredTrainableGraph) << "Graph Input : " << operand << std::endl; + for (auto &&operand : _trainable_graph.getOutputs()) + VERBOSE(LoweredTrainableGraph) << "Graph Output : " << operand << std::endl; + dumper::text::dumpLoweredGraph(*this); + + // Graph verifications + { + assert(ir::verifier::InputOutputChecker().verify(_trainable_graph.graph())); + assert(ir::verifier::DAGChecker().verify(_trainable_graph.graph())); + assert(ir::verifier::EdgeChecker().verify(_trainable_graph.graph())); + } +} + +void LoweredTrainableGraph::makeLowerInfo(const compiler::BackendResolver &backend_resolver) +{ + _trainable_graph.operands().iterate([&](const ir::OperandIndex &index, const ir::Operand &) { + lower_info().operand.set(index, std::make_unique<OperandLowerInfo>()); + }); + + // Set operand lower info using assigned backends to operations + _trainable_graph.operations().iterate( + [&](const ir::OperationIndex &op_ind, const ir::IOperation &op) { + auto backend = backend_resolver.getBackend(op_ind); + if (!backend) + { + throw std::runtime_error{"Fail to find backend for " + op.name() + " operation"}; + } + + auto frontend_layout = _trainable_graph.layout(); + + // The layout of each backend should be set at another place + // TODO Change setting layout of each backend at another place + auto backend_layout = backend->config()->supportLayout(op, frontend_layout); + + for (auto &&ind : op.getInputs() | ir::Remove::UNDEFINED) + { + auto &operand_li = lower_info().operand.at(ind); + operand_li.addUsePermuteFactor(PermuteFactor{backend, backend_layout}); + } + for (auto &&ind : op.getOutputs() | ir::Remove::UNDEFINED) + { + auto &operand_li = lower_info().operand.at(ind); + operand_li.addDefPermuteFactor(PermuteFactor{backend, backend_layout}); + } + lower_info().operation.set( + op_ind, std::make_unique<compiler::OperationLowerInfo>(backend, backend_layout)); + }); + + // Handle graph inputs and outputs + const auto builtin_backend = BackendManager::get().getBuiltin(); + auto factor = PermuteFactor{builtin_backend, _trainable_graph.layout()}; + for (auto &&index : _trainable_graph.getInputs() | ir::Remove::UNDEFINED) + { + auto &operand_li = lower_info().operand.at(index); + assert(operand_li.def_factors().empty()); + operand_li.addDefPermuteFactor(factor); + } + for (auto &&index : _trainable_graph.getOutputs() | ir::Remove::UNDEFINED) + { + auto &operand_li = lower_info().operand.at(index); + operand_li.addUsePermuteFactor(factor); + } + + // Handle variable tensors + _trainable_graph.operands().iterate([&](const ir::OperandIndex &index, ir::Operand &operand) { + // Some inputs of an operation could be non-constant, but not existed in graph inputs/outputs + // and not undefined operand - these are variable tensors. For example, + // UnidirectionalSequenceLSTM has such inputs. + if (operand.info().isVariable()) + { + // The variable operand with buffer is not supported yet + assert(operand.data() == nullptr); + assert(operand.getUses().size() == 1 && !operand.getDef().valid()); + auto operand_li = lower_info().operand.at(index); + assert(operand_li.def_factors().empty()); + operand_li.addDefPermuteFactor(operand_li.use_factors().getOnlyElement()); + } + }); +} + +void LoweredTrainableGraph::dumpLowerInfo() +{ + if (::onert::util::logging::ctx.enabled() == false) + return; + + std::map<uint32_t, std::string> dumps; + + _trainable_graph.operands().iterate([&](const ir::OperandIndex &index, ir::Operand &object) { + const auto operand_lower_info = lower_info().operand.getRawPtr(index); + assert(operand_lower_info); + if (!operand_lower_info->def_factors().empty() || !operand_lower_info->use_factors().empty()) + { + auto shape_to_string = [](const ir::Shape &shape) { + std::stringstream sstream; + sstream << "{ "; + for (auto i = 0; i < shape.rank(); ++i) + sstream << (shape.dim(i)) << " "; + sstream << "}"; + return sstream.str(); + }; + + auto factors_to_string = [](const PermuteFactorSet &factors) { + std::string str; + for (auto &&factor : factors) + { + str += factor.backend()->config()->id(); + str += "(" + to_string(factor.layout()) + ")"; + str += " "; + } + return "{ " + str + "}"; + }; + + auto operation_index_set_to_string = [](const ir::OperationIndexSet &operations) { + std::stringstream sstream; + sstream << "{ "; + for (auto &&op : operations) + sstream << op << " "; + sstream << "}"; + return sstream.str(); + }; + + auto data_to_str = [](const ir::Data *data) { + return (data ? (std::to_string(data->size()) + " bytes") : "N/A"); + }; + + std::string shape_str = shape_to_string(object.shape()); + std::string def_op = operation_index_set_to_string({object.getDef()}); + std::string use_ops = operation_index_set_to_string(object.getUses()); + std::string def_factors = factors_to_string(operand_lower_info->def_factors()); + std::string use_factors = factors_to_string(operand_lower_info->use_factors()); + std::stringstream sstream; + sstream << "Operand " << index << " Info" << std::endl; + sstream << " - Shape : " << shape_str << std::endl; + sstream << " - Def/Uses : Def " << def_op << " Uses " << use_ops << std::endl; + sstream << " - Data : " << data_to_str(object.data()) << std::endl; + sstream << " - LowerInfo : Def " << def_factors << " Uses " << use_factors << std::endl; + dumps.emplace(index.value(), sstream.str()); + } + }); + + for (const auto &e : dumps) + { + if (!e.second.empty()) + { + std::istringstream iss(e.second); + std::string line; + while (std::getline(iss, line)) + VERBOSE(Lower) << line << std::endl; + } + } +} + +} // namespace train +} // namespace compiler +} // namespace onert diff --git a/runtime/onert/core/src/compiler/train/StaticDerivativeShapeInferer.cc b/runtime/onert/core/src/compiler/train/StaticDerivativeShapeInferer.cc new file mode 100644 index 000000000..d2153296f --- /dev/null +++ b/runtime/onert/core/src/compiler/train/StaticDerivativeShapeInferer.cc @@ -0,0 +1,150 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "StaticDerivativeShapeInferer.h" +#include "util/ShapeInference.h" +#include "util/logging.h" + +#include <misc/polymorphic_downcast.h> + +#include <sstream> +#include <stdexcept> + +namespace onert +{ +namespace compiler +{ +namespace train +{ + +void StaticDerivativeShapeInferer::infer() +{ + // It is not determined to iterate in reverse order. + auto sorted_ops = _lowered_subg->graph().topolSortOperations(); + for (auto it = sorted_ops.rbegin(); it != sorted_ops.rend(); ++it) + { + const auto op_idx = *it; + const auto &op = _lowered_subg->trainable_graph().operation(op_idx); + if (checkDynamicInput(op)) + { + std::stringstream msg; + msg << "StaticDerivativeShapeInferer does not support dynamic shape yet, "; + msg << op.name() << "(op index: " << op_idx << ") has dynamic shape."; + throw std::runtime_error(msg.str()); + } + + checkOutput(op); + + op.accept(*this); + } +} + +void StaticDerivativeShapeInferer::dump() +{ + // TODO dump +} + +bool StaticDerivativeShapeInferer::checkDynamicInput(const ir::IOperation &op) +{ + const auto &operands = _lowered_subg->graph().operands(); + for (auto input_idx : op.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED) + { + if (operands.at(input_idx).info().isDynamic()) + { + return true; + } + } + + return false; +} + +void StaticDerivativeShapeInferer::checkOutput(const ir::IOperation &op) +{ + const auto &derivatives = _lowered_subg->trainable_graph().derivatives(); + for (auto output_idx : op.getOutputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED) + { + if (!derivatives.exist(output_idx)) + { + std::stringstream msg; + msg << "StaticDerivativeShapeInferer : Invalid output, "; + msg << op.name() << "'s derivative output(index: " << output_idx << ") does not exist."; + throw std::runtime_error(msg.str()); + } + } +} + +void StaticDerivativeShapeInferer::setShape(const ir::OperandIndex &index, const ir::Shape &shape) +{ + auto &tgraph = _lowered_subg->trainable_graph(); + + if (tgraph.derivatives().exist(index)) + tgraph.changeDerivativeShape(index, shape); + else + { + // NOTE This code assumes the types are always the same, but I'm not sure. + const auto &type = tgraph.operands().at(index).typeInfo(); + const auto new_index = tgraph.addDerivative(index, std::make_unique<ir::Operand>(shape, type)); + assert(new_index == index); + UNUSED_RELEASE(new_index); + } +} + +void StaticDerivativeShapeInferer::visit(const ir::train::operation::Conv2D &) +{ + // NYI +} + +void StaticDerivativeShapeInferer::visit(const ir::train::operation::ElementwiseActivation &) +{ + // NYI +} + +void StaticDerivativeShapeInferer::visit(const ir::train::operation::Loss &) +{ + // NYI +} + +void StaticDerivativeShapeInferer::visit(const ir::train::operation::Permute &op) +{ + const auto &derivatives = _lowered_subg->trainable_graph().derivatives(); + + const auto &output_idx = op.getOutputs().at(0); + const auto &output = derivatives.at(output_idx); + + // re-sizing input derivative shape + const auto &input_idx = op.getInputs().at(0); + const auto &new_shape = output.info().shape(); + setShape(input_idx, new_shape); +} + +void StaticDerivativeShapeInferer::visit(const ir::train::operation::Pool2D &) +{ + // NYI +} + +void StaticDerivativeShapeInferer::visit(const ir::train::operation::Reshape &) +{ + // NYI +} + +void StaticDerivativeShapeInferer::visit(const ir::train::operation::Softmax &) +{ + // NYI +} + +} // namespace train +} // namespace compiler +} // namespace onert diff --git a/runtime/onert/core/src/compiler/train/StaticDerivativeShapeInferer.h b/runtime/onert/core/src/compiler/train/StaticDerivativeShapeInferer.h new file mode 100644 index 000000000..48b3172d2 --- /dev/null +++ b/runtime/onert/core/src/compiler/train/StaticDerivativeShapeInferer.h @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_COMPILER_STATIC_DERIVATIVE_SHAPE_INFERER_H__ +#define __ONERT_COMPILER_STATIC_DERIVATIVE_SHAPE_INFERER_H__ + +#include "ir/train/TrainableOperationVisitor.h" + +#include "compiler/train/LoweredTrainableGraph.h" +#include "ir/Index.h" + +#include <memory> +#include <unordered_map> + +namespace onert +{ +namespace compiler +{ +namespace train +{ + +/** + * @brief Class to infer shape before running kernels. It does the following: + * - re-calculate and set output shape at compile time (before running kernels) + * - if calculation cannot be done at compile time, mark the outputs to be dynamic, meaning + * shapes of outputs will be calculated during running kernels + */ +class StaticDerivativeShapeInferer : public ir::train::TrainableOperationVisitor +{ +public: + StaticDerivativeShapeInferer(compiler::train::LoweredTrainableGraph *lowered_subg) + : _lowered_subg{lowered_subg} + { + } + + /** + * @brief Infer shape of operands belonging to ops and set the output shape. + * If output shape cannot be known without running op, mark it so that it can be allocated + * when running kernel. + */ + void infer(void); + + void dump(); + +private: + bool checkDynamicInput(const ir::IOperation &op); + void checkOutput(const ir::IOperation &op); + void setShape(const ir::OperandIndex &index, const ir::Shape &shape); + +private: + void visit(const ir::train::operation::Conv2D &op) override; + void visit(const ir::train::operation::ElementwiseActivation &op) override; + void visit(const ir::train::operation::Loss &op) override; + void visit(const ir::train::operation::Permute &op) override; + void visit(const ir::train::operation::Pool2D &op) override; + void visit(const ir::train::operation::Reshape &op) override; + void visit(const ir::train::operation::Softmax &op) override; + +private: + compiler::train::LoweredTrainableGraph *_lowered_subg; +}; + +} // namespace train +} // namespace compiler +} // namespace onert + +#endif // __ONERT_COMPILER_STATIC_DERIVATIVE_SHAPE_INFERER_H__ diff --git a/runtime/onert/core/src/compiler/train/TensorRegistries.h b/runtime/onert/core/src/compiler/train/TensorRegistries.h new file mode 100644 index 000000000..48eaf10a1 --- /dev/null +++ b/runtime/onert/core/src/compiler/train/TensorRegistries.h @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_COMPILER_TRAIN_TENSOR_REGISTRIES_H__ +#define __ONERT_COMPILER_TRAIN_TENSOR_REGISTRIES_H__ + +#include "../../backend/builtin/Config.h" +#include "../../backend/builtin/train/TensorRegistry.h" + +#include <backend/train/TrainableBackendContext.h> + +#include <memory> +#include <unordered_set> + +namespace onert +{ +namespace compiler +{ +namespace train +{ + +class TensorRegistries +{ +public: + TensorRegistries() = default; + + TensorRegistries(const backend::train::TrainableBackendContexts &backend_contexts, + bool include_builtin) + { + for (const auto &e : backend_contexts) + { + auto tensor_reg = e.second->tensor_registry(); + if (e.first->config()->id() == backend::builtin::Config::ID) + { + _builtin_tensor_reg = + std::dynamic_pointer_cast<backend::builtin::train::TensorRegistry>(tensor_reg); + if (include_builtin) + _tensor_regs.insert(tensor_reg); + } + else + { + _tensor_regs.insert(tensor_reg); + } + } + } + + std::unordered_set<std::shared_ptr<backend::train::ITensorRegistry>>::const_iterator begin() const + { + return _tensor_regs.cbegin(); + } + std::unordered_set<std::shared_ptr<backend::train::ITensorRegistry>>::const_iterator end() const + { + return _tensor_regs.cend(); + } + + std::shared_ptr<backend::builtin::train::TensorRegistry> getBuiltinTensorRegistry() const + { + return _builtin_tensor_reg; + } + + backend::ITensor *getITensor(ir::OperandIndex index) const + { + for (auto &&tensor_reg : _tensor_regs) + { + auto tensor = tensor_reg->getITensor(index); + if (tensor) + return tensor; + } + return nullptr; + } + + backend::ITensor *getDerivativeITensor(ir::OperandIndex index) const + { + for (auto &&tensor_reg : _tensor_regs) + { + auto tensor = tensor_reg->getDerivativeITensor(index); + if (tensor) + return tensor; + } + return nullptr; + } + +private: + std::unordered_set<std::shared_ptr<backend::train::ITensorRegistry>> _tensor_regs; + std::shared_ptr<backend::builtin::train::TensorRegistry> _builtin_tensor_reg; +}; + +} // namespace train +} // namespace compiler +} // namespace onert + +#endif // __ONERT_COMPILER_TRAIN_TENSOR_REGISTRIES_H__ diff --git a/runtime/onert/core/src/compiler/train/TrainableOperationConverter.cc b/runtime/onert/core/src/compiler/train/TrainableOperationConverter.cc new file mode 100644 index 000000000..d20ae9fd3 --- /dev/null +++ b/runtime/onert/core/src/compiler/train/TrainableOperationConverter.cc @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TrainableOperationConverter.h" + +#include "ir/train/Operations.Include.h" +#include "util/Utils.h" + +namespace onert +{ +namespace compiler +{ +namespace train +{ + +TrainableOperationConverter::TrainableOperationConverter( + ir::train::TrainableGraph &tgraph, const compiler::train::TrainingInfo *training_info) + : UntrainableOperationConverter{tgraph}, _training_info{training_info} +{ + // Avoid unused-private-field error + UNUSED_RELEASE(_training_info); +} + +void TrainableOperationConverter::visit(const ir::operation::Conv2D &node) +{ + _return_op = std::make_unique<ir::train::operation::Conv2D>(node); +} + +void TrainableOperationConverter::visit(const ir::operation::ElementwiseActivation &node) +{ + if (node.param().op_type == ir::operation::ElementwiseActivation::Type::RELU) + { + _return_op = std::make_unique<ir::train::operation::ElementwiseActivation>(node); + } + else + { + UntrainableOperationConverter::visit(node); + } +} + +void TrainableOperationConverter::visit(const ir::operation::FullyConnected &node) +{ + _return_op = std::make_unique<ir::train::operation::FullyConnected>(node); +} + +void TrainableOperationConverter::visit(const ir::operation::Loss &node) +{ + _return_op = std::make_unique<ir::train::operation::Loss>(node); +} + +void TrainableOperationConverter::visit(const ir::operation::Permute &node) +{ + _return_op = std::make_unique<ir::train::operation::Permute>(node); +} + +void TrainableOperationConverter::visit(const ir::operation::Pool2D &node) +{ + _return_op = std::make_unique<ir::train::operation::Pool2D>(node); +} + +void TrainableOperationConverter::visit(const ir::operation::Reshape &node) +{ + _return_op = std::make_unique<ir::train::operation::Reshape>(node); +} + +void TrainableOperationConverter::visit(const ir::operation::Softmax &node) +{ + _return_op = std::make_unique<ir::train::operation::Softmax>(node); +} + +} // namespace train +} // namespace compiler +} // namespace onert diff --git a/runtime/onert/core/src/compiler/train/TrainableOperationConverter.h b/runtime/onert/core/src/compiler/train/TrainableOperationConverter.h new file mode 100644 index 000000000..5f6fc10c3 --- /dev/null +++ b/runtime/onert/core/src/compiler/train/TrainableOperationConverter.h @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_COMPILER_TRAIN_TRAINABLE_OPERATION_CONVERTER_H__ +#define __ONERT_COMPILER_TRAIN_TRAINABLE_OPERATION_CONVERTER_H__ + +#include "UntrainableOperationConverter.h" + +#include "compiler/train/TrainingInfo.h" + +namespace onert +{ +namespace compiler +{ +namespace train +{ + +class TrainableOperationConverter : public UntrainableOperationConverter +{ +public: + TrainableOperationConverter(ir::train::TrainableGraph &trainable_graph, + const compiler::train::TrainingInfo *training_info); + + using UntrainableOperationConverter::operator(); + +private: + void visit(const ir::operation::Conv2D &) override; + void visit(const ir::operation::ElementwiseActivation &) override; + void visit(const ir::operation::FullyConnected &) override; + void visit(const ir::operation::Loss &node) override; + void visit(const ir::operation::Permute &node) override; + void visit(const ir::operation::Pool2D &node) override; + void visit(const ir::operation::Reshape &) override; + void visit(const ir::operation::Softmax &) override; + +private: + const compiler::train::TrainingInfo *_training_info; +}; + +} // namespace train +} // namespace compiler +} // namespace onert + +#endif // __ONERT_COMPILER_TRAIN_TRAINABLE_OPERATION_CONVERTER_H__ diff --git a/runtime/onert/core/src/compiler/train/TrainingCompiler.cc b/runtime/onert/core/src/compiler/train/TrainingCompiler.cc new file mode 100644 index 000000000..711af1651 --- /dev/null +++ b/runtime/onert/core/src/compiler/train/TrainingCompiler.cc @@ -0,0 +1,299 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TrainingCompiler.h" + +#include "StaticDerivativeShapeInferer.h" +#include "TrainableOperationConverter.h" +#include "pass/LossInsertionPass.h" +#include "../CompilerHelpers.h" +#include "../ExecutorFactory.h" +#include "../pass/ConstantOutputPass.h" +#include "../pass/OddOutputPass.h" +#include "../pass/PassRunner.h" +#include "../pass/UnusedOperandEliminationPass.h" +#include "../ShapeValidator.h" +#include "../../dumper/dot/DotDumper.h" +#include "../../exec/train/TrainableExecutors.h" +#include "../../ir/OperationDumper.h" +#include "../../ir/verifier/Verifier.h" + +#include <compiler/StaticShapeInferer.h> +#include <compiler/train/LoweredTrainableGraph.h> +#include <ir/train/TrainableGraph.h> +#include <exec/train/optimizer/SGD.h> + +#include <misc/polymorphic_downcast.h> +#include <misc/string_helpers.h> + +namespace onert +{ +namespace compiler +{ +namespace train +{ + +TrainingCompiler::TrainingCompiler(const std::shared_ptr<ir::NNPkg> &nnpkg, + std::vector<std::unique_ptr<CompilerOptions>> &copts, + const TrainingInfo &training_info) + : _model{nnpkg->primary_model()}, _options{copts[0].get()}, _training_info{training_info} +{ + if (nnpkg->model_count() > 1) + throw std::runtime_error("TrainingCompiler does not support multiple models yet"); + + if (nnpkg->primary_model()->subgraphs_count() > 1) + throw std::runtime_error("TrainingCompiler does not support multiple subgraphs yet"); +} + +std::shared_ptr<CompilerArtifact> TrainingCompiler::compile(void) +{ + /*************************************************** + * Prepare compilation phase + ***************************************************/ + if (!_options) + throw std::runtime_error{"Empty compile option"}; + + // Mode check + // TODO handle option for each model + if (_options->he_profiling_mode) + { + if (!_options->he_scheduler) + throw std::runtime_error("Heterogeneous scheduler must be enabled during profiling."); + + if (_options->executor != "Dataflow") + throw std::runtime_error("Profiling mode works only with 'Dataflow' executor"); + } + + if (!_options->minmax_filepath.empty()) + { + if (_options->executor != "Linear") + throw std::runtime_error("Recording minmax works only with Linear executor"); + } + + _options->forceInternalOptions(); + _options->verboseOptions(); + + auto custom_kernel_builder = _model->getKernelBuilder(); + + _model->iterate([&](const ir::SubgraphIndex &, ir::IGraph &graph) { + auto &subg = nnfw::misc::polymorphic_downcast<ir::Graph &>(graph); + // Mandatory passes + compiler::pass::PassRunner{} + .append(std::make_unique<compiler::pass::ConstantOutputPass>(subg)) + .append(std::make_unique<compiler::pass::OddOutputPass>(subg)) + .run(); + + // Optimizations + compiler::pass::PassRunner{} + .append(std::make_unique<compiler::pass::UnusedOperandEliminationPass>(subg)) + .run(); + }); + + std::unordered_map<ir::SubgraphIndex, std::shared_ptr<ir::train::TrainableGraph>> + trainable_subgraphs; + + if (_model->hasOnly<ir::Graph>()) + { + // Create trainable subgraphs by copy and converting inference model + _model->iterate([&](const ir::SubgraphIndex &subg_index, const ir::IGraph &graph) { + const auto &subg = nnfw::misc::polymorphic_downcast<const ir::Graph &>(graph); + // Create TrainableGraph by copying Graph + auto trainable_subg = std::make_shared<ir::train::TrainableGraph>(subg); + + // Convert operations to trainable operations + auto converter = TrainableOperationConverter{*trainable_subg, &_training_info}; + subg.operations().iterate( + [&](const onert::ir::OperationIndex &op_index, const onert::ir::IOperation &op) { + auto trainable_op = converter(op); + auto gen_index = trainable_subg->replaceOperation(op_index, std::move(trainable_op)); + UNUSED_RELEASE(gen_index); + assert(gen_index == op_index); + }); + + trainable_subgraphs[subg_index] = std::move(trainable_subg); + }); + } + else + { + // TODO Support models that have TrainableGraphs + throw std::runtime_error("TrainingCompiler: Invalid model"); + } + + // operation + _model.reset(); + + // Apply pass for trainable subgraphs + for (auto &&pair : trainable_subgraphs) + { + auto trainable_subg = pair.second; + auto subg_index = pair.first; + + compiler::pass::PassRunner{} + .append(std::make_unique<train::pass::LossInsertionPass>(*trainable_subg, &_training_info, + subg_index)) + .run(); + } + + // Change input shape according to batch_size + for (auto &&pair : trainable_subgraphs) + { + auto trainable_subg = pair.second; + + for (const auto &ind : trainable_subg->getInputs()) + { + auto &input = trainable_subg->operands().at(ind); + auto new_shape = input.info().shape(); + // TODO Consider batch size index + if (new_shape.dim(0) != 1) + throw std::runtime_error("the first dim is not 1. It is not supported yet."); + new_shape.dim(0) = _training_info.batchSize(); + input.info().shape(new_shape); + } + } + + /*************************************************** + * Backend independent analysis & optimization phase + ***************************************************/ + // TODO Handle dump level for each model + auto dump_level = static_cast<dumper::dot::DotDumper::Level>(_options->graph_dump_level); + onert::dumper::dot::DotDumper dot_dumper(dump_level); + + // Tracing context + auto tracing_ctx = std::make_unique<util::TracingCtx>(); + + // Lower: Assign backend + std::unordered_map<ir::SubgraphIndex, std::unique_ptr<compiler::train::LoweredTrainableGraph>> + lowered_subgs; + { + for (auto &&pair : trainable_subgraphs) + { + auto &subg_index = pair.first; + auto trainable_subg = pair.second; + + // Lower: Assign backend + lowered_subgs[subg_index] = + std::make_unique<compiler::train::LoweredTrainableGraph>(*trainable_subg, *_options); + // Set tracing_ctx for copied graph + if (tracing_ctx != nullptr) + tracing_ctx->setSubgraphIndex(&(lowered_subgs[subg_index]->graph()), subg_index.value()); + } + } + + for (const auto &pair : lowered_subgs) + { + const auto &subg_index = pair.first; + const auto &lowered_subg = pair.second; + dot_dumper.dump(*lowered_subg, nnfw::misc::str("after_lower_subg-", subg_index.value())); + } + + // Set derivatives as default tensor info + for (const auto &pair : lowered_subgs) + { + auto lowered_subg = pair.second.get(); + auto &tgraph = lowered_subg->trainable_graph(); + tgraph.operands().iterate([&](const ir::OperandIndex &index, const ir::Operand &obj) { + if (!obj.isConstant()) + { + auto deriv = std::make_unique<ir::Operand>(obj); + const auto gen_index = tgraph.addDerivative(index, std::move(deriv)); + assert(gen_index == index); + UNUSED_RELEASE(gen_index); + } + }); + } + + // Shape inference. + { + // Run the StaticShapeInfer of primary subg. All child StaticShapeInferers are called + // recursively + std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>> inferers = + createStaticShapeInferers(lowered_subgs); + + const auto primary_subg_idx = ir::SubgraphIndex{0}; + inferers.at(primary_subg_idx)->infer(); + + for (const auto &pair_inferer : inferers) + { + const auto inferer = pair_inferer.second.get(); + inferer->dump(); + } + + // NOTE StaticDerivativeShapeInferer is allocated for each subgraph, + // so it does not support models that have controlflow operations yet. + for (auto &&pair : lowered_subgs) + { + auto &lowered_subg = pair.second; + auto inferer = std::make_unique<StaticDerivativeShapeInferer>(lowered_subg.get()); + inferer->infer(); + inferer->dump(); + } + } + + // Shape validation + for (const auto &pair : lowered_subgs) + { + auto &lowered_subg = pair.second; + compiler::ShapeValidator{lowered_subg->graph()}(); + } + + // TODO Validate shapes of derivative tensors + + // Create optimizer + // TODO Set properties of optimizer + std::shared_ptr<exec::train::optimizer::Optimizer> optimizer; + const auto &optim_info = _training_info.optimizerInfo(); + if (optim_info.optim_code == exec::train::optimizer::OptimizerCode::SGD) + optimizer = std::make_shared<exec::train::optimizer::SGD>(optim_info.learning_rate); + else + throw std::runtime_error("Invalid optimizer type, " + + exec::train::optimizer::toString(optim_info.optim_code)); + + /************************************************************* + * Backend independent analysis & optimization phase finished + *************************************************************/ + auto executors = std::make_shared<exec::train::TrainableExecutors>(); + for (auto &&pair : lowered_subgs) + { + auto const model_index = ir::ModelIndex{0}; + auto const subg_index = pair.first; + auto &lowered_subg = pair.second; + auto const indexed_ranks = lowered_subg->indexed_ranks(); + + ir::OperationDumper dumper("Executor generation of Subgraph " + + std::to_string(subg_index.value())); + lowered_subg->graph().operations().iterate( + [&](const ir::OperationIndex &, const ir::IOperation &op) { op.accept(dumper); }); + + ExecutorFactoryArgs args; + args.tracing_ctx = tracing_ctx.get(); + args.options = _options; + args.model_index = model_index; + args.custom_kernel_builder = custom_kernel_builder; + auto executor = std::unique_ptr<exec::IExecutor>{ + ExecutorFactory::get().create(std::move(lowered_subg), executors, args, optimizer)}; + executor->setIndexedRanks(indexed_ranks); + executors->emplace(model_index, subg_index, std::move(executor)); + } + + /******************************** + * Code generation phase finished + ********************************/ + return std::make_shared<CompilerArtifact>(executors, std::move(tracing_ctx)); +} + +} // namespace train +} // namespace compiler +} // namespace onert diff --git a/runtime/onert/core/src/compiler/train/TrainingCompiler.h b/runtime/onert/core/src/compiler/train/TrainingCompiler.h new file mode 100644 index 000000000..b93437217 --- /dev/null +++ b/runtime/onert/core/src/compiler/train/TrainingCompiler.h @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file TrainingCompiler.h + * @brief This file contains TrainingCompiler class to define and run compilation phase + */ + +#ifndef __ONERT_COMPILER_TRAIN_TRAINING_COMPILER_H_ +#define __ONERT_COMPILER_TRAIN_TRAINING_COMPILER_H_ + +#include "compiler/CompilerOptions.h" +#include "compiler/ICompiler.h" +#include "compiler/train/TrainingInfo.h" +#include "ir/NNPkg.h" + +namespace onert +{ +namespace compiler +{ +namespace train +{ + +/** + * @brief Class to compile NN package + */ +class TrainingCompiler : public ICompiler +{ +public: + /** + * @brief Construct a new TrainingCompiler object for single model + * @param[in] model model to compile + * @param[in] inference_compiler Compiler for inference + * @param[in] coptions Compiler Options + * @param[in] training_info Training information + */ + explicit TrainingCompiler(const std::shared_ptr<ir::NNPkg> &nnpkg, + std::vector<std::unique_ptr<CompilerOptions>> &copts, + const TrainingInfo &training_info); + + /** + * @brief Default Construct + * + */ + TrainingCompiler(void) = delete; + + /** + * @brief Destroy the TrainingCompiler object + */ + ~TrainingCompiler() = default; + +public: + /** + * @brief Do compilation with the options + * + * @return std::shared_ptr<CompilerArtifact> Executors as a result of compilation + */ + std::shared_ptr<CompilerArtifact> compile(void); + +private: + std::shared_ptr<ir::Model> _model; + CompilerOptions *_options; + const TrainingInfo _training_info; +}; + +} // namespace train +} // namespace compiler +} // namespace onert + +#endif // __ONERT_COMPILER_TRAIN_TRAINING_COMPILER_H_ diff --git a/runtime/onert/core/src/compiler/train/UntrainableOperationConverter.cc b/runtime/onert/core/src/compiler/train/UntrainableOperationConverter.cc new file mode 100644 index 000000000..6a5a052b6 --- /dev/null +++ b/runtime/onert/core/src/compiler/train/UntrainableOperationConverter.cc @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "UntrainableOperationConverter.h" + +#include "ir/train/operation/UntrainableOperation.h" + +namespace onert +{ +namespace compiler +{ +namespace train +{ + +UntrainableOperationConverter::UntrainableOperationConverter(ir::train::TrainableGraph &tgraph) + : _tgraph{tgraph}, _return_op{nullptr} +{ +} + +std::unique_ptr<ir::train::ITrainableOperation> UntrainableOperationConverter:: +operator()(const ir::IOperation &op) +{ + op.accept(*this); + + return std::move(_return_op); +} + +#define OP(InternalName) \ + void UntrainableOperationConverter::visit(const ir::operation::InternalName &node) \ + { \ + _return_op = \ + std::make_unique<ir::train::operation::UntrainableOperation<ir::operation::InternalName>>( \ + node); \ + } +#include "ir/Operations.lst" +#undef OP + +} // namespace train +} // namespace compiler +} // namespace onert diff --git a/runtime/onert/core/src/compiler/train/UntrainableOperationConverter.h b/runtime/onert/core/src/compiler/train/UntrainableOperationConverter.h new file mode 100644 index 000000000..e960b3831 --- /dev/null +++ b/runtime/onert/core/src/compiler/train/UntrainableOperationConverter.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_COMPILER_TRAIN_UNTRAINABLE_OPERATION_CONVERTER_H__ +#define __ONERT_COMPILER_TRAIN_UNTRAINABLE_OPERATION_CONVERTER_H__ + +#include "ir/Operations.Include.h" +#include "ir/OperationVisitor.h" +#include "ir/train/TrainableGraph.h" + +#include <memory> + +namespace onert +{ +namespace compiler +{ +namespace train +{ + +class UntrainableOperationConverter : public ir::OperationVisitor +{ +public: + UntrainableOperationConverter(ir::train::TrainableGraph &tgraph); + std::unique_ptr<ir::train::ITrainableOperation> operator()(const ir::IOperation &op); + +#define OP(InternalName) void visit(const ir::operation::InternalName &node); +#include "ir/Operations.lst" +#undef OP + +protected: + ir::train::TrainableGraph &_tgraph; + std::unique_ptr<ir::train::ITrainableOperation> _return_op; +}; + +} // namespace train +} // namespace compiler +} // namespace onert + +#endif // __ONERT_COMPILER_TRAIN_UNTRAINABLE_OPERATION_CONVERTER_H__ diff --git a/runtime/onert/core/src/compiler/train/pass/LossInsertionPass.cc b/runtime/onert/core/src/compiler/train/pass/LossInsertionPass.cc new file mode 100644 index 000000000..3e01a9739 --- /dev/null +++ b/runtime/onert/core/src/compiler/train/pass/LossInsertionPass.cc @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "LossInsertionPass.h" + +#include "ir/train/TrainableGraph.h" +#include "ir/train/operation/Loss.h" + +namespace onert +{ +namespace compiler +{ +namespace train +{ +namespace pass +{ + +void LossInsertionPass::run() +{ + const auto &loss_info = _training_info->lossInfo(); + + ir::operation::Loss::Param param; + param.op_type = loss_info.type; + + if (_trainable_graph.getOutputs().size() != 1) + { + throw std::runtime_error("LossInsertionPass: Not supported multiple outputs"); + } + + // TODO Consider SparseCategoricalCrossentropy y_true shape + // SparseCategoricalCrossentropy loss has a different y_true shape than y_pred. + + // TODO Implement Loop [0, getOutputs().size()) + // index: a loop index + const auto index = 0; + const auto &y_pred_index = _trainable_graph.getOutputs().at(index); + const auto &y_pred = _trainable_graph.operands().at(y_pred_index); + const auto &shape = y_pred.shape(); + const auto &type_info = y_pred.typeInfo(); + auto y_true_index = _trainable_graph.addOperand(shape, type_info); + ir::OperandIndexSequence inputs{y_pred_index, y_true_index}; + + // TODO Consider Reduction + // Some types of Reduction have the same shape y_true and output. + + const ir::TypeInfo float_op(ir::DataType::FLOAT32); + auto output_index = _trainable_graph.addOperand(ir::Shape{1}, float_op); + ir::OperandIndexSequence outputs{output_index}; + + auto loss_op = std::make_unique<ir::operation::Loss>(inputs, outputs, param); + auto trainable_loss_op = std::make_unique<ir::train::operation::Loss>(*loss_op); + + _trainable_graph.addOperation(std::move(trainable_loss_op)); + + _trainable_graph.addInput(y_true_index); + + // TODO Add loss as many as output size + _trainable_graph.addLoss(output_index, ir::IOIndex{index}); +} + +} // namespace pass +} // namespace train +} // namespace compiler +} // namespace onert diff --git a/runtime/onert/core/src/compiler/train/pass/LossInsertionPass.h b/runtime/onert/core/src/compiler/train/pass/LossInsertionPass.h new file mode 100644 index 000000000..ed4d60c96 --- /dev/null +++ b/runtime/onert/core/src/compiler/train/pass/LossInsertionPass.h @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_COMPILER_TRAIN_PASS_LOSS_INSERTION_PASS_H__ +#define __ONERT_COMPILER_TRAIN_PASS_LOSS_INSERTION_PASS_H__ + +#include "Pass.h" + +#include "compiler/train/TrainingInfo.h" + +namespace onert +{ +namespace compiler +{ +namespace train +{ +namespace pass +{ + +class LossInsertionPass : public Pass +{ +public: + LossInsertionPass(ir::train::TrainableGraph &trainable_graph, const TrainingInfo *training_info, + const ir::SubgraphIndex &subg_index) + : Pass{trainable_graph, training_info}, _subg_index{subg_index} + { + } + +public: + std::string id() final { return "LossInsertionPass"; } + void run() final; + +private: + ir::SubgraphIndex _subg_index; +}; + +} // namespace pass +} // namespace train +} // namespace compiler +} // namespace onert + +#endif // __ONERT_COMPILER_TRAIN_PASS_LOSS_INSERTION_PASS_H__ diff --git a/runtime/onert/core/src/compiler/train/pass/Pass.h b/runtime/onert/core/src/compiler/train/pass/Pass.h new file mode 100644 index 000000000..d64c06cf4 --- /dev/null +++ b/runtime/onert/core/src/compiler/train/pass/Pass.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_COMPILER_TRAIN_PASS_PASS_H__ +#define __ONERT_COMPILER_TRAIN_PASS_PASS_H__ + +#include "../../pass/IPass.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +class TrainableGraph; +} // namespace train +} // namespace ir +} // namespace onert + +namespace onert +{ +namespace compiler +{ +namespace train +{ + +class TrainingInfo; + +namespace pass +{ + +class Pass : public compiler::pass::IPass +{ +public: + Pass(ir::train::TrainableGraph &trainable_graph, const TrainingInfo *training_info) + : _trainable_graph{trainable_graph}, _training_info{training_info} + { + } + virtual ~Pass() = default; + +protected: + ir::train::TrainableGraph &_trainable_graph; + const TrainingInfo *_training_info; +}; + +} // namespace pass +} // namespace train +} // namespace compiler +} // namespace onert + +#endif // __ONERT_COMPILER_TRAIN_PASS_PASS_H__ diff --git a/runtime/onert/core/src/dumper/dot/DotBuilder.cc b/runtime/onert/core/src/dumper/dot/DotBuilder.cc index d4e4d5484..9257434fa 100644 --- a/runtime/onert/core/src/dumper/dot/DotBuilder.cc +++ b/runtime/onert/core/src/dumper/dot/DotBuilder.cc @@ -29,7 +29,7 @@ DotBuilder::DotBuilder() {} void DotBuilder::update(const Node &node_info) { add(node_info); - for (auto edge : node_info.out_edges()) + for (auto &&edge : node_info.out_edges()) { addEdge(node_info, *edge); } @@ -47,7 +47,7 @@ void DotBuilder::add(const Node &node) _dot << node.id(); std::stringstream ss; _dot << "["; - for (auto attr : node.attributes()) + for (auto &&attr : node.attributes()) { _dot << attr.first << "=\"" << attr.second << "\" "; } diff --git a/runtime/onert/core/src/dumper/dot/DotDumper.cc b/runtime/onert/core/src/dumper/dot/DotDumper.cc index 0bb2fa11f..ab77a6c62 100644 --- a/runtime/onert/core/src/dumper/dot/DotDumper.cc +++ b/runtime/onert/core/src/dumper/dot/DotDumper.cc @@ -98,10 +98,10 @@ generate_dot_operations(const ir::Graph &graph, { ir::OperationIndexMap<std::unique_ptr<Operation>> dot_operations; const auto &operations = graph.operations(); - operations.iterate([&](const ir::OperationIndex &index, const ir::Operation &op) { + operations.iterate([&](const ir::OperationIndex &index, const ir::IOperation &op) { auto node = std::make_unique<Operation>(index, op); - for (auto input : op.getInputs()) + for (auto &&input : op.getInputs()) { using onert::dumper::dot::Operand; @@ -113,7 +113,7 @@ generate_dot_operations(const ir::Graph &graph, input_node->addOutEdge(node.get()); } - for (auto output : op.getOutputs() | ir::Remove::UNDEFINED) + for (auto &&output : op.getOutputs() | ir::Remove::UNDEFINED) { using onert::dumper::dot::Operand; auto &output_node = dot_operands.at(output); @@ -126,7 +126,7 @@ generate_dot_operations(const ir::Graph &graph, return dot_operations; } -void update_lower_info(const compiler::LoweredGraph &lowered_graph, +void update_lower_info(const compiler::ILoweredGraph &lowered_graph, ir::OperandIndexMap<std::unique_ptr<Operand>> *dot_operands) { const auto &operands = lowered_graph.graph().operands(); @@ -153,11 +153,11 @@ void update_lower_info(const compiler::LoweredGraph &lowered_graph, }); } -void update_lower_info(const compiler::LoweredGraph &lowered_graph, +void update_lower_info(const compiler::ILoweredGraph &lowered_graph, ir::OperationIndexMap<std::unique_ptr<Operation>> *dot_operations) { const auto &operations = lowered_graph.graph().operations(); - operations.iterate([&](const ir::OperationIndex &index, const ir::Operation &) { + operations.iterate([&](const ir::OperationIndex &index, const ir::IOperation &) { const auto lower_info = lowered_graph.lower_info().operation.getRawPtr(index); if (lower_info) { @@ -213,7 +213,8 @@ void DotDumper::dump(const ir::Graph &graph, const std::string &tag) dump_to_file(dot_operands, dot_operations, tag); } -void DotDumper::dump(const compiler::LoweredGraph &lowered_graph, const std::string &tag) +// TODO Support derivative tensors +void DotDumper::dump(const compiler::ILoweredGraph &lowered_graph, const std::string &tag) { if (_level == Level::OFF) { diff --git a/runtime/onert/core/src/dumper/dot/DotDumper.h b/runtime/onert/core/src/dumper/dot/DotDumper.h index 6249010d3..fca5f356c 100644 --- a/runtime/onert/core/src/dumper/dot/DotDumper.h +++ b/runtime/onert/core/src/dumper/dot/DotDumper.h @@ -15,7 +15,7 @@ */ #include "ir/Graph.h" -#include "compiler/LoweredGraph.h" +#include "compiler/ILoweredGraph.h" #ifndef __ONERT_DUMPER_DOT_DOT_DUMPER_H__ #define __ONERT_DUMPER_DOT_DOT_DUMPER_H__ @@ -57,7 +57,7 @@ public: * @param[in] tag The name of dot file that would be created * @return N/A */ - void dump(const compiler::LoweredGraph &lowered_graph, const std::string &tag); + void dump(const compiler::ILoweredGraph &lowered_graph, const std::string &tag); private: Level _level; diff --git a/runtime/onert/core/src/dumper/dot/OperationNode.cc b/runtime/onert/core/src/dumper/dot/OperationNode.cc index 87c5ba148..2ef08c9c6 100644 --- a/runtime/onert/core/src/dumper/dot/OperationNode.cc +++ b/runtime/onert/core/src/dumper/dot/OperationNode.cc @@ -31,7 +31,7 @@ namespace dot const std::string Operation::OPERATION_SHAPE = "rect"; const std::string Operation::BG_COLOR_SCHEME = "pastel18"; -Operation::Operation(const ir::OperationIndex &index, const ir::Operation &node) +Operation::Operation(const ir::OperationIndex &index, const ir::IOperation &node) : Node{"operation" + std::to_string(index.value())} { setAttribute("label", std::to_string(index.value()) + " : " + node.name()); diff --git a/runtime/onert/core/src/dumper/dot/OperationNode.h b/runtime/onert/core/src/dumper/dot/OperationNode.h index 74a37d3fb..d9292ad0c 100644 --- a/runtime/onert/core/src/dumper/dot/OperationNode.h +++ b/runtime/onert/core/src/dumper/dot/OperationNode.h @@ -25,7 +25,7 @@ #define __ONERT_DUMPER_DOT_DOT_NODE_INFO_H__ #include "Node.h" -#include "ir/Operation.h" +#include "ir/IOperation.h" #include "ir/Index.h" namespace onert @@ -52,7 +52,7 @@ public: * @param[in] index operation index * @param[in] node operation object */ - Operation(const ir::OperationIndex &index, const ir::Operation &node); + Operation(const ir::OperationIndex &index, const ir::IOperation &node); }; } // namespace dot diff --git a/runtime/onert/core/src/dumper/h5/Dumper.cc b/runtime/onert/core/src/dumper/h5/Dumper.cc new file mode 100644 index 000000000..5e12c2dbb --- /dev/null +++ b/runtime/onert/core/src/dumper/h5/Dumper.cc @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Dumper.h" + +#include <iostream> +#include <sstream> +#include <stdexcept> + +namespace onert +{ +namespace dumper +{ +namespace h5 +{ + +Dumper::Dumper(const std::string &filepath) : _file{filepath, H5F_ACC_CREAT | H5F_ACC_RDWR} {} + +} // namespace h5 +} // namespace dumper +} // namespace onert diff --git a/runtime/onert/core/src/dumper/h5/Dumper.h b/runtime/onert/core/src/dumper/h5/Dumper.h new file mode 100644 index 000000000..53d0e0332 --- /dev/null +++ b/runtime/onert/core/src/dumper/h5/Dumper.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_DUMPER_H5_DUMPER_H__ +#define __ONERT_DUMPER_H5_DUMPER_H__ + +#include "exec/MinMaxMap.h" + +#include <H5Cpp.h> +#include <string> + +namespace onert +{ +namespace dumper +{ +namespace h5 +{ + +class Dumper +{ +public: + /** + * @brief Construct dumper + * + * @param[in] path filepath to dump + * @throw H5::FileIException on error during file open/create + */ + Dumper(const std::string &filepath); + +protected: + H5::H5File _file; +}; + +} // namespace h5 +} // namespace dumper +} // namespace onert + +#endif // __ONERT_DUMPER_H5_DUMPER_H__ diff --git a/runtime/onert/core/src/dumper/h5/MinMaxDumper.cc b/runtime/onert/core/src/dumper/h5/MinMaxDumper.cc new file mode 100644 index 000000000..8a9de9f95 --- /dev/null +++ b/runtime/onert/core/src/dumper/h5/MinMaxDumper.cc @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "MinMaxDumper.h" + +#include <iostream> +#include <sstream> +#include <stdexcept> + +namespace onert +{ +namespace dumper +{ +namespace h5 +{ + +static const char *h5_value_grpname = "value"; + +/* + * ensure grp_name exists in parent + */ +H5::Group ensureGroup(H5::Group parent, const char *child) +{ + H5::Exception::dontPrint(); + try + { + return parent.openGroup(child); + } + catch (H5::Exception &e) + { + return parent.createGroup(child); + } +} + +MinMaxDumper::MinMaxDumper(const std::string &filepath) : Dumper(filepath) +{ + auto root_grp = _file.openGroup("/"); + ensureGroup(root_grp, h5_value_grpname); +} + +void MinMaxDumper::dump(const exec::SMMinMaxMap &mmmap) const +{ + auto val_grp = _file.openGroup(h5_value_grpname); + auto num_run = val_grp.getNumObjs(); + auto num_grp = val_grp.createGroup(std::to_string(num_run)); + auto model_grp = ensureGroup(num_grp, "0"); + hsize_t dims[] = {2}; + H5::DataSpace dspace(1, dims); // rank=1, dim(0)=2, {min, max} + for (auto &&e : mmmap) + { + // key = {subg_idx, op_idx} = e.first + const auto subg_idx = e.first.first.value(); + const auto op_idx = e.first.second.value(); + auto subg_grp = ensureGroup(model_grp, std::to_string(subg_idx).c_str()); + auto op_dset = subg_grp.createDataSet(std::to_string(op_idx), H5::PredType::IEEE_F32BE, dspace); + op_dset.write(e.second.data, H5::PredType::NATIVE_FLOAT); + } +} + +} // namespace h5 +} // namespace dumper +} // namespace onert diff --git a/runtime/onert/core/src/dumper/h5/MinMaxDumper.h b/runtime/onert/core/src/dumper/h5/MinMaxDumper.h new file mode 100644 index 000000000..1f1b27c6e --- /dev/null +++ b/runtime/onert/core/src/dumper/h5/MinMaxDumper.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_DUMPER_H5_MINMAX_DUMPER_H__ +#define __ONERT_DUMPER_H5_MINMAX_DUMPER_H__ + +#include "exec/MinMaxMap.h" +#include "Dumper.h" + +#include <H5Cpp.h> +#include <string> + +namespace onert +{ +namespace dumper +{ +namespace h5 +{ + +// The hierachy of single model minmax h5 file +// +// GROUP / +// GROUP value +// └── GROUP run_idx +// └── GROUP model_idx +// └── GROUP subg_idx +// └── DATASET op_idx +// DATATYPE Float32 +// DATASPACE (2) +// DATA { min, max } +// GROUP name (optional, for debug) +// └── GROUP model_idx +// └── GROUP subg_idx +// └── ATTRIBUTE op_idx +// DATATYPE String +// DATA { "model/your/op/name"} +// +class MinMaxDumper : private Dumper +{ +public: + MinMaxDumper(const std::string &filepath); + /** + * @brief Dump minmax map + * + * @param[in] map single model minmax map + */ + void dump(const exec::SMMinMaxMap &map) const; + +private: + H5::Group _val_grp; +}; + +} // namespace h5 +} // namespace dumper +} // namespace onert + +#endif // __ONERT_DUMPER_H5_MINMAX_DUMPER_H__ diff --git a/runtime/onert/core/src/dumper/text/GraphDumper.cc b/runtime/onert/core/src/dumper/text/GraphDumper.cc index 80cfbbc34..6bd7904aa 100644 --- a/runtime/onert/core/src/dumper/text/GraphDumper.cc +++ b/runtime/onert/core/src/dumper/text/GraphDumper.cc @@ -18,6 +18,9 @@ #include "ir/Graph.h" #include "compiler/LoweredGraph.h" +#ifdef ONERT_TRAIN +#include "compiler/train/LoweredTrainableGraph.h" +#endif // ONERT_TRAIN #include "util/logging.h" #include "misc/string_helpers.h" @@ -34,7 +37,7 @@ namespace std::string formatOperandIndexSequence(const ir::OperandIndexSequence &seq) { std::vector<std::string> strs; - for (auto ind : seq) + for (auto &&ind : seq) strs.push_back(dumper::text::formatOperandBrief(ind)); return nnfw::misc::join(strs.begin(), strs.end(), ", "); } @@ -56,10 +59,9 @@ std::string formatOperand(const ir::Graph &, ir::OperandIndex ind) return ss.str(); } -std::string formatOperation(const ir::Graph &graph, ir::OperationIndex ind) +std::string formatOperation(const ir::IOperation &op, ir::OperationIndex ind) { std::stringstream ss; - const auto &op = graph.operations().at(ind); ss << formatOperandIndexSequence(op.getOutputs()); ss << " = "; @@ -69,13 +71,21 @@ std::string formatOperation(const ir::Graph &graph, ir::OperationIndex ind) return ss.str(); } +std::string formatOperation(const ir::Graph &graph, ir::OperationIndex ind) +{ + std::stringstream ss; + const auto &op = graph.operations().at(ind); + return formatOperation(op, ind); +} + void dumpGraph(const ir::Graph &graph) { VERBOSE(GraphDumper) << "{\n"; auto ops_topol = graph.topolSortOperations(); - for (auto op_ind : ops_topol) + for (auto &&op_ind : ops_topol) { - VERBOSE(GraphDumper) << " " << formatOperation(graph, op_ind) << "\n"; + const auto &op = graph.operations().at(op_ind); + VERBOSE(GraphDumper) << " " << formatOperation(op, op_ind) << "\n"; } VERBOSE(GraphDumper) << "}\n"; VERBOSE(GraphDumper) << std::endl; @@ -87,6 +97,14 @@ void dumpLoweredGraph(const compiler::LoweredGraph &lgraph) dumpGraph(lgraph.graph()); } +#ifdef ONERT_TRAIN +void dumpLoweredGraph(const compiler::train::LoweredTrainableGraph &lgraph) +{ + // TODO Graph dump with backend info + dumpGraph(lgraph.graph()); +} +#endif // ONERT_TRAIN + } // namespace text } // namespace dumper } // namespace onert diff --git a/runtime/onert/core/src/dumper/text/GraphDumper.h b/runtime/onert/core/src/dumper/text/GraphDumper.h index 0501ff050..ab0061465 100644 --- a/runtime/onert/core/src/dumper/text/GraphDumper.h +++ b/runtime/onert/core/src/dumper/text/GraphDumper.h @@ -24,7 +24,8 @@ namespace onert namespace ir { class Graph; -} +struct IOperation; +} // namespace ir } // namespace onert namespace onert @@ -32,7 +33,14 @@ namespace onert namespace compiler { class LoweredGraph; -} + +#ifdef ONERT_TRAIN +namespace train +{ +class LoweredTrainableGraph; +} // namespace train +#endif // ONERT_TRAIN +} // namespace compiler } // namespace onert namespace onert @@ -47,6 +55,9 @@ std::string formatOperand(const ir::Graph &, ir::OperandIndex ind); std::string formatOperation(const ir::Graph &graph, ir::OperationIndex ind); void dumpGraph(const ir::Graph &graph); void dumpLoweredGraph(const compiler::LoweredGraph &lgraph); +#ifdef ONERT_TRAIN +void dumpLoweredGraph(const compiler::train::LoweredTrainableGraph &lgraph); +#endif // ONERT_TRAIN } // namespace text } // namespace dumper diff --git a/runtime/onert/core/src/exec/DataflowExecutor.cc b/runtime/onert/core/src/exec/DataflowExecutor.cc index 8dac1219e..e0b00077f 100644 --- a/runtime/onert/core/src/exec/DataflowExecutor.cc +++ b/runtime/onert/core/src/exec/DataflowExecutor.cc @@ -60,7 +60,7 @@ void DataflowExecutor::emplaceToReadyJobs(const uint32_t &id) void DataflowExecutor::notify(uint32_t finished_job_id) { - for (auto id : _output_info[finished_job_id]) + for (auto &&id : _output_info[finished_job_id]) { assert(_input_info[id] > 0); auto count = --_input_info[id]; @@ -90,7 +90,7 @@ DataflowExecutor::DataflowExecutor(std::unique_ptr<compiler::LoweredGraph> lower uint32_t next_job_index = 0; std::unordered_map<ir::OperationIndex, uint32_t> op_to_job; const auto &operations = _lowered_graph->graph().operations(); - operations.iterate([&](const ir::OperationIndex &op_ind, const ir::Operation &) { + operations.iterate([&](const ir::OperationIndex &op_ind, const ir::IOperation &) { VERBOSE(DataflowExecutor) << "Create a job " << next_job_index << " with Operation " << op_ind << std::endl; _finished_jobs.emplace_back( @@ -102,12 +102,12 @@ DataflowExecutor::DataflowExecutor(std::unique_ptr<compiler::LoweredGraph> lower _output_info.resize(next_job_index); _initial_input_info.resize(next_job_index, 0); - operations.iterate([&](const ir::OperationIndex &op_ind, const ir::Operation &op) { + operations.iterate([&](const ir::OperationIndex &op_ind, const ir::IOperation &op) { auto job_index = op_to_job[op_ind]; - for (auto output : op.getOutputs()) + for (auto &&output : op.getOutputs()) { // Update output and input info - operations.iterate([&](const ir::OperationIndex &op_cur_ind, const ir::Operation &op_cur) { + operations.iterate([&](const ir::OperationIndex &op_cur_ind, const ir::IOperation &op_cur) { if (op_cur.getInputs().contains(output)) { auto dep_index = op_to_job[op_cur_ind]; diff --git a/runtime/onert/core/src/exec/DynamicShapeInferer.cc b/runtime/onert/core/src/exec/DynamicShapeInferer.cc index fb8058d23..78b21cf49 100644 --- a/runtime/onert/core/src/exec/DynamicShapeInferer.cc +++ b/runtime/onert/core/src/exec/DynamicShapeInferer.cc @@ -253,7 +253,7 @@ void DynamicShapeInferer::visit(const ir::operation::Concat &op) So, only when all inputs are static, we can skip dynamic shape inference. */ bool all_static = true; - for (auto input_ind : op.getInputs()) + for (auto &&input_ind : op.getInputs()) { auto input = _tensor_registry->getITensor(input_ind); if (input->is_dynamic()) @@ -290,7 +290,7 @@ void DynamicShapeInferer::visit(const ir::operation::Concat &op) auto first_input_ind = op.getInputs().at(0); auto first_input = _tensor_registry->getITensor(first_input_ind); - for (auto input_ind : op.getInputs()) + for (auto &&input_ind : op.getInputs()) { auto input = _tensor_registry->getITensor(input_ind); if (input != first_input && !isConcatible(first_input, input, op.param().axis)) @@ -300,7 +300,7 @@ void DynamicShapeInferer::visit(const ir::operation::Concat &op) // getting output shape onert::shape_inference::Shapes in_shapes; - for (auto input_ind : op.getInputs()) + for (auto &&input_ind : op.getInputs()) { auto input = _tensor_registry->getITensor(input_ind); ir::Shape shape = input->getShape(); @@ -1042,7 +1042,7 @@ void DynamicShapeInferer::visit(const ir::operation::Split &op) // Return if all tensors are not dynamic bool has_dynamic = false; - for (const auto output_idx : op.getOutputs()) + for (const auto &output_idx : op.getOutputs()) { auto output = _tensor_registry->getITensor(output_idx); has_dynamic |= output->is_dynamic(); diff --git a/runtime/onert/core/src/exec/ExecTime.test.cc b/runtime/onert/core/src/exec/ExecTime.test.cc index 1f7152e7b..939184e4e 100644 --- a/runtime/onert/core/src/exec/ExecTime.test.cc +++ b/runtime/onert/core/src/exec/ExecTime.test.cc @@ -34,7 +34,7 @@ struct MockConfig : public IConfig std::string id() override { return "b1"; } bool initialize() override { return true; }; bool supportPermutation() override { return false; } - ir::Layout supportLayout(const ir::Operation &, ir::Layout) override + ir::Layout supportLayout(const ir::IOperation &, ir::Layout) override { return ir::Layout::UNKNOWN; } diff --git a/runtime/onert/core/src/exec/Execution.cc b/runtime/onert/core/src/exec/Execution.cc index 7d5b406ef..1384c9fdc 100644 --- a/runtime/onert/core/src/exec/Execution.cc +++ b/runtime/onert/core/src/exec/Execution.cc @@ -16,6 +16,8 @@ #include "exec/Execution.h" +#include "train/TrainableExecutors.h" + #include "util/logging.h" namespace onert @@ -151,6 +153,35 @@ void Execution::waitFinish() bool Execution::isFinished(void) const { return finished; } +#ifdef ONERT_TRAIN +void Execution::train(uint32_t training_step) +{ + auto execs = dynamic_cast<exec::train::TrainableExecutors *>(_executors.get()); + if (!execs) + { + throw std::runtime_error{"Supported only TrainableExecutors"}; + } + + VERBOSE(Execution) << "Start training" << std::endl; + + execs->train(_io_desc, training_step); + finished = true; + + VERBOSE(Execution) << "training finished" << std::endl; +} + +float Execution::getLoss(const ir::IOIndex &ind) +{ + auto execs = dynamic_cast<exec::train::TrainableExecutors *>(_executors.get()); + if (!execs) + { + throw std::runtime_error{"Supported only TrainableExecutors"}; + } + + return execs->getLoss(ind); +} +#endif // ONERT_TRAIN + ir::Shape Execution::getInputShape(ir::IOIndex ind) const { auto itr = _io_desc.dynamic_input_shapes.find(ind); @@ -180,5 +211,16 @@ ir::Shape Execution::getOutputShape(ir::IOIndex ind) const return output_desc->info.shape(); } +size_t Execution::getInputTotalSize(ir::IOIndex ind) const +{ + // TODO Support dynamic shape + return _executors->inputInfo(ind).total_size(); +} + +size_t Execution::getOutputTotalSize(ir::IOIndex ind) const +{ + return _executors->outputInfo(ind).total_size(); +} + } // namespace exec } // namespace onert diff --git a/runtime/onert/core/src/exec/ExecutionObservers.cc b/runtime/onert/core/src/exec/ExecutionObservers.cc index 9abde7ba4..5245518a0 100644 --- a/runtime/onert/core/src/exec/ExecutionObservers.cc +++ b/runtime/onert/core/src/exec/ExecutionObservers.cc @@ -28,7 +28,7 @@ namespace { -void setUserData(const onert::ir::Graph &g, const onert::ir::Operation *op, +void setUserData(const onert::ir::Graph &g, const onert::ir::IOperation *op, decltype(EventCollector::Event::userData) &data) { // From a tensor of shape [a, b, c], this will return a string "shape(a b c)". diff --git a/runtime/onert/core/src/exec/ExecutionObservers.h b/runtime/onert/core/src/exec/ExecutionObservers.h index 91fbac323..7e93ecf7c 100644 --- a/runtime/onert/core/src/exec/ExecutionObservers.h +++ b/runtime/onert/core/src/exec/ExecutionObservers.h @@ -24,7 +24,7 @@ #include "exec/IExecutor.h" #include "ir/Index.h" -#include "ir/Operation.h" +#include "ir/IOperation.h" #include "util/ITimer.h" #include "util/TracingCtx.h" diff --git a/runtime/onert/core/src/exec/ExecutorBase.cc b/runtime/onert/core/src/exec/ExecutorBase.cc index 515cf8e48..ad0073477 100644 --- a/runtime/onert/core/src/exec/ExecutorBase.cc +++ b/runtime/onert/core/src/exec/ExecutorBase.cc @@ -35,7 +35,7 @@ ExecutorBase::ExecutorBase(std::unique_ptr<compiler::LoweredGraph> &&lowered_gra { auto build_tensor_list = [&](const auto &ind_seq, auto &tensors) { assert(tensors.empty()); - for (auto ind : ind_seq) + for (auto &&ind : ind_seq) { backend::ITensor *tensor = tensor_regs.getITensor(ind); assert(tensor != nullptr); diff --git a/runtime/onert/core/src/exec/ExecutorBase.h b/runtime/onert/core/src/exec/ExecutorBase.h index 7aee3d9ee..4f97de922 100644 --- a/runtime/onert/core/src/exec/ExecutorBase.h +++ b/runtime/onert/core/src/exec/ExecutorBase.h @@ -77,6 +77,7 @@ public: { return _output_tensors; } + backend::BackendContexts &getBackendContexts() { return _backend_contexts; } protected: /** diff --git a/runtime/onert/core/src/exec/Executors.cc b/runtime/onert/core/src/exec/Executors.cc index 7edd5aaa2..8a1be3df4 100644 --- a/runtime/onert/core/src/exec/Executors.cc +++ b/runtime/onert/core/src/exec/Executors.cc @@ -147,7 +147,7 @@ void Executors::checkSupportedMultimodel() const // Assumption: edges // m1 < m2, s1 == 0 and s2 == 0 if edge 'm1:s1:o1 -> m2:s2:o2' - for (auto edge : _model_edges->edges) + for (auto &&edge : _model_edges->edges) { auto const model_from = std::get<ir::ModelIndex>(edge.from); auto const model_to = std::get<ir::ModelIndex>(edge.to); diff --git a/runtime/onert/core/src/exec/FunctionSequence.cc b/runtime/onert/core/src/exec/FunctionSequence.cc index f87c271f7..578123a54 100644 --- a/runtime/onert/core/src/exec/FunctionSequence.cc +++ b/runtime/onert/core/src/exec/FunctionSequence.cc @@ -16,7 +16,6 @@ #include "exec/FunctionSequence.h" -#include "ir/Operation.h" #include "backend/ITensorRegistry.h" #include "util/logging.h" diff --git a/runtime/onert/core/src/exec/LinearExecutor.h b/runtime/onert/core/src/exec/LinearExecutor.h index a833466da..cc073411a 100644 --- a/runtime/onert/core/src/exec/LinearExecutor.h +++ b/runtime/onert/core/src/exec/LinearExecutor.h @@ -52,7 +52,7 @@ public: const std::vector<ir::OperationIndex> &order, const util::TracingCtx *tracing_ctx) : ExecutorBase{std::move(lowered_graph), std::move(backend_contexts), tensor_regs, tracing_ctx} { - for (auto index : order) + for (auto &&index : order) { _code.emplace_back(std::move(code_map.at(index))); } diff --git a/runtime/onert/core/src/exec/MinMaxRecorder.cc b/runtime/onert/core/src/exec/MinMaxRecorder.cc new file mode 100644 index 000000000..88fc104d1 --- /dev/null +++ b/runtime/onert/core/src/exec/MinMaxRecorder.cc @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "MinMaxRecorder.h" + +#include "backend/ITensor.h" + +#include <cassert> +#include <cmath> + +namespace onert +{ +namespace exec +{ + +MinMaxRecorder::MinMaxRecorder(const std::string &minmax_filepath, const ir::Graph &graph, + const backend::BackendContexts &backend_contexts) + : _graph{graph}, _backend_contexts{backend_contexts}, _h5dumper(minmax_filepath) +{ +} + +void MinMaxRecorder::handleJobEnd(IExecutor *, ir::SubgraphIndex subg_idx, + ir::OperationIndex op_idx, const backend::Backend *backend) +{ + const auto &tensor_reg = _backend_contexts.at(backend)->tensor_registry; + const auto &op = _graph.operations().at(op_idx); + const auto &outputs = op.getOutputs(); + // TODO: Support multiple output + if (outputs.size() != 1) + throw std::runtime_error("Only 1 output operator is supported for recording minmax."); + + auto tensor = tensor_reg->getITensor(outputs.at(0)); + + // Logic copied from MinMaxObserver.cpp. + + // Filter Ops + if (tensor->is_constant()) + return; + + if (tensor->data_type() != ir::DataType::FLOAT32) + return; + + switch (op.opcode()) + { + // Operators with multiple outputs + case ir::OpCode::If: + case ir::OpCode::Split: + case ir::OpCode::SplitV: + case ir::OpCode::TopKV2: + case ir::OpCode::Unpack: + case ir::OpCode::While: + return; + // NOTE: Sin, Cos, Tanh's output is in [-1, 1] + // We may not need to dump those operators. + default:; // Do Nothing + } + + // Otherwise, dump! + assert(tensor->data_type() == ir::DataType::FLOAT32); + const auto data = reinterpret_cast<float *>(tensor->buffer()); + const auto num_elements = tensor->total_size() / sizeof(float); + + float max = std::numeric_limits<float>::lowest(); + float min = std::numeric_limits<float>::max(); + + bool all_nan = true; + for (size_t i = 0; i < num_elements; ++i) + { + const float number = data[i]; + if (std::isnan(number)) + continue; + + if (number == std::numeric_limits<float>::lowest()) + continue; + + all_nan = false; + + if (number > max) + max = number; + + if (number < min) + min = number; + } + + if (all_nan) + throw std::runtime_error("All values are NaN(Not a Number)"); + + _minmax_map.append({subg_idx, op_idx}, min, max); +} + +void MinMaxRecorder::handleSubgraphEnd(ir::SubgraphIndex) +{ + // It would be better to dump at the end of model execution, not subgraph + // But it requires more changes than subgraph. + _h5dumper.dump(_minmax_map); +} + +} // namespace exec +} // namespace onert diff --git a/runtime/onert/core/src/exec/MinMaxRecorder.h b/runtime/onert/core/src/exec/MinMaxRecorder.h new file mode 100644 index 000000000..7a0817f5f --- /dev/null +++ b/runtime/onert/core/src/exec/MinMaxRecorder.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_EXEC_MINMAX_RECORDER__ +#define __ONERT_EXEC_MINMAX_RECORDER__ + +#include "ExecutionObservers.h" +#include "ir/Index.h" +#include "exec/MinMaxMap.h" +#include "../dumper/h5/MinMaxDumper.h" + +#include <memory> + +namespace onert +{ +namespace exec +{ + +class MinMaxRecorder : public IExecutionObserver +{ +public: + MinMaxRecorder(const std::string &minmax_filepath, const ir::Graph &graph, + const backend::BackendContexts &backend_contexts); + void handleJobBegin(IExecutor *, ir::SubgraphIndex, ir::OperationIndex, + const backend::Backend *) override + { + return; + } + void handleJobEnd(IExecutor *, ir::SubgraphIndex, ir::OperationIndex, + const backend::Backend *) override; + void handleSubgraphEnd(ir::SubgraphIndex) override; + +private: + const ir::Graph &_graph; + const backend::BackendContexts &_backend_contexts; + dumper::h5::MinMaxDumper _h5dumper; + SMMinMaxMap _minmax_map; +}; + +} // namespace exec +} // namespace onert + +#endif // __ONERT_EXEC_MINMAX_RECORDER__ diff --git a/runtime/onert/core/src/exec/ParallelScheduler.cc b/runtime/onert/core/src/exec/ParallelScheduler.cc index 456663f91..538945631 100644 --- a/runtime/onert/core/src/exec/ParallelScheduler.cc +++ b/runtime/onert/core/src/exec/ParallelScheduler.cc @@ -30,7 +30,7 @@ ParallelScheduler::ParallelScheduler(const BackendSet &backends) { assert(!backends.empty()); - for (auto backend : backends) + for (auto &&backend : backends) { _thread_pools[backend] = std::make_unique<ThreadPool>(); } diff --git a/runtime/onert/core/src/exec/train/TrainableExecutor.cc b/runtime/onert/core/src/exec/train/TrainableExecutor.cc new file mode 100644 index 000000000..9c7e70c29 --- /dev/null +++ b/runtime/onert/core/src/exec/train/TrainableExecutor.cc @@ -0,0 +1,204 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TrainableExecutor.h" +#ifdef RUY_PROFILER +#include "ruy/profiler/instrumentation.h" +#endif + +#include <misc/polymorphic_downcast.h> + +namespace onert +{ +namespace exec +{ +namespace train +{ + +TrainableExecutor::TrainableExecutor( + std::unique_ptr<compiler::train::LoweredTrainableGraph> lowered_graph, + backend::train::TrainableBackendContexts &&backend_contexts, + const compiler::train::TensorRegistries &tensor_regs, + compiler::train::TrainableCodeMap &&code_map, const std::vector<ir::OperationIndex> &order, + const util::TracingCtx *tracing_ctx) + : _lowered_graph{std::move(lowered_graph)}, _backend_contexts{std::move(backend_contexts)}, + _trainable_graph{_lowered_graph->trainable_graph()}, _tensor_regs{std::move(tensor_regs)}, + _mutex(), _tracing_ctx(tracing_ctx) +{ + auto build_tensor_list = [&](const auto &ind_seq, auto &tensors) { + assert(tensors.empty()); + for (auto &&ind : ind_seq) + { + backend::ITensor *tensor = tensor_regs.getITensor(ind); + assert(tensor != nullptr); + auto io_tensor = nnfw::misc::polymorphic_downcast<backend::builtin::IOTensor *>(tensor); + tensors.push_back(io_tensor); + } + }; + build_tensor_list(_trainable_graph.getInputs(), _input_tensors); + build_tensor_list(_trainable_graph.getOutputs(), _output_tensors); + + for (auto &&index : order) + { + auto &trainable_code = code_map.at(index); + _code.emplace_back(std::move(trainable_code)); + } +} + +void TrainableExecutor::execute(const std::vector<backend::IPortableTensor *> &, + const std::vector<backend::IPortableTensor *> &) +{ + throw std::runtime_error("TrainableExecutor does not support multiple subgraphs yet"); +} + +void TrainableExecutor::forward(const IODescription &desc, bool training) +{ + // For thread-safe, use mutex + // TODO: if all used backends on this executor are thread-safe, + // do not need to use mutex (otherwise, use mutex) + std::lock_guard<std::mutex> lock(_mutex); + + // TODO Update IO tensors if desc has dynamic input + // Set input(s) + assert(_input_tensors.size() == desc.inputs.size()); + for (uint32_t i = 0; i < _input_tensors.size(); ++i) + { + auto tensor = _input_tensors[i]; + + // TODO Check if (desc.inputs[i] == nullptr) + // TODO Better design for ITensor? (we need const_cast as ITensor is writable) + tensor->setUserTensor(static_cast<uint8_t *>(const_cast<void *>(desc.inputs[i]->buffer)), + desc.inputs[i]->size); + } + + if (!training) + { + // Set output(s) + assert(_output_tensors.size() == desc.outputs.size()); + for (uint32_t i = 0; i < _output_tensors.size(); ++i) + { + auto tensor = _output_tensors[i]; + + if (desc.outputs[i] == nullptr) + throw std::runtime_error{"Output " + std::to_string(i) + "'s buffer is not set."}; + tensor->setUserTensor(static_cast<uint8_t *>(desc.outputs[i]->buffer), desc.outputs[i]->size); + } + } + + forwardImpl(training); + + // TODO Update output(s) desc if desc has dynamic input +} + +void TrainableExecutor::forwardImpl(bool training) +{ + if (_tracing_ctx) + { + auto profiling_subg_index = _tracing_ctx->getSubgraphIndex(&_trainable_graph.graph()); + + _subject.notifySubgraphBegin(profiling_subg_index); + for (auto &&code : _code) + { + const auto backend = code.lower_info->backend(); +// TODO : Move ruy profiler into ExecutionObserver +#ifdef RUY_PROFILER + ruy::profiler::ScopeLabel label(code.op->name()); +#endif + _subject.notifyJobBegin(this, profiling_subg_index, code.op_ind, backend); + + auto &tn_seq = code.tn_seq; + tn_seq->forward(training); + + _subject.notifyJobEnd(this, profiling_subg_index, code.op_ind, backend); + } + _subject.notifySubgraphEnd(profiling_subg_index); + } + else + { + for (auto &&code : _code) + { +// TODO : Move ruy profiler into ExecutionObserver +#ifdef RUY_PROFILER + ruy::profiler::ScopeLabel label(code.op->name()); +#endif + auto &tn_seq = code.tn_seq; + tn_seq->forward(training); + } + } +} + +void TrainableExecutor::backward(const IODescription &, uint32_t training_step) +{ + // For thread-safe, use mutex + // TODO: if all used backends on this executor are thread-safe, + // do not need to use mutex (otherwise, use mutex) + std::lock_guard<std::mutex> lock(_mutex); + + backwardImpl(training_step); +} + +void TrainableExecutor::backwardImpl(uint32_t training_step) +{ + if (_tracing_ctx) + { + auto profiling_subg_index = _tracing_ctx->getSubgraphIndex(&_trainable_graph.graph()); + + _subject.notifySubgraphBegin(profiling_subg_index); + for (auto it = _code.rbegin(); it != _code.rend(); ++it) + { + const auto &code = *it; + const auto backend = code.lower_info->backend(); +// TODO : Move ruy profiler into ExecutionObserver +#ifdef RUY_PROFILER + ruy::profiler::ScopeLabel label(code.op->name()); +#endif + _subject.notifyJobBegin(this, profiling_subg_index, code.op_ind, backend); + + auto &tn_seq = code.tn_seq; + tn_seq->backward(training_step); + + _subject.notifyJobEnd(this, profiling_subg_index, code.op_ind, backend); + } + _subject.notifySubgraphEnd(profiling_subg_index); + } + else + { + for (auto it = _code.rbegin(); it != _code.rend(); ++it) + { + const auto &code = *it; +// TODO : Move ruy profiler into ExecutionObserver +#ifdef RUY_PROFILER + ruy::profiler::ScopeLabel label(code.op->name()); +#endif + auto &tn_seq = code.tn_seq; + tn_seq->backward(training_step); + } + } +} + +float TrainableExecutor::getLoss(const ir::IOIndex &pred_io_ind) const +{ + const auto &loss_ind = _trainable_graph.getLossIndex(pred_io_ind); + if (loss_ind.undefined()) + throw std::runtime_error{"Loss " + std::to_string(loss_ind.value()) + " is not defined."}; + backend::ITensor *tensor = _tensor_regs.getITensor(loss_ind); + auto loss_buf = reinterpret_cast<float *>(tensor->buffer()); + return *loss_buf; +} + +} // namespace train +} // namespace exec +} // namespace onert diff --git a/runtime/onert/core/src/exec/train/TrainableExecutor.h b/runtime/onert/core/src/exec/train/TrainableExecutor.h new file mode 100644 index 000000000..6b645305f --- /dev/null +++ b/runtime/onert/core/src/exec/train/TrainableExecutor.h @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_EXEC_TRAIN_TRAINABLE_EXECUTOR_H_ +#define __ONERT_EXEC_TRAIN_TRAINABLE_EXECUTOR_H_ + +#include "exec/IExecutor.h" + +#include "../ExecutionObservee.h" +#include "../../compiler/train/TensorRegistries.h" + +#include "backend/train/TrainableBackendContext.h" +#include "compiler/train/TrainableCodeMap.h" +#include "compiler/train/LoweredTrainableGraph.h" +#include "ir/Index.h" +#include "util/TracingCtx.h" + +namespace onert +{ +namespace exec +{ +namespace train +{ + +class TrainableExecutor : public IExecutor +{ +public: + /** + * @brief Construct a new TrainableExecutor object + * @param lowered_graph LoweredTrainableGraph object + * @param tensor_builders Tensor builders that are currently used + * @param code_map @c ir::Operation and its code map + */ + TrainableExecutor(std::unique_ptr<compiler::train::LoweredTrainableGraph> lowered_graph, + backend::train::TrainableBackendContexts &&backend_contexts, + const compiler::train::TensorRegistries &tensor_regs, + compiler::train::TrainableCodeMap &&code_map, + const std::vector<ir::OperationIndex> &order, + const util::TracingCtx *tracing_ctx); + +public: + const ir::Graph &graph() const final { return _trainable_graph.graph(); } + + void execute(const IODescription &desc) override { forward(desc, false); }; + + void execute(const std::vector<backend::IPortableTensor *> &inputs, + const std::vector<backend::IPortableTensor *> &outputs) override; + + void forward(const IODescription &desc, bool training); + void backward(const IODescription &desc, uint32_t training_step); + + // Used only in Dataflow and Parallel Executors + void setIndexedRanks(std::shared_ptr<ir::OperationIndexMap<int64_t>> ranks) final + { + _indexed_ranks = std::move(ranks); + }; + + void addObserver(std::unique_ptr<IExecutionObserver> ref) { _subject.add(std::move(ref)); }; + + const std::vector<backend::builtin::IOTensor *> &getInputTensors() const override + { + return _input_tensors; + } + + const std::vector<backend::builtin::IOTensor *> &getOutputTensors() const override + { + return _output_tensors; + } + + float getLoss(const ir::IOIndex &pred_io_ind) const; + + backend::train::TrainableBackendContexts &getBackendContexts() { return _backend_contexts; } + +private: + void forwardImpl(bool training); + void backwardImpl(uint32_t training_step); + +private: + std::vector<compiler::train::TrainableCodeAndInfo> _code; + ExecutionObservee _subject; + std::shared_ptr<ir::OperationIndexMap<int64_t>> _indexed_ranks; + std::unique_ptr<compiler::train::LoweredTrainableGraph> _lowered_graph; + backend::train::TrainableBackendContexts _backend_contexts; + const ir::train::TrainableGraph &_trainable_graph; + compiler::train::TensorRegistries _tensor_regs; + std::vector<backend::builtin::IOTensor *> _input_tensors; + std::vector<backend::builtin::IOTensor *> _output_tensors; + std::mutex _mutex; + const util::TracingCtx *_tracing_ctx; +}; + +} // namespace train +} // namespace exec +} // namespace onert + +#endif // __ONERT_EXEC_TRAIN_TRAINABLE_EXECUTOR_H_ diff --git a/runtime/onert/core/src/exec/train/TrainableExecutors.cc b/runtime/onert/core/src/exec/train/TrainableExecutors.cc new file mode 100644 index 000000000..ba39bf0f0 --- /dev/null +++ b/runtime/onert/core/src/exec/train/TrainableExecutors.cc @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TrainableExecutors.h" + +#include "../../backend/builtin/IOTensor.h" + +#include <misc/polymorphic_downcast.h> + +namespace onert +{ +namespace exec +{ +namespace train +{ + +void TrainableExecutors::emplace(const ir::ModelIndex &, const ir::SubgraphIndex &subg_index, + std::unique_ptr<IExecutor> exec) +{ + std::unique_ptr<TrainableExecutor> t_exec{ + nnfw::misc::polymorphic_downcast<TrainableExecutor *>(exec.release())}; + _executors.emplace(subg_index, std::move(t_exec)); +} + +TrainableExecutor *TrainableExecutors::at(const ir::ModelIndex &, + const ir::SubgraphIndex &subg_index) const +{ + return _executors.at(subg_index).get(); +} + +uint32_t TrainableExecutors::inputSize() const { return entryExecutor()->getInputTensors().size(); } + +uint32_t TrainableExecutors::outputSize() const +{ + return entryExecutor()->getOutputTensors().size(); +} + +const ir::OperandInfo &TrainableExecutors::inputInfo(const ir::IOIndex &index) const +{ + return entryExecutor()->getInputTensors().at(index.value())->orig_info(); +} + +const ir::OperandInfo &TrainableExecutors::outputInfo(const ir::IOIndex &index) const +{ + return entryExecutor()->getOutputTensors().at(index.value())->orig_info(); +} + +void TrainableExecutors::execute(const IODescription &desc) +{ + if (_executors.size() > 1) + throw std::runtime_error("TrainableExecutors does not support multiple executors yet"); + entryExecutor()->forward(desc, false); + + // TODO Support multple executors +} + +void TrainableExecutors::train(const IODescription &desc, uint32_t training_step) +{ + if (_executors.size() > 1) + throw std::runtime_error("TrainableExecutors does not support multiple executors yet"); + entryExecutor()->forward(desc, true); + entryExecutor()->backward(desc, training_step); + + // TODO Support multple executors +} + +float TrainableExecutors::getLoss(const ir::IOIndex &index) const +{ + if (_executors.size() > 1) + throw std::runtime_error("TrainableExecutors does not support multiple executors yet"); + return entryExecutor()->getLoss(index); +} + +} // namespace train +} // namespace exec +} // namespace onert diff --git a/runtime/onert/core/src/exec/train/TrainableExecutors.h b/runtime/onert/core/src/exec/train/TrainableExecutors.h new file mode 100644 index 000000000..db6d198b1 --- /dev/null +++ b/runtime/onert/core/src/exec/train/TrainableExecutors.h @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_EXEC_TRAIN_TRAINABLE_EXECUTORS_H__ +#define __ONERT_EXEC_TRAIN_TRAINABLE_EXECUTORS_H__ + +#include "TrainableExecutor.h" +#include "exec/IExecutors.h" +#include "ir/NNPkg.h" + +namespace onert +{ +namespace exec +{ +namespace train +{ + +/** + * @brief Class to gather executor set for trainable model NN package + */ +class TrainableExecutors : public IExecutors +{ +public: + /** + * @brief Construct a new TrainableExecutors object + */ + TrainableExecutors(void) = default; + TrainableExecutors(const TrainableExecutors &) = delete; + TrainableExecutors(TrainableExecutors &&) = default; + + /** + * @brief Destroy the TrainableExecutors object + */ + ~TrainableExecutors() = default; + +public: + TrainableExecutors &operator=(const TrainableExecutors &) = delete; + TrainableExecutors &operator=(TrainableExecutors &&) = default; + +public: + void emplace(const ir::ModelIndex &model_index, const ir::SubgraphIndex &subg_index, + std::unique_ptr<IExecutor> exec) override; + + TrainableExecutor *at(const ir::ModelIndex &model_index, + const ir::SubgraphIndex &subg_index) const override; + + TrainableExecutor *entryExecutor() const { return at(ir::ModelIndex{0}, ir::SubgraphIndex{0}); } + + uint32_t inputSize() const override; + + uint32_t outputSize() const override; + + const ir::OperandInfo &inputInfo(const ir::IOIndex &index) const override; + + const ir::OperandInfo &outputInfo(const ir::IOIndex &index) const override; + + void execute(const IODescription &desc) override; + + /** + * @brief Train + * + * @param desc IO information + * @param training_step The number of iterations of an training process. + * In other words, the number of gradient update. + */ + void train(const IODescription &desc, uint32_t training_step); + + float getLoss(const ir::IOIndex &index) const; + +private: + // TODO Append model index to ModelIndex + std::unordered_map<ir::SubgraphIndex, std::unique_ptr<TrainableExecutor>> _executors; +}; + +} // namespace train +} // namespace exec +} // namespace onert + +#endif // __ONERT_EXEC_TRAIN_TRAINABLE_EXECUTORS_H__ diff --git a/runtime/onert/core/src/exec/train/TrainableFnSequence.cc b/runtime/onert/core/src/exec/train/TrainableFnSequence.cc new file mode 100644 index 000000000..084b3d708 --- /dev/null +++ b/runtime/onert/core/src/exec/train/TrainableFnSequence.cc @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "exec/train/TrainableFnSequence.h" + +namespace onert +{ +namespace exec +{ +namespace train +{ + +void TrainableFnSequence::forward(bool training) +{ + for (const auto &function : _functions) + { + function->forward(training); + } +} + +void TrainableFnSequence::backward(uint32_t training_step) +{ + for (auto it = _functions.rbegin(); it != _functions.rend(); ++it) + { + (*it)->backward(); + } + + for (const auto &applier : _appliers) + { + applier->applyGradient(training_step); + } +} + +void TrainableFnSequence::append(std::unique_ptr<ITrainableFunction> &&function) +{ + _functions.push_back(std::move(function)); +} + +void TrainableFnSequence::append(std::unique_ptr<IGradientApplier> &&applier) +{ + _appliers.push_back(std::move(applier)); +} + +void TrainableFnSequence::iterate(const std::function<void(ITrainableFunction &)> &fn) +{ + for (const auto &func : _functions) + { + fn(*func); + } +} + +} // namespace train +} // namespace exec +} // namespace onert diff --git a/runtime/onert/core/src/exec/train/optimizer/OptimizerCode.cc b/runtime/onert/core/src/exec/train/optimizer/OptimizerCode.cc new file mode 100644 index 000000000..72b581bf6 --- /dev/null +++ b/runtime/onert/core/src/exec/train/optimizer/OptimizerCode.cc @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "exec/train/optimizer/OptimizerCode.h" + +#include <unordered_map> + +namespace onert +{ +namespace exec +{ +namespace train +{ +namespace optimizer +{ + +std::string toString(OptimizerCode code) +{ + static const std::unordered_map<OptimizerCode, const char *> map{ + {OptimizerCode::Invalid, "Invalid"}, + {OptimizerCode::SGD, "SGD"}, + {OptimizerCode::Adam, "Adam"}}; + return map.at(code); +} + +} // namespace optimizer +} // namespace train +} // namespace exec +} // namespace onert diff --git a/runtime/onert/core/src/exec/train/optimizer/OptimizerHelpers.h b/runtime/onert/core/src/exec/train/optimizer/OptimizerHelpers.h new file mode 100644 index 000000000..66a08b50f --- /dev/null +++ b/runtime/onert/core/src/exec/train/optimizer/OptimizerHelpers.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_EXEC_TRAIN_OPTIMIZER_OPTIMIZER_HELPERS_H__ +#define __ONERT_EXEC_TRAIN_OPTIMIZER_OPTIMIZER_HELPERS_H__ + +#include "backend/IPortableTensor.h" + +namespace onert +{ +namespace exec +{ +namespace train +{ +namespace optimizer +{ + +template <typename T, typename L> +void elementwise(const ir::Shape &shape, const backend::ITensor &src, backend::ITensor &dst, + const L &f) +{ + ShapeLoop(shape, [&](const ir::Coordinates &coords) { + const T src_val = *reinterpret_cast<const T *>(src.buffer() + src.calcOffset(coords)); + T *dst_data = reinterpret_cast<T *>(dst.buffer() + dst.calcOffset(coords)); + *dst_data = f(src_val, *dst_data); + }); +} + +} // namespace optimizer +} // namespace train +} // namespace exec +} // namespace onert + +#endif // __ONERT_EXEC_TRAIN_OPTIMIZER_OPTIMIZER_HELPERS_H__ diff --git a/runtime/onert/core/src/exec/train/optimizer/SGD.cc b/runtime/onert/core/src/exec/train/optimizer/SGD.cc new file mode 100644 index 000000000..abfbc1b4b --- /dev/null +++ b/runtime/onert/core/src/exec/train/optimizer/SGD.cc @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <exec/train/optimizer/SGD.h> + +#include "OptimizerHelpers.h" + +namespace onert +{ +namespace exec +{ +namespace train +{ +namespace optimizer +{ + +double SGD::getLearningRate(uint32_t) const +{ + // TODO Use iteration, momentum, and nesterov + return _learning_rate; +} + +void SGD::applyGradient(const UpdateFactors &factors) const +{ + const auto lr = getLearningRate(std::get<size_t>(factors)); + const auto &grad_tensor = std::get<const backend::IPortableTensor &>(factors); + auto &trainable_tensor = std::get<backend::train::ITrainableTensor &>(factors); + assert(trainable_tensor.data_type() == grad_tensor.data_type()); + + const auto shape = trainable_tensor.getShape(); + const auto &grad_shape = grad_tensor.get_info().shape(); + + // TODO Support for different shapes + if (shape != grad_shape) + { + throw std::runtime_error("SGD: Invalid gradient tensor"); + } + + switch (grad_tensor.data_type()) + { + case ir::DataType::FLOAT32: + elementwise<float>(shape, grad_tensor, trainable_tensor, + [&](float src, float dst) -> float { return dst - src * lr; }); + break; + default: + throw std::runtime_error("SGD: Not supported data type"); + } +} + +} // namespace optimizer +} // namespace train +} // namespace exec +} // namespace onert diff --git a/runtime/onert/core/src/ir/Graph.cc b/runtime/onert/core/src/ir/Graph.cc index 28cf4137d..ef0f988fa 100644 --- a/runtime/onert/core/src/ir/Graph.cc +++ b/runtime/onert/core/src/ir/Graph.cc @@ -42,33 +42,33 @@ OperandIndex Graph::addOperand(OperandIndex index, std::unique_ptr<Operand> &&op return _operands.push(std::move(operand), index); } -bool Graph::checkOperandsForOperation(const Operation &operation) +bool Graph::checkOperandsForOperation(const IOperation &operation) { auto inputs = operation.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED; auto outputs = operation.getOutputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED; - for (auto input : inputs) + for (auto &&input : inputs) if (!operands().exist(input)) return false; - for (auto input : outputs) + for (auto &&input : outputs) if (!operands().exist(input)) return false; return true; } -void Graph::linkOperandToOperation(OperationIndex index, const Operation &operation) +void Graph::linkOperandToOperation(OperationIndex index, const IOperation &operation) { auto inputs = operation.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED; auto outputs = operation.getOutputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED; - for (auto input : inputs) + for (auto &&input : inputs) operands().at(input).insertUse(index); - for (auto output : outputs) + for (auto &&output : outputs) operands().at(output).setDef(index); } -OperationIndex Graph::addOperation(std::unique_ptr<Operation> &&operation) +OperationIndex Graph::addOperation(std::unique_ptr<IOperation> &&operation) { - const Operation &op_ref = *operation; + const IOperation &op_ref = *operation; if (!checkOperandsForOperation(op_ref)) return OperationIndex{}; auto ind = _operations.push(std::move(operation)); @@ -77,9 +77,9 @@ OperationIndex Graph::addOperation(std::unique_ptr<Operation> &&operation) return ind; } -OperationIndex Graph::addOperation(OperationIndex index, std::unique_ptr<Operation> &&operation) +OperationIndex Graph::addOperation(OperationIndex index, std::unique_ptr<IOperation> &&operation) { - const Operation &op_ref = *operation; + const IOperation &op_ref = *operation; if (!checkOperandsForOperation(op_ref)) return OperationIndex{}; auto ind_gen = _operations.push(std::move(operation), index); @@ -91,12 +91,35 @@ OperationIndex Graph::addOperation(OperationIndex index, std::unique_ptr<Operati return index; } +OperationIndex Graph::replaceOperation(OperationIndex index, + std::unique_ptr<IOperation> &&operation) +{ + const IOperation &op_ref = *operation; + if (!checkOperandsForOperation(op_ref) || !_operations.exist(index)) + return OperationIndex{}; + + // Check the new operation has the same inputs/outputs as the existing operation + const auto &old_op = _operations.at(index); + if (!(old_op.getInputs() == op_ref.getInputs() && old_op.getOutputs() == op_ref.getOutputs())) + { + return OperationIndex{}; + } + + return _operations.set(index, std::move(operation)); +} + void Graph::setOperandValue(const OperandIndex &ind, std::shared_ptr<Data> data) { assert(_operands.exist(ind)); _operands.at(ind).data(std::move(data)); } +void Graph::changeShape(const OperandIndex &ind, const ir::Shape &new_shape) +{ + assert(_operands.exist(ind)); + _operands.at(ind).info().shape(new_shape); +} + void Graph::addInput(const OperandIndex &ind, const std::string &name) { if (!name.empty()) @@ -123,7 +146,7 @@ IOIndex Graph::getOutputIndex(const std::string &name) const return (itr == _name_to_output.end()) ? IOIndex{} : itr->second; } -void Graph::verify(void) +void Graph::verify(void) const { // Call graph verifications for the MODEL phase { @@ -144,14 +167,14 @@ void Graph::verify(void) void Graph::initializeUseDef() { - operations().iterate([&](const OperationIndex &index, const Operation &node) -> void { + operations().iterate([&](const OperationIndex &index, const IOperation &node) -> void { auto outputs = node.getOutputs(); - for (auto output : outputs | ir::Remove::UNDEFINED) + for (auto &&output : outputs | ir::Remove::UNDEFINED) { operands().at(output).setDef(index); } - for (auto input : node.getInputs() | ir::Remove::UNDEFINED) + for (auto &&input : node.getInputs() | ir::Remove::UNDEFINED) { operands().at(input).insertUse(index); } @@ -163,15 +186,15 @@ std::vector<ir::OperationIndex> Graph::topolSortOperations() const std::vector<ir::OperationIndex> ret; util::Set<ir::OperationIndex> unvisited; operations().iterate( - [&](const ir::OperationIndex &index, const ir::Operation &) { unvisited.add(index); }); + [&](const ir::OperationIndex &index, const ir::IOperation &) { unvisited.add(index); }); - std::function<void(const ir::OperationIndex &, const ir::Operation &)> dfs = - [&](const ir::OperationIndex &index, const ir::Operation &op) -> void { + std::function<void(const ir::OperationIndex &, const ir::IOperation &)> dfs = + [&](const ir::OperationIndex &index, const ir::IOperation &op) -> void { if (!unvisited.contains(index)) return; unvisited.remove(index); - for (const auto output : op.getOutputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED) + for (const auto &output : op.getOutputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED) { const auto &operand = operands().at(output); for (const auto &use : operand.getUses()) diff --git a/runtime/onert/core/src/ir/LayoutSet.cc b/runtime/onert/core/src/ir/LayoutSet.cc index bd3f438ad..732460aa2 100644 --- a/runtime/onert/core/src/ir/LayoutSet.cc +++ b/runtime/onert/core/src/ir/LayoutSet.cc @@ -23,7 +23,7 @@ namespace ir LayoutSet::LayoutSet(std::initializer_list<Layout> layouts) { - for (auto layout : layouts) + for (auto &&layout : layouts) { _set.insert(layout); } @@ -32,7 +32,7 @@ LayoutSet::LayoutSet(std::initializer_list<Layout> layouts) LayoutSet LayoutSet::operator|(const LayoutSet &other) const { auto ret = *this; - for (auto layout : other) + for (auto &&layout : other) { ret.add(layout); } @@ -42,7 +42,7 @@ LayoutSet LayoutSet::operator|(const LayoutSet &other) const LayoutSet LayoutSet::operator&(const LayoutSet &other) const { LayoutSet ret; - for (auto layout : other) + for (auto &&layout : other) { if (contains(layout)) { @@ -55,7 +55,7 @@ LayoutSet LayoutSet::operator&(const LayoutSet &other) const LayoutSet LayoutSet::operator-(const LayoutSet &other) const { auto ret = *this; - for (auto layout : other) + for (auto &&layout : other) { ret.remove(layout); } diff --git a/runtime/onert/core/src/ir/LayoutSet.h b/runtime/onert/core/src/ir/LayoutSet.h index 6ce4e38c6..be077f2f0 100644 --- a/runtime/onert/core/src/ir/LayoutSet.h +++ b/runtime/onert/core/src/ir/LayoutSet.h @@ -17,6 +17,7 @@ #ifndef __ONERT_IR_LAYOUT_SET_H__ #define __ONERT_IR_LAYOUT_SET_H__ +#include <cstdint> #include <initializer_list> #include <unordered_set> diff --git a/runtime/onert/core/src/ir/OperandIndexSequence.cc b/runtime/onert/core/src/ir/OperandIndexSequence.cc index b092f5cee..a15b6d0d6 100644 --- a/runtime/onert/core/src/ir/OperandIndexSequence.cc +++ b/runtime/onert/core/src/ir/OperandIndexSequence.cc @@ -31,7 +31,7 @@ OperandIndexSequence::OperandIndexSequence(std::initializer_list<OperandIndex> l OperandIndexSequence::OperandIndexSequence(std::initializer_list<int32_t> list) { - for (auto val : list) + for (auto &&val : list) { _vec.emplace_back(static_cast<uint32_t>(val)); } @@ -39,7 +39,7 @@ OperandIndexSequence::OperandIndexSequence(std::initializer_list<int32_t> list) OperandIndexSequence::OperandIndexSequence(std::initializer_list<uint32_t> list) { - for (auto val : list) + for (auto &&val : list) { _vec.emplace_back(val); } @@ -55,6 +55,11 @@ void OperandIndexSequence::replace(const OperandIndex &from, const OperandIndex std::replace(_vec.begin(), _vec.end(), from, to); } +bool OperandIndexSequence::operator==(const OperandIndexSequence &other) const +{ + return _vec == other._vec; +} + OperandIndexSequence OperandIndexSequence::operator+(const OperandIndexSequence &other) const { OperandIndexSequence ret = *this; diff --git a/runtime/onert/core/src/ir/OperationCloner.cc b/runtime/onert/core/src/ir/OperationCloner.cc index c06315814..64e1cc807 100644 --- a/runtime/onert/core/src/ir/OperationCloner.cc +++ b/runtime/onert/core/src/ir/OperationCloner.cc @@ -57,7 +57,7 @@ std::unique_ptr<Operation> OperationCloner::releaseClone() } // namespace -std::unique_ptr<Operation> clone(const Operation &operation) +std::unique_ptr<Operation> clone(const IOperation &operation) { OperationCloner cloner; operation.accept(cloner); diff --git a/runtime/onert/core/src/ir/OperationCloner.h b/runtime/onert/core/src/ir/OperationCloner.h index 6424549e9..49297a05c 100644 --- a/runtime/onert/core/src/ir/OperationCloner.h +++ b/runtime/onert/core/src/ir/OperationCloner.h @@ -26,7 +26,7 @@ namespace onert namespace ir { -std::unique_ptr<Operation> clone(const Operation &operation); +std::unique_ptr<Operation> clone(const IOperation &operation); } // namespace ir } // namespace onert diff --git a/runtime/onert/core/src/ir/OperationDumper.cc b/runtime/onert/core/src/ir/OperationDumper.cc index 0b596ff13..5e6d700f3 100644 --- a/runtime/onert/core/src/ir/OperationDumper.cc +++ b/runtime/onert/core/src/ir/OperationDumper.cc @@ -202,6 +202,14 @@ void OperationDumper::visit(const L2Normalization &node) { dumpOpGeneric(node); void OperationDumper::visit(const LocalResponseNormalization &node) { dumpOpGeneric(node); } +void OperationDumper::visit(const Loss &node) +{ + VERBOSE(LIR) << "* " << node.name() << std::endl; + VERBOSE(LIR) << " - Inputs : Prediction(" << node.getInputs().at(Loss::Input::Y_PRED) << ") True(" + << node.getInputs().at(Loss::Input::Y_TRUE) << ")" << std::endl; + VERBOSE(LIR) << " - Outputs : Output(" << node.getOutputs().at(0) << ")" << std::endl; +} + void OperationDumper::visit(const LSTM &node) { VERBOSE(LIR) << "* " << node.name() << std::endl; diff --git a/runtime/onert/core/src/ir/OperationDumper.h b/runtime/onert/core/src/ir/OperationDumper.h index fe18307b9..99bf869d5 100644 --- a/runtime/onert/core/src/ir/OperationDumper.h +++ b/runtime/onert/core/src/ir/OperationDumper.h @@ -55,6 +55,7 @@ public: void visit(const operation::InstanceNorm &) override; void visit(const operation::L2Normalization &) override; void visit(const operation::LocalResponseNormalization &) override; + void visit(const operation::Loss &node) override; void visit(const operation::LSTM &) override; void visit(const operation::Pack &) override; void visit(const operation::Pad &) override; diff --git a/runtime/onert/core/src/ir/OperationValidator.cc b/runtime/onert/core/src/ir/OperationValidator.cc index 094dbc0d5..cf7323d77 100644 --- a/runtime/onert/core/src/ir/OperationValidator.cc +++ b/runtime/onert/core/src/ir/OperationValidator.cc @@ -38,7 +38,7 @@ OperationValidator::OperationValidator(const Graph &graph) void OperationValidator::operator()() { - _operations.iterate([&](const OperationIndex &, const Operation &node) { node.accept(*this); }); + _operations.iterate([&](const OperationIndex &, const IOperation &node) { node.accept(*this); }); } DataType OperationValidator::operandType(const OperandIndex &idx) @@ -75,7 +75,7 @@ bool OperationValidator::isValidType(const OperandIndex &idx, const DataType &ty bool OperationValidator::isValidType(const OperandIndex &idx, std::initializer_list<DataType> valid_types) { - for (auto type_to_check : valid_types) + for (auto &&type_to_check : valid_types) { if (isValidType(idx, type_to_check)) { @@ -163,7 +163,7 @@ void OperationValidator::visit(const operation::Concat &node) { const auto output_index{node.getOutputs().at(0)}; - for (auto input_index : node.getInputs()) + for (auto &&input_index : node.getInputs()) { OP_REQUIRES(isSameType(input_index, output_index)); diff --git a/runtime/onert/core/src/ir/Operations.cc b/runtime/onert/core/src/ir/Operations.cc index e7e0c88cf..1b4691f58 100644 --- a/runtime/onert/core/src/ir/Operations.cc +++ b/runtime/onert/core/src/ir/Operations.cc @@ -26,7 +26,7 @@ namespace ir Operations::Operations(const Operations &obj) { obj.iterate( - [&](const OperationIndex &index, const Operation &op) { _objects.emplace(index, clone(op)); }); + [&](const OperationIndex &index, const IOperation &op) { _objects.emplace(index, clone(op)); }); _next_index = obj._next_index; } diff --git a/runtime/onert/core/src/ir/operation/Loss.cc b/runtime/onert/core/src/ir/operation/Loss.cc new file mode 100644 index 000000000..fa3520b2c --- /dev/null +++ b/runtime/onert/core/src/ir/operation/Loss.cc @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/operation/Loss.h" +#include "ir/OperationVisitor.h" + +#include <unordered_map> + +namespace onert +{ +namespace ir +{ +namespace operation +{ + +void Loss::accept(OperationVisitor &v) const { v.visit(*this); } + +Loss::Loss(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, + const Param ¶m) + : Operation{OperandConstraint::createAtLeast(2u), inputs, outputs}, _param{param} +{ + if (param.op_type == Type::CATEGORICAL_CROSSENTROPY) + { + assert(inputs.size() == 2 && "CategoricalCrossentropy Loss has 2 inputs"); + } +} + +std::string Loss::name() const +{ + using LossType = onert::ir::operation::Loss::Type; + static const std::unordered_map<Type, std::string> name_map{ + {LossType::MEAN_SQUARED_ERROR, "MeanSquaredError Loss"}, + {LossType::CATEGORICAL_CROSSENTROPY, "CategoricalCrossentropy Loss"}}; + return name_map.at(_param.op_type); +} + +} // namespace operation +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/train/TrainableGraph.cc b/runtime/onert/core/src/ir/train/TrainableGraph.cc new file mode 100644 index 000000000..781f04956 --- /dev/null +++ b/runtime/onert/core/src/ir/train/TrainableGraph.cc @@ -0,0 +1,145 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/train/TrainableGraph.h" +#include "util/Utils.h" + +#include <algorithm> +#include <misc/polymorphic_downcast.h> + +namespace onert +{ +namespace ir +{ +namespace train +{ + +TrainableGraph::TrainableGraph() : _graph{} {} + +TrainableGraph::TrainableGraph(const TrainableGraph &tgraph) + : _graph{tgraph._graph}, _derivatives{tgraph._derivatives}, _losses{tgraph._losses} +{ + tgraph.operations().iterate( + [&](const onert::ir::OperationIndex &index, const onert::ir::IOperation &op) { + replaceOperation(index, dynamic_cast<const ITrainableOperation &>(op).clone()); + }); +} + +TrainableGraph::TrainableGraph(const Graph &graph) : _graph{graph} {} + +OperandIndex TrainableGraph::addOperand(const Shape &shape, const TypeInfo &type) +{ + return _graph.addOperand(shape, type); +} + +OperandIndex TrainableGraph::addOperand(OperandIndex index, std::unique_ptr<Operand> &&operand) +{ + return _graph.addOperand(index, std::move(operand)); +} + +OperationIndex TrainableGraph::addOperation(std::unique_ptr<ITrainableOperation> &&operation) +{ + return _graph.addOperation(std::move(operation)); +} + +OperationIndex TrainableGraph::replaceOperation(OperationIndex index, + std::unique_ptr<ITrainableOperation> &&operation) +{ + return _graph.replaceOperation(index, std::move(operation)); +} + +OperandIndex TrainableGraph::addDerivative(OperandIndex index, + std::unique_ptr<Operand> &&derivative) +{ + return _derivatives.push(std::move(derivative), index); +} + +IOIndex TrainableGraph::getInputIndex(const std::string &name) const +{ + return _graph.getInputIndex(name); +} + +IOIndex TrainableGraph::getOutputIndex(const std::string &name) const +{ + return _graph.getOutputIndex(name); +} + +void TrainableGraph::changeShape(const OperandIndex &index, const ir::Shape &new_shape) +{ + _graph.changeShape(index, new_shape); +} + +void TrainableGraph::changeDerivativeShape(const OperandIndex &index, const ir::Shape &new_shape) +{ + assert(_derivatives.exist(index)); + _derivatives.at(index).info().shape(new_shape); +} + +void TrainableGraph::addInput(const OperandIndex &ind, const std::string &name) +{ + _graph.addInput(ind, name); +} + +void TrainableGraph::addOutput(const OperandIndex &ind, const std::string &name) +{ + _graph.addOutput(ind, name); +} + +void TrainableGraph::verify(void) const +{ + _graph.verify(); + + operations().iterate([](const onert::ir::OperationIndex &, const onert::ir::IOperation &op) { + try + { + UNUSED_RELEASE(dynamic_cast<const onert::ir::train::ITrainableOperation &>(op)); + } + catch (const std::bad_cast &) + { + std::runtime_error("TrainableGraph: " + op.name() + " is not a trainable operation"); + } + }); +} + +void TrainableGraph::removeOperand(const OperandIndex &ind) { _graph.removeOperand(ind); } + +void TrainableGraph::setLayout(Layout layout) { _graph.setLayout(layout); } + +const ITrainableOperation &TrainableGraph::operation(OperationIndex index) const +{ + // NOTE Virtual inherited objects cannot be static_casted. + return dynamic_cast<const ITrainableOperation &>(_graph.operations().at(index)); +} + +std::vector<ir::OperationIndex> TrainableGraph::topolSortOperations() const +{ + return _graph.topolSortOperations(); +} + +void TrainableGraph::addLoss(const OperandIndex &loss_ind, const IOIndex &pred_ioind) +{ + _losses.emplace(pred_ioind, loss_ind); +} + +OperandIndex TrainableGraph::getLossIndex(const IOIndex &pred_ioind) const +{ + auto itr = _losses.find(pred_ioind); + return (itr == _losses.end()) ? OperandIndex{} : itr->second; +} + +} // namespace train +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/train/operation/Conv2D.cc b/runtime/onert/core/src/ir/train/operation/Conv2D.cc new file mode 100644 index 000000000..923861ae3 --- /dev/null +++ b/runtime/onert/core/src/ir/train/operation/Conv2D.cc @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/train/operation/Conv2D.h" + +#include "ir/OperationVisitor.h" +#include "ir/train/TrainableOperationVisitor.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +std::unique_ptr<ITrainableOperation> Conv2D::clone() const +{ + return std::make_unique<Conv2D>(*this); +} + +void Conv2D::accept(OperationVisitor &v) const { v.visit(*this); } + +void Conv2D::accept(TrainableOperationVisitor &v) const { v.visit(*this); } + +Conv2D::Conv2D(const OperationType &operation) + : OperationType{operation.getInputs(), operation.getOutputs(), operation.param()} +{ + // DO NOTHING +} + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/train/operation/ElementwiseActivation.cc b/runtime/onert/core/src/ir/train/operation/ElementwiseActivation.cc new file mode 100644 index 000000000..1dae3f674 --- /dev/null +++ b/runtime/onert/core/src/ir/train/operation/ElementwiseActivation.cc @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/train/operation/ElementwiseActivation.h" + +#include "ir/OperationVisitor.h" +#include "ir/train/TrainableOperationVisitor.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +std::unique_ptr<ITrainableOperation> ElementwiseActivation::clone() const +{ + return std::make_unique<ElementwiseActivation>(*this); +} + +void ElementwiseActivation::accept(OperationVisitor &v) const { v.visit(*this); } + +void ElementwiseActivation::accept(TrainableOperationVisitor &v) const { v.visit(*this); } + +ElementwiseActivation::ElementwiseActivation(const OperationType &operation) + : OperationType{operation.getInputs(), operation.getOutputs(), operation.param()} +{ + // DO NOTHING +} + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/train/operation/FullyConnected.cc b/runtime/onert/core/src/ir/train/operation/FullyConnected.cc new file mode 100644 index 000000000..a26f7c489 --- /dev/null +++ b/runtime/onert/core/src/ir/train/operation/FullyConnected.cc @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/train/operation/FullyConnected.h" + +#include "ir/OperationVisitor.h" +#include "ir/train/TrainableOperationVisitor.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +std::unique_ptr<ITrainableOperation> FullyConnected::clone() const +{ + return std::make_unique<FullyConnected>(*this); +} + +void FullyConnected::accept(OperationVisitor &v) const { v.visit(*this); } + +void FullyConnected::accept(TrainableOperationVisitor &v) const { v.visit(*this); } + +FullyConnected::FullyConnected(const OperationType &operation) + : OperationType{operation.getInputs(), operation.getOutputs(), operation.param()} +{ + // DO NOTHING +} + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/train/operation/Loss.cc b/runtime/onert/core/src/ir/train/operation/Loss.cc new file mode 100644 index 000000000..abd79929b --- /dev/null +++ b/runtime/onert/core/src/ir/train/operation/Loss.cc @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/train/operation/Loss.h" + +#include "ir/OperationVisitor.h" +#include "ir/train/TrainableOperationVisitor.h" + +#include <misc/polymorphic_downcast.h> + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +std::unique_ptr<ITrainableOperation> Loss::clone() const { return std::make_unique<Loss>(*this); } + +void Loss::accept(OperationVisitor &v) const { v.visit(*this); } + +void Loss::accept(TrainableOperationVisitor &v) const { v.visit(*this); } + +Loss::Loss(const OperationType &operation) + : OperationType{operation.getInputs(), operation.getOutputs(), operation.param()} +{ + // DO NOTHING +} + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/train/operation/Permute.cc b/runtime/onert/core/src/ir/train/operation/Permute.cc new file mode 100644 index 000000000..adc23aa49 --- /dev/null +++ b/runtime/onert/core/src/ir/train/operation/Permute.cc @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/train/operation/Permute.h" + +#include "ir/OperationVisitor.h" +#include "ir/train/TrainableOperationVisitor.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +std::unique_ptr<ITrainableOperation> Permute::clone() const +{ + return std::make_unique<Permute>(*this); +} + +void Permute::accept(OperationVisitor &v) const { v.visit(*this); } + +void Permute::accept(TrainableOperationVisitor &v) const { v.visit(*this); } + +Permute::Permute(const OperationType &operation) + : OperationType{operation.getInputs().at(0), operation.getOutputs().at(0), + operation.getPermuteType()} +{ + // DO NOTHING +} + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/train/operation/Pool2D.cc b/runtime/onert/core/src/ir/train/operation/Pool2D.cc new file mode 100644 index 000000000..021574f19 --- /dev/null +++ b/runtime/onert/core/src/ir/train/operation/Pool2D.cc @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/train/operation/Pool2D.h" + +#include "ir/OperationVisitor.h" +#include "ir/train/TrainableOperationVisitor.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +std::unique_ptr<ITrainableOperation> Pool2D::clone() const +{ + return std::make_unique<Pool2D>(*this); +} + +void Pool2D::accept(OperationVisitor &v) const { v.visit(*this); } + +void Pool2D::accept(TrainableOperationVisitor &v) const { v.visit(*this); } + +Pool2D::Pool2D(const OperationType &operation) + : OperationType{operation.getInputs(), operation.getOutputs(), operation.param()} +{ + // DO NOTHING +} + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/train/operation/Reshape.cc b/runtime/onert/core/src/ir/train/operation/Reshape.cc new file mode 100644 index 000000000..c76158607 --- /dev/null +++ b/runtime/onert/core/src/ir/train/operation/Reshape.cc @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/train/operation/Reshape.h" + +#include "ir/OperationVisitor.h" +#include "ir/train/TrainableOperationVisitor.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +std::unique_ptr<ITrainableOperation> Reshape::clone() const +{ + return std::make_unique<Reshape>(*this); +} + +void Reshape::accept(OperationVisitor &v) const { v.visit(*this); } + +void Reshape::accept(TrainableOperationVisitor &v) const { v.visit(*this); } + +Reshape::Reshape(const OperationType &operation) + : OperationType{operation.getInputs(), operation.getOutputs(), operation.param()} +{ + // DO NOTHING +} + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/train/operation/Softmax.cc b/runtime/onert/core/src/ir/train/operation/Softmax.cc new file mode 100644 index 000000000..dbd403879 --- /dev/null +++ b/runtime/onert/core/src/ir/train/operation/Softmax.cc @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/train/operation/Softmax.h" + +#include "ir/OperationVisitor.h" +#include "ir/train/TrainableOperationVisitor.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +std::unique_ptr<ITrainableOperation> Softmax::clone() const +{ + return std::make_unique<Softmax>(*this); +} + +void Softmax::accept(OperationVisitor &v) const { v.visit(*this); } + +void Softmax::accept(TrainableOperationVisitor &v) const { v.visit(*this); } + +Softmax::Softmax(const OperationType &operation) + : OperationType{operation.getInputs(), operation.getOutputs(), operation.param()} +{ + // DO NOTHING +} + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/verifier/Verifier.cc b/runtime/onert/core/src/ir/verifier/Verifier.cc index 25a82d5a2..6260d29ff 100644 --- a/runtime/onert/core/src/ir/verifier/Verifier.cc +++ b/runtime/onert/core/src/ir/verifier/Verifier.cc @@ -39,11 +39,11 @@ bool DAGChecker::verify(const Graph &graph) const noexcept OperationIndexMap<bool> visited; operations.iterate( - [&](const OperationIndex &index, const Operation &) { visited[index] = false; }); + [&](const OperationIndex &index, const IOperation &) { visited[index] = false; }); OperationIndexMap<bool> on_stack = visited; // Copy from visited - std::function<void(const OperationIndex &index, const Operation &)> dfs_recursive = - [&](const OperationIndex &index, const Operation &node) -> void { + std::function<void(const OperationIndex &index, const IOperation &)> dfs_recursive = + [&](const OperationIndex &index, const IOperation &node) -> void { if (on_stack[index]) cyclic = true; if (visited[index]) @@ -51,7 +51,7 @@ bool DAGChecker::verify(const Graph &graph) const noexcept visited[index] = true; on_stack[index] = true; - for (auto output : node.getOutputs() | Remove::DUPLICATED | Remove::UNDEFINED) + for (auto &&output : node.getOutputs() | Remove::DUPLICATED | Remove::UNDEFINED) { const auto &operand = graph.operands().at(output); for (const auto &use : operand.getUses()) @@ -76,8 +76,8 @@ bool EdgeChecker::verify(const Graph &graph) const noexcept { auto &operations = graph.operations(); uint32_t errors = 0; - operations.iterate([&](const OperationIndex &index, const Operation &node) { - for (auto operand_index : node.getInputs() | ir::Remove::UNDEFINED) + operations.iterate([&](const OperationIndex &index, const IOperation &node) { + for (auto &&operand_index : node.getInputs() | ir::Remove::UNDEFINED) { try { @@ -98,7 +98,7 @@ bool EdgeChecker::verify(const Graph &graph) const noexcept errors += 1; } } - for (auto operand_index : node.getOutputs() | ir::Remove::UNDEFINED) + for (auto &&operand_index : node.getOutputs() | ir::Remove::UNDEFINED) { try { @@ -127,7 +127,7 @@ bool EdgeChecker::verify(const Graph &graph) const noexcept bool InputOutputChecker::verify(const Graph &graph) const noexcept { - for (auto operand_ind : + for (auto &&operand_ind : (graph.getInputs() + graph.getOutputs()) | Remove::DUPLICATED | Remove::UNDEFINED) { if (!graph.operands().exist(operand_ind)) diff --git a/runtime/onert/core/src/odc/QuantizeManager.cc b/runtime/onert/core/src/odc/QuantizeManager.cc new file mode 100644 index 000000000..71572a7e0 --- /dev/null +++ b/runtime/onert/core/src/odc/QuantizeManager.cc @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "QuantizerLoader.h" +#include "odc/QuantizeManager.h" + +#include <iostream> +#include <mutex> + +namespace onert +{ +namespace odc +{ + +bool QuantizeManager::quantize() +{ + // Compile function is thread-unsafe + static std::mutex lock; + std::lock_guard<std::mutex> guard(lock); + + if (_export_model_path.empty()) + throw std::runtime_error("Export model path is not set"); + + auto &quantize_loader = QuantizerLoader::instance(); + if (quantize_loader.loadLibrary() != 0) + return false; + + auto quantizer = quantize_loader.get(); + auto result = quantizer->quantize(_model_path.c_str(), _export_model_path.c_str(), _is_q16); + + // TODO Unload quantize library to reduce memory usage + + return (result == 0); +} + +} // namespace odc +} // namespace onert diff --git a/runtime/onert/core/src/odc/QuantizeManager.test.cc b/runtime/onert/core/src/odc/QuantizeManager.test.cc new file mode 100644 index 000000000..4e155a6ef --- /dev/null +++ b/runtime/onert/core/src/odc/QuantizeManager.test.cc @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "odc/QuantizeManager.h" + +#include <gtest/gtest.h> + +using namespace onert::odc; + +// Test export model path is not set +TEST(odc_QuantizeManager, neg_export_model_path) +{ + QuantizeManager manager("model_path"); + ASSERT_THROW(manager.quantize(), std::runtime_error); +} + +// Test invalid model path +TEST(odc_QuantizeManager, neg_invalid_model_path) +{ + QuantizeManager manager("invalid_model_path.circle"); + manager.exportModelPath("export_model_path.circle"); + ASSERT_EQ(manager.quantize(), false); +} diff --git a/runtime/onert/core/src/odc/QuantizerLoader.cc b/runtime/onert/core/src/odc/QuantizerLoader.cc new file mode 100644 index 000000000..8a972e97e --- /dev/null +++ b/runtime/onert/core/src/odc/QuantizerLoader.cc @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "QuantizerLoader.h" + +#include <dlfcn.h> +#include <iostream> +#include <string> + +static const char *SHARED_LIB_EXT = +#if defined(__APPLE__) && defined(__MACH__) + ".dylib"; +#else + ".so"; +#endif + +namespace onert +{ +namespace odc +{ + +QuantizerLoader &QuantizerLoader::instance() +{ + static QuantizerLoader singleton; + return singleton; +} + +int32_t QuantizerLoader::loadLibrary() +{ + if (get() != nullptr) + return 0; + + const std::string quantize_so = std::string("libonert_odc") + SHARED_LIB_EXT; + void *handle = dlopen(quantize_so.c_str(), RTLD_LAZY | RTLD_LOCAL); + auto dlerror_msg = dlerror(); + + if (handle == nullptr) + { + std::cerr << "Failed to load " << quantize_so << std::endl; + std::cerr << dlerror_msg << std::endl; + return 1; + } + + { + const char *factory_name = "create_quantizer"; + auto factory = (factory_t)dlsym(handle, factory_name); + dlerror_msg = dlerror(); + + if (factory == nullptr) + { + std::cerr << "QuantizerLoader: unable to find function " << factory_name << dlerror_msg + << std::endl; + dlclose(handle); + return 1; + } + + auto destroyer = (quantizer_destory_t)dlsym(handle, "destroy_quantizer"); + _quantizer = std::unique_ptr<IQuantizer, quantizer_destory_t>(factory(), destroyer); + + if (_quantizer == nullptr) + { + std::cerr << "QuantizerLoader: unable to create quantizer" << std::endl; + dlclose(handle); + return 1; + } + } + + // Save quantize library handle (avoid warning by handle lost without dlclose()) + // clang-format off + _dlhandle = std::unique_ptr<void, dlhandle_destroy_t>{handle, [filename = quantize_so](void *h) { + if (dlclose(h) != 0) + std::cerr << "Failed to unload backend " << filename << std::endl; + }}; + // clang-format on + + return 0; +} + +int32_t QuantizerLoader::unloadLibrary() +{ + if (get() == nullptr) + return 0; + + _quantizer.reset(nullptr); + _dlhandle.reset(nullptr); + + return 0; +} + +} // namespace odc +} // namespace onert diff --git a/runtime/onert/core/src/odc/QuantizerLoader.h b/runtime/onert/core/src/odc/QuantizerLoader.h new file mode 100644 index 000000000..36a9f2996 --- /dev/null +++ b/runtime/onert/core/src/odc/QuantizerLoader.h @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_ODC_QUANTIZER_LOADER_H__ +#define __ONERT_ODC_QUANTIZER_LOADER_H__ + +#include "odc/IQuantizer.h" + +#include <functional> +#include <memory> + +namespace onert +{ +namespace odc +{ + +/** + * @brief Class to manage loading and unloading of dynamic library containing + * implementation of IQuantizer interface + */ +class QuantizerLoader +{ +public: + /** + * @brief Typedef for function pointer to destroy loaded library handle + */ + using dlhandle_destroy_t = std::function<void(void *)>; + /** + * @brief Typedef for function pointer to create instance of IQuantizer + */ + using factory_t = IQuantizer *(*)(); + /** + * @brief Typedef for function pointer to destroy instance of IQuantizer + */ + using quantizer_destory_t = void (*)(IQuantizer *); + + /** + * @brief Get singleton instance of QuantizerLoader + * @return Reference to singleton instance of QuantizerLoader + */ + static QuantizerLoader &instance(); + +private: + // Cannot create instance of QuantizerLoader outside of this class + QuantizerLoader() = default; + QuantizerLoader(QuantizerLoader const &) = delete; + QuantizerLoader &operator=(QuantizerLoader const &) = delete; + ~QuantizerLoader() = default; + +public: + /** + * @brief Load dynamic library containing implementation of IQuantizer + * @return 0 if success, otherwise errno value + */ + int32_t loadLibrary(); + /** + * @brief Unload dynamic library containing implementation of IQuantizer + * @return 0 if success, otherwise errno value + */ + int32_t unloadLibrary(); + /** + * @brief Get instance of IQuantizer created through factory method + * @return Pointer to instance of IQuantizer + */ + IQuantizer *get() const { return _quantizer.get(); } + +private: + // Note: Keep handle to avoid svace warning of "handle lost without dlclose()" + std::unique_ptr<void, dlhandle_destroy_t> _dlhandle; + std::unique_ptr<IQuantizer, quantizer_destory_t> _quantizer{nullptr, nullptr}; +}; + +} // namespace odc +} // namespace onert + +#endif // __ONERT_ODC_QUANTIZER_LOADER_H__ diff --git a/runtime/onert/core/src/odc/QuantizerLoader.test.cc b/runtime/onert/core/src/odc/QuantizerLoader.test.cc new file mode 100644 index 000000000..112e65b27 --- /dev/null +++ b/runtime/onert/core/src/odc/QuantizerLoader.test.cc @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "QuantizerLoader.h" + +#include <gtest/gtest.h> + +using namespace onert::odc; + +// Test QuantizerLoader singleton +TEST(odc_QuantizerLoader, singleton) +{ + QuantizerLoader &loader1 = QuantizerLoader::instance(); + QuantizerLoader &loader2 = QuantizerLoader::instance(); + ASSERT_EQ(&loader1, &loader2); +} + +// Test load quantizer library +TEST(odc_QuantizerLoader, load) +{ + QuantizerLoader &loader = QuantizerLoader::instance(); + // Unload because it may be loaded on previous tests + ASSERT_EQ(loader.unloadLibrary(), 0); + + if (loader.loadLibrary() == 0) + { + // Load twice to check if it is thread-safe + ASSERT_EQ(loader.loadLibrary(), 0); + } +} + +// Get quantizer function without loading quantizer library +TEST(odc_QuantizerLoader, neg_get) +{ + QuantizerLoader &loader = QuantizerLoader::instance(); + // Unload because it may be loaded on previous tests + ASSERT_EQ(loader.unloadLibrary(), 0); + ASSERT_EQ(loader.get(), nullptr); +} + +// Check quantizer function pointer when QuantizerLoader is unloaded +TEST(odc_QuantizerLoader, neg_unload) +{ + QuantizerLoader &loader = QuantizerLoader::instance(); + if (loader.loadLibrary() == 0) + ASSERT_NE(loader.get(), nullptr); + + ASSERT_EQ(loader.unloadLibrary(), 0); + ASSERT_EQ(loader.get(), nullptr); +} diff --git a/runtime/onert/core/src/util/MDTableEventWriter.cc b/runtime/onert/core/src/util/MDTableEventWriter.cc index 13dab5b77..e7d90eec4 100644 --- a/runtime/onert/core/src/util/MDTableEventWriter.cc +++ b/runtime/onert/core/src/util/MDTableEventWriter.cc @@ -124,7 +124,7 @@ struct Graph : public MDContent void setOperations(const std::map<std::string, Operation> &name_to_op) { uint64_t graph_latency = end_ts - begin_ts; - for (auto it : name_to_op) + for (auto &&it : name_to_op) { auto op = it.second; op.graph_latency = graph_latency; @@ -172,7 +172,7 @@ struct Graph : public MDContent writeMDTableRow(os, op_headers_line); // Operation's contents - for (auto op : ops) + for (auto &&op : ops) { op.write(os); } diff --git a/runtime/onert/frontend/base_loader/include/base_loader.h b/runtime/onert/frontend/base_loader/include/base_loader.h index 878a594cc..a6b1fb4a1 100644 --- a/runtime/onert/frontend/base_loader/include/base_loader.h +++ b/runtime/onert/frontend/base_loader/include/base_loader.h @@ -513,11 +513,12 @@ void BaseLoader<LoaderDomain>::loadSparsity(const Tensor *tensor, ir::TypeInfo & if (src_metadata->array_segments() == nullptr || src_metadata->array_indices() == nullptr) return false; bool status = true; + /* `onert` inernally uses uint16 type regardless of the value of + the array_segments_type and array_indices_type */ switch (src_metadata->array_segments_type()) { case SparseIndexVector::SparseIndexVector_Int32Vector: - status = Copy(src_metadata->array_segments_as_Int32Vector(), w1_segments); - break; + throw std::runtime_error("sparse tensor with int32 segment type is not supported"); case SparseIndexVector::SparseIndexVector_Uint16Vector: status = Copy(src_metadata->array_segments_as_Uint16Vector(), w1_segments); break; @@ -532,7 +533,7 @@ void BaseLoader<LoaderDomain>::loadSparsity(const Tensor *tensor, ir::TypeInfo & switch (src_metadata->array_indices_type()) { case SparseIndexVector::SparseIndexVector_Int32Vector: - return Copy(src_metadata->array_indices_as_Int32Vector(), w1_indices); + throw std::runtime_error("sparse tensor with int32 indices type is not supported"); case SparseIndexVector::SparseIndexVector_Uint16Vector: return Copy(src_metadata->array_indices_as_Uint16Vector(), w1_indices); case SparseIndexVector::SparseIndexVector_Uint8Vector: @@ -650,7 +651,19 @@ void BaseLoader<LoaderDomain>::loadConv2D(const Operator *op, ir::Graph &subg) param.dilation.width_factor = options->dilation_w_factor(); param.dilation.height_factor = options->dilation_h_factor(); - loadOperationTo<ir::operation::Conv2D>(op, subg, param); + const auto conv = loadOperationTo<ir::operation::Conv2D>(op, subg, param); + + // TFLite support old hybrid quantization (float input/output, uint8 kernel) + // but it interprets weight type as init8 internally + const auto &input_operand = + subg.operands().at(conv->getInputs().at(ir::operation::Conv2D::INPUT)); + auto &weights_operand = subg.operands().at(conv->getInputs().at(ir::operation::Conv2D::KERNEL)); + if (input_operand.typeInfo().type() == ir::DataType::FLOAT32 && + ((weights_operand.typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM) || + weights_operand.typeInfo().type() == ir::DataType::QUANT_INT8_ASYMM)) + { + weights_operand.type(ir::DataType::QUANT_INT8_SYMM); + } } template <typename LoaderDomain> @@ -665,7 +678,21 @@ void BaseLoader<LoaderDomain>::loadDepthwiseConv2D(const Operator *op, ir::Graph param.dilation.width_factor = options->dilation_w_factor(); param.dilation.height_factor = options->dilation_h_factor(); - loadOperationTo<ir::operation::DepthwiseConv2D>(op, subg, param); + const auto dconv = loadOperationTo<ir::operation::DepthwiseConv2D>(op, subg, param); + + // TFLite does not support old hybrid quantization (float input/output, uint8 kernel) + // for depthwise convolution. + // But for consistency with Conv2D and FC, we interpret weight type as init8 internally + const auto &input_operand = + subg.operands().at(dconv->getInputs().at(ir::operation::DepthwiseConv2D::INPUT)); + auto &weights_operand = + subg.operands().at(dconv->getInputs().at(ir::operation::DepthwiseConv2D::KERNEL)); + if (input_operand.typeInfo().type() == ir::DataType::FLOAT32 && + ((weights_operand.typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM) || + weights_operand.typeInfo().type() == ir::DataType::QUANT_INT8_ASYMM)) + { + weights_operand.type(ir::DataType::QUANT_INT8_SYMM); + } } template <typename LoaderDomain> @@ -745,6 +772,8 @@ void BaseLoader<LoaderDomain>::loadFC(const Operator *op, ir::Graph &subg) const auto fc = loadOperationTo<ir::operation::FullyConnected>(op, subg, param); + // TFLite supports old hybrid quantization (float input/output, uint8 kernel) + // but it interprets weight type as init8 internally const auto &input_operand = subg.operands().at(fc->getInputs().at(ir::operation::FullyConnected::INPUT)); auto &weights_operand = diff --git a/runtime/onert/frontend/circle_schema/include/circle_schema_generated.h b/runtime/onert/frontend/circle_schema/include/circle_schema_generated.h index e3c92eae0..dd6f9dcd7 100644 --- a/runtime/onert/frontend/circle_schema/include/circle_schema_generated.h +++ b/runtime/onert/frontend/circle_schema/include/circle_schema_generated.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2019-2023 Samsung Electronics Co., Ltd. All Rights Reserved * Copyright 2018 The TensorFlow Authors. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -45,6 +45,9 @@ struct DimensionMetadataBuilder; struct SparsityParameters; struct SparsityParametersBuilder; +struct VariantSubType; +struct VariantSubTypeBuilder; + struct Tensor; struct TensorBuilder; @@ -390,6 +393,42 @@ struct AssignVariableOptionsBuilder; struct RandomOptions; struct RandomOptionsBuilder; +struct BucketizeOptions; +struct BucketizeOptionsBuilder; + +struct GeluOptions; +struct GeluOptionsBuilder; + +struct DynamicUpdateSliceOptions; +struct DynamicUpdateSliceOptionsBuilder; + +struct UnsortedSegmentProdOptions; +struct UnsortedSegmentProdOptionsBuilder; + +struct UnsortedSegmentMaxOptions; +struct UnsortedSegmentMaxOptionsBuilder; + +struct UnsortedSegmentSumOptions; +struct UnsortedSegmentSumOptionsBuilder; + +struct ATan2Options; +struct ATan2OptionsBuilder; + +struct UnsortedSegmentMinOptions; +struct UnsortedSegmentMinOptionsBuilder; + +struct SignOptions; +struct SignOptionsBuilder; + +struct BitcastOptions; +struct BitcastOptionsBuilder; + +struct BitwiseXorOptions; +struct BitwiseXorOptionsBuilder; + +struct RightShiftOptions; +struct RightShiftOptionsBuilder; + struct BCQGatherOptions; struct BCQGatherOptionsBuilder; @@ -441,32 +480,35 @@ enum TensorType : int8_t TensorType_RESOURCE = 13, TensorType_VARIANT = 14, TensorType_UINT32 = 15, + TensorType_UINT16 = 16, + TensorType_INT4 = 17, TensorType_MIN = TensorType_FLOAT32, - TensorType_MAX = TensorType_UINT32 + TensorType_MAX = TensorType_INT4 }; -inline const TensorType (&EnumValuesTensorType())[16] +inline const TensorType (&EnumValuesTensorType())[18] { static const TensorType values[] = { TensorType_FLOAT32, TensorType_FLOAT16, TensorType_INT32, TensorType_UINT8, TensorType_INT64, TensorType_STRING, TensorType_BOOL, TensorType_INT16, TensorType_COMPLEX64, TensorType_INT8, TensorType_FLOAT64, TensorType_COMPLEX128, - TensorType_UINT64, TensorType_RESOURCE, TensorType_VARIANT, TensorType_UINT32}; + TensorType_UINT64, TensorType_RESOURCE, TensorType_VARIANT, TensorType_UINT32, + TensorType_UINT16, TensorType_INT4}; return values; } inline const char *const *EnumNamesTensorType() { - static const char *const names[17] = {"FLOAT32", "FLOAT16", "INT32", "UINT8", "INT64", + static const char *const names[19] = {"FLOAT32", "FLOAT16", "INT32", "UINT8", "INT64", "STRING", "BOOL", "INT16", "COMPLEX64", "INT8", "FLOAT64", "COMPLEX128", "UINT64", "RESOURCE", "VARIANT", - "UINT32", nullptr}; + "UINT32", "UINT16", "INT4", nullptr}; return names; } inline const char *EnumNameTensorType(TensorType e) { - if (flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_UINT32)) + if (flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_INT4)) return ""; const size_t index = static_cast<size_t>(e); return EnumNamesTensorType()[index]; @@ -756,11 +798,26 @@ enum BuiltinOperator : int32_t BuiltinOperator_ASSIGN_VARIABLE = 144, BuiltinOperator_BROADCAST_ARGS = 145, BuiltinOperator_RANDOM_STANDARD_NORMAL = 146, + BuiltinOperator_BUCKETIZE = 147, + BuiltinOperator_RANDOM_UNIFORM = 148, + BuiltinOperator_MULTINOMIAL = 149, + BuiltinOperator_GELU = 150, + BuiltinOperator_DYNAMIC_UPDATE_SLICE = 151, + BuiltinOperator_RELU_0_TO_1 = 152, + BuiltinOperator_UNSORTED_SEGMENT_PROD = 153, + BuiltinOperator_UNSORTED_SEGMENT_MAX = 154, + BuiltinOperator_UNSORTED_SEGMENT_SUM = 155, + BuiltinOperator_ATAN2 = 156, + BuiltinOperator_UNSORTED_SEGMENT_MIN = 157, + BuiltinOperator_SIGN = 158, + BuiltinOperator_BITCAST = 159, + BuiltinOperator_BITWISE_XOR = 160, + BuiltinOperator_RIGHT_SHIFT = 161, BuiltinOperator_MIN = BuiltinOperator_BCQ_GATHER, - BuiltinOperator_MAX = BuiltinOperator_RANDOM_STANDARD_NORMAL + BuiltinOperator_MAX = BuiltinOperator_RIGHT_SHIFT }; -inline const BuiltinOperator (&EnumValuesBuiltinOperator())[150] +inline const BuiltinOperator (&EnumValuesBuiltinOperator())[165] { static const BuiltinOperator values[] = {BuiltinOperator_BCQ_GATHER, BuiltinOperator_BCQ_FULLY_CONNECTED, @@ -911,13 +968,28 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[150] BuiltinOperator_READ_VARIABLE, BuiltinOperator_ASSIGN_VARIABLE, BuiltinOperator_BROADCAST_ARGS, - BuiltinOperator_RANDOM_STANDARD_NORMAL}; + BuiltinOperator_RANDOM_STANDARD_NORMAL, + BuiltinOperator_BUCKETIZE, + BuiltinOperator_RANDOM_UNIFORM, + BuiltinOperator_MULTINOMIAL, + BuiltinOperator_GELU, + BuiltinOperator_DYNAMIC_UPDATE_SLICE, + BuiltinOperator_RELU_0_TO_1, + BuiltinOperator_UNSORTED_SEGMENT_PROD, + BuiltinOperator_UNSORTED_SEGMENT_MAX, + BuiltinOperator_UNSORTED_SEGMENT_SUM, + BuiltinOperator_ATAN2, + BuiltinOperator_UNSORTED_SEGMENT_MIN, + BuiltinOperator_SIGN, + BuiltinOperator_BITCAST, + BuiltinOperator_BITWISE_XOR, + BuiltinOperator_RIGHT_SHIFT}; return values; } inline const char *const *EnumNamesBuiltinOperator() { - static const char *const names[152] = {"BCQ_GATHER", + static const char *const names[167] = {"BCQ_GATHER", "BCQ_FULLY_CONNECTED", "INSTANCE_NORM", "", @@ -1068,14 +1140,28 @@ inline const char *const *EnumNamesBuiltinOperator() "ASSIGN_VARIABLE", "BROADCAST_ARGS", "RANDOM_STANDARD_NORMAL", + "BUCKETIZE", + "RANDOM_UNIFORM", + "MULTINOMIAL", + "GELU", + "DYNAMIC_UPDATE_SLICE", + "RELU_0_TO_1", + "UNSORTED_SEGMENT_PROD", + "UNSORTED_SEGMENT_MAX", + "UNSORTED_SEGMENT_SUM", + "ATAN2", + "UNSORTED_SEGMENT_MIN", + "SIGN", + "BITCAST", + "BITWISE_XOR", + "RIGHT_SHIFT", nullptr}; return names; } inline const char *EnumNameBuiltinOperator(BuiltinOperator e) { - if (flatbuffers::IsOutRange(e, BuiltinOperator_BCQ_GATHER, - BuiltinOperator_RANDOM_STANDARD_NORMAL)) + if (flatbuffers::IsOutRange(e, BuiltinOperator_BCQ_GATHER, BuiltinOperator_RIGHT_SHIFT)) return ""; const size_t index = static_cast<size_t>(e) - static_cast<size_t>(BuiltinOperator_BCQ_GATHER); return EnumNamesBuiltinOperator()[index]; @@ -1198,6 +1284,18 @@ enum BuiltinOptions : uint8_t BuiltinOptions_ReadVariableOptions = 112, BuiltinOptions_AssignVariableOptions = 113, BuiltinOptions_RandomOptions = 114, + BuiltinOptions_BucketizeOptions = 115, + BuiltinOptions_GeluOptions = 116, + BuiltinOptions_DynamicUpdateSliceOptions = 117, + BuiltinOptions_UnsortedSegmentProdOptions = 118, + BuiltinOptions_UnsortedSegmentMaxOptions = 119, + BuiltinOptions_UnsortedSegmentMinOptions = 120, + BuiltinOptions_UnsortedSegmentSumOptions = 121, + BuiltinOptions_ATan2Options = 122, + BuiltinOptions_SignOptions = 123, + BuiltinOptions_BitcastOptions = 124, + BuiltinOptions_BitwiseXorOptions = 125, + BuiltinOptions_RightShiftOptions = 126, BuiltinOptions_BCQGatherOptions = 252, BuiltinOptions_BCQFullyConnectedOptions = 253, BuiltinOptions_InstanceNormOptions = 254, @@ -1205,7 +1303,7 @@ enum BuiltinOptions : uint8_t BuiltinOptions_MAX = BuiltinOptions_InstanceNormOptions }; -inline const BuiltinOptions (&EnumValuesBuiltinOptions())[118] +inline const BuiltinOptions (&EnumValuesBuiltinOptions())[130] { static const BuiltinOptions values[] = {BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -1322,6 +1420,18 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[118] BuiltinOptions_ReadVariableOptions, BuiltinOptions_AssignVariableOptions, BuiltinOptions_RandomOptions, + BuiltinOptions_BucketizeOptions, + BuiltinOptions_GeluOptions, + BuiltinOptions_DynamicUpdateSliceOptions, + BuiltinOptions_UnsortedSegmentProdOptions, + BuiltinOptions_UnsortedSegmentMaxOptions, + BuiltinOptions_UnsortedSegmentMinOptions, + BuiltinOptions_UnsortedSegmentSumOptions, + BuiltinOptions_ATan2Options, + BuiltinOptions_SignOptions, + BuiltinOptions_BitcastOptions, + BuiltinOptions_BitwiseXorOptions, + BuiltinOptions_RightShiftOptions, BuiltinOptions_BCQGatherOptions, BuiltinOptions_BCQFullyConnectedOptions, BuiltinOptions_InstanceNormOptions}; @@ -1445,18 +1555,18 @@ inline const char *const *EnumNamesBuiltinOptions() "ReadVariableOptions", "AssignVariableOptions", "RandomOptions", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", + "BucketizeOptions", + "GeluOptions", + "DynamicUpdateSliceOptions", + "UnsortedSegmentProdOptions", + "UnsortedSegmentMaxOptions", + "UnsortedSegmentMinOptions", + "UnsortedSegmentSumOptions", + "ATan2Options", + "SignOptions", + "BitcastOptions", + "BitwiseXorOptions", + "RightShiftOptions", "", "", "", @@ -2172,6 +2282,66 @@ template <> struct BuiltinOptionsTraits<circle::RandomOptions> static const BuiltinOptions enum_value = BuiltinOptions_RandomOptions; }; +template <> struct BuiltinOptionsTraits<circle::BucketizeOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_BucketizeOptions; +}; + +template <> struct BuiltinOptionsTraits<circle::GeluOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_GeluOptions; +}; + +template <> struct BuiltinOptionsTraits<circle::DynamicUpdateSliceOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_DynamicUpdateSliceOptions; +}; + +template <> struct BuiltinOptionsTraits<circle::UnsortedSegmentProdOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_UnsortedSegmentProdOptions; +}; + +template <> struct BuiltinOptionsTraits<circle::UnsortedSegmentMaxOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_UnsortedSegmentMaxOptions; +}; + +template <> struct BuiltinOptionsTraits<circle::UnsortedSegmentMinOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_UnsortedSegmentMinOptions; +}; + +template <> struct BuiltinOptionsTraits<circle::UnsortedSegmentSumOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_UnsortedSegmentSumOptions; +}; + +template <> struct BuiltinOptionsTraits<circle::ATan2Options> +{ + static const BuiltinOptions enum_value = BuiltinOptions_ATan2Options; +}; + +template <> struct BuiltinOptionsTraits<circle::SignOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_SignOptions; +}; + +template <> struct BuiltinOptionsTraits<circle::BitcastOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_BitcastOptions; +}; + +template <> struct BuiltinOptionsTraits<circle::BitwiseXorOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_BitwiseXorOptions; +}; + +template <> struct BuiltinOptionsTraits<circle::RightShiftOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_RightShiftOptions; +}; + template <> struct BuiltinOptionsTraits<circle::BCQGatherOptions> { static const BuiltinOptions enum_value = BuiltinOptions_BCQGatherOptions; @@ -3103,6 +3273,81 @@ inline flatbuffers::Offset<SparsityParameters> CreateSparsityParametersDirect( return circle::CreateSparsityParameters(_fbb, traversal_order__, block_map__, dim_metadata__); } +struct VariantSubType FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table +{ + typedef VariantSubTypeBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_SHAPE = 4, + VT_TYPE = 6, + VT_HAS_RANK = 8 + }; + const flatbuffers::Vector<int32_t> *shape() const + { + return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_SHAPE); + } + circle::TensorType type() const + { + return static_cast<circle::TensorType>(GetField<int8_t>(VT_TYPE, 0)); + } + bool has_rank() const { return GetField<uint8_t>(VT_HAS_RANK, 0) != 0; } + bool Verify(flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_SHAPE) && + verifier.VerifyVector(shape()) && VerifyField<int8_t>(verifier, VT_TYPE) && + VerifyField<uint8_t>(verifier, VT_HAS_RANK) && verifier.EndTable(); + } +}; + +struct VariantSubTypeBuilder +{ + typedef VariantSubType Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_shape(flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape) + { + fbb_.AddOffset(VariantSubType::VT_SHAPE, shape); + } + void add_type(circle::TensorType type) + { + fbb_.AddElement<int8_t>(VariantSubType::VT_TYPE, static_cast<int8_t>(type), 0); + } + void add_has_rank(bool has_rank) + { + fbb_.AddElement<uint8_t>(VariantSubType::VT_HAS_RANK, static_cast<uint8_t>(has_rank), 0); + } + explicit VariantSubTypeBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset<VariantSubType> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<VariantSubType>(end); + return o; + } +}; + +inline flatbuffers::Offset<VariantSubType> +CreateVariantSubType(flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape = 0, + circle::TensorType type = circle::TensorType_FLOAT32, bool has_rank = false) +{ + VariantSubTypeBuilder builder_(_fbb); + builder_.add_shape(shape); + builder_.add_has_rank(has_rank); + builder_.add_type(type); + return builder_.Finish(); +} + +inline flatbuffers::Offset<VariantSubType> CreateVariantSubTypeDirect( + flatbuffers::FlatBufferBuilder &_fbb, const std::vector<int32_t> *shape = nullptr, + circle::TensorType type = circle::TensorType_FLOAT32, bool has_rank = false) +{ + auto shape__ = shape ? _fbb.CreateVector<int32_t>(*shape) : 0; + return circle::CreateVariantSubType(_fbb, shape__, type, has_rank); +} + struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef TensorBuilder Builder; @@ -3115,7 +3360,9 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table VT_QUANTIZATION = 12, VT_IS_VARIABLE = 14, VT_SPARSITY = 16, - VT_SHAPE_SIGNATURE = 18 + VT_SHAPE_SIGNATURE = 18, + VT_HAS_RANK = 20, + VT_VARIANT_TENSORS = 22 }; const flatbuffers::Vector<int32_t> *shape() const { @@ -3143,6 +3390,12 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_SHAPE_SIGNATURE); } + bool has_rank() const { return GetField<uint8_t>(VT_HAS_RANK, 0) != 0; } + const flatbuffers::Vector<flatbuffers::Offset<circle::VariantSubType>> *variant_tensors() const + { + return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<circle::VariantSubType>> *>( + VT_VARIANT_TENSORS); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_SHAPE) && @@ -3152,7 +3405,9 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table verifier.VerifyTable(quantization()) && VerifyField<uint8_t>(verifier, VT_IS_VARIABLE) && VerifyOffset(verifier, VT_SPARSITY) && verifier.VerifyTable(sparsity()) && VerifyOffset(verifier, VT_SHAPE_SIGNATURE) && verifier.VerifyVector(shape_signature()) && - verifier.EndTable(); + VerifyField<uint8_t>(verifier, VT_HAS_RANK) && + VerifyOffset(verifier, VT_VARIANT_TENSORS) && verifier.VerifyVector(variant_tensors()) && + verifier.VerifyVectorOfTables(variant_tensors()) && verifier.EndTable(); } }; @@ -3190,6 +3445,16 @@ struct TensorBuilder { fbb_.AddOffset(Tensor::VT_SHAPE_SIGNATURE, shape_signature); } + void add_has_rank(bool has_rank) + { + fbb_.AddElement<uint8_t>(Tensor::VT_HAS_RANK, static_cast<uint8_t>(has_rank), 0); + } + void add_variant_tensors( + flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<circle::VariantSubType>>> + variant_tensors) + { + fbb_.AddOffset(Tensor::VT_VARIANT_TENSORS, variant_tensors); + } explicit TensorBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -3202,22 +3467,25 @@ struct TensorBuilder } }; -inline flatbuffers::Offset<Tensor> -CreateTensor(flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape = 0, - circle::TensorType type = circle::TensorType_FLOAT32, uint32_t buffer = 0, - flatbuffers::Offset<flatbuffers::String> name = 0, - flatbuffers::Offset<circle::QuantizationParameters> quantization = 0, - bool is_variable = false, flatbuffers::Offset<circle::SparsityParameters> sparsity = 0, - flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape_signature = 0) +inline flatbuffers::Offset<Tensor> CreateTensor( + flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape = 0, + circle::TensorType type = circle::TensorType_FLOAT32, uint32_t buffer = 0, + flatbuffers::Offset<flatbuffers::String> name = 0, + flatbuffers::Offset<circle::QuantizationParameters> quantization = 0, bool is_variable = false, + flatbuffers::Offset<circle::SparsityParameters> sparsity = 0, + flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape_signature = 0, bool has_rank = false, + flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<circle::VariantSubType>>> + variant_tensors = 0) { TensorBuilder builder_(_fbb); + builder_.add_variant_tensors(variant_tensors); builder_.add_shape_signature(shape_signature); builder_.add_sparsity(sparsity); builder_.add_quantization(quantization); builder_.add_name(name); builder_.add_buffer(buffer); builder_.add_shape(shape); + builder_.add_has_rank(has_rank); builder_.add_is_variable(is_variable); builder_.add_type(type); return builder_.Finish(); @@ -3228,13 +3496,18 @@ inline flatbuffers::Offset<Tensor> CreateTensorDirect( circle::TensorType type = circle::TensorType_FLOAT32, uint32_t buffer = 0, const char *name = nullptr, flatbuffers::Offset<circle::QuantizationParameters> quantization = 0, bool is_variable = false, flatbuffers::Offset<circle::SparsityParameters> sparsity = 0, - const std::vector<int32_t> *shape_signature = nullptr) + const std::vector<int32_t> *shape_signature = nullptr, bool has_rank = false, + const std::vector<flatbuffers::Offset<circle::VariantSubType>> *variant_tensors = nullptr) { auto shape__ = shape ? _fbb.CreateVector<int32_t>(*shape) : 0; auto name__ = name ? _fbb.CreateString(name) : 0; auto shape_signature__ = shape_signature ? _fbb.CreateVector<int32_t>(*shape_signature) : 0; + auto variant_tensors__ = + variant_tensors + ? _fbb.CreateVector<flatbuffers::Offset<circle::VariantSubType>>(*variant_tensors) + : 0; return circle::CreateTensor(_fbb, shape__, type, buffer, name__, quantization, is_variable, - sparsity, shape_signature__); + sparsity, shape_signature__, has_rank, variant_tensors__); } struct Conv2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table @@ -4561,7 +4834,8 @@ struct UnidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatb VT_CELL_CLIP = 6, VT_PROJ_CLIP = 8, VT_TIME_MAJOR = 10, - VT_ASYMMETRIC_QUANTIZE_INPUTS = 12 + VT_ASYMMETRIC_QUANTIZE_INPUTS = 12, + VT_DIAGONAL_RECURRENT_TENSORS = 14 }; circle::ActivationFunctionType fused_activation_function() const { @@ -4575,6 +4849,10 @@ struct UnidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatb { return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0; } + bool diagonal_recurrent_tensors() const + { + return GetField<uint8_t>(VT_DIAGONAL_RECURRENT_TENSORS, 0) != 0; + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -4582,7 +4860,8 @@ struct UnidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatb VerifyField<float>(verifier, VT_CELL_CLIP) && VerifyField<float>(verifier, VT_PROJ_CLIP) && VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) && - VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) && verifier.EndTable(); + VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) && + VerifyField<uint8_t>(verifier, VT_DIAGONAL_RECURRENT_TENSORS) && verifier.EndTable(); } }; @@ -4614,6 +4893,11 @@ struct UnidirectionalSequenceLSTMOptionsBuilder fbb_.AddElement<uint8_t>(UnidirectionalSequenceLSTMOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast<uint8_t>(asymmetric_quantize_inputs), 0); } + void add_diagonal_recurrent_tensors(bool diagonal_recurrent_tensors) + { + fbb_.AddElement<uint8_t>(UnidirectionalSequenceLSTMOptions::VT_DIAGONAL_RECURRENT_TENSORS, + static_cast<uint8_t>(diagonal_recurrent_tensors), 0); + } explicit UnidirectionalSequenceLSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { @@ -4632,11 +4916,12 @@ CreateUnidirectionalSequenceLSTMOptions( flatbuffers::FlatBufferBuilder &_fbb, circle::ActivationFunctionType fused_activation_function = circle::ActivationFunctionType_NONE, float cell_clip = 0.0f, float proj_clip = 0.0f, bool time_major = false, - bool asymmetric_quantize_inputs = false) + bool asymmetric_quantize_inputs = false, bool diagonal_recurrent_tensors = false) { UnidirectionalSequenceLSTMOptionsBuilder builder_(_fbb); builder_.add_proj_clip(proj_clip); builder_.add_cell_clip(cell_clip); + builder_.add_diagonal_recurrent_tensors(diagonal_recurrent_tensors); builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs); builder_.add_time_major(time_major); builder_.add_fused_activation_function(fused_activation_function); @@ -6350,7 +6635,8 @@ struct TransposeConvOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_PADDING = 4, VT_STRIDE_W = 6, - VT_STRIDE_H = 8 + VT_STRIDE_H = 8, + VT_FUSED_ACTIVATION_FUNCTION = 10 }; circle::Padding padding() const { @@ -6358,11 +6644,17 @@ struct TransposeConvOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table } int32_t stride_w() const { return GetField<int32_t>(VT_STRIDE_W, 0); } int32_t stride_h() const { return GetField<int32_t>(VT_STRIDE_H, 0); } + circle::ActivationFunctionType fused_activation_function() const + { + return static_cast<circle::ActivationFunctionType>( + GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField<int8_t>(verifier, VT_PADDING) && VerifyField<int32_t>(verifier, VT_STRIDE_W) && - VerifyField<int32_t>(verifier, VT_STRIDE_H) && verifier.EndTable(); + VerifyField<int32_t>(verifier, VT_STRIDE_H) && + VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) && verifier.EndTable(); } }; @@ -6383,6 +6675,11 @@ struct TransposeConvOptionsBuilder { fbb_.AddElement<int32_t>(TransposeConvOptions::VT_STRIDE_H, stride_h, 0); } + void add_fused_activation_function(circle::ActivationFunctionType fused_activation_function) + { + fbb_.AddElement<int8_t>(TransposeConvOptions::VT_FUSED_ACTIVATION_FUNCTION, + static_cast<int8_t>(fused_activation_function), 0); + } explicit TransposeConvOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -6395,14 +6692,15 @@ struct TransposeConvOptionsBuilder } }; -inline flatbuffers::Offset<TransposeConvOptions> -CreateTransposeConvOptions(flatbuffers::FlatBufferBuilder &_fbb, - circle::Padding padding = circle::Padding_SAME, int32_t stride_w = 0, - int32_t stride_h = 0) +inline flatbuffers::Offset<TransposeConvOptions> CreateTransposeConvOptions( + flatbuffers::FlatBufferBuilder &_fbb, circle::Padding padding = circle::Padding_SAME, + int32_t stride_w = 0, int32_t stride_h = 0, + circle::ActivationFunctionType fused_activation_function = circle::ActivationFunctionType_NONE) { TransposeConvOptionsBuilder builder_(_fbb); builder_.add_stride_h(stride_h); builder_.add_stride_w(stride_w); + builder_.add_fused_activation_function(fused_activation_function); builder_.add_padding(padding); return builder_.Finish(); } @@ -8506,12 +8804,12 @@ struct RandomOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table VT_SEED = 4, VT_SEED2 = 6 }; - int32_t seed() const { return GetField<int32_t>(VT_SEED, 0); } - int32_t seed2() const { return GetField<int32_t>(VT_SEED2, 0); } + int64_t seed() const { return GetField<int64_t>(VT_SEED, 0); } + int64_t seed2() const { return GetField<int64_t>(VT_SEED2, 0); } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && VerifyField<int32_t>(verifier, VT_SEED) && - VerifyField<int32_t>(verifier, VT_SEED2) && verifier.EndTable(); + return VerifyTableStart(verifier) && VerifyField<int64_t>(verifier, VT_SEED) && + VerifyField<int64_t>(verifier, VT_SEED2) && verifier.EndTable(); } }; @@ -8520,8 +8818,8 @@ struct RandomOptionsBuilder typedef RandomOptions Table; flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_seed(int32_t seed) { fbb_.AddElement<int32_t>(RandomOptions::VT_SEED, seed, 0); } - void add_seed2(int32_t seed2) { fbb_.AddElement<int32_t>(RandomOptions::VT_SEED2, seed2, 0); } + void add_seed(int64_t seed) { fbb_.AddElement<int64_t>(RandomOptions::VT_SEED, seed, 0); } + void add_seed2(int64_t seed2) { fbb_.AddElement<int64_t>(RandomOptions::VT_SEED2, seed2, 0); } explicit RandomOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -8535,7 +8833,7 @@ struct RandomOptionsBuilder }; inline flatbuffers::Offset<RandomOptions> CreateRandomOptions(flatbuffers::FlatBufferBuilder &_fbb, - int32_t seed = 0, int32_t seed2 = 0) + int64_t seed = 0, int64_t seed2 = 0) { RandomOptionsBuilder builder_(_fbb); builder_.add_seed2(seed2); @@ -8543,6 +8841,434 @@ inline flatbuffers::Offset<RandomOptions> CreateRandomOptions(flatbuffers::FlatB return builder_.Finish(); } +struct BucketizeOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table +{ + typedef BucketizeOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_BOUNDARIES = 4 + }; + const flatbuffers::Vector<float> *boundaries() const + { + return GetPointer<const flatbuffers::Vector<float> *>(VT_BOUNDARIES); + } + bool Verify(flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_BOUNDARIES) && + verifier.VerifyVector(boundaries()) && verifier.EndTable(); + } +}; + +struct BucketizeOptionsBuilder +{ + typedef BucketizeOptions Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_boundaries(flatbuffers::Offset<flatbuffers::Vector<float>> boundaries) + { + fbb_.AddOffset(BucketizeOptions::VT_BOUNDARIES, boundaries); + } + explicit BucketizeOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset<BucketizeOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<BucketizeOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<BucketizeOptions> +CreateBucketizeOptions(flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset<flatbuffers::Vector<float>> boundaries = 0) +{ + BucketizeOptionsBuilder builder_(_fbb); + builder_.add_boundaries(boundaries); + return builder_.Finish(); +} + +inline flatbuffers::Offset<BucketizeOptions> +CreateBucketizeOptionsDirect(flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<float> *boundaries = nullptr) +{ + auto boundaries__ = boundaries ? _fbb.CreateVector<float>(*boundaries) : 0; + return circle::CreateBucketizeOptions(_fbb, boundaries__); +} + +struct GeluOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table +{ + typedef GeluOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_APPROXIMATE = 4 + }; + bool approximate() const { return GetField<uint8_t>(VT_APPROXIMATE, 0) != 0; } + bool Verify(flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<uint8_t>(verifier, VT_APPROXIMATE) && + verifier.EndTable(); + } +}; + +struct GeluOptionsBuilder +{ + typedef GeluOptions Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_approximate(bool approximate) + { + fbb_.AddElement<uint8_t>(GeluOptions::VT_APPROXIMATE, static_cast<uint8_t>(approximate), 0); + } + explicit GeluOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset<GeluOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<GeluOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<GeluOptions> CreateGeluOptions(flatbuffers::FlatBufferBuilder &_fbb, + bool approximate = false) +{ + GeluOptionsBuilder builder_(_fbb); + builder_.add_approximate(approximate); + return builder_.Finish(); +} + +struct DynamicUpdateSliceOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table +{ + typedef DynamicUpdateSliceOptionsBuilder Builder; + bool Verify(flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct DynamicUpdateSliceOptionsBuilder +{ + typedef DynamicUpdateSliceOptions Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit DynamicUpdateSliceOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset<DynamicUpdateSliceOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<DynamicUpdateSliceOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<DynamicUpdateSliceOptions> +CreateDynamicUpdateSliceOptions(flatbuffers::FlatBufferBuilder &_fbb) +{ + DynamicUpdateSliceOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct UnsortedSegmentProdOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table +{ + typedef UnsortedSegmentProdOptionsBuilder Builder; + bool Verify(flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct UnsortedSegmentProdOptionsBuilder +{ + typedef UnsortedSegmentProdOptions Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit UnsortedSegmentProdOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset<UnsortedSegmentProdOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<UnsortedSegmentProdOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<UnsortedSegmentProdOptions> +CreateUnsortedSegmentProdOptions(flatbuffers::FlatBufferBuilder &_fbb) +{ + UnsortedSegmentProdOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct UnsortedSegmentMaxOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table +{ + typedef UnsortedSegmentMaxOptionsBuilder Builder; + bool Verify(flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct UnsortedSegmentMaxOptionsBuilder +{ + typedef UnsortedSegmentMaxOptions Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit UnsortedSegmentMaxOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset<UnsortedSegmentMaxOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<UnsortedSegmentMaxOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<UnsortedSegmentMaxOptions> +CreateUnsortedSegmentMaxOptions(flatbuffers::FlatBufferBuilder &_fbb) +{ + UnsortedSegmentMaxOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct UnsortedSegmentSumOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table +{ + typedef UnsortedSegmentSumOptionsBuilder Builder; + bool Verify(flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct UnsortedSegmentSumOptionsBuilder +{ + typedef UnsortedSegmentSumOptions Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit UnsortedSegmentSumOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset<UnsortedSegmentSumOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<UnsortedSegmentSumOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<UnsortedSegmentSumOptions> +CreateUnsortedSegmentSumOptions(flatbuffers::FlatBufferBuilder &_fbb) +{ + UnsortedSegmentSumOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct ATan2Options FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table +{ + typedef ATan2OptionsBuilder Builder; + bool Verify(flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct ATan2OptionsBuilder +{ + typedef ATan2Options Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit ATan2OptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset<ATan2Options> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<ATan2Options>(end); + return o; + } +}; + +inline flatbuffers::Offset<ATan2Options> CreateATan2Options(flatbuffers::FlatBufferBuilder &_fbb) +{ + ATan2OptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct UnsortedSegmentMinOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table +{ + typedef UnsortedSegmentMinOptionsBuilder Builder; + bool Verify(flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct UnsortedSegmentMinOptionsBuilder +{ + typedef UnsortedSegmentMinOptions Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit UnsortedSegmentMinOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset<UnsortedSegmentMinOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<UnsortedSegmentMinOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<UnsortedSegmentMinOptions> +CreateUnsortedSegmentMinOptions(flatbuffers::FlatBufferBuilder &_fbb) +{ + UnsortedSegmentMinOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct SignOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table +{ + typedef SignOptionsBuilder Builder; + bool Verify(flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct SignOptionsBuilder +{ + typedef SignOptions Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit SignOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset<SignOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<SignOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<SignOptions> CreateSignOptions(flatbuffers::FlatBufferBuilder &_fbb) +{ + SignOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct BitcastOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table +{ + typedef BitcastOptionsBuilder Builder; + bool Verify(flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct BitcastOptionsBuilder +{ + typedef BitcastOptions Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit BitcastOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset<BitcastOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<BitcastOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<BitcastOptions> +CreateBitcastOptions(flatbuffers::FlatBufferBuilder &_fbb) +{ + BitcastOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct BitwiseXorOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table +{ + typedef BitwiseXorOptionsBuilder Builder; + bool Verify(flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct BitwiseXorOptionsBuilder +{ + typedef BitwiseXorOptions Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit BitwiseXorOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset<BitwiseXorOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<BitwiseXorOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<BitwiseXorOptions> +CreateBitwiseXorOptions(flatbuffers::FlatBufferBuilder &_fbb) +{ + BitwiseXorOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct RightShiftOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table +{ + typedef RightShiftOptionsBuilder Builder; + bool Verify(flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct RightShiftOptionsBuilder +{ + typedef RightShiftOptions Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit RightShiftOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset<RightShiftOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<RightShiftOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<RightShiftOptions> +CreateRightShiftOptions(flatbuffers::FlatBufferBuilder &_fbb) +{ + RightShiftOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + struct BCQGatherOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef BCQGatherOptionsBuilder Builder; @@ -9513,6 +10239,78 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table ? static_cast<const circle::RandomOptions *>(builtin_options()) : nullptr; } + const circle::BucketizeOptions *builtin_options_as_BucketizeOptions() const + { + return builtin_options_type() == circle::BuiltinOptions_BucketizeOptions + ? static_cast<const circle::BucketizeOptions *>(builtin_options()) + : nullptr; + } + const circle::GeluOptions *builtin_options_as_GeluOptions() const + { + return builtin_options_type() == circle::BuiltinOptions_GeluOptions + ? static_cast<const circle::GeluOptions *>(builtin_options()) + : nullptr; + } + const circle::DynamicUpdateSliceOptions *builtin_options_as_DynamicUpdateSliceOptions() const + { + return builtin_options_type() == circle::BuiltinOptions_DynamicUpdateSliceOptions + ? static_cast<const circle::DynamicUpdateSliceOptions *>(builtin_options()) + : nullptr; + } + const circle::UnsortedSegmentProdOptions *builtin_options_as_UnsortedSegmentProdOptions() const + { + return builtin_options_type() == circle::BuiltinOptions_UnsortedSegmentProdOptions + ? static_cast<const circle::UnsortedSegmentProdOptions *>(builtin_options()) + : nullptr; + } + const circle::UnsortedSegmentMaxOptions *builtin_options_as_UnsortedSegmentMaxOptions() const + { + return builtin_options_type() == circle::BuiltinOptions_UnsortedSegmentMaxOptions + ? static_cast<const circle::UnsortedSegmentMaxOptions *>(builtin_options()) + : nullptr; + } + const circle::UnsortedSegmentMinOptions *builtin_options_as_UnsortedSegmentMinOptions() const + { + return builtin_options_type() == circle::BuiltinOptions_UnsortedSegmentMinOptions + ? static_cast<const circle::UnsortedSegmentMinOptions *>(builtin_options()) + : nullptr; + } + const circle::UnsortedSegmentSumOptions *builtin_options_as_UnsortedSegmentSumOptions() const + { + return builtin_options_type() == circle::BuiltinOptions_UnsortedSegmentSumOptions + ? static_cast<const circle::UnsortedSegmentSumOptions *>(builtin_options()) + : nullptr; + } + const circle::ATan2Options *builtin_options_as_ATan2Options() const + { + return builtin_options_type() == circle::BuiltinOptions_ATan2Options + ? static_cast<const circle::ATan2Options *>(builtin_options()) + : nullptr; + } + const circle::SignOptions *builtin_options_as_SignOptions() const + { + return builtin_options_type() == circle::BuiltinOptions_SignOptions + ? static_cast<const circle::SignOptions *>(builtin_options()) + : nullptr; + } + const circle::BitcastOptions *builtin_options_as_BitcastOptions() const + { + return builtin_options_type() == circle::BuiltinOptions_BitcastOptions + ? static_cast<const circle::BitcastOptions *>(builtin_options()) + : nullptr; + } + const circle::BitwiseXorOptions *builtin_options_as_BitwiseXorOptions() const + { + return builtin_options_type() == circle::BuiltinOptions_BitwiseXorOptions + ? static_cast<const circle::BitwiseXorOptions *>(builtin_options()) + : nullptr; + } + const circle::RightShiftOptions *builtin_options_as_RightShiftOptions() const + { + return builtin_options_type() == circle::BuiltinOptions_RightShiftOptions + ? static_cast<const circle::RightShiftOptions *>(builtin_options()) + : nullptr; + } const circle::BCQGatherOptions *builtin_options_as_BCQGatherOptions() const { return builtin_options_type() == circle::BuiltinOptions_BCQGatherOptions @@ -10301,6 +11099,86 @@ inline const circle::RandomOptions *Operator::builtin_options_as<circle::RandomO } template <> +inline const circle::BucketizeOptions * +Operator::builtin_options_as<circle::BucketizeOptions>() const +{ + return builtin_options_as_BucketizeOptions(); +} + +template <> +inline const circle::GeluOptions *Operator::builtin_options_as<circle::GeluOptions>() const +{ + return builtin_options_as_GeluOptions(); +} + +template <> +inline const circle::DynamicUpdateSliceOptions * +Operator::builtin_options_as<circle::DynamicUpdateSliceOptions>() const +{ + return builtin_options_as_DynamicUpdateSliceOptions(); +} + +template <> +inline const circle::UnsortedSegmentProdOptions * +Operator::builtin_options_as<circle::UnsortedSegmentProdOptions>() const +{ + return builtin_options_as_UnsortedSegmentProdOptions(); +} + +template <> +inline const circle::UnsortedSegmentMaxOptions * +Operator::builtin_options_as<circle::UnsortedSegmentMaxOptions>() const +{ + return builtin_options_as_UnsortedSegmentMaxOptions(); +} + +template <> +inline const circle::UnsortedSegmentMinOptions * +Operator::builtin_options_as<circle::UnsortedSegmentMinOptions>() const +{ + return builtin_options_as_UnsortedSegmentMinOptions(); +} + +template <> +inline const circle::UnsortedSegmentSumOptions * +Operator::builtin_options_as<circle::UnsortedSegmentSumOptions>() const +{ + return builtin_options_as_UnsortedSegmentSumOptions(); +} + +template <> +inline const circle::ATan2Options *Operator::builtin_options_as<circle::ATan2Options>() const +{ + return builtin_options_as_ATan2Options(); +} + +template <> +inline const circle::SignOptions *Operator::builtin_options_as<circle::SignOptions>() const +{ + return builtin_options_as_SignOptions(); +} + +template <> +inline const circle::BitcastOptions *Operator::builtin_options_as<circle::BitcastOptions>() const +{ + return builtin_options_as_BitcastOptions(); +} + +template <> +inline const circle::BitwiseXorOptions * +Operator::builtin_options_as<circle::BitwiseXorOptions>() const +{ + return builtin_options_as_BitwiseXorOptions(); +} + +template <> +inline const circle::RightShiftOptions * +Operator::builtin_options_as<circle::RightShiftOptions>() const +{ + return builtin_options_as_RightShiftOptions(); +} + +template <> inline const circle::BCQGatherOptions * Operator::builtin_options_as<circle::BCQGatherOptions>() const { @@ -11667,6 +12545,66 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast<const circle::RandomOptions *>(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_BucketizeOptions: + { + auto ptr = reinterpret_cast<const circle::BucketizeOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_GeluOptions: + { + auto ptr = reinterpret_cast<const circle::GeluOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_DynamicUpdateSliceOptions: + { + auto ptr = reinterpret_cast<const circle::DynamicUpdateSliceOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_UnsortedSegmentProdOptions: + { + auto ptr = reinterpret_cast<const circle::UnsortedSegmentProdOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_UnsortedSegmentMaxOptions: + { + auto ptr = reinterpret_cast<const circle::UnsortedSegmentMaxOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_UnsortedSegmentMinOptions: + { + auto ptr = reinterpret_cast<const circle::UnsortedSegmentMinOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_UnsortedSegmentSumOptions: + { + auto ptr = reinterpret_cast<const circle::UnsortedSegmentSumOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ATan2Options: + { + auto ptr = reinterpret_cast<const circle::ATan2Options *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SignOptions: + { + auto ptr = reinterpret_cast<const circle::SignOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_BitcastOptions: + { + auto ptr = reinterpret_cast<const circle::BitcastOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_BitwiseXorOptions: + { + auto ptr = reinterpret_cast<const circle::BitwiseXorOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_RightShiftOptions: + { + auto ptr = reinterpret_cast<const circle::RightShiftOptions *>(obj); + return verifier.VerifyTable(ptr); + } case BuiltinOptions_BCQGatherOptions: { auto ptr = reinterpret_cast<const circle::BCQGatherOptions *>(obj); diff --git a/runtime/onert/frontend/nnapi/wrapper/ANeuralNetworksModel.cc b/runtime/onert/frontend/nnapi/wrapper/ANeuralNetworksModel.cc index a641368ec..837dac954 100644 --- a/runtime/onert/frontend/nnapi/wrapper/ANeuralNetworksModel.cc +++ b/runtime/onert/frontend/nnapi/wrapper/ANeuralNetworksModel.cc @@ -261,8 +261,8 @@ void ANeuralNetworksModel::setOptionalOperand(const onert::ir::OperandIndex idx) void ANeuralNetworksModel::fillOptionalOperand(void) { - _graph->operations().iterate([&](const onert::ir::OperationIndex &, onert::ir::Operation &node) { - for (auto input : node.getInputs()) + _graph->operations().iterate([&](const onert::ir::OperationIndex &, onert::ir::IOperation &node) { + for (auto &&input : node.getInputs()) { // TODO fill default value for optional operands if (_optional_operands.find(input) != _optional_operands.end()) diff --git a/runtime/onert/frontend/tflite/src/tflite_schema_generated.h b/runtime/onert/frontend/tflite/src/tflite_schema_generated.h index cec5bce74..7ad3c75bd 100644 --- a/runtime/onert/frontend/tflite/src/tflite_schema_generated.h +++ b/runtime/onert/frontend/tflite/src/tflite_schema_generated.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2019-2023 Samsung Electronics Co., Ltd. All Rights Reserved * Copyright 2018 The TensorFlow Authors. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,7 +14,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - // automatically generated by the FlatBuffers compiler, do not modify #ifndef FLATBUFFERS_GENERATED_TFLITESCHEMA_ONERT_TFLITE_H_ @@ -391,6 +390,27 @@ struct AssignVariableOptionsBuilder; struct RandomOptions; struct RandomOptionsBuilder; +struct BucketizeOptions; +struct BucketizeOptionsBuilder; + +struct GeluOptions; +struct GeluOptionsBuilder; + +struct DynamicUpdateSliceOptions; +struct DynamicUpdateSliceOptionsBuilder; + +struct UnsortedSegmentProdOptions; +struct UnsortedSegmentProdOptionsBuilder; + +struct UnsortedSegmentMaxOptions; +struct UnsortedSegmentMaxOptionsBuilder; + +struct UnsortedSegmentSumOptions; +struct UnsortedSegmentSumOptionsBuilder; + +struct ATan2Options; +struct ATan2OptionsBuilder; + struct OperatorCode; struct OperatorCodeBuilder; @@ -433,32 +453,34 @@ enum TensorType : int8_t TensorType_RESOURCE = 13, TensorType_VARIANT = 14, TensorType_UINT32 = 15, + TensorType_UINT16 = 16, TensorType_MIN = TensorType_FLOAT32, - TensorType_MAX = TensorType_UINT32 + TensorType_MAX = TensorType_UINT16 }; -inline const TensorType (&EnumValuesTensorType())[16] +inline const TensorType (&EnumValuesTensorType())[17] { static const TensorType values[] = { TensorType_FLOAT32, TensorType_FLOAT16, TensorType_INT32, TensorType_UINT8, TensorType_INT64, TensorType_STRING, TensorType_BOOL, TensorType_INT16, TensorType_COMPLEX64, TensorType_INT8, TensorType_FLOAT64, TensorType_COMPLEX128, - TensorType_UINT64, TensorType_RESOURCE, TensorType_VARIANT, TensorType_UINT32}; + TensorType_UINT64, TensorType_RESOURCE, TensorType_VARIANT, TensorType_UINT32, + TensorType_UINT16}; return values; } inline const char *const *EnumNamesTensorType() { - static const char *const names[17] = {"FLOAT32", "FLOAT16", "INT32", "UINT8", "INT64", + static const char *const names[18] = {"FLOAT32", "FLOAT16", "INT32", "UINT8", "INT64", "STRING", "BOOL", "INT16", "COMPLEX64", "INT8", "FLOAT64", "COMPLEX128", "UINT64", "RESOURCE", "VARIANT", - "UINT32", nullptr}; + "UINT32", "UINT16", nullptr}; return names; } inline const char *EnumNameTensorType(TensorType e) { - if (flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_UINT32)) + if (flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_UINT16)) return ""; const size_t index = static_cast<size_t>(e); return EnumNamesTensorType()[index]; @@ -745,11 +767,21 @@ enum BuiltinOperator : int32_t BuiltinOperator_ASSIGN_VARIABLE = 144, BuiltinOperator_BROADCAST_ARGS = 145, BuiltinOperator_RANDOM_STANDARD_NORMAL = 146, + BuiltinOperator_BUCKETIZE = 147, + BuiltinOperator_RANDOM_UNIFORM = 148, + BuiltinOperator_MULTINOMIAL = 149, + BuiltinOperator_GELU = 150, + BuiltinOperator_DYNAMIC_UPDATE_SLICE = 151, + BuiltinOperator_RELU_0_TO_1 = 152, + BuiltinOperator_UNSORTED_SEGMENT_PROD = 153, + BuiltinOperator_UNSORTED_SEGMENT_MAX = 154, + BuiltinOperator_UNSORTED_SEGMENT_SUM = 155, + BuiltinOperator_ATAN2 = 156, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_RANDOM_STANDARD_NORMAL + BuiltinOperator_MAX = BuiltinOperator_ATAN2 }; -inline const BuiltinOperator (&EnumValuesBuiltinOperator())[147] +inline const BuiltinOperator (&EnumValuesBuiltinOperator())[157] { static const BuiltinOperator values[] = {BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -897,13 +929,23 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[147] BuiltinOperator_READ_VARIABLE, BuiltinOperator_ASSIGN_VARIABLE, BuiltinOperator_BROADCAST_ARGS, - BuiltinOperator_RANDOM_STANDARD_NORMAL}; + BuiltinOperator_RANDOM_STANDARD_NORMAL, + BuiltinOperator_BUCKETIZE, + BuiltinOperator_RANDOM_UNIFORM, + BuiltinOperator_MULTINOMIAL, + BuiltinOperator_GELU, + BuiltinOperator_DYNAMIC_UPDATE_SLICE, + BuiltinOperator_RELU_0_TO_1, + BuiltinOperator_UNSORTED_SEGMENT_PROD, + BuiltinOperator_UNSORTED_SEGMENT_MAX, + BuiltinOperator_UNSORTED_SEGMENT_SUM, + BuiltinOperator_ATAN2}; return values; } inline const char *const *EnumNamesBuiltinOperator() { - static const char *const names[148] = {"ADD", + static const char *const names[158] = {"ADD", "AVERAGE_POOL_2D", "CONCATENATION", "CONV_2D", @@ -1050,13 +1092,23 @@ inline const char *const *EnumNamesBuiltinOperator() "ASSIGN_VARIABLE", "BROADCAST_ARGS", "RANDOM_STANDARD_NORMAL", + "BUCKETIZE", + "RANDOM_UNIFORM", + "MULTINOMIAL", + "GELU", + "DYNAMIC_UPDATE_SLICE", + "RELU_0_TO_1", + "UNSORTED_SEGMENT_PROD", + "UNSORTED_SEGMENT_MAX", + "UNSORTED_SEGMENT_SUM", + "ATAN2", nullptr}; return names; } inline const char *EnumNameBuiltinOperator(BuiltinOperator e) { - if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_RANDOM_STANDARD_NORMAL)) + if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_ATAN2)) return ""; const size_t index = static_cast<size_t>(e); return EnumNamesBuiltinOperator()[index]; @@ -1179,11 +1231,18 @@ enum BuiltinOptions : uint8_t BuiltinOptions_ReadVariableOptions = 112, BuiltinOptions_AssignVariableOptions = 113, BuiltinOptions_RandomOptions = 114, + BuiltinOptions_BucketizeOptions = 115, + BuiltinOptions_GeluOptions = 116, + BuiltinOptions_DynamicUpdateSliceOptions = 117, + BuiltinOptions_UnsortedSegmentProdOptions = 118, + BuiltinOptions_UnsortedSegmentMaxOptions = 119, + BuiltinOptions_UnsortedSegmentSumOptions = 120, + BuiltinOptions_ATan2Options = 121, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_RandomOptions + BuiltinOptions_MAX = BuiltinOptions_ATan2Options }; -inline const BuiltinOptions (&EnumValuesBuiltinOptions())[115] +inline const BuiltinOptions (&EnumValuesBuiltinOptions())[122] { static const BuiltinOptions values[] = {BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -1299,13 +1358,20 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[115] BuiltinOptions_VarHandleOptions, BuiltinOptions_ReadVariableOptions, BuiltinOptions_AssignVariableOptions, - BuiltinOptions_RandomOptions}; + BuiltinOptions_RandomOptions, + BuiltinOptions_BucketizeOptions, + BuiltinOptions_GeluOptions, + BuiltinOptions_DynamicUpdateSliceOptions, + BuiltinOptions_UnsortedSegmentProdOptions, + BuiltinOptions_UnsortedSegmentMaxOptions, + BuiltinOptions_UnsortedSegmentSumOptions, + BuiltinOptions_ATan2Options}; return values; } inline const char *const *EnumNamesBuiltinOptions() { - static const char *const names[116] = {"NONE", + static const char *const names[123] = {"NONE", "Conv2DOptions", "DepthwiseConv2DOptions", "ConcatEmbeddingsOptions", @@ -1420,13 +1486,20 @@ inline const char *const *EnumNamesBuiltinOptions() "ReadVariableOptions", "AssignVariableOptions", "RandomOptions", + "BucketizeOptions", + "GeluOptions", + "DynamicUpdateSliceOptions", + "UnsortedSegmentProdOptions", + "UnsortedSegmentMaxOptions", + "UnsortedSegmentSumOptions", + "ATan2Options", nullptr}; return names; } inline const char *EnumNameBuiltinOptions(BuiltinOptions e) { - if (flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_RandomOptions)) + if (flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_ATan2Options)) return ""; const size_t index = static_cast<size_t>(e); return EnumNamesBuiltinOptions()[index]; @@ -2007,6 +2080,41 @@ template <> struct BuiltinOptionsTraits<onert_tflite::RandomOptions> static const BuiltinOptions enum_value = BuiltinOptions_RandomOptions; }; +template <> struct BuiltinOptionsTraits<onert_tflite::BucketizeOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_BucketizeOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::GeluOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_GeluOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::DynamicUpdateSliceOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_DynamicUpdateSliceOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::UnsortedSegmentProdOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_UnsortedSegmentProdOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::UnsortedSegmentMaxOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_UnsortedSegmentMaxOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::UnsortedSegmentSumOptions> +{ + static const BuiltinOptions enum_value = BuiltinOptions_UnsortedSegmentSumOptions; +}; + +template <> struct BuiltinOptionsTraits<onert_tflite::ATan2Options> +{ + static const BuiltinOptions enum_value = BuiltinOptions_ATan2Options; +}; + bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); bool VerifyBuiltinOptionsVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, @@ -2917,7 +3025,8 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table VT_QUANTIZATION = 12, VT_IS_VARIABLE = 14, VT_SPARSITY = 16, - VT_SHAPE_SIGNATURE = 18 + VT_SHAPE_SIGNATURE = 18, + VT_HAS_RANK = 20 }; const flatbuffers::Vector<int32_t> *shape() const { @@ -2945,6 +3054,7 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_SHAPE_SIGNATURE); } + bool has_rank() const { return GetField<uint8_t>(VT_HAS_RANK, 0) != 0; } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_SHAPE) && @@ -2954,7 +3064,7 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table verifier.VerifyTable(quantization()) && VerifyField<uint8_t>(verifier, VT_IS_VARIABLE) && VerifyOffset(verifier, VT_SPARSITY) && verifier.VerifyTable(sparsity()) && VerifyOffset(verifier, VT_SHAPE_SIGNATURE) && verifier.VerifyVector(shape_signature()) && - verifier.EndTable(); + VerifyField<uint8_t>(verifier, VT_HAS_RANK) && verifier.EndTable(); } }; @@ -2992,6 +3102,10 @@ struct TensorBuilder { fbb_.AddOffset(Tensor::VT_SHAPE_SIGNATURE, shape_signature); } + void add_has_rank(bool has_rank) + { + fbb_.AddElement<uint8_t>(Tensor::VT_HAS_RANK, static_cast<uint8_t>(has_rank), 0); + } explicit TensorBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -3010,7 +3124,7 @@ inline flatbuffers::Offset<Tensor> CreateTensor( flatbuffers::Offset<flatbuffers::String> name = 0, flatbuffers::Offset<onert_tflite::QuantizationParameters> quantization = 0, bool is_variable = false, flatbuffers::Offset<onert_tflite::SparsityParameters> sparsity = 0, - flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape_signature = 0) + flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape_signature = 0, bool has_rank = false) { TensorBuilder builder_(_fbb); builder_.add_shape_signature(shape_signature); @@ -3019,6 +3133,7 @@ inline flatbuffers::Offset<Tensor> CreateTensor( builder_.add_name(name); builder_.add_buffer(buffer); builder_.add_shape(shape); + builder_.add_has_rank(has_rank); builder_.add_is_variable(is_variable); builder_.add_type(type); return builder_.Finish(); @@ -3030,13 +3145,13 @@ inline flatbuffers::Offset<Tensor> CreateTensorDirect( const char *name = nullptr, flatbuffers::Offset<onert_tflite::QuantizationParameters> quantization = 0, bool is_variable = false, flatbuffers::Offset<onert_tflite::SparsityParameters> sparsity = 0, - const std::vector<int32_t> *shape_signature = nullptr) + const std::vector<int32_t> *shape_signature = nullptr, bool has_rank = false) { auto shape__ = shape ? _fbb.CreateVector<int32_t>(*shape) : 0; auto name__ = name ? _fbb.CreateString(name) : 0; auto shape_signature__ = shape_signature ? _fbb.CreateVector<int32_t>(*shape_signature) : 0; return onert_tflite::CreateTensor(_fbb, shape__, type, buffer, name__, quantization, is_variable, - sparsity, shape_signature__); + sparsity, shape_signature__, has_rank); } struct Conv2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table @@ -8325,12 +8440,12 @@ struct RandomOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table VT_SEED = 4, VT_SEED2 = 6 }; - int32_t seed() const { return GetField<int32_t>(VT_SEED, 0); } - int32_t seed2() const { return GetField<int32_t>(VT_SEED2, 0); } + int64_t seed() const { return GetField<int64_t>(VT_SEED, 0); } + int64_t seed2() const { return GetField<int64_t>(VT_SEED2, 0); } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && VerifyField<int32_t>(verifier, VT_SEED) && - VerifyField<int32_t>(verifier, VT_SEED2) && verifier.EndTable(); + return VerifyTableStart(verifier) && VerifyField<int64_t>(verifier, VT_SEED) && + VerifyField<int64_t>(verifier, VT_SEED2) && verifier.EndTable(); } }; @@ -8339,8 +8454,8 @@ struct RandomOptionsBuilder typedef RandomOptions Table; flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_seed(int32_t seed) { fbb_.AddElement<int32_t>(RandomOptions::VT_SEED, seed, 0); } - void add_seed2(int32_t seed2) { fbb_.AddElement<int32_t>(RandomOptions::VT_SEED2, seed2, 0); } + void add_seed(int64_t seed) { fbb_.AddElement<int64_t>(RandomOptions::VT_SEED, seed, 0); } + void add_seed2(int64_t seed2) { fbb_.AddElement<int64_t>(RandomOptions::VT_SEED2, seed2, 0); } explicit RandomOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -8354,7 +8469,7 @@ struct RandomOptionsBuilder }; inline flatbuffers::Offset<RandomOptions> CreateRandomOptions(flatbuffers::FlatBufferBuilder &_fbb, - int32_t seed = 0, int32_t seed2 = 0) + int64_t seed = 0, int64_t seed2 = 0) { RandomOptionsBuilder builder_(_fbb); builder_.add_seed2(seed2); @@ -8362,6 +8477,270 @@ inline flatbuffers::Offset<RandomOptions> CreateRandomOptions(flatbuffers::FlatB return builder_.Finish(); } +struct BucketizeOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table +{ + typedef BucketizeOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_BOUNDARIES = 4 + }; + const flatbuffers::Vector<float> *boundaries() const + { + return GetPointer<const flatbuffers::Vector<float> *>(VT_BOUNDARIES); + } + bool Verify(flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_BOUNDARIES) && + verifier.VerifyVector(boundaries()) && verifier.EndTable(); + } +}; + +struct BucketizeOptionsBuilder +{ + typedef BucketizeOptions Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_boundaries(flatbuffers::Offset<flatbuffers::Vector<float>> boundaries) + { + fbb_.AddOffset(BucketizeOptions::VT_BOUNDARIES, boundaries); + } + explicit BucketizeOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset<BucketizeOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<BucketizeOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<BucketizeOptions> +CreateBucketizeOptions(flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset<flatbuffers::Vector<float>> boundaries = 0) +{ + BucketizeOptionsBuilder builder_(_fbb); + builder_.add_boundaries(boundaries); + return builder_.Finish(); +} + +inline flatbuffers::Offset<BucketizeOptions> +CreateBucketizeOptionsDirect(flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<float> *boundaries = nullptr) +{ + auto boundaries__ = boundaries ? _fbb.CreateVector<float>(*boundaries) : 0; + return onert_tflite::CreateBucketizeOptions(_fbb, boundaries__); +} + +struct GeluOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table +{ + typedef GeluOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_APPROXIMATE = 4 + }; + bool approximate() const { return GetField<uint8_t>(VT_APPROXIMATE, 0) != 0; } + bool Verify(flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField<uint8_t>(verifier, VT_APPROXIMATE) && + verifier.EndTable(); + } +}; + +struct GeluOptionsBuilder +{ + typedef GeluOptions Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_approximate(bool approximate) + { + fbb_.AddElement<uint8_t>(GeluOptions::VT_APPROXIMATE, static_cast<uint8_t>(approximate), 0); + } + explicit GeluOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset<GeluOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<GeluOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<GeluOptions> CreateGeluOptions(flatbuffers::FlatBufferBuilder &_fbb, + bool approximate = false) +{ + GeluOptionsBuilder builder_(_fbb); + builder_.add_approximate(approximate); + return builder_.Finish(); +} + +struct DynamicUpdateSliceOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table +{ + typedef DynamicUpdateSliceOptionsBuilder Builder; + bool Verify(flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct DynamicUpdateSliceOptionsBuilder +{ + typedef DynamicUpdateSliceOptions Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit DynamicUpdateSliceOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset<DynamicUpdateSliceOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<DynamicUpdateSliceOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<DynamicUpdateSliceOptions> +CreateDynamicUpdateSliceOptions(flatbuffers::FlatBufferBuilder &_fbb) +{ + DynamicUpdateSliceOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct UnsortedSegmentProdOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table +{ + typedef UnsortedSegmentProdOptionsBuilder Builder; + bool Verify(flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct UnsortedSegmentProdOptionsBuilder +{ + typedef UnsortedSegmentProdOptions Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit UnsortedSegmentProdOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset<UnsortedSegmentProdOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<UnsortedSegmentProdOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<UnsortedSegmentProdOptions> +CreateUnsortedSegmentProdOptions(flatbuffers::FlatBufferBuilder &_fbb) +{ + UnsortedSegmentProdOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct UnsortedSegmentMaxOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table +{ + typedef UnsortedSegmentMaxOptionsBuilder Builder; + bool Verify(flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct UnsortedSegmentMaxOptionsBuilder +{ + typedef UnsortedSegmentMaxOptions Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit UnsortedSegmentMaxOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset<UnsortedSegmentMaxOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<UnsortedSegmentMaxOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<UnsortedSegmentMaxOptions> +CreateUnsortedSegmentMaxOptions(flatbuffers::FlatBufferBuilder &_fbb) +{ + UnsortedSegmentMaxOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct UnsortedSegmentSumOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table +{ + typedef UnsortedSegmentSumOptionsBuilder Builder; + bool Verify(flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct UnsortedSegmentSumOptionsBuilder +{ + typedef UnsortedSegmentSumOptions Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit UnsortedSegmentSumOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset<UnsortedSegmentSumOptions> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<UnsortedSegmentSumOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<UnsortedSegmentSumOptions> +CreateUnsortedSegmentSumOptions(flatbuffers::FlatBufferBuilder &_fbb) +{ + UnsortedSegmentSumOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +struct ATan2Options FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table +{ + typedef ATan2OptionsBuilder Builder; + bool Verify(flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && verifier.EndTable(); + } +}; + +struct ATan2OptionsBuilder +{ + typedef ATan2Options Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit ATan2OptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset<ATan2Options> Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<ATan2Options>(end); + return o; + } +}; + +inline flatbuffers::Offset<ATan2Options> CreateATan2Options(flatbuffers::FlatBufferBuilder &_fbb) +{ + ATan2OptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + struct OperatorCode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef OperatorCodeBuilder Builder; @@ -9173,6 +9552,52 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table ? static_cast<const onert_tflite::RandomOptions *>(builtin_options()) : nullptr; } + const onert_tflite::BucketizeOptions *builtin_options_as_BucketizeOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_BucketizeOptions + ? static_cast<const onert_tflite::BucketizeOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::GeluOptions *builtin_options_as_GeluOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_GeluOptions + ? static_cast<const onert_tflite::GeluOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::DynamicUpdateSliceOptions * + builtin_options_as_DynamicUpdateSliceOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_DynamicUpdateSliceOptions + ? static_cast<const onert_tflite::DynamicUpdateSliceOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::UnsortedSegmentProdOptions * + builtin_options_as_UnsortedSegmentProdOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_UnsortedSegmentProdOptions + ? static_cast<const onert_tflite::UnsortedSegmentProdOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::UnsortedSegmentMaxOptions * + builtin_options_as_UnsortedSegmentMaxOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_UnsortedSegmentMaxOptions + ? static_cast<const onert_tflite::UnsortedSegmentMaxOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::UnsortedSegmentSumOptions * + builtin_options_as_UnsortedSegmentSumOptions() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_UnsortedSegmentSumOptions + ? static_cast<const onert_tflite::UnsortedSegmentSumOptions *>(builtin_options()) + : nullptr; + } + const onert_tflite::ATan2Options *builtin_options_as_ATan2Options() const + { + return builtin_options_type() == onert_tflite::BuiltinOptions_ATan2Options + ? static_cast<const onert_tflite::ATan2Options *>(builtin_options()) + : nullptr; + } const flatbuffers::Vector<uint8_t> *custom_options() const { return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS); @@ -10004,6 +10429,55 @@ Operator::builtin_options_as<onert_tflite::RandomOptions>() const return builtin_options_as_RandomOptions(); } +template <> +inline const onert_tflite::BucketizeOptions * +Operator::builtin_options_as<onert_tflite::BucketizeOptions>() const +{ + return builtin_options_as_BucketizeOptions(); +} + +template <> +inline const onert_tflite::GeluOptions * +Operator::builtin_options_as<onert_tflite::GeluOptions>() const +{ + return builtin_options_as_GeluOptions(); +} + +template <> +inline const onert_tflite::DynamicUpdateSliceOptions * +Operator::builtin_options_as<onert_tflite::DynamicUpdateSliceOptions>() const +{ + return builtin_options_as_DynamicUpdateSliceOptions(); +} + +template <> +inline const onert_tflite::UnsortedSegmentProdOptions * +Operator::builtin_options_as<onert_tflite::UnsortedSegmentProdOptions>() const +{ + return builtin_options_as_UnsortedSegmentProdOptions(); +} + +template <> +inline const onert_tflite::UnsortedSegmentMaxOptions * +Operator::builtin_options_as<onert_tflite::UnsortedSegmentMaxOptions>() const +{ + return builtin_options_as_UnsortedSegmentMaxOptions(); +} + +template <> +inline const onert_tflite::UnsortedSegmentSumOptions * +Operator::builtin_options_as<onert_tflite::UnsortedSegmentSumOptions>() const +{ + return builtin_options_as_UnsortedSegmentSumOptions(); +} + +template <> +inline const onert_tflite::ATan2Options * +Operator::builtin_options_as<onert_tflite::ATan2Options>() const +{ + return builtin_options_as_ATan2Options(); +} + struct OperatorBuilder { typedef Operator Table; @@ -11351,6 +11825,41 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast<const onert_tflite::RandomOptions *>(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_BucketizeOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::BucketizeOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_GeluOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::GeluOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_DynamicUpdateSliceOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::DynamicUpdateSliceOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_UnsortedSegmentProdOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::UnsortedSegmentProdOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_UnsortedSegmentMaxOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::UnsortedSegmentMaxOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_UnsortedSegmentSumOptions: + { + auto ptr = reinterpret_cast<const onert_tflite::UnsortedSegmentSumOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ATan2Options: + { + auto ptr = reinterpret_cast<const onert_tflite::ATan2Options *>(obj); + return verifier.VerifyTable(ptr); + } default: return true; } diff --git a/runtime/onert/frontend/tflite/tflite_schema.fbs b/runtime/onert/frontend/tflite/tflite_schema.fbs index 9bffb4f3c..f7997528e 100644 --- a/runtime/onert/frontend/tflite/tflite_schema.fbs +++ b/runtime/onert/frontend/tflite/tflite_schema.fbs @@ -18,6 +18,10 @@ // Version 1: Add subgraphs to schema. // Version 2: Rename operators to conform to NN API. // Version 3: Move buffer data from Model.Subgraph.Tensors to Model.Buffers. +// Version 3a: Add new builtin op code field. Has backward compatibility with +// version 3. +// Version 3b: Rename fields in SignatureDef. Has backward compatibility with +// version 3 and 3a. // Change namespace to onert_tflite namespace onert_tflite; @@ -43,6 +47,15 @@ enum TensorType : byte { COMPLEX64 = 8, INT8 = 9, FLOAT64 = 10, + COMPLEX128 = 11, + UINT64 = 12, + // Experimental: Resource and variant types are experimental, that are subject + // to change. Do not implement custom kernels using resource & variant types + // now. + RESOURCE = 13, + VARIANT = 14, + UINT32 = 15, + UINT16 = 16 } // Custom quantization parameters for experimenting with new quantization @@ -209,14 +222,18 @@ table Tensor { // Encodes `shape` with unknown dimensions. Unknown dimensions are // represented with -1. shape_signature:[int]; // Optional. + + // If false, the rank or the number of tensor dimensions is unknown. + // If false, "shape" must be []. + has_rank: bool = false; } // A list of builtin operators. Builtin operators are slightly faster than custom // ones, but not by much. Moreover, while custom operators accept an opaque // object containing configuration parameters, builtins have a predetermined // set of acceptable options. - -enum BuiltinOperator : byte { +// LINT.IfChange +enum BuiltinOperator : int32 { ADD = 0, AVERAGE_POOL_2D = 1, CONCATENATION = 2, @@ -249,7 +266,6 @@ enum BuiltinOperator : byte { SPACE_TO_DEPTH = 26, SVDF = 27, TANH = 28, - // TODO(aselle): Consider rename to CONCATENATE_EMBEDDINGS CONCAT_EMBEDDINGS = 29, SKIP_GRAM = 30, CALL = 31, @@ -350,9 +366,39 @@ enum BuiltinOperator : byte { SELECT_V2 = 123, DENSIFY = 124, SEGMENT_SUM = 125, - BATCH_MATMUL = 126 -} - + BATCH_MATMUL = 126, + PLACEHOLDER_FOR_GREATER_OP_CODES = 127, + CUMSUM = 128, + CALL_ONCE = 129, + BROADCAST_TO = 130, + RFFT2D = 131, + CONV_3D = 132, + IMAG=133, + REAL=134, + COMPLEX_ABS=135, + HASHTABLE = 136, + HASHTABLE_FIND = 137, + HASHTABLE_IMPORT = 138, + HASHTABLE_SIZE = 139, + REDUCE_ALL = 140, + CONV_3D_TRANSPOSE = 141, + VAR_HANDLE = 142, + READ_VARIABLE = 143, + ASSIGN_VARIABLE = 144, + BROADCAST_ARGS = 145, + RANDOM_STANDARD_NORMAL = 146, + BUCKETIZE = 147, + RANDOM_UNIFORM = 148, + MULTINOMIAL = 149, + GELU = 150, + DYNAMIC_UPDATE_SLICE = 151, + RELU_0_TO_1 = 152, + UNSORTED_SEGMENT_PROD = 153, + UNSORTED_SEGMENT_MAX = 154, + UNSORTED_SEGMENT_SUM = 155, + ATAN2 = 156 +} +// LINT.ThenChange(nnapi_linter/linter.proto) // Options for the builtin operators. union BuiltinOptions { @@ -456,11 +502,34 @@ union BuiltinOptions { SelectV2Options, DensifyOptions, SegmentSumOptions, - BatchMatMulOptions -} - + BatchMatMulOptions, + CumsumOptions, + CallOnceOptions, + BroadcastToOptions, + Rfft2dOptions, + Conv3DOptions, + HashtableOptions, + HashtableFindOptions, + HashtableImportOptions, + HashtableSizeOptions, + VarHandleOptions, + ReadVariableOptions, + AssignVariableOptions, + RandomOptions, + BucketizeOptions, + GeluOptions, + DynamicUpdateSliceOptions, + UnsortedSegmentProdOptions, + UnsortedSegmentMaxOptions, + UnsortedSegmentSumOptions, + ATan2Options +} + +// LINT.IfChange enum Padding : byte { SAME, VALID } +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td) +// LINT.IfChange enum ActivationFunctionType : byte { NONE = 0, RELU = 1, @@ -469,6 +538,7 @@ enum ActivationFunctionType : byte { TANH = 4, SIGN_BIT = 5, } +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td) table Conv2DOptions { padding:Padding; @@ -479,6 +549,18 @@ table Conv2DOptions { dilation_h_factor:int = 1; } +// Options for both Conv3D and Conv3DTranspose. +table Conv3DOptions { + padding:Padding; + stride_d:int; + stride_w:int; + stride_h:int; + fused_activation_function:ActivationFunctionType; + dilation_d_factor:int = 1; + dilation_w_factor:int = 1; + dilation_h_factor:int = 1; +} + table Pool2DOptions { padding:Padding; stride_w:int; @@ -548,10 +630,12 @@ table BidirectionalSequenceRNNOptions { asymmetric_quantize_inputs:bool; } +// LINT.IfChange enum FullyConnectedOptionsWeightsFormat: byte { DEFAULT = 0, SHUFFLED4x16INT8 = 1, } +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td) // An implementation of TensorFlow fully_connected (a.k.a Dense) layer. table FullyConnectedOptions { @@ -584,6 +668,8 @@ table ConcatenationOptions { table AddOptions { fused_activation_function:ActivationFunctionType; + // Parameters supported by version 3. + pot_scale_int16:bool = true; } table MulOptions { @@ -591,6 +677,7 @@ table MulOptions { } table L2NormOptions { + // This field is currently ignored in the L2 Norm Op. fused_activation_function:ActivationFunctionType; } @@ -601,12 +688,14 @@ table LocalResponseNormalizationOptions { beta:float; } +// LINT.IfChange enum LSTMKernelType : byte { // Full LSTM kernel which supports peephole and projection. FULL = 0, // Basic LSTM kernels. Equivalent to TensorFlow BasicLSTMCell. BASIC = 1, } +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td) // An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell table LSTMOptions { @@ -664,6 +753,7 @@ table ResizeBilinearOptions { table ResizeNearestNeighborOptions { align_corners: bool; + half_pixel_centers: bool; } // A call operation options @@ -704,6 +794,8 @@ table DepthToSpaceOptions { table SubOptions { fused_activation_function:ActivationFunctionType; + // Parameters supported by version 5 + pot_scale_int16:bool = true; } table DivOptions { @@ -725,6 +817,8 @@ table EmbeddingLookupSparseOptions { table GatherOptions { axis: int; + // Parameters for Gather version 5 or above. + batch_dims: int = 0; } table TransposeOptions { @@ -901,12 +995,14 @@ table LeakyReluOptions { table SquaredDifferenceOptions { } +// LINT.IfChange enum MirrorPadMode : byte { // Doesn't include borders. REFLECT = 0, // Includes borders. SYMMETRIC = 1, } +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td) table MirrorPadOptions { mode:MirrorPadMode; @@ -947,6 +1043,10 @@ table IfOptions { else_subgraph_index:int; } +table CallOnceOptions { + init_subgraph_index:int; +} + table WhileOptions { cond_subgraph_index:int; body_subgraph_index:int; @@ -971,19 +1071,100 @@ table SegmentSumOptions { } table BatchMatMulOptions { - adjoint_lhs:bool; - adjoint_rhs:bool; + adj_x:bool; + adj_y:bool; + // Parameters for BatchMatMul version 4 or above. + // If set to true, then weights-only op will use asymmetric quantization for + // inputs. + asymmetric_quantize_inputs: bool; +} + +table CumsumOptions { + exclusive:bool; + reverse:bool; +} + +table BroadcastToOptions { +} + +table Rfft2dOptions { +} + +table HashtableOptions { + // The identity of hash tables. This identity will be used across different + // subgraphs in the same interpreter instance. + table_id:int; + key_dtype:TensorType; + value_dtype:TensorType; +} + +table HashtableFindOptions { +} + +table HashtableImportOptions { +} + +table HashtableSizeOptions { +} + +table VarHandleOptions { + container:string; + shared_name:string; } +table ReadVariableOptions { +} + +table AssignVariableOptions { +} + +table RandomOptions { + seed: long; + seed2: long; +} + +table BucketizeOptions { + boundaries: [float]; // The bucket boundaries. +} + +table GeluOptions { + approximate: bool; +} + +table DynamicUpdateSliceOptions { +} + +table UnsortedSegmentProdOptions { +} + +table UnsortedSegmentMaxOptions { +} + +table UnsortedSegmentSumOptions { +} + +table ATan2Options { +} + + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { - builtin_code:BuiltinOperator; + // This field is for backward compatibility. This field will be used when + // the value of the extended builtin_code field has less than + // BulitinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES. + deprecated_builtin_code:byte; custom_code:string; // The version of the operator. The version need to be bumped whenever new // parameters are introduced into an op. version:int = 1; + + // This field is introduced for resolving op builtin code shortage problem + // (the original BuiltinOperator enum field was represented as a byte). + // This field will be used when the value of the extended builtin_code field + // has greater than BulitinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES. + builtin_code:BuiltinOperator; } enum CustomOptionsFormat : byte { @@ -1062,6 +1243,35 @@ table Metadata { buffer:uint; } +// Map from an alias name of tensor to tensor index in the graph. +// This is used in Signature def. +table TensorMap { + // Represents the alias to use for this tensor. + name:string; + + // The actual tensor index in the primary graph, that 'name' corresponds to. + tensor_index:uint; +} + +// This corresponds to SignatureDef in Tensorflow SavedModel. +// The SignatureDef will be part of the SavedModel provided for conversion. +table SignatureDef { + // Named inputs for this signature. + inputs:[TensorMap]; + + // Named outputs for this signature. + outputs:[TensorMap]; + + // Key value which was in the Tensorflow SavedModel SignatureDef map. + signature_key:string; + + // Model tag, deprecated. + deprecated_tag:string (deprecated); + + // Index of subgraphs that corresponds to the exported method. + subgraph_index:uint; +} + table Model { // Version of the schema. version:uint; @@ -1090,6 +1300,9 @@ table Model { // Metadata about the model. metadata:[Metadata]; + + // Optional SignatureDefs for the model. + signature_defs:[SignatureDef]; } root_type Model; diff --git a/runtime/onert/odc/CMakeLists.txt b/runtime/onert/odc/CMakeLists.txt new file mode 100644 index 000000000..e48878dc3 --- /dev/null +++ b/runtime/onert/odc/CMakeLists.txt @@ -0,0 +1,39 @@ +# Luci library is not supported is on cross build +if(CMAKE_CROSSCOMPILING) + return() +endif() + +nnfw_find_package(Luci QUIET) +if(NOT Luci_FOUND) + message(STATUS "Luci not found. Skip onert_odc") + return() +endif() + +file(GLOB_RECURSE SOURCES "*.cc") +file(GLOB_RECURSE TESTS "*.test.cc") +list(REMOVE_ITEM SOURCES ${TESTS}) + +add_library(onert_odc SHARED ${SOURCES}) +target_link_libraries(onert_odc PRIVATE onert_core luci::import luci::export luci::pass luci::loco) +target_link_libraries(onert_odc PRIVATE nnfw_common) +target_link_libraries(onert_odc PRIVATE nnfw_coverage) + +install(TARGETS onert_odc LIBRARY DESTINATION lib/nnfw/odc) + +if(NOT ENABLE_TEST) + return() +endif(NOT ENABLE_TEST) + +# Unit Tests +set(TEST_ONERT_ODC test_onert_odc) + +add_executable(${TEST_ONERT_ODC} ${TESTS}) + +target_link_libraries(${TEST_ONERT_ODC} onert_odc) +# Requires linking nnfw_coverage: check header coverage +target_link_libraries(${TEST_ONERT_CORE} nnfw_coverage) +target_link_libraries(${TEST_ONERT_ODC} gtest gtest_main dl ${LIB_PTHREAD}) +target_include_directories(${TEST_ONERT_ODC} PRIVATE $<TARGET_PROPERTY:onert_odc,INCLUDE_DIRECTORIES>) + +add_test(${TEST_ONERT_ODC} ${TEST_ONERT_ODC}) +install(TARGETS ${TEST_ONERT_ODC} DESTINATION unittest) diff --git a/runtime/onert/odc/Quantizer.cc b/runtime/onert/odc/Quantizer.cc new file mode 100644 index 000000000..b8aec97ce --- /dev/null +++ b/runtime/onert/odc/Quantizer.cc @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Quantizer.h" + +#include <luci/ImporterEx.h> +#include <luci/CircleQuantizer.h> +#include <luci/CircleExporter.h> +#include <luci/CircleFileExpContract.h> + +#include <iostream> + +extern "C" onert::odc::IQuantizer *create_quantizer() { return new onert::odc::Quantizer(); } +extern "C" void destroy_quantizer(onert::odc::IQuantizer *quantizer) { delete quantizer; } + +namespace onert +{ +namespace odc +{ + +int Quantizer::quantize(const char *in, const char *out, bool is_q16) +{ + // Load model from the file + luci::ImporterEx importerex; + auto module = importerex.importVerifyModule(std::string(in)); + if (module.get() == nullptr) + return 1; + + luci::CircleQuantizer quantizer; + auto options = quantizer.options(); + { + options->enable(luci::CircleQuantizer::Options::Algorithm::QuantizeWeights); + + using AlgorithmParameters = luci::CircleQuantizer::Options::AlgorithmParameters; + options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32"); + options->param(AlgorithmParameters::Quantize_output_model_dtype, is_q16 ? "int16" : "int8"); + options->param(AlgorithmParameters::Quantize_granularity, "channel"); + } + + for (size_t idx = 0; idx < module->size(); ++idx) + { + auto graph = module->graph(idx); + + // quantize the graph + quantizer.quantize(graph); + + // Skip validate + // TODO Validate if needed +#if 0 + if (!luci::validate(graph)) + { + std::cerr << "ERROR: Quantized graph is invalid" << std::endl; + return 1; + } +#endif + } + + // Export to output Circle file + luci::CircleExporter exporter; + luci::CircleFileExpContract contract(module.get(), std::string(out)); + + if (!exporter.invoke(&contract)) + return 1; + + // Return 0 when luci::CircleQuantizer::Options::Algorithm::QuantizeWeights is ready + return 0; +} + +} // namespace odc +} // namespace onert diff --git a/runtime/onert/odc/Quantizer.h b/runtime/onert/odc/Quantizer.h new file mode 100644 index 000000000..8a03f59d5 --- /dev/null +++ b/runtime/onert/odc/Quantizer.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_ODC_QUANTIZE_H__ +#define __ONERT_ODC_QUANTIZE_H__ + +#include "odc/IQuantizer.h" + +namespace onert +{ +namespace odc +{ + +class Quantizer : public IQuantizer +{ +public: + Quantizer() = default; + ~Quantizer() = default; + + int quantize(const char *in, const char *out, bool is_q16); +}; + +} // namespace odc +} // namespace onert + +#endif // __ONERT_ODC_QUANTIZE_H__ diff --git a/runtime/onert/odc/Quantizer.test.cc b/runtime/onert/odc/Quantizer.test.cc new file mode 100644 index 000000000..22baed576 --- /dev/null +++ b/runtime/onert/odc/Quantizer.test.cc @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Quantizer.h" + +#include <gtest/gtest.h> + +using namespace onert::odc; + +// Test model input path is not set +TEST(odc_Quantizer, neg_model_input_path) +{ + Quantizer quantizer; + ASSERT_THROW(quantizer.quantize(nullptr, "out", false), std::logic_error); +} + +// Test model output path is not set +TEST(odc_Quantizer, neg_model_output_path) +{ + Quantizer quantizer; + ASSERT_NE(quantizer.quantize("in", nullptr, false), 0); +} + +// Test invalid model input path +TEST(odc_Quantizer, neg_invalid_model_input_path) +{ + Quantizer quantizer; + ASSERT_NE(quantizer.quantize("invalid_model_input_path.circle", "out", false), 0); +} |