summaryrefslogtreecommitdiff
path: root/tools/scripts/bisector
diff options
context:
space:
mode:
Diffstat (limited to 'tools/scripts/bisector')
-rwxr-xr-xtools/scripts/bisector287
1 files changed, 287 insertions, 0 deletions
diff --git a/tools/scripts/bisector b/tools/scripts/bisector
new file mode 100755
index 0000000..b6a82d0
--- /dev/null
+++ b/tools/scripts/bisector
@@ -0,0 +1,287 @@
+#!/usr/bin/env python
+#
+# Copyright (c) the JPEG XL Project Authors. All rights reserved.
+#
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file.
+
+r"""General-purpose bisector
+
+Prints a space-separated list of values to stdout:
+1_if_success_0_otherwise left_x left_f(x) right_x right_f(x)
+
+Usage examples:
+
+# Finding the square root of 200 via bisection:
+bisector --var=BB --range=0.0,100.0 --target=200 --maxiter=100 \
+ --atol_val=1e-12 --rtol_val=0 --cmd='echo "$BB * $BB" | bc'
+# => 1 14.142135623730923 199.99999999999923 14.142135623731633 200.0000000000193
+
+# Finding an integer approximation to sqrt(200) via bisection:
+bisector --var=BB --range=0,100 --target=200 --maxiter=100 \
+ --atol_arg=1 --cmd='echo "$BB * $BB" | bc'
+# => 1 14 196.0 15 225.0
+
+# Finding a change-id that broke something via bisection:
+bisector --var=BB --range=0,1000000 --target=0.5 --maxiter=100 \
+ --atol_arg=1 \
+ --cmd='test $BB -gt 123456 && echo 1 || echo 0' --verbosity=3
+# => 1 123456 0.0 123457 1.0
+
+# Finding settings that compress /usr/share/dict/words to a given target size:
+bisector --var=BB --range=1,9 --target=250000 --atol_arg=1 \
+ --cmd='gzip -$BB </usr/share/dict/words >/tmp/w_$BB.gz; wc -c /tmp/w_$BB.gz' \
+ --final='mv /tmp/w_$BB.gz /tmp/words.gz; rm /tmp/w_*.gz' \
+ --verbosity=1
+# => 1 3 263170.0 4 240043.0
+
+# JXL-encoding with bisection-for-size (tolerance 0.5%):
+bisector --var=BB --range=0.1,3.0 --target=3500 --rtol_val=0.005 \
+ --cmd='(build/tools/cjxl --distance=$BB /tmp/baseball.png /tmp/baseball_$BB.jxl && wc -c /tmp/baseball_$BB.jxl)' \
+ --final='mv /tmp/baseball_$BB.jxl /tmp/baseball.jxl; rm -f /tmp/baseball_*.jxl' \
+ --verbosity=1
+# => 1 1.1875 3573.0 1.278125 3481.0
+
+# JXL-encoding with bisection-for-bits-per-pixel (tolerance 0.5%), using helper:
+bisector --var=BB --range=0.1,3.0 --target=1.2 --rtol_val=0.005 \
+ --cmd='(build/tools/cjxl --distance=$BB /tmp/baseball.png /tmp/baseball_$BB.jxl && get_bpp /tmp/baseball_$BB.jxl)' \
+ --final='mv /tmp/baseball_$BB.jxl /tmp/baseball.jxl; rm -f /tmp/baseball_*.jxl' \
+ --verbosity=1
+# => ...
+"""
+
+import argparse
+import os
+import re
+import subprocess
+import sys
+
+
+def _expandvars(vardef, env,
+ max_recursion=100,
+ max_length=10**6,
+ verbosity=0):
+ """os.path.expandvars() variant using parameter env rather than os.environ."""
+ current_expanded = vardef
+ for num_recursions in range(max_recursion):
+ if verbosity >= 3:
+ print(f'_expandvars(): num_recursions={num_recursions}, '
+ f'len={len(current_expanded)}' +
+ (', current: ' + current_expanded if verbosity >= 4 else ''))
+ if len > max_length:
+ break
+ current_expanded, num_replacements = re.subn(
+ r'$\{(\w+)\}|$(\w+)',
+ lambda m: env.get(m[1] if m[1] is not None else m[2], ''),
+ current_expanded)
+ if num_replacements == 0:
+ break
+ return current_expanded
+
+
+def _strtod(string):
+ """Extracts leftmost float from string (like strtod(3))."""
+ match = re.match(r'[+-]?\d*[.]?\d*(?:[eE][+-]?\d+)?', string)
+ return float(match[0]) if match[0] else None
+
+
+def run_shell_command(shell_command,
+ bisect_var, bisect_val,
+ extra_env_defs,
+ verbosity=0):
+ """Runs a shell command with env modifications, fetching return value."""
+ shell_env = dict(os.environ)
+ shell_env[bisect_var] = str(bisect_val)
+ for env_def in extra_env_defs:
+ varname, vardef = env_def.split('=', 1)
+ shell_env[varname] = _expandvars(vardev, shell_env,
+ verbosity=verbosity)
+ shell_ret = subprocess.run(shell_command,
+ # We explicitly want subshell semantics!
+ shell=True,
+ capture_output=True,
+ env=shell_env)
+ stdout = shell_ret.stdout.decode('utf-8')
+ score = _strtod(stdout)
+ if verbosity >= 2:
+ print(f'{bisect_var}={bisect_val} {shell_command} => '
+ f'{shell_ret.returncode} # {stdout.strip()}')
+ return (shell_ret.returncode == 0, # Command was successful?
+ score)
+
+
+def _bisect(*,
+ shell_command,
+ final_shell_command,
+ target,
+ int_args,
+ bisect_var, bisect_left, bisect_right,
+ rtol_val, atol_val, rtol_arg, atol_arg,
+ maxiter,
+ extra_env_defs,
+ verbosity=0
+ ):
+ """Performs bisection."""
+ def _get_val(x):
+ success, val = run_shell_command(shell_command,
+ bisect_var, x,
+ extra_env_defs,
+ verbosity=verbosity)
+ if not success:
+ raise RuntimeError(f'Bisection failed for: {bisect_var}={x}: '
+ f'success={success}, val={val}, '
+ f'cmd={shell_command}, var={bisect_var}')
+ return val
+ #
+ bisect_mid, value_mid = None, None
+ try:
+ value_left = _get_val(bisect_left)
+ value_right = _get_val(bisect_right)
+ if (value_left < target) != (target <= value_right):
+ raise RuntimeError(
+ f'Cannot bisect: target={target}, value_left={value_left}, '
+ f'value_right={value_right}')
+ for num_iter in range(maxiter):
+ bisect_mid_f = 0.5 * (bisect_left + bisect_right)
+ bisect_mid = round(bisect_mid_f) if int_args else bisect_mid_f
+ value_mid = _get_val(bisect_mid)
+ if (value_left < target) == (value_mid < target):
+ # Relative to target, `value_mid` is on the same side
+ # as `value_left`.
+ bisect_left = bisect_mid
+ value_left = value_mid
+ else:
+ # Otherwise, this situation must hold for value_right
+ # ("tertium non datur").
+ bisect_right = bisect_mid
+ value_right = value_mid
+ if verbosity >= 1:
+ print(f'bisect target={target}, '
+ f'left: {value_left} at {bisect_left}, '
+ f'right: {value_right} at {bisect_right}, '
+ f'mid: {value_mid} at {bisect_mid}')
+ delta_val = target - value_mid
+ if abs(delta_val) <= atol_val + rtol_val * abs(target):
+ return 1, bisect_left, value_left, bisect_right, value_right
+ delta_arg = bisect_right - bisect_left
+ # Also check whether the argument is "within tolerance".
+ # Here, we have to be careful if bisect_left and bisect_right
+ # have different signs: Then, their absolute magnitude
+ # "sets the relevant scale".
+ if abs(delta_arg) <= atol_arg + (
+ rtol_arg * 0.5 * (abs(bisect_left) + abs(bisect_right))):
+ return 1, bisect_left, value_left, bisect_right, value_right
+ return 0, bisect_left, value_left, bisect_right, value_right
+ finally:
+ # If cleanup is specified, always run it
+ if final_shell_command:
+ run_shell_command(
+ final_shell_command,
+ bisect_var,
+ bisect_mid if bisect_mid is not None else bisect_left,
+ extra_env_defs, verbosity=verbosity)
+
+
+def main(args):
+ """Main entry point."""
+ parser = argparse.ArgumentParser(description='mhtml_walk args')
+ parser.add_argument(
+ '--var',
+ help='The variable to use for bisection.',
+ default='BISECT')
+ parser.add_argument(
+ '--range',
+ help=('The argument range for bisecting, as {low},{high}. '
+ 'If no argument has a decimal dot, assume integer parameters.'),
+ default='0.0,1.0')
+ parser.add_argument(
+ '--max',
+ help='The maximal value for bisecting.',
+ type=float,
+ default=0.0)
+ parser.add_argument(
+ '--target',
+ help='The target value to aim for.',
+ type=float,
+ default=1.0)
+ parser.add_argument(
+ '--maxiter',
+ help='The maximal number of iterations to perform.',
+ type=int,
+ default=40)
+ parser.add_argument(
+ '--rtol_val',
+ help='Relative tolerance to accept for deviations from target value.',
+ type=float,
+ default=0.0)
+ parser.add_argument(
+ '--atol_val',
+ help='Absolute tolerance to accept for deviations from target value.',
+ type=float,
+ default=0.0)
+ parser.add_argument(
+ '--rtol_arg',
+ help='Relative tolerance to accept for the argument.',
+ type=float,
+ default=0.0)
+ parser.add_argument(
+ '--atol_arg',
+ help=('Absolute tolerance to accept for the argument '
+ '(e.g. for bisecting change-IDs).'),
+ type=float,
+ default=0.0)
+ parser.add_argument(
+ '--verbosity',
+ help='The verbosity level.',
+ type=int,
+ default=1)
+ parser.add_argument(
+ '--env',
+ help=('Comma-separated list of extra environment variables '
+ 'to incrementally add before executing the shell-command.'),
+ default='')
+ parser.add_argument(
+ '--cmd',
+ help=('The shell command to execute. Must print a numerical result '
+ 'to stdout.'))
+ parser.add_argument(
+ '--final',
+ help='The cleanup shell command to execute.')
+ #
+ parsed = parser.parse_args(args)
+ extra_env_defs = tuple(filter(None, parsed.env.split(',')))
+ try:
+ low_high = parsed.range.split(',')
+ if len(low_high) != 2:
+ raise ValueError('--range must be {low},{high}')
+ int_args = False
+ low_val, high_val = map(float, low_high)
+ low_val_int = round(low_val)
+ high_val_int = round(high_val)
+ if low_high == [str(low_val_int), str(high_val_int)]:
+ int_args = True
+ low_val = low_val_int
+ high_val = high_val_int
+ ret = _bisect(
+ shell_command=parsed.cmd,
+ final_shell_command=parsed.final,
+ target=parsed.target,
+ int_args=int_args,
+ bisect_var=parsed.var,
+ bisect_left=low_val,
+ bisect_right=high_val,
+ rtol_val=parsed.rtol_val,
+ atol_val=parsed.atol_val,
+ rtol_arg=parsed.rtol_arg,
+ atol_arg=parsed.atol_arg,
+ maxiter=parsed.maxiter,
+ extra_env_defs=extra_env_defs,
+ verbosity=parsed.verbosity,
+ )
+ print(' '.join(map(str, ret)))
+ except Exception as exn:
+ sys.exit(f'Problem: {exn}')
+
+
+if __name__ == '__main__':
+ main(sys.argv[1:])