From da7593d27a1018e3bb0f812ce15fa21c91413187 Mon Sep 17 00:00:00 2001 From: Naushir Patuck Date: Thu, 2 Apr 2026 10:51:06 +0100 Subject: [PATCH 1/2] utils: test_convert: Fix GStreamer pipeline colorimetry and byte ordering Set explicit JPEG colorimetry (1:4:0:0) on input and output caps to ensure the correct BT601 full-range matrix is used for YUV conversions, matching the convert binary behaviour. Use videoconvert to output BGR for RGB test cases, accounting for pispconvert's R/B channel swap on RGB output. Also run ruff check --fix and ruff format across all Python utility scripts to fix lint warnings and standardise formatting. Signed-off-by: Naushir Patuck --- utils/colourspace_calcs.py | 37 ++++++--- utils/generate_filter.py | 80 ++++++++++++------ utils/test_convert.py | 165 +++++++++++++++++++++++-------------- utils/version.py | 46 +++++++---- 4 files changed, 212 insertions(+), 116 deletions(-) diff --git a/utils/colourspace_calcs.py b/utils/colourspace_calcs.py index da22aee..edf55a7 100644 --- a/utils/colourspace_calcs.py +++ b/utils/colourspace_calcs.py @@ -9,18 +9,32 @@ import numpy as np -BT601 = np.array([[0.299, 0.5870, 0.1140], [-0.168736, -0.331264, 0.5], [0.5, -0.418688, -0.081312]]) -REC709 = np.array([[0.2126, 0.7152, 0.0722], [-0.1146, -0.3854, 0.5], [0.5, -0.4542, -0.0458]]) -REC2020 = np.array([[0.2627, 0.6780, 0.0593], [-0.13963006, -0.36036994, 0.5], [0.5, -0.4597857, -0.0402143]]) +BT601 = np.array( + [[0.299, 0.5870, 0.1140], [-0.168736, -0.331264, 0.5], [0.5, -0.418688, -0.081312]] +) +REC709 = np.array( + [[0.2126, 0.7152, 0.0722], [-0.1146, -0.3854, 0.5], [0.5, -0.4542, -0.0458]] +) +REC2020 = np.array( + [ + [0.2627, 0.6780, 0.0593], + [-0.13963006, -0.36036994, 0.5], + [0.5, -0.4597857, -0.0402143], + ] +) colour_spaces = {"select": "default"} + def flatten(array): return [round(num) for num in list(array.flatten())] + def add_entry(name, M, limited): offsets = np.array([0, 128, 128]) - scaling = np.array([[(235 - 16) / 255, 0, 0], [0, (240 - 16) / 255, 0], [0, 0, (240 - 16) / 255]]) + scaling = np.array( + [[(235 - 16) / 255, 0, 0], [0, (240 - 16) / 255, 0], [0, 0, (240 - 16) / 255]] + ) if limited: offsets = np.array([16, 128, 128]) M = np.matmul(scaling, M) @@ -28,14 +42,15 @@ def add_entry(name, M, limited): colour_spaces[name] = {} colour_spaces[name]["ycbcr"] = {} colour_spaces[name]["ycbcr"]["coeffs"] = flatten(M * 1024) - colour_spaces[name]["ycbcr"]["offsets"] = flatten(offsets * (2 ** 18)) + colour_spaces[name]["ycbcr"]["offsets"] = flatten(offsets * (2**18)) colour_spaces[name]["ycbcr_inverse"] = {} colour_spaces[name]["ycbcr_inverse"]["coeffs"] = flatten(Mi * 1024) - inv_offsets = np.rint(np.dot(Mi, -offsets) * (2 ** 18)) + inv_offsets = np.rint(np.dot(Mi, -offsets) * (2**18)) colour_spaces[name]["ycbcr_inverse"]["offsets"] = flatten(inv_offsets) - if inv_offsets.min() < -2 ** 26 or inv_offsets.max() >= 2 ** 26: + if inv_offsets.min() < -(2**26) or inv_offsets.max() >= 2**26: print("WARNING:", name, "will overflow!") + add_entry("default", BT601, limited=False) add_entry("jpeg", BT601, limited=False) add_entry("smpte170m", BT601, limited=True) @@ -44,18 +59,20 @@ def add_entry(name, M, limited): add_entry("bt2020", REC2020, limited=True) add_entry("bt2020_full", REC2020, limited=False) + def print_dict(d, indent=0): print("{") indent += 4 for i, (k, v) in enumerate(d.items()): if type(v) is dict: - print(" " * indent, f'"{k}"', ": ", end='', sep='') + print(" " * indent, f'"{k}"', ": ", end="", sep="") print_dict(v, indent) else: - print(" " * indent, f'"{k}"', ": ", v, end='', sep='') + print(" " * indent, f'"{k}"', ": ", v, end="", sep="") print("," if i < len(d) - 1 else "") indent -= 4 - print(" " * indent, "}", end ='', sep='') + print(" " * indent, "}", end="", sep="") + final_dict = {"colour_encoding": colour_spaces} print_dict(final_dict) diff --git a/utils/generate_filter.py b/utils/generate_filter.py index 26d74aa..4db5baf 100644 --- a/utils/generate_filter.py +++ b/utils/generate_filter.py @@ -28,9 +28,18 @@ def mitchell(B, C, N): for i in range(N): ax = abs(x[i]) if ax < 1: - h[i] = ((12 - 9 * B - 6 * C) * ax**3 + (-18 + 12 * B + 6 * C) * ax**2 + (6 - 2 * B)) / 6 + h[i] = ( + (12 - 9 * B - 6 * C) * ax**3 + + (-18 + 12 * B + 6 * C) * ax**2 + + (6 - 2 * B) + ) / 6 elif (ax >= 1) and (ax < 2): - h[i] = ((-B - 6 * C) * ax**3 + (6 * B + 30 * C) * ax**2 + (-12 * B - 48 * C) * ax + (8 * B + 24 * C)) / 6 + h[i] = ( + (-B - 6 * C) * ax**3 + + (6 * B + 30 * C) * ax**2 + + (-12 * B - 48 * C) * ax + + (8 * B + 24 * C) + ) / 6 return h @@ -50,14 +59,33 @@ def bicubic_spline(a, N): def main(): parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) - parser.add_argument('--phases', metavar='P', type=int, help='Number of phases.', default=16) - parser.add_argument('--taps', metavar='T', type=int, help='Number of filter taps per phase.', default=6) - parser.add_argument('--precision', metavar='PR', type=int, help='Filter precision required.', default=10) - parser.add_argument('--filter', type=str, metavar='F', - help='Filter type and parameters, e.g.: \n' - '"Mitchell, b = 0.333, c = 0.333"\n' - '"Lanczos, order = 3"\n' - '"bicubic_spline, a=-0.5"', required=True) + parser.add_argument( + "--phases", metavar="P", type=int, help="Number of phases.", default=16 + ) + parser.add_argument( + "--taps", + metavar="T", + type=int, + help="Number of filter taps per phase.", + default=6, + ) + parser.add_argument( + "--precision", + metavar="PR", + type=int, + help="Filter precision required.", + default=10, + ) + parser.add_argument( + "--filter", + type=str, + metavar="F", + help="Filter type and parameters, e.g.: \n" + '"Mitchell, b = 0.333, c = 0.333"\n' + '"Lanczos, order = 3"\n' + '"bicubic_spline, a=-0.5"', + required=True, + ) args = parser.parse_args() @@ -66,24 +94,26 @@ def main(): precision = args.precision # Parse the filter string and pick out the needed parameters. - filt = args.filter.split(',') - params = {'a': 0., 'b': 0., 'c': 0., 'order': 0} + filt = args.filter.split(",") + params = {"a": 0.0, "b": 0.0, "c": 0.0, "order": 0} for param in filt[1:]: - p = param.replace(' ', '').split('=') + p = param.replace(" ", "").split("=") params[p[0]] = type(params[p[0]])(p[1]) # Generate the filter. - if (filt[0].lower() == 'mitchell'): - filter = f'"Michell - Netravali (B = {params["b"]:.3f}, C = {params["c"]:.3f})": [\n' - h = mitchell(params['b'], params['c'], phases * taps) - elif (filt[0].lower() == 'lanczos'): + if filt[0].lower() == "mitchell": + filter = ( + f'"Michell - Netravali (B = {params["b"]:.3f}, C = {params["c"]:.3f})": [\n' + ) + h = mitchell(params["b"], params["c"], phases * taps) + elif filt[0].lower() == "lanczos": filter = f'"Lanczos order {params["order"]}": [\n' - h = lanczos(params['order'], phases * taps) - elif (filt[0].lower() == 'bicubic_spline'): + h = lanczos(params["order"], phases * taps) + elif filt[0].lower() == "bicubic_spline": filter = f'"Bicubic-spline (a = {params["a"]:.3f})": [\n' - h = bicubic_spline(params['a'], phases * taps) + h = bicubic_spline(params["a"], phases * taps) else: - print(f'Invalid filter ({filt[0]}) selected!') + print(f"Invalid filter ({filt[0]}) selected!") exit() # Normalise and convert to fixed-point. @@ -98,13 +128,13 @@ def main(): max_index = np.nonzero(ppf[i] == ppf[i].max())[0] ppf[i, max_index] += (1 << precision) - np.int32(ppf[i].sum() / max_index.size) - nl = '\n' + nl = "\n" for i in range(phases): - phase = ', '.join([f'{c:>4}' for c in ppf[i, :]]) - filter += f' {phase}{nl+"]" if i==phases-1 else ","+nl}' + phase = ", ".join([f"{c:>4}" for c in ppf[i, :]]) + filter += f" {phase}{nl + ']' if i == phases - 1 else ',' + nl}" print(filter) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/utils/test_convert.py b/utils/test_convert.py index 85b2260..41e01c0 100644 --- a/utils/test_convert.py +++ b/utils/test_convert.py @@ -11,14 +11,19 @@ import os import subprocess import sys -import tempfile -import shutil -from pathlib import Path import hashlib class ConvertTester: - def __init__(self, convert_binary, output_dir=None, input_dir=None, reference_dir=None, use_gstreamer=False, gst_plugin_path=None): + def __init__( + self, + convert_binary, + output_dir=None, + input_dir=None, + reference_dir=None, + use_gstreamer=False, + gst_plugin_path=None, + ): """Initialize the tester with the path to the convert binary.""" self.convert_binary = convert_binary self.output_dir = output_dir @@ -38,7 +43,7 @@ def __init__(self, convert_binary, output_dir=None, input_dir=None, reference_di "input_format": "4056:3040:4056:YUV420P", "output_format": "4056:3040:12168:RGB888", "reference_file": "ref_4056x3050_12168s_rgb888.rgb", - "skip_gst": False + "skip_gst": False, }, { "input_file": "conv_800x600_1200s_422_yuyv.yuv", @@ -46,7 +51,7 @@ def __init__(self, convert_binary, output_dir=None, input_dir=None, reference_di "input_format": "800:600:1600:YUYV", "output_format": "1600:1200:1600:YUV422P", "reference_file": "ref_1600x1200_1600_422p.yuv", - "skip_gst": False + "skip_gst": False, }, { "input_file": "conv_rgb888_800x600_2432s.rgb", @@ -54,33 +59,33 @@ def __init__(self, convert_binary, output_dir=None, input_dir=None, reference_di "input_format": "800:600:2432:RGB888", "output_format": "4000:3000:4032:YUV444P", "reference_file": "ref_4000x3000_4032s.yuv", - "skip_gst": True + "skip_gst": True, }, # Add more test cases here as needed ] def _parse_format(self, format_str): """Parse format string like '4056:3040:4056:YUV420P' into components.""" - parts = format_str.split(':') + parts = format_str.split(":") if len(parts) != 4: raise ValueError(f"Invalid format string: {format_str}") return { - 'width': int(parts[0]), - 'height': int(parts[1]), - 'stride': int(parts[2]), - 'format': parts[3] + "width": int(parts[0]), + "height": int(parts[1]), + "stride": int(parts[2]), + "format": parts[3], } def _pisp_to_gst_format(self, pisp_format): """Convert PiSP format to GStreamer format string.""" format_map = { - 'YUV420P': 'I420', - 'YVU420P': 'YV12', - 'YUV422P': 'Y42B', - 'YUV444P': 'Y444', - 'YUYV': 'YUY2', - 'UYVY': 'UYVY', - 'RGB888': 'RGB', + "YUV420P": "I420", + "YVU420P": "YV12", + "YUV422P": "Y42B", + "YUV444P": "Y444", + "YUYV": "YUY2", + "UYVY": "UYVY", + "RGB888": "RGB", } return format_map.get(pisp_format, pisp_format) @@ -99,34 +104,50 @@ def run_gstreamer(self, input_file, output_file, input_format, output_format): out_fmt = self._parse_format(output_format) # Convert to GStreamer format names - gst_in_format = self._pisp_to_gst_format(in_fmt['format']) - gst_out_format = self._pisp_to_gst_format(out_fmt['format']) + gst_in_format = self._pisp_to_gst_format(in_fmt["format"]) + gst_out_format = self._pisp_to_gst_format(out_fmt["format"]) + # pispconvert swaps R/B for RGB, use BGR file output to match convert reference + gst_file_format = "BGR" if gst_out_format == "RGB" else gst_out_format # Build GStreamer pipeline pipeline = [ - 'gst-launch-1.0', - 'filesrc', f'location={input_file}', '!', - 'rawvideoparse', - f'width={in_fmt["width"]}', - f'height={in_fmt["height"]}', - f'format={gst_in_format.lower()}', - 'framerate=30/1', '!', - 'pispconvert', '!', - f'video/x-raw,format={gst_out_format},width={out_fmt["width"]},height={out_fmt["height"]}', '!', - 'filesink', f'location={output_file}' + "gst-launch-1.0", + "filesrc", + f"location={input_file}", + "!", + "rawvideoparse", + f"width={in_fmt['width']}", + f"height={in_fmt['height']}", + f"format={gst_in_format.lower()}", + "framerate=30/1", + "!", + "video/x-raw,colorimetry=1:4:0:0", + "!", + "pispconvert", + "!", + f"video/x-raw,format={gst_out_format},width={out_fmt['width']},height={out_fmt['height']},colorimetry=1:4:0:0", + "!", + "videoconvert", + "!", + f"video/x-raw,format={gst_file_format},width={out_fmt['width']},height={out_fmt['height']}", + "!", + "filesink", + f"location={output_file}", ] - print(f"Running GStreamer pipeline:") - print(' '.join(pipeline)) + print("Running GStreamer pipeline:") + print(" ".join(pipeline)) # Set GST_PLUGIN_PATH environment variable if specified env = os.environ.copy() if self.gst_plugin_path: - env['GST_PLUGIN_PATH'] = self.gst_plugin_path + env["GST_PLUGIN_PATH"] = self.gst_plugin_path print(f"GST_PLUGIN_PATH={self.gst_plugin_path}") try: - result = subprocess.run(pipeline, capture_output=True, text=True, check=True, env=env) + subprocess.run( + pipeline, capture_output=True, text=True, check=True, env=env + ) print("GStreamer pipeline completed successfully") return True except subprocess.CalledProcessError as e: @@ -149,14 +170,16 @@ def run_convert(self, input_file, output_file, input_format, output_format): self.convert_binary, input_file, output_file, - "--input-format", input_format, - "--output-format", output_format + "--input-format", + input_format, + "--output-format", + output_format, ] print(f"Running: {' '.join(cmd)}") try: - result = subprocess.run(cmd, capture_output=True, text=True, check=True) + subprocess.run(cmd, capture_output=True, text=True, check=True) print("Convert completed successfully") return True except subprocess.CalledProcessError as e: @@ -203,7 +226,7 @@ def _file_hash(self, filepath): def run_test_case(self, test_case): """Run a single test case.""" - print(f"\n=== Running test case ===") + print("\n=== Running test case ===") print(f"Input file: {test_case['input_file']}") print(f"Output file: {test_case['output_file']}") print(f"Input format: {test_case['input_format']}") @@ -211,49 +234,51 @@ def run_test_case(self, test_case): print(f"Reference file: {test_case['reference_file']}") # Check if input file exists - input_file = test_case['input_file'] + input_file = test_case["input_file"] if self.input_dir: - input_file = os.path.join(self.input_dir, test_case['input_file']) + input_file = os.path.join(self.input_dir, test_case["input_file"]) if not os.path.exists(input_file): print(f"Error: Input file {input_file} does not exist") return False # Skip GStreamer test if marked to skip - if self.use_gstreamer and test_case.get('skip_gst', False): - print(f"SKIPPED: Test case marked as skip_gst=True") + if self.use_gstreamer and test_case.get("skip_gst", False): + print("SKIPPED: Test case marked as skip_gst=True") return None # Return None to indicate skipped # Run the convert utility or GStreamer pipeline if self.use_gstreamer: success = self.run_gstreamer( - test_case['input_file'], - test_case['output_file'], - test_case['input_format'], - test_case['output_format'] + test_case["input_file"], + test_case["output_file"], + test_case["input_format"], + test_case["output_format"], ) else: success = self.run_convert( - test_case['input_file'], - test_case['output_file'], - test_case['input_format'], - test_case['output_format'] + test_case["input_file"], + test_case["output_file"], + test_case["input_format"], + test_case["output_format"], ) if not success: return False # Compare with reference file if it exists - reference_file = test_case['reference_file'] + reference_file = test_case["reference_file"] if self.reference_dir: - reference_file = os.path.join(self.reference_dir, test_case['reference_file']) + reference_file = os.path.join( + self.reference_dir, test_case["reference_file"] + ) if os.path.exists(reference_file): - print(f"Comparing output with reference file...") + print("Comparing output with reference file...") # Use output directory for the generated output file - output_file = test_case['output_file'] + output_file = test_case["output_file"] if self.output_dir: - output_file = os.path.join(self.output_dir, test_case['output_file']) + output_file = os.path.join(self.output_dir, test_case["output_file"]) return self.compare_files(output_file, reference_file) else: print(f"Reference file {reference_file} not found") @@ -293,7 +318,7 @@ def run_all_tests(self): failed += 1 print("✗ Test FAILED") - print(f"\n=== Test Summary ===") + print("\n=== Test Summary ===") print(f"Passed: {passed}") print(f"Failed: {failed}") print(f"Skipped: {skipped}") @@ -303,13 +328,25 @@ def run_all_tests(self): def main(): - parser = argparse.ArgumentParser(description="Test script for libpisp convert utility") - parser.add_argument("convert_binary", nargs='?', default=None, help="Path to the convert binary (not needed with --gst-plugin-path)") + parser = argparse.ArgumentParser( + description="Test script for libpisp convert utility" + ) + parser.add_argument( + "convert_binary", + nargs="?", + default=None, + help="Path to the convert binary (not needed with --gst-plugin-path)", + ) parser.add_argument("--test-dir", help="Directory containing test files") - parser.add_argument("--in", dest="input_dir", help="Directory containing input files") + parser.add_argument( + "--in", dest="input_dir", help="Directory containing input files" + ) parser.add_argument("--out", help="Directory where output files will be written") parser.add_argument("--ref", help="Directory containing reference files") - parser.add_argument("--gst-plugin-path", help="Path to GStreamer plugin directory (enables GStreamer testing)") + parser.add_argument( + "--gst-plugin-path", + help="Path to GStreamer plugin directory (enables GStreamer testing)", + ) args = parser.parse_args() @@ -319,7 +356,9 @@ def main(): # Validate arguments if not use_gstreamer and not args.convert_binary: - parser.error("convert_binary is required unless --gst-plugin-path is specified") + parser.error( + "convert_binary is required unless --gst-plugin-path is specified" + ) tester = ConvertTester( args.convert_binary, @@ -327,7 +366,7 @@ def main(): args.input_dir, args.ref, use_gstreamer=use_gstreamer, - gst_plugin_path=args.gst_plugin_path + gst_plugin_path=args.gst_plugin_path, ) # Change to test directory if specified diff --git a/utils/version.py b/utils/version.py index b017646..5fdec05 100755 --- a/utils/version.py +++ b/utils/version.py @@ -9,8 +9,6 @@ import subprocess import sys import time -from datetime import datetime -from string import hexdigits digits = 12 @@ -19,39 +17,51 @@ def generate_version(): try: if len(sys.argv) == 2: # Check if this is a git directory - r = subprocess.run(['git', 'rev-parse', '--git-dir'], - stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, universal_newlines=True) + r = subprocess.run( + ["git", "rev-parse", "--git-dir"], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + universal_newlines=True, + ) if r.returncode: - raise RuntimeError('Invalid git directory!') + raise RuntimeError("Invalid git directory!") # Get commit id - r = subprocess.run(['git', 'rev-parse', '--verify', 'HEAD'], - stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, universal_newlines=True) + r = subprocess.run( + ["git", "rev-parse", "--verify", "HEAD"], + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + universal_newlines=True, + ) if r.returncode: - raise RuntimeError('Invalid git commit!') + raise RuntimeError("Invalid git commit!") - commit = r.stdout.strip('\n')[0:digits] + commit = r.stdout.strip("\n")[0:digits] # Check dirty status - r = subprocess.run(['git', 'diff-index', '--quiet', 'HEAD'], - stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, universal_newlines=True) + r = subprocess.run( + ["git", "diff-index", "--quiet", "HEAD"], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + universal_newlines=True, + ) if r.returncode: - commit = commit + '-dirty' + commit = commit + "-dirty" else: - raise RuntimeError('Invalid number of command line arguments') + raise RuntimeError("Invalid number of command line arguments") - commit = f'v{sys.argv[1]} {commit}' + commit = f"v{sys.argv[1]} {commit}" - except RuntimeError as e: - commit = f'v{sys.argv[1]}' + except RuntimeError: + commit = f"v{sys.argv[1]}" finally: date_str = time.strftime( "%d-%m-%Y (%H:%M:%S)", - time.gmtime(int(os.environ.get('SOURCE_DATE_EPOCH', time.time()))) + time.gmtime(int(os.environ.get("SOURCE_DATE_EPOCH", time.time()))), ) - print(f'{commit} {date_str}', end="") + print(f"{commit} {date_str}", end="") if __name__ == "__main__": From 2cb61f2d1cfa3edbd2de8f5a6178715768a0871a Mon Sep 17 00:00:00 2001 From: Naushir Patuck Date: Thu, 2 Apr 2026 10:51:10 +0100 Subject: [PATCH 2/2] ci: Replace pycodestyle with ruff check and ruff format Replace the Pep8Checker with a RuffChecker in checkstyle.py and add a RuffFormatter for Python files. Add ruff check and ruff format steps to the CI style checker workflow. Signed-off-by: Naushir Patuck --- .github/workflows/libpisp-style-checker.yml | 9 + utils/checkstyle.py | 351 +++++++++++++------- 2 files changed, 231 insertions(+), 129 deletions(-) diff --git a/.github/workflows/libpisp-style-checker.yml b/.github/workflows/libpisp-style-checker.yml index 026fb12..07c67f8 100644 --- a/.github/workflows/libpisp-style-checker.yml +++ b/.github/workflows/libpisp-style-checker.yml @@ -17,5 +17,14 @@ jobs: fetch-depth: 0 clean: true + - name: Install ruff + run: pip install ruff + + - name: Ruff check + run: ruff check utils/ + + - name: Ruff format + run: ruff format --check utils/ + - name: Check style run: ${{github.workspace}}/utils/checkstyle.py $(git log --format=%P -1 | awk '{print $1 ".." $2}') diff --git a/utils/checkstyle.py b/utils/checkstyle.py index 04f7215..693df8a 100755 --- a/utils/checkstyle.py +++ b/utils/checkstyle.py @@ -21,8 +21,8 @@ import sys dependencies = { - 'clang-format': True, - 'git': True, + "clang-format": True, + "git": True, } # ------------------------------------------------------------------------------ @@ -52,29 +52,30 @@ class Colours: @staticmethod def fg(colour): if sys.stdout.isatty(): - return '\033[%um' % colour + return "\033[%um" % colour else: - return '' + return "" @staticmethod def bg(colour): if sys.stdout.isatty(): - return '\033[%um' % (colour + 10) + return "\033[%um" % (colour + 10) else: - return '' + return "" @staticmethod def reset(): if sys.stdout.isatty(): - return '\033[0m' + return "\033[0m" else: - return '' + return "" # ------------------------------------------------------------------------------ # Diff parsing, handling and printing # + class DiffHunkSide(object): """A side of a diff hunk, recording line numbers""" @@ -88,8 +89,7 @@ def __len__(self): class DiffHunk(object): - diff_header_regex = re.compile( - r'@@ -([0-9]+),?([0-9]+)? \+([0-9]+),?([0-9]+)? @@') + diff_header_regex = re.compile(r"@@ -([0-9]+),?([0-9]+)? \+([0-9]+),?([0-9]+)? @@") def __init__(self, line): match = DiffHunk.diff_header_regex.match(line) @@ -104,18 +104,21 @@ def __init__(self, line): self.lines = [] def __repr__(self): - s = '%s@@ -%u,%u +%u,%u @@\n' % \ - (Colours.fg(Colours.Cyan), - self.__from.start, len(self.__from), - self.__to.start, len(self.__to)) + s = "%s@@ -%u,%u +%u,%u @@\n" % ( + Colours.fg(Colours.Cyan), + self.__from.start, + len(self.__from), + self.__to.start, + len(self.__to), + ) for line in self.lines: - if line[0] == '-': + if line[0] == "-": s += Colours.fg(Colours.Red) - elif line[0] == '+': + elif line[0] == "+": s += Colours.fg(Colours.Green) - if line[0] == '-': + if line[0] == "-": spaces = 0 for i in range(len(line)): if line[-i - 1].isspace(): @@ -127,24 +130,24 @@ def __repr__(self): s += line s += Colours.reset() - s += '\n' + s += "\n" return s[:-1] def append(self, line): - if line[0] == ' ': + if line[0] == " ": self.__from.untouched.append(self.__from_line) self.__from_line += 1 self.__to.untouched.append(self.__to_line) self.__to_line += 1 - elif line[0] == '-': + elif line[0] == "-": self.__from.touched.append(self.__from_line) self.__from_line += 1 - elif line[0] == '+': + elif line[0] == "+": self.__to.touched.append(self.__to_line) self.__to_line += 1 - self.lines.append(line.rstrip('\n')) + self.lines.append(line.rstrip("\n")) def intersects(self, lines): for line in lines: @@ -153,7 +156,7 @@ def intersects(self, lines): return False def side(self, side): - if side == 'from': + if side == "from": return self.__from else: return self.__to @@ -163,7 +166,7 @@ def parse_diff(diff): hunks = [] hunk = None for line in diff: - if line.startswith('@@'): + if line.startswith("@@"): if hunk: hunks.append(hunk) hunk = DiffHunk(line) @@ -181,13 +184,14 @@ def parse_diff(diff): # Commit, Staged Changes & Amendments # + class CommitFile: def __init__(self, name): info = name.split() self.__status = info[0][0] # For renamed files, store the new name - if self.__status == 'R': + if self.__status == "R": self.__filename = info[2] else: self.__filename = info[1] @@ -208,14 +212,15 @@ def __init__(self, commit): def _parse(self): # Get the commit title and list of files. - ret = subprocess.run(['git', 'show', '--pretty=oneline', '--name-status', - self.commit], - stdout=subprocess.PIPE).stdout.decode('utf-8') + ret = subprocess.run( + ["git", "show", "--pretty=oneline", "--name-status", self.commit], + stdout=subprocess.PIPE, + ).stdout.decode("utf-8") files = ret.splitlines() self._files = [CommitFile(f) for f in files[1:]] self._title = files[0] - def files(self, filter='AMR'): + def files(self, filter="AMR"): return [f.filename for f in self._files if f.status in filter] @property @@ -223,30 +228,40 @@ def title(self): return self._title def get_diff(self, top_level, filename): - diff = subprocess.run(['git', 'diff', '%s~..%s' % (self.commit, self.commit), - '--', '%s/%s' % (top_level, filename)], - stdout=subprocess.PIPE).stdout.decode('utf-8') + diff = subprocess.run( + [ + "git", + "diff", + "%s~..%s" % (self.commit, self.commit), + "--", + "%s/%s" % (top_level, filename), + ], + stdout=subprocess.PIPE, + ).stdout.decode("utf-8") return parse_diff(diff.splitlines(True)) def get_file(self, filename): - return subprocess.run(['git', 'show', '%s:%s' % (self.commit, filename)], - stdout=subprocess.PIPE).stdout.decode('utf-8') + return subprocess.run( + ["git", "show", "%s:%s" % (self.commit, filename)], stdout=subprocess.PIPE + ).stdout.decode("utf-8") class StagedChanges(Commit): def __init__(self): - Commit.__init__(self, '') + Commit.__init__(self, "") def _parse(self): - ret = subprocess.run(['git', 'diff', '--staged', '--name-status'], - stdout=subprocess.PIPE).stdout.decode('utf-8') + ret = subprocess.run( + ["git", "diff", "--staged", "--name-status"], stdout=subprocess.PIPE + ).stdout.decode("utf-8") self._title = "Staged changes" self._files = [CommitFile(f) for f in ret.splitlines()] def get_diff(self, top_level, filename): - diff = subprocess.run(['git', 'diff', '--staged', '--', - '%s/%s' % (top_level, filename)], - stdout=subprocess.PIPE).stdout.decode('utf-8') + diff = subprocess.run( + ["git", "diff", "--staged", "--", "%s/%s" % (top_level, filename)], + stdout=subprocess.PIPE, + ).stdout.decode("utf-8") return parse_diff(diff.splitlines(True)) @@ -256,18 +271,22 @@ def __init__(self): def _parse(self): # Create a title using HEAD commit - ret = subprocess.run(['git', 'show', '--pretty=oneline', '--no-patch'], - stdout=subprocess.PIPE).stdout.decode('utf-8') - self._title = 'Amendment of ' + ret.strip() + ret = subprocess.run( + ["git", "show", "--pretty=oneline", "--no-patch"], stdout=subprocess.PIPE + ).stdout.decode("utf-8") + self._title = "Amendment of " + ret.strip() # Extract the list of modified files - ret = subprocess.run(['git', 'diff', '--staged', '--name-status', 'HEAD~'], - stdout=subprocess.PIPE).stdout.decode('utf-8') + ret = subprocess.run( + ["git", "diff", "--staged", "--name-status", "HEAD~"], + stdout=subprocess.PIPE, + ).stdout.decode("utf-8") self._files = [CommitFile(f) for f in ret.splitlines()] def get_diff(self, top_level, filename): - diff = subprocess.run(['git', 'diff', '--staged', 'HEAD~', '--', - '%s/%s' % (top_level, filename)], - stdout=subprocess.PIPE).stdout.decode('utf-8') + diff = subprocess.run( + ["git", "diff", "--staged", "HEAD~", "--", "%s/%s" % (top_level, filename)], + stdout=subprocess.PIPE, + ).stdout.decode("utf-8") return parse_diff(diff.splitlines(True)) @@ -275,6 +294,7 @@ def get_diff(self, top_level, filename): # Helpers # + class ClassRegistry(type): def __new__(cls, clsname, bases, attrs): newclass = super().__new__(cls, clsname, bases, attrs) @@ -287,6 +307,7 @@ def __new__(cls, clsname, bases, attrs): # Commit Checkers # + class CommitChecker(metaclass=ClassRegistry): subclasses = [] @@ -311,6 +332,7 @@ def __init__(self, msg): # Style Checkers # + class StyleChecker(metaclass=ClassRegistry): subclasses = [] @@ -350,13 +372,31 @@ def __init__(self, line_number, line, msg): class IncludeChecker(StyleChecker): - patterns = ('*.cpp', '*.h', '*.hpp') - - headers = ('assert', 'ctype', 'errno', 'fenv', 'float', 'inttypes', - 'limits', 'locale', 'setjmp', 'signal', 'stdarg', 'stddef', - 'stdint', 'stdio', 'stdlib', 'string', 'time', 'uchar', 'wchar', - 'wctype') - include_regex = re.compile('^#include ') + patterns = ("*.cpp", "*.h", "*.hpp") + + headers = ( + "assert", + "ctype", + "errno", + "fenv", + "float", + "inttypes", + "limits", + "locale", + "setjmp", + "signal", + "stdarg", + "stddef", + "stdint", + "stdio", + "stdlib", + "string", + "time", + "uchar", + "wchar", + "wctype", + ) + include_regex = re.compile("^#include ") def __init__(self, content): super().__init__() @@ -375,15 +415,20 @@ def check(self, line_numbers): if header not in IncludeChecker.headers: continue - issues.append(StyleIssue(line_number, line, - 'C compatibility header <%s.h> is preferred' % header)) + issues.append( + StyleIssue( + line_number, + line, + "C compatibility header <%s.h> is preferred" % header, + ) + ) return issues -class Pep8Checker(StyleChecker): - patterns = ('*.py',) - results_regex = re.compile('stdin:([0-9]+):([0-9]+)(.*)') +class RuffChecker(StyleChecker): + patterns = ("*.py",) + results_regex = re.compile(r"^.+:(\d+):(\d+): (.+)$") def __init__(self, content): super().__init__() @@ -391,21 +436,27 @@ def __init__(self, content): def check(self, line_numbers): issues = [] - data = ''.join(self.__content).encode('utf-8') + data = "".join(self.__content).encode("utf-8") try: - ret = subprocess.run(['pycodestyle', '--ignore=E501', '-'], - input=data, stdout=subprocess.PIPE) + ret = subprocess.run( + ["ruff", "check", "--stdin-filename=stdin.py", "-"], + input=data, + stdout=subprocess.PIPE, + ) except FileNotFoundError: - issues.append(StyleIssue( - 0, None, "Please install pycodestyle to validate python additions")) + issues.append( + StyleIssue(0, None, "Please install ruff to validate python additions") + ) return issues - results = ret.stdout.decode('utf-8').splitlines() + results = ret.stdout.decode("utf-8").splitlines() for item in results: - search = re.search(Pep8Checker.results_regex, item) + search = re.search(RuffChecker.results_regex, item) + if not search: + continue + line_number = int(search.group(1)) - position = int(search.group(2)) msg = search.group(3) if line_number in line_numbers: @@ -416,8 +467,8 @@ def check(self, line_numbers): class ShellChecker(StyleChecker): - patterns = ('*.sh',) - results_line_regex = re.compile('In - line ([0-9]+):') + patterns = ("*.sh",) + results_line_regex = re.compile("In - line ([0-9]+):") def __init__(self, content): super().__init__() @@ -425,17 +476,23 @@ def __init__(self, content): def check(self, line_numbers): issues = [] - data = ''.join(self.__content).encode('utf-8') + data = "".join(self.__content).encode("utf-8") try: - ret = subprocess.run(['shellcheck', '-Cnever', '-'], - input=data, stdout=subprocess.PIPE) + ret = subprocess.run( + ["shellcheck", "-Cnever", "-"], input=data, stdout=subprocess.PIPE + ) except FileNotFoundError: - issues.append(StyleIssue( - 0, None, "Please install shellcheck to validate shell script additions")) + issues.append( + StyleIssue( + 0, + None, + "Please install shellcheck to validate shell script additions", + ) + ) return issues - results = ret.stdout.decode('utf-8').splitlines() + results = ret.stdout.decode("utf-8").splitlines() for nr, item in enumerate(results): search = re.search(ShellChecker.results_line_regex, item) if search is None: @@ -445,9 +502,6 @@ def check(self, line_numbers): line = results[nr + 1] msg = results[nr + 2] - # Determined, but not yet used - position = msg.find('^') + 1 - if line_number in line_numbers: issues.append(StyleIssue(line_number, line, msg)) @@ -458,6 +512,7 @@ def check(self, line_numbers): # Formatters # + class Formatter(metaclass=ClassRegistry): subclasses = [] @@ -490,18 +545,20 @@ def all_patterns(cls): class CLangFormatter(Formatter): - patterns = ('*.c', '*.cpp', '*.h', '*.hpp') + patterns = ("*.c", "*.cpp", "*.h", "*.hpp") @classmethod def format(cls, filename, data): - ret = subprocess.run(['clang-format', '-style=file', - '-assume-filename=' + filename], - input=data.encode('utf-8'), stdout=subprocess.PIPE) - return ret.stdout.decode('utf-8') + ret = subprocess.run( + ["clang-format", "-style=file", "-assume-filename=" + filename], + input=data.encode("utf-8"), + stdout=subprocess.PIPE, + ) + return ret.stdout.decode("utf-8") class IncludeOrderFormatter(Formatter): - patterns = ('*.cpp', '*.h', '*.hpp') + patterns = ("*.cpp", "*.h", "*.hpp") include_regex = re.compile('^#include ["<]([^">]*)[">]') @@ -512,7 +569,7 @@ def format(cls, filename, data): # Parse blocks of #include statements, and output them as a sorted list # when we reach a non #include statement. - for line in data.split('\n'): + for line in data.split("\n"): match = IncludeOrderFormatter.include_regex.match(line) if match: # If the current line is an #include statement, add it to the @@ -538,31 +595,48 @@ def format(cls, filename, data): lines.append(include[0]) includes = [] - return '\n'.join(lines) + return "\n".join(lines) class StripTrailingSpaceFormatter(Formatter): - patterns = ('*.c', '*.cpp', '*.h', '*.hpp', '*.py', 'CMakelists.txt') + patterns = ("*.c", "*.cpp", "*.h", "*.hpp", "CMakelists.txt") @classmethod def format(cls, filename, data): - lines = data.split('\n') + lines = data.split("\n") for i in range(len(lines)): - lines[i] = lines[i].rstrip() + '\n' - return ''.join(lines) + lines[i] = lines[i].rstrip() + "\n" + return "".join(lines) + + +class RuffFormatter(Formatter): + patterns = ("*.py",) + + @classmethod + def format(cls, filename, data): + try: + ret = subprocess.run( + ["ruff", "format", "--stdin-filename=" + filename, "-"], + input=data.encode("utf-8"), + stdout=subprocess.PIPE, + ) + return ret.stdout.decode("utf-8") + except FileNotFoundError: + return data # ------------------------------------------------------------------------------ # Style checking # + def check_file(top_level, commit, filename): # Extract the line numbers touched by the commit. commit_diff = commit.get_diff(top_level, filename) lines = [] for hunk in commit_diff: - lines.extend(hunk.side('to').touched) + lines.extend(hunk.side("to").touched) # Skip commits that don't add any line. if len(lines) == 0: @@ -583,22 +657,21 @@ def check_file(top_level, commit, filename): # Split the diff in hunks, recording line number ranges for each hunk, and # filter out hunks that are not touched by the commit. formatted_diff = parse_diff(diff) - formatted_diff = [ - hunk for hunk in formatted_diff if hunk.intersects(lines)] + formatted_diff = [hunk for hunk in formatted_diff if hunk.intersects(lines)] # Check for code issues not related to formatting. issues = [] for checker in StyleChecker.checkers(filename): checker = checker(after) for hunk in commit_diff: - issues += checker.check(hunk.side('to').touched) + issues += checker.check(hunk.side("to").touched) # Print the detected issues. if len(issues) == 0 and len(formatted_diff) == 0: return 0 - print('%s---' % Colours.fg(Colours.Red), filename) - print('%s+++' % Colours.fg(Colours.Green), filename) + print("%s---" % Colours.fg(Colours.Red), filename) + print("%s+++" % Colours.fg(Colours.Green), filename) if len(formatted_diff): for hunk in formatted_diff: @@ -607,16 +680,17 @@ def check_file(top_level, commit, filename): if len(issues): issues = sorted(issues, key=lambda i: i.line_number) for issue in issues: - print('%s#%u: %s' % - (Colours.fg(Colours.Yellow), issue.line_number, issue.msg)) + print( + "%s#%u: %s" % (Colours.fg(Colours.Yellow), issue.line_number, issue.msg) + ) if issue.line is not None: - print('+%s%s' % (issue.line.rstrip(), Colours.reset())) + print("+%s%s" % (issue.line.rstrip(), Colours.reset())) return len(formatted_diff) + len(issues) def check_style(top_level, commit): - separator = '-' * len(commit.title) + separator = "-" * len(commit.title) print(separator) print(commit.title) print(separator) @@ -626,16 +700,18 @@ def check_style(top_level, commit): # Apply the commit checkers first. for checker in CommitChecker.checkers(): for issue in checker.check(commit, top_level): - print('%s%s%s' % - (Colours.fg(Colours.Yellow), issue.msg, Colours.reset())) + print("%s%s%s" % (Colours.fg(Colours.Yellow), issue.msg, Colours.reset())) issues += 1 # Filter out files we have no checker for. patterns = set() patterns.update(StyleChecker.all_patterns()) patterns.update(Formatter.all_patterns()) - files = [f for f in commit.files() if len( - [p for p in patterns if fnmatch.fnmatch(os.path.basename(f), p)])] + files = [ + f + for f in commit.files() + if len([p for p in patterns if fnmatch.fnmatch(os.path.basename(f), p)]) + ] for f in files: issues += check_file(top_level, commit, f) @@ -643,9 +719,11 @@ def check_style(top_level, commit): if issues == 0: print("No issue detected") else: - print('---') - print("%u potential %s detected, please review" % - (issues, 'issue' if issues == 1 else 'issues')) + print("---") + print( + "%u potential %s detected, please review" + % (issues, "issue" if issues == 1 else "issues") + ) return issues @@ -654,20 +732,20 @@ def extract_commits(revs): """Extract a list of commits on which to operate from a revision or revision range. """ - ret = subprocess.run(['git', 'rev-parse', revs], stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + ret = subprocess.run( + ["git", "rev-parse", revs], stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) if ret.returncode != 0: - print(ret.stderr.decode('utf-8').splitlines()[0]) + print(ret.stderr.decode("utf-8").splitlines()[0]) return [] - revlist = ret.stdout.decode('utf-8').splitlines() + revlist = ret.stdout.decode("utf-8").splitlines() # If the revlist contains more than one item, pass it to git rev-list to list # each commit individually. if len(revlist) > 1: - ret = subprocess.run(['git', 'rev-list', *revlist], - stdout=subprocess.PIPE) - revlist = ret.stdout.decode('utf-8').splitlines() + ret = subprocess.run(["git", "rev-list", *revlist], stdout=subprocess.PIPE) + revlist = ret.stdout.decode("utf-8").splitlines() revlist.reverse() return [Commit(x) for x in revlist] @@ -675,26 +753,41 @@ def extract_commits(revs): def git_top_level(): """Get the absolute path of the git top-level directory.""" - ret = subprocess.run(['git', 'rev-parse', '--show-toplevel'], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + ret = subprocess.run( + ["git", "rev-parse", "--show-toplevel"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) if ret.returncode != 0: - print(ret.stderr.decode('utf-8').splitlines()[0]) + print(ret.stderr.decode("utf-8").splitlines()[0]) return None - return ret.stdout.decode('utf-8').strip() + return ret.stdout.decode("utf-8").strip() def main(argv): # Parse command line arguments parser = argparse.ArgumentParser() - parser.add_argument('--staged', '-s', action='store_true', - help='Include the changes in the index. Defaults to False') - parser.add_argument('--amend', '-a', action='store_true', - help='Include changes in the index and the previous patch combined. Defaults to False') - parser.add_argument('revision_range', type=str, default=None, nargs='?', - help='Revision range (as defined by git rev-parse). Defaults to HEAD if not specified.') + parser.add_argument( + "--staged", + "-s", + action="store_true", + help="Include the changes in the index. Defaults to False", + ) + parser.add_argument( + "--amend", + "-a", + action="store_true", + help="Include changes in the index and the previous patch combined. Defaults to False", + ) + parser.add_argument( + "revision_range", + type=str, + default=None, + nargs="?", + help="Revision range (as defined by git rev-parse). Defaults to HEAD if not specified.", + ) args = parser.parse_args(argv[1:]) # Check for required dependencies. @@ -723,7 +816,7 @@ def main(argv): if len(commits) == 0: # And no revisions were passed, then default to HEAD if not args.revision_range: - args.revision_range = 'HEAD' + args.revision_range = "HEAD" if args.revision_range: commits += extract_commits(args.revision_range) @@ -731,7 +824,7 @@ def main(argv): issues = 0 for commit in commits: issues += check_style(top_level, commit) - print('') + print("") if issues: return 1 @@ -739,5 +832,5 @@ def main(argv): return 0 -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main(sys.argv))