diff options
Diffstat (limited to 'runtime/onert/api/src/nnfw_api_internal.h')
-rw-r--r-- | runtime/onert/api/src/nnfw_api_internal.h | 46 |
1 files changed, 39 insertions, 7 deletions
diff --git a/runtime/onert/api/src/nnfw_api_internal.h b/runtime/onert/api/src/nnfw_api_internal.h index 8e2c2fba6..62791765e 100644 --- a/runtime/onert/api/src/nnfw_api_internal.h +++ b/runtime/onert/api/src/nnfw_api_internal.h @@ -39,7 +39,7 @@ class Execution; } // namespace exec namespace ir { -class Graph; +struct IGraph; class Model; class NNPkg; } // namespace ir @@ -48,6 +48,10 @@ namespace compiler struct CompilerArtifact; class CompilerOptions; } // namespace compiler +namespace odc +{ +class QuantizeManager; +} // namespace odc } // namespace onert struct nnfw_session @@ -90,11 +94,13 @@ private: */ enum class State { - INITIALIZED, //< Session is initialized and nothing has done to it - MODEL_LOADED, //< Model is loaded - PREPARED, //< Prepared(compiled) for execution - RUNNING, //< Execution is in progress (only for asynchronous execution) - FINISHED_RUN //< Executed at least once + INITIALIZED, //< Session is initialized and nothing has done to it + MODEL_LOADED, //< Model is loaded + PREPARED, //< Prepared(compiled) for execution + RUNNING, //< Execution is in progress (only for asynchronous execution) + FINISHED_RUN, //< Executed at least once + PREPARED_TRAINING, //< Prepared for training + FINISHED_TRAINING //< Trained at least once }; public: @@ -160,8 +166,25 @@ public: */ NNFW_STATUS set_backends_per_operation(const char *backend_settings); +#ifdef ONERT_TRAIN + NNFW_STATUS train_prepare(const nnfw_train_info *info); + NNFW_STATUS train_input_tensorinfo(uint32_t index, nnfw_tensorinfo *ti); + NNFW_STATUS train_expected_tensorinfo(uint32_t index, nnfw_tensorinfo *ti); + NNFW_STATUS train_set_input(uint32_t index, const void *input, + const nnfw_tensorinfo *input_tensorinfo); + NNFW_STATUS train_set_expected(uint32_t index, const void *expected, + const nnfw_tensorinfo *expected_tensorinfo); + NNFW_STATUS train_run(bool update_weights); + NNFW_STATUS train_get_loss(uint32_t index, float *loss); + NNFW_STATUS train_export_circle(const char *path); +#endif // ONERT_TRAIN + + NNFW_STATUS set_quantization_type(NNFW_QUANTIZE_TYPE qtype); + NNFW_STATUS set_quantized_model_path(const char *path); + NNFW_STATUS quantize(); + private: - const onert::ir::Graph *primary_subgraph(); + const onert::ir::IGraph *primary_subgraph(); uint32_t getInputSize(); uint32_t getOutputSize(); @@ -171,6 +194,11 @@ private: bool isStateRunning(); bool isStateFinishedRun(); bool isStatePreparedOrFinishedRun(); +#ifdef ONERT_TRAIN + bool isStatePreparedTraining(); + bool isStateFinishedTraining(); + bool isStatePreparedOrFinishedTraining(); +#endif // ONERT_TRAIN private: State _state{State::INITIALIZED}; @@ -180,6 +208,10 @@ private: std::unique_ptr<onert::exec::Execution> _execution; std::shared_ptr<onert::api::CustomKernelRegistry> _kernel_registry; std::vector<std::thread> _threads; +#ifdef ONERT_TRAIN + uint32_t _training_step{0}; +#endif // ONERT_TRAIN + std::unique_ptr<onert::odc::QuantizeManager> _quant_manager; }; #endif // __API_NNFW_API_INTERNAL_H__ |