summaryrefslogtreecommitdiff
path: root/tests/AssetsLibrary.h
diff options
context:
space:
mode:
authorJenkins <bsgcomp@arm.com>2018-05-23 11:36:53 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-05-23 14:55:11 +0100
commitb3a371bc429d2ba45e56baaf239d8200c2662a74 (patch)
tree554525e415c303d64a08722a755397852ebbb8e4 /tests/AssetsLibrary.h
parent67c8c91522e5be8156b77f57e63c0253535c902a (diff)
downloadarmcl-master.tar.gz
armcl-master.tar.bz2
armcl-master.zip
arm_compute v18.05HEADmaster
Diffstat (limited to 'tests/AssetsLibrary.h')
-rw-r--r--tests/AssetsLibrary.h65
1 files changed, 61 insertions, 4 deletions
diff --git a/tests/AssetsLibrary.h b/tests/AssetsLibrary.h
index 7e2a042e9..1fba3d4b4 100644
--- a/tests/AssetsLibrary.h
+++ b/tests/AssetsLibrary.h
@@ -66,27 +66,39 @@ public:
*/
AssetsLibrary(std::string path, std::random_device::result_type seed);
- /** Path to assets directory used to initialise library. */
+ /** Path to assets directory used to initialise library.
+ *
+ * @return the path to the assets directory.
+ */
std::string path() const;
- /** Seed that is used to fill tensors with random values. */
+ /** Seed that is used to fill tensors with random values.
+ *
+ * @return the initial random seed.
+ */
std::random_device::result_type seed() const;
/** Provides a tensor shape for the specified image.
*
* @param[in] name Image file used to look up the raw tensor.
+ *
+ * @return the tensor shape for the specified image.
*/
TensorShape get_image_shape(const std::string &name);
- /** Provides a contant raw tensor for the specified image.
+ /** Provides a constant raw tensor for the specified image.
*
* @param[in] name Image file used to look up the raw tensor.
+ *
+ * @return a raw tensor for the specified image.
*/
const RawTensor &get(const std::string &name) const;
/** Provides a raw tensor for the specified image.
*
* @param[in] name Image file used to look up the raw tensor.
+ *
+ * @return a raw tensor for the specified image.
*/
RawTensor get(const std::string &name);
@@ -96,6 +108,8 @@ public:
* @param[in] name Image file used to initialise the tensor.
* @param[in] data_type Data type used to initialise the tensor.
* @param[in] num_channels Number of channels used to initialise the tensor.
+ *
+ * @return a raw tensor for the specified image.
*/
RawTensor get(const std::string &name, DataType data_type, int num_channels = 1) const;
@@ -104,6 +118,8 @@ public:
*
* @param[in] name Image file used to look up the raw tensor.
* @param[in] format Format used to look up the raw tensor.
+ *
+ * @return a raw tensor for the specified image.
*/
const RawTensor &get(const std::string &name, Format format) const;
@@ -112,6 +128,8 @@ public:
*
* @param[in] name Image file used to look up the raw tensor.
* @param[in] format Format used to look up the raw tensor.
+ *
+ * @return a raw tensor for the specified image.
*/
RawTensor get(const std::string &name, Format format);
@@ -123,6 +141,8 @@ public:
*
* @note The channel has to be unambiguous so that the format can be
* inferred automatically.
+ *
+ * @return a raw tensor for the specified image channel.
*/
const RawTensor &get(const std::string &name, Channel channel) const;
@@ -134,6 +154,8 @@ public:
*
* @note The channel has to be unambiguous so that the format can be
* inferred automatically.
+ *
+ * @return a raw tensor for the specified image channel.
*/
RawTensor get(const std::string &name, Channel channel);
@@ -143,6 +165,8 @@ public:
* @param[in] name Image file used to look up the raw tensor.
* @param[in] format Format used to look up the raw tensor.
* @param[in] channel Channel used to look up the raw tensor.
+ *
+ * @return a raw tensor for the specified image channel.
*/
const RawTensor &get(const std::string &name, Format format, Channel channel) const;
@@ -152,6 +176,8 @@ public:
* @param[in] name Image file used to look up the raw tensor.
* @param[in] format Format used to look up the raw tensor.
* @param[in] channel Channel used to look up the raw tensor.
+ *
+ * @return a raw tensor for the specified image channel.
*/
RawTensor get(const std::string &name, Format format, Channel channel);
@@ -326,6 +352,16 @@ public:
template <typename T>
void fill_layer_data(T &&tensor, std::string name) const;
+ /** Fill a tensor with a constant value
+ *
+ * @param[in, out] tensor To be filled tensor.
+ * @param[in] value Value to be assigned to all elements of the input tensor.
+ *
+ * @note @p value must be of the same type as the data type of @p tensor
+ */
+ template <typename T, typename D>
+ void fill_tensor_value(T &&tensor, D value) const;
+
private:
// Function prototype to convert between image formats.
using Converter = void (*)(const RawTensor &src, RawTensor &dst);
@@ -420,10 +456,25 @@ void AssetsLibrary::fill(T &&tensor, D &&distribution, std::random_device::resul
std::mt19937 gen(_seed + seed_offset);
+ const bool is_nhwc = tensor.data_layout() == DataLayout::NHWC;
+ TensorShape shape(tensor.shape());
+
+ if(is_nhwc)
+ {
+ // Ensure that the equivalent tensors will be filled for both data layouts
+ permute(shape, PermutationVector(1U, 2U, 0U));
+ }
+
// Iterate over all elements
for(int element_idx = 0; element_idx < tensor.num_elements(); ++element_idx)
{
- const Coordinates id = index2coord(tensor.shape(), element_idx);
+ Coordinates id = index2coord(shape, element_idx);
+
+ if(is_nhwc)
+ {
+ // Write in the correct id for permuted shapes
+ permute(id, PermutationVector(2U, 0U, 1U));
+ }
// Iterate over all channels
for(int channel = 0; channel < tensor.num_channels(); ++channel)
@@ -748,6 +799,12 @@ void AssetsLibrary::fill_layer_data(T &&tensor, std::string name) const
});
}
}
+
+template <typename T, typename D>
+void AssetsLibrary::fill_tensor_value(T &&tensor, D value) const
+{
+ fill_tensor_uniform(tensor, 0, value, value);
+}
} // namespace test
} // namespace arm_compute
#endif /* __ARM_COMPUTE_TEST_TENSOR_LIBRARY_H__ */