summaryrefslogtreecommitdiff
path: root/compiler/luci/import/src/PostImport.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci/import/src/PostImport.cpp')
-rw-r--r--compiler/luci/import/src/PostImport.cpp47
1 files changed, 32 insertions, 15 deletions
diff --git a/compiler/luci/import/src/PostImport.cpp b/compiler/luci/import/src/PostImport.cpp
index f436b48e8..63b16bb95 100644
--- a/compiler/luci/import/src/PostImport.cpp
+++ b/compiler/luci/import/src/PostImport.cpp
@@ -130,7 +130,10 @@ private:
namespace
{
/**
- * @brief ValidateNodeProp will validate inter graph connections for each Nodes
+ * @brief ValidateNodeProp will validate inter graph connections for each Nodes.
+ * @note In here, only loco::GraphInput and loco::GraphOutput are validated,
+ * since this class is for checking inter graph connections.
+ * CircleNodes such as CircleInput and CircleOutput will be validated at later steps.
*/
class ValidateNodeProp final : public luci::CircleNodeMutableVisitor<void>
{
@@ -172,9 +175,19 @@ public:
auto then_graph_output = then_graph_outputs->at(then_out->index());
auto else_graph_output = else_graph_outputs->at(else_out->index());
- if (!(*then_graph_output->shape() == *else_graph_output->shape()))
+ if (then_graph_output->shape()->rank() != else_graph_output->shape()->rank())
{
- INTERNAL_EXN_V("CircleIf THEN and ELSE Graph Output shape mismatch ", idx);
+ INTERNAL_EXN_V("CircleIf THEN and ELSE Graph Output rank mismatch ", idx);
+ }
+ for (uint32_t i = 0; i < then_graph_output->shape()->rank(); ++i)
+ {
+ if (then_graph_output->shape()->dim(i).known() &&
+ else_graph_output->shape()->dim(i).known() &&
+ then_graph_output->shape()->dim(i).value() !=
+ else_graph_output->shape()->dim(i).value())
+ {
+ INTERNAL_EXN_V("CircleIf THEN and ELSE Graph Output dimension mismatch ", idx);
+ }
}
if (then_graph_output->dtype() != else_graph_output->dtype())
{
@@ -231,18 +244,20 @@ public:
auto cond_graph_input = cond_graph_inputs->at(cond_in->index());
auto body_graph_input = body_graph_inputs->at(body_in->index());
- if ((cond_in->rank() != body_in->rank()))
+ if (cond_graph_input->shape()->rank() != body_graph_input->shape()->rank())
{
- INTERNAL_EXN_V("CircleWhile COND input and BODY input shape mismatch ", idx);
+ INTERNAL_EXN_V("CircleWhile COND input and BODY input rank mismatch ", idx);
}
- if (cond_in->rank() > 0 && body_in->rank() > 0)
+ for (uint32_t i = 0; i < cond_graph_input->shape()->rank(); ++i)
{
- if (!(*cond_graph_input->shape() == *body_graph_input->shape()))
+ if (cond_graph_input->shape()->dim(i).known() &&
+ body_graph_input->shape()->dim(i).known() &&
+ cond_graph_input->shape()->dim(i).value() != body_graph_input->shape()->dim(i).value())
{
- INTERNAL_EXN_V("CircleWhile COND input and BODY input shape mismatch ", idx);
+ INTERNAL_EXN_V("CircleWhile COND input and BODY input dimension mismatch ", idx);
}
}
- if (cond_in->dtype() != body_in->dtype())
+ if (cond_graph_input->dtype() != body_graph_input->dtype())
{
INTERNAL_EXN_V("CircleWhile COND input and BODY input type mismatch ", idx);
}
@@ -257,18 +272,20 @@ public:
auto cond_graph_input = cond_graph_inputs->at(cond_in->index());
auto body_graph_output = body_graph_outputs->at(body_out->index());
- if ((cond_in->rank() != body_out->rank()))
+ if (cond_graph_input->shape()->rank() != body_graph_output->shape()->rank())
{
- INTERNAL_EXN_V("CircleWhile COND input and BODY output shape mismatch ", idx);
+ INTERNAL_EXN_V("CircleWhile COND input and BODY output rank mismatch ", idx);
}
- if (cond_in->rank() > 0 && body_out->rank() > 0)
+ for (uint32_t i = 0; i < cond_graph_input->shape()->rank(); ++i)
{
- if (!(*cond_graph_input->shape() == *body_graph_output->shape()))
+ if (cond_graph_input->shape()->dim(i).known() &&
+ body_graph_output->shape()->dim(i).known() &&
+ cond_graph_input->shape()->dim(i).value() != body_graph_output->shape()->dim(i).value())
{
- INTERNAL_EXN_V("CircleWhile COND input and BODY output shape mismatch ", idx);
+ INTERNAL_EXN_V("CircleWhile COND input and BODY output dimension mismatch ", idx);
}
}
- if (cond_in->dtype() != body_out->dtype())
+ if (cond_graph_input->dtype() != body_graph_output->dtype())
{
INTERNAL_EXN_V("CircleWhile COND input and BODY output type mismatch ", idx);
}