diff options
Diffstat (limited to 'libkmod/libkmod-file.c')
-rw-r--r-- | libkmod/libkmod-file.c | 144 |
1 files changed, 144 insertions, 0 deletions
diff --git a/libkmod/libkmod-file.c b/libkmod/libkmod-file.c index 5eeba6a..b6a8cc9 100644 --- a/libkmod/libkmod-file.c +++ b/libkmod/libkmod-file.c @@ -26,6 +26,9 @@ #include <sys/stat.h> #include <sys/types.h> #include <unistd.h> +#ifdef ENABLE_ZSTD +#include <zstd.h> +#endif #ifdef ENABLE_XZ #include <lzma.h> #endif @@ -45,6 +48,9 @@ struct file_ops { }; struct kmod_file { +#ifdef ENABLE_ZSTD + bool zstd_used; +#endif #ifdef ENABLE_XZ bool xz_used; #endif @@ -60,6 +66,141 @@ struct kmod_file { struct kmod_elf *elf; }; +#ifdef ENABLE_ZSTD +static int zstd_read_block(struct kmod_file *file, size_t block_size, + ZSTD_inBuffer *input, size_t *input_capacity) +{ + ssize_t rdret; + int ret; + + if (*input_capacity < block_size) { + free((void *)input->src); + input->src = malloc(block_size); + if (input->src == NULL) { + ret = -errno; + ERR(file->ctx, "zstd: %m\n"); + return ret; + } + *input_capacity = block_size; + } + + rdret = read(file->fd, (void *)input->src, block_size); + if (rdret < 0) { + ret = -errno; + ERR(file->ctx, "zstd: %m\n"); + return ret; + } + + input->pos = 0; + input->size = rdret; + return 0; +} + +static int zstd_ensure_outbuffer_space(ZSTD_outBuffer *buffer, size_t min_free) +{ + uint8_t *old_buffer = buffer->dst; + int ret = 0; + + if (buffer->size - buffer->pos >= min_free) + return 0; + + buffer->size += min_free; + buffer->dst = realloc(buffer->dst, buffer->size); + if (buffer->dst == NULL) { + ret = -errno; + free(old_buffer); + } + + return ret; +} + +static int zstd_decompress_block(struct kmod_file *file, ZSTD_DStream *dstr, + ZSTD_inBuffer *input, ZSTD_outBuffer *output, + size_t *next_block_size) +{ + size_t out_buf_min_size = ZSTD_DStreamOutSize(); + int ret = 0; + + do { + ssize_t dsret; + + ret = zstd_ensure_outbuffer_space(output, out_buf_min_size); + if (ret) { + ERR(file->ctx, "zstd: %s\n", strerror(-ret)); + break; + } + + dsret = ZSTD_decompressStream(dstr, output, input); + if (ZSTD_isError(dsret)) { + ret = -EINVAL; + ERR(file->ctx, "zstd: %s\n", ZSTD_getErrorName(dsret)); + break; + } + if (dsret > 0) + *next_block_size = (size_t)dsret; + } while (input->pos < input->size + || output->pos > output->size + || output->size - output->pos < out_buf_min_size); + + return ret; +} + +static int load_zstd(struct kmod_file *file) +{ + ZSTD_DStream *dstr; + size_t next_block_size; + size_t zst_inb_capacity = 0; + ZSTD_inBuffer zst_inb = { 0 }; + ZSTD_outBuffer zst_outb = { 0 }; + int ret; + + dstr = ZSTD_createDStream(); + if (dstr == NULL) { + ret = -EINVAL; + ERR(file->ctx, "zstd: Failed to create decompression stream\n"); + goto out; + } + + next_block_size = ZSTD_initDStream(dstr); + + while (true) { + ret = zstd_read_block(file, next_block_size, &zst_inb, + &zst_inb_capacity); + if (ret != 0) + goto out; + if (zst_inb.size == 0) /* EOF */ + break; + + ret = zstd_decompress_block(file, dstr, &zst_inb, &zst_outb, + &next_block_size); + if (ret != 0) + goto out; + } + + ZSTD_freeDStream(dstr); + free((void *)zst_inb.src); + file->zstd_used = true; + file->memory = zst_outb.dst; + file->size = zst_outb.pos; + return 0; +out: + if (dstr != NULL) + ZSTD_freeDStream(dstr); + free((void *)zst_inb.src); + free((void *)zst_outb.dst); + return ret; +} + +static void unload_zstd(struct kmod_file *file) +{ + if (!file->zstd_used) + return; + free(file->memory); +} + +static const char magic_zstd[] = {0x28, 0xB5, 0x2F, 0xFD}; +#endif + #ifdef ENABLE_XZ static void xz_uncompress_belch(struct kmod_file *file, lzma_ret ret) { @@ -238,6 +379,9 @@ static const struct comp_type { const char *magic_bytes; const struct file_ops ops; } comp_types[] = { +#ifdef ENABLE_ZSTD + {sizeof(magic_zstd), magic_zstd, {load_zstd, unload_zstd}}, +#endif #ifdef ENABLE_XZ {sizeof(magic_xz), magic_xz, {load_xz, unload_xz}}, #endif |