summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>2019-09-17 07:01:23 +0900
committerGitHub Enterprise <noreply-CODE@samsung.com>2019-09-17 07:01:23 +0900
commitd96fbc0347d2f3455209e72e4478ee6840b58b4f (patch)
tree643864a03c82f83cda38c33d4b4e7cc2bbbdc83f
parent8808d0dd07b15e53d76bdb8e39169d8a5ab09e59 (diff)
downloadnnfw-d96fbc0347d2f3455209e72e4478ee6840b58b4f.tar.gz
nnfw-d96fbc0347d2f3455209e72e4478ee6840b58b4f.tar.bz2
nnfw-d96fbc0347d2f3455209e72e4478ee6840b58b4f.zip
[exo-tflite] Test of broadcast algorithm (#7471)
This will add a test of broadcast algorithm with TFLAdd Signed-off-by: SaeHie Park <saehie.park@samsung.com>
-rw-r--r--compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp76
1 files changed, 76 insertions, 0 deletions
diff --git a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp
index eca72e618..bdacaf0b5 100644
--- a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp
+++ b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp
@@ -20,6 +20,8 @@
#include "Dialect/IR/TFLDialect.h"
#include "Dialect/Service/TFLShapeInferenceRule.h"
+#include "Conversion/ShapeInferencePass.h"
+
#include <loco.h>
#include <loco/IR/CanonicalDialect.h>
#include <loco/Service/ShapeInference.h>
@@ -171,3 +173,77 @@ TEST(TFLShapeInferenceRuleTest, avgpool2d_same)
ASSERT_EQ(shape.dim(3).value(), 1);
}
}
+
+/**
+ * @note Function to test: Shape inference of two different input shapes
+ *
+ * Rank expansion to higher input side
+ * x(2,1,5) + y(3,5) --> x(2,1,5) + y(1,3,5)
+ * Do output shape inference like numpy
+ * x(2,1,5) + y(1,3,5) --> output(2,3,5)
+ * For each axis, dim value should be same OR one of them should be 1
+ */
+TEST(TFLShapeInferenceRuleTest, TFAdd_shapeinf_different)
+{
+ auto g = loco::make_graph();
+
+ auto x_node = g->nodes()->create<loco::Pull>();
+ {
+ x_node->rank(3);
+ x_node->dim(0) = 2;
+ x_node->dim(1) = 1;
+ x_node->dim(2) = 5;
+ }
+ auto y_node = g->nodes()->create<loco::Pull>();
+ {
+ y_node->rank(2);
+ y_node->dim(0) = 3;
+ y_node->dim(1) = 5;
+ }
+ auto tfl_node = g->nodes()->create<locoex::TFLAdd>();
+ {
+ tfl_node->x(x_node);
+ tfl_node->y(y_node);
+ }
+ auto push_node = g->nodes()->create<loco::Push>();
+ {
+ push_node->from(tfl_node);
+ }
+
+ auto x_input = g->inputs()->create();
+ {
+ x_input->name("x");
+ loco::link(x_input, x_node);
+ }
+ auto y_input = g->inputs()->create();
+ {
+ y_input->name("y");
+ loco::link(y_input, y_node);
+ }
+ auto output = g->outputs()->create();
+ {
+ output->name("output");
+ loco::link(output, push_node);
+ }
+
+ // pre-check
+ ASSERT_FALSE(loco::shape_known(tfl_node));
+
+ exo::ShapeInferencePass pass;
+ while (pass.run(g.get()) == true)
+ {
+ ;
+ }
+
+ // Verify
+ {
+ ASSERT_TRUE(loco::shape_known(tfl_node));
+ ASSERT_EQ(loco::shape_get(tfl_node).domain(), loco::Domain::Tensor);
+
+ auto shape = loco::shape_get(tfl_node).as<loco::TensorShape>();
+ ASSERT_EQ(shape.rank(), 3);
+ ASSERT_EQ(shape.dim(0), 2);
+ ASSERT_EQ(shape.dim(1), 3);
+ ASSERT_EQ(shape.dim(2), 5);
+ }
+}