summaryrefslogtreecommitdiff
path: root/compiler/exo/src/GraphBlock.h
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/exo/src/GraphBlock.h')
-rw-r--r--compiler/exo/src/GraphBlock.h199
1 files changed, 199 insertions, 0 deletions
diff --git a/compiler/exo/src/GraphBlock.h b/compiler/exo/src/GraphBlock.h
new file mode 100644
index 000000000..b771c821b
--- /dev/null
+++ b/compiler/exo/src/GraphBlock.h
@@ -0,0 +1,199 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __GRAPH_BLOCK_H__
+#define __GRAPH_BLOCK_H__
+
+#include <loco.h>
+#include <loco/Service/ShapeInference.h>
+
+#include <oops/InternalExn.h>
+
+#include <functional>
+
+namespace exo
+{
+
+/// @brief feature layout of TFLITE file
+enum class FeatureLayout
+{
+ NHWC,
+};
+
+/// @brief Creates a loco::FeatureEncode with T layout (NHWC for tflite) and add it to graph.
+template <FeatureLayout T> loco::FeatureEncode *make_feature_encode(loco::Node *input_for_encode);
+
+/// @brief Creates a loco::FeatureDecode with T layout (NHWC for tflite) and add it to graph.
+template <FeatureLayout T> loco::FeatureDecode *make_feature_decode(loco::Node *input_for_decode);
+
+enum class FilterLayout
+{
+ OHWI, // a.k.a., NHWC, Tensorflow Lite uses this layout for filter
+ HWIO, // a.k.a., HWCN, Tensorflow uses this layout for filter
+};
+
+/// @brief Create a loco::FilterEncode of given layout
+template <FilterLayout T> loco::FilterEncode *make_filter_encode(loco::Node *input_for_encode);
+
+/// @brief Create a loco::FilterDecode of given layout
+template <FilterLayout T> loco::FilterDecode *make_filter_decode(loco::Node *input_for_decode);
+
+enum class DepthwiseFilterLayout
+{
+ HWCM,
+};
+
+/// @brief Create a loco::DepthwiseFilterDecode of given layout
+template <DepthwiseFilterLayout T>
+loco::DepthwiseFilterDecode *make_dw_filter_decode(loco::Node *input_for_decode);
+
+enum class MatrixLayout
+{
+ HW,
+ WH
+};
+
+/// @brief Create a loco::MatrixEncode of given layout
+template <MatrixLayout T> loco::MatrixEncode *make_matrix_encode(loco::Node *input_for_encode);
+
+/// @brief Create a loco::MatrixDecode of given layout
+template <MatrixLayout T> loco::MatrixDecode *make_matrix_decode(loco::Node *input_for_decode);
+
+} // exo
+
+//
+// DomainConverter
+//
+
+/**
+ * Some canonical nodes can have input of various loco::Domain, e.g., loco::Domain::Tensor,
+ * loco::Domain::Feature, etc. However, TFL node accepts only loco::Domain::Tensor.
+ * So, When converting such canonical node to TFL node and input(s) of a canonical node are not
+ * loco::Domain::Tensor, additional nodes need to be inserted.
+ *
+ * The following two classes helps this insertion.
+ *
+ * For example, in case of loco::Relu conversion,
+ *
+ * Before:
+ *
+ * A (output: feature) -- loco::ReLU --- B (input:feature)
+ *
+ * After:
+ *
+ * A -- loco::FeatureDecode -- locoex::TFLRelu -- loco::FeatureEncode --- B
+ *
+ * loco::ReLU (dead node)
+ */
+
+namespace exo
+{
+
+/**
+ * @brief Handles input(s) while converting a canonical node to TFL node(s).
+ * This class informs DomainConverter how to handle inputs of a specific canonical node.
+ */
+template <class CanonicalT, class TFLT> class InputHandler
+{
+public:
+ /**
+ * @brief Assign origin's inputs to replacer's inputs.
+ * (This is called when origin belongs in Tensor domain.)
+ */
+ virtual void handover(CanonicalT *origin, TFLT *replacer) = 0;
+
+ /**
+ * @brief Returns the list of inputs that needs to have FeatureDecode as its input.
+ * (This is called when origin belongs in Feature domain.)
+ */
+ virtual std::vector<loco::Node *> getInputsToConvert(CanonicalT *origin) = 0;
+
+ /// @brief Set the inputs of replacer to new_inputs
+ virtual void set(TFLT *replacer, std::vector<loco::Node *> &new_inputs) = 0;
+
+ /// @brief Set the inputs to nullptr
+ virtual void nullify(CanonicalT *origin) = 0;
+};
+
+/**
+ * @brief Class to handle domain conversion while converting a canonical node to TFL node(s)
+ */
+template <class CanonicalT, class TFLT> class DomainConverter
+{
+public:
+ template <FeatureLayout FeatureLayoutT>
+ TFLT *convert(CanonicalT *origin, InputHandler<CanonicalT, TFLT> &input_handler);
+};
+
+/**
+ * @brief Performs domain conversion
+ *
+ * 1. if origin belong to loco::Domain::Tensor, and replace origin to a TFL node.
+ * 2. if origin belong to loco::Domain::Feature, insert loco::FeatureDecode for input(s) and
+ * insert loco::FeatureEncode for output. Then replace origin to a TFL node.
+ *
+ * @return new TFL node; nullptr if shape of origin cannot be known
+ */
+template <class CanonicalT, class TFLT>
+template <FeatureLayout FeatureLayoutT>
+TFLT *DomainConverter<CanonicalT, TFLT>::convert(CanonicalT *origin,
+ InputHandler<CanonicalT, TFLT> &input_handler)
+{
+ static_assert(FeatureLayoutT == FeatureLayout::NHWC, "Feature layout should be NHWC");
+
+ if (!loco::shape_known(origin))
+ {
+ return nullptr;
+ }
+
+ auto tfl_node = origin->graph()->nodes()->template create<TFLT>();
+
+ // when the input is Tensor, just replace canonical node to TFL node.
+ if (loco::shape_get(origin).domain() == loco::Domain::Tensor)
+ {
+ input_handler.handover(origin, tfl_node);
+
+ loco::replace(origin).with(tfl_node);
+ input_handler.nullify(origin);
+
+ return tfl_node;
+ }
+ else if (loco::shape_get(origin).domain() == loco::Domain::Feature)
+ {
+ std::vector<loco::Node *> feature_decodes;
+
+ for (auto input : input_handler.getInputsToConvert(origin))
+ {
+ auto dec = make_feature_decode<FeatureLayoutT>(input);
+ feature_decodes.emplace_back(dec);
+ }
+
+ input_handler.set(tfl_node, feature_decodes);
+
+ auto enc = make_feature_encode<FeatureLayoutT>(tfl_node);
+
+ loco::replace(origin).with(enc);
+ input_handler.nullify(origin);
+
+ return tfl_node;
+ }
+ else
+ INTERNAL_EXN_V("Unsupported loco::Domain", oops::to_uint32(loco::shape_get(origin).domain()));
+}
+
+} // namespace exo
+
+#endif //__GRAPH_BLOCK_H__