summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVladimir Plazun/AI Tools Lab /SRR/Engineer/삼성전자 <v.plazun@samsung.com>2019-09-05 07:28:40 +0300
committer이춘석/On-Device Lab(SR)/Staff Engineer/삼성전자 <chunseok.lee@samsung.com>2019-09-05 13:28:40 +0900
commit88d2864284b28bd44c8419a99337f5377040f379 (patch)
treedd41ca410d624eee3c63940f5dad40f67cf3be8c
parent562b8ee874fae59795d3ce48919217f5d5202378 (diff)
downloadnnfw-88d2864284b28bd44c8419a99337f5377040f379.tar.gz
nnfw-88d2864284b28bd44c8419a99337f5377040f379.tar.bz2
nnfw-88d2864284b28bd44c8419a99337f5377040f379.zip
[custom op] Implement custom op registration using public API (#7181)
* [custom op] Implement custom op registration using public API Adds public custom operation registration method Signed-off-by: Vladimir Plazun <v.plazun@samsung.com> * format fix
-rw-r--r--runtimes/include/nnfw_dev.h3
-rw-r--r--runtimes/neurun/frontend/api/nnfw_dev.cc13
-rw-r--r--runtimes/neurun/frontend/api/wrapper/nnfw_api.cc12
-rw-r--r--runtimes/neurun/frontend/api/wrapper/nnfw_api.hpp5
4 files changed, 32 insertions, 1 deletions
diff --git a/runtimes/include/nnfw_dev.h b/runtimes/include/nnfw_dev.h
index b4e62fa3b..5886377d8 100644
--- a/runtimes/include/nnfw_dev.h
+++ b/runtimes/include/nnfw_dev.h
@@ -58,4 +58,7 @@ typedef struct
nnfw_custom_eval eval_function;
} custom_kernel_registration_info;
+NNFW_STATUS nnfw_register_custom_op_info(nnfw_session *session, const char *id,
+ custom_kernel_registration_info *info);
+
#endif // __NNFW_DEV_H__
diff --git a/runtimes/neurun/frontend/api/nnfw_dev.cc b/runtimes/neurun/frontend/api/nnfw_dev.cc
index 3c8178b43..642cb1fb3 100644
--- a/runtimes/neurun/frontend/api/nnfw_dev.cc
+++ b/runtimes/neurun/frontend/api/nnfw_dev.cc
@@ -163,3 +163,16 @@ NNFW_STATUS nnfw_output_tensorinfo(nnfw_session *session, uint32_t index,
{
return session->output_tensorinfo(index, tensor_info);
}
+
+/*
+ * Register custom operation
+ * @param session session to register this operation
+ * @param id operation id
+ * @param info registration info ( eval function, etc. )
+ * @return NNFW_STATUS_NO_ERROR if successful
+ */
+NNFW_STATUS nnfw_register_custom_op_info(nnfw_session *session, const char *id,
+ custom_kernel_registration_info *info)
+{
+ return session->register_custom_operation(id, info->eval_function);
+}
diff --git a/runtimes/neurun/frontend/api/wrapper/nnfw_api.cc b/runtimes/neurun/frontend/api/wrapper/nnfw_api.cc
index 49474fd4c..6b8a53dda 100644
--- a/runtimes/neurun/frontend/api/wrapper/nnfw_api.cc
+++ b/runtimes/neurun/frontend/api/wrapper/nnfw_api.cc
@@ -23,7 +23,9 @@
#include <limits.h>
#include <stdint.h>
-nnfw_session::nnfw_session() : _graph{nullptr}, _execution{nullptr}
+nnfw_session::nnfw_session()
+ : _graph{nullptr}, _execution{nullptr},
+ _kernel_registry{new neurun::backend::custom::KernelRegistry}
{
// DO NOTHING
}
@@ -53,6 +55,7 @@ NNFW_STATUS nnfw_session::load_model_from_file(const char *package_dir)
auto model = nnfw::cpp14::make_unique<neurun::model::Model>();
_graph = std::make_shared<neurun::graph::Graph>(std::move(model));
+ _graph->bindKernelRegistry(_kernel_registry);
tflite_loader::Loader loader(*_graph);
auto model_file_path = package_dir + std::string("/") + models[0].asString(); // first model
loader.loadFromFile(model_file_path.c_str());
@@ -255,3 +258,10 @@ NNFW_STATUS nnfw_session::output_tensorinfo(uint32_t index, nnfw_tensorinfo *ti)
}
return NNFW_STATUS_NO_ERROR;
}
+
+NNFW_STATUS nnfw_session::register_custom_operation(const std::string &id,
+ nnfw_custom_eval eval_func)
+{
+ _kernel_registry->registerKernel(id, eval_func);
+ return NNFW_STATUS_NO_ERROR;
+}
diff --git a/runtimes/neurun/frontend/api/wrapper/nnfw_api.hpp b/runtimes/neurun/frontend/api/wrapper/nnfw_api.hpp
index 8efdd77b9..f616f8100 100644
--- a/runtimes/neurun/frontend/api/wrapper/nnfw_api.hpp
+++ b/runtimes/neurun/frontend/api/wrapper/nnfw_api.hpp
@@ -18,9 +18,11 @@
#define __API_NNFW_INTERNAL_HPP__
#include "nnfw.h"
+#include "nnfw_dev.h"
#include "compiler/Compiler.h"
#include "exec/Execution.h"
#include "graph/Graph.h"
+#include "backend/CustomKernelRegistry.h"
struct nnfw_session
{
@@ -40,9 +42,12 @@ public:
NNFW_STATUS input_tensorinfo(uint32_t index, nnfw_tensorinfo *ti);
NNFW_STATUS output_tensorinfo(uint32_t index, nnfw_tensorinfo *ti);
+ NNFW_STATUS register_custom_operation(const std::string &id, nnfw_custom_eval eval_func);
+
private:
std::shared_ptr<neurun::graph::Graph> _graph;
std::shared_ptr<neurun::exec::Execution> _execution;
+ std::shared_ptr<neurun::backend::custom::KernelRegistry> _kernel_registry;
};
#endif // __API_NNFW_INTERNAL_HPP__