summaryrefslogtreecommitdiff
path: root/tests/AssetsLibrary.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/AssetsLibrary.cpp')
-rw-r--r--tests/AssetsLibrary.cpp77
1 files changed, 59 insertions, 18 deletions
diff --git a/tests/AssetsLibrary.cpp b/tests/AssetsLibrary.cpp
index 5660f3c06..111ae7709 100644
--- a/tests/AssetsLibrary.cpp
+++ b/tests/AssetsLibrary.cpp
@@ -120,15 +120,15 @@ void discard_comments_and_spaces(std::ifstream &fs)
}
}
-std::tuple<unsigned int, unsigned int, int> parse_ppm_header(std::ifstream &fs)
+std::tuple<unsigned int, unsigned int, int> parse_netpbm_format_header(std::ifstream &fs, char number)
{
- // Check the PPM magic number is valid
+ // check file type magic number is valid
std::array<char, 2> magic_number{ { 0 } };
fs >> magic_number[0] >> magic_number[1];
- if(magic_number[0] != 'P' || magic_number[1] != '6')
+ if(magic_number[0] != 'P' || magic_number[1] != number)
{
- throw std::runtime_error("Only raw PPM format is suported");
+ throw std::runtime_error("File type magic number not supported");
}
discard_comments_and_spaces(fs);
@@ -160,7 +160,7 @@ std::tuple<unsigned int, unsigned int, int> parse_ppm_header(std::ifstream &fs)
if(isspace(fs.peek()) == 0)
{
- throw std::runtime_error("Invalid PPM header");
+ throw std::runtime_error("Invalid image header");
}
fs.ignore(1);
@@ -168,6 +168,39 @@ std::tuple<unsigned int, unsigned int, int> parse_ppm_header(std::ifstream &fs)
return std::make_tuple(width, height, max_value);
}
+std::tuple<unsigned int, unsigned int, int> parse_ppm_header(std::ifstream &fs)
+{
+ return parse_netpbm_format_header(fs, '6');
+}
+
+std::tuple<unsigned int, unsigned int, int> parse_pgm_header(std::ifstream &fs)
+{
+ return parse_netpbm_format_header(fs, '5');
+}
+
+void check_image_size(std::ifstream &fs, size_t raw_size)
+{
+ const size_t current_position = fs.tellg();
+ fs.seekg(0, std::ios_base::end);
+ const size_t end_position = fs.tellg();
+ fs.seekg(current_position, std::ios_base::beg);
+
+ if((end_position - current_position) < raw_size)
+ {
+ throw std::runtime_error("Not enough data in file");
+ }
+}
+
+void read_image_buffer(std::ifstream &fs, RawTensor &raw)
+{
+ fs.read(reinterpret_cast<std::fstream::char_type *>(raw.data()), raw.size());
+
+ if(!fs.good())
+ {
+ throw std::runtime_error("Failure while reading image buffer");
+ }
+}
+
RawTensor load_ppm(const std::string &path)
{
std::ifstream file(path, std::ios::in | std::ios::binary);
@@ -184,24 +217,31 @@ RawTensor load_ppm(const std::string &path)
RawTensor raw(TensorShape(width, height), Format::RGB888);
- // Check if the file is large enough to fill the image
- const size_t current_position = file.tellg();
- file.seekg(0, std::ios_base::end);
- const size_t end_position = file.tellg();
- file.seekg(current_position, std::ios_base::beg);
+ check_image_size(file, raw.size());
+ read_image_buffer(file, raw);
- if((end_position - current_position) < raw.size())
- {
- throw std::runtime_error("Not enough data in file");
- }
+ return raw;
+}
- file.read(reinterpret_cast<std::fstream::char_type *>(raw.data()), raw.size());
+RawTensor load_pgm(const std::string &path)
+{
+ std::ifstream file(path, std::ios::in | std::ios::binary);
if(!file.good())
{
- throw std::runtime_error("Failure while reading image buffer");
+ throw framework::FileNotFound("Could not load PGM image: " + path);
}
+ unsigned int width = 0;
+ unsigned int height = 0;
+
+ std::tie(width, height, std::ignore) = parse_pgm_header(file);
+
+ RawTensor raw(TensorShape(width, height), Format::U8);
+
+ check_image_size(file, raw.size());
+ read_image_buffer(file, raw);
+
return raw;
}
} // namespace
@@ -243,7 +283,8 @@ const AssetsLibrary::Loader &AssetsLibrary::get_loader(const std::string &extens
{
static std::unordered_map<std::string, Loader> loaders =
{
- { "ppm", load_ppm }
+ { "ppm", load_ppm },
+ { "pgm", load_pgm }
};
const auto it = loaders.find(extension);
@@ -407,7 +448,7 @@ const RawTensor &AssetsLibrary::find_or_create_raw_tensor(const std::string &nam
}
const RawTensor &src = get(name, format);
- RawTensor dst(src.shape(), get_channel_format(channel));
+ RawTensor dst(src.shape(), get_channel_format(channel));
(*get_extractor(format, channel))(src, dst);