diff options
Diffstat (limited to 'lib/jxl/dec_modular.cc')
-rw-r--r-- | lib/jxl/dec_modular.cc | 369 |
1 files changed, 240 insertions, 129 deletions
diff --git a/lib/jxl/dec_modular.cc b/lib/jxl/dec_modular.cc index f8b4c0a..bf85eaa 100644 --- a/lib/jxl/dec_modular.cc +++ b/lib/jxl/dec_modular.cc @@ -7,6 +7,8 @@ #include <stdint.h> +#include <atomic> +#include <sstream> #include <vector> #include "lib/jxl/frame_header.h" @@ -18,6 +20,7 @@ #include "lib/jxl/alpha.h" #include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/printf_macros.h" #include "lib/jxl/base/span.h" #include "lib/jxl/base/status.h" #include "lib/jxl/compressed_dc.h" @@ -31,6 +34,8 @@ namespace jxl { namespace HWY_NAMESPACE { // These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::Mul; using hwy::HWY_NAMESPACE::Rebind; void MultiplySum(const size_t xsize, @@ -41,49 +46,39 @@ void MultiplySum(const size_t xsize, const Rebind<pixel_type, HWY_FULL(float)> di; // assumes pixel_type <= float const auto factor_v = Set(df, factor); for (size_t x = 0; x < xsize; x += Lanes(di)) { - const auto in = Load(di, row_in + x) + Load(di, row_in_Y + x); - const auto out = ConvertTo(df, in) * factor_v; + const auto in = Add(Load(di, row_in + x), Load(di, row_in_Y + x)); + const auto out = Mul(ConvertTo(df, in), factor_v); Store(out, df, row_out + x); } } void RgbFromSingle(const size_t xsize, const pixel_type* const JXL_RESTRICT row_in, - const float factor, Image3F* decoded, size_t /*c*/, size_t y, - Rect& rect) { - JXL_DASSERT(xsize <= rect.xsize()); + const float factor, float* out_r, float* out_g, + float* out_b) { const HWY_FULL(float) df; const Rebind<pixel_type, HWY_FULL(float)> di; // assumes pixel_type <= float - float* const JXL_RESTRICT row_out_r = rect.PlaneRow(decoded, 0, y); - float* const JXL_RESTRICT row_out_g = rect.PlaneRow(decoded, 1, y); - float* const JXL_RESTRICT row_out_b = rect.PlaneRow(decoded, 2, y); - const auto factor_v = Set(df, factor); for (size_t x = 0; x < xsize; x += Lanes(di)) { const auto in = Load(di, row_in + x); - const auto out = ConvertTo(df, in) * factor_v; - Store(out, df, row_out_r + x); - Store(out, df, row_out_g + x); - Store(out, df, row_out_b + x); + const auto out = Mul(ConvertTo(df, in), factor_v); + Store(out, df, out_r + x); + Store(out, df, out_g + x); + Store(out, df, out_b + x); } } -// Same signature as RgbFromSingle so we can assign to the same pointer. void SingleFromSingle(const size_t xsize, const pixel_type* const JXL_RESTRICT row_in, - const float factor, Image3F* decoded, size_t c, size_t y, - Rect& rect) { - JXL_DASSERT(xsize <= rect.xsize()); + const float factor, float* row_out) { const HWY_FULL(float) df; const Rebind<pixel_type, HWY_FULL(float)> di; // assumes pixel_type <= float - float* const JXL_RESTRICT row_out = rect.PlaneRow(decoded, c, y); - const auto factor_v = Set(df, factor); for (size_t x = 0; x < xsize; x += Lanes(di)) { const auto in = Load(di, row_in + x); - const auto out = ConvertTo(df, in) * factor_v; + const auto out = Mul(ConvertTo(df, in), factor_v); Store(out, df, row_out + x); } } @@ -98,6 +93,16 @@ HWY_EXPORT(MultiplySum); // Local function HWY_EXPORT(RgbFromSingle); // Local function HWY_EXPORT(SingleFromSingle); // Local function +// Slow conversion using double precision multiplication, only +// needed when the bit depth is too high for single precision +void SingleFromSingleAccurate(const size_t xsize, + const pixel_type* const JXL_RESTRICT row_in, + const double factor, float* row_out) { + for (size_t x = 0; x < xsize; x++) { + row_out[x] = row_in[x] * factor; + } +} + // convert custom [bits]-bit float (with [exp_bits] exponent bits) stored as int // back to binary32 float void int_to_float(const pixel_type* const JXL_RESTRICT row_in, @@ -148,6 +153,28 @@ void int_to_float(const pixel_type* const JXL_RESTRICT row_in, } } +std::string ModularStreamId::DebugString() const { + std::ostringstream os; + os << (kind == kGlobalData ? "ModularGlobal" + : kind == kVarDCTDC ? "VarDCTDC" + : kind == kModularDC ? "ModularDC" + : kind == kACMetadata ? "ACMeta" + : kind == kQuantTable ? "QuantTable" + : kind == kModularAC ? "ModularAC" + : ""); + if (kind == kVarDCTDC || kind == kModularDC || kind == kACMetadata || + kind == kModularAC) { + os << " group " << group_id; + } + if (kind == kModularAC) { + os << " pass " << pass_id; + } + if (kind == kQuantTable) { + os << " " << quant_table_id; + } + return os.str(); +} + Status ModularFrameDecoder::DecodeGlobalInfo(BitReader* reader, const FrameHeader& frame_header, bool allow_truncated_group) { @@ -158,17 +185,22 @@ Status ModularFrameDecoder::DecodeGlobalInfo(BitReader* reader, if (is_gray && frame_header.color_transform == ColorTransform::kNone) { nb_chans = 1; } + do_color = decode_color; + size_t nb_extra = metadata.extra_channel_info.size(); bool has_tree = reader->ReadBits(1); - if (has_tree) { - size_t tree_size_limit = - 1024 + frame_dim.xsize * frame_dim.ysize * nb_chans / 16; - JXL_RETURN_IF_ERROR(DecodeTree(reader, &tree, tree_size_limit)); - JXL_RETURN_IF_ERROR( - DecodeHistograms(reader, (tree.size() + 1) / 2, &code, &context_map)); + if (!allow_truncated_group || + reader->TotalBitsConsumed() < reader->TotalBytes() * kBitsPerByte) { + if (has_tree) { + size_t tree_size_limit = + std::min(static_cast<size_t>(1 << 22), + 1024 + frame_dim.xsize * frame_dim.ysize * + (nb_chans + nb_extra) / 16); + JXL_RETURN_IF_ERROR(DecodeTree(reader, &tree, tree_size_limit)); + JXL_RETURN_IF_ERROR( + DecodeHistograms(reader, (tree.size() + 1) / 2, &code, &context_map)); + } } - do_color = decode_color; if (!do_color) nb_chans = 0; - size_t nb_extra = metadata.extra_channel_info.size(); bool fp = metadata.bit_depth.floating_point_sample; @@ -212,13 +244,15 @@ Status ModularFrameDecoder::DecodeGlobalInfo(BitReader* reader, all_same_shift = false; } + JXL_DEBUG_V(6, "DecodeGlobalInfo: full_image (w/o transforms) %s", + gi.DebugString().c_str()); ModularOptions options; options.max_chan_size = frame_dim.group_dim; options.group_dim = frame_dim.group_dim; Status dec_status = ModularGenericDecompress( reader, gi, &global_header, ModularStreamId::Global().ID(frame_dim), &options, - /*undo_transforms=*/-2, &tree, &code, &context_map, + /*undo_transforms=*/false, &tree, &code, &context_map, allow_truncated_group); if (!allow_truncated_group) JXL_RETURN_IF_ERROR(dec_status); if (dec_status.IsFatalError()) { @@ -242,12 +276,15 @@ Status ModularFrameDecoder::DecodeGlobalInfo(BitReader* reader, } } full_image = std::move(gi); + JXL_DEBUG_V(6, "DecodeGlobalInfo: full_image (with transforms) %s", + full_image.DebugString().c_str()); return dec_status; } void ModularFrameDecoder::MaybeDropFullImage() { if (full_image.transform.empty() && !have_something && all_same_shift) { use_full_image = false; + JXL_DEBUG_V(6, "Dropping full image"); for (auto& ch : full_image.channel) { // keep metadata on channels around, but dealloc their planes ch.plane = Plane<pixel_type>(); @@ -255,12 +292,14 @@ void ModularFrameDecoder::MaybeDropFullImage() { } } -Status ModularFrameDecoder::DecodeGroup(const Rect& rect, BitReader* reader, - int minShift, int maxShift, - const ModularStreamId& stream, - bool zerofill, - PassesDecoderState* dec_state, - ImageBundle* output) { +Status ModularFrameDecoder::DecodeGroup( + const Rect& rect, BitReader* reader, int minShift, int maxShift, + const ModularStreamId& stream, bool zerofill, PassesDecoderState* dec_state, + RenderPipelineInput* render_pipeline_input, bool allow_truncated, + bool* should_run_pipeline) { + JXL_DEBUG_V(6, "Decoding %s with rect %s and shift bracket %d..%d %s", + stream.DebugString().c_str(), Description(rect).c_str(), minShift, + maxShift, zerofill ? "using zerofill" : ""); JXL_DASSERT(stream.kind == ModularStreamId::kModularDC || stream.kind == ModularStreamId::kModularAC); const size_t xsize = rect.xsize(); @@ -297,22 +336,36 @@ Status ModularFrameDecoder::DecodeGroup(const Rect& rect, BitReader* reader, if (zerofill && use_full_image) return true; // Return early if there's nothing to decode. Otherwise there might be // problems later (in ModularImageToDecodedRect). - if (gi.channel.empty()) return true; + if (gi.channel.empty()) { + if (dec_state && should_run_pipeline) { + const auto& frame_header = dec_state->shared->frame_header; + const auto* metadata = frame_header.nonserialized_metadata; + if (do_color || metadata->m.num_extra_channels > 0) { + // Signal to FrameDecoder that we do not have some of the required input + // for the render pipeline. + *should_run_pipeline = false; + } + } + JXL_DEBUG_V(6, "Nothing to decode, returning early."); + return true; + } ModularOptions options; if (!zerofill) { - if (!ModularGenericDecompress( - reader, gi, /*header=*/nullptr, stream.ID(frame_dim), &options, - /*undo_transforms=*/-1, &tree, &code, &context_map)) { - return JXL_FAILURE("Failed to decode modular group"); - } + auto status = ModularGenericDecompress( + reader, gi, /*header=*/nullptr, stream.ID(frame_dim), &options, + /*undo_transforms=*/true, &tree, &code, &context_map, allow_truncated); + if (!allow_truncated) JXL_RETURN_IF_ERROR(status); + if (status.IsFatalError()) return status; } // Undo global transforms that have been pushed to the group level if (!use_full_image) { + JXL_ASSERT(render_pipeline_input); for (auto t : global_transform) { JXL_RETURN_IF_ERROR(t.Inverse(gi, global_header.wp_header)); } - JXL_RETURN_IF_ERROR(ModularImageToDecodedRect( - gi, dec_state, nullptr, output, rect.Crop(dec_state->decoded))); + JXL_RETURN_IF_ERROR(ModularImageToDecodedRect(gi, dec_state, nullptr, + *render_pipeline_input, + Rect(0, 0, gi.w, gi.h))); return true; } int gic = 0; @@ -332,6 +385,7 @@ Status ModularFrameDecoder::DecodeGroup(const Rect& rect, BitReader* reader, } return true; } + Status ModularFrameDecoder::DecodeVarDCTDC(size_t group_id, BitReader* reader, PassesDecoderState* dec_state) { const Rect r = dec_state->shared->DCGroupRect(group_id); @@ -355,7 +409,7 @@ Status ModularFrameDecoder::DecodeVarDCTDC(size_t group_id, BitReader* reader, } if (!ModularGenericDecompress( reader, image, /*header=*/nullptr, stream_id, &options, - /*undo_transforms=*/-1, &tree, &code, &context_map)) { + /*undo_transforms=*/true, &tree, &code, &context_map)) { return JXL_FAILURE("Failed to decode modular DC group"); } DequantDC(r, &dec_state->shared_storage.dc_storage, @@ -384,7 +438,7 @@ Status ModularFrameDecoder::DecodeAcMetadata(size_t group_id, BitReader* reader, ModularOptions options; if (!ModularGenericDecompress( reader, image, /*header=*/nullptr, stream_id, &options, - /*undo_transforms=*/-1, &tree, &code, &context_map)) { + /*undo_transforms=*/true, &tree, &code, &context_map)) { return JXL_FAILURE("Failed to decode AC metadata"); } ConvertPlaneAndClamp(Rect(image.channel[0].plane), image.channel[0].plane, cr, @@ -399,11 +453,11 @@ Status ModularFrameDecoder::DecodeAcMetadata(size_t group_id, BitReader* reader, uint32_t local_used_acs = 0; for (size_t iy = 0; iy < r.ysize(); iy++) { size_t y = r.y0() + iy; - int* row_qf = r.Row(&dec_state->shared_storage.raw_quant_field, iy); + int32_t* row_qf = r.Row(&dec_state->shared_storage.raw_quant_field, iy); uint8_t* row_epf = r.Row(&dec_state->shared_storage.epf_sharpness, iy); - int* row_in_1 = image.channel[2].plane.Row(0); - int* row_in_2 = image.channel[2].plane.Row(1); - int* row_in_3 = image.channel[3].plane.Row(iy); + int32_t* row_in_1 = image.channel[2].plane.Row(0); + int32_t* row_in_2 = image.channel[2].plane.Row(1); + int32_t* row_in_3 = image.channel[3].plane.Row(iy); for (size_t ix = 0; ix < r.xsize(); ix++) { size_t x = r.x0() + ix; int sharpness = row_in_3[ix]; @@ -440,8 +494,8 @@ Status ModularFrameDecoder::DecodeAcMetadata(size_t group_id, BitReader* reader, } JXL_RETURN_IF_ERROR( ac_strategy.SetNoBoundsCheck(x, y, AcStrategy::Type(row_in_1[num]))); - row_qf[ix] = - 1 + std::max(0, std::min(Quantizer::kQuantMax - 1, row_in_2[num])); + row_qf[ix] = 1 + std::max<int32_t>(0, std::min(Quantizer::kQuantMax - 1, + row_in_2[num])); num++; } } @@ -454,16 +508,15 @@ Status ModularFrameDecoder::DecodeAcMetadata(size_t group_id, BitReader* reader, Status ModularFrameDecoder::ModularImageToDecodedRect( Image& gi, PassesDecoderState* dec_state, jxl::ThreadPool* pool, - ImageBundle* output, Rect rect) { - auto& decoded = dec_state->decoded; + RenderPipelineInput& render_pipeline_input, Rect modular_rect) { const auto& frame_header = dec_state->shared->frame_header; const auto* metadata = frame_header.nonserialized_metadata; - size_t xsize = rect.xsize(); - size_t ysize = rect.ysize(); - if (!xsize || !ysize) { - return true; - } - JXL_DASSERT(rect.IsInside(decoded)); + JXL_CHECK(gi.transform.empty()); + + auto get_row = [&](size_t c, size_t y) { + const auto& buffer = render_pipeline_input.GetBuffer(c); + return buffer.second.Row(buffer.first, y); + }; size_t c = 0; if (do_color) { @@ -473,9 +526,9 @@ Status ModularFrameDecoder::ModularImageToDecodedRect( const bool fp = metadata->m.bit_depth.floating_point_sample && frame_header.color_transform != ColorTransform::kXYB; for (; c < 3; c++) { - float factor = full_image.bitdepth < 32 - ? 1.f / ((1u << full_image.bitdepth) - 1) - : 0; + double factor = full_image.bitdepth < 32 + ? 1.0 / ((1u << full_image.bitdepth) - 1) + : 0; size_t c_in = c; if (frame_header.color_transform == ColorTransform::kXYB) { factor = dec_state->shared->matrices.DCQuants()[c]; @@ -490,59 +543,89 @@ Status ModularFrameDecoder::ModularImageToDecodedRect( if (ch_in.w == 0 || ch_in.h == 0) { return JXL_FAILURE("Empty image"); } - size_t xsize_shifted = DivCeil(xsize, 1 << ch_in.hshift); - size_t ysize_shifted = DivCeil(ysize, 1 << ch_in.vshift); - Rect r(rect.x0() >> ch_in.hshift, rect.y0() >> ch_in.vshift, - rect.xsize() >> ch_in.hshift, rect.ysize() >> ch_in.vshift, - DivCeil(decoded.xsize(), 1 << ch_in.hshift), - DivCeil(decoded.ysize(), 1 << ch_in.vshift)); - if (r.ysize() != ch_in.h || r.xsize() != ch_in.w) { - return JXL_FAILURE( - "Dimension mismatch: trying to fit a %zux%zu modular channel into " - "a %zux%zu rect", - ch_in.w, ch_in.h, r.xsize(), r.ysize()); + JXL_CHECK(ch_in.hshift <= 3 && ch_in.vshift <= 3); + Rect r = render_pipeline_input.GetBuffer(c).second; + Rect mr(modular_rect.x0() >> ch_in.hshift, + modular_rect.y0() >> ch_in.vshift, + DivCeil(modular_rect.xsize(), 1 << ch_in.hshift), + DivCeil(modular_rect.ysize(), 1 << ch_in.vshift)); + mr = mr.Crop(ch_in.plane); + size_t xsize_shifted = r.xsize(); + size_t ysize_shifted = r.ysize(); + if (r.ysize() != mr.ysize() || r.xsize() != mr.xsize()) { + return JXL_FAILURE("Dimension mismatch: trying to fit a %" PRIuS + "x%" PRIuS + " modular channel into " + "a %" PRIuS "x%" PRIuS " rect", + mr.xsize(), mr.ysize(), r.xsize(), r.ysize()); } if (frame_header.color_transform == ColorTransform::kXYB && c == 2) { JXL_ASSERT(!fp); - RunOnPool( - pool, 0, ysize_shifted, jxl::ThreadPool::SkipInit(), - [&](const int task, const int thread) { + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, ysize_shifted, ThreadPool::NoInit, + [&](const uint32_t task, size_t /* thread */) { const size_t y = task; - const pixel_type* const JXL_RESTRICT row_in = ch_in.Row(y); + const pixel_type* const JXL_RESTRICT row_in = + mr.Row(&ch_in.plane, y); const pixel_type* const JXL_RESTRICT row_in_Y = - gi.channel[0].Row(y); - float* const JXL_RESTRICT row_out = r.PlaneRow(&decoded, c, y); + mr.Row(&gi.channel[0].plane, y); + float* const JXL_RESTRICT row_out = get_row(c, y); HWY_DYNAMIC_DISPATCH(MultiplySum) (xsize_shifted, row_in, row_in_Y, factor, row_out); }, - "ModularIntToFloat"); + "ModularIntToFloat")); } else if (fp) { int bits = metadata->m.bit_depth.bits_per_sample; int exp_bits = metadata->m.bit_depth.exponent_bits_per_sample; - RunOnPool( - pool, 0, ysize_shifted, jxl::ThreadPool::SkipInit(), - [&](const int task, const int thread) { + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, ysize_shifted, ThreadPool::NoInit, + [&](const uint32_t task, size_t /* thread */) { const size_t y = task; - const pixel_type* const JXL_RESTRICT row_in = ch_in.Row(y); - float* const JXL_RESTRICT row_out = r.PlaneRow(&decoded, c, y); - int_to_float(row_in, row_out, xsize_shifted, bits, exp_bits); + const pixel_type* const JXL_RESTRICT row_in = + mr.Row(&ch_in.plane, y); + if (rgb_from_gray) { + for (size_t cc = 0; cc < 3; cc++) { + float* const JXL_RESTRICT row_out = get_row(cc, y); + int_to_float(row_in, row_out, xsize_shifted, bits, exp_bits); + } + } else { + float* const JXL_RESTRICT row_out = get_row(c, y); + int_to_float(row_in, row_out, xsize_shifted, bits, exp_bits); + } }, - "ModularIntToFloat_losslessfloat"); + "ModularIntToFloat_losslessfloat")); } else { - RunOnPool( - pool, 0, ysize_shifted, jxl::ThreadPool::SkipInit(), - [&](const int task, const int thread) { + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, ysize_shifted, ThreadPool::NoInit, + [&](const uint32_t task, size_t /* thread */) { const size_t y = task; - const pixel_type* const JXL_RESTRICT row_in = ch_in.Row(y); + const pixel_type* const JXL_RESTRICT row_in = + mr.Row(&ch_in.plane, y); if (rgb_from_gray) { - HWY_DYNAMIC_DISPATCH(RgbFromSingle) - (xsize_shifted, row_in, factor, &decoded, c, y, r); + if (full_image.bitdepth < 23) { + HWY_DYNAMIC_DISPATCH(RgbFromSingle) + (xsize_shifted, row_in, factor, get_row(0, y), get_row(1, y), + get_row(2, y)); + } else { + SingleFromSingleAccurate(xsize_shifted, row_in, factor, + get_row(0, y)); + SingleFromSingleAccurate(xsize_shifted, row_in, factor, + get_row(1, y)); + SingleFromSingleAccurate(xsize_shifted, row_in, factor, + get_row(2, y)); + } } else { - HWY_DYNAMIC_DISPATCH(SingleFromSingle) - (xsize_shifted, row_in, factor, &decoded, c, y, r); + float* const JXL_RESTRICT row_out = get_row(c, y); + if (full_image.bitdepth < 23) { + HWY_DYNAMIC_DISPATCH(SingleFromSingle) + (xsize_shifted, row_in, factor, row_out); + } else { + SingleFromSingleAccurate(xsize_shifted, row_in, factor, + row_out); + } } }, - "ModularIntToFloat"); + "ModularIntToFloat")); } if (rgb_from_gray) { break; @@ -552,66 +635,94 @@ Status ModularFrameDecoder::ModularImageToDecodedRect( c = 1; } } - for (size_t ec = 0; ec < dec_state->extra_channels.size(); ec++, c++) { - const ExtraChannelInfo& eci = output->metadata()->extra_channel_info[ec]; + size_t num_extra_channels = metadata->m.num_extra_channels; + for (size_t ec = 0; ec < num_extra_channels; ec++, c++) { + const ExtraChannelInfo& eci = metadata->m.extra_channel_info[ec]; int bits = eci.bit_depth.bits_per_sample; int exp_bits = eci.bit_depth.exponent_bits_per_sample; bool fp = eci.bit_depth.floating_point_sample; JXL_ASSERT(fp || bits < 32); - const float mul = fp ? 0 : (1.0f / ((1u << bits) - 1)); - size_t ecups = frame_header.extra_channel_upsampling[ec]; - const size_t ec_xsize = DivCeil(frame_dim.xsize_upsampled, ecups); - const size_t ec_ysize = DivCeil(frame_dim.ysize_upsampled, ecups); + const double factor = fp ? 0 : (1.0 / ((1u << bits) - 1)); JXL_ASSERT(c < gi.channel.size()); Channel& ch_in = gi.channel[c]; - // For x0, y0 there's no need to do a DivCeil(). - JXL_DASSERT(rect.x0() % (1ul << ch_in.hshift) == 0); - JXL_DASSERT(rect.y0() % (1ul << ch_in.vshift) == 0); - Rect r(rect.x0() >> ch_in.hshift, rect.y0() >> ch_in.vshift, - DivCeil(rect.xsize(), 1lu << ch_in.hshift), - DivCeil(rect.ysize(), 1lu << ch_in.vshift), ec_xsize, ec_ysize); - - JXL_DASSERT(r.IsInside(dec_state->extra_channels[ec])); - JXL_DASSERT(Rect(0, 0, r.xsize(), r.ysize()).IsInside(ch_in.plane)); + Rect r = render_pipeline_input.GetBuffer(3 + ec).second; + Rect mr(modular_rect.x0() >> ch_in.hshift, + modular_rect.y0() >> ch_in.vshift, + DivCeil(modular_rect.xsize(), 1 << ch_in.hshift), + DivCeil(modular_rect.ysize(), 1 << ch_in.vshift)); + mr = mr.Crop(ch_in.plane); + if (r.ysize() != mr.ysize() || r.xsize() != mr.xsize()) { + return JXL_FAILURE("Dimension mismatch: trying to fit a %" PRIuS + "x%" PRIuS + " modular channel into " + "a %" PRIuS "x%" PRIuS " rect", + mr.xsize(), mr.ysize(), r.xsize(), r.ysize()); + } for (size_t y = 0; y < r.ysize(); ++y) { float* const JXL_RESTRICT row_out = - r.Row(&dec_state->extra_channels[ec], y); - const pixel_type* const JXL_RESTRICT row_in = ch_in.Row(y); + r.Row(render_pipeline_input.GetBuffer(3 + ec).first, y); + const pixel_type* const JXL_RESTRICT row_in = mr.Row(&ch_in.plane, y); if (fp) { int_to_float(row_in, row_out, r.xsize(), bits, exp_bits); } else { - for (size_t x = 0; x < r.xsize(); ++x) { - row_out[x] = row_in[x] * mul; + if (full_image.bitdepth < 23) { + HWY_DYNAMIC_DISPATCH(SingleFromSingle) + (r.xsize(), row_in, factor, row_out); + } else { + SingleFromSingleAccurate(r.xsize(), row_in, factor, row_out); } } } - JXL_CHECK_IMAGE_INITIALIZED(dec_state->extra_channels[ec], r); } return true; } Status ModularFrameDecoder::FinalizeDecoding(PassesDecoderState* dec_state, jxl::ThreadPool* pool, - ImageBundle* output) { + bool inplace) { if (!use_full_image) return true; - Image& gi = full_image; + Image gi = (inplace ? std::move(full_image) : full_image.clone()); size_t xsize = gi.w; size_t ysize = gi.h; + JXL_DEBUG_V(3, "Finalizing decoding for modular image: %s", + gi.DebugString().c_str()); + // Don't use threads if total image size is smaller than a group if (xsize * ysize < frame_dim.group_dim * frame_dim.group_dim) pool = nullptr; // Undo the global transforms - gi.undo_transforms(global_header.wp_header, -1, pool); - for (auto t : global_transform) { - JXL_RETURN_IF_ERROR(t.Inverse(gi, global_header.wp_header)); - } + gi.undo_transforms(global_header.wp_header, pool); + JXL_DASSERT(global_transform.empty()); if (gi.error) return JXL_FAILURE("Undoing transforms failed"); - auto& decoded = dec_state->decoded; - - JXL_RETURN_IF_ERROR( - ModularImageToDecodedRect(gi, dec_state, pool, output, Rect(decoded))); + for (size_t i = 0; i < dec_state->shared->frame_dim.num_groups; i++) { + dec_state->render_pipeline->ClearDone(i); + } + std::atomic<bool> has_error{false}; + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, dec_state->shared->frame_dim.num_groups, + [&](size_t num_threads) { + const auto& frame_header = dec_state->shared->frame_header; + bool use_group_ids = (frame_header.encoding == FrameEncoding::kVarDCT || + (frame_header.flags & FrameHeader::kNoise)); + return dec_state->render_pipeline->PrepareForThreads(num_threads, + use_group_ids); + }, + [&](const uint32_t group, size_t thread_id) { + RenderPipelineInput input = + dec_state->render_pipeline->GetInputBuffers(group, thread_id); + if (!ModularImageToDecodedRect(gi, dec_state, nullptr, input, + dec_state->shared->GroupRect(group))) { + has_error = true; + return; + } + input.Done(); + }, + "ModularToRect")); + if (has_error) { + return JXL_FAILURE("Error producing input to render pipeline"); + } return true; } @@ -633,12 +744,12 @@ Status ModularFrameDecoder::DecodeQuantTable( JXL_RETURN_IF_ERROR(ModularGenericDecompress( br, image, /*header=*/nullptr, ModularStreamId::QuantTable(idx).ID(modular_frame_decoder->frame_dim), - &options, /*undo_transforms=*/-1, &modular_frame_decoder->tree, + &options, /*undo_transforms=*/true, &modular_frame_decoder->tree, &modular_frame_decoder->code, &modular_frame_decoder->context_map)); } else { JXL_RETURN_IF_ERROR(ModularGenericDecompress(br, image, /*header=*/nullptr, 0, &options, - /*undo_transforms=*/-1)); + /*undo_transforms=*/true)); } if (!encoding->qraw.qtable) { encoding->qraw.qtable = new std::vector<int>(); @@ -646,7 +757,7 @@ Status ModularFrameDecoder::DecodeQuantTable( encoding->qraw.qtable->resize(required_size_x * required_size_y * 3); for (size_t c = 0; c < 3; c++) { for (size_t y = 0; y < required_size_y; y++) { - int* JXL_RESTRICT row = image.channel[c].Row(y); + int32_t* JXL_RESTRICT row = image.channel[c].Row(y); for (size_t x = 0; x < required_size_x; x++) { (*encoding->qraw.qtable)[c * required_size_x * required_size_y + y * required_size_x + x] = row[x]; |