summaryrefslogtreecommitdiff
path: root/runtime/onert/api/src/nnfw_api_internal.h
diff options
context:
space:
mode:
Diffstat (limited to 'runtime/onert/api/src/nnfw_api_internal.h')
-rw-r--r--runtime/onert/api/src/nnfw_api_internal.h46
1 files changed, 39 insertions, 7 deletions
diff --git a/runtime/onert/api/src/nnfw_api_internal.h b/runtime/onert/api/src/nnfw_api_internal.h
index 8e2c2fba6..62791765e 100644
--- a/runtime/onert/api/src/nnfw_api_internal.h
+++ b/runtime/onert/api/src/nnfw_api_internal.h
@@ -39,7 +39,7 @@ class Execution;
} // namespace exec
namespace ir
{
-class Graph;
+struct IGraph;
class Model;
class NNPkg;
} // namespace ir
@@ -48,6 +48,10 @@ namespace compiler
struct CompilerArtifact;
class CompilerOptions;
} // namespace compiler
+namespace odc
+{
+class QuantizeManager;
+} // namespace odc
} // namespace onert
struct nnfw_session
@@ -90,11 +94,13 @@ private:
*/
enum class State
{
- INITIALIZED, //< Session is initialized and nothing has done to it
- MODEL_LOADED, //< Model is loaded
- PREPARED, //< Prepared(compiled) for execution
- RUNNING, //< Execution is in progress (only for asynchronous execution)
- FINISHED_RUN //< Executed at least once
+ INITIALIZED, //< Session is initialized and nothing has done to it
+ MODEL_LOADED, //< Model is loaded
+ PREPARED, //< Prepared(compiled) for execution
+ RUNNING, //< Execution is in progress (only for asynchronous execution)
+ FINISHED_RUN, //< Executed at least once
+ PREPARED_TRAINING, //< Prepared for training
+ FINISHED_TRAINING //< Trained at least once
};
public:
@@ -160,8 +166,25 @@ public:
*/
NNFW_STATUS set_backends_per_operation(const char *backend_settings);
+#ifdef ONERT_TRAIN
+ NNFW_STATUS train_prepare(const nnfw_train_info *info);
+ NNFW_STATUS train_input_tensorinfo(uint32_t index, nnfw_tensorinfo *ti);
+ NNFW_STATUS train_expected_tensorinfo(uint32_t index, nnfw_tensorinfo *ti);
+ NNFW_STATUS train_set_input(uint32_t index, const void *input,
+ const nnfw_tensorinfo *input_tensorinfo);
+ NNFW_STATUS train_set_expected(uint32_t index, const void *expected,
+ const nnfw_tensorinfo *expected_tensorinfo);
+ NNFW_STATUS train_run(bool update_weights);
+ NNFW_STATUS train_get_loss(uint32_t index, float *loss);
+ NNFW_STATUS train_export_circle(const char *path);
+#endif // ONERT_TRAIN
+
+ NNFW_STATUS set_quantization_type(NNFW_QUANTIZE_TYPE qtype);
+ NNFW_STATUS set_quantized_model_path(const char *path);
+ NNFW_STATUS quantize();
+
private:
- const onert::ir::Graph *primary_subgraph();
+ const onert::ir::IGraph *primary_subgraph();
uint32_t getInputSize();
uint32_t getOutputSize();
@@ -171,6 +194,11 @@ private:
bool isStateRunning();
bool isStateFinishedRun();
bool isStatePreparedOrFinishedRun();
+#ifdef ONERT_TRAIN
+ bool isStatePreparedTraining();
+ bool isStateFinishedTraining();
+ bool isStatePreparedOrFinishedTraining();
+#endif // ONERT_TRAIN
private:
State _state{State::INITIALIZED};
@@ -180,6 +208,10 @@ private:
std::unique_ptr<onert::exec::Execution> _execution;
std::shared_ptr<onert::api::CustomKernelRegistry> _kernel_registry;
std::vector<std::thread> _threads;
+#ifdef ONERT_TRAIN
+ uint32_t _training_step{0};
+#endif // ONERT_TRAIN
+ std::unique_ptr<onert::odc::QuantizeManager> _quant_manager;
};
#endif // __API_NNFW_API_INTERNAL_H__