summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorstruss <rrstrous@nate.com>2020-06-25 16:41:00 +0900
committerGitHub <noreply@github.com>2020-06-25 16:41:00 +0900
commitf527512bf2d9040463db7be328a86644d741305d (patch)
treeb88ce7e64ba0e52d7301ff2daf9786c7ba7d9339
parent7fc3a430ce69d913527e4c3cd3c00674be7b7dcc (diff)
downloadnnfw-f527512bf2d9040463db7be328a86644d741305d.tar.gz
nnfw-f527512bf2d9040463db7be328a86644d741305d.tar.bz2
nnfw-f527512bf2d9040463db7be328a86644d741305d.zip
[luci-interpreter]Add extractShape function on testutil (#2531)
This commit add extractShape function on luci-interpreter testutils. ONE-DCO-1.0-Signed-off-by: KiDeuk Bang <rrstrous@nate.com>
-rw-r--r--compiler/luci-interpreter/src/kernels/TestUtils.cpp11
-rw-r--r--compiler/luci-interpreter/src/kernels/TestUtils.h2
2 files changed, 13 insertions, 0 deletions
diff --git a/compiler/luci-interpreter/src/kernels/TestUtils.cpp b/compiler/luci-interpreter/src/kernels/TestUtils.cpp
index 2641a1625..2c8a6ae78 100644
--- a/compiler/luci-interpreter/src/kernels/TestUtils.cpp
+++ b/compiler/luci-interpreter/src/kernels/TestUtils.cpp
@@ -45,6 +45,17 @@ std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float> &values, flo
return matchers;
}
+std::vector<int32_t> extractTensorShape(const Tensor &tensor)
+{
+ std::vector<int32_t> result;
+ int dims = tensor.shape().num_dims();
+ for (int i = 0; i < dims; i++)
+ {
+ result.push_back(tensor.shape().dim(i));
+ }
+ return result;
+}
+
} // namespace testing
} // namespace kernels
} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/kernels/TestUtils.h b/compiler/luci-interpreter/src/kernels/TestUtils.h
index 9d9304ae2..ff223773b 100644
--- a/compiler/luci-interpreter/src/kernels/TestUtils.h
+++ b/compiler/luci-interpreter/src/kernels/TestUtils.h
@@ -43,6 +43,8 @@ Tensor makeInputTensor(const Shape &shape, const std::vector<typename DataTypeIm
Tensor makeOutputTensor(DataType element_type);
Tensor makeOutputTensor(DataType element_type, float scale, int32_t zero_point);
+std::vector<int32_t> extractTensorShape(const Tensor &tensor);
+
// Returns the corresponding DataType given the type T.
template <typename T> constexpr DataType getElementType()
{