summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>2019-09-05 17:19:36 +0900
committerGitHub Enterprise <noreply-CODE@samsung.com>2019-09-05 17:19:36 +0900
commit36f814f0dc5f2d035e9e10b64d6c41fa3b64fc6b (patch)
treec71160c3612052c1fe9f64ee3f33a185b0371d9e /compiler
parent3e42d442b47758c749508d717aded0256b56e37b (diff)
downloadnnfw-36f814f0dc5f2d035e9e10b64d6c41fa3b64fc6b.tar.gz
nnfw-36f814f0dc5f2d035e9e10b64d6c41fa3b64fc6b.tar.bz2
nnfw-36f814f0dc5f2d035e9e10b64d6c41fa3b64fc6b.zip
[moco-tf] Use loco::shape in import (#7217)
This will revise import to use loco shape instead of ShapeInferenceData - also add guards for safety Signed-off-by: SaeHie Park <saehie.park@samsung.com>
Diffstat (limited to 'compiler')
-rw-r--r--compiler/moco-tf/src/Frontend.cpp10
1 files changed, 7 insertions, 3 deletions
diff --git a/compiler/moco-tf/src/Frontend.cpp b/compiler/moco-tf/src/Frontend.cpp
index a0bf1b670..e76580785 100644
--- a/compiler/moco-tf/src/Frontend.cpp
+++ b/compiler/moco-tf/src/Frontend.cpp
@@ -26,6 +26,8 @@
#include "Op/COpCall.h"
+#include <loco/Service/ShapeInference.h>
+
#include <cwrap/Fildes.h>
#include <stdex/Memory.h>
@@ -124,15 +126,15 @@ moco::tf::GraphBuilderRegistry make_graph_builder_registry(const moco::tf::Model
} // namespace
// TODO Find a proper place for this function
-#include "Annotations/ShapeInferenceData.h"
namespace
{
loco::TensorShape tensor_shape(loco::Node *node)
{
- assert(node->annot<moco::tf::ShapeInferenceData>() != nullptr);
- return node->annot<moco::tf::ShapeInferenceData>()->tensor_shape();
+ assert(loco::shape_known(node));
+ auto node_shape = loco::shape_get(node);
+ return node_shape.as<loco::TensorShape>();
}
} // namespace
@@ -258,6 +260,7 @@ std::unique_ptr<loco::Graph> Frontend::import(const ModelSignature &signature,
{
auto input = graph->inputs()->at(n);
auto input_node = loco::pull_node(graph.get(), n);
+ assert(input_node != nullptr);
input->shape(stdex::make_unique<loco::TensorShape>(tensor_shape(input_node)));
}
@@ -265,6 +268,7 @@ std::unique_ptr<loco::Graph> Frontend::import(const ModelSignature &signature,
{
auto output = graph->outputs()->at(n);
auto output_node = loco::push_node(graph.get(), n);
+ assert(output_node != nullptr);
output->shape(stdex::make_unique<loco::TensorShape>(tensor_shape(output_node)));
}