diff options
Diffstat (limited to 'compiler/luci/lang/src/CircleNodes.cpp')
-rw-r--r-- | compiler/luci/lang/src/CircleNodes.cpp | 35 |
1 files changed, 35 insertions, 0 deletions
diff --git a/compiler/luci/lang/src/CircleNodes.cpp b/compiler/luci/lang/src/CircleNodes.cpp index 76ff7ec5a..c77c06861 100644 --- a/compiler/luci/lang/src/CircleNodes.cpp +++ b/compiler/luci/lang/src/CircleNodes.cpp @@ -37,6 +37,7 @@ void set_new_shape(CircleReshape *node, int32_t *base, uint32_t size) const_shape_node->dim(0) = size; const_shape_node->dtype(S32); const_shape_node->size<S32>(size); + const_shape_node->shape_status(luci::ShapeStatus::VALID); for (uint32_t axis = 0; axis < size; ++axis) const_shape_node->at<S32>(axis) = base[axis]; node->shape(const_shape_node); @@ -47,4 +48,38 @@ void set_new_shape(CircleReshape *node, int32_t *base, uint32_t size) node->newShape()->dim(axis) = base[axis]; } +void link(loco::GraphOutput *output, CircleOutput *node) { node->index(output->index()); } + +CircleOutput *output_node(loco::Graph *g, const loco::GraphOutputIndex &index) +{ + for (uint32_t n = 0; n < g->nodes()->size(); ++n) + { + if (auto output = dynamic_cast<CircleOutput *>(g->nodes()->at(n))) + { + if (output->indexed() && output->index() == index) + { + return output; + } + } + } + return nullptr; +} + +void link(loco::GraphInput *input, CircleInput *node) { node->index(input->index()); } + +CircleInput *input_node(loco::Graph *g, const loco::GraphInputIndex &index) +{ + for (uint32_t n = 0; n < g->nodes()->size(); ++n) + { + if (auto input = dynamic_cast<CircleInput *>(g->nodes()->at(n))) + { + if (input->indexed() && input->index() == index) + { + return input; + } + } + } + return nullptr; +} + } // namespace luci |