summaryrefslogtreecommitdiff
path: root/compiler/luci/lang/src/CircleNodes.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci/lang/src/CircleNodes.cpp')
-rw-r--r--compiler/luci/lang/src/CircleNodes.cpp35
1 files changed, 35 insertions, 0 deletions
diff --git a/compiler/luci/lang/src/CircleNodes.cpp b/compiler/luci/lang/src/CircleNodes.cpp
index 76ff7ec5a..c77c06861 100644
--- a/compiler/luci/lang/src/CircleNodes.cpp
+++ b/compiler/luci/lang/src/CircleNodes.cpp
@@ -37,6 +37,7 @@ void set_new_shape(CircleReshape *node, int32_t *base, uint32_t size)
const_shape_node->dim(0) = size;
const_shape_node->dtype(S32);
const_shape_node->size<S32>(size);
+ const_shape_node->shape_status(luci::ShapeStatus::VALID);
for (uint32_t axis = 0; axis < size; ++axis)
const_shape_node->at<S32>(axis) = base[axis];
node->shape(const_shape_node);
@@ -47,4 +48,38 @@ void set_new_shape(CircleReshape *node, int32_t *base, uint32_t size)
node->newShape()->dim(axis) = base[axis];
}
+void link(loco::GraphOutput *output, CircleOutput *node) { node->index(output->index()); }
+
+CircleOutput *output_node(loco::Graph *g, const loco::GraphOutputIndex &index)
+{
+ for (uint32_t n = 0; n < g->nodes()->size(); ++n)
+ {
+ if (auto output = dynamic_cast<CircleOutput *>(g->nodes()->at(n)))
+ {
+ if (output->indexed() && output->index() == index)
+ {
+ return output;
+ }
+ }
+ }
+ return nullptr;
+}
+
+void link(loco::GraphInput *input, CircleInput *node) { node->index(input->index()); }
+
+CircleInput *input_node(loco::Graph *g, const loco::GraphInputIndex &index)
+{
+ for (uint32_t n = 0; n < g->nodes()->size(); ++n)
+ {
+ if (auto input = dynamic_cast<CircleInput *>(g->nodes()->at(n)))
+ {
+ if (input->indexed() && input->index() == index)
+ {
+ return input;
+ }
+ }
+ }
+ return nullptr;
+}
+
} // namespace luci