summaryrefslogtreecommitdiff
path: root/tools/nnpackage_tool/gen_golden/gen_golden.py
diff options
context:
space:
mode:
Diffstat (limited to 'tools/nnpackage_tool/gen_golden/gen_golden.py')
-rwxr-xr-xtools/nnpackage_tool/gen_golden/gen_golden.py13
1 files changed, 10 insertions, 3 deletions
diff --git a/tools/nnpackage_tool/gen_golden/gen_golden.py b/tools/nnpackage_tool/gen_golden/gen_golden.py
index 125a69cac..d555419a6 100755
--- a/tools/nnpackage_tool/gen_golden/gen_golden.py
+++ b/tools/nnpackage_tool/gen_golden/gen_golden.py
@@ -91,9 +91,12 @@ if __name__ == '__main__':
if this_dtype == tf.uint8:
input_values.append(
np.random.randint(0, 255, this_shape).astype(np.uint8))
+ if this_dtype == tf.int8:
+ input_values.append(
+ np.random.randint(-127, 127, this_shape).astype(np.int8))
elif this_dtype == tf.float32:
input_values.append(
- np.random.random_sample(this_shape).astype(np.float32))
+ (10 * np.random.random_sample(this_shape) - 5).astype(np.float32))
elif this_dtype == tf.bool:
# generate random integer from [0, 2)
input_values.append(
@@ -134,9 +137,12 @@ if __name__ == '__main__':
if this_dtype == np.uint8:
input_values.append(
np.random.randint(0, 255, this_shape).astype(np.uint8))
+ if this_dtype == np.int8:
+ input_values.append(
+ np.random.randint(-127, 127, this_shape).astype(np.int8))
elif this_dtype == np.float32:
input_values.append(
- np.random.random_sample(this_shape).astype(np.float32))
+ (10 * np.random.random_sample(this_shape) - 5).astype(np.float32))
elif this_dtype == np.bool_:
# generate random integer from [0, 2)
input_values.append(
@@ -158,10 +164,11 @@ if __name__ == '__main__':
# dump input and output in h5
import h5py
- supported_dtypes = ("float32", "uint8", "bool", "int32", "int64")
+ supported_dtypes = ("float32", "uint8", "int8", "bool", "int32", "int64")
h5dtypes = {
"float32": ">f4",
"uint8": "u1",
+ "int8": "i1",
"bool": "u1",
"int32": "int32",
"int64": "int64"