summaryrefslogtreecommitdiff
path: root/compiler/souschef/src/Gaussian.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/souschef/src/Gaussian.cpp')
-rw-r--r--compiler/souschef/src/Gaussian.cpp45
1 files changed, 44 insertions, 1 deletions
diff --git a/compiler/souschef/src/Gaussian.cpp b/compiler/souschef/src/Gaussian.cpp
index 32cbcff4d..53a62cabf 100644
--- a/compiler/souschef/src/Gaussian.cpp
+++ b/compiler/souschef/src/Gaussian.cpp
@@ -23,6 +23,8 @@
#include <cassert>
#include <stdexcept>
+#include <fp16.h>
+
namespace souschef
{
@@ -36,7 +38,7 @@ static std::vector<uint8_t> generate_gaussian(int32_t count, float mean, float s
std::vector<uint8_t> res;
constexpr float max_cap = std::numeric_limits<T>::max();
- constexpr float min_cap = std::numeric_limits<T>::min();
+ constexpr float min_cap = std::numeric_limits<T>::lowest();
for (uint32_t n = 0; n < count; ++n)
{
float raw_value = dist(rand);
@@ -69,6 +71,34 @@ std::vector<uint8_t> GaussianFloat32DataChef::generate(int32_t count) const
return generate_gaussian<float>(count, _mean, _stddev);
}
+std::vector<uint8_t> GaussianFloat16DataChef::generate(int32_t count) const
+{
+ auto time_stamp = std::chrono::system_clock::now().time_since_epoch().count();
+ auto seed = static_cast<std::minstd_rand::result_type>(time_stamp);
+
+ std::minstd_rand rand{static_cast<std::minstd_rand::result_type>(seed)};
+ std::normal_distribution<float> dist{_mean, _stddev};
+
+ std::vector<uint8_t> res;
+
+ constexpr float max_cap = 1e9;
+ constexpr float min_cap = -1e9;
+ for (uint32_t n = 0; n < count; ++n)
+ {
+ float raw_value = dist(rand);
+ const float capped_value = std::max(min_cap, std::min(max_cap, raw_value));
+ const uint16_t value = fp16_ieee_from_fp32_value(capped_value);
+ auto const arr = reinterpret_cast<const uint8_t *>(&value);
+
+ for (uint32_t b = 0; b < sizeof(uint16_t); ++b)
+ {
+ res.emplace_back(arr[b]);
+ }
+ }
+
+ return res;
+}
+
std::vector<uint8_t> GaussianInt32DataChef::generate(int32_t count) const
{
return generate_gaussian<int32_t>(count, _mean, _stddev);
@@ -136,4 +166,17 @@ std::unique_ptr<DataChef> GaussianUint8DataChefFactory::create(const Arguments &
return std::unique_ptr<DataChef>{new GaussianUint8DataChef{mean, stddev}};
}
+std::unique_ptr<DataChef> GaussianFloat16DataChefFactory::create(const Arguments &args) const
+{
+ if (args.count() != 2)
+ {
+ throw std::runtime_error{"invalid argument count: two arguments (mean/stddev) are expected"};
+ }
+
+ auto const mean = to_number<float>(args.value(0));
+ auto const stddev = to_number<float>(args.value(1));
+
+ return std::unique_ptr<DataChef>{new GaussianFloat16DataChef{mean, stddev}};
+}
+
} // namespace souschef