diff options
Diffstat (limited to 'compiler/record-minmax/src/HDF5Importer.cpp')
-rw-r--r-- | compiler/record-minmax/src/HDF5Importer.cpp | 28 |
1 files changed, 27 insertions, 1 deletions
diff --git a/compiler/record-minmax/src/HDF5Importer.cpp b/compiler/record-minmax/src/HDF5Importer.cpp index a0e65eeb7..cfb270ce0 100644 --- a/compiler/record-minmax/src/HDF5Importer.cpp +++ b/compiler/record-minmax/src/HDF5Importer.cpp @@ -59,7 +59,30 @@ DataType toInternalDtype(const H5::DataType &h5_type) { return DataType::S64; } - // Only support three datatypes for now + if (h5_type.getClass() == H5T_class_t::H5T_ENUM) + { + // We follow the numpy format + // In numpy 1.19.0, np.bool_ is saved as H5T_ENUM + // - (name, value) -> (FALSE, 0) and (TRUE, 1) + // - value dtype is H5T_STD_I8LE + // TODO Find a general way to recognize BOOL type + char name[10]; + int8_t value[2] = {0, 1}; + if (H5Tenum_nameof(h5_type.getId(), value, name, 10) < 0) + return DataType::Unknown; + + if (std::string(name) != "FALSE") + return DataType::Unknown; + + if (H5Tenum_nameof(h5_type.getId(), value + 1, name, 10) < 0) + return DataType::Unknown; + + if (std::string(name) != "TRUE") + return DataType::Unknown; + + return DataType::BOOL; + } + // TODO Support more datatypes return DataType::Unknown; } @@ -125,6 +148,9 @@ void HDF5Importer::readTensor(int32_t record_idx, int32_t input_idx, DataType *d case DataType::S64: readTensorData(tensor, static_cast<int64_t *>(buffer)); break; + case DataType::BOOL: + readTensorData(tensor, static_cast<uint8_t *>(buffer)); + break; default: throw std::runtime_error{"Unsupported data type for input data (.h5)"}; } |