diff options
author | 박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com> | 2019-09-17 07:01:23 +0900 |
---|---|---|
committer | GitHub Enterprise <noreply-CODE@samsung.com> | 2019-09-17 07:01:23 +0900 |
commit | d96fbc0347d2f3455209e72e4478ee6840b58b4f (patch) | |
tree | 643864a03c82f83cda38c33d4b4e7cc2bbbdc83f | |
parent | 8808d0dd07b15e53d76bdb8e39169d8a5ab09e59 (diff) | |
download | nnfw-d96fbc0347d2f3455209e72e4478ee6840b58b4f.tar.gz nnfw-d96fbc0347d2f3455209e72e4478ee6840b58b4f.tar.bz2 nnfw-d96fbc0347d2f3455209e72e4478ee6840b58b4f.zip |
[exo-tflite] Test of broadcast algorithm (#7471)
This will add a test of broadcast algorithm with TFLAdd
Signed-off-by: SaeHie Park <saehie.park@samsung.com>
-rw-r--r-- | compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp | 76 |
1 files changed, 76 insertions, 0 deletions
diff --git a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp index eca72e618..bdacaf0b5 100644 --- a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp +++ b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp @@ -20,6 +20,8 @@ #include "Dialect/IR/TFLDialect.h" #include "Dialect/Service/TFLShapeInferenceRule.h" +#include "Conversion/ShapeInferencePass.h" + #include <loco.h> #include <loco/IR/CanonicalDialect.h> #include <loco/Service/ShapeInference.h> @@ -171,3 +173,77 @@ TEST(TFLShapeInferenceRuleTest, avgpool2d_same) ASSERT_EQ(shape.dim(3).value(), 1); } } + +/** + * @note Function to test: Shape inference of two different input shapes + * + * Rank expansion to higher input side + * x(2,1,5) + y(3,5) --> x(2,1,5) + y(1,3,5) + * Do output shape inference like numpy + * x(2,1,5) + y(1,3,5) --> output(2,3,5) + * For each axis, dim value should be same OR one of them should be 1 + */ +TEST(TFLShapeInferenceRuleTest, TFAdd_shapeinf_different) +{ + auto g = loco::make_graph(); + + auto x_node = g->nodes()->create<loco::Pull>(); + { + x_node->rank(3); + x_node->dim(0) = 2; + x_node->dim(1) = 1; + x_node->dim(2) = 5; + } + auto y_node = g->nodes()->create<loco::Pull>(); + { + y_node->rank(2); + y_node->dim(0) = 3; + y_node->dim(1) = 5; + } + auto tfl_node = g->nodes()->create<locoex::TFLAdd>(); + { + tfl_node->x(x_node); + tfl_node->y(y_node); + } + auto push_node = g->nodes()->create<loco::Push>(); + { + push_node->from(tfl_node); + } + + auto x_input = g->inputs()->create(); + { + x_input->name("x"); + loco::link(x_input, x_node); + } + auto y_input = g->inputs()->create(); + { + y_input->name("y"); + loco::link(y_input, y_node); + } + auto output = g->outputs()->create(); + { + output->name("output"); + loco::link(output, push_node); + } + + // pre-check + ASSERT_FALSE(loco::shape_known(tfl_node)); + + exo::ShapeInferencePass pass; + while (pass.run(g.get()) == true) + { + ; + } + + // Verify + { + ASSERT_TRUE(loco::shape_known(tfl_node)); + ASSERT_EQ(loco::shape_get(tfl_node).domain(), loco::Domain::Tensor); + + auto shape = loco::shape_get(tfl_node).as<loco::TensorShape>(); + ASSERT_EQ(shape.rank(), 3); + ASSERT_EQ(shape.dim(0), 2); + ASSERT_EQ(shape.dim(1), 3); + ASSERT_EQ(shape.dim(2), 5); + } +} |