summaryrefslogtreecommitdiff
path: root/compiler/luci/import/src/Nodes
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci/import/src/Nodes')
-rw-r--r--compiler/luci/import/src/Nodes/CircleAbs.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleAdd.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleArgMax.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleArgMin.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleAveragePool2D.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleBCQFullyConnected.cpp14
-rw-r--r--compiler/luci/import/src/Nodes/CircleBCQGather.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleBatchMatMul.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleBidirectionalSequenceLSTM.cpp124
-rw-r--r--compiler/luci/import/src/Nodes/CircleCast.cpp19
-rw-r--r--compiler/luci/import/src/Nodes/CircleCeil.cpp10
-rw-r--r--compiler/luci/import/src/Nodes/CircleConst.cpp98
-rw-r--r--compiler/luci/import/src/Nodes/CircleConv2D.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleCos.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleCustom.cpp69
-rw-r--r--compiler/luci/import/src/Nodes/CircleDensify.cpp43
-rw-r--r--compiler/luci/import/src/Nodes/CircleDepthToSpace.cpp15
-rw-r--r--compiler/luci/import/src/Nodes/CircleDepthwiseConv2D.cpp28
-rw-r--r--compiler/luci/import/src/Nodes/CircleDequantize.cpp43
-rw-r--r--compiler/luci/import/src/Nodes/CircleDiv.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleElu.cpp25
-rw-r--r--compiler/luci/import/src/Nodes/CircleEqual.cpp12
-rw-r--r--compiler/luci/import/src/Nodes/CircleExp.cpp15
-rw-r--r--compiler/luci/import/src/Nodes/CircleExpandDims.cpp12
-rw-r--r--compiler/luci/import/src/Nodes/CircleFakeQuant.cpp49
-rw-r--r--compiler/luci/import/src/Nodes/CircleFill.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleFloor.cpp10
-rw-r--r--compiler/luci/import/src/Nodes/CircleFloorDiv.cpp31
-rw-r--r--compiler/luci/import/src/Nodes/CircleFloorMod.cpp16
-rw-r--r--compiler/luci/import/src/Nodes/CircleFullyConnected.cpp22
-rw-r--r--compiler/luci/import/src/Nodes/CircleGather.cpp10
-rw-r--r--compiler/luci/import/src/Nodes/CircleGatherNd.cpp16
-rw-r--r--compiler/luci/import/src/Nodes/CircleGelu.cpp44
-rw-r--r--compiler/luci/import/src/Nodes/CircleGreater.cpp20
-rw-r--r--compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp22
-rw-r--r--compiler/luci/import/src/Nodes/CircleHardSwish.cpp41
-rw-r--r--compiler/luci/import/src/Nodes/CircleIf.cpp74
-rw-r--r--compiler/luci/import/src/Nodes/CircleInstanceNorm.cpp6
-rw-r--r--compiler/luci/import/src/Nodes/CircleL2Normalize.cpp15
-rw-r--r--compiler/luci/import/src/Nodes/CircleL2Pool2D.cpp6
-rw-r--r--compiler/luci/import/src/Nodes/CircleLeakyRelu.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleLess.cpp27
-rw-r--r--compiler/luci/import/src/Nodes/CircleLessEqual.cpp22
-rw-r--r--compiler/luci/import/src/Nodes/CircleLocalResponseNormalization.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleLog.cpp13
-rw-r--r--compiler/luci/import/src/Nodes/CircleLogSoftmax.cpp6
-rw-r--r--compiler/luci/import/src/Nodes/CircleLogicalAnd.cpp13
-rw-r--r--compiler/luci/import/src/Nodes/CircleLogicalNot.cpp9
-rw-r--r--compiler/luci/import/src/Nodes/CircleLogicalOr.cpp9
-rw-r--r--compiler/luci/import/src/Nodes/CircleLogistic.cpp13
-rw-r--r--compiler/luci/import/src/Nodes/CircleMatrixDiag.cpp17
-rw-r--r--compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp17
-rw-r--r--compiler/luci/import/src/Nodes/CircleMaxPool2D.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleMean.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleMirrorPad.cpp6
-rw-r--r--compiler/luci/import/src/Nodes/CircleMul.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleNeg.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV4.cpp98
-rw-r--r--compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV5.cpp103
-rw-r--r--compiler/luci/import/src/Nodes/CircleNotEqual.cpp22
-rw-r--r--compiler/luci/import/src/Nodes/CircleOneHot.cpp35
-rw-r--r--compiler/luci/import/src/Nodes/CirclePRelu.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CirclePad.cpp6
-rw-r--r--compiler/luci/import/src/Nodes/CirclePadV2.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CirclePow.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleQuantize.cpp43
-rw-r--r--compiler/luci/import/src/Nodes/CircleRange.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleRank.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleReduceAny.cpp25
-rw-r--r--compiler/luci/import/src/Nodes/CircleReduceProd.cpp13
-rw-r--r--compiler/luci/import/src/Nodes/CircleRelu.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleRelu6.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleReluN1To1.cpp9
-rw-r--r--compiler/luci/import/src/Nodes/CircleReshape.cpp15
-rw-r--r--compiler/luci/import/src/Nodes/CircleResizeBilinear.cpp9
-rw-r--r--compiler/luci/import/src/Nodes/CircleResizeNearestNeighbor.cpp11
-rw-r--r--compiler/luci/import/src/Nodes/CircleReverseSequence.cpp24
-rw-r--r--compiler/luci/import/src/Nodes/CircleReverseV2.cpp24
-rw-r--r--compiler/luci/import/src/Nodes/CircleRound.cpp21
-rw-r--r--compiler/luci/import/src/Nodes/CircleRsqrt.cpp13
-rw-r--r--compiler/luci/import/src/Nodes/CircleSVDF.cpp64
-rw-r--r--compiler/luci/import/src/Nodes/CircleScatterNd.cpp13
-rw-r--r--compiler/luci/import/src/Nodes/CircleSegmentSum.cpp23
-rw-r--r--compiler/luci/import/src/Nodes/CircleSelect.cpp14
-rw-r--r--compiler/luci/import/src/Nodes/CircleSelectV2.cpp21
-rw-r--r--compiler/luci/import/src/Nodes/CircleShape.cpp10
-rw-r--r--compiler/luci/import/src/Nodes/CircleSin.cpp13
-rw-r--r--compiler/luci/import/src/Nodes/CircleSlice.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleSoftmax.cpp6
-rw-r--r--compiler/luci/import/src/Nodes/CircleSpaceToDepth.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleSparseToDense.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleSplit.cpp63
-rw-r--r--compiler/luci/import/src/Nodes/CircleSplitV.cpp76
-rw-r--r--compiler/luci/import/src/Nodes/CircleSqrt.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleSquare.cpp16
-rw-r--r--compiler/luci/import/src/Nodes/CircleSquaredDifference.cpp27
-rw-r--r--compiler/luci/import/src/Nodes/CircleSqueeze.cpp9
-rw-r--r--compiler/luci/import/src/Nodes/CircleStridedSlice.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleSub.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleSum.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleTanh.cpp13
-rw-r--r--compiler/luci/import/src/Nodes/CircleTile.cpp20
-rw-r--r--compiler/luci/import/src/Nodes/CircleTopKV2.cpp76
-rw-r--r--compiler/luci/import/src/Nodes/CircleTranspose.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleTransposeConv.cpp27
-rw-r--r--compiler/luci/import/src/Nodes/CircleUnidirectionalSequenceLSTM.cpp76
-rw-r--r--compiler/luci/import/src/Nodes/CircleUnique.cpp55
-rw-r--r--compiler/luci/import/src/Nodes/CircleUnpack.cpp72
-rw-r--r--compiler/luci/import/src/Nodes/CircleVariable.cpp80
-rw-r--r--compiler/luci/import/src/Nodes/CircleWhere.cpp22
-rw-r--r--compiler/luci/import/src/Nodes/CircleWhile.cpp18
-rw-r--r--compiler/luci/import/src/Nodes/CircleZerosLike.cpp8
112 files changed, 1343 insertions, 1218 deletions
diff --git a/compiler/luci/import/src/Nodes/CircleAbs.cpp b/compiler/luci/import/src/Nodes/CircleAbs.cpp
index 3556dc7fa..2a1601a21 100644
--- a/compiler/luci/import/src/Nodes/CircleAbs.cpp
+++ b/compiler/luci/import/src/Nodes/CircleAbs.cpp
@@ -24,11 +24,8 @@ namespace luci
{
bool CircleAbsGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
- return false;
-
// TODO Support type check
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleAbsGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleAdd.cpp b/compiler/luci/import/src/Nodes/CircleAdd.cpp
index b767d4af2..94cbdf081 100644
--- a/compiler/luci/import/src/Nodes/CircleAdd.cpp
+++ b/compiler/luci/import/src/Nodes/CircleAdd.cpp
@@ -25,10 +25,7 @@ namespace luci
bool CircleAddGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 2);
}
CircleNode *CircleAddGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleArgMax.cpp b/compiler/luci/import/src/Nodes/CircleArgMax.cpp
index 10e8516f4..fd8a84289 100644
--- a/compiler/luci/import/src/Nodes/CircleArgMax.cpp
+++ b/compiler/luci/import/src/Nodes/CircleArgMax.cpp
@@ -25,10 +25,7 @@ namespace luci
bool CircleArgMaxGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 2);
}
CircleNode *CircleArgMaxGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleArgMin.cpp b/compiler/luci/import/src/Nodes/CircleArgMin.cpp
index 5ff534dbb..63ca8db03 100644
--- a/compiler/luci/import/src/Nodes/CircleArgMin.cpp
+++ b/compiler/luci/import/src/Nodes/CircleArgMin.cpp
@@ -25,10 +25,7 @@ namespace luci
bool CircleArgMinGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 2);
}
CircleNode *CircleArgMinGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleAveragePool2D.cpp b/compiler/luci/import/src/Nodes/CircleAveragePool2D.cpp
index ad011f71f..a351cf5e7 100644
--- a/compiler/luci/import/src/Nodes/CircleAveragePool2D.cpp
+++ b/compiler/luci/import/src/Nodes/CircleAveragePool2D.cpp
@@ -23,10 +23,7 @@ namespace luci
bool CircleAveragePool2DGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleAveragePool2DGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleBCQFullyConnected.cpp b/compiler/luci/import/src/Nodes/CircleBCQFullyConnected.cpp
index 16ecebd5c..4c86399ce 100644
--- a/compiler/luci/import/src/Nodes/CircleBCQFullyConnected.cpp
+++ b/compiler/luci/import/src/Nodes/CircleBCQFullyConnected.cpp
@@ -25,10 +25,7 @@ namespace luci
bool CircleBCQFullyConnectedGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 5)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 5);
}
CircleNode *CircleBCQFullyConnectedGraphBuilder::build_node(const circle::OperatorT &op,
@@ -43,15 +40,6 @@ CircleNode *CircleBCQFullyConnectedGraphBuilder::build_node(const circle::Operat
node->bias(inputs.at(3));
node->weights_clusters(inputs.at(4));
- // TODO Find and move to appropriate place for setting optional input
- if (auto bias = dynamic_cast<luci::CircleOutputExclude *>(node->bias()))
- {
- // bias is not used for type inference, but node itself should have a type
- bias->dtype(loco::DataType::FLOAT32);
-
- // bias is not used for shape inference
- }
-
const auto *options = op.builtin_options.AsBCQFullyConnectedOptions();
node->weights_hidden_size(options->weights_hidden_size);
node->fusedActivationFunction(luci_actfunc(options->fused_activation_function));
diff --git a/compiler/luci/import/src/Nodes/CircleBCQGather.cpp b/compiler/luci/import/src/Nodes/CircleBCQGather.cpp
index 464f1ac18..ee1358197 100644
--- a/compiler/luci/import/src/Nodes/CircleBCQGather.cpp
+++ b/compiler/luci/import/src/Nodes/CircleBCQGather.cpp
@@ -25,10 +25,7 @@ namespace luci
bool CircleBCQGatherGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 4)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 4);
}
CircleNode *CircleBCQGatherGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleBatchMatMul.cpp b/compiler/luci/import/src/Nodes/CircleBatchMatMul.cpp
index 330775691..390719061 100644
--- a/compiler/luci/import/src/Nodes/CircleBatchMatMul.cpp
+++ b/compiler/luci/import/src/Nodes/CircleBatchMatMul.cpp
@@ -23,10 +23,7 @@ namespace luci
bool CircleBatchMatMulGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 2);
}
CircleNode *CircleBatchMatMulGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleBidirectionalSequenceLSTM.cpp b/compiler/luci/import/src/Nodes/CircleBidirectionalSequenceLSTM.cpp
new file mode 100644
index 000000000..c04b957bb
--- /dev/null
+++ b/compiler/luci/import/src/Nodes/CircleBidirectionalSequenceLSTM.cpp
@@ -0,0 +1,124 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Import/Nodes/CircleBidirectionalSequenceLSTM.h"
+
+#include <luci/IR/Nodes/CircleBidirectionalSequenceLSTM.h>
+#include <luci/IR/Nodes/CircleBidirectionalSequenceLSTMOut.h>
+
+#include <loco.h>
+
+namespace luci
+{
+
+bool CircleBidirectionalSequenceLSTMGraphBuilder::validate(const ValidateArgs &args) const
+{
+ if (args.op.inputs.size() != 48)
+ return false;
+ if (args.op.outputs.size() != 2)
+ return false;
+
+ return true;
+}
+
+CircleNode *CircleBidirectionalSequenceLSTMGraphBuilder::build_node(const BuildNodeArgs &bna) const
+{
+ auto *node = bna.context->graph()->nodes()->create<CircleBidirectionalSequenceLSTM>();
+ auto &inputs = bna.input_nodes;
+ node->input(inputs.at(0));
+
+ node->fw_input_to_input_weights(inputs.at(1)); // Optional
+ node->fw_input_to_cell_weights(inputs.at(2));
+ node->fw_input_to_forget_weights(inputs.at(3));
+ node->fw_input_to_output_weights(inputs.at(4));
+
+ node->fw_recurrent_to_input_weights(inputs.at(5)); // Optional
+ node->fw_recurrent_to_cell_weights(inputs.at(6));
+ node->fw_recurrent_to_forget_weights(inputs.at(7));
+ node->fw_recurrent_to_output_weights(inputs.at(8));
+
+ node->fw_cell_to_input_weights(inputs.at(9)); // Optional
+ node->fw_cell_to_forget_weights(inputs.at(10)); // Optional
+ node->fw_cell_to_output_weights(inputs.at(11)); // Optional
+
+ node->fw_input_gate_bias(inputs.at(12)); // Optional
+ node->fw_forget_gate_bias(inputs.at(13));
+ node->fw_cell_gate_bias(inputs.at(14));
+ node->fw_output_gate_bias(inputs.at(15));
+
+ node->fw_projection_weights(inputs.at(16)); // Optional
+ node->fw_projection_bias(inputs.at(17)); // Optional
+
+ node->bw_input_to_input_weights(inputs.at(18)); // Optional
+ node->bw_input_to_cell_weights(inputs.at(19));
+ node->bw_input_to_forget_weights(inputs.at(20));
+ node->bw_input_to_output_weights(inputs.at(21));
+
+ node->bw_recurrent_to_input_weights(inputs.at(22)); // Optional
+ node->bw_recurrent_to_cell_weights(inputs.at(23));
+ node->bw_recurrent_to_forget_weights(inputs.at(24));
+ node->bw_recurrent_to_output_weights(inputs.at(25));
+
+ node->bw_cell_to_input_weights(inputs.at(26)); // Optional
+ node->bw_cell_to_forget_weights(inputs.at(27)); // Optional
+ node->bw_cell_to_output_weights(inputs.at(28)); // Optional
+
+ node->bw_input_gate_bias(inputs.at(29)); // Optional
+ node->bw_forget_gate_bias(inputs.at(30));
+ node->bw_cell_gate_bias(inputs.at(31));
+ node->bw_output_gate_bias(inputs.at(32));
+
+ node->bw_projection_weights(inputs.at(33)); // Optional
+ node->bw_projection_bias(inputs.at(34)); // Optional
+
+ node->fw_activation_state(inputs.at(35));
+ node->fw_cell_state(inputs.at(36));
+ node->bw_activation_state(inputs.at(37));
+ node->bw_cell_state(inputs.at(38));
+
+ node->auxillary_input(inputs.at(39)); // Optional
+ node->fw_auxillary_input_to_input_weights(inputs.at(40)); // Optional
+ node->fw_auxillary_input_to_forget_weights(inputs.at(41)); // Optional
+ node->fw_auxillary_input_to_cell_weights(inputs.at(42)); // Optional
+ node->fw_auxillary_input_to_output_weights(inputs.at(43)); // Optional
+
+ node->bw_auxillary_input_to_input_weights(inputs.at(44)); // Optional
+ node->bw_auxillary_input_to_forget_weights(inputs.at(45)); // Optional
+ node->bw_auxillary_input_to_cell_weights(inputs.at(46)); // Optional
+ node->bw_auxillary_input_to_output_weights(inputs.at(47)); // Optional
+
+ const auto *options = bna.op.builtin_options.AsBidirectionalSequenceLSTMOptions();
+ node->fusedActivationFunction(luci_actfunc(options->fused_activation_function));
+ node->cell_clip(options->cell_clip);
+ node->proj_clip(options->proj_clip);
+ node->merge_outputs(options->merge_outputs);
+ node->time_major(options->time_major);
+ node->asymmetric_quantize_inputs(options->asymmetric_quantize_inputs);
+
+ return node;
+}
+
+CircleNode *CircleBidirectionalSequenceLSTMGraphBuilder::build_out(const BuildOutArgs &boa) const
+{
+ auto *nodeout = boa.node->graph()->nodes()->create<CircleBidirectionalSequenceLSTMOut>();
+
+ nodeout->input(boa.node);
+ nodeout->index(boa.index);
+
+ return nodeout;
+}
+
+} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleCast.cpp b/compiler/luci/import/src/Nodes/CircleCast.cpp
index 7bdb63044..acde823b1 100644
--- a/compiler/luci/import/src/Nodes/CircleCast.cpp
+++ b/compiler/luci/import/src/Nodes/CircleCast.cpp
@@ -30,25 +30,26 @@ bool CircleCastGraphBuilder::validate(const ValidateArgs &args) const
{
LOGGER(l);
+ if (!GraphBuilder::validate(args, 1))
+ return false;
+
auto settings = luci::UserSettings::settings();
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
- if (inputs.size() != 1)
- return false;
- if (outputs.size() != 1)
- return false;
// NOTE real models do have type mismatch
const auto *options = args.op.builtin_options.AsCastOptions();
if (options != nullptr)
{
- const auto &tensors = args.reader.tensors();
- const circle::TensorT &output_tensor = *tensors[outputs[0]];
+ const auto tensors = args.reader.tensors();
+ const auto output_tensor = tensors[outputs[0]];
+ assert(output_tensor != nullptr);
auto name = tensor_name(output_tensor);
- const auto &tensor_in = tensors.at(inputs.at(0));
- if (tensor_in->type != options->in_data_type)
+ const auto tensor_in = tensors.at(inputs.at(0));
+ assert(tensor_in != nullptr);
+ if (tensor_in->type() != options->in_data_type)
{
if (settings->get(luci::UserSettings::Key::DisableValidation))
{
@@ -58,7 +59,7 @@ bool CircleCastGraphBuilder::validate(const ValidateArgs &args) const
return false;
}
const auto &tensor_out = tensors.at(outputs[0]);
- if (tensor_out->type != options->out_data_type)
+ if (tensor_out->type() != options->out_data_type)
{
if (settings->get(luci::UserSettings::Key::DisableValidation))
{
diff --git a/compiler/luci/import/src/Nodes/CircleCeil.cpp b/compiler/luci/import/src/Nodes/CircleCeil.cpp
index 2e1aaa295..d439f41cd 100644
--- a/compiler/luci/import/src/Nodes/CircleCeil.cpp
+++ b/compiler/luci/import/src/Nodes/CircleCeil.cpp
@@ -25,16 +25,8 @@ namespace luci
bool CircleCeilGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
- if (inputs.size() != 1)
- return false;
- if (outputs.size() != 1)
- return false;
-
// TODO dtype check
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleCeilGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleConst.cpp b/compiler/luci/import/src/Nodes/CircleConst.cpp
index fad7a0757..88f2ae3d0 100644
--- a/compiler/luci/import/src/Nodes/CircleConst.cpp
+++ b/compiler/luci/import/src/Nodes/CircleConst.cpp
@@ -23,14 +23,17 @@
#include <oops/UserExn.h>
#include <cassert>
+#include <ostream>
+#include <string>
+#include <vector>
namespace
{
-std::ostream &operator<<(std::ostream &os, const std::vector<int32_t> &vect)
+std::ostream &operator<<(std::ostream &os, const luci::VectorWrapper<int32_t> &vect)
{
uint32_t seq = 0;
- for (auto &v : vect)
+ for (const auto &v : vect)
{
if (seq)
os << ", ";
@@ -40,17 +43,20 @@ std::ostream &operator<<(std::ostream &os, const std::vector<int32_t> &vect)
return os;
}
-} // namespace
-
-namespace luci
-{
+using namespace luci;
template <loco::DataType DT>
-static void copy_data(const std::vector<uint8_t> &raw_data, uint32_t num_elements,
- CircleConst *const_node)
+void copy_data(const VectorWrapper<uint8_t> &raw_data, uint32_t num_elements,
+ CircleConst *const_node)
{
using T = typename loco::DataTypeImpl<DT>::Type;
+ // TODO calculate the exact buffer size of sparse tensor
+ if (const_node->sparsityparam())
+ {
+ num_elements = raw_data.size() / sizeof(T);
+ }
+
assert(raw_data.size() == num_elements * sizeof(T));
const auto *data = reinterpret_cast<const T *>(raw_data.data());
@@ -61,23 +67,69 @@ static void copy_data(const std::vector<uint8_t> &raw_data, uint32_t num_element
}
}
-//
-// circleconst_from_tensor() ?
-//
-CircleConst *create_circleconst(GraphBuilderContext *context, int32_t tensor_index)
+template <>
+void copy_data<loco::DataType::STRING>(const VectorWrapper<uint8_t> &raw_data,
+ uint32_t num_elements, CircleConst *const_node)
{
+ assert(const_node->sparsityparam() == nullptr);
+
+ const auto *data = reinterpret_cast<const char *>(raw_data.data());
+ const auto *i32d = reinterpret_cast<const int32_t *>(raw_data.data());
+
+ // de-serialize string data
+ // int32_t count
+ // int32_t offsets[count + 1]
+ // string values[count]
+ assert(static_cast<uint32_t>(*i32d) == num_elements);
+ i32d++; // skip count
+
+ std::vector<int32_t> offsets;
+ offsets.push_back(*i32d++);
+ for (uint32_t i = 0; i < num_elements; ++i)
+ {
+ offsets.push_back(*i32d++);
+ }
+ assert(offsets.size() == num_elements + 1);
+
+ const_node->size<loco::DataType::STRING>(num_elements);
+ for (uint32_t i = 0; i < num_elements; ++i)
+ {
+ int32_t start = offsets[i];
+ int32_t next = offsets[i + 1];
+
+ std::string value(data + start, next - start);
+ const_node->at<loco::DataType::STRING>(i) = value;
+ }
+}
+
+} // namespace
+
+namespace luci
+{
+
+CircleNode *CircleConstNodeBuilder::build(TensorIndex tensor_index,
+ GraphBuilderContext *context) const
+{
+ assert(tensor_index >= 0);
LOGGER(l);
auto graph = context->graph();
auto reader = context->reader();
- const auto &tensors = reader->tensors();
- const circle::TensorT &const_tensor = *tensors[tensor_index];
+ const auto tensors = reader->tensors();
+ const auto const_tensor = tensors[tensor_index];
+ assert(const_tensor != nullptr);
+ if (const_tensor->is_variable())
+ {
+ // Create CircleVariable for variable
+ return nullptr;
+ }
- const std::vector<uint8_t> &buffer = reader->buffers()[const_tensor.buffer]->data;
- std::vector<int32_t> const_dims = const_tensor.shape; // in NHWC
+ assert(reader->buffers()[const_tensor->buffer()] != nullptr);
+ const auto buffer = wrap(reader->buffers()[const_tensor->buffer()]->data());
+ const auto const_dims = wrap(const_tensor->shape()); // in NHWC
if (const_dims.size() == 0 && buffer.empty())
{
- // unknown shape tensor
+ // unknown shape tensor and scalar tensor
return nullptr;
}
@@ -108,12 +160,16 @@ CircleConst *create_circleconst(GraphBuilderContext *context, int32_t tensor_ind
<< const_dims << std::endl;
if (num_elements > 0)
{
- switch (luci_datatype(const_tensor.type))
+ switch (luci_datatype(const_tensor->type()))
{
case loco::DataType::FLOAT32:
copy_data<loco::DataType::FLOAT32>(buffer, num_elements, const_node);
break;
+ case loco::DataType::FLOAT16:
+ copy_data<loco::DataType::FLOAT16>(buffer, num_elements, const_node);
+ break;
+
case loco::DataType::U8:
copy_data<loco::DataType::U8>(buffer, num_elements, const_node);
break;
@@ -138,9 +194,13 @@ CircleConst *create_circleconst(GraphBuilderContext *context, int32_t tensor_ind
copy_data<loco::DataType::BOOL>(buffer, num_elements, const_node);
break;
+ case loco::DataType::STRING:
+ copy_data<loco::DataType::STRING>(buffer, num_elements, const_node);
+ break;
+
default:
throw oops::UserExn("Unsupported tensor type",
- circle::EnumNameTensorType(const_tensor.type));
+ circle::EnumNameTensorType(const_tensor->type()));
}
}
diff --git a/compiler/luci/import/src/Nodes/CircleConv2D.cpp b/compiler/luci/import/src/Nodes/CircleConv2D.cpp
index 9516ef16a..8cbecdc00 100644
--- a/compiler/luci/import/src/Nodes/CircleConv2D.cpp
+++ b/compiler/luci/import/src/Nodes/CircleConv2D.cpp
@@ -28,10 +28,7 @@ namespace luci
bool CircleConv2DGraphBuilder::validate(const ValidateArgs &args) const
{
// Circle Conv2D may not have a bias but we won't support this
- if (args.op.inputs.size() != 3)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 3);
}
CircleNode *CircleConv2DGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleCos.cpp b/compiler/luci/import/src/Nodes/CircleCos.cpp
index 27d60c62c..9705202ee 100644
--- a/compiler/luci/import/src/Nodes/CircleCos.cpp
+++ b/compiler/luci/import/src/Nodes/CircleCos.cpp
@@ -25,10 +25,7 @@ namespace luci
bool CircleCosGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleCosGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleCustom.cpp b/compiler/luci/import/src/Nodes/CircleCustom.cpp
index d541ee87b..4e78d5fb7 100644
--- a/compiler/luci/import/src/Nodes/CircleCustom.cpp
+++ b/compiler/luci/import/src/Nodes/CircleCustom.cpp
@@ -27,62 +27,41 @@ bool CircleCustomGraphBuilder::validate(const ValidateArgs &) const
return true;
}
-void CircleCustomGraphBuilder::build(const circle::OperatorT &op,
- GraphBuilderContext *context) const
+CircleNode *CircleCustomGraphBuilder::build_node(const BuildNodeArgs &bna) const
{
- assert(context != nullptr);
+ uint32_t input_count = bna.op.inputs.size();
+ uint32_t output_count = bna.op.outputs.size();
- auto graph = context->graph();
+ auto *node = bna.context->graph()->nodes()->create<CircleCustom>(input_count, output_count);
- const std::vector<int32_t> &inputs = op.inputs;
- const std::vector<int32_t> &outputs = op.outputs;
- const auto &tensors = context->reader()->tensors();
- auto tensors_ptr = context->reader()->tensors_ptr();
- assert(tensors_ptr != nullptr);
-
- // Create CircleCustom
- const auto &opcodes = context->reader()->opcodes();
- const uint32_t opcode_index = op.opcode_index;
- const circle::OperatorCodeT &opcode = *opcodes[opcode_index];
-
- auto *node = graph->nodes()->create<CircleCustom>(inputs.size());
- uint32_t input_idx = 0;
- for (const int32_t input_tensor_index : inputs)
+ for (uint32_t idx = 0; idx < input_count; ++idx)
{
- node->inputs(input_idx++, context->nodefinder()->node(input_tensor_index));
+ node->inputs(idx, bna.input_nodes[idx]);
}
- node->custom_options(std::vector<uint8_t>{op.custom_options.begin(), op.custom_options.end()});
- node->custom_code(opcode.custom_code);
- // Operator version of custom is always 1, so do nothing
- uint32_t output_count = outputs.size();
+ const auto opcodes = bna.context->reader()->opcodes();
+ const uint32_t opcode_index = bna.op.opcode_index;
+ const auto opcode = opcodes[opcode_index];
+ assert(opcode != nullptr);
- assert(output_count > 0);
- {
- // Let's use attributes from output 0 for this node
- const circle::TensorT &output_tensor = *tensors[outputs[0]];
- node->name(tensor_name(output_tensor));
- node->dtype(luci_datatype(output_tensor.type));
- }
+ node->custom_options(
+ std::vector<uint8_t>{bna.op.custom_options.begin(), bna.op.custom_options.end()});
+ assert(opcode->custom_code() != nullptr);
+ node->custom_code(opcode->custom_code()->c_str());
- // Create virtual outputs of Custom
- for (uint32_t n = 0; n < output_count; ++n)
- {
- const circle::TensorT &output_tensor = *tensors[outputs[n]];
+ // NOTE Operator version of custom is always 1
- auto *nodeout = graph->nodes()->create<CircleCustomOut>();
- copy_tensor_attributes(output_tensor, nodeout);
- // mark shape_status
- if (tensors_ptr->Get(outputs[n])->shape() == nullptr)
- nodeout->shape_status(ShapeStatus::NOSHAPE);
- else
- nodeout->shape_status(ShapeStatus::VALID);
+ return node;
+}
+
+CircleNode *CircleCustomGraphBuilder::build_out(const BuildOutArgs &boa) const
+{
+ auto *nodeout = boa.node->graph()->nodes()->create<CircleCustomOut>();
- nodeout->input(node);
- nodeout->index(n);
+ nodeout->input(boa.node);
+ nodeout->index(boa.index);
- context->nodefinder()->enroll(outputs[n], nodeout);
- }
+ return nodeout;
}
} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleDensify.cpp b/compiler/luci/import/src/Nodes/CircleDensify.cpp
new file mode 100644
index 000000000..0a4b2186f
--- /dev/null
+++ b/compiler/luci/import/src/Nodes/CircleDensify.cpp
@@ -0,0 +1,43 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Import/Nodes/CircleDensify.h"
+
+#include <luci/IR/Nodes/CircleDensify.h>
+
+#include <loco.h>
+
+namespace luci
+{
+
+bool CircleDensifyGraphBuilder::validate(const ValidateArgs &args) const
+{
+ return GraphBuilder::validate(args, 1);
+}
+
+CircleNode *CircleDensifyGraphBuilder::build_node(const circle::OperatorT &,
+ const std::vector<CircleNode *> &inputs,
+ loco::Graph *graph) const
+{
+ auto *node = graph->nodes()->create<CircleDensify>();
+ node->input(inputs.at(0));
+
+ // No options for Densify
+
+ return node;
+}
+
+} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleDepthToSpace.cpp b/compiler/luci/import/src/Nodes/CircleDepthToSpace.cpp
index 49d31bb99..83fc2e37d 100644
--- a/compiler/luci/import/src/Nodes/CircleDepthToSpace.cpp
+++ b/compiler/luci/import/src/Nodes/CircleDepthToSpace.cpp
@@ -27,20 +27,17 @@ namespace luci
bool CircleDepthToSpaceGraphBuilder::validate(const ValidateArgs &args) const
{
+ if (!GraphBuilder::validate(args, 1))
+ return false;
+
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
const auto *options = args.op.builtin_options.AsDepthToSpaceOptions();
+ const auto tensors = args.reader.tensors();
+ assert(tensors[outputs[0]] != nullptr && tensors[inputs.at(0)] != nullptr);
- if (inputs.size() != 1)
- return false;
-
- if (outputs.size() != 1)
- return false;
-
- const auto &tensors = args.reader.tensors();
-
- if (tensors[outputs[0]]->type != tensors[inputs.at(0)]->type)
+ if (tensors[outputs[0]]->type() != tensors[inputs.at(0)]->type())
{
return false;
}
diff --git a/compiler/luci/import/src/Nodes/CircleDepthwiseConv2D.cpp b/compiler/luci/import/src/Nodes/CircleDepthwiseConv2D.cpp
index 53f85f2f5..a24e4160d 100644
--- a/compiler/luci/import/src/Nodes/CircleDepthwiseConv2D.cpp
+++ b/compiler/luci/import/src/Nodes/CircleDepthwiseConv2D.cpp
@@ -32,6 +32,34 @@ bool CircleDepthwiseConv2DGraphBuilder::validate(const ValidateArgs &args) const
if (args.op.outputs.size() != 1)
return false;
+ const auto tensors = args.reader.tensors();
+
+ // input shape
+ const auto input = tensors.at(args.op.inputs.at(0));
+ assert(input != nullptr);
+ const auto input_shape = wrap(input->shape());
+
+ // input shape must be rank 4
+ if (input_shape.size() != 4)
+ return false;
+
+ // filter shape
+ const auto filter = tensors.at(args.op.inputs.at(1));
+ assert(filter != nullptr);
+ const auto filter_shape = wrap(filter->shape());
+
+ // filter shape must be rank 4
+ if (filter_shape.size() != 4)
+ return false;
+
+ // multiplier
+ const auto *options = args.op.builtin_options.AsDepthwiseConv2DOptions();
+ const auto &multiplier = options->depth_multiplier;
+
+ // filter represents as [1, H, W, C*M] where M is multiplier.
+ if (filter_shape.at(3) != input_shape.at(3) * multiplier)
+ return false;
+
return true;
}
diff --git a/compiler/luci/import/src/Nodes/CircleDequantize.cpp b/compiler/luci/import/src/Nodes/CircleDequantize.cpp
new file mode 100644
index 000000000..3db546bd0
--- /dev/null
+++ b/compiler/luci/import/src/Nodes/CircleDequantize.cpp
@@ -0,0 +1,43 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Import/Nodes/CircleDequantize.h"
+
+#include <luci/IR/Nodes/CircleDequantize.h>
+
+#include <loco.h>
+
+namespace luci
+{
+
+bool CircleDequantizeGraphBuilder::validate(const ValidateArgs &args) const
+{
+ return GraphBuilder::validate(args, 1);
+}
+
+CircleNode *CircleDequantizeGraphBuilder::build_node(const circle::OperatorT &,
+ const std::vector<CircleNode *> &inputs,
+ loco::Graph *graph) const
+{
+ auto *node = graph->nodes()->create<CircleDequantize>();
+ node->input(inputs.at(0));
+
+ // No options for Dequantize
+
+ return node;
+}
+
+} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleDiv.cpp b/compiler/luci/import/src/Nodes/CircleDiv.cpp
index 615c224d7..7ea1afd95 100644
--- a/compiler/luci/import/src/Nodes/CircleDiv.cpp
+++ b/compiler/luci/import/src/Nodes/CircleDiv.cpp
@@ -23,13 +23,7 @@ namespace luci
bool CircleDivGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
- return false;
-
- if (args.op.outputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 2);
}
CircleNode *CircleDivGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleElu.cpp b/compiler/luci/import/src/Nodes/CircleElu.cpp
index 919e95ee4..e5d7a4c7a 100644
--- a/compiler/luci/import/src/Nodes/CircleElu.cpp
+++ b/compiler/luci/import/src/Nodes/CircleElu.cpp
@@ -25,27 +25,32 @@ namespace luci
bool CircleEluGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
-
- if (inputs.size() != 1)
+ if (!GraphBuilder::validate(args, 1))
return false;
- if (outputs.size() != 1)
- return false;
+ const auto &inputs = args.op.inputs;
+ const auto &outputs = args.op.outputs;
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
- switch (tensor->type)
+ switch (tensor->type())
{
+ case circle::TensorType_FLOAT64:
+ break;
case circle::TensorType_FLOAT32:
break;
+ case circle::TensorType_INT16:
+ break;
+ case circle::TensorType_UINT8:
+ break;
default:
return false;
}
- if (tensors[outputs[0]]->type != tensor->type)
+ assert(tensors[outputs[0]] != nullptr);
+ if (tensors[outputs[0]]->type() != tensor->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleEqual.cpp b/compiler/luci/import/src/Nodes/CircleEqual.cpp
index 1db33b8ac..b326d9b5d 100644
--- a/compiler/luci/import/src/Nodes/CircleEqual.cpp
+++ b/compiler/luci/import/src/Nodes/CircleEqual.cpp
@@ -25,16 +25,14 @@ namespace luci
bool CircleEqualGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
-
- if (inputs.size() != 2)
- {
+ if (!GraphBuilder::validate(args, 2))
return false;
- }
- const auto &tensors = args.reader.tensors();
+ const auto &inputs = args.op.inputs;
+ const auto tensors = args.reader.tensors();
- return tensors[inputs.at(0)]->type == tensors[inputs.at(1)]->type;
+ assert(tensors[inputs.at(0)] != nullptr && tensors[inputs.at(1)] != nullptr);
+ return tensors[inputs.at(0)]->type() == tensors[inputs.at(1)]->type();
}
CircleNode *CircleEqualGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleExp.cpp b/compiler/luci/import/src/Nodes/CircleExp.cpp
index 2c031d6b3..82c26f0e5 100644
--- a/compiler/luci/import/src/Nodes/CircleExp.cpp
+++ b/compiler/luci/import/src/Nodes/CircleExp.cpp
@@ -25,19 +25,24 @@ namespace luci
bool CircleExpGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- if (inputs.size() != 1)
+ if (!GraphBuilder::validate(args, 1))
return false;
+ const auto &inputs = args.op.inputs;
// input type check
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
- switch (tensor->type)
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
+ switch (tensor->type())
{
case circle::TensorType_FLOAT16:
case circle::TensorType_FLOAT32:
case circle::TensorType_FLOAT64:
break;
+ // Additional support for quantized tensors
+ case circle::TensorType_UINT8:
+ case circle::TensorType_INT16:
+ break;
// TODO support TensorType_COMPLEX64, complex128, bfloat16
default:
return false;
diff --git a/compiler/luci/import/src/Nodes/CircleExpandDims.cpp b/compiler/luci/import/src/Nodes/CircleExpandDims.cpp
index ab537c710..67d9b7e9e 100644
--- a/compiler/luci/import/src/Nodes/CircleExpandDims.cpp
+++ b/compiler/luci/import/src/Nodes/CircleExpandDims.cpp
@@ -25,16 +25,14 @@ namespace luci
bool CircleExpandDimsGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
-
- if (inputs.size() != 2)
- {
+ if (!GraphBuilder::validate(args, 2))
return false;
- }
- const auto &tensors = args.reader.tensors();
+ const auto &inputs = args.op.inputs;
+ const auto tensors = args.reader.tensors();
- return tensors[inputs.at(1)]->type == circle::TensorType_INT32;
+ assert(tensors[inputs.at(1)] != nullptr);
+ return tensors[inputs.at(1)]->type() == circle::TensorType_INT32;
}
CircleNode *CircleExpandDimsGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleFakeQuant.cpp b/compiler/luci/import/src/Nodes/CircleFakeQuant.cpp
new file mode 100644
index 000000000..7cf40b225
--- /dev/null
+++ b/compiler/luci/import/src/Nodes/CircleFakeQuant.cpp
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Import/Nodes/CircleFakeQuant.h"
+
+#include <luci/IR/Nodes/CircleFullyConnected.h>
+#include <luci/IR/Nodes/CircleOutput.h>
+
+#include <loco.h>
+#include <oops/UserExn.h>
+
+namespace luci
+{
+
+bool CircleFakeQuantGraphBuilder::validate(const ValidateArgs &args) const
+{
+ return GraphBuilder::validate(args, 1);
+}
+
+CircleNode *CircleFakeQuantGraphBuilder::build_node(const circle::OperatorT &op,
+ const std::vector<CircleNode *> &inputs,
+ loco::Graph *graph) const
+{
+ auto *node = graph->nodes()->create<CircleFakeQuant>();
+ node->inputs(inputs.at(0));
+
+ const auto *options = op.builtin_options.AsFakeQuantOptions();
+ node->min(options->min);
+ node->max(options->max);
+ node->num_bits(options->num_bits);
+ node->narrow_range(options->narrow_range);
+
+ return node;
+}
+
+} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleFill.cpp b/compiler/luci/import/src/Nodes/CircleFill.cpp
index 95d5b876b..9aacddcbe 100644
--- a/compiler/luci/import/src/Nodes/CircleFill.cpp
+++ b/compiler/luci/import/src/Nodes/CircleFill.cpp
@@ -23,13 +23,7 @@ namespace luci
bool CircleFillGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
- return false;
-
- if (args.op.outputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 2);
}
CircleNode *CircleFillGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleFloor.cpp b/compiler/luci/import/src/Nodes/CircleFloor.cpp
index ce756b3b1..9651259c7 100644
--- a/compiler/luci/import/src/Nodes/CircleFloor.cpp
+++ b/compiler/luci/import/src/Nodes/CircleFloor.cpp
@@ -25,16 +25,8 @@ namespace luci
bool CircleFloorGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
- if (inputs.size() != 1)
- return false;
- if (outputs.size() != 1)
- return false;
-
// TODO dtype check
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleFloorGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleFloorDiv.cpp b/compiler/luci/import/src/Nodes/CircleFloorDiv.cpp
index 55f385d60..67eeddf91 100644
--- a/compiler/luci/import/src/Nodes/CircleFloorDiv.cpp
+++ b/compiler/luci/import/src/Nodes/CircleFloorDiv.cpp
@@ -25,28 +25,23 @@ namespace luci
bool CircleFloorDivGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
-
- if (inputs.size() != 2)
- {
- return false;
- }
-
- if (outputs.size() != 1)
- {
+ if (!GraphBuilder::validate(args, 2))
return false;
- }
- const auto &tensors = args.reader.tensors();
- const auto &tensor_in_0 = tensors.at(inputs.at(0));
- const auto &tensor_in_1 = tensors.at(inputs.at(1));
- const auto &tensor_out = tensors.at(outputs[0]);
-
- if (tensor_in_0->type != tensor_in_1->type)
+ const auto &inputs = args.op.inputs;
+ const auto &outputs = args.op.outputs;
+ const auto tensors = args.reader.tensors();
+ const auto tensor_in_0 = tensors.at(inputs.at(0));
+ const auto tensor_in_1 = tensors.at(inputs.at(1));
+ const auto tensor_out = tensors.at(outputs[0]);
+ assert(tensor_in_0 != nullptr);
+ assert(tensor_in_1 != nullptr);
+ assert(tensor_out != nullptr);
+
+ if (tensor_in_0->type() != tensor_in_1->type())
return false;
- if (tensor_out->type != tensor_in_1->type)
+ if (tensor_out->type() != tensor_in_1->type())
{
return false;
}
diff --git a/compiler/luci/import/src/Nodes/CircleFloorMod.cpp b/compiler/luci/import/src/Nodes/CircleFloorMod.cpp
index 2101e417e..d2a275b62 100644
--- a/compiler/luci/import/src/Nodes/CircleFloorMod.cpp
+++ b/compiler/luci/import/src/Nodes/CircleFloorMod.cpp
@@ -25,17 +25,15 @@ namespace luci
bool CircleFloorModGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
- if (inputs.size() != 2)
- return false;
- if (outputs.size() != 1)
+ if (!GraphBuilder::validate(args, 2))
return false;
- const auto &tensors = args.reader.tensors();
- const auto &tensor_in_0 = tensors.at(inputs.at(0));
- const auto &tensor_in_1 = tensors.at(inputs.at(1));
- if (tensor_in_0->type != tensor_in_1->type)
+ const auto &inputs = args.op.inputs;
+ const auto tensors = args.reader.tensors();
+ const auto tensor_in_0 = tensors.at(inputs.at(0));
+ const auto tensor_in_1 = tensors.at(inputs.at(1));
+ assert(tensor_in_0 != nullptr && tensor_in_1 != nullptr);
+ if (tensor_in_0->type() != tensor_in_1->type())
return false;
// TODO dtype check
diff --git a/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp b/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp
index 65a863bde..cc7be1693 100644
--- a/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp
+++ b/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp
@@ -27,10 +27,7 @@ namespace luci
bool CircleFullyConnectedGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 3)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 3);
}
CircleNode *CircleFullyConnectedGraphBuilder::build_node(const circle::OperatorT &op,
@@ -42,23 +39,10 @@ CircleNode *CircleFullyConnectedGraphBuilder::build_node(const circle::OperatorT
node->weights(inputs.at(1));
node->bias(inputs.at(2)); // bias is optional
- // TODO Find and move to appropriate place for setting optional input
- if (auto bias = dynamic_cast<luci::CircleOutputExclude *>(node->bias()))
- {
- // bias is not used for type inference, but node itself should have a type
- bias->dtype(loco::DataType::FLOAT32);
-
- // bias is not used for shape inference
- }
-
const auto *options = op.builtin_options.AsFullyConnectedOptions();
node->fusedActivationFunction(luci_actfunc(options->fused_activation_function));
- if (options->weights_format != circle::FullyConnectedOptionsWeightsFormat_DEFAULT)
- {
- throw oops::UserExn(
- "Unsupported weights format",
- circle::EnumNameFullyConnectedOptionsWeightsFormat(options->weights_format));
- }
+ node->weights_format(luci_weights_format(options->weights_format));
+ node->keep_num_dims(options->keep_num_dims);
return node;
}
diff --git a/compiler/luci/import/src/Nodes/CircleGather.cpp b/compiler/luci/import/src/Nodes/CircleGather.cpp
index 75447a38a..8317a3340 100644
--- a/compiler/luci/import/src/Nodes/CircleGather.cpp
+++ b/compiler/luci/import/src/Nodes/CircleGather.cpp
@@ -26,18 +26,14 @@ namespace luci
bool CircleGatherGraphBuilder::validate(const ValidateArgs &args) const
{
+ if (!GraphBuilder::validate(args, 2))
+ return false;
+
const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
const auto *options = args.op.builtin_options.AsGatherOptions();
int32_t axis = options->axis;
- if (inputs.size() != 2)
- return false;
-
- if (outputs.size() != 1)
- return false;
-
if (axis < 0)
axis += inputs.size();
diff --git a/compiler/luci/import/src/Nodes/CircleGatherNd.cpp b/compiler/luci/import/src/Nodes/CircleGatherNd.cpp
index 981adbf63..d336878ad 100644
--- a/compiler/luci/import/src/Nodes/CircleGatherNd.cpp
+++ b/compiler/luci/import/src/Nodes/CircleGatherNd.cpp
@@ -27,19 +27,15 @@ namespace luci
bool CircleGatherNdGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
-
- if (inputs.size() != 2)
+ if (!GraphBuilder::validate(args, 2))
return false;
- if (outputs.size() != 1)
- return false;
-
- auto &indices_tensor = args.reader.tensors()[inputs.at(1)];
+ const auto &inputs = args.op.inputs;
+ auto indices_tensor = args.reader.tensors()[inputs.at(1)];
+ assert(indices_tensor != nullptr);
- if (!(indices_tensor->type == circle::TensorType::TensorType_INT32 ||
- indices_tensor->type == circle::TensorType::TensorType_INT64))
+ if (!(indices_tensor->type() == circle::TensorType::TensorType_INT32 ||
+ indices_tensor->type() == circle::TensorType::TensorType_INT64))
{
return false;
}
diff --git a/compiler/luci/import/src/Nodes/CircleGelu.cpp b/compiler/luci/import/src/Nodes/CircleGelu.cpp
new file mode 100644
index 000000000..89b325f82
--- /dev/null
+++ b/compiler/luci/import/src/Nodes/CircleGelu.cpp
@@ -0,0 +1,44 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Import/Nodes/CircleGelu.h"
+
+#include <luci/IR/Nodes/CircleGelu.h>
+
+#include <loco.h>
+
+namespace luci
+{
+
+bool CircleGeluGraphBuilder::validate(const ValidateArgs &args) const
+{
+ return GraphBuilder::validate(args, 1);
+}
+
+CircleNode *CircleGeluGraphBuilder::build_node(const circle::OperatorT &op,
+ const std::vector<CircleNode *> &inputs,
+ loco::Graph *graph) const
+{
+ auto *node = graph->nodes()->create<CircleGelu>();
+ node->features(inputs.at(0));
+
+ const auto *options = op.builtin_options.AsGeluOptions();
+ node->approximate(options->approximate);
+
+ return node;
+}
+
+} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleGreater.cpp b/compiler/luci/import/src/Nodes/CircleGreater.cpp
index 1ad0467e4..7f031b0ba 100644
--- a/compiler/luci/import/src/Nodes/CircleGreater.cpp
+++ b/compiler/luci/import/src/Nodes/CircleGreater.cpp
@@ -30,28 +30,26 @@ bool CircleGreaterGraphBuilder::validate(const ValidateArgs &args) const
{
LOGGER(l);
+ if (!GraphBuilder::validate(args, 2))
+ return false;
+
auto settings = luci::UserSettings::settings();
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
+ const auto tensors = args.reader.tensors();
- if (inputs.size() != 2)
- return false;
-
- if (outputs.size() != 1)
- return false;
-
- const auto &tensors = args.reader.tensors();
-
- if (tensors[inputs.at(0)]->type != tensors[inputs.at(1)]->type)
+ assert(tensors[inputs.at(0)] != nullptr && tensors[inputs.at(1)] != nullptr);
+ if (tensors[inputs.at(0)]->type() != tensors[inputs.at(1)]->type())
return false;
// NOTE: real models do have output dtype NOT BOOL
- if (tensors[outputs[0]]->type != circle::TensorType_BOOL)
+ assert(tensors[outputs[0]] != nullptr);
+ if (tensors[outputs[0]]->type() != circle::TensorType_BOOL)
{
if (settings->get(luci::UserSettings::Key::DisableValidation))
{
- const circle::TensorT &output_tensor = *tensors[outputs[0]];
+ const auto output_tensor = tensors[outputs[0]];
auto name = tensor_name(output_tensor);
WARN(l) << "Warning: import Greater(" << name << ") output dtype is not boolean";
}
diff --git a/compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp b/compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp
index 0ac63b017..ac4ce62f5 100644
--- a/compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp
+++ b/compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp
@@ -25,27 +25,21 @@ namespace luci
bool CircleGreaterEqualGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
-
- if (inputs.size() != 2)
- {
+ if (!GraphBuilder::validate(args, 2))
return false;
- }
- if (outputs.size() != 1)
- {
- return false;
- }
-
- const auto &tensors = args.reader.tensors();
+ const auto &inputs = args.op.inputs;
+ const auto &outputs = args.op.outputs;
+ const auto tensors = args.reader.tensors();
- if (tensors[inputs.at(0)]->type != tensors[inputs.at(1)]->type)
+ assert(tensors[inputs.at(0)] != nullptr && tensors[inputs.at(1)] != nullptr);
+ if (tensors[inputs.at(0)]->type() != tensors[inputs.at(1)]->type())
{
return false;
}
- return tensors[outputs[0]]->type == circle::TensorType::TensorType_BOOL;
+ assert(tensors[outputs[0]] != nullptr);
+ return tensors[outputs[0]]->type() == circle::TensorType::TensorType_BOOL;
}
CircleNode *CircleGreaterEqualGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleHardSwish.cpp b/compiler/luci/import/src/Nodes/CircleHardSwish.cpp
new file mode 100644
index 000000000..47fc1c92c
--- /dev/null
+++ b/compiler/luci/import/src/Nodes/CircleHardSwish.cpp
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Import/Nodes/CircleHardSwish.h"
+
+#include <luci/IR/Nodes/CircleHardSwish.h>
+
+#include <loco.h>
+
+namespace luci
+{
+
+bool CircleHardSwishGraphBuilder::validate(const ValidateArgs &args) const
+{
+ return GraphBuilder::validate(args, 1);
+}
+
+CircleNode *CircleHardSwishGraphBuilder::build_node(const circle::OperatorT &,
+ const std::vector<CircleNode *> &inputs,
+ loco::Graph *graph) const
+{
+ auto *node = graph->nodes()->create<CircleHardSwish>();
+ node->features(inputs.at(0));
+
+ return node;
+}
+
+} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleIf.cpp b/compiler/luci/import/src/Nodes/CircleIf.cpp
index db9ffe1cd..e8a50ff32 100644
--- a/compiler/luci/import/src/Nodes/CircleIf.cpp
+++ b/compiler/luci/import/src/Nodes/CircleIf.cpp
@@ -42,12 +42,13 @@ bool CircleIfGraphBuilder::validate(const ValidateArgs &args) const
return false;
// input 0 should be BOOL type
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
- if (tensor->type != circle::TensorType_BOOL)
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
+ if (tensor->type() != circle::TensorType_BOOL)
return false;
- const auto &shape = tensor->shape;
+ const auto shape = wrap(tensor->shape());
if (shape.size() != 1 && shape.size() != 0)
return false;
@@ -70,69 +71,34 @@ bool CircleIfGraphBuilder::validate(const ValidateArgs &args) const
* \- CircleIfOut --- Node ---
*/
-void CircleIfGraphBuilder::build(const circle::OperatorT &op, GraphBuilderContext *context) const
+CircleNode *CircleIfGraphBuilder::build_node(const BuildNodeArgs &bna) const
{
- assert(context != nullptr);
+ uint32_t input_count = bna.op.inputs.size() - 1;
+ uint32_t output_count = bna.op.outputs.size();
- auto graph = context->graph();
+ auto *node = bna.context->graph()->nodes()->create<CircleIf>(input_count, output_count);
- const std::vector<int32_t> &inputs = op.inputs;
- const std::vector<int32_t> &outputs = op.outputs;
- const auto &tensors = context->reader()->tensors();
- const auto &opcodes = context->reader()->opcodes();
- auto tensors_ptr = context->reader()->tensors_ptr();
- assert(tensors_ptr != nullptr);
-
- std::vector<CircleNode *> input_nodes;
- for (const int32_t input_tensor_index : inputs)
- {
- input_nodes.push_back(context->nodefinder()->node(input_tensor_index));
- }
-
- uint32_t input_count = inputs.size() - 1;
- uint32_t output_count = outputs.size();
-
- // Create CircleIf
- CircleIf *node = graph->nodes()->create<CircleIf>(input_count, output_count);
-
- node->cond(input_nodes[0]);
+ node->cond(bna.input_nodes[0]);
for (uint32_t idx = 0; idx < input_count; ++idx)
{
- node->input(idx, input_nodes[idx + 1]);
+ node->input(idx, bna.input_nodes[idx + 1]);
}
- const auto *options = op.builtin_options.AsIfOptions();
+ const auto *options = bna.op.builtin_options.AsIfOptions();
node->then_branch(options->then_subgraph_index);
node->else_branch(options->else_subgraph_index);
- assert(outputs.size() > 0);
- {
- // Lets use name of output 0 as If name
- const circle::TensorT &output_tensor = *tensors[outputs[0]];
- node->name(tensor_name(output_tensor));
- node->op_version(opcodes[op.opcode_index].get()->version);
-
- // NOTE We don't set quantization for If itself but to virtual outputs
- }
-
- // Create virtual outputs of If
- for (uint32_t n = 0; n < output_count; ++n)
- {
- const circle::TensorT &output_tensor = *tensors[outputs[n]];
+ return node;
+}
- auto *nodeout = graph->nodes()->create<CircleIfOut>();
- copy_tensor_attributes(output_tensor, nodeout);
- // mark shape_status
- if (tensors_ptr->Get(outputs[n])->shape() == nullptr)
- nodeout->shape_status(ShapeStatus::NOSHAPE);
- else
- nodeout->shape_status(ShapeStatus::VALID);
+CircleNode *CircleIfGraphBuilder::build_out(const BuildOutArgs &boa) const
+{
+ auto *nodeout = boa.node->graph()->nodes()->create<CircleIfOut>();
- nodeout->input(node);
- nodeout->index(n);
+ nodeout->input(boa.node);
+ nodeout->index(boa.index);
- context->nodefinder()->enroll(outputs[n], nodeout);
- }
+ return nodeout;
}
} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleInstanceNorm.cpp b/compiler/luci/import/src/Nodes/CircleInstanceNorm.cpp
index 6349fd3b7..977b53406 100644
--- a/compiler/luci/import/src/Nodes/CircleInstanceNorm.cpp
+++ b/compiler/luci/import/src/Nodes/CircleInstanceNorm.cpp
@@ -25,12 +25,8 @@ namespace luci
bool CircleInstanceNormGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 3)
- return false;
-
// TODO check dtypes
-
- return true;
+ return GraphBuilder::validate(args, 3);
}
CircleNode *CircleInstanceNormGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleL2Normalize.cpp b/compiler/luci/import/src/Nodes/CircleL2Normalize.cpp
index e4fdc200c..7e1faedfb 100644
--- a/compiler/luci/import/src/Nodes/CircleL2Normalize.cpp
+++ b/compiler/luci/import/src/Nodes/CircleL2Normalize.cpp
@@ -25,20 +25,7 @@ namespace luci
bool CircleL2NormalizeGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
-
- if (inputs.size() != 1)
- {
- return false;
- }
-
- if (outputs.size() != 1)
- {
- return false;
- }
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleL2NormalizeGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleL2Pool2D.cpp b/compiler/luci/import/src/Nodes/CircleL2Pool2D.cpp
index 202d9d6fb..849c7c5ed 100644
--- a/compiler/luci/import/src/Nodes/CircleL2Pool2D.cpp
+++ b/compiler/luci/import/src/Nodes/CircleL2Pool2D.cpp
@@ -25,12 +25,8 @@ namespace luci
bool CircleL2Pool2DGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
- return false;
-
// TODO check dtypes
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleL2Pool2DGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleLeakyRelu.cpp b/compiler/luci/import/src/Nodes/CircleLeakyRelu.cpp
index ad4979f39..880fa6428 100644
--- a/compiler/luci/import/src/Nodes/CircleLeakyRelu.cpp
+++ b/compiler/luci/import/src/Nodes/CircleLeakyRelu.cpp
@@ -25,13 +25,7 @@ namespace luci
bool CircleLeakyReluGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
- return false;
-
- if (args.op.outputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleLeakyReluGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleLess.cpp b/compiler/luci/import/src/Nodes/CircleLess.cpp
index 506036908..5c5ae51e1 100644
--- a/compiler/luci/import/src/Nodes/CircleLess.cpp
+++ b/compiler/luci/import/src/Nodes/CircleLess.cpp
@@ -25,23 +25,16 @@ namespace luci
bool CircleLessGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
-
- if (inputs.size() != 2)
- {
+ if (!GraphBuilder::validate(args, 2))
return false;
- }
- if (outputs.size() != 1)
- {
- return false;
- }
-
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
+ const auto &inputs = args.op.inputs;
+ const auto &outputs = args.op.outputs;
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
- switch (tensor->type)
+ switch (tensor->type())
{
case circle::TensorType_FLOAT32:
case circle::TensorType_FLOAT64:
@@ -56,12 +49,14 @@ bool CircleLessGraphBuilder::validate(const ValidateArgs &args) const
return false;
}
- if (tensors[inputs.at(1)]->type != tensor->type)
+ assert(tensors[inputs.at(1)] != nullptr);
+ if (tensors[inputs.at(1)]->type() != tensor->type())
{
return false;
}
- return tensors[outputs[0]]->type == circle::TensorType_BOOL;
+ assert(tensors[outputs[0]] != nullptr);
+ return tensors[outputs[0]]->type() == circle::TensorType_BOOL;
}
CircleNode *CircleLessGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleLessEqual.cpp b/compiler/luci/import/src/Nodes/CircleLessEqual.cpp
index 9b4f934a5..8a2aea8db 100644
--- a/compiler/luci/import/src/Nodes/CircleLessEqual.cpp
+++ b/compiler/luci/import/src/Nodes/CircleLessEqual.cpp
@@ -25,27 +25,21 @@ namespace luci
bool CircleLessEqualGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
-
- if (inputs.size() != 2)
- {
+ if (!GraphBuilder::validate(args, 2))
return false;
- }
- if (outputs.size() != 1)
- {
- return false;
- }
-
- const auto &tensors = args.reader.tensors();
+ const auto &inputs = args.op.inputs;
+ const auto &outputs = args.op.outputs;
+ const auto tensors = args.reader.tensors();
- if (tensors[inputs.at(0)]->type != tensors[inputs.at(1)]->type)
+ assert(tensors[inputs.at(0)] != nullptr && tensors[inputs.at(1)] != nullptr);
+ if (tensors[inputs.at(0)]->type() != tensors[inputs.at(1)]->type())
{
return false;
}
- return tensors[outputs[0]]->type == circle::TensorType::TensorType_BOOL;
+ assert(tensors[outputs[0]] != nullptr);
+ return tensors[outputs[0]]->type() == circle::TensorType::TensorType_BOOL;
}
CircleNode *CircleLessEqualGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleLocalResponseNormalization.cpp b/compiler/luci/import/src/Nodes/CircleLocalResponseNormalization.cpp
index 0e32f62de..d03c47d12 100644
--- a/compiler/luci/import/src/Nodes/CircleLocalResponseNormalization.cpp
+++ b/compiler/luci/import/src/Nodes/CircleLocalResponseNormalization.cpp
@@ -25,16 +25,12 @@ namespace luci
bool CircleLocalResponseNormalizationGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
- return false;
-
// TODO do attribute checks
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleLocalResponseNormalizationGraphBuilder::build_node(
- const circle::OperatorT &op, const std::vector<CircleNode *> &inputs, loco::Graph *graph) const
+ const circle::OperatorT &op, const std::vector<CircleNode *> &inputs, loco::Graph *graph) const
{
auto *node = graph->nodes()->create<CircleLocalResponseNormalization>();
node->input(inputs.at(0));
diff --git a/compiler/luci/import/src/Nodes/CircleLog.cpp b/compiler/luci/import/src/Nodes/CircleLog.cpp
index 346fc43bb..f41926829 100644
--- a/compiler/luci/import/src/Nodes/CircleLog.cpp
+++ b/compiler/luci/import/src/Nodes/CircleLog.cpp
@@ -25,18 +25,17 @@ namespace luci
bool CircleLogGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- if (inputs.size() != 1)
- return false;
- if (args.op.outputs.size() != 1)
+ if (!GraphBuilder::validate(args, 1))
return false;
+ const auto &inputs = args.op.inputs;
// input type check
// Must be one of bfloat16, half, float32, float64, complex64, complex128.
// Currently circle supports half(float16), float32, float64, complex64.
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
- switch (tensor->type)
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
+ switch (tensor->type())
{
case circle::TensorType_FLOAT16:
case circle::TensorType_FLOAT32:
diff --git a/compiler/luci/import/src/Nodes/CircleLogSoftmax.cpp b/compiler/luci/import/src/Nodes/CircleLogSoftmax.cpp
index ef69e868a..4361db691 100644
--- a/compiler/luci/import/src/Nodes/CircleLogSoftmax.cpp
+++ b/compiler/luci/import/src/Nodes/CircleLogSoftmax.cpp
@@ -25,12 +25,8 @@ namespace luci
bool CircleLogSoftmaxGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
- return false;
-
// TODO do attribute checks
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleLogSoftmaxGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleLogicalAnd.cpp b/compiler/luci/import/src/Nodes/CircleLogicalAnd.cpp
index 7844da0f6..b61fb6f3e 100644
--- a/compiler/luci/import/src/Nodes/CircleLogicalAnd.cpp
+++ b/compiler/luci/import/src/Nodes/CircleLogicalAnd.cpp
@@ -25,16 +25,17 @@ namespace luci
bool CircleLogicalAndGraphBuilder::validate(const ValidateArgs &args) const
{
- // Only BOOL type is allowed for inputs
- const auto &inputs = args.op.inputs;
- if (inputs.size() != 2)
+ if (!GraphBuilder::validate(args, 2))
return false;
- const auto &tensors = args.reader.tensors();
+ // Only BOOL type is allowed for inputs
+ const auto &inputs = args.op.inputs;
+ const auto tensors = args.reader.tensors();
for (auto input : inputs)
{
- const auto &tensor = tensors.at(input);
- if (tensor->type != circle::TensorType::TensorType_BOOL)
+ const auto tensor = tensors.at(input);
+ assert(tensor != nullptr);
+ if (tensor->type() != circle::TensorType::TensorType_BOOL)
return false;
}
diff --git a/compiler/luci/import/src/Nodes/CircleLogicalNot.cpp b/compiler/luci/import/src/Nodes/CircleLogicalNot.cpp
index 3758642e4..43e9ed39f 100644
--- a/compiler/luci/import/src/Nodes/CircleLogicalNot.cpp
+++ b/compiler/luci/import/src/Nodes/CircleLogicalNot.cpp
@@ -25,14 +25,15 @@ namespace luci
bool CircleLogicalNotGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
+ if (!GraphBuilder::validate(args, 1))
return false;
// Only BOOL type is allowed for the input
const auto &inputs = args.op.inputs;
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
- if (tensor->type != circle::TensorType::TensorType_BOOL)
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
+ if (tensor->type() != circle::TensorType::TensorType_BOOL)
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleLogicalOr.cpp b/compiler/luci/import/src/Nodes/CircleLogicalOr.cpp
index 1b87e6f9c..6354e7dc1 100644
--- a/compiler/luci/import/src/Nodes/CircleLogicalOr.cpp
+++ b/compiler/luci/import/src/Nodes/CircleLogicalOr.cpp
@@ -25,16 +25,17 @@ namespace luci
bool CircleLogicalOrGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
+ if (!GraphBuilder::validate(args, 2))
return false;
// Only BOOL type is allowed for inputs
const auto &inputs = args.op.inputs;
- const auto &tensors = args.reader.tensors();
+ const auto tensors = args.reader.tensors();
for (auto input : inputs)
{
- const auto &tensor = tensors.at(input);
- if (tensor->type != circle::TensorType::TensorType_BOOL)
+ const auto tensor = tensors.at(input);
+ assert(tensor != nullptr);
+ if (tensor->type() != circle::TensorType::TensorType_BOOL)
return false;
}
diff --git a/compiler/luci/import/src/Nodes/CircleLogistic.cpp b/compiler/luci/import/src/Nodes/CircleLogistic.cpp
index 9606e19cd..b0d08e039 100644
--- a/compiler/luci/import/src/Nodes/CircleLogistic.cpp
+++ b/compiler/luci/import/src/Nodes/CircleLogistic.cpp
@@ -25,15 +25,14 @@ namespace luci
bool CircleLogisticGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- if (inputs.size() != 1)
- return false;
- const auto &outputs = args.op.outputs;
- if (outputs.size() != 1)
+ if (!GraphBuilder::validate(args, 1))
return false;
- const auto &tensors = args.reader.tensors();
- if (tensors.at(inputs.at(0))->type != tensors.at(outputs[0])->type)
+ const auto &inputs = args.op.inputs;
+ const auto &outputs = args.op.outputs;
+ const auto tensors = args.reader.tensors();
+ assert(tensors.at(inputs.at(0)) != nullptr && tensors.at(outputs[0]) != nullptr);
+ if (tensors.at(inputs.at(0))->type() != tensors.at(outputs[0])->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleMatrixDiag.cpp b/compiler/luci/import/src/Nodes/CircleMatrixDiag.cpp
index a4a21a8b7..384b98586 100644
--- a/compiler/luci/import/src/Nodes/CircleMatrixDiag.cpp
+++ b/compiler/luci/import/src/Nodes/CircleMatrixDiag.cpp
@@ -25,19 +25,16 @@ namespace luci
bool CircleMatrixDiagGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
-
- if (inputs.size() != 1)
+ if (!GraphBuilder::validate(args, 1))
return false;
- if (outputs.size() != 1)
- return false;
-
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
+ const auto &inputs = args.op.inputs;
+ const auto &outputs = args.op.outputs;
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
- if (tensors[outputs[0]]->type != tensor->type)
+ assert(tensors[outputs[0]] != nullptr && tensor != nullptr);
+ if (tensors[outputs[0]]->type() != tensor->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp b/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp
index cf0313149..64870c057 100644
--- a/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp
+++ b/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp
@@ -25,19 +25,16 @@ namespace luci
bool CircleMatrixSetDiagGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
-
- if (inputs.size() != 2)
+ if (!GraphBuilder::validate(args, 2))
return false;
- if (outputs.size() != 1)
- return false;
-
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
+ const auto &inputs = args.op.inputs;
+ const auto &outputs = args.op.outputs;
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
- if (tensors[outputs[0]]->type != tensor->type)
+ assert(tensors[outputs[0]] != nullptr && tensor != nullptr);
+ if (tensors[outputs[0]]->type() != tensor->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleMaxPool2D.cpp b/compiler/luci/import/src/Nodes/CircleMaxPool2D.cpp
index 4bca0f40b..5c03fff18 100644
--- a/compiler/luci/import/src/Nodes/CircleMaxPool2D.cpp
+++ b/compiler/luci/import/src/Nodes/CircleMaxPool2D.cpp
@@ -25,10 +25,7 @@ namespace luci
bool CircleMaxPool2DGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleMaxPool2DGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleMean.cpp b/compiler/luci/import/src/Nodes/CircleMean.cpp
index d8fa9a53d..7882f17fc 100644
--- a/compiler/luci/import/src/Nodes/CircleMean.cpp
+++ b/compiler/luci/import/src/Nodes/CircleMean.cpp
@@ -23,10 +23,7 @@ namespace luci
bool CircleMeanGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 2);
}
CircleNode *CircleMeanGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleMirrorPad.cpp b/compiler/luci/import/src/Nodes/CircleMirrorPad.cpp
index e0ddd4c11..e40ce2249 100644
--- a/compiler/luci/import/src/Nodes/CircleMirrorPad.cpp
+++ b/compiler/luci/import/src/Nodes/CircleMirrorPad.cpp
@@ -25,12 +25,8 @@ namespace luci
bool CircleMirrorPadGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
- return false;
-
// TODO check others
-
- return true;
+ return GraphBuilder::validate(args, 2);
}
CircleNode *CircleMirrorPadGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleMul.cpp b/compiler/luci/import/src/Nodes/CircleMul.cpp
index e3c4a7ee5..28421f8c4 100644
--- a/compiler/luci/import/src/Nodes/CircleMul.cpp
+++ b/compiler/luci/import/src/Nodes/CircleMul.cpp
@@ -23,13 +23,7 @@ namespace luci
bool CircleMulGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
- return false;
-
- if (args.op.outputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 2);
}
CircleNode *CircleMulGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleNeg.cpp b/compiler/luci/import/src/Nodes/CircleNeg.cpp
index a64a69560..9dd1458f4 100644
--- a/compiler/luci/import/src/Nodes/CircleNeg.cpp
+++ b/compiler/luci/import/src/Nodes/CircleNeg.cpp
@@ -24,11 +24,8 @@ namespace luci
{
bool CircleNegGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
- return false;
-
// TODO Support type check
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleNegGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV4.cpp b/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV4.cpp
index a4ad4a53d..e86f2ba81 100644
--- a/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV4.cpp
+++ b/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV4.cpp
@@ -35,20 +35,26 @@ bool CircleNonMaxSuppressionV4GraphBuilder::validate(const ValidateArgs &args) c
if (outputs.size() != 2)
return false;
- const auto &tensors = args.reader.tensors();
- const auto &boxes_tensor = tensors.at(inputs[0]);
- if (boxes_tensor->shape.size() != 2)
+ const auto tensors = args.reader.tensors();
+ const auto boxes_tensor = tensors.at(inputs[0]);
+ assert(boxes_tensor != nullptr);
+ const auto boxes_tensor_shape = wrap(boxes_tensor->shape());
+ if (boxes_tensor_shape.size() != 2)
return false;
- if (boxes_tensor->shape.at(1) != 4)
+ if (boxes_tensor_shape.at(1) != 4)
return false;
- if (boxes_tensor->shape.at(0) != tensors.at(inputs[1])->shape.at(0))
+ assert(tensors.at(inputs[1]) != nullptr);
+ if (boxes_tensor_shape.at(0) != wrap(tensors.at(inputs[1])->shape()).at(0))
return false;
- if (tensors.at(inputs[2])->type != circle::TensorType_INT32)
+ assert(tensors.at(inputs[2]) != nullptr);
+ if (tensors.at(inputs[2])->type() != circle::TensorType_INT32)
return false;
- if (tensors.at(inputs[3])->type != circle::TensorType_FLOAT32)
+ assert(tensors.at(inputs[3]) != nullptr);
+ if (tensors.at(inputs[3])->type() != circle::TensorType_FLOAT32)
return false;
- if (tensors.at(inputs[4])->type != circle::TensorType_FLOAT32)
+ assert(tensors.at(inputs[4]) != nullptr);
+ if (tensors.at(inputs[4])->type() != circle::TensorType_FLOAT32)
return false;
return true;
@@ -61,63 +67,27 @@ bool CircleNonMaxSuppressionV4GraphBuilder::validate(const ValidateArgs &args) c
* We will create multiple NonMasSuppressionV4Oout nodes to emulate this
*/
-void CircleNonMaxSuppressionV4GraphBuilder::build(const circle::OperatorT &op,
- GraphBuilderContext *context) const
+CircleNode *CircleNonMaxSuppressionV4GraphBuilder::build_node(const BuildNodeArgs &bna) const
{
- assert(context != nullptr);
-
- auto graph = context->graph();
-
- const std::vector<int32_t> &inputs = op.inputs;
- const std::vector<int32_t> &outputs = op.outputs;
- const auto &tensors = context->reader()->tensors();
- const auto &opcodes = context->reader()->opcodes();
- auto tensors_ptr = context->reader()->tensors_ptr();
- assert(tensors_ptr != nullptr);
-
- std::vector<CircleNode *> input_nodes;
- for (const int32_t input_tensor_index : inputs)
- {
- input_nodes.push_back(context->nodefinder()->node(input_tensor_index));
- }
-
- // Create CircleNonMaxSuppressionV4
- auto node = graph->nodes()->create<CircleNonMaxSuppressionV4>();
- node->boxes(input_nodes[0]);
- node->scores(input_nodes[1]);
- node->max_output_size(input_nodes[2]);
- node->iou_threshold(input_nodes[3]);
- node->score_threshold(input_nodes[4]);
-
- assert(outputs.size() == 2);
- {
- // Let's use name of output 0 as NonMaxSuppressionV4 name
- const circle::TensorT &output_tensor = *tensors[outputs[0]];
- node->name(tensor_name(output_tensor));
- node->op_version(opcodes[op.opcode_index].get()->version);
-
- // NOTE We don't set quantization for NonMaxSuppressionV4 itself but to virtual outputs
- }
-
- // Create virtual outputs of NonMaxSuppressionV4
- for (size_t n = 0; n < outputs.size(); ++n)
- {
- const circle::TensorT &output_tensor = *tensors[outputs[n]];
-
- auto *nodeout = graph->nodes()->create<CircleNonMaxSuppressionV4Out>();
- copy_tensor_attributes(output_tensor, nodeout);
-
- // mark shape_status
- if (tensors_ptr->Get(outputs[n])->shape() == nullptr)
- nodeout->shape_status(ShapeStatus::NOSHAPE);
- else
- nodeout->shape_status(ShapeStatus::VALID);
-
- nodeout->input(node);
- nodeout->index(n);
-
- context->nodefinder()->enroll(outputs[n], nodeout);
- }
+ auto node = bna.context->graph()->nodes()->create<CircleNonMaxSuppressionV4>();
+
+ node->boxes(bna.input_nodes[0]);
+ node->scores(bna.input_nodes[1]);
+ node->max_output_size(bna.input_nodes[2]);
+ node->iou_threshold(bna.input_nodes[3]);
+ node->score_threshold(bna.input_nodes[4]);
+
+ return node;
+}
+
+CircleNode *CircleNonMaxSuppressionV4GraphBuilder::build_out(const BuildOutArgs &boa) const
+{
+ auto *nodeout = boa.node->graph()->nodes()->create<CircleNonMaxSuppressionV4Out>();
+
+ nodeout->input(boa.node);
+ nodeout->index(boa.index);
+
+ return nodeout;
}
} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV5.cpp b/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV5.cpp
index 241dbf5ff..a60eed4e4 100644
--- a/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV5.cpp
+++ b/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV5.cpp
@@ -35,22 +35,29 @@ bool CircleNonMaxSuppressionV5GraphBuilder::validate(const ValidateArgs &args) c
if (outputs.size() != 3)
return false;
- const auto &tensors = args.reader.tensors();
- const auto &boxes_tensor = tensors.at(inputs[0]);
- if (boxes_tensor->shape.size() != 2)
+ const auto tensors = args.reader.tensors();
+ const auto boxes_tensor = tensors.at(inputs[0]);
+ assert(boxes_tensor != nullptr);
+ const auto boxes_tensor_shape = wrap(boxes_tensor->shape());
+ if (boxes_tensor_shape.size() != 2)
return false;
- if (boxes_tensor->shape.at(1) != 4)
+ if (boxes_tensor_shape.at(1) != 4)
return false;
- if (boxes_tensor->shape.at(0) != tensors.at(inputs[1])->shape.at(0))
+ assert(tensors.at(inputs[1]) != nullptr);
+ if (boxes_tensor_shape.at(0) != wrap(tensors.at(inputs[1])->shape()).at(0))
return false;
- if (tensors.at(inputs[2])->type != circle::TensorType_INT32)
+ assert(tensors.at(inputs[2]) != nullptr);
+ if (tensors.at(inputs[2])->type() != circle::TensorType_INT32)
return false;
- if (tensors.at(inputs[3])->type != circle::TensorType_FLOAT32)
+ assert(tensors.at(inputs[3]) != nullptr);
+ if (tensors.at(inputs[3])->type() != circle::TensorType_FLOAT32)
return false;
- if (tensors.at(inputs[4])->type != circle::TensorType_FLOAT32)
+ assert(tensors.at(inputs[4]) != nullptr);
+ if (tensors.at(inputs[4])->type() != circle::TensorType_FLOAT32)
return false;
- if (tensors.at(inputs[5])->type != circle::TensorType_FLOAT32)
+ assert(tensors.at(inputs[5]) != nullptr);
+ if (tensors.at(inputs[5])->type() != circle::TensorType_FLOAT32)
return false;
return true;
@@ -63,64 +70,28 @@ bool CircleNonMaxSuppressionV5GraphBuilder::validate(const ValidateArgs &args) c
* We will create multiple NonMasSuppressionV5Oout nodes to emulate this
*/
-void CircleNonMaxSuppressionV5GraphBuilder::build(const circle::OperatorT &op,
- GraphBuilderContext *context) const
+CircleNode *CircleNonMaxSuppressionV5GraphBuilder::build_node(const BuildNodeArgs &bna) const
{
- assert(context != nullptr);
-
- auto graph = context->graph();
-
- const std::vector<int32_t> &inputs = op.inputs;
- const std::vector<int32_t> &outputs = op.outputs;
- const auto &tensors = context->reader()->tensors();
- const auto &opcodes = context->reader()->opcodes();
- auto tensors_ptr = context->reader()->tensors_ptr();
- assert(tensors_ptr != nullptr);
-
- std::vector<CircleNode *> input_nodes;
- for (const int32_t input_tensor_index : inputs)
- {
- input_nodes.push_back(context->nodefinder()->node(input_tensor_index));
- }
-
- // Create CircleNonMaxSuppressionV5
- auto node = graph->nodes()->create<CircleNonMaxSuppressionV5>();
- node->boxes(input_nodes[0]);
- node->scores(input_nodes[1]);
- node->max_output_size(input_nodes[2]);
- node->iou_threshold(input_nodes[3]);
- node->score_threshold(input_nodes[4]);
- node->soft_nms_sigma(input_nodes[5]);
-
- assert(outputs.size() == 3);
- {
- // Let's use name of output 0 as NonMaxSuppressionV5 name
- const circle::TensorT &output_tensor = *tensors[outputs[0]];
- node->name(tensor_name(output_tensor));
- node->op_version(opcodes[op.opcode_index].get()->version);
-
- // NOTE We don't set quantization for NonMaxSuppressionV5 itself but to virtual outputs
- }
-
- // Create virtual outputs of NonMaxSuppressionV5
- for (size_t n = 0; n < outputs.size(); ++n)
- {
- const circle::TensorT &output_tensor = *tensors[outputs[n]];
-
- auto *nodeout = graph->nodes()->create<CircleNonMaxSuppressionV5Out>();
- copy_tensor_attributes(output_tensor, nodeout);
-
- // mark shape_status
- if (tensors_ptr->Get(outputs[n])->shape() == nullptr)
- nodeout->shape_status(ShapeStatus::NOSHAPE);
- else
- nodeout->shape_status(ShapeStatus::VALID);
-
- nodeout->input(node);
- nodeout->index(n);
-
- context->nodefinder()->enroll(outputs[n], nodeout);
- }
+ auto node = bna.context->graph()->nodes()->create<CircleNonMaxSuppressionV5>();
+
+ node->boxes(bna.input_nodes[0]);
+ node->scores(bna.input_nodes[1]);
+ node->max_output_size(bna.input_nodes[2]);
+ node->iou_threshold(bna.input_nodes[3]);
+ node->score_threshold(bna.input_nodes[4]);
+ node->soft_nms_sigma(bna.input_nodes[5]);
+
+ return node;
+}
+
+CircleNode *CircleNonMaxSuppressionV5GraphBuilder::build_out(const BuildOutArgs &boa) const
+{
+ auto *nodeout = boa.node->graph()->nodes()->create<CircleNonMaxSuppressionV5Out>();
+
+ nodeout->input(boa.node);
+ nodeout->index(boa.index);
+
+ return nodeout;
}
} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleNotEqual.cpp b/compiler/luci/import/src/Nodes/CircleNotEqual.cpp
index 77e986de1..3f5c1e033 100644
--- a/compiler/luci/import/src/Nodes/CircleNotEqual.cpp
+++ b/compiler/luci/import/src/Nodes/CircleNotEqual.cpp
@@ -25,27 +25,21 @@ namespace luci
bool CircleNotEqualGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
-
- if (inputs.size() != 2)
- {
+ if (!GraphBuilder::validate(args, 2))
return false;
- }
- if (outputs.size() != 1)
- {
- return false;
- }
-
- const auto &tensors = args.reader.tensors();
+ const auto &inputs = args.op.inputs;
+ const auto &outputs = args.op.outputs;
+ const auto tensors = args.reader.tensors();
- if (tensors[inputs.at(0)]->type != tensors[inputs.at(1)]->type)
+ assert(tensors[inputs.at(0)] != nullptr && tensors[inputs.at(1)] != nullptr);
+ if (tensors[inputs.at(0)]->type() != tensors[inputs.at(1)]->type())
{
return false;
}
- return tensors[outputs[0]]->type == circle::TensorType::TensorType_BOOL;
+ assert(tensors[outputs[0]] != nullptr);
+ return tensors[outputs[0]]->type() == circle::TensorType::TensorType_BOOL;
}
CircleNode *CircleNotEqualGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleOneHot.cpp b/compiler/luci/import/src/Nodes/CircleOneHot.cpp
index 69294e1ed..6e5f8e16f 100644
--- a/compiler/luci/import/src/Nodes/CircleOneHot.cpp
+++ b/compiler/luci/import/src/Nodes/CircleOneHot.cpp
@@ -26,32 +26,31 @@ namespace luci
bool CircleOneHotGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
- const auto *options = args.op.builtin_options.AsOneHotOptions();
-
// Only 4 Input come refered from
- if (inputs.size() != 4)
+ if (!GraphBuilder::validate(args, 4))
return false;
- if (outputs.size() != 1)
- return false;
-
- const auto &tensors = args.reader.tensors();
- const auto &indices = tensors.at(inputs.at(0));
- const auto &depth = tensors.at(inputs.at(1));
- const auto &on_value = tensors.at(inputs.at(2));
- const auto &off_value = tensors.at(inputs.at(3));
+ const auto &inputs = args.op.inputs;
+ const auto *options = args.op.builtin_options.AsOneHotOptions();
+ const auto tensors = args.reader.tensors();
+ const auto indices = tensors.at(inputs.at(0));
+ const auto depth = tensors.at(inputs.at(1));
+ const auto on_value = tensors.at(inputs.at(2));
+ const auto off_value = tensors.at(inputs.at(3));
+ assert(indices != nullptr);
+ assert(depth != nullptr);
+ assert(on_value != nullptr);
+ assert(off_value != nullptr);
- if (options->axis < -1 || options->axis > static_cast<int32_t>(indices->shape.size()))
+ if (options->axis < -1 || options->axis > static_cast<int32_t>(wrap(indices->shape()).size()))
return false;
- if (depth->shape.size() != 0)
+ if (wrap(depth->shape()).size() != 0)
return false;
- if (on_value->shape.size() != 0)
+ if (wrap(on_value->shape()).size() != 0)
return false;
- if (off_value->shape.size() != 0)
+ if (wrap(off_value->shape()).size() != 0)
return false;
- if (on_value->type != off_value->type)
+ if (on_value->type() != off_value->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CirclePRelu.cpp b/compiler/luci/import/src/Nodes/CirclePRelu.cpp
index c07920f7c..7c81f04bb 100644
--- a/compiler/luci/import/src/Nodes/CirclePRelu.cpp
+++ b/compiler/luci/import/src/Nodes/CirclePRelu.cpp
@@ -25,13 +25,7 @@ namespace luci
bool CirclePReluGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
- return false;
-
- if (args.op.outputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 2);
}
CircleNode *CirclePReluGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CirclePad.cpp b/compiler/luci/import/src/Nodes/CirclePad.cpp
index 999173b90..67dce6dee 100644
--- a/compiler/luci/import/src/Nodes/CirclePad.cpp
+++ b/compiler/luci/import/src/Nodes/CirclePad.cpp
@@ -25,12 +25,8 @@ namespace luci
bool CirclePadGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
- return false;
-
// TODO do attribute checks
-
- return true;
+ return GraphBuilder::validate(args, 2);
}
CircleNode *CirclePadGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CirclePadV2.cpp b/compiler/luci/import/src/Nodes/CirclePadV2.cpp
index 493876e68..84a45722a 100644
--- a/compiler/luci/import/src/Nodes/CirclePadV2.cpp
+++ b/compiler/luci/import/src/Nodes/CirclePadV2.cpp
@@ -25,13 +25,7 @@ namespace luci
bool CirclePadV2GraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 3)
- return false;
-
- if (args.op.outputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 3);
}
CircleNode *CirclePadV2GraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CirclePow.cpp b/compiler/luci/import/src/Nodes/CirclePow.cpp
index def012614..1d2d41607 100644
--- a/compiler/luci/import/src/Nodes/CirclePow.cpp
+++ b/compiler/luci/import/src/Nodes/CirclePow.cpp
@@ -25,13 +25,7 @@ namespace luci
bool CirclePowGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
- return false;
-
- if (args.op.outputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 2);
}
CircleNode *CirclePowGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleQuantize.cpp b/compiler/luci/import/src/Nodes/CircleQuantize.cpp
new file mode 100644
index 000000000..9247a76d9
--- /dev/null
+++ b/compiler/luci/import/src/Nodes/CircleQuantize.cpp
@@ -0,0 +1,43 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Import/Nodes/CircleQuantize.h"
+
+#include <luci/IR/Nodes/CircleQuantize.h>
+
+#include <loco.h>
+
+namespace luci
+{
+
+bool CircleQuantizeGraphBuilder::validate(const ValidateArgs &args) const
+{
+ return GraphBuilder::validate(args, 1);
+}
+
+CircleNode *CircleQuantizeGraphBuilder::build_node(const circle::OperatorT &,
+ const std::vector<CircleNode *> &inputs,
+ loco::Graph *graph) const
+{
+ auto *node = graph->nodes()->create<CircleQuantize>();
+ node->input(inputs.at(0));
+
+ // No options for Quantize
+
+ return node;
+}
+
+} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleRange.cpp b/compiler/luci/import/src/Nodes/CircleRange.cpp
index 38dc44ed6..d3b5afc95 100644
--- a/compiler/luci/import/src/Nodes/CircleRange.cpp
+++ b/compiler/luci/import/src/Nodes/CircleRange.cpp
@@ -24,11 +24,8 @@ namespace luci
{
bool CircleRangeGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 3)
- return false;
-
// TODO Support type check
- return true;
+ return GraphBuilder::validate(args, 3);
}
CircleNode *CircleRangeGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleRank.cpp b/compiler/luci/import/src/Nodes/CircleRank.cpp
index 12658b192..afebb9509 100644
--- a/compiler/luci/import/src/Nodes/CircleRank.cpp
+++ b/compiler/luci/import/src/Nodes/CircleRank.cpp
@@ -24,13 +24,7 @@ namespace luci
{
bool CircleRankGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
- return false;
-
- if (args.op.outputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleRankGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleReduceAny.cpp b/compiler/luci/import/src/Nodes/CircleReduceAny.cpp
index 21a821951..ebe2368e0 100644
--- a/compiler/luci/import/src/Nodes/CircleReduceAny.cpp
+++ b/compiler/luci/import/src/Nodes/CircleReduceAny.cpp
@@ -23,24 +23,25 @@ namespace luci
bool CircleReduceAnyGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
- if (inputs.size() != 2)
- return false;
- if (outputs.size() != 1)
+ if (!GraphBuilder::validate(args, 2))
return false;
- const auto &tensors = args.reader.tensors();
- const auto &tensor_0 = tensors.at(inputs.at(0));
- const auto &tensor_1 = tensors.at(inputs.at(1));
- const auto &tensor_o = tensors.at(outputs[0]);
+ const auto &inputs = args.op.inputs;
+ const auto &outputs = args.op.outputs;
+ const auto tensors = args.reader.tensors();
+ const auto tensor_0 = tensors.at(inputs.at(0));
+ const auto tensor_1 = tensors.at(inputs.at(1));
+ const auto tensor_o = tensors.at(outputs[0]);
+ assert(tensor_0 != nullptr);
+ assert(tensor_1 != nullptr);
+ assert(tensor_o != nullptr);
- if (tensor_0->type != circle::TensorType_BOOL)
+ if (tensor_0->type() != circle::TensorType_BOOL)
return false;
- if (tensor_o->type != circle::TensorType_BOOL)
+ if (tensor_o->type() != circle::TensorType_BOOL)
return false;
- switch (tensor_1->type)
+ switch (tensor_1->type())
{
case circle::TensorType_INT32:
case circle::TensorType_INT64:
diff --git a/compiler/luci/import/src/Nodes/CircleReduceProd.cpp b/compiler/luci/import/src/Nodes/CircleReduceProd.cpp
index 5f054586e..3b874b7c9 100644
--- a/compiler/luci/import/src/Nodes/CircleReduceProd.cpp
+++ b/compiler/luci/import/src/Nodes/CircleReduceProd.cpp
@@ -23,19 +23,18 @@ namespace luci
bool CircleReduceProdGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- if (inputs.size() != 2)
- return false;
- if (args.op.outputs.size() != 1)
+ if (!GraphBuilder::validate(args, 2))
return false;
- const auto &tensors = args.reader.tensors();
- const auto &tensor_1 = tensors.at(inputs.at(1));
+ const auto &inputs = args.op.inputs;
+ const auto tensors = args.reader.tensors();
+ const auto tensor_1 = tensors.at(inputs.at(1));
+ assert(tensor_1 != nullptr);
// TODO check input types
// Check for reduction_indices types
- switch (tensor_1->type)
+ switch (tensor_1->type())
{
case circle::TensorType_INT32:
case circle::TensorType_INT64:
diff --git a/compiler/luci/import/src/Nodes/CircleRelu.cpp b/compiler/luci/import/src/Nodes/CircleRelu.cpp
index 8e1c32a3a..73b8ffee8 100644
--- a/compiler/luci/import/src/Nodes/CircleRelu.cpp
+++ b/compiler/luci/import/src/Nodes/CircleRelu.cpp
@@ -25,13 +25,7 @@ namespace luci
bool CircleReluGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
- return false;
-
- if (args.op.outputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleReluGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleRelu6.cpp b/compiler/luci/import/src/Nodes/CircleRelu6.cpp
index 0283d7350..ab957eda8 100644
--- a/compiler/luci/import/src/Nodes/CircleRelu6.cpp
+++ b/compiler/luci/import/src/Nodes/CircleRelu6.cpp
@@ -25,13 +25,7 @@ namespace luci
bool CircleRelu6GraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
- return false;
-
- if (args.op.outputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleRelu6GraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleReluN1To1.cpp b/compiler/luci/import/src/Nodes/CircleReluN1To1.cpp
index 7f517bc0d..4987f3be2 100644
--- a/compiler/luci/import/src/Nodes/CircleReluN1To1.cpp
+++ b/compiler/luci/import/src/Nodes/CircleReluN1To1.cpp
@@ -25,15 +25,8 @@ namespace luci
bool CircleReluN1To1GraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
- return false;
-
- if (args.op.outputs.size() != 1)
- return false;
-
// TODO check dtypes
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleReluN1To1GraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleReshape.cpp b/compiler/luci/import/src/Nodes/CircleReshape.cpp
index 996ae9d20..12da54ef7 100644
--- a/compiler/luci/import/src/Nodes/CircleReshape.cpp
+++ b/compiler/luci/import/src/Nodes/CircleReshape.cpp
@@ -30,6 +30,19 @@ bool CircleReshapeGraphBuilder::validate(const ValidateArgs &args) const
if (args.op.outputs.size() != 1)
return false;
+ // for two inputs, check if type is S32 or S64
+ if (args.op.inputs.size() == 2)
+ {
+ const auto &inputs = args.op.inputs;
+ const auto tensors = args.reader.tensors();
+ const auto tensor_in = tensors.at(inputs.at(1));
+ assert(tensor_in != nullptr);
+
+ if (tensor_in->type() != circle::TensorType::TensorType_INT32 &&
+ tensor_in->type() != circle::TensorType::TensorType_INT64)
+ return false;
+ }
+
return true;
}
@@ -53,6 +66,7 @@ static CircleNode *create_shape_node(const std::vector<int32_t> &shape, loco::Gr
{
shape_node->at<loco::DataType::S32>(i) = shape[i];
}
+ shape_node->name("Reshape/shape");
return shape_node;
}
@@ -73,6 +87,7 @@ CircleNode *CircleReshapeGraphBuilder::build_node(const circle::OperatorT &op,
shape_node = graph->nodes()->create<CircleOutputDummy>();
shape_node->dtype(loco::DataType::S32);
shape_node->rank(0);
+ shape_node->name("Reshape/dummy");
}
}
diff --git a/compiler/luci/import/src/Nodes/CircleResizeBilinear.cpp b/compiler/luci/import/src/Nodes/CircleResizeBilinear.cpp
index 0fccb7b44..c751b245c 100644
--- a/compiler/luci/import/src/Nodes/CircleResizeBilinear.cpp
+++ b/compiler/luci/import/src/Nodes/CircleResizeBilinear.cpp
@@ -16,7 +16,6 @@
#include "luci/Import/Nodes/CircleResizeBilinear.h"
-#include <luci/IR/Nodes/CircleConst.h>
#include <luci/IR/Nodes/CircleResizeBilinear.h>
namespace luci
@@ -24,13 +23,7 @@ namespace luci
bool CircleResizeBilinearGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
- return false;
-
- if (args.op.outputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 2);
}
CircleNode *CircleResizeBilinearGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleResizeNearestNeighbor.cpp b/compiler/luci/import/src/Nodes/CircleResizeNearestNeighbor.cpp
index 324323f59..df7517fe9 100644
--- a/compiler/luci/import/src/Nodes/CircleResizeNearestNeighbor.cpp
+++ b/compiler/luci/import/src/Nodes/CircleResizeNearestNeighbor.cpp
@@ -16,7 +16,6 @@
#include "luci/Import/Nodes/CircleResizeNearestNeighbor.h"
-#include <luci/IR/Nodes/CircleConst.h>
#include <luci/IR/Nodes/CircleResizeNearestNeighbor.h>
namespace luci
@@ -24,17 +23,11 @@ namespace luci
bool CircleResizeNearestNeighborGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
- return false;
-
- if (args.op.outputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 2);
}
CircleNode *CircleResizeNearestNeighborGraphBuilder::build_node(
- const circle::OperatorT &op, const std::vector<CircleNode *> &inputs, loco::Graph *graph) const
+ const circle::OperatorT &op, const std::vector<CircleNode *> &inputs, loco::Graph *graph) const
{
auto *node = graph->nodes()->create<CircleResizeNearestNeighbor>();
node->input(inputs.at(0));
diff --git a/compiler/luci/import/src/Nodes/CircleReverseSequence.cpp b/compiler/luci/import/src/Nodes/CircleReverseSequence.cpp
index ad11d4c63..c9cc792bb 100644
--- a/compiler/luci/import/src/Nodes/CircleReverseSequence.cpp
+++ b/compiler/luci/import/src/Nodes/CircleReverseSequence.cpp
@@ -25,20 +25,20 @@ namespace luci
bool CircleReverseSequenceGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
-
- if (inputs.size() != 2)
- return false;
- if (outputs.size() != 1)
+ if (!GraphBuilder::validate(args, 2))
return false;
- const auto &tensors = args.reader.tensors();
- const auto &tensor_in = tensors.at(inputs.at(0));
- const auto &tensor_lengths = tensors.at(inputs.at(1));
- const auto &tensor_out = tensors.at(outputs[0]);
+ const auto &inputs = args.op.inputs;
+ const auto &outputs = args.op.outputs;
+ const auto tensors = args.reader.tensors();
+ const auto tensor_in = tensors.at(inputs.at(0));
+ const auto tensor_lengths = tensors.at(inputs.at(1));
+ const auto tensor_out = tensors.at(outputs[0]);
+ assert(tensor_in != nullptr);
+ assert(tensor_lengths != nullptr);
+ assert(tensor_out != nullptr);
- switch (tensor_lengths->type)
+ switch (tensor_lengths->type())
{
case circle::TensorType_INT32:
case circle::TensorType_INT64:
@@ -47,7 +47,7 @@ bool CircleReverseSequenceGraphBuilder::validate(const ValidateArgs &args) const
return false;
}
- if (tensor_in->type != tensor_out->type)
+ if (tensor_in->type() != tensor_out->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleReverseV2.cpp b/compiler/luci/import/src/Nodes/CircleReverseV2.cpp
index e2e53bb4b..c19a0fdd2 100644
--- a/compiler/luci/import/src/Nodes/CircleReverseV2.cpp
+++ b/compiler/luci/import/src/Nodes/CircleReverseV2.cpp
@@ -25,20 +25,20 @@ namespace luci
bool CircleReverseV2GraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
-
- if (inputs.size() != 2)
- return false;
- if (outputs.size() != 1)
+ if (!GraphBuilder::validate(args, 2))
return false;
- const auto &tensors = args.reader.tensors();
- const auto &tensor_in = tensors.at(inputs.at(0));
- const auto &tensor_axis = tensors.at(inputs.at(1));
- const auto &tensor_out = tensors.at(outputs[0]);
+ const auto &inputs = args.op.inputs;
+ const auto &outputs = args.op.outputs;
+ const auto tensors = args.reader.tensors();
+ const auto tensor_in = tensors.at(inputs.at(0));
+ const auto tensor_axis = tensors.at(inputs.at(1));
+ const auto tensor_out = tensors.at(outputs[0]);
+ assert(tensor_in != nullptr);
+ assert(tensor_axis != nullptr);
+ assert(tensor_out != nullptr);
- switch (tensor_axis->type)
+ switch (tensor_axis->type())
{
case circle::TensorType_INT32:
case circle::TensorType_INT64:
@@ -47,7 +47,7 @@ bool CircleReverseV2GraphBuilder::validate(const ValidateArgs &args) const
return false;
}
- if (tensor_out->type != tensor_in->type)
+ if (tensor_out->type() != tensor_in->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleRound.cpp b/compiler/luci/import/src/Nodes/CircleRound.cpp
index ad77f9f03..08cfae6c2 100644
--- a/compiler/luci/import/src/Nodes/CircleRound.cpp
+++ b/compiler/luci/import/src/Nodes/CircleRound.cpp
@@ -25,22 +25,21 @@ namespace luci
bool CircleRoundGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
-
- if (inputs.size() != 1)
- return false;
- if (outputs.size() != 1)
+ if (!GraphBuilder::validate(args, 1))
return false;
+ const auto &inputs = args.op.inputs;
+ const auto &outputs = args.op.outputs;
// Must be one of the following types
// bfloat16, half (float16), float32, float64, complex64, complex128
// Currently, circle supports float16, float32, complex64
- const auto &tensors = args.reader.tensors();
- const auto &tensor_in = tensors.at(inputs.at(0));
- const auto &tensor_out = tensors.at(outputs[0]);
+ const auto tensors = args.reader.tensors();
+ const auto tensor_in = tensors.at(inputs.at(0));
+ const auto tensor_out = tensors.at(outputs[0]);
+ assert(tensor_in != nullptr);
+ assert(tensor_out != nullptr);
- switch (tensor_in->type)
+ switch (tensor_in->type())
{
case circle::TensorType_FLOAT16:
case circle::TensorType_FLOAT32:
@@ -52,7 +51,7 @@ bool CircleRoundGraphBuilder::validate(const ValidateArgs &args) const
return false;
}
- if (tensor_out->type != tensor_in->type)
+ if (tensor_out->type() != tensor_in->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleRsqrt.cpp b/compiler/luci/import/src/Nodes/CircleRsqrt.cpp
index ae05fbbf9..e3bc68f8b 100644
--- a/compiler/luci/import/src/Nodes/CircleRsqrt.cpp
+++ b/compiler/luci/import/src/Nodes/CircleRsqrt.cpp
@@ -25,17 +25,20 @@ namespace luci
bool CircleRsqrtGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- if (inputs.size() != 1)
+ if (!GraphBuilder::validate(args, 1))
return false;
+ const auto &inputs = args.op.inputs;
// Must be one of the following types
// bfloat16, half (float16), float32, float64, complex64, complex128
// Currently, circle supports float16, float32, complex64
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
- switch (tensor->type)
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
+ switch (tensor->type())
{
+ case circle::TensorType_UINT8:
+ case circle::TensorType_INT16:
case circle::TensorType_FLOAT16:
case circle::TensorType_FLOAT32:
case circle::TensorType_COMPLEX64:
diff --git a/compiler/luci/import/src/Nodes/CircleSVDF.cpp b/compiler/luci/import/src/Nodes/CircleSVDF.cpp
new file mode 100644
index 000000000..ef57a132a
--- /dev/null
+++ b/compiler/luci/import/src/Nodes/CircleSVDF.cpp
@@ -0,0 +1,64 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Import/Nodes/CircleSVDF.h"
+
+#include <luci/IR/Nodes/CircleSVDF.h>
+
+#include <loco.h>
+
+namespace luci
+{
+
+bool CircleSVDFBuilder::validate(const ValidateArgs &args) const
+{
+ const auto &inputs = args.op.inputs;
+ if (!(inputs.size() == 4 || inputs.size() == 5))
+ return false;
+
+ return true;
+}
+
+CircleNode *CircleSVDFBuilder::build_node(const circle::OperatorT &op,
+ const std::vector<CircleNode *> &inputs,
+ loco::Graph *graph) const
+{
+ auto *node = graph->nodes()->create<CircleSVDF>();
+ node->input(inputs.at(0));
+ node->weight_feature(inputs.at(1));
+ node->weight_time(inputs.at(2));
+ if (inputs.size() == 4)
+ {
+ auto *bias = graph->nodes()->create<CircleOutputExclude>();
+ node->bias(bias);
+
+ node->input_activation_state(inputs.at(3));
+ }
+ else
+ {
+ node->bias(inputs.at(3));
+ node->input_activation_state(inputs.at(4));
+ }
+
+ const auto *options = op.builtin_options.AsSVDFOptions();
+ node->svdf_rank(options->rank);
+ node->fusedActivationFunction(luci_actfunc(options->fused_activation_function));
+ node->asymmetric_quantize_inputs(options->asymmetric_quantize_inputs);
+
+ return node;
+}
+
+} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleScatterNd.cpp b/compiler/luci/import/src/Nodes/CircleScatterNd.cpp
index 7f86aeb74..ebe252527 100644
--- a/compiler/luci/import/src/Nodes/CircleScatterNd.cpp
+++ b/compiler/luci/import/src/Nodes/CircleScatterNd.cpp
@@ -25,19 +25,20 @@ namespace luci
bool CircleScatterNdGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- if (inputs.size() != 3)
+ if (!GraphBuilder::validate(args, 3))
return false;
+ const auto &inputs = args.op.inputs;
// indices must have the same type as shape
- const auto &tensors = args.reader.tensors();
+ const auto tensors = args.reader.tensors();
- if (tensors[inputs.at(0)]->type != tensors[inputs.at(2)]->type)
+ assert(tensors[inputs.at(0)] != nullptr && tensors[inputs.at(2)] != nullptr);
+ if (tensors[inputs.at(0)]->type() != tensors[inputs.at(2)]->type())
return false;
// indices must be either int32 or int64
- if (tensors[inputs.at(0)]->type != circle::TensorType_INT32 &&
- tensors[inputs.at(0)]->type != circle::TensorType_INT64)
+ if (tensors[inputs.at(0)]->type() != circle::TensorType_INT32 &&
+ tensors[inputs.at(0)]->type() != circle::TensorType_INT64)
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleSegmentSum.cpp b/compiler/luci/import/src/Nodes/CircleSegmentSum.cpp
index fb84e5d52..01d1aab44 100644
--- a/compiler/luci/import/src/Nodes/CircleSegmentSum.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSegmentSum.cpp
@@ -25,19 +25,20 @@ namespace luci
bool CircleSegmentSumGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
- if (inputs.size() != 2)
- return false;
- if (outputs.size() != 1)
+ if (!GraphBuilder::validate(args, 2))
return false;
- const auto &tensors = args.reader.tensors();
- const auto &tensor_in = tensors.at(inputs.at(0));
- const auto &tensor_out = tensors.at(outputs[0]);
- const auto &tensor_ids = tensors.at(inputs.at(1));
+ const auto &inputs = args.op.inputs;
+ const auto &outputs = args.op.outputs;
+ const auto tensors = args.reader.tensors();
+ const auto tensor_in = tensors.at(inputs.at(0));
+ const auto tensor_out = tensors.at(outputs[0]);
+ const auto tensor_ids = tensors.at(inputs.at(1));
+ assert(tensor_in != nullptr);
+ assert(tensor_out != nullptr);
+ assert(tensor_ids != nullptr);
- switch (tensor_ids->type)
+ switch (tensor_ids->type())
{
case circle::TensorType_INT32:
case circle::TensorType_INT64:
@@ -46,7 +47,7 @@ bool CircleSegmentSumGraphBuilder::validate(const ValidateArgs &args) const
return false;
}
- if (tensor_out->type != tensor_in->type)
+ if (tensor_out->type() != tensor_in->type())
{
return false;
}
diff --git a/compiler/luci/import/src/Nodes/CircleSelect.cpp b/compiler/luci/import/src/Nodes/CircleSelect.cpp
index 1e649f1e0..002f62f6c 100644
--- a/compiler/luci/import/src/Nodes/CircleSelect.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSelect.cpp
@@ -25,16 +25,14 @@ namespace luci
bool CircleSelectGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
- if (inputs.size() != 3)
- return false;
- if (outputs.size() != 1)
+ if (!GraphBuilder::validate(args, 3))
return false;
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
- if (tensor->type != circle::TensorType_BOOL)
+ const auto &inputs = args.op.inputs;
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
+ if (tensor->type() != circle::TensorType_BOOL)
return false;
// TODO check dtypes for input 1, 2
diff --git a/compiler/luci/import/src/Nodes/CircleSelectV2.cpp b/compiler/luci/import/src/Nodes/CircleSelectV2.cpp
index e6dd04de0..062fdc143 100644
--- a/compiler/luci/import/src/Nodes/CircleSelectV2.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSelectV2.cpp
@@ -25,21 +25,20 @@ namespace luci
bool CircleSelectV2GraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
- if (inputs.size() != 3)
- return false;
- if (outputs.size() != 1)
+ if (!GraphBuilder::validate(args, 3))
return false;
- const auto &tensors = args.reader.tensors();
- const auto &condition = tensors.at(inputs.at(0));
- if (condition->type != circle::TensorType_BOOL)
+ const auto &inputs = args.op.inputs;
+ const auto tensors = args.reader.tensors();
+ const auto condition = tensors.at(inputs.at(0));
+ assert(condition != nullptr);
+ if (condition->type() != circle::TensorType_BOOL)
return false;
- const auto &t = tensors.at(inputs.at(1));
- const auto &e = tensors.at(inputs.at(2));
- if (t->type != e->type)
+ const auto t = tensors.at(inputs.at(1));
+ const auto e = tensors.at(inputs.at(2));
+ assert(t != nullptr && e != nullptr);
+ if (t->type() != e->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleShape.cpp b/compiler/luci/import/src/Nodes/CircleShape.cpp
index bd7dfc9d9..86c0bf59b 100644
--- a/compiler/luci/import/src/Nodes/CircleShape.cpp
+++ b/compiler/luci/import/src/Nodes/CircleShape.cpp
@@ -25,16 +25,8 @@ namespace luci
bool CircleShapeGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
- if (inputs.size() != 1)
- return false;
- if (outputs.size() != 1)
- return false;
-
// TODO check shape, dtype
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleShapeGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleSin.cpp b/compiler/luci/import/src/Nodes/CircleSin.cpp
index 4b245ef6b..51ebf0355 100644
--- a/compiler/luci/import/src/Nodes/CircleSin.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSin.cpp
@@ -25,16 +25,15 @@ namespace luci
bool CircleSinGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- if (inputs.size() != 1)
- return false;
- if (args.op.outputs.size() != 1)
+ if (!GraphBuilder::validate(args, 1))
return false;
+ const auto &inputs = args.op.inputs;
// input type check
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
- switch (tensor->type)
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
+ switch (tensor->type())
{
case circle::TensorType_FLOAT16:
case circle::TensorType_FLOAT32:
diff --git a/compiler/luci/import/src/Nodes/CircleSlice.cpp b/compiler/luci/import/src/Nodes/CircleSlice.cpp
index 8601fbf21..4166040b3 100644
--- a/compiler/luci/import/src/Nodes/CircleSlice.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSlice.cpp
@@ -27,14 +27,8 @@ namespace luci
bool CircleSliceGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 3)
- return false;
- if (args.op.outputs.size() != 1)
- return false;
-
// TODO check shapes and types
-
- return true;
+ return GraphBuilder::validate(args, 3);
}
CircleNode *CircleSliceGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleSoftmax.cpp b/compiler/luci/import/src/Nodes/CircleSoftmax.cpp
index 0ef0b5418..e79914455 100644
--- a/compiler/luci/import/src/Nodes/CircleSoftmax.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSoftmax.cpp
@@ -25,12 +25,8 @@ namespace luci
bool CircleSoftmaxGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
- return false;
-
// TODO do attribute checks
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleSoftmaxGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleSpaceToDepth.cpp b/compiler/luci/import/src/Nodes/CircleSpaceToDepth.cpp
index 8ccd55dc6..2152b65c9 100644
--- a/compiler/luci/import/src/Nodes/CircleSpaceToDepth.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSpaceToDepth.cpp
@@ -27,13 +27,8 @@ namespace luci
bool CircleSpaceToDepthGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- if (inputs.size() != 1)
- return false;
-
// TODO do attribute checks
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleSpaceToDepthGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleSparseToDense.cpp b/compiler/luci/import/src/Nodes/CircleSparseToDense.cpp
index ac756b1f3..ce0688bb9 100644
--- a/compiler/luci/import/src/Nodes/CircleSparseToDense.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSparseToDense.cpp
@@ -25,10 +25,7 @@ namespace luci
bool CircleSparseToDenseGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 4)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 4);
}
CircleNode *CircleSparseToDenseGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleSplit.cpp b/compiler/luci/import/src/Nodes/CircleSplit.cpp
index 07b6cc939..d0a24aae3 100644
--- a/compiler/luci/import/src/Nodes/CircleSplit.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSplit.cpp
@@ -58,62 +58,27 @@ bool CircleSplitGraphBuilder::validate(const ValidateArgs &args) const
* \- CircleSplitOut --- FullyConnected ---
*/
-void CircleSplitGraphBuilder::build(const circle::OperatorT &op, GraphBuilderContext *context) const
+CircleNode *CircleSplitGraphBuilder::build_node(const BuildNodeArgs &bna) const
{
- assert(context != nullptr);
+ auto node = bna.context->graph()->nodes()->create<CircleSplit>();
- auto graph = context->graph();
+ node->split_dim(bna.input_nodes[0]);
+ node->input(bna.input_nodes[1]);
- const std::vector<int32_t> &inputs = op.inputs;
- const std::vector<int32_t> &outputs = op.outputs;
- const auto &tensors = context->reader()->tensors();
- const auto &opcodes = context->reader()->opcodes();
- auto tensors_ptr = context->reader()->tensors_ptr();
- assert(tensors_ptr != nullptr);
+ const auto *options = bna.op.builtin_options.AsSplitOptions();
+ node->num_split(options->num_splits);
- std::vector<CircleNode *> input_nodes;
- for (const int32_t input_tensor_index : inputs)
- {
- input_nodes.push_back(context->nodefinder()->node(input_tensor_index));
- }
+ return node;
+}
- // Create CircleSplit
- auto node = graph->nodes()->create<CircleSplit>();
- node->split_dim(input_nodes[0]);
- node->input(input_nodes[1]);
+CircleNode *CircleSplitGraphBuilder::build_out(const BuildOutArgs &boa) const
+{
+ auto *nodeout = boa.node->graph()->nodes()->create<CircleSplitOut>();
- const auto *options = op.builtin_options.AsSplitOptions();
- node->num_split(options->num_splits);
+ nodeout->input(boa.node);
+ nodeout->index(boa.index);
- assert(outputs.size() > 0);
- assert(int32_t(outputs.size()) == options->num_splits);
- {
- // Let's use name of output 0 as Split name
- const circle::TensorT &output_tensor = *tensors[outputs[0]];
- node->name(tensor_name(output_tensor));
- node->op_version(opcodes[op.opcode_index].get()->version);
-
- // NOTE We don't set quantization for Split itself but to virtual outputs
- }
-
- // Create virtual outputs of Split
- for (int32_t n = 0; n < options->num_splits; ++n)
- {
- const circle::TensorT &output_tensor = *tensors[outputs[n]];
-
- auto *nodeout = graph->nodes()->create<CircleSplitOut>();
- copy_tensor_attributes(output_tensor, nodeout);
- // mark shape_status
- if (tensors_ptr->Get(outputs[n])->shape() == nullptr)
- nodeout->shape_status(ShapeStatus::NOSHAPE);
- else
- nodeout->shape_status(ShapeStatus::VALID);
-
- nodeout->input(node);
- nodeout->index(n);
-
- context->nodefinder()->enroll(outputs[n], nodeout);
- }
+ return nodeout;
}
} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleSplitV.cpp b/compiler/luci/import/src/Nodes/CircleSplitV.cpp
index 7c6e83e17..76cbf7046 100644
--- a/compiler/luci/import/src/Nodes/CircleSplitV.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSplitV.cpp
@@ -58,64 +58,30 @@ bool CircleSplitVGraphBuilder::validate(const ValidateArgs &args) const
* \- CircleSplitVOut --- FullyConnected ---
*/
-void CircleSplitVGraphBuilder::build(const circle::OperatorT &op,
- GraphBuilderContext *context) const
+CircleNode *CircleSplitVGraphBuilder::build_node(const BuildNodeArgs &bna) const
{
- assert(context != nullptr);
-
- auto graph = context->graph();
-
- const std::vector<int32_t> &inputs = op.inputs;
- const std::vector<int32_t> &outputs = op.outputs;
- const auto &tensors = context->reader()->tensors();
- const auto &opcodes = context->reader()->opcodes();
- auto tensors_ptr = context->reader()->tensors_ptr();
- assert(tensors_ptr != nullptr);
-
- std::vector<CircleNode *> input_nodes;
- for (const int32_t input_tensor_index : inputs)
- {
- input_nodes.push_back(context->nodefinder()->node(input_tensor_index));
- }
-
- // Create CircleSplitV
- auto node = graph->nodes()->create<CircleSplitV>();
- node->input(input_nodes[0]);
- node->size_splits(input_nodes[1]);
- node->split_dim(input_nodes[2]);
-
- const auto *options = op.builtin_options.AsSplitVOptions();
+ auto node = bna.context->graph()->nodes()->create<CircleSplitV>();
+
+ node->input(bna.input_nodes[0]);
+ node->size_splits(bna.input_nodes[1]);
+ node->split_dim(bna.input_nodes[2]);
+
+ const auto *options = bna.op.builtin_options.AsSplitVOptions();
node->num_split(options->num_splits);
- assert(outputs.size() > 0);
- assert(int32_t(outputs.size()) == options->num_splits);
- {
- // Let's use name of output 0 as Split name
- const circle::TensorT &output_tensor = *tensors[outputs[0]];
- node->name(tensor_name(output_tensor));
- node->op_version(opcodes[op.opcode_index].get()->version);
-
- // NOTE We don't set quantization for Split itself but to virtual outputs
- }
-
- // Create virtual outputs of Split
- for (int32_t n = 0; n < options->num_splits; ++n)
- {
- const circle::TensorT &output_tensor = *tensors[outputs[n]];
-
- auto *nodeout = graph->nodes()->create<CircleSplitVOut>();
- copy_tensor_attributes(output_tensor, nodeout);
- // mark shape_status
- if (tensors_ptr->Get(outputs[n])->shape() == nullptr)
- nodeout->shape_status(ShapeStatus::NOSHAPE);
- else
- nodeout->shape_status(ShapeStatus::VALID);
-
- nodeout->input(node);
- nodeout->index(n);
-
- context->nodefinder()->enroll(outputs[n], nodeout);
- }
+ assert(int32_t(bna.op.outputs.size()) == options->num_splits);
+
+ return node;
+}
+
+CircleNode *CircleSplitVGraphBuilder::build_out(const BuildOutArgs &boa) const
+{
+ auto *nodeout = boa.node->graph()->nodes()->create<CircleSplitVOut>();
+
+ nodeout->input(boa.node);
+ nodeout->index(boa.index);
+
+ return nodeout;
}
} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleSqrt.cpp b/compiler/luci/import/src/Nodes/CircleSqrt.cpp
index c8beaee0d..b1fdf7996 100644
--- a/compiler/luci/import/src/Nodes/CircleSqrt.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSqrt.cpp
@@ -25,10 +25,7 @@ namespace luci
bool CircleSqrtGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleSqrtGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleSquare.cpp b/compiler/luci/import/src/Nodes/CircleSquare.cpp
index b5ba048d7..bec84b4c0 100644
--- a/compiler/luci/import/src/Nodes/CircleSquare.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSquare.cpp
@@ -25,17 +25,17 @@ namespace luci
bool CircleSquareGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- if (inputs.size() != 1)
+ if (!GraphBuilder::validate(args, 1))
return false;
- // Must be one of the following types
- // bfloat16, half (float16), float32, float64, complex64, complex128
- // Currently, circle supports float16, float32, complex64
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
- switch (tensor->type)
+ const auto &inputs = args.op.inputs;
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
+ switch (tensor->type())
{
+ case circle::TensorType_UINT8:
+ case circle::TensorType_INT16:
case circle::TensorType_INT32:
case circle::TensorType_INT64:
case circle::TensorType_FLOAT16:
diff --git a/compiler/luci/import/src/Nodes/CircleSquaredDifference.cpp b/compiler/luci/import/src/Nodes/CircleSquaredDifference.cpp
index 6deae94c5..1983465d3 100644
--- a/compiler/luci/import/src/Nodes/CircleSquaredDifference.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSquaredDifference.cpp
@@ -25,20 +25,17 @@ namespace luci
bool CircleSquaredDifferenceGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
-
- if (inputs.size() != 2)
- return false;
-
- if (outputs.size() != 1)
+ if (!GraphBuilder::validate(args, 2))
return false;
+ const auto &inputs = args.op.inputs;
+ const auto &outputs = args.op.outputs;
// Inputs must be one of the following types
// bfloat16, half(float16), float32, float64, int32, int64, complex64, complex128
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
- switch (tensor->type)
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
+ switch (tensor->type())
{
case circle::TensorType_FLOAT16:
case circle::TensorType_FLOAT32:
@@ -48,16 +45,22 @@ bool CircleSquaredDifferenceGraphBuilder::validate(const ValidateArgs &args) con
case circle::TensorType_COMPLEX64:
break;
// TODO support bfloat16, complex128
+ // Additional support for quantized tensors
+ case circle::TensorType_UINT8:
+ case circle::TensorType_INT16:
+ break;
default:
return false;
}
// Input types must match
- if (tensors.at(inputs.at(0))->type != tensors.at(inputs.at(1))->type)
+ assert(tensors.at(inputs.at(0)) != nullptr && tensors.at(inputs.at(1)) != nullptr);
+ if (tensors.at(inputs.at(0))->type() != tensors.at(inputs.at(1))->type())
return false;
// Input and output types must match
- if (tensors.at(inputs.at(0))->type != tensors.at(outputs[0])->type)
+ assert(tensors.at(outputs[0]) != nullptr);
+ if (tensors.at(inputs.at(0))->type() != tensors.at(outputs[0])->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleSqueeze.cpp b/compiler/luci/import/src/Nodes/CircleSqueeze.cpp
index 32792c266..d24d8166c 100644
--- a/compiler/luci/import/src/Nodes/CircleSqueeze.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSqueeze.cpp
@@ -16,7 +16,6 @@
#include "luci/Import/Nodes/CircleSqueeze.h"
-#include <luci/IR/Nodes/CircleConst.h>
#include <luci/IR/Nodes/CircleSqueeze.h>
namespace luci
@@ -24,13 +23,7 @@ namespace luci
bool CircleSqueezeGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
- return false;
-
- if (args.op.outputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleSqueezeGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleStridedSlice.cpp b/compiler/luci/import/src/Nodes/CircleStridedSlice.cpp
index 8f943a682..ca8259cac 100644
--- a/compiler/luci/import/src/Nodes/CircleStridedSlice.cpp
+++ b/compiler/luci/import/src/Nodes/CircleStridedSlice.cpp
@@ -27,14 +27,8 @@ namespace luci
bool CircleStridedSliceGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 4)
- return false;
- if (args.op.outputs.size() != 1)
- return false;
-
// TODO check shapes and types
-
- return true;
+ return GraphBuilder::validate(args, 4);
}
CircleNode *CircleStridedSliceGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleSub.cpp b/compiler/luci/import/src/Nodes/CircleSub.cpp
index 9acf83d40..c3978f218 100644
--- a/compiler/luci/import/src/Nodes/CircleSub.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSub.cpp
@@ -25,13 +25,7 @@ namespace luci
bool CircleSubGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
- return false;
-
- if (args.op.outputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 2);
}
CircleNode *CircleSubGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleSum.cpp b/compiler/luci/import/src/Nodes/CircleSum.cpp
index bd3cb6239..e348a62d9 100644
--- a/compiler/luci/import/src/Nodes/CircleSum.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSum.cpp
@@ -23,10 +23,7 @@ namespace luci
bool CircleSumGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 2);
}
CircleNode *CircleSumGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleTanh.cpp b/compiler/luci/import/src/Nodes/CircleTanh.cpp
index 018f5701b..80a0e887f 100644
--- a/compiler/luci/import/src/Nodes/CircleTanh.cpp
+++ b/compiler/luci/import/src/Nodes/CircleTanh.cpp
@@ -25,15 +25,14 @@ namespace luci
bool CircleTanhGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- if (inputs.size() != 1)
- return false;
- const auto &outputs = args.op.outputs;
- if (outputs.size() != 1)
+ if (!GraphBuilder::validate(args, 1))
return false;
- const auto &tensors = args.reader.tensors();
- if (tensors.at(inputs.at(0))->type != tensors.at(outputs[0])->type)
+ const auto &inputs = args.op.inputs;
+ const auto &outputs = args.op.outputs;
+ const auto tensors = args.reader.tensors();
+ assert(tensors.at(inputs.at(0)) != nullptr && tensors.at(outputs[0]) != nullptr);
+ if (tensors.at(inputs.at(0))->type() != tensors.at(outputs[0])->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleTile.cpp b/compiler/luci/import/src/Nodes/CircleTile.cpp
index bc6f320ba..c41a6ba3f 100644
--- a/compiler/luci/import/src/Nodes/CircleTile.cpp
+++ b/compiler/luci/import/src/Nodes/CircleTile.cpp
@@ -25,20 +25,17 @@ namespace luci
bool CircleTileGraphBuilder::validate(const ValidateArgs &args) const
{
- auto inputs = args.op.inputs;
- auto outputs = args.op.outputs;
-
- if (inputs.size() != 2)
- return false;
-
- if (outputs.size() != 1)
+ if (!GraphBuilder::validate(args, 2))
return false;
+ auto inputs = args.op.inputs;
+ auto outputs = args.op.outputs;
// Multiples (inputs.at(1)) must be one of the following types
// int32, int64
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(1));
- switch (tensor->type)
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(1));
+ assert(tensor != nullptr);
+ switch (tensor->type())
{
case circle::TensorType_INT32:
case circle::TensorType_INT64:
@@ -48,7 +45,8 @@ bool CircleTileGraphBuilder::validate(const ValidateArgs &args) const
}
// Type of input and output must be the same
- if (tensors.at(inputs.at(0))->type != tensors.at(outputs[0])->type)
+ assert(tensors.at(inputs.at(0)) != nullptr && tensors.at(outputs[0]) != nullptr);
+ if (tensors.at(inputs.at(0))->type() != tensors.at(outputs[0])->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleTopKV2.cpp b/compiler/luci/import/src/Nodes/CircleTopKV2.cpp
index f0677de86..9f9173738 100644
--- a/compiler/luci/import/src/Nodes/CircleTopKV2.cpp
+++ b/compiler/luci/import/src/Nodes/CircleTopKV2.cpp
@@ -35,9 +35,10 @@ bool CircleTopKV2GraphBuilder::validate(const ValidateArgs &args) const
if (outputs.size() != 2)
return false;
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(1));
- if (tensor->type != circle::TensorType_INT32)
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(1));
+ assert(tensor != nullptr);
+ if (tensor->type() != circle::TensorType_INT32)
return false;
return true;
@@ -59,59 +60,24 @@ bool CircleTopKV2GraphBuilder::validate(const ValidateArgs &args) const
* \- CircleTopKV2Out --- FullyConnected ---
*/
-void CircleTopKV2GraphBuilder::build(const circle::OperatorT &op,
- GraphBuilderContext *context) const
+CircleNode *CircleTopKV2GraphBuilder::build_node(const BuildNodeArgs &bna) const
{
- assert(context != nullptr);
-
- auto graph = context->graph();
-
- const std::vector<int32_t> &inputs = op.inputs;
- const std::vector<int32_t> &outputs = op.outputs;
- const auto &tensors = context->reader()->tensors();
- const auto &opcodes = context->reader()->opcodes();
- auto tensors_ptr = context->reader()->tensors_ptr();
- assert(tensors_ptr != nullptr);
-
- std::vector<CircleNode *> input_nodes;
- for (const int32_t input_tensor_index : inputs)
- {
- input_nodes.push_back(context->nodefinder()->node(input_tensor_index));
- }
-
- // Create CircleTopKV2
- auto node = graph->nodes()->create<CircleTopKV2>();
- node->input(input_nodes[0]);
- node->k(input_nodes[1]);
-
- assert(outputs.size() == 2);
- {
- // Let's use name of output 0 as TopKV2 name
- const circle::TensorT &output_tensor = *tensors[outputs[0]];
- node->name(tensor_name(output_tensor));
- node->op_version(opcodes[op.opcode_index].get()->version);
-
- // NOTE We don't set quantization for TopKV2 itself but to virtual outputs
- }
-
- // Create virtual outputs of TopKV2
- for (size_t n = 0; n < outputs.size(); ++n)
- {
- const circle::TensorT &output_tensor = *tensors[outputs[n]];
-
- auto *nodeout = graph->nodes()->create<CircleTopKV2Out>();
- copy_tensor_attributes(output_tensor, nodeout);
- // mark shape_status
- if (tensors_ptr->Get(outputs[n])->shape() == nullptr)
- nodeout->shape_status(ShapeStatus::NOSHAPE);
- else
- nodeout->shape_status(ShapeStatus::VALID);
-
- nodeout->input(node);
- nodeout->index(n);
-
- context->nodefinder()->enroll(outputs[n], nodeout);
- }
+ auto node = bna.context->graph()->nodes()->create<CircleTopKV2>();
+
+ node->input(bna.input_nodes[0]);
+ node->k(bna.input_nodes[1]);
+
+ return node;
+}
+
+CircleNode *CircleTopKV2GraphBuilder::build_out(const BuildOutArgs &boa) const
+{
+ auto *nodeout = boa.node->graph()->nodes()->create<CircleTopKV2Out>();
+
+ nodeout->input(boa.node);
+ nodeout->index(boa.index);
+
+ return nodeout;
}
} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleTranspose.cpp b/compiler/luci/import/src/Nodes/CircleTranspose.cpp
index cc3153085..01095239e 100644
--- a/compiler/luci/import/src/Nodes/CircleTranspose.cpp
+++ b/compiler/luci/import/src/Nodes/CircleTranspose.cpp
@@ -25,13 +25,7 @@ namespace luci
bool CircleTransposeGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
- return false;
-
- if (args.op.outputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 2);
}
CircleNode *CircleTransposeGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp b/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp
index c280faaf5..62326f435 100644
--- a/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp
+++ b/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp
@@ -31,11 +31,13 @@ bool CircleTransposeConvGraphBuilder::validate(const ValidateArgs &args) const
return false;
const auto &inputs = args.op.inputs;
- const auto &tensors = args.reader.tensors();
- const auto &filter_tensor = tensors.at(inputs.at(1));
- const auto &filter_shape = filter_tensor.get()->shape;
- const auto &ifm_tensor = tensors.at(inputs.at(2));
- const auto &ifm_shape = ifm_tensor.get()->shape;
+ const auto tensors = args.reader.tensors();
+ const auto filter_tensor = tensors.at(inputs.at(1));
+ assert(filter_tensor != nullptr);
+ const auto filter_shape = wrap(filter_tensor->shape());
+ const auto ifm_tensor = tensors.at(inputs.at(2));
+ assert(ifm_tensor != nullptr);
+ const auto ifm_shape = wrap(ifm_tensor->shape());
// ifm and filters must be 4-D tensor
if (ifm_shape.size() != 4)
@@ -45,7 +47,7 @@ bool CircleTransposeConvGraphBuilder::validate(const ValidateArgs &args) const
// input shape : [batch, height, width, in_channels]
// filters shape : [output_channels, height, weight, in_channels]
- if (ifm_tensor.get()->shape.at(3) != filter_tensor.get()->shape.at(3))
+ if (ifm_shape.at(3) != filter_shape.at(3))
return false;
return true;
@@ -61,21 +63,18 @@ CircleNode *CircleTransposeConvGraphBuilder::build_node(const circle::OperatorT
node->filter(inputs.at(1));
node->outBackprop(inputs.at(2));
if (inputs.size() == 3)
- node->bias(graph->nodes()->create<CircleOutputExclude>());
- else
- node->bias(inputs.at(3));
-
- if (auto bias = dynamic_cast<luci::CircleOutputExclude *>(node->bias()))
{
- // CircleOutputExclude doesn't need a type, but since all nodes must have a type, a dummy type
- // is inserted.
- bias->dtype(loco::DataType::FLOAT32);
+ auto *bias = graph->nodes()->create<CircleOutputExclude>();
+ node->bias(bias);
}
+ else
+ node->bias(inputs.at(3));
const auto *options = op.builtin_options.AsTransposeConvOptions();
node->padding(luci_padding(options->padding));
node->stride()->w(options->stride_w);
node->stride()->h(options->stride_h);
+ node->fusedActivationFunction(luci_actfunc(options->fused_activation_function));
return node;
}
diff --git a/compiler/luci/import/src/Nodes/CircleUnidirectionalSequenceLSTM.cpp b/compiler/luci/import/src/Nodes/CircleUnidirectionalSequenceLSTM.cpp
new file mode 100644
index 000000000..7ab6d6881
--- /dev/null
+++ b/compiler/luci/import/src/Nodes/CircleUnidirectionalSequenceLSTM.cpp
@@ -0,0 +1,76 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Import/Nodes/CircleUnidirectionalSequenceLSTM.h"
+
+#include <luci/IR/Nodes/CircleUnidirectionalSequenceLSTM.h>
+
+#include <loco.h>
+
+namespace luci
+{
+
+bool CircleUnidirectionalSequenceLSTMGraphBuilder::validate(const ValidateArgs &args) const
+{
+ return GraphBuilder::validate(args, 24);
+}
+
+CircleNode *CircleUnidirectionalSequenceLSTMGraphBuilder::build_node(
+ const circle::OperatorT &op, const std::vector<CircleNode *> &inputs, loco::Graph *graph) const
+{
+ auto *node = graph->nodes()->create<CircleUnidirectionalSequenceLSTM>();
+ node->input(inputs.at(0));
+ node->input_to_input_weights(inputs.at(1)); // Optional
+ node->input_to_forget_weights(inputs.at(2));
+ node->input_to_cell_weights(inputs.at(3));
+ node->input_to_output_weights(inputs.at(4));
+
+ node->recurrent_to_input_weights(inputs.at(5)); // Optional
+ node->recurrent_to_forget_weights(inputs.at(6));
+ node->recurrent_to_cell_weights(inputs.at(7));
+ node->recurrent_to_output_weights(inputs.at(8));
+
+ node->cell_to_input_weights(inputs.at(9)); // Optional
+ node->cell_to_forget_weights(inputs.at(10)); // Optional
+ node->cell_to_output_weights(inputs.at(11)); // Optional
+
+ node->input_gate_bias(inputs.at(12)); // Optional
+ node->forget_gate_bias(inputs.at(13));
+ node->cell_gate_bias(inputs.at(14));
+ node->output_gate_bias(inputs.at(15));
+
+ node->projection_weights(inputs.at(16)); // Optional
+ node->projection_bias(inputs.at(17)); // Optional
+
+ node->output_state(inputs.at(18));
+ node->cell_state(inputs.at(19));
+
+ node->input_layer_norm_coefficients(inputs.at(20)); // Optional
+ node->forget_layer_norm_coefficients(inputs.at(21)); // Optional
+ node->cell_layer_norm_coefficients(inputs.at(22)); // Optional
+ node->output_layer_norm_coefficients(inputs.at(23)); // Optional
+
+ const auto *options = op.builtin_options.AsUnidirectionalSequenceLSTMOptions();
+ node->fusedActivationFunction(luci_actfunc(options->fused_activation_function));
+ node->cell_clip(options->cell_clip);
+ node->proj_clip(options->proj_clip);
+ node->time_major(options->time_major);
+ node->asymmetric_quantize_inputs(options->asymmetric_quantize_inputs);
+
+ return node;
+}
+
+} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleUnique.cpp b/compiler/luci/import/src/Nodes/CircleUnique.cpp
index 5e79a2920..f6914c24a 100644
--- a/compiler/luci/import/src/Nodes/CircleUnique.cpp
+++ b/compiler/luci/import/src/Nodes/CircleUnique.cpp
@@ -35,55 +35,26 @@ bool CircleUniqueGraphBuilder::validate(const ValidateArgs &args) const
return true;
}
-void CircleUniqueGraphBuilder::build(const circle::OperatorT &op,
- GraphBuilderContext *context) const
+CircleNode *CircleUniqueGraphBuilder::build_node(const BuildNodeArgs &bna) const
{
- assert(context != nullptr);
+ auto node = bna.context->graph()->nodes()->create<CircleUnique>();
- auto graph = context->graph();
+ node->input(bna.input_nodes[0]);
- const std::vector<int32_t> &inputs = op.inputs;
- const std::vector<int32_t> &outputs = op.outputs;
- const auto &tensors = context->reader()->tensors();
- auto tensors_ptr = context->reader()->tensors_ptr();
- assert(tensors_ptr != nullptr);
+ const auto *options = bna.op.builtin_options.AsUniqueOptions();
+ node->idx_out_type(luci_datatype(options->idx_out_type));
- std::vector<CircleNode *> input_nodes;
- for (const int32_t input_tensor_index : inputs)
- {
- input_nodes.push_back(context->nodefinder()->node(input_tensor_index));
- }
-
- // Create CircleUnique
- auto node = graph->nodes()->create<CircleUnique>();
- node->input(input_nodes[0]);
-
- const auto *options = op.builtin_options.AsUniqueOptions();
- node->output_type(luci_datatype(options->idx_out_type));
-
- assert(int32_t(outputs.size()) == 2);
- // Let's use name of output 0 as Unique name
- const circle::TensorT &output_tensor = *tensors[outputs[0]];
- node->name(tensor_name(output_tensor));
-
- // Create virtual outputs of Unique
- for (int32_t n = 0; n < 2; ++n)
- {
- const circle::TensorT &output_tensor = *tensors[outputs[n]];
+ return node;
+}
- auto *nodeout = graph->nodes()->create<CircleUniqueOut>();
- copy_tensor_attributes(output_tensor, nodeout);
- // mark shape_status
- if (tensors_ptr->Get(outputs[n])->shape() == nullptr)
- nodeout->shape_status(ShapeStatus::NOSHAPE);
- else
- nodeout->shape_status(ShapeStatus::VALID);
+CircleNode *CircleUniqueGraphBuilder::build_out(const BuildOutArgs &boa) const
+{
+ auto *nodeout = boa.node->graph()->nodes()->create<CircleUniqueOut>();
- nodeout->input(node);
- nodeout->index(n);
+ nodeout->input(boa.node);
+ nodeout->index(boa.index);
- context->nodefinder()->enroll(outputs[n], nodeout);
- }
+ return nodeout;
}
} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleUnpack.cpp b/compiler/luci/import/src/Nodes/CircleUnpack.cpp
index 9e7f3d3e1..6b3401609 100644
--- a/compiler/luci/import/src/Nodes/CircleUnpack.cpp
+++ b/compiler/luci/import/src/Nodes/CircleUnpack.cpp
@@ -46,8 +46,8 @@ bool CircleUnpackGraphBuilder::validate(const ValidateArgs &args) const
{
if (settings->get(luci::UserSettings::Key::DisableValidation))
{
- const auto &tensors = args.reader.tensors();
- const circle::TensorT &output_tensor = *tensors[outputs[0]];
+ const auto tensors = args.reader.tensors();
+ const auto output_tensor = tensors[outputs[0]];
auto name = tensor_name(output_tensor);
WARN(l) << "Warning: import Unpack(" << name << ") 'num' is not same as outputs used";
}
@@ -58,9 +58,10 @@ bool CircleUnpackGraphBuilder::validate(const ValidateArgs &args) const
if (options->num < 0)
return false;
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
- const auto &shape = tensor->shape;
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
+ const auto shape = wrap(tensor->shape());
auto shape_size = static_cast<int32_t>(shape.size());
if (shape_size > 0)
{
@@ -88,64 +89,27 @@ bool CircleUnpackGraphBuilder::validate(const ValidateArgs &args) const
* \- CircleUnpackOut --- FullyConnected ---
*/
-void CircleUnpackGraphBuilder::build(const circle::OperatorT &op,
- GraphBuilderContext *context) const
+CircleNode *CircleUnpackGraphBuilder::build_node(const BuildNodeArgs &bna) const
{
- assert(context != nullptr);
+ auto node = bna.context->graph()->nodes()->create<CircleUnpack>();
- auto graph = context->graph();
+ node->value(bna.input_nodes[0]);
- const std::vector<int32_t> &inputs = op.inputs;
- const std::vector<int32_t> &outputs = op.outputs;
- const auto &tensors = context->reader()->tensors();
- const auto &opcodes = context->reader()->opcodes();
- auto tensors_ptr = context->reader()->tensors_ptr();
- assert(tensors_ptr != nullptr);
-
- // NOTE Unpack has only one input so running a loop is not necessary
- // This is provided as a reference for other Ops as a reference
- std::vector<CircleNode *> input_nodes;
- for (const int32_t input_tensor_index : inputs)
- {
- input_nodes.push_back(context->nodefinder()->node(input_tensor_index));
- }
-
- // Create CircleUnpack
- CircleUnpack *node = graph->nodes()->create<CircleUnpack>();
- node->value(input_nodes[0]);
-
- const auto *options = op.builtin_options.AsUnpackOptions();
+ const auto *options = bna.op.builtin_options.AsUnpackOptions();
node->num(options->num);
node->axis(options->axis);
- assert(outputs.size() > 0);
- {
- // Let's use name of output 0 as Unpack name
- const circle::TensorT &output_tensor = *tensors[outputs[0]];
- node->name(tensor_name(output_tensor));
- node->op_version(opcodes[op.opcode_index].get()->version);
-
- // NOTE We don't set quantization for Unpack itself but to virtual outputs
- }
-
- // Create virtual outputs of Unpack
- for (int32_t n = 0; n < options->num; ++n)
- {
- const circle::TensorT &output_tensor = *tensors[outputs[n]];
+ return node;
+}
- auto *nodeout = graph->nodes()->create<CircleUnpackOut>();
- copy_tensor_attributes(output_tensor, nodeout);
- // mark shape_status
- if (tensors_ptr->Get(outputs[n])->shape() == nullptr)
- nodeout->shape_status(ShapeStatus::NOSHAPE);
- else
- nodeout->shape_status(ShapeStatus::VALID);
+CircleNode *CircleUnpackGraphBuilder::build_out(const BuildOutArgs &boa) const
+{
+ auto *nodeout = boa.node->graph()->nodes()->create<CircleUnpackOut>();
- nodeout->input(node);
- nodeout->index(n);
+ nodeout->input(boa.node);
+ nodeout->index(boa.index);
- context->nodefinder()->enroll(outputs[n], nodeout);
- }
+ return nodeout;
}
} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleVariable.cpp b/compiler/luci/import/src/Nodes/CircleVariable.cpp
new file mode 100644
index 000000000..23ae9e7be
--- /dev/null
+++ b/compiler/luci/import/src/Nodes/CircleVariable.cpp
@@ -0,0 +1,80 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Import/Nodes/CircleVariable.h"
+
+#include <luci/IR/Nodes/CircleVariable.h>
+#include <luci/Log.h>
+
+#include <cassert>
+#include <ostream>
+#include <string>
+#include <vector>
+
+namespace
+{
+
+std::ostream &operator<<(std::ostream &os, const luci::VectorWrapper<int32_t> &vect)
+{
+ uint32_t seq = 0;
+ for (const auto &v : vect)
+ {
+ if (seq)
+ os << ", ";
+ os << v;
+ seq++;
+ }
+ return os;
+}
+
+} // namespace
+
+namespace luci
+{
+
+CircleVariable *create_circlevariable(GraphBuilderContext *context, int32_t tensor_index)
+{
+ LOGGER(l);
+
+ auto graph = context->graph();
+ auto reader = context->reader();
+ const auto tensors = reader->tensors();
+ const auto variable_tensor = tensors[tensor_index];
+ assert(variable_tensor != nullptr);
+
+ if (not variable_tensor->is_variable())
+ {
+ // not a variable
+ return nullptr;
+ }
+ {
+ // check if there is no buffer as we don't support this for now
+ // TODO use buffer when this is enabled in Kernel
+ assert(reader->buffers()[variable_tensor->buffer()] != nullptr);
+ assert(reader->buffers()[variable_tensor->buffer()]->data() == nullptr);
+ }
+
+ auto variable_node = graph->nodes()->create<CircleVariable>();
+ copy_tensor_attributes(variable_tensor, variable_node);
+ variable_node->shape_status(luci::ShapeStatus::VALID);
+
+ INFO(l) << "[luci] NodeFinder variable node(" << tensor_index << ") -> " << variable_node << " "
+ << wrap(variable_tensor->shape()) << std::endl;
+
+ return variable_node;
+}
+
+} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleWhere.cpp b/compiler/luci/import/src/Nodes/CircleWhere.cpp
index f4c5f0c66..bc6199ace 100644
--- a/compiler/luci/import/src/Nodes/CircleWhere.cpp
+++ b/compiler/luci/import/src/Nodes/CircleWhere.cpp
@@ -25,23 +25,21 @@ namespace luci
bool CircleWhereGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
-
- if (inputs.size() != 1)
+ if (!GraphBuilder::validate(args, 1))
return false;
- if (outputs.size() != 1)
- return false;
-
- const auto &tensors = args.reader.tensors();
- const auto &tensor_condition = tensors.at(inputs.at(0));
- const auto &tensor_out = tensors.at(outputs[0]);
+ const auto &inputs = args.op.inputs;
+ const auto &outputs = args.op.outputs;
+ const auto tensors = args.reader.tensors();
+ const auto tensor_condition = tensors.at(inputs.at(0));
+ const auto tensor_out = tensors.at(outputs[0]);
+ assert(tensor_condition != nullptr);
+ assert(tensor_out != nullptr);
- if (tensor_condition->type != circle::TensorType_BOOL)
+ if (tensor_condition->type() != circle::TensorType_BOOL)
return false;
- if (tensor_out->type != circle::TensorType_INT64)
+ if (tensor_out->type() != circle::TensorType_INT64)
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleWhile.cpp b/compiler/luci/import/src/Nodes/CircleWhile.cpp
index aead25071..27a392b2a 100644
--- a/compiler/luci/import/src/Nodes/CircleWhile.cpp
+++ b/compiler/luci/import/src/Nodes/CircleWhile.cpp
@@ -58,7 +58,8 @@ bool CircleWhileGraphBuilder::validate(const ValidateArgs &args) const
* \- CircleWhileOut --- Node ---
*/
-void CircleWhileGraphBuilder::build(const circle::OperatorT &op, GraphBuilderContext *context) const
+CircleNode *CircleWhileGraphBuilder::build(const circle::OperatorT &op,
+ GraphBuilderContext *context) const
{
assert(context != nullptr);
@@ -66,8 +67,8 @@ void CircleWhileGraphBuilder::build(const circle::OperatorT &op, GraphBuilderCon
const std::vector<int32_t> &inputs = op.inputs;
const std::vector<int32_t> &outputs = op.outputs;
- const auto &tensors = context->reader()->tensors();
- const auto &opcodes = context->reader()->opcodes();
+ const auto tensors = context->reader()->tensors();
+ const auto opcodes = context->reader()->opcodes();
std::vector<CircleNode *> input_nodes;
for (const int32_t input_tensor_index : inputs)
@@ -95,9 +96,11 @@ void CircleWhileGraphBuilder::build(const circle::OperatorT &op, GraphBuilderCon
assert(outputs.size() > 0);
{
// Lets use name of output 0 as While name
- const circle::TensorT &output_tensor = *tensors[outputs[0]];
+ const auto output_tensor = tensors[outputs[0]];
+ assert(output_tensor != nullptr);
node->name(tensor_name(output_tensor));
- node->op_version(opcodes[op.opcode_index].get()->version);
+ assert(opcodes[op.opcode_index] != nullptr);
+ node->op_version(opcodes[op.opcode_index]->version());
// NOTE We don't set quantization for While itself but to virtual outputs
}
@@ -105,7 +108,8 @@ void CircleWhileGraphBuilder::build(const circle::OperatorT &op, GraphBuilderCon
// Create virtual outputs of While
for (uint32_t n = 0; n < output_count; ++n)
{
- const circle::TensorT &output_tensor = *tensors[outputs[n]];
+ const auto output_tensor = tensors[outputs[n]];
+ assert(output_tensor != nullptr);
auto *nodeout = graph->nodes()->create<CircleWhileOut>();
@@ -118,6 +122,8 @@ void CircleWhileGraphBuilder::build(const circle::OperatorT &op, GraphBuilderCon
context->nodefinder()->enroll(outputs[n], nodeout);
}
+
+ return node;
}
} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleZerosLike.cpp b/compiler/luci/import/src/Nodes/CircleZerosLike.cpp
index e60424def..ddb05e8a4 100644
--- a/compiler/luci/import/src/Nodes/CircleZerosLike.cpp
+++ b/compiler/luci/import/src/Nodes/CircleZerosLike.cpp
@@ -25,13 +25,7 @@ namespace luci
bool CircleZerosLikeGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
- return false;
-
- if (args.op.outputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleZerosLikeGraphBuilder::build_node(const circle::OperatorT &,