summaryrefslogtreecommitdiff
path: root/lib/jxl/enc_noise.cc
blob: a12a9e6dc4c3d2d190d3aea1748bad9ca8d20db9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

#include "lib/jxl/enc_noise.h"

#include <stdint.h>
#include <stdlib.h>

#include <algorithm>
#include <numeric>
#include <utility>

#include "lib/jxl/base/compiler_specific.h"
#include "lib/jxl/chroma_from_luma.h"
#include "lib/jxl/convolve.h"
#include "lib/jxl/enc_aux_out.h"
#include "lib/jxl/enc_optimize.h"
#include "lib/jxl/image_ops.h"

namespace jxl {
namespace {

using OptimizeArray = optimize::Array<double, NoiseParams::kNumNoisePoints>;

float GetScoreSumsOfAbsoluteDifferences(const Image3F& opsin, const int x,
                                        const int y, const int block_size) {
  const int small_bl_size_x = 3;
  const int small_bl_size_y = 4;
  const int kNumSAD =
      (block_size - small_bl_size_x) * (block_size - small_bl_size_y);
  // block_size x block_size reference pixels
  int counter = 0;
  const int offset = 2;

  std::vector<float> sad(kNumSAD, 0);
  for (int y_bl = 0; y_bl + small_bl_size_y < block_size; ++y_bl) {
    for (int x_bl = 0; x_bl + small_bl_size_x < block_size; ++x_bl) {
      float sad_sum = 0;
      // size of the center patch, we compare all the patches inside window with
      // the center one
      for (int cy = 0; cy < small_bl_size_y; ++cy) {
        for (int cx = 0; cx < small_bl_size_x; ++cx) {
          float wnd = 0.5f * (opsin.PlaneRow(1, y + y_bl + cy)[x + x_bl + cx] +
                              opsin.PlaneRow(0, y + y_bl + cy)[x + x_bl + cx]);
          float center =
              0.5f * (opsin.PlaneRow(1, y + offset + cy)[x + offset + cx] +
                      opsin.PlaneRow(0, y + offset + cy)[x + offset + cx]);
          sad_sum += std::abs(center - wnd);
        }
      }
      sad[counter++] = sad_sum;
    }
  }
  const int kSamples = (kNumSAD) / 2;
  // As with ROAD (rank order absolute distance), we keep the smallest half of
  // the values in SAD (we use here the more robust patch SAD instead of
  // absolute single-pixel differences).
  std::sort(sad.begin(), sad.end());
  const float total_sad_sum =
      std::accumulate(sad.begin(), sad.begin() + kSamples, 0.0f);
  return total_sad_sum / kSamples;
}

class NoiseHistogram {
 public:
  static constexpr int kBins = 256;

  NoiseHistogram() { std::fill(bins, bins + kBins, 0); }

  void Increment(const float x) { bins[Index(x)] += 1; }
  int Get(const float x) const { return bins[Index(x)]; }
  int Bin(const size_t bin) const { return bins[bin]; }

  int Mode() const {
    size_t max_idx = 0;
    for (size_t i = 0; i < kBins; i++) {
      if (bins[i] > bins[max_idx]) max_idx = i;
    }
    return max_idx;
  }

  double Quantile(double q01) const {
    const int64_t total = std::accumulate(bins, bins + kBins, int64_t{1});
    const int64_t target = static_cast<int64_t>(q01 * total);
    // Until sum >= target:
    int64_t sum = 0;
    size_t i = 0;
    for (; i < kBins; ++i) {
      sum += bins[i];
      // Exact match: assume middle of bin i
      if (sum == target) {
        return i + 0.5;
      }
      if (sum > target) break;
    }

    // Next non-empty bin (in case histogram is sparsely filled)
    size_t next = i + 1;
    while (next < kBins && bins[next] == 0) {
      ++next;
    }

    // Linear interpolation according to how far into next we went
    const double excess = target - sum;
    const double weight_next = bins[Index(next)] / excess;
    return ClampX(next * weight_next + i * (1.0 - weight_next));
  }

  // Inter-quartile range
  double IQR() const { return Quantile(0.75) - Quantile(0.25); }

 private:
  template <typename T>
  T ClampX(const T x) const {
    return std::min(std::max(T(0), x), T(kBins - 1));
  }
  size_t Index(const float x) const { return ClampX(static_cast<int>(x)); }

