summaryrefslogtreecommitdiff
path: root/compiler/record-minmax/src/HDF5Importer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/record-minmax/src/HDF5Importer.cpp')
-rw-r--r--compiler/record-minmax/src/HDF5Importer.cpp28
1 files changed, 27 insertions, 1 deletions
diff --git a/compiler/record-minmax/src/HDF5Importer.cpp b/compiler/record-minmax/src/HDF5Importer.cpp
index a0e65eeb7..cfb270ce0 100644
--- a/compiler/record-minmax/src/HDF5Importer.cpp
+++ b/compiler/record-minmax/src/HDF5Importer.cpp
@@ -59,7 +59,30 @@ DataType toInternalDtype(const H5::DataType &h5_type)
{
return DataType::S64;
}
- // Only support three datatypes for now
+ if (h5_type.getClass() == H5T_class_t::H5T_ENUM)
+ {
+ // We follow the numpy format
+ // In numpy 1.19.0, np.bool_ is saved as H5T_ENUM
+ // - (name, value) -> (FALSE, 0) and (TRUE, 1)
+ // - value dtype is H5T_STD_I8LE
+ // TODO Find a general way to recognize BOOL type
+ char name[10];
+ int8_t value[2] = {0, 1};
+ if (H5Tenum_nameof(h5_type.getId(), value, name, 10) < 0)
+ return DataType::Unknown;
+
+ if (std::string(name) != "FALSE")
+ return DataType::Unknown;
+
+ if (H5Tenum_nameof(h5_type.getId(), value + 1, name, 10) < 0)
+ return DataType::Unknown;
+
+ if (std::string(name) != "TRUE")
+ return DataType::Unknown;
+
+ return DataType::BOOL;
+ }
+ // TODO Support more datatypes
return DataType::Unknown;
}
@@ -125,6 +148,9 @@ void HDF5Importer::readTensor(int32_t record_idx, int32_t input_idx, DataType *d
case DataType::S64:
readTensorData(tensor, static_cast<int64_t *>(buffer));
break;
+ case DataType::BOOL:
+ readTensorData(tensor, static_cast<uint8_t *>(buffer));
+ break;
default:
throw std::runtime_error{"Unsupported data type for input data (.h5)"};
}