summaryrefslogtreecommitdiff
path: root/lib/jxl/dec_modular.cc
diff options
context:
space:
mode:
Diffstat (limited to 'lib/jxl/dec_modular.cc')
-rw-r--r--lib/jxl/dec_modular.cc369
1 files changed, 240 insertions, 129 deletions
diff --git a/lib/jxl/dec_modular.cc b/lib/jxl/dec_modular.cc
index f8b4c0a..bf85eaa 100644
--- a/lib/jxl/dec_modular.cc
+++ b/lib/jxl/dec_modular.cc
@@ -7,6 +7,8 @@
#include <stdint.h>
+#include <atomic>
+#include <sstream>
#include <vector>
#include "lib/jxl/frame_header.h"
@@ -18,6 +20,7 @@
#include "lib/jxl/alpha.h"
#include "lib/jxl/base/compiler_specific.h"
+#include "lib/jxl/base/printf_macros.h"
#include "lib/jxl/base/span.h"
#include "lib/jxl/base/status.h"
#include "lib/jxl/compressed_dc.h"
@@ -31,6 +34,8 @@ namespace jxl {
namespace HWY_NAMESPACE {
// These templates are not found via ADL.
+using hwy::HWY_NAMESPACE::Add;
+using hwy::HWY_NAMESPACE::Mul;
using hwy::HWY_NAMESPACE::Rebind;
void MultiplySum(const size_t xsize,
@@ -41,49 +46,39 @@ void MultiplySum(const size_t xsize,
const Rebind<pixel_type, HWY_FULL(float)> di; // assumes pixel_type <= float
const auto factor_v = Set(df, factor);
for (size_t x = 0; x < xsize; x += Lanes(di)) {
- const auto in = Load(di, row_in + x) + Load(di, row_in_Y + x);
- const auto out = ConvertTo(df, in) * factor_v;
+ const auto in = Add(Load(di, row_in + x), Load(di, row_in_Y + x));
+ const auto out = Mul(ConvertTo(df, in), factor_v);
Store(out, df, row_out + x);
}
}
void RgbFromSingle(const size_t xsize,
const pixel_type* const JXL_RESTRICT row_in,
- const float factor, Image3F* decoded, size_t /*c*/, size_t y,
- Rect& rect) {
- JXL_DASSERT(xsize <= rect.xsize());
+ const float factor, float* out_r, float* out_g,
+ float* out_b) {
const HWY_FULL(float) df;
const Rebind<pixel_type, HWY_FULL(float)> di; // assumes pixel_type <= float
- float* const JXL_RESTRICT row_out_r = rect.PlaneRow(decoded, 0, y);
- float* const JXL_RESTRICT row_out_g = rect.PlaneRow(decoded, 1, y);
- float* const JXL_RESTRICT row_out_b = rect.PlaneRow(decoded, 2, y);
-
const auto factor_v = Set(df, factor);
for (size_t x = 0; x < xsize; x += Lanes(di)) {
const auto in = Load(di, row_in + x);
- const auto out = ConvertTo(df, in) * factor_v;
- Store(out, df, row_out_r + x);
- Store(out, df, row_out_g + x);
- Store(out, df, row_out_b + x);
+ const auto out = Mul(ConvertTo(df, in), factor_v);
+ Store(out, df, out_r + x);
+ Store(out, df, out_g + x);
+ Store(out, df, out_b + x);
}
}
-// Same signature as RgbFromSingle so we can assign to the same pointer.
void SingleFromSingle(const size_t xsize,
const pixel_type* const JXL_RESTRICT row_in,
- const float factor, Image3F* decoded, size_t c, size_t y,
- Rect& rect) {
- JXL_DASSERT(xsize <= rect.xsize());
+ const float factor, float* row_out) {
const HWY_FULL(float) df;
const Rebind<pixel_type, HWY_FULL(float)> di; // assumes pixel_type <= float
- float* const JXL_RESTRICT row_out = rect.PlaneRow(decoded, c, y);
-
const auto factor_v = Set(df, factor);
for (size_t x = 0; x < xsize; x += Lanes(di)) {
const auto in = Load(di, row_in + x);
- const auto out = ConvertTo(df, in) * factor_v;
+ const auto out = Mul(ConvertTo(df, in), factor_v);
Store(out, df, row_out + x);
}
}
@@ -98,6 +93,16 @@ HWY_EXPORT(MultiplySum); // Local function
HWY_EXPORT(RgbFromSingle); // Local function
HWY_EXPORT(SingleFromSingle); // Local function
+// Slow conversion using double precision multiplication, only
+// needed when the bit depth is too high for single precision
+void SingleFromSingleAccurate(const size_t xsize,
+ const pixel_type* const JXL_RESTRICT row_in,
+ const double factor, float* row_out) {
+ for (size_t x = 0; x < xsize; x++) {
+ row_out[x] = row_in[x] * factor;
+ }
+}
+
// convert custom [bits]-bit float (with [exp_bits] exponent bits) stored as int
// back to binary32 float
void int_to_float(const pixel_type* const JXL_RESTRICT row_in,
@@ -148,6 +153,28 @@ void int_to_float(const pixel_type* const JXL_RESTRICT row_in,
}
}
+std::string ModularStreamId::DebugString() const {
+ std::ostringstream os;
+ os << (kind == kGlobalData ? "ModularGlobal"
+ : kind == kVarDCTDC ? "VarDCTDC"
+ : kind == kModularDC ? "ModularDC"
+ : kind == kACMetadata ? "ACMeta"
+ : kind == kQuantTable ? "QuantTable"
+ : kind == kModularAC ? "ModularAC"
+ : "");
+ if (kind == kVarDCTDC || kind == kModularDC || kind == kACMetadata ||
+ kind == kModularAC) {
+ os << " group " << group_id;
+ }
+ if (kind == kModularAC) {
+ os << " pass " << pass_id;
+ }
+ if (kind == kQuantTable) {
+ os << " " << quant_table_id;
+ }
+ return os.str();
+}
+
Status ModularFrameDecoder::DecodeGlobalInfo(BitReader* reader,
const FrameHeader& frame_header,
bool allow_truncated_group) {
@@ -158,17 +185,22 @@ Status ModularFrameDecoder::DecodeGlobalInfo(BitReader* reader,
if (is_gray && frame_header.color_transform == ColorTransform::kNone) {
nb_chans = 1;
}
+ do_color = decode_color;
+ size_t nb_extra = metadata.extra_channel_info.size();
bool has_tree = reader->ReadBits(1);
- if (has_tree) {
- size_t tree_size_limit =
- 1024 + frame_dim.xsize * frame_dim.ysize * nb_chans / 16;
- JXL_RETURN_IF_ERROR(DecodeTree(reader, &tree, tree_size_limit));
- JXL_RETURN_IF_ERROR(
- DecodeHistograms(reader, (tree.size() + 1) / 2, &code, &context_map));
+ if (!allow_truncated_group ||
+ reader->TotalBitsConsumed() < reader->TotalBytes() * kBitsPerByte) {
+ if (has_tree) {
+ size_t tree_size_limit =
+ std::min(static_cast<size_t>(1 << 22),
+ 1024 + frame_dim.xsize * frame_dim.ysize *
+ (nb_chans + nb_extra) / 16);
+ JXL_RETURN_IF_ERROR(DecodeTree(reader, &tree, tree_size_limit));
+ JXL_RETURN_IF_ERROR(
+ DecodeHistograms(reader, (tree.size() + 1) / 2, &code, &context_map));
+ }
}
- do_color = decode_color;
if (!do_color) nb_chans = 0;
- size_t nb_extra = metadata.extra_channel_info.size();
bool fp = metadata.bit_depth.floating_point_sample;
@@ -212,13 +244,15 @@ Status ModularFrameDecoder::DecodeGlobalInfo(BitReader* reader,
all_same_shift = false;
}
+ JXL_DEBUG_V(6, "DecodeGlobalInfo: full_image (w/o transforms) %s",
+ gi.DebugString().c_str());
ModularOptions options;
options.max_chan_size = frame_dim.group_dim;
options.group_dim = frame_dim.group_dim;
Status dec_status = ModularGenericDecompress(
reader, gi, &global_header, ModularStreamId::Global().ID(frame_dim),
&options,
- /*undo_transforms=*/-2, &tree, &code, &context_map,
+ /*undo_transforms=*/false, &tree, &code, &context_map,
allow_truncated_group);
if (!allow_truncated_group) JXL_RETURN_IF_ERROR(dec_status);
if (dec_status.IsFatalError()) {
@@ -242,12 +276,15 @@ Status ModularFrameDecoder::DecodeGlobalInfo(BitReader* reader,
}
}
full_image = std::move(gi);
+ JXL_DEBUG_V(6, "DecodeGlobalInfo: full_image (with transforms) %s",
+ full_image.DebugString().c_str());
return dec_status;
}
void ModularFrameDecoder::MaybeDropFullImage() {
if (full_image.transform.empty() && !have_something && all_same_shift) {
use_full_image = false;
+ JXL_DEBUG_V(6, "Dropping full image");
for (auto& ch : full_image.channel) {
// keep metadata on channels around, but dealloc their planes
ch.plane = Plane<pixel_type>();
@@ -255,12 +292,14 @@ void ModularFrameDecoder::MaybeDropFullImage() {
}
}
-Status ModularFrameDecoder::DecodeGroup(const Rect& rect, BitReader* reader,
- int minShift, int maxShift,
- const ModularStreamId& stream,
- bool zerofill,
- PassesDecoderState* dec_state,
- ImageBundle* output) {
+Status ModularFrameDecoder::DecodeGroup(
+ const Rect& rect, BitReader* reader, int minShift, int maxShift,
+ const ModularStreamId& stream, bool zerofill, PassesDecoderState* dec_state,
+ RenderPipelineInput* render_pipeline_input, bool allow_truncated,
+ bool* should_run_pipeline) {
+ JXL_DEBUG_V(6, "Decoding %s with rect %s and shift bracket %d..%d %s",
+ stream.DebugString().c_str(), Description(rect).c_str(), minShift,
+ maxShift, zerofill ? "using zerofill" : "");
JXL_DASSERT(stream.kind == ModularStreamId::kModularDC ||
stream.kind == ModularStreamId::kModularAC);
const size_t xsize = rect.xsize();
@@ -297,22 +336,36 @@ Status ModularFrameDecoder::DecodeGroup(const Rect& rect, BitReader* reader,
if (zerofill && use_full_image) return true;
// Return early if there's nothing to decode. Otherwise there might be
// problems later (in ModularImageToDecodedRect).
- if (gi.channel.empty()) return true;
+ if (gi.channel.empty()) {
+ if (dec_state && should_run_pipeline) {
+ const auto& frame_header = dec_state->shared->frame_header;
+ const auto* metadata = frame_header.nonserialized_metadata;
+ if (do_color || metadata->m.num_extra_channels > 0) {
+ // Signal to FrameDecoder that we do not have some of the required input
+ // for the render pipeline.
+ *should_run_pipeline = false;
+ }
+ }
+ JXL_DEBUG_V(6, "Nothing to decode, returning early.");
+ return true;
+ }
ModularOptions options;
if (!zerofill) {
- if (!ModularGenericDecompress(
- reader, gi, /*header=*/nullptr, stream.ID(frame_dim), &options,
- /*undo_transforms=*/-1, &tree, &code, &context_map)) {
- return JXL_FAILURE("Failed to decode modular group");
- }
+ auto status = ModularGenericDecompress(
+ reader, gi, /*header=*/nullptr, stream.ID(frame_dim), &options,
+ /*undo_transforms=*/true, &tree, &code, &context_map, allow_truncated);
+ if (!allow_truncated) JXL_RETURN_IF_ERROR(status);
+ if (status.IsFatalError()) return status;
}
// Undo global transforms that have been pushed to the group level
if (!use_full_image) {
+ JXL_ASSERT(render_pipeline_input);
for (auto t : global_transform) {
JXL_RETURN_IF_ERROR(t.Inverse(gi, global_header.wp_header));
}
- JXL_RETURN_IF_ERROR(ModularImageToDecodedRect(
- gi, dec_state, nullptr, output, rect.Crop(dec_state->decoded)));
+ JXL_RETURN_IF_ERROR(ModularImageToDecodedRect(gi, dec_state, nullptr,
+ *render_pipeline_input,
+ Rect(0, 0, gi.w, gi.h)));
return true;
}
int gic = 0;
@@ -332,6 +385,7 @@ Status ModularFrameDecoder::DecodeGroup(const Rect& rect, BitReader* reader,
}
return true;
}
+
Status ModularFrameDecoder::DecodeVarDCTDC(size_t group_id, BitReader* reader,
PassesDecoderState* dec_state) {
const Rect r = dec_state->shared->DCGroupRect(group_id);
@@ -355,7 +409,7 @@ Status ModularFrameDecoder::DecodeVarDCTDC(size_t group_id, BitReader* reader,
}
if (!ModularGenericDecompress(
reader, image, /*header=*/nullptr, stream_id, &options,
- /*undo_transforms=*/-1, &tree, &code, &context_map)) {
+ /*undo_transforms=*/true, &tree, &code, &context_map)) {
return JXL_FAILURE("Failed to decode modular DC group");
}
DequantDC(r, &dec_state->shared_storage.dc_storage,
@@ -384,7 +438,7 @@ Status ModularFrameDecoder::DecodeAcMetadata(size_t group_id, BitReader* reader,
ModularOptions options;
if (!ModularGenericDecompress(
reader, image, /*header=*/nullptr, stream_id, &options,
- /*undo_transforms=*/-1, &tree, &code, &context_map)) {
+ /*undo_transforms=*/true, &tree, &code, &context_map)) {
return JXL_FAILURE("Failed to decode AC metadata");
}
ConvertPlaneAndClamp(Rect(image.channel[0].plane), image.channel[0].plane, cr,
@@ -399,11 +453,11 @@ Status ModularFrameDecoder::DecodeAcMetadata(size_t group_id, BitReader* reader,
uint32_t local_used_acs = 0;
for (size_t iy = 0; iy < r.ysize(); iy++) {
size_t y = r.y0() + iy;
- int* row_qf = r.Row(&dec_state->shared_storage.raw_quant_field, iy);
+ int32_t* row_qf = r.Row(&dec_state->shared_storage.raw_quant_field, iy);
uint8_t* row_epf = r.Row(&dec_state->shared_storage.epf_sharpness, iy);
- int* row_in_1 = image.channel[2].plane.Row(0);
- int* row_in_2 = image.channel[2].plane.Row(1);
- int* row_in_3 = image.channel[3].plane.Row(iy);
+ int32_t* row_in_1 = image.channel[2].plane.Row(0);
+ int32_t* row_in_2 = image.channel[2].plane.Row(1);
+ int32_t* row_in_3 = image.channel[3].plane.Row(iy);
for (size_t ix = 0; ix < r.xsize(); ix++) {
size_t x = r.x0() + ix;
int sharpness = row_in_3[ix];
@@ -440,8 +494,8 @@ Status ModularFrameDecoder::DecodeAcMetadata(size_t group_id, BitReader* reader,
}
JXL_RETURN_IF_ERROR(
ac_strategy.SetNoBoundsCheck(x, y, AcStrategy::Type(row_in_1[num])));
- row_qf[ix] =
- 1 + std::max(0, std::min(Quantizer::kQuantMax - 1, row_in_2[num]));
+ row_qf[ix] = 1 + std::max<int32_t>(0, std::min(Quantizer::kQuantMax - 1,
+ row_in_2[num]));
num++;
}
}
@@ -454,16 +508,15 @@ Status ModularFrameDecoder::DecodeAcMetadata(size_t group_id, BitReader* reader,
Status ModularFrameDecoder::ModularImageToDecodedRect(
Image& gi, PassesDecoderState* dec_state, jxl::ThreadPool* pool,
- ImageBundle* output, Rect rect) {
- auto& decoded = dec_state->decoded;
+ RenderPipelineInput& render_pipeline_input, Rect modular_rect) {
const auto& frame_header = dec_state->shared->frame_header;
const auto* metadata = frame_header.nonserialized_metadata;
- size_t xsize = rect.xsize();
- size_t ysize = rect.ysize();
- if (!xsize || !ysize) {
- return true;
- }
- JXL_DASSERT(rect.IsInside(decoded));
+ JXL_CHECK(gi.transform.empty());
+
+ auto get_row = [&](size_t c, size_t y) {
+ const auto& buffer = render_pipeline_input.GetBuffer(c);
+ return buffer.second.Row(buffer.first, y);
+ };
size_t c = 0;
if (do_color) {
@@ -473,9 +526,9 @@ Status ModularFrameDecoder::ModularImageToDecodedRect(
const bool fp = metadata->m.bit_depth.floating_point_sample &&
frame_header.color_transform != ColorTransform::kXYB;
for (; c < 3; c++) {
- float factor = full_image.bitdepth < 32
- ? 1.f / ((1u << full_image.bitdepth) - 1)
- : 0;
+ double factor = full_image.bitdepth < 32
+ ? 1.0 / ((1u << full_image.bitdepth) - 1)
+ : 0;
size_t c_in = c;
if (frame_header.color_transform == ColorTransform::kXYB) {
factor = dec_state->shared->matrices.DCQuants()[c];
@@ -490,59 +543,89 @@ Status ModularFrameDecoder::ModularImageToDecodedRect(
if (ch_in.w == 0 || ch_in.h == 0) {
return JXL_FAILURE("Empty image");
}
- size_t xsize_shifted = DivCeil(xsize, 1 << ch_in.hshift);
- size_t ysize_shifted = DivCeil(ysize, 1 << ch_in.vshift);
- Rect r(rect.x0() >> ch_in.hshift, rect.y0() >> ch_in.vshift,
- rect.xsize() >> ch_in.hshift, rect.ysize() >> ch_in.vshift,
- DivCeil(decoded.xsize(), 1 << ch_in.hshift),
- DivCeil(decoded.ysize(), 1 << ch_in.vshift));
- if (r.ysize() != ch_in.h || r.xsize() != ch_in.w) {
- return JXL_FAILURE(
- "Dimension mismatch: trying to fit a %zux%zu modular channel into "
- "a %zux%zu rect",
- ch_in.w, ch_in.h, r.xsize(), r.ysize());
+ JXL_CHECK(ch_in.hshift <= 3 && ch_in.vshift <= 3);
+ Rect r = render_pipeline_input.GetBuffer(c).second;
+ Rect mr(modular_rect.x0() >> ch_in.hshift,
+ modular_rect.y0() >> ch_in.vshift,
+ DivCeil(modular_rect.xsize(), 1 << ch_in.hshift),
+ DivCeil(modular_rect.ysize(), 1 << ch_in.vshift));
+ mr = mr.Crop(ch_in.plane);
+ size_t xsize_shifted = r.xsize();
+ size_t ysize_shifted = r.ysize();
+ if (r.ysize() != mr.ysize() || r.xsize() != mr.xsize()) {
+ return JXL_FAILURE("Dimension mismatch: trying to fit a %" PRIuS
+ "x%" PRIuS
+ " modular channel into "
+ "a %" PRIuS "x%" PRIuS " rect",
+ mr.xsize(), mr.ysize(), r.xsize(), r.ysize());
}
if (frame_header.color_transform == ColorTransform::kXYB && c == 2) {
JXL_ASSERT(!fp);
- RunOnPool(
- pool, 0, ysize_shifted, jxl::ThreadPool::SkipInit(),
- [&](const int task, const int thread) {
+ JXL_RETURN_IF_ERROR(RunOnPool(
+ pool, 0, ysize_shifted, ThreadPool::NoInit,
+ [&](const uint32_t task, size_t /* thread */) {
const size_t y = task;
- const pixel_type* const JXL_RESTRICT row_in = ch_in.Row(y);
+ const pixel_type* const JXL_RESTRICT row_in =
+ mr.Row(&ch_in.plane, y);
const pixel_type* const JXL_RESTRICT row_in_Y =
- gi.channel[0].Row(y);
- float* const JXL_RESTRICT row_out = r.PlaneRow(&decoded, c, y);
+ mr.Row(&gi.channel[0].plane, y);
+ float* const JXL_RESTRICT row_out = get_row(c, y);
HWY_DYNAMIC_DISPATCH(MultiplySum)
(xsize_shifted, row_in, row_in_Y, factor, row_out);
},
- "ModularIntToFloat");
+ "ModularIntToFloat"));
} else if (fp) {
int bits = metadata->m.bit_depth.bits_per_sample;
int exp_bits = metadata->m.bit_depth.exponent_bits_per_sample;
- RunOnPool(
- pool, 0, ysize_shifted, jxl::ThreadPool::SkipInit(),
- [&](const int task, const int thread) {
+ JXL_RETURN_IF_ERROR(RunOnPool(
+ pool, 0, ysize_shifted, ThreadPool::NoInit,
+ [&](const uint32_t task, size_t /* thread */) {
const size_t y = task;
- const pixel_type* const JXL_RESTRICT row_in = ch_in.Row(y);
- float* const JXL_RESTRICT row_out = r.PlaneRow(&decoded, c, y);
- int_to_float(row_in, row_out, xsize_shifted, bits, exp_bits);
+ const pixel_type* const JXL_RESTRICT row_in =
+ mr.Row(&ch_in.plane, y);
+ if (rgb_from_gray) {
+ for (size_t cc = 0; cc < 3; cc++) {
+ float* const JXL_RESTRICT row_out = get_row(cc, y);
+ int_to_float(row_in, row_out, xsize_shifted, bits, exp_bits);
+ }
+ } else {
+ float* const JXL_RESTRICT row_out = get_row(c, y);
+ int_to_float(row_in, row_out, xsize_shifted, bits, exp_bits);
+ }
},
- "ModularIntToFloat_losslessfloat");
+ "ModularIntToFloat_losslessfloat"));
} else {
- RunOnPool(
- pool, 0, ysize_shifted, jxl::ThreadPool::SkipInit(),
- [&](const int task, const int thread) {
+ JXL_RETURN_IF_ERROR(RunOnPool(
+ pool, 0, ysize_shifted, ThreadPool::NoInit,
+ [&](const uint32_t task, size_t /* thread */) {
const size_t y = task;
- const pixel_type* const JXL_RESTRICT row_in = ch_in.Row(y);
+ const pixel_type* const JXL_RESTRICT row_in =
+ mr.Row(&ch_in.plane, y);
if (rgb_from_gray) {
- HWY_DYNAMIC_DISPATCH(RgbFromSingle)
- (xsize_shifted, row_in, factor, &decoded, c, y, r);
+ if (full_image.bitdepth < 23) {
+ HWY_DYNAMIC_DISPATCH(RgbFromSingle)
+ (xsize_shifted, row_in, factor, get_row(0, y), get_row(1, y),
+ get_row(2, y));
+ } else {
+ SingleFromSingleAccurate(xsize_shifted, row_in, factor,
+ get_row(0, y));
+ SingleFromSingleAccurate(xsize_shifted, row_in, factor,
+ get_row(1, y));
+ SingleFromSingleAccurate(xsize_shifted, row_in, factor,
+ get_row(2, y));
+ }
} else {
- HWY_DYNAMIC_DISPATCH(SingleFromSingle)
- (xsize_shifted, row_in, factor, &decoded, c, y, r);
+ float* const JXL_RESTRICT row_out = get_row(c, y);
+ if (full_image.bitdepth < 23) {
+ HWY_DYNAMIC_DISPATCH(SingleFromSingle)
+ (xsize_shifted, row_in, factor, row_out);
+ } else {
+ SingleFromSingleAccurate(xsize_shifted, row_in, factor,
+ row_out);
+ }
}
},
- "ModularIntToFloat");
+ "ModularIntToFloat"));
}
if (rgb_from_gray) {
break;
@@ -552,66 +635,94 @@ Status ModularFrameDecoder::ModularImageToDecodedRect(
c = 1;
}
}
- for (size_t ec = 0; ec < dec_state->extra_channels.size(); ec++, c++) {
- const ExtraChannelInfo& eci = output->metadata()->extra_channel_info[ec];
+ size_t num_extra_channels = metadata->m.num_extra_channels;
+ for (size_t ec = 0; ec < num_extra_channels; ec++, c++) {
+ const ExtraChannelInfo& eci = metadata->m.extra_channel_info[ec];
int bits = eci.bit_depth.bits_per_sample;
int exp_bits = eci.bit_depth.exponent_bits_per_sample;
bool fp = eci.bit_depth.floating_point_sample;
JXL_ASSERT(fp || bits < 32);
- const float mul = fp ? 0 : (1.0f / ((1u << bits) - 1));
- size_t ecups = frame_header.extra_channel_upsampling[ec];
- const size_t ec_xsize = DivCeil(frame_dim.xsize_upsampled, ecups);
- const size_t ec_ysize = DivCeil(frame_dim.ysize_upsampled, ecups);
+ const double factor = fp ? 0 : (1.0 / ((1u << bits) - 1));
JXL_ASSERT(c < gi.channel.size());
Channel& ch_in = gi.channel[c];
- // For x0, y0 there's no need to do a DivCeil().
- JXL_DASSERT(rect.x0() % (1ul << ch_in.hshift) == 0);
- JXL_DASSERT(rect.y0() % (1ul << ch_in.vshift) == 0);
- Rect r(rect.x0() >> ch_in.hshift, rect.y0() >> ch_in.vshift,
- DivCeil(rect.xsize(), 1lu << ch_in.hshift),
- DivCeil(rect.ysize(), 1lu << ch_in.vshift), ec_xsize, ec_ysize);
-
- JXL_DASSERT(r.IsInside(dec_state->extra_channels[ec]));
- JXL_DASSERT(Rect(0, 0, r.xsize(), r.ysize()).IsInside(ch_in.plane));
+ Rect r = render_pipeline_input.GetBuffer(3 + ec).second;
+ Rect mr(modular_rect.x0() >> ch_in.hshift,
+ modular_rect.y0() >> ch_in.vshift,
+ DivCeil(modular_rect.xsize(), 1 << ch_in.hshift),
+ DivCeil(modular_rect.ysize(), 1 << ch_in.vshift));
+ mr = mr.Crop(ch_in.plane);
+ if (r.ysize() != mr.ysize() || r.xsize() != mr.xsize()) {
+ return JXL_FAILURE("Dimension mismatch: trying to fit a %" PRIuS
+ "x%" PRIuS
+ " modular channel into "
+ "a %" PRIuS "x%" PRIuS " rect",
+ mr.xsize(), mr.ysize(), r.xsize(), r.ysize());
+ }
for (size_t y = 0; y < r.ysize(); ++y) {
float* const JXL_RESTRICT row_out =
- r.Row(&dec_state->extra_channels[ec], y);
- const pixel_type* const JXL_RESTRICT row_in = ch_in.Row(y);
+ r.Row(render_pipeline_input.GetBuffer(3 + ec).first, y);
+ const pixel_type* const JXL_RESTRICT row_in = mr.Row(&ch_in.plane, y);
if (fp) {
int_to_float(row_in, row_out, r.xsize(), bits, exp_bits);
} else {
- for (size_t x = 0; x < r.xsize(); ++x) {
- row_out[x] = row_in[x] * mul;
+ if (full_image.bitdepth < 23) {
+ HWY_DYNAMIC_DISPATCH(SingleFromSingle)
+ (r.xsize(), row_in, factor, row_out);
+ } else {
+ SingleFromSingleAccurate(r.xsize(), row_in, factor, row_out);
}
}
}
- JXL_CHECK_IMAGE_INITIALIZED(dec_state->extra_channels[ec], r);
}
return true;
}
Status ModularFrameDecoder::FinalizeDecoding(PassesDecoderState* dec_state,
jxl::ThreadPool* pool,
- ImageBundle* output) {
+ bool inplace) {
if (!use_full_image) return true;
- Image& gi = full_image;
+ Image gi = (inplace ? std::move(full_image) : full_image.clone());
size_t xsize = gi.w;
size_t ysize = gi.h;
+ JXL_DEBUG_V(3, "Finalizing decoding for modular image: %s",
+ gi.DebugString().c_str());
+
// Don't use threads if total image size is smaller than a group
if (xsize * ysize < frame_dim.group_dim * frame_dim.group_dim) pool = nullptr;
// Undo the global transforms
- gi.undo_transforms(global_header.wp_header, -1, pool);
- for (auto t : global_transform) {
- JXL_RETURN_IF_ERROR(t.Inverse(gi, global_header.wp_header));
- }
+ gi.undo_transforms(global_header.wp_header, pool);
+ JXL_DASSERT(global_transform.empty());
if (gi.error) return JXL_FAILURE("Undoing transforms failed");
- auto& decoded = dec_state->decoded;
-
- JXL_RETURN_IF_ERROR(
- ModularImageToDecodedRect(gi, dec_state, pool, output, Rect(decoded)));
+ for (size_t i = 0; i < dec_state->shared->frame_dim.num_groups; i++) {
+ dec_state->render_pipeline->ClearDone(i);
+ }
+ std::atomic<bool> has_error{false};
+ JXL_RETURN_IF_ERROR(RunOnPool(
+ pool, 0, dec_state->shared->frame_dim.num_groups,
+ [&](size_t num_threads) {
+ const auto& frame_header = dec_state->shared->frame_header;
+ bool use_group_ids = (frame_header.encoding == FrameEncoding::kVarDCT ||
+ (frame_header.flags & FrameHeader::kNoise));
+ return dec_state->render_pipeline->PrepareForThreads(num_threads,
+ use_group_ids);
+ },
+ [&](const uint32_t group, size_t thread_id) {
+ RenderPipelineInput input =
+ dec_state->render_pipeline->GetInputBuffers(group, thread_id);
+ if (!ModularImageToDecodedRect(gi, dec_state, nullptr, input,
+ dec_state->shared->GroupRect(group))) {
+ has_error = true;
+ return;
+ }
+ input.Done();
+ },
+ "ModularToRect"));
+ if (has_error) {
+ return JXL_FAILURE("Error producing input to render pipeline");
+ }
return true;
}
@@ -633,12 +744,12 @@ Status ModularFrameDecoder::DecodeQuantTable(
JXL_RETURN_IF_ERROR(ModularGenericDecompress(
br, image, /*header=*/nullptr,
ModularStreamId::QuantTable(idx).ID(modular_frame_decoder->frame_dim),
- &options, /*undo_transforms=*/-1, &modular_frame_decoder->tree,
+ &options, /*undo_transforms=*/true, &modular_frame_decoder->tree,
&modular_frame_decoder->code, &modular_frame_decoder->context_map));
} else {
JXL_RETURN_IF_ERROR(ModularGenericDecompress(br, image, /*header=*/nullptr,
0, &options,
- /*undo_transforms=*/-1));
+ /*undo_transforms=*/true));
}
if (!encoding->qraw.qtable) {
encoding->qraw.qtable = new std::vector<int>();
@@ -646,7 +757,7 @@ Status ModularFrameDecoder::DecodeQuantTable(
encoding->qraw.qtable->resize(required_size_x * required_size_y * 3);
for (size_t c = 0; c < 3; c++) {
for (size_t y = 0; y < required_size_y; y++) {
- int* JXL_RESTRICT row = image.channel[c].Row(y);
+ int32_t* JXL_RESTRICT row = image.channel[c].Row(y);
for (size_t x = 0; x < required_size_x; x++) {
(*encoding->qraw.qtable)[c * required_size_x * required_size_y +
y * required_size_x + x] = row[x];