summaryrefslogtreecommitdiff
path: root/compiler/luci/pass/src/VerifyQuantizedNodeU8Type.h
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci/pass/src/VerifyQuantizedNodeU8Type.h')
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedNodeU8Type.h375
1 files changed, 375 insertions, 0 deletions
diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeU8Type.h b/compiler/luci/pass/src/VerifyQuantizedNodeU8Type.h
new file mode 100644
index 000000000..72ce5b8f8
--- /dev/null
+++ b/compiler/luci/pass/src/VerifyQuantizedNodeU8Type.h
@@ -0,0 +1,375 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_VERIFY_QUANTIZED_NODE_U8_TYPE_H__
+#define __LUCI_VERIFY_QUANTIZED_NODE_U8_TYPE_H__
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+
+using Type = loco::DataType;
+
+// This macro is undef at the end of the file
+#define RETURN_FALSE_UNLESS(ARG) \
+ if (not(ARG)) \
+ { \
+ return false; \
+ }
+
+namespace luci
+{
+
+/**
+ * @brief Verify the data type of UINT8 quantized node
+ * @details
+ *
+ * Targets to verify
+ * - node's output (i.e., node itself)
+ * - node's inputs
+ */
+struct VerifyQuantizedNodeU8Type final : public luci::CircleNodeVisitor<bool>
+{
+private:
+ bool has_type(const loco::Node *node, Type dtype)
+ {
+ auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
+ return circle_node->dtype() == dtype;
+ }
+
+private:
+ bool visit(const luci::CircleConv2D *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->filter(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->bias(), Type::S32))
+ return true;
+ }
+
+ bool visit(const luci::CircleConcatenation *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ for (uint32_t i = 0; i < node->numValues(); i++)
+ {
+ RETURN_FALSE_UNLESS(has_type(node->values(i), Type::U8))
+ }
+ return true;
+ }
+
+ bool visit(const luci::CircleDepthToSpace *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleDepthwiseConv2D *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->filter(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->bias(), Type::S32))
+ return true;
+ }
+
+ bool visit(const luci::CircleInstanceNorm *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->gamma(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->beta(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CirclePad *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->paddings(), Type::S32))
+ return true;
+ }
+
+ bool visit(const luci::CirclePRelu *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->alpha(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleTransposeConv *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->outBackprop(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->filter(), Type::U8))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(has_type(bias, Type::S32))
+ return true;
+ }
+
+ bool visit(const luci::CircleFullyConnected *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->weights(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->bias(), Type::S32))
+ return true;
+ }
+
+ bool visit(const luci::CircleAdd *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleAveragePool2D *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->value(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleBatchToSpaceND *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleLogicalOr *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::BOOL))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Type::BOOL))
+ return true;
+ }
+
+ bool visit(const luci::CircleMaxPool2D *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->value(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleMean *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->reduction_indices(), Type::S32))
+ return true;
+ }
+
+ bool visit(const luci::CircleMul *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleNotEqual *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleRelu *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->features(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleReshape *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->tensor(), Type::U8))
+ luci::CircleConst *shape = dynamic_cast<luci::CircleConst *>(node->shape());
+ if (shape != nullptr)
+ RETURN_FALSE_UNLESS(has_type(shape, Type::S32))
+ return true;
+ }
+
+ bool visit(const luci::CircleLogistic *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleSoftmax *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->logits(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleSpaceToBatchND *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleSpaceToDepth *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleSlice *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->begin(), Type::S32) || has_type(node->begin(), Type::S64))
+ RETURN_FALSE_UNLESS(has_type(node->size(), Type::S32) || has_type(node->size(), Type::S64))
+ return true;
+ }
+
+ bool visit(const luci::CircleSplit *node)
+ {
+ // node's output is the input of CircleSplitOut, thus not quantized
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleSplitOut *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleStridedSlice *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleArgMax *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, node->output_type()))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->dimension(), Type::S32) ||
+ has_type(node->dimension(), Type::S64))
+ return true;
+ }
+
+ bool visit(const luci::CircleTanh *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleTranspose *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->a(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->perm(), Type::S32))
+ return true;
+ }
+
+ bool visit(const luci::CircleFloor *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleGreater *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleGreaterEqual *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleDiv *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleFloorDiv *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleRsqrt *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleSqrt *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleElu *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->features(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CirclePow *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleResizeBilinear *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ return true;
+ }
+
+ // TODO: Implement more Ops
+
+ bool visit(const luci::CircleNode *) { return true; }
+};
+
+} // namespace luci
+
+#undef RETURN_FALSE_UNLESS
+
+#endif // __LUCI_VERIFY_QUNTIZED_NODE_U8_TYPE_H__