diff options
author | Jenkins <bsgcomp@arm.com> | 2018-05-23 11:36:53 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-05-23 14:55:11 +0100 |
commit | b3a371bc429d2ba45e56baaf239d8200c2662a74 (patch) | |
tree | 554525e415c303d64a08722a755397852ebbb8e4 /tests/AssetsLibrary.h | |
parent | 67c8c91522e5be8156b77f57e63c0253535c902a (diff) | |
download | armcl-master.tar.gz armcl-master.tar.bz2 armcl-master.zip |
Diffstat (limited to 'tests/AssetsLibrary.h')
-rw-r--r-- | tests/AssetsLibrary.h | 65 |
1 files changed, 61 insertions, 4 deletions
diff --git a/tests/AssetsLibrary.h b/tests/AssetsLibrary.h index 7e2a042e9..1fba3d4b4 100644 --- a/tests/AssetsLibrary.h +++ b/tests/AssetsLibrary.h @@ -66,27 +66,39 @@ public: */ AssetsLibrary(std::string path, std::random_device::result_type seed); - /** Path to assets directory used to initialise library. */ + /** Path to assets directory used to initialise library. + * + * @return the path to the assets directory. + */ std::string path() const; - /** Seed that is used to fill tensors with random values. */ + /** Seed that is used to fill tensors with random values. + * + * @return the initial random seed. + */ std::random_device::result_type seed() const; /** Provides a tensor shape for the specified image. * * @param[in] name Image file used to look up the raw tensor. + * + * @return the tensor shape for the specified image. */ TensorShape get_image_shape(const std::string &name); - /** Provides a contant raw tensor for the specified image. + /** Provides a constant raw tensor for the specified image. * * @param[in] name Image file used to look up the raw tensor. + * + * @return a raw tensor for the specified image. */ const RawTensor &get(const std::string &name) const; /** Provides a raw tensor for the specified image. * * @param[in] name Image file used to look up the raw tensor. + * + * @return a raw tensor for the specified image. */ RawTensor get(const std::string &name); @@ -96,6 +108,8 @@ public: * @param[in] name Image file used to initialise the tensor. * @param[in] data_type Data type used to initialise the tensor. * @param[in] num_channels Number of channels used to initialise the tensor. + * + * @return a raw tensor for the specified image. */ RawTensor get(const std::string &name, DataType data_type, int num_channels = 1) const; @@ -104,6 +118,8 @@ public: * * @param[in] name Image file used to look up the raw tensor. * @param[in] format Format used to look up the raw tensor. + * + * @return a raw tensor for the specified image. */ const RawTensor &get(const std::string &name, Format format) const; @@ -112,6 +128,8 @@ public: * * @param[in] name Image file used to look up the raw tensor. * @param[in] format Format used to look up the raw tensor. + * + * @return a raw tensor for the specified image. */ RawTensor get(const std::string &name, Format format); @@ -123,6 +141,8 @@ public: * * @note The channel has to be unambiguous so that the format can be * inferred automatically. + * + * @return a raw tensor for the specified image channel. */ const RawTensor &get(const std::string &name, Channel channel) const; @@ -134,6 +154,8 @@ public: * * @note The channel has to be unambiguous so that the format can be * inferred automatically. + * + * @return a raw tensor for the specified image channel. */ RawTensor get(const std::string &name, Channel channel); @@ -143,6 +165,8 @@ public: * @param[in] name Image file used to look up the raw tensor. * @param[in] format Format used to look up the raw tensor. * @param[in] channel Channel used to look up the raw tensor. + * + * @return a raw tensor for the specified image channel. */ const RawTensor &get(const std::string &name, Format format, Channel channel) const; @@ -152,6 +176,8 @@ public: * @param[in] name Image file used to look up the raw tensor. * @param[in] format Format used to look up the raw tensor. * @param[in] channel Channel used to look up the raw tensor. + * + * @return a raw tensor for the specified image channel. */ RawTensor get(const std::string &name, Format format, Channel channel); @@ -326,6 +352,16 @@ public: template <typename T> void fill_layer_data(T &&tensor, std::string name) const; + /** Fill a tensor with a constant value + * + * @param[in, out] tensor To be filled tensor. + * @param[in] value Value to be assigned to all elements of the input tensor. + * + * @note @p value must be of the same type as the data type of @p tensor + */ + template <typename T, typename D> + void fill_tensor_value(T &&tensor, D value) const; + private: // Function prototype to convert between image formats. using Converter = void (*)(const RawTensor &src, RawTensor &dst); @@ -420,10 +456,25 @@ void AssetsLibrary::fill(T &&tensor, D &&distribution, std::random_device::resul std::mt19937 gen(_seed + seed_offset); + const bool is_nhwc = tensor.data_layout() == DataLayout::NHWC; + TensorShape shape(tensor.shape()); + + if(is_nhwc) + { + // Ensure that the equivalent tensors will be filled for both data layouts + permute(shape, PermutationVector(1U, 2U, 0U)); + } + // Iterate over all elements for(int element_idx = 0; element_idx < tensor.num_elements(); ++element_idx) { - const Coordinates id = index2coord(tensor.shape(), element_idx); + Coordinates id = index2coord(shape, element_idx); + + if(is_nhwc) + { + // Write in the correct id for permuted shapes + permute(id, PermutationVector(2U, 0U, 1U)); + } // Iterate over all channels for(int channel = 0; channel < tensor.num_channels(); ++channel) @@ -748,6 +799,12 @@ void AssetsLibrary::fill_layer_data(T &&tensor, std::string name) const }); } } + +template <typename T, typename D> +void AssetsLibrary::fill_tensor_value(T &&tensor, D value) const +{ + fill_tensor_uniform(tensor, 0, value, value); +} } // namespace test } // namespace arm_compute #endif /* __ARM_COMPUTE_TEST_TENSOR_LIBRARY_H__ */ |