summaryrefslogtreecommitdiff
path: root/compiler/luci-interpreter/src/kernels/Sqrt.test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci-interpreter/src/kernels/Sqrt.test.cpp')
-rw-r--r--compiler/luci-interpreter/src/kernels/Sqrt.test.cpp7
1 files changed, 2 insertions, 5 deletions
diff --git a/compiler/luci-interpreter/src/kernels/Sqrt.test.cpp b/compiler/luci-interpreter/src/kernels/Sqrt.test.cpp
index cdd208280..504db4493 100644
--- a/compiler/luci-interpreter/src/kernels/Sqrt.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/Sqrt.test.cpp
@@ -29,17 +29,14 @@ using namespace testing;
void Check(std::initializer_list<int32_t> input_shape, std::initializer_list<int32_t> output_shape,
std::initializer_list<float> input_data, std::initializer_list<float> output_data)
{
- Tensor input_tensor{DataType::FLOAT32, input_shape, {}, ""};
- input_tensor.writeData(input_data.begin(), input_data.size() * sizeof(float));
-
+ Tensor input_tensor = makeInputTensor<DataType::FLOAT32>(input_shape, input_data);
Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
Sqrt kernel(&input_tensor, &output_tensor);
kernel.configure();
kernel.execute();
- EXPECT_THAT(extractTensorData<float>(output_tensor),
- ::testing::ElementsAreArray(ArrayFloatNear(output_data)));
+ EXPECT_THAT(extractTensorData<float>(output_tensor), FloatArrayNear(output_data));
EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(output_shape));
}