diff options
Diffstat (limited to 'compiler/luci/pass/src/PropagateQuantParamPass.test.cpp')
-rw-r--r-- | compiler/luci/pass/src/PropagateQuantParamPass.test.cpp | 118 |
1 files changed, 118 insertions, 0 deletions
diff --git a/compiler/luci/pass/src/PropagateQuantParamPass.test.cpp b/compiler/luci/pass/src/PropagateQuantParamPass.test.cpp new file mode 100644 index 000000000..15adbfc01 --- /dev/null +++ b/compiler/luci/pass/src/PropagateQuantParamPass.test.cpp @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/PropagateQuantParamPass.h" + +#include <luci/IR/CircleNodes.h> + +#include <gtest/gtest.h> + +namespace +{ + +void addQuantParam(luci::CircleNode *node, const std::vector<float> &scale, + const std::vector<int64_t> &zp) +{ + assert(node->quantparam() == nullptr); + + auto quantparam = std::make_unique<luci::CircleQuantParam>(); + quantparam->scale = scale; + quantparam->zerop = zp; + node->quantparam(std::move(quantparam)); +} + +/** + * Simple graph for test + * + * BEFORE + * + * [Conv] (qparam 1) + * | + * [Reshape] (qparam 2) + * + * AFTER + * + * [Conv] (qparam 2) + * | + * [Reshape] (qparam 2) + * + */ +class SimpleGraph +{ +public: + SimpleGraph() + { + input = g.nodes()->create<luci::CircleInput>(); + conv = g.nodes()->create<luci::CircleConv2D>(); + reshape = g.nodes()->create<luci::CircleReshape>(); + output = g.nodes()->create<luci::CircleOutput>(); + + auto graph_input = g.inputs()->create(); + input->index(graph_input->index()); + auto graph_output = g.outputs()->create(); + output->index(graph_output->index()); + + addQuantParam(conv, {0.1, 0.2, 0.3}, {0, 10, 20}); + addQuantParam(reshape, {0.2, 0.4, 0.6}, {-10, 0, 10}); + + conv->input(input); + reshape->tensor(conv); + output->from(reshape); + } + +public: + loco::Graph g; + luci::CircleInput *input; + luci::CircleConv2D *conv; + luci::CircleReshape *reshape; + luci::CircleOutput *output; +}; + +} // namespace + +TEST(PropagateQuantParam, simple) +{ + SimpleGraph g; + + luci::PropagateQuantParamPass pass; + while (pass.run(&g.g)) + ; + + EXPECT_FLOAT_EQ(0.2, g.conv->quantparam()->scale[0]); + EXPECT_FLOAT_EQ(0.4, g.conv->quantparam()->scale[1]); + EXPECT_FLOAT_EQ(0.6, g.conv->quantparam()->scale[2]); + EXPECT_EQ(-10, g.conv->quantparam()->zerop[0]); + EXPECT_EQ(0, g.conv->quantparam()->zerop[1]); + EXPECT_EQ(10, g.conv->quantparam()->zerop[2]); +} + +TEST(PropagateQuantParam, wrong_op_NEG) +{ + SimpleGraph g; + g.output->from(g.conv); + g.reshape->drop(); + + luci::PropagateQuantParamPass pass; + while (pass.run(&g.g)) + ; + + EXPECT_FLOAT_EQ(0.1, g.conv->quantparam()->scale[0]); + EXPECT_FLOAT_EQ(0.2, g.conv->quantparam()->scale[1]); + EXPECT_FLOAT_EQ(0.3, g.conv->quantparam()->scale[2]); + EXPECT_EQ(0, g.conv->quantparam()->zerop[0]); + EXPECT_EQ(10, g.conv->quantparam()->zerop[1]); + EXPECT_EQ(20, g.conv->quantparam()->zerop[2]); +} |