diff options
Diffstat (limited to 'compiler/luci/import/src/Nodes/CircleTopKV2.cpp')
-rw-r--r-- | compiler/luci/import/src/Nodes/CircleTopKV2.cpp | 69 |
1 files changed, 17 insertions, 52 deletions
diff --git a/compiler/luci/import/src/Nodes/CircleTopKV2.cpp b/compiler/luci/import/src/Nodes/CircleTopKV2.cpp index f0677de86..49f858798 100644 --- a/compiler/luci/import/src/Nodes/CircleTopKV2.cpp +++ b/compiler/luci/import/src/Nodes/CircleTopKV2.cpp @@ -59,59 +59,24 @@ bool CircleTopKV2GraphBuilder::validate(const ValidateArgs &args) const * \- CircleTopKV2Out --- FullyConnected --- */ -void CircleTopKV2GraphBuilder::build(const circle::OperatorT &op, - GraphBuilderContext *context) const +CircleNode *CircleTopKV2GraphBuilder::build_node(const BuildNodeArgs &bna) const { - assert(context != nullptr); - - auto graph = context->graph(); - - const std::vector<int32_t> &inputs = op.inputs; - const std::vector<int32_t> &outputs = op.outputs; - const auto &tensors = context->reader()->tensors(); - const auto &opcodes = context->reader()->opcodes(); - auto tensors_ptr = context->reader()->tensors_ptr(); - assert(tensors_ptr != nullptr); - - std::vector<CircleNode *> input_nodes; - for (const int32_t input_tensor_index : inputs) - { - input_nodes.push_back(context->nodefinder()->node(input_tensor_index)); - } - - // Create CircleTopKV2 - auto node = graph->nodes()->create<CircleTopKV2>(); - node->input(input_nodes[0]); - node->k(input_nodes[1]); - - assert(outputs.size() == 2); - { - // Let's use name of output 0 as TopKV2 name - const circle::TensorT &output_tensor = *tensors[outputs[0]]; - node->name(tensor_name(output_tensor)); - node->op_version(opcodes[op.opcode_index].get()->version); - - // NOTE We don't set quantization for TopKV2 itself but to virtual outputs - } - - // Create virtual outputs of TopKV2 - for (size_t n = 0; n < outputs.size(); ++n) - { - const circle::TensorT &output_tensor = *tensors[outputs[n]]; - - auto *nodeout = graph->nodes()->create<CircleTopKV2Out>(); - copy_tensor_attributes(output_tensor, nodeout); - // mark shape_status - if (tensors_ptr->Get(outputs[n])->shape() == nullptr) - nodeout->shape_status(ShapeStatus::NOSHAPE); - else - nodeout->shape_status(ShapeStatus::VALID); - - nodeout->input(node); - nodeout->index(n); - - context->nodefinder()->enroll(outputs[n], nodeout); - } + auto node = bna.context->graph()->nodes()->create<CircleTopKV2>(); + + node->input(bna.input_nodes[0]); + node->k(bna.input_nodes[1]); + + return node; +} + +CircleNode *CircleTopKV2GraphBuilder::build_out(const BuildOutArgs &boa) const +{ + auto *nodeout = boa.node->graph()->nodes()->create<CircleTopKV2Out>(); + + nodeout->input(boa.node); + nodeout->index(boa.index); + + return nodeout; } } // namespace luci |