diff options
Diffstat (limited to 'compiler/luci/import/src/PostImport.cpp')
-rw-r--r-- | compiler/luci/import/src/PostImport.cpp | 47 |
1 files changed, 32 insertions, 15 deletions
diff --git a/compiler/luci/import/src/PostImport.cpp b/compiler/luci/import/src/PostImport.cpp index f436b48e8..63b16bb95 100644 --- a/compiler/luci/import/src/PostImport.cpp +++ b/compiler/luci/import/src/PostImport.cpp @@ -130,7 +130,10 @@ private: namespace { /** - * @brief ValidateNodeProp will validate inter graph connections for each Nodes + * @brief ValidateNodeProp will validate inter graph connections for each Nodes. + * @note In here, only loco::GraphInput and loco::GraphOutput are validated, + * since this class is for checking inter graph connections. + * CircleNodes such as CircleInput and CircleOutput will be validated at later steps. */ class ValidateNodeProp final : public luci::CircleNodeMutableVisitor<void> { @@ -172,9 +175,19 @@ public: auto then_graph_output = then_graph_outputs->at(then_out->index()); auto else_graph_output = else_graph_outputs->at(else_out->index()); - if (!(*then_graph_output->shape() == *else_graph_output->shape())) + if (then_graph_output->shape()->rank() != else_graph_output->shape()->rank()) { - INTERNAL_EXN_V("CircleIf THEN and ELSE Graph Output shape mismatch ", idx); + INTERNAL_EXN_V("CircleIf THEN and ELSE Graph Output rank mismatch ", idx); + } + for (uint32_t i = 0; i < then_graph_output->shape()->rank(); ++i) + { + if (then_graph_output->shape()->dim(i).known() && + else_graph_output->shape()->dim(i).known() && + then_graph_output->shape()->dim(i).value() != + else_graph_output->shape()->dim(i).value()) + { + INTERNAL_EXN_V("CircleIf THEN and ELSE Graph Output dimension mismatch ", idx); + } } if (then_graph_output->dtype() != else_graph_output->dtype()) { @@ -231,18 +244,20 @@ public: auto cond_graph_input = cond_graph_inputs->at(cond_in->index()); auto body_graph_input = body_graph_inputs->at(body_in->index()); - if ((cond_in->rank() != body_in->rank())) + if (cond_graph_input->shape()->rank() != body_graph_input->shape()->rank()) { - INTERNAL_EXN_V("CircleWhile COND input and BODY input shape mismatch ", idx); + INTERNAL_EXN_V("CircleWhile COND input and BODY input rank mismatch ", idx); } - if (cond_in->rank() > 0 && body_in->rank() > 0) + for (uint32_t i = 0; i < cond_graph_input->shape()->rank(); ++i) { - if (!(*cond_graph_input->shape() == *body_graph_input->shape())) + if (cond_graph_input->shape()->dim(i).known() && + body_graph_input->shape()->dim(i).known() && + cond_graph_input->shape()->dim(i).value() != body_graph_input->shape()->dim(i).value()) { - INTERNAL_EXN_V("CircleWhile COND input and BODY input shape mismatch ", idx); + INTERNAL_EXN_V("CircleWhile COND input and BODY input dimension mismatch ", idx); } } - if (cond_in->dtype() != body_in->dtype()) + if (cond_graph_input->dtype() != body_graph_input->dtype()) { INTERNAL_EXN_V("CircleWhile COND input and BODY input type mismatch ", idx); } @@ -257,18 +272,20 @@ public: auto cond_graph_input = cond_graph_inputs->at(cond_in->index()); auto body_graph_output = body_graph_outputs->at(body_out->index()); - if ((cond_in->rank() != body_out->rank())) + if (cond_graph_input->shape()->rank() != body_graph_output->shape()->rank()) { - INTERNAL_EXN_V("CircleWhile COND input and BODY output shape mismatch ", idx); + INTERNAL_EXN_V("CircleWhile COND input and BODY output rank mismatch ", idx); } - if (cond_in->rank() > 0 && body_out->rank() > 0) + for (uint32_t i = 0; i < cond_graph_input->shape()->rank(); ++i) { - if (!(*cond_graph_input->shape() == *body_graph_output->shape())) + if (cond_graph_input->shape()->dim(i).known() && + body_graph_output->shape()->dim(i).known() && + cond_graph_input->shape()->dim(i).value() != body_graph_output->shape()->dim(i).value()) { - INTERNAL_EXN_V("CircleWhile COND input and BODY output shape mismatch ", idx); + INTERNAL_EXN_V("CircleWhile COND input and BODY output dimension mismatch ", idx); } } - if (cond_in->dtype() != body_out->dtype()) + if (cond_graph_input->dtype() != body_graph_output->dtype()) { INTERNAL_EXN_V("CircleWhile COND input and BODY output type mismatch ", idx); } |