diff options
Diffstat (limited to 'runtimes/neurun/frontend/api/wrapper/nnfw_api.cc')
-rw-r--r-- | runtimes/neurun/frontend/api/wrapper/nnfw_api.cc | 60 |
1 files changed, 14 insertions, 46 deletions
diff --git a/runtimes/neurun/frontend/api/wrapper/nnfw_api.cc b/runtimes/neurun/frontend/api/wrapper/nnfw_api.cc index eb8aef039..80a768e4e 100644 --- a/runtimes/neurun/frontend/api/wrapper/nnfw_api.cc +++ b/runtimes/neurun/frontend/api/wrapper/nnfw_api.cc @@ -275,38 +275,16 @@ NNFW_STATUS nnfw_session::register_custom_operation(const std::string &id, return NNFW_STATUS_NO_ERROR; } -static std::string get_backend_string(NNFW_BACKEND backend) -{ - static std::unordered_map<NNFW_BACKEND, std::string> backend_map = { - {NNFW_BACKEND_ACL_CL, "acl_cl"}, - {NNFW_BACKEND_ACL_NEON, "acl_neon"}, - {NNFW_BACKEND_CPU, "cpu"}, - {NNFW_BACKEND_SRCN, "srcn"}, - }; - - auto b = backend_map.find(backend); - - if (b == backend_map.end()) - { - // this return value is handled by a caller to return error code - return std::string(""); - } - else - { - return b->second; - } -} - -static std::string get_op_backend_string(NNFW_OP op) +static std::string get_op_backend_string(std::string op) { // TODO: Provide complete set of operations - static std::unordered_map<NNFW_OP, std::string> operation_map = { - {NNFW_OP_TransposeConvNode, "OP_BACKEND_TransposeConvNode"}, - {NNFW_OP_Conv2DNode, "OP_BACKEND_Conv2DNode"}, - {NNFW_OP_DepthwiseConv2DNode, "OP_BACKEND_DepthwiseConv2DNode"}, - {NNFW_OP_MeanNode, "OP_BACKEND_MeanNode"}, - {NNFW_OP_AvgPool2DNode, "OP_BACKEND_AvgPool2DNode"}, - {NNFW_OP_MaxPool2DNode, "OP_BACKEND_MaxPool2DNode"}, + static std::unordered_map<std::string, std::string> operation_map = { + {"TRANSPOSE_CONV", "OP_BACKEND_TransposeConvNode"}, + {"CONV_2D", "OP_BACKEND_Conv2DNode"}, + {"DEPTHWISE_CONV_2D", "OP_BACKEND_DepthwiseConv2DNode"}, + {"MEAN", "OP_BACKEND_MeanNode"}, + {"AVERAGE_POOL_2D", "OP_BACKEND_AvgPool2DNode"}, + {"MAX_POOL_2D", "OP_BACKEND_MaxPool2DNode"}, }; auto n = operation_map.find(op); @@ -322,19 +300,11 @@ static std::string get_op_backend_string(NNFW_OP op) } } -NNFW_STATUS nnfw_session::set_default_backend(NNFW_BACKEND backend) +NNFW_STATUS nnfw_session::set_default_backend(std::string backend) { try { - std::string bs = get_backend_string(backend); - if (bs.empty()) - { - return NNFW_STATUS_ERROR; - } - else - { - _source->set("OP_BACKEND_ALLOPS", bs); - } + _source->set("OP_BACKEND_ALLOPS", backend); } catch (...) { @@ -344,21 +314,19 @@ NNFW_STATUS nnfw_session::set_default_backend(NNFW_BACKEND backend) return NNFW_STATUS_NO_ERROR; } -NNFW_STATUS nnfw_session::set_op_backend(NNFW_OP op, NNFW_BACKEND backend) +NNFW_STATUS nnfw_session::set_op_backend(std::string op, std::string backend) { try { - std::string key, value; - value = get_backend_string(backend); - key = get_op_backend_string(op); + auto key = get_op_backend_string(op); - if (key.empty() || value.empty()) + if (key.empty()) { return NNFW_STATUS_ERROR; } - _source->set(key, value); + _source->set(key, backend); } catch (...) { |