  uint32_t bins[kBins];
};

std::vector<float> GetSADScoresForPatches(const Image3F& opsin,
                                          const size_t block_s,
                                          const size_t num_bin,
                                          NoiseHistogram* sad_histogram) {
  std::vector<float> sad_scores(
      (opsin.ysize() / block_s) * (opsin.xsize() / block_s), 0.0f);

  int block_index = 0;

  for (size_t y = 0; y + block_s <= opsin.ysize(); y += block_s) {
    for (size_t x = 0; x + block_s <= opsin.xsize(); x += block_s) {
      float sad_sc = GetScoreSumsOfAbsoluteDifferences(opsin, x, y, block_s);
      sad_scores[block_index++] = sad_sc;
      sad_histogram->Increment(sad_sc * num_bin);
    }
  }
  return sad_scores;
}

float GetSADThreshold(const NoiseHistogram& histogram, const int num_bin) {
  // Here we assume that the most patches with similar SAD value is a "flat"
  // patches. However, some images might contain regular texture part and
  // generate second strong peak at the histogram
  // TODO(user) handle bimodal and heavy-tailed case
  const int mode = histogram.Mode();
  return static_cast<float>(mode) / NoiseHistogram::kBins;
}

// loss = sum asym * (F(x) - nl)^2 + kReg * num_points * sum (w[i] - w[i+1])^2
// where asym = 1 if F(x) < nl, kAsym if F(x) > nl.
struct LossFunction {
  explicit LossFunction(std::vector<NoiseLevel> nl0) : nl(std::move(nl0)) {}

  double Compute(const OptimizeArray& w, OptimizeArray* df,
                 bool skip_regularization = false) const {
    constexpr double kReg = 0.005;
    constexpr double kAsym = 1.1;
    double loss_function = 0;
    for (size_t i = 0; i < w.size(); i++) {
      (*df)[i] = 0;
    }
    for (auto ind : nl) {
      std::pair<int, float> pos = IndexAndFrac(ind.intensity);
      JXL_DASSERT(pos.first >= 0 && static_cast<size_t>(pos.first) <
                                        NoiseParams::kNumNoisePoints - 1);
      double low = w[pos.first];
      double hi = w[pos.first + 1];
      double val = low * (1.0f - pos.second) + hi * pos.second;
      double dist = val - ind.noise_level;
      if (dist > 0) {
        loss_function += kAsym * dist * dist;
        (*df)[pos.first] -= kAsym * (1.0f - pos.second) * dist;
        (*df)[pos.first + 1] -= kAsym * pos.second * dist;
      } else {
        loss_function += dist * dist;
        (*df)[pos.first] -= (1.0f - pos.second) * dist;
        (*df)[pos.first + 1] -= pos.second * dist;
      }
    }
    if (skip_regularization) return loss_function;
    for (size_t i = 0; i + 1 < w.size(); i++) {
      double diff = w[i] - w[i + 1];
      loss_function += kReg * nl.size() * diff * diff;
      (*df)[i] -= kReg * diff * nl.size();
      (*df)[i + 1] += kReg * diff * nl.size();
    }
    return loss_function;
  }

