diff options
Diffstat (limited to 'compiler/luci/import/src/Nodes/CircleSplitV.cpp')
-rw-r--r-- | compiler/luci/import/src/Nodes/CircleSplitV.cpp | 76 |
1 files changed, 21 insertions, 55 deletions
diff --git a/compiler/luci/import/src/Nodes/CircleSplitV.cpp b/compiler/luci/import/src/Nodes/CircleSplitV.cpp index 7c6e83e17..76cbf7046 100644 --- a/compiler/luci/import/src/Nodes/CircleSplitV.cpp +++ b/compiler/luci/import/src/Nodes/CircleSplitV.cpp @@ -58,64 +58,30 @@ bool CircleSplitVGraphBuilder::validate(const ValidateArgs &args) const * \- CircleSplitVOut --- FullyConnected --- */ -void CircleSplitVGraphBuilder::build(const circle::OperatorT &op, - GraphBuilderContext *context) const +CircleNode *CircleSplitVGraphBuilder::build_node(const BuildNodeArgs &bna) const { - assert(context != nullptr); - - auto graph = context->graph(); - - const std::vector<int32_t> &inputs = op.inputs; - const std::vector<int32_t> &outputs = op.outputs; - const auto &tensors = context->reader()->tensors(); - const auto &opcodes = context->reader()->opcodes(); - auto tensors_ptr = context->reader()->tensors_ptr(); - assert(tensors_ptr != nullptr); - - std::vector<CircleNode *> input_nodes; - for (const int32_t input_tensor_index : inputs) - { - input_nodes.push_back(context->nodefinder()->node(input_tensor_index)); - } - - // Create CircleSplitV - auto node = graph->nodes()->create<CircleSplitV>(); - node->input(input_nodes[0]); - node->size_splits(input_nodes[1]); - node->split_dim(input_nodes[2]); - - const auto *options = op.builtin_options.AsSplitVOptions(); + auto node = bna.context->graph()->nodes()->create<CircleSplitV>(); + + node->input(bna.input_nodes[0]); + node->size_splits(bna.input_nodes[1]); + node->split_dim(bna.input_nodes[2]); + + const auto *options = bna.op.builtin_options.AsSplitVOptions(); node->num_split(options->num_splits); - assert(outputs.size() > 0); - assert(int32_t(outputs.size()) == options->num_splits); - { - // Let's use name of output 0 as Split name - const circle::TensorT &output_tensor = *tensors[outputs[0]]; - node->name(tensor_name(output_tensor)); - node->op_version(opcodes[op.opcode_index].get()->version); - - // NOTE We don't set quantization for Split itself but to virtual outputs - } - - // Create virtual outputs of Split - for (int32_t n = 0; n < options->num_splits; ++n) - { - const circle::TensorT &output_tensor = *tensors[outputs[n]]; - - auto *nodeout = graph->nodes()->create<CircleSplitVOut>(); - copy_tensor_attributes(output_tensor, nodeout); - // mark shape_status - if (tensors_ptr->Get(outputs[n])->shape() == nullptr) - nodeout->shape_status(ShapeStatus::NOSHAPE); - else - nodeout->shape_status(ShapeStatus::VALID); - - nodeout->input(node); - nodeout->index(n); - - context->nodefinder()->enroll(outputs[n], nodeout); - } + assert(int32_t(bna.op.outputs.size()) == options->num_splits); + + return node; +} + +CircleNode *CircleSplitVGraphBuilder::build_out(const BuildOutArgs &boa) const +{ + auto *nodeout = boa.node->graph()->nodes()->create<CircleSplitVOut>(); + + nodeout->input(boa.node); + nodeout->index(boa.index); + + return nodeout; } } // namespace luci |