diff options
Diffstat (limited to 'compiler/luci/service/src/GraphBlock.test.cpp')
-rw-r--r-- | compiler/luci/service/src/GraphBlock.test.cpp | 246 |
1 files changed, 246 insertions, 0 deletions
diff --git a/compiler/luci/service/src/GraphBlock.test.cpp b/compiler/luci/service/src/GraphBlock.test.cpp new file mode 100644 index 000000000..1da8c18fa --- /dev/null +++ b/compiler/luci/service/src/GraphBlock.test.cpp @@ -0,0 +1,246 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "GraphBlock.h" + +#include "Check.h" + +#include <loco.h> + +#include <memory> + +// TODO Change all Canonical nodes to Circle nodes + +namespace +{ + +template <luci::FeatureLayout T> loco::Permutation<loco::Domain::Feature> perm(); + +template <> loco::Permutation<loco::Domain::Feature> perm<luci::FeatureLayout::NHWC>() +{ + // Make NHWC permutation for encoder and decoder + loco::Permutation<loco::Domain::Feature> NHWC; + + NHWC.axis(loco::FeatureAxis::Count) = 0; + NHWC.axis(loco::FeatureAxis::Height) = 1; + NHWC.axis(loco::FeatureAxis::Width) = 2; + NHWC.axis(loco::FeatureAxis::Depth) = 3; + + return NHWC; +} + +template <luci::FilterLayout T> loco::Permutation<loco::Domain::Filter> perm(); + +template <> loco::Permutation<loco::Domain::Filter> perm<luci::FilterLayout::HWIO>() +{ + loco::Permutation<loco::Domain::Filter> HWIO; // a.k.a., HWCN + + HWIO.axis(loco::FilterAxis::Height) = 0; + HWIO.axis(loco::FilterAxis::Width) = 1; + HWIO.axis(loco::FilterAxis::Depth) = 2; + HWIO.axis(loco::FilterAxis::Count) = 3; + + return HWIO; +} + +template <> loco::Permutation<loco::Domain::Filter> perm<luci::FilterLayout::OHWI>() +{ + + // Make NHWC permutation for encoder and decoder + loco::Permutation<loco::Domain::Filter> OHWI; // a.k.a., NHWC + + OHWI.axis(loco::FilterAxis::Count) = 0; + OHWI.axis(loco::FilterAxis::Height) = 1; + OHWI.axis(loco::FilterAxis::Width) = 2; + OHWI.axis(loco::FilterAxis::Depth) = 3; + + return OHWI; +} + +template <luci::DepthwiseFilterLayout T> loco::Permutation<loco::Domain::DepthwiseFilter> perm(); + +template <> +loco::Permutation<loco::Domain::DepthwiseFilter> perm<luci::DepthwiseFilterLayout::HWCM>() +{ + loco::Permutation<loco::Domain::DepthwiseFilter> HWCM; + + HWCM.axis(loco::DepthwiseFilterAxis::Height) = 0; + HWCM.axis(loco::DepthwiseFilterAxis::Width) = 1; + HWCM.axis(loco::DepthwiseFilterAxis::Depth) = 2; + HWCM.axis(loco::DepthwiseFilterAxis::Multiplier) = 3; + + return HWCM; +} + +template <luci::MatrixLayout T> loco::Permutation<loco::Domain::Matrix> perm(); + +template <> loco::Permutation<loco::Domain::Matrix> perm<luci::MatrixLayout::HW>() +{ + loco::Permutation<loco::Domain::Matrix> HW; + + HW.axis(loco::MatrixAxis::Height) = 0; + HW.axis(loco::MatrixAxis::Width) = 1; + + return HW; +} + +template <> loco::Permutation<loco::Domain::Matrix> perm<luci::MatrixLayout::WH>() +{ + loco::Permutation<loco::Domain::Matrix> WH; + + WH.axis(loco::MatrixAxis::Height) = 1; + WH.axis(loco::MatrixAxis::Width) = 0; + + return WH; +} + +} // namespace + +namespace luci +{ + +template <FeatureLayout T> loco::FeatureEncode *make_feature_encode(loco::Node *input_for_encode) +{ + LUCI_ASSERT(input_for_encode != nullptr, "input should not be nullptr"); + loco::Graph *g = input_for_encode->graph(); + + auto encoder = std::make_unique<loco::PermutingEncoder<loco::Domain::Feature>>(); + + encoder->perm(perm<T>()); + + auto enc = g->nodes()->create<loco::FeatureEncode>(); + enc->input(input_for_encode); + enc->encoder(std::move(encoder)); + + return enc; +} + +template <FeatureLayout T> loco::FeatureDecode *make_feature_decode(loco::Node *input_for_decode) +{ + LUCI_ASSERT(input_for_decode != nullptr, "input should not be nullptr"); + loco::Graph *g = input_for_decode->graph(); + + auto decoder = std::make_unique<loco::PermutingDecoder<loco::Domain::Feature>>(); + + decoder->perm(perm<T>()); + + auto dec = g->nodes()->create<loco::FeatureDecode>(); + dec->input(input_for_decode); + dec->decoder(std::move(decoder)); + + return dec; +} + +template <FilterLayout T> loco::FilterEncode *make_filter_encode(loco::Node *input_for_encode) +{ + LUCI_ASSERT(input_for_encode != nullptr, "filter should not be nullptr"); + loco::Graph *g = input_for_encode->graph(); + + auto encoder = std::make_unique<loco::PermutingEncoder<loco::Domain::Filter>>(); + + encoder->perm(perm<T>()); + + auto enc = g->nodes()->create<loco::FilterEncode>(); + enc->input(input_for_encode); + enc->encoder(std::move(encoder)); + + return enc; +} + +template <FilterLayout T> loco::FilterDecode *make_filter_decode(loco::Node *input_for_decode) +{ + LUCI_ASSERT(input_for_decode != nullptr, "filter should not be nullptr"); + loco::Graph *g = input_for_decode->graph(); + + auto decoder = std::make_unique<loco::PermutingDecoder<loco::Domain::Filter>>(); + + decoder->perm(perm<T>()); + + auto dec = g->nodes()->create<loco::FilterDecode>(); + dec->input(input_for_decode); + dec->decoder(std::move(decoder)); + + return dec; +} + +template <DepthwiseFilterLayout T> +loco::DepthwiseFilterDecode *make_dw_filter_decode(loco::Node *input_for_decode) +{ + LUCI_ASSERT(input_for_decode != nullptr, "filter should not be nullptr"); + loco::Graph *g = input_for_decode->graph(); + + auto decoder = std::make_unique<loco::PermutingDecoder<loco::Domain::DepthwiseFilter>>(); + + decoder->perm(perm<T>()); + + auto dec = g->nodes()->create<loco::DepthwiseFilterDecode>(); + dec->input(input_for_decode); + dec->decoder(std::move(decoder)); + + return dec; +} + +template <MatrixLayout T> loco::MatrixEncode *make_matrix_encode(loco::Node *input_for_encode) +{ + LUCI_ASSERT(input_for_encode != nullptr, "input should not be nullptr"); + loco::Graph *g = input_for_encode->graph(); + + auto encoder = std::make_unique<loco::PermutingEncoder<loco::Domain::Matrix>>(); + + encoder->perm(perm<T>()); + + auto enc = g->nodes()->create<loco::MatrixEncode>(); + enc->input(input_for_encode); + enc->encoder(std::move(encoder)); + + return enc; +} + +template <MatrixLayout T> loco::MatrixDecode *make_matrix_decode(loco::Node *input_for_decode) +{ + LUCI_ASSERT(input_for_decode != nullptr, "input should not be nullptr"); + loco::Graph *g = input_for_decode->graph(); + + auto decoder = std::make_unique<loco::PermutingDecoder<loco::Domain::Matrix>>(); + + decoder->perm(perm<T>()); + + auto dec = g->nodes()->create<loco::MatrixDecode>(); + dec->input(input_for_decode); + dec->decoder(std::move(decoder)); + + return dec; +} + +// template instantiation +template loco::FeatureEncode * +make_feature_encode<FeatureLayout::NHWC>(loco::Node *input_for_encode); + +template loco::FeatureDecode * +make_feature_decode<FeatureLayout::NHWC>(loco::Node *input_for_encode); + +template loco::FilterEncode *make_filter_encode<FilterLayout::HWIO>(loco::Node *input_for_encode); +template loco::FilterDecode *make_filter_decode<FilterLayout::OHWI>(loco::Node *input_for_decode); + +template loco::DepthwiseFilterDecode * +make_dw_filter_decode<DepthwiseFilterLayout::HWCM>(loco::Node *input_for_decode); + +template loco::MatrixEncode *make_matrix_encode<MatrixLayout::HW>(loco::Node *input_for_encode); +template loco::MatrixEncode *make_matrix_encode<MatrixLayout::WH>(loco::Node *input_for_encode); +template loco::MatrixDecode *make_matrix_decode<MatrixLayout::HW>(loco::Node *input_for_decode); +template loco::MatrixDecode *make_matrix_decode<MatrixLayout::WH>(loco::Node *input_for_decode); + +} // namespace luci |