summaryrefslogtreecommitdiff
path: root/lib/jxl/enc_optimize.h
diff options
context:
space:
mode:
Diffstat (limited to 'lib/jxl/enc_optimize.h')
-rw-r--r--lib/jxl/enc_optimize.h216
1 files changed, 216 insertions, 0 deletions
diff --git a/lib/jxl/enc_optimize.h b/lib/jxl/enc_optimize.h
new file mode 100644
index 0000000..9da523f
--- /dev/null
+++ b/lib/jxl/enc_optimize.h
@@ -0,0 +1,216 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Utility functions for optimizing multi-dimensional nonlinear functions.
+
+#ifndef LIB_JXL_OPTIMIZE_H_
+#define LIB_JXL_OPTIMIZE_H_
+
+#include <cmath>
+#include <cstdio>
+#include <functional>
+#include <vector>
+
+#include "lib/jxl/base/status.h"
+
+namespace jxl {
+namespace optimize {
+
+// An array type of numeric values that supports math operations with operator-,
+// operator+, etc.
+template <typename T, size_t N>
+class Array {
+ public:
+ Array() = default;
+ explicit Array(T v) {
+ for (size_t i = 0; i < N; i++) v_[i] = v;
+ }
+
+ size_t size() const { return N; }
+
+ T& operator[](size_t index) {
+ JXL_DASSERT(index < N);
+ return v_[index];
+ }
+ T operator[](size_t index) const {
+ JXL_DASSERT(index < N);
+ return v_[index];
+ }
+
+ private:
+ // The values used by this Array.
+ T v_[N];
+};
+
+template <typename T, size_t N>
+Array<T, N> operator+(const Array<T, N>& x, const Array<T, N>& y) {
+ Array<T, N> z;
+ for (size_t i = 0; i < N; ++i) {
+ z[i] = x[i] + y[i];
+ }
+ return z;
+}
+
+template <typename T, size_t N>
+Array<T, N> operator-(const Array<T, N>& x, const Array<T, N>& y) {
+ Array<T, N> z;
+ for (size_t i = 0; i < N; ++i) {
+ z[i] = x[i] - y[i];
+ }
+ return z;
+}
+
+template <typename T, size_t N>
+Array<T, N> operator*(T v, const Array<T, N>& x) {
+ Array<T, N> y;
+ for (size_t i = 0; i < N; ++i) {
+ y[i] = v * x[i];
+ }
+ return y;
+}
+
+template <typename T, size_t N>
+T operator*(const Array<T, N>& x, const Array<T, N>& y) {
+ T r = 0.0;
+ for (size_t i = 0; i < N; ++i) {
+ r += x[i] * y[i];
+ }
+ return r;
+}
+
+// Runs Nelder-Mead like optimization. Runs for max_iterations times,
+// fun gets called with a vector of size dim as argument, and returns the score
+// based on those parameters (lower is better). Returns a vector of dim+1
+// dimensions, where the first value is the optimal value of the function and
+// the rest is the argmin value. Use init to pass an initial guess or where
+// the optimal value is.
+//
+// Usage example:
+//
+// RunSimplex(2, 0.1, 100, [](const vector<float>& v) {
+// return (v[0] - 5) * (v[0] - 5) + (v[1] - 7) * (v[1] - 7);
+// });
+//
+// Returns (0.0, 5, 7)
+std::vector<double> RunSimplex(
+ int dim, double amount, int max_iterations,
+ const std::function<double(const std::vector<double>&)>& fun);
+std::vector<double> RunSimplex(
+ int dim, double amount, int max_iterations, const std::vector<double>& init,
+ const std::function<double(const std::vector<double>&)>& fun);
+
+// Implementation of the Scaled Conjugate Gradient method described in the
+// following paper:
+// Moller, M. "A Scaled Conjugate Gradient Algorithm for Fast Supervised
+// Learning", Neural Networks, Vol. 6. pp. 525-533, 1993
+// http://sci2s.ugr.es/keel/pdf/algorithm/articulo/moller1990.pdf
+//
+// The Function template parameter is a class that has the following method:
+//
+// // Returns the value of the function at point w and sets *df to be the
+// // negative gradient vector of the function at point w.
+// double Compute(const optimize::Array<T, N>& w,
+// optimize::Array<T, N>* df) const;
+//
+// Returns a vector w, such that |df(w)| < grad_norm_threshold.
+template <typename T, size_t N, typename Function>
+Array<T, N> OptimizeWithScaledConjugateGradientMethod(
+ const Function& f, const Array<T, N>& w0, const T grad_norm_threshold,
+ size_t max_iters) {
+ const size_t n = w0.size();
+ const T rsq_threshold = grad_norm_threshold * grad_norm_threshold;
+ const T sigma0 = static_cast<T>(0.0001);
+ const T l_min = static_cast<T>(1.0e-15);
+ const T l_max = static_cast<T>(1.0e15);
+
+ Array<T, N> w = w0;
+ Array<T, N> wp;
+ Array<T, N> r;
+ Array<T, N> rt;
+ Array<T, N> e;
+ Array<T, N> p;
+ T psq;
+ T fp;
+ T D;
+ T d;
+ T m;
+ T a;
+ T b;
+ T s;
+ T t;
+
+ T fw = f.Compute(w, &r);
+ T rsq = r * r;
+ e = r;
+ p = r;
+ T l = static_cast<T>(1.0);
+ bool success = true;
+ size_t n_success = 0;
+ size_t k = 0;
+
+ while (k++ < max_iters) {
+ if (success) {
+ m = -(p * r);
+ if (m >= 0) {
+ p = r;
+ m = -(p * r);
+ }
+ psq = p * p;
+ s = sigma0 / std::sqrt(psq);
+ f.Compute(w + (s * p), &rt);
+ t = (p * (r - rt)) / s;
+ }
+
+ d = t + l * psq;
+ if (d <= 0) {
+ d = l * psq;
+ l = l - t / psq;
+ }
+
+ a = -m / d;
+ wp = w + a * p;
+ fp = f.Compute(wp, &rt);
+
+ D = 2.0 * (fp - fw) / (a * m);
+ if (D >= 0.0) {
+ success = true;
+ n_success++;
+ w = wp;
+ } else {
+ success = false;
+ }
+
+ if (success) {
+ e = r;
+ r = rt;
+ rsq = r * r;
+ fw = fp;
+ if (rsq <= rsq_threshold) {
+ break;
+ }
+ }
+
+ if (D < 0.25) {
+ l = std::min(4.0 * l, l_max);
+ } else if (D > 0.75) {
+ l = std::max(0.25 * l, l_min);
+ }
+
+ if ((n_success % n) == 0) {
+ p = r;
+ l = 1.0;
+ } else if (success) {
+ b = ((e - r) * r) / m;
+ p = b * p + r;
+ }
+ }
+
+ return w;
+}
+
+} // namespace optimize
+} // namespace jxl
+
+#endif // LIB_JXL_OPTIMIZE_H_