summaryrefslogtreecommitdiff
path: root/compiler/luci/pass/src/FoldDensifyPass.test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci/pass/src/FoldDensifyPass.test.cpp')
-rw-r--r--compiler/luci/pass/src/FoldDensifyPass.test.cpp158
1 files changed, 158 insertions, 0 deletions
diff --git a/compiler/luci/pass/src/FoldDensifyPass.test.cpp b/compiler/luci/pass/src/FoldDensifyPass.test.cpp
new file mode 100644
index 000000000..2f9736f49
--- /dev/null
+++ b/compiler/luci/pass/src/FoldDensifyPass.test.cpp
@@ -0,0 +1,158 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FoldDensifyPass.h"
+#include "PassTestGraphs.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+class FoldDensifyPassGraph : public luci::ConstantFoldingAddTestGraph
+{
+public:
+ FoldDensifyPassGraph(std::initializer_list<uint32_t> shape)
+ : luci::ConstantFoldingAddTestGraph(shape, loco::DataType::FLOAT32)
+ {
+ _densify = _g.nodes()->create<luci::CircleDensify>();
+ _x = _g.nodes()->create<luci::CircleConst>();
+
+ _densify->dtype(loco::DataType::FLOAT32);
+ _x->dtype(loco::DataType::FLOAT32);
+
+ _densify->shape(shape);
+ _x->shape(shape);
+
+ _densify->input(_x);
+
+ _densify->name("densify");
+ _x->name("x");
+ }
+
+ loco::Node *createFoldedPattern() override { return _densify; }
+
+public:
+ void fill_const_dense(void)
+ {
+ uint32_t num_elems = 1;
+ for (uint32_t r = 0; r < _x->rank(); ++r)
+ num_elems *= _x->dim(r).value();
+
+ _x->size<loco::DataType::FLOAT32>(num_elems);
+ for (uint32_t i = 0; i < num_elems; i++)
+ _x->at<loco::DataType::FLOAT32>(i) = static_cast<float>(i + 1);
+ }
+
+ void fill_const_sparse(void)
+ {
+ // fill 4x4 of
+ // [[1 0 0 0]
+ // [0 2 0 0]
+ // [0 0 3 0]
+ // [0 0 0 4]]
+
+ // values of 1.0, 2.0, 3.0, 4.0
+ uint32_t udata[] = {0x3f800000, 0x40000000, 0x40400000, 0x40800000};
+ float *fdata = reinterpret_cast<float *>(udata);
+
+ _x->size<loco::DataType::FLOAT32>(4);
+ for (uint32_t i = 0; i < 4; i++)
+ _x->at<loco::DataType::FLOAT32>(i) = fdata[i];
+
+ auto sparsityparam = std::make_unique<luci::SparsityParam>();
+ sparsityparam->traversal_order = std::vector<int32_t>({0, 1});
+ sparsityparam->block_map = std::vector<int32_t>({});
+
+ auto dm0 = luci::DimMetaData(luci::DimensionType::DENSE, 4);
+
+ std::vector<int32_t> as_vec = {0, 1, 2, 3, 4};
+ std::vector<int32_t> ai_vec = {0, 1, 2, 3};
+ auto as = luci::SparseIndexVector(luci::SparseIndexVectorType::I32, as_vec);
+ auto ai = luci::SparseIndexVector(luci::SparseIndexVectorType::I32, ai_vec);
+ auto dm1 = luci::DimMetaData(luci::DimensionType::SPARSE_CSR, 0, as, ai);
+ sparsityparam->dim_metadata.emplace_back(dm0);
+ sparsityparam->dim_metadata.emplace_back(dm1);
+
+ _x->sparsityparam(std::move(sparsityparam));
+ }
+
+protected:
+ luci::CircleDensify *_densify = nullptr;
+ luci::CircleConst *_x = nullptr;
+};
+
+class FoldDensifyPassGraphTest : public FoldDensifyPassGraph, public ::testing::Test
+{
+public:
+ FoldDensifyPassGraphTest() : FoldDensifyPassGraph({4, 4}) {}
+
+ virtual void SetUp() { init(); }
+};
+
+} // namespace
+
+TEST(FoldDensifyPassGraph, name)
+{
+ luci::FoldDensifyPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST_F(FoldDensifyPassGraphTest, no_sparsity_param_NEG)
+{
+ fill_const_dense();
+
+ luci::FoldDensifyPass pass;
+ while (pass.run(graph()))
+ ;
+
+ auto folded_const = getFoldedPattern();
+ EXPECT_EQ(nullptr, folded_const);
+}
+
+TEST_F(FoldDensifyPassGraphTest, sparsity_param)
+{
+ fill_const_sparse();
+
+ luci::FoldDensifyPass pass;
+ while (pass.run(graph()))
+ ;
+
+ auto folded_const = getFoldedPattern();
+ EXPECT_NE(nullptr, folded_const);
+
+ EXPECT_EQ(2, folded_const->rank());
+ EXPECT_EQ(4, folded_const->dim(0).value());
+ EXPECT_EQ(4, folded_const->dim(1).value());
+ EXPECT_EQ(16, folded_const->size<loco::DataType::FLOAT32>());
+ for (int y = 0; y < 4; ++y)
+ {
+ for (int x = 0; x < 4; ++x)
+ {
+ float ovalue = folded_const->at<loco::DataType::FLOAT32>(y * 4 + x);
+ float fvalue = 0.0;
+ if (x == y)
+ {
+ // diagonal position
+ fvalue = static_cast<float>(y + 1);
+ }
+ EXPECT_EQ(fvalue, ovalue);
+ }
+ }
+}