diff options
Diffstat (limited to 'tests/nnfw_api')
-rw-r--r-- | tests/nnfw_api/src/CircleGen.cc | 8 | ||||
-rw-r--r-- | tests/nnfw_api/src/CircleGen.h | 1 | ||||
-rw-r--r-- | tests/nnfw_api/src/one_op_tests/If.cc | 76 |
3 files changed, 85 insertions, 0 deletions
diff --git a/tests/nnfw_api/src/CircleGen.cc b/tests/nnfw_api/src/CircleGen.cc index 19cb95f37..6ebd5a945 100644 --- a/tests/nnfw_api/src/CircleGen.cc +++ b/tests/nnfw_api/src/CircleGen.cc @@ -183,6 +183,14 @@ uint32_t CircleGen::addOperatorWhile(const OperatorParams ¶ms, uint32_t cond circle::BuiltinOptions_WhileOptions, options); } +uint32_t CircleGen::addOperatorIf(const OperatorParams ¶ms, uint32_t then_subg, + uint32_t else_subg) +{ + auto options = circle::CreateIfOptions(_fbb, then_subg, else_subg).Union(); + return addOperatorWithOptions(params, circle::BuiltinOperator_IF, + circle::BuiltinOptions_IfOptions, options); +} + // NOTE Please add addOperator functions ABOVE this lie // // % How to add a new addOperatorXXX fuction diff --git a/tests/nnfw_api/src/CircleGen.h b/tests/nnfw_api/src/CircleGen.h index 09ca5a5db..8cb83bce3 100644 --- a/tests/nnfw_api/src/CircleGen.h +++ b/tests/nnfw_api/src/CircleGen.h @@ -108,6 +108,7 @@ public: uint32_t addOperatorRank(const OperatorParams ¶ms); uint32_t addOperatorResizeNearestNeighbor(const OperatorParams ¶ms); uint32_t addOperatorWhile(const OperatorParams ¶ms, uint32_t cond_subg, uint32_t body_subg); + uint32_t addOperatorIf(const OperatorParams ¶ms, uint32_t cond_subg, uint32_t body_subg); // NOTE Please add addOperator functions ABOVE this lie // ===== Add Operator methods end ===== diff --git a/tests/nnfw_api/src/one_op_tests/If.cc b/tests/nnfw_api/src/one_op_tests/If.cc new file mode 100644 index 000000000..2eb1d3420 --- /dev/null +++ b/tests/nnfw_api/src/one_op_tests/If.cc @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "GenModelTest.h" + +#include <memory> + +TEST_F(GenModelTest, OneOp_If) +{ + // The model looks just like the below pseudocode + // + // function model(x) + // { + // if (x < 0.0) + // return -100.0; + // else + // return 100.0; + // } + + CircleGen cgen; + + // constant buffers + std::vector<float> comp_data{0.0}; + uint32_t comp_buf = cgen.addBuffer(comp_data); + std::vector<float> then_data{-100}; + uint32_t then_buf = cgen.addBuffer(then_data); + std::vector<float> else_data{100}; + uint32_t else_buf = cgen.addBuffer(else_data); + + // primary subgraph + { + int x = cgen.addTensor({{1}, circle::TensorType_FLOAT32}); + int comp = cgen.addTensor({{1}, circle::TensorType_FLOAT32, comp_buf}); + int cond = cgen.addTensor({{1}, circle::TensorType_BOOL}); + cgen.addOperatorLess({{x, comp}, {cond}}); + + int ret = cgen.addTensor({{1}, circle::TensorType_FLOAT32}); + cgen.addOperatorIf({{cond}, {ret}}, 1, 2); + + cgen.setInputsAndOutputs({x}, {ret}); + } + + // then subgraph + { + cgen.nextSubgraph(); + int ret = cgen.addTensor({{1}, circle::TensorType_FLOAT32, then_buf}); + cgen.setInputsAndOutputs({}, {ret}); + } + + // else subgraph + { + cgen.nextSubgraph(); + int ret = cgen.addTensor({{1}, circle::TensorType_FLOAT32, else_buf}); + cgen.setInputsAndOutputs({}, {ret}); + } + + _context = std::make_unique<GenModelTestContext>(cgen.finish()); + _context->addTestCase({{{-1.0}}, {{-100.0}}}); + _context->addTestCase({{{1.0}}, {{100.0}}}); + _context->setBackends({"cpu"}); + + SUCCEED(); +} |