diff options
author | 박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com> | 2019-09-05 17:19:36 +0900 |
---|---|---|
committer | GitHub Enterprise <noreply-CODE@samsung.com> | 2019-09-05 17:19:36 +0900 |
commit | 36f814f0dc5f2d035e9e10b64d6c41fa3b64fc6b (patch) | |
tree | c71160c3612052c1fe9f64ee3f33a185b0371d9e /compiler | |
parent | 3e42d442b47758c749508d717aded0256b56e37b (diff) | |
download | nnfw-36f814f0dc5f2d035e9e10b64d6c41fa3b64fc6b.tar.gz nnfw-36f814f0dc5f2d035e9e10b64d6c41fa3b64fc6b.tar.bz2 nnfw-36f814f0dc5f2d035e9e10b64d6c41fa3b64fc6b.zip |
[moco-tf] Use loco::shape in import (#7217)
This will revise import to use loco shape instead of ShapeInferenceData
- also add guards for safety
Signed-off-by: SaeHie Park <saehie.park@samsung.com>
Diffstat (limited to 'compiler')
-rw-r--r-- | compiler/moco-tf/src/Frontend.cpp | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/compiler/moco-tf/src/Frontend.cpp b/compiler/moco-tf/src/Frontend.cpp index a0bf1b670..e76580785 100644 --- a/compiler/moco-tf/src/Frontend.cpp +++ b/compiler/moco-tf/src/Frontend.cpp @@ -26,6 +26,8 @@ #include "Op/COpCall.h" +#include <loco/Service/ShapeInference.h> + #include <cwrap/Fildes.h> #include <stdex/Memory.h> @@ -124,15 +126,15 @@ moco::tf::GraphBuilderRegistry make_graph_builder_registry(const moco::tf::Model } // namespace // TODO Find a proper place for this function -#include "Annotations/ShapeInferenceData.h" namespace { loco::TensorShape tensor_shape(loco::Node *node) { - assert(node->annot<moco::tf::ShapeInferenceData>() != nullptr); - return node->annot<moco::tf::ShapeInferenceData>()->tensor_shape(); + assert(loco::shape_known(node)); + auto node_shape = loco::shape_get(node); + return node_shape.as<loco::TensorShape>(); } } // namespace @@ -258,6 +260,7 @@ std::unique_ptr<loco::Graph> Frontend::import(const ModelSignature &signature, { auto input = graph->inputs()->at(n); auto input_node = loco::pull_node(graph.get(), n); + assert(input_node != nullptr); input->shape(stdex::make_unique<loco::TensorShape>(tensor_shape(input_node))); } @@ -265,6 +268,7 @@ std::unique_ptr<loco::Graph> Frontend::import(const ModelSignature &signature, { auto output = graph->outputs()->at(n); auto output_node = loco::push_node(graph.get(), n); + assert(output_node != nullptr); output->shape(stdex::make_unique<loco::TensorShape>(tensor_shape(output_node))); } |