diff options
Diffstat (limited to 'tests/AssetsLibrary.cpp')
-rw-r--r-- | tests/AssetsLibrary.cpp | 77 |
1 files changed, 59 insertions, 18 deletions
diff --git a/tests/AssetsLibrary.cpp b/tests/AssetsLibrary.cpp index 5660f3c06..111ae7709 100644 --- a/tests/AssetsLibrary.cpp +++ b/tests/AssetsLibrary.cpp @@ -120,15 +120,15 @@ void discard_comments_and_spaces(std::ifstream &fs) } } -std::tuple<unsigned int, unsigned int, int> parse_ppm_header(std::ifstream &fs) +std::tuple<unsigned int, unsigned int, int> parse_netpbm_format_header(std::ifstream &fs, char number) { - // Check the PPM magic number is valid + // check file type magic number is valid std::array<char, 2> magic_number{ { 0 } }; fs >> magic_number[0] >> magic_number[1]; - if(magic_number[0] != 'P' || magic_number[1] != '6') + if(magic_number[0] != 'P' || magic_number[1] != number) { - throw std::runtime_error("Only raw PPM format is suported"); + throw std::runtime_error("File type magic number not supported"); } discard_comments_and_spaces(fs); @@ -160,7 +160,7 @@ std::tuple<unsigned int, unsigned int, int> parse_ppm_header(std::ifstream &fs) if(isspace(fs.peek()) == 0) { - throw std::runtime_error("Invalid PPM header"); + throw std::runtime_error("Invalid image header"); } fs.ignore(1); @@ -168,6 +168,39 @@ std::tuple<unsigned int, unsigned int, int> parse_ppm_header(std::ifstream &fs) return std::make_tuple(width, height, max_value); } +std::tuple<unsigned int, unsigned int, int> parse_ppm_header(std::ifstream &fs) +{ + return parse_netpbm_format_header(fs, '6'); +} + +std::tuple<unsigned int, unsigned int, int> parse_pgm_header(std::ifstream &fs) +{ + return parse_netpbm_format_header(fs, '5'); +} + +void check_image_size(std::ifstream &fs, size_t raw_size) +{ + const size_t current_position = fs.tellg(); + fs.seekg(0, std::ios_base::end); + const size_t end_position = fs.tellg(); + fs.seekg(current_position, std::ios_base::beg); + + if((end_position - current_position) < raw_size) + { + throw std::runtime_error("Not enough data in file"); + } +} + +void read_image_buffer(std::ifstream &fs, RawTensor &raw) +{ + fs.read(reinterpret_cast<std::fstream::char_type *>(raw.data()), raw.size()); + + if(!fs.good()) + { + throw std::runtime_error("Failure while reading image buffer"); + } +} + RawTensor load_ppm(const std::string &path) { std::ifstream file(path, std::ios::in | std::ios::binary); @@ -184,24 +217,31 @@ RawTensor load_ppm(const std::string &path) RawTensor raw(TensorShape(width, height), Format::RGB888); - // Check if the file is large enough to fill the image - const size_t current_position = file.tellg(); - file.seekg(0, std::ios_base::end); - const size_t end_position = file.tellg(); - file.seekg(current_position, std::ios_base::beg); + check_image_size(file, raw.size()); + read_image_buffer(file, raw); - if((end_position - current_position) < raw.size()) - { - throw std::runtime_error("Not enough data in file"); - } + return raw; +} - file.read(reinterpret_cast<std::fstream::char_type *>(raw.data()), raw.size()); +RawTensor load_pgm(const std::string &path) +{ + std::ifstream file(path, std::ios::in | std::ios::binary); if(!file.good()) { - throw std::runtime_error("Failure while reading image buffer"); + throw framework::FileNotFound("Could not load PGM image: " + path); } + unsigned int width = 0; + unsigned int height = 0; + + std::tie(width, height, std::ignore) = parse_pgm_header(file); + + RawTensor raw(TensorShape(width, height), Format::U8); + + check_image_size(file, raw.size()); + read_image_buffer(file, raw); + return raw; } } // namespace @@ -243,7 +283,8 @@ const AssetsLibrary::Loader &AssetsLibrary::get_loader(const std::string &extens { static std::unordered_map<std::string, Loader> loaders = { - { "ppm", load_ppm } + { "ppm", load_ppm }, + { "pgm", load_pgm } }; const auto it = loaders.find(extension); @@ -407,7 +448,7 @@ const RawTensor &AssetsLibrary::find_or_create_raw_tensor(const std::string &nam } const RawTensor &src = get(name, format); - RawTensor dst(src.shape(), get_channel_format(channel)); + RawTensor dst(src.shape(), get_channel_format(channel)); (*get_extractor(format, channel))(src, dst); |