  std::vector<NoiseLevel> nl;
};

void OptimizeNoiseParameters(const std::vector<NoiseLevel>& noise_level,
                             NoiseParams* noise_params) {
  constexpr double kMaxError = 1e-3;
  static const double kPrecision = 1e-8;
  static const int kMaxIter = 40;

  float avg = 0;
  for (const NoiseLevel& nl : noise_level) {
    avg += nl.noise_level;
  }
  avg /= noise_level.size();

  LossFunction loss_function(noise_level);
  OptimizeArray parameter_vector;
  for (size_t i = 0; i < parameter_vector.size(); i++) {
    parameter_vector[i] = avg;
  }

  parameter_vector = optimize::OptimizeWithScaledConjugateGradientMethod(
      loss_function, parameter_vector, kPrecision, kMaxIter);

  OptimizeArray df = parameter_vector;
  float loss = loss_function.Compute(parameter_vector, &df,
                                     /*skip_regularization=*/true) /
               noise_level.size();

  // Approximation went too badly: escape with no noise at all.
  if (loss > kMaxError) {
    noise_params->Clear();
    return;
  }

  for (size_t i = 0; i < parameter_vector.size(); i++) {
    noise_params->lut[i] = std::max(parameter_vector[i], 0.0);
  }
}

std::vector<NoiseLevel> GetNoiseLevel(
    const Image3F& opsin, const std::vector<float>& texture_strength,
    const float threshold, const size_t block_s) {
  std::vector<NoiseLevel> noise_level_per_intensity;

  const int filt_size = 1;
  static const float kLaplFilter[filt_size * 2 + 1][filt_size * 2 + 1] = {
      {-0.25f, -1.0f, -0.25f},
      {-1.0f, 5.0f, -1.0f},
      {-0.25f, -1.0f, -0.25f},
  };

  // The noise model is built based on channel 0.5 * (X+Y) as we notice that it
  // is similar to the model 0.5 * (Y-X)
  size_t patch_index = 0;

  for (size_t y = 0; y + block_s <= opsin.ysize(); y += block_s) {
    for (size_t x = 0; x + block_s <= opsin.xsize(); x += block_s) {
      if (texture_strength[patch_index] <= threshold) {
        // Calculate mean value
        float mean_int = 0;
        for (size_t y_bl = 0; y_bl < block_s; ++y_bl) {
          for (size_t x_bl = 0; x_bl < block_s; ++x_bl) {
            mean_int += 0.5f * (opsin.PlaneRow(1, y + y_bl)[x + x_bl] +
                                opsin.PlaneRow(0, y + y_bl)[x + x_bl]);
          }
        }
        mean_int /= block_s * block_s;

        // Calculate Noise level
        float noise_level = 0;
        size_t count = 0;
        for (size_t y_bl = 0; y_bl < block_s; ++y_bl) {
          for (size_t x_bl = 0; x_bl < block_s; ++x_bl) {
            float filtered_value = 0;
            for (int y_f = -1 * filt_size; y_f <= filt_size; ++y_f) {
              if ((static_cast<ssize_t>(y_bl) + y_f) >= 0 &&
                  (y_bl + y_f) < block_s) {
                for (int x_f = -1 * filt_size; x_f <= filt_size; ++x_f) {
                  if ((static_cast<ssize_t>(x_bl) + x_f) >= 0 &&
                      (x_bl + x_f) < block_s) {
                    filtered_value +=
                        0.5f *
                        (opsin.PlaneRow(1, y + y_bl + y_f)[x + x_bl + x_f] +
                         opsin.PlaneRow(0, y + y_bl + y_f)[x + x_bl + x_f]) *
                        kLaplFilter[y_f + filt_size][x_f + filt_size];
                  } else {
                    filtered_value +=
                        0.5f *
                        (opsin.PlaneRow(1, y + y_bl + y_f)[x + x_bl - x_f] +
                         opsin.PlaneRow(0, y + y_bl + y_f)[x + x_bl - x_f]) *
                        kLaplFilter[y_f + filt_size][x_f + filt_size];
                  }
                }
              } else {
                for (int x_f = -1 * filt_size; x_f <= filt_size; ++x_f) {
                  if ((static_cast<ssize_t>(x_bl) + x_f) >= 0 &&
                      (x_bl + x_f) < block_s) {
                    filtered_value +=
                        0.5f *
                        (opsin.PlaneRow(1, y + y_bl - y_f)[x + x_bl + x_f] +
                         opsin.PlaneRow(0, y + y_bl - y_f)[x + x_bl + x_f]) *
                        kLaplFilter[y_f + filt_size][x_f + filt_size];
                  } else {
                    filtered_value +=
                        0.5f *
                        (opsin.PlaneRow(1, y + y_bl - y_f)[x + x_bl - x_f] +
                         opsin.PlaneRow(0, y + y_bl - y_f)[x + x_bl - x_f]) *
                        kLaplFilter[y_f + filt_size][x_f + filt_size];
                  }
                }
              }
            }
            noise_level += std::abs(filtered_value);
            ++count;
          }
        }
        noise_level /= count;
        NoiseLevel nl;
        nl.intensity = mean_int;
        nl.noise_level = noise_level;
        noise_level_per_intensity.push_back(nl);
      }
      ++patch_index;
    }
  }
  return noise_level_per_intensity;
}

void EncodeFloatParam(float val, float precision, BitWriter* writer) {
  JXL_ASSERT(val >= 0);
  const int absval_quant = static_cast<int>(val * precision + 0.5f);
  JXL_ASSERT(absval_quant < (1 << 10));
  writer->Write(10, absval_quant);
}

}  // namespace

Status GetNoiseParameter(const Image3F& opsin, NoiseParams* noise_params,
                         float quality_coef) {
  // The size of a patch in decoder might be different from encoder's patch
  // size.
  // For encoder: the patch size should be big enough to estimate
  //              noise level, but, at the same time, it should be not too big
  //              to be able to estimate intensity value of the patch
  const size_t block_s = 8;
  const size_t kNumBin = 256;
  NoiseHistogram sad_histogram;
  std::vector<float> sad_scores =
      GetSADScoresForPatches(opsin, block_s, kNumBin, &sad_histogram);
  float sad_threshold = GetSADThreshold(sad_histogram, kNumBin);
  // If threshold is too large, the image has a strong pattern. This pattern
  // fools our model and it will add too much noise. Therefore, we do not add
  // noise for such images
  if (sad_threshold > 0.15f || sad_threshold <= 0.0f) {
    noise_params->Clear();
    return false;
  }
  std::vector<NoiseLevel> nl =
      GetNoiseLevel(opsin, sad_scores, sad_threshold, block_s);

  OptimizeNoiseParameters(nl, noise_params);
  for (float& i : noise_params->lut) {
    i *= quality_coef * 1.4;
  }
  return noise_params->HasAny();
}

void EncodeNoise(const NoiseParams& noise_params, BitWriter* writer,
                 size_t layer, AuxOut* aux_out) {
  JXL_ASSERT(noise_params.HasAny());

  BitWriter::Allotment allotment(writer, NoiseParams::kNumNoisePoints * 16);
  for (float i : noise_params.lut) {
    EncodeFloatParam(i, kNoisePrecision, writer);
  }
  allotment.ReclaimAndCharge(writer, layer, aux_out);
}

}  // namespace jxl