/* * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "GraphBlock.h" #include "Check.h" #include #include namespace { template loco::Permutation perm(); template <> loco::Permutation perm() { // Make NHWC permutation for encoder and decoder loco::Permutation NHWC; NHWC.axis(loco::FeatureAxis::Count) = 0; NHWC.axis(loco::FeatureAxis::Height) = 1; NHWC.axis(loco::FeatureAxis::Width) = 2; NHWC.axis(loco::FeatureAxis::Depth) = 3; return NHWC; } template loco::Permutation perm(); template <> loco::Permutation perm() { loco::Permutation HWIO; // a.k.a., HWCN HWIO.axis(loco::FilterAxis::Height) = 0; HWIO.axis(loco::FilterAxis::Width) = 1; HWIO.axis(loco::FilterAxis::Depth) = 2; HWIO.axis(loco::FilterAxis::Count) = 3; return HWIO; } template <> loco::Permutation perm() { // Make NHWC permutation for encoder and decoder loco::Permutation OHWI; // a.k.a., NHWC OHWI.axis(loco::FilterAxis::Count) = 0; OHWI.axis(loco::FilterAxis::Height) = 1; OHWI.axis(loco::FilterAxis::Width) = 2; OHWI.axis(loco::FilterAxis::Depth) = 3; return OHWI; } template loco::Permutation perm(); template <> loco::Permutation perm() { loco::Permutation HWCM; HWCM.axis(loco::DepthwiseFilterAxis::Height) = 0; HWCM.axis(loco::DepthwiseFilterAxis::Width) = 1; HWCM.axis(loco::DepthwiseFilterAxis::Depth) = 2; HWCM.axis(loco::DepthwiseFilterAxis::Multiplier) = 3; return HWCM; } template loco::Permutation perm(); template <> loco::Permutation perm() { loco::Permutation HW; HW.axis(loco::MatrixAxis::Height) = 0; HW.axis(loco::MatrixAxis::Width) = 1; return HW; } template <> loco::Permutation perm() { loco::Permutation WH; WH.axis(loco::MatrixAxis::Height) = 1; WH.axis(loco::MatrixAxis::Width) = 0; return WH; } } // namespace namespace exo { template loco::FeatureEncode *make_feature_encode(loco::Node *input_for_encode) { EXO_ASSERT(input_for_encode != nullptr, "input should not be nullptr"); loco::Graph *g = input_for_encode->graph(); auto encoder = std::make_unique>(); encoder->perm(perm()); auto enc = g->nodes()->create(); enc->input(input_for_encode); enc->encoder(std::move(encoder)); return enc; } template loco::FeatureDecode *make_feature_decode(loco::Node *input_for_decode) { EXO_ASSERT(input_for_decode != nullptr, "input should not be nullptr"); loco::Graph *g = input_for_decode->graph(); auto decoder = std::make_unique>(); decoder->perm(perm()); auto dec = g->nodes()->create(); dec->input(input_for_decode); dec->decoder(std::move(decoder)); return dec; } template loco::FilterEncode *make_filter_encode(loco::Node *input_for_encode) { EXO_ASSERT(input_for_encode != nullptr, "filter should not be nullptr"); loco::Graph *g = input_for_encode->graph(); auto encoder = std::make_unique>(); encoder->perm(perm()); auto enc = g->nodes()->create(); enc->input(input_for_encode); enc->encoder(std::move(encoder)); return enc; } template loco::FilterDecode *make_filter_decode(loco::Node *input_for_decode) { EXO_ASSERT(input_for_decode != nullptr, "filter should not be nullptr"); loco::Graph *g = input_for_decode->graph(); auto decoder = std::make_unique>(); decoder->perm(perm()); auto dec = g->nodes()->create(); dec->input(input_for_decode); dec->decoder(std::move(decoder)); return dec; } template loco::DepthwiseFilterDecode *make_dw_filter_decode(loco::Node *input_for_decode) { EXO_ASSERT(input_for_decode != nullptr, "filter should not be nullptr"); loco::Graph *g = input_for_decode->graph(); auto decoder = std::make_unique>(); decoder->perm(perm()); auto dec = g->nodes()->create(); dec->input(input_for_decode); dec->decoder(std::move(decoder)); return dec; } template loco::MatrixEncode *make_matrix_encode(loco::Node *input_for_encode) { EXO_ASSERT(input_for_encode != nullptr, "input should not be nullptr"); loco::Graph *g = input_for_encode->graph(); auto encoder = std::make_unique>(); encoder->perm(perm()); auto enc = g->nodes()->create(); enc->input(input_for_encode); enc->encoder(std::move(encoder)); return enc; } template loco::MatrixDecode *make_matrix_decode(loco::Node *input_for_decode) { EXO_ASSERT(input_for_decode != nullptr, "input should not be nullptr"); loco::Graph *g = input_for_decode->graph(); auto decoder = std::make_unique>(); decoder->perm(perm()); auto dec = g->nodes()->create(); dec->input(input_for_decode); dec->decoder(std::move(decoder)); return dec; } // template instantiation template loco::FeatureEncode * make_feature_encode(loco::Node *input_for_encode); template loco::FeatureDecode * make_feature_decode(loco::Node *input_for_encode); template loco::FilterEncode *make_filter_encode(loco::Node *input_for_encode); template loco::FilterDecode *make_filter_decode(loco::Node *input_for_decode); template loco::DepthwiseFilterDecode * make_dw_filter_decode(loco::Node *input_for_decode); template loco::MatrixEncode *make_matrix_encode(loco::Node *input_for_encode); template loco::MatrixEncode *make_matrix_encode(loco::Node *input_for_encode); template loco::MatrixDecode *make_matrix_decode(loco::Node *input_for_decode); template loco::MatrixDecode *make_matrix_decode(loco::Node *input_for_decode); } // namespace exo