From fb0dd85959164360ee8243c64c5c213f3936e1c2 Mon Sep 17 00:00:00 2001 From: Roland Walker Date: Mon, 6 Apr 2026 17:09:52 -0400 Subject: [PATCH] restore full main.py test coverage * move useful frameworks out of test_main_regression.py to test/utils.py * add tests to test_main.py covering the missing paths --- test/pytests/test_main.py | 176 ++++++++++++++++++++++++++- test/pytests/test_main_regression.py | 162 ++---------------------- test/utils.py | 157 ++++++++++++++++++++++++ 3 files changed, 340 insertions(+), 155 deletions(-) diff --git a/test/pytests/test_main.py b/test/pytests/test_main.py index 6f80b0f4..d6ee27d0 100644 --- a/test/pytests/test_main.py +++ b/test/pytests/test_main.py @@ -8,11 +8,15 @@ import shutil from tempfile import NamedTemporaryFile from textwrap import dedent +from types import SimpleNamespace +from typing import Any, cast import click from click.testing import CliRunner from pymysql.err import OperationalError +import pytest +from mycli import main from mycli.constants import ( DEFAULT_DATABASE, DEFAULT_HOST, @@ -26,7 +30,20 @@ from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS from mycli.packages.sqlresult import SQLResult from mycli.sqlexecute import ServerInfo, SQLExecute -from test.utils import DATABASE, HOST, PASSWORD, PORT, TEMPFILE_PREFIX, USER, dbtest, run +from test.utils import ( + DATABASE, + HOST, + PASSWORD, + PORT, + TEMPFILE_PREFIX, + USER, + ReusableLock, + call_click_entrypoint_direct, + dbtest, + make_bare_mycli, + make_dummy_mycli_class, + run, +) pytests_dir = os.path.abspath(os.path.dirname(__file__)) project_root_dir = os.path.abspath(os.path.join(pytests_dir, '..', '..')) @@ -2150,3 +2167,160 @@ def test_null_string_config(monkeypatch): os.remove(myclirc.name) except Exception as e: print(f'An error occurred while attempting to delete the file: {e}') + + +def test_change_prompt_format_requires_argument() -> None: + cli = make_bare_mycli() + assert main.MyCli.change_prompt_format(cli, '')[0].status == 'Missing required argument, format.' + + +def test_change_prompt_format_updates_prompt() -> None: + cli = make_bare_mycli() + assert main.MyCli.change_prompt_format(cli, '\\u@\\h> ')[0].status == 'Changed prompt format to \\u@\\h> ' + + +def test_output_timing_logs_and_prints_with_warning_style(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + timings_logged: list[str] = [] + cli.log_output = lambda text: timings_logged.append(text) # type: ignore[assignment] + printed: list[tuple[Any, Any]] = [] + monkeypatch.setattr(main, 'print_formatted_text', lambda text, style=None: printed.append((text, style))) + main.MyCli.output_timing(cli, 'Time: 1.000s', is_warnings_style=True) + assert timings_logged == ['Time: 1.000s'] + assert printed[-1][1] == cli.ptoolkit_style + + +def test_run_cli_delegates_to_main_repl(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + run_cli_calls: list[Any] = [] + monkeypatch.setattr(main, 'main_repl', lambda target: run_cli_calls.append(target)) + main.MyCli.run_cli(cli) + assert run_cli_calls == [cli] + + +def test_get_output_margin_uses_prompt_session_render_counter(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + render_counters: list[int] = [] + cli.prompt_lines = 0 + cli.get_reserved_space = lambda: 2 # type: ignore[assignment] + cli.prompt_session = cast( + Any, + SimpleNamespace(app=SimpleNamespace(render_counter=7)), + ) + + def fake_get_prompt(mycli: Any, string: str, render_counter: int) -> str: + render_counters.append(render_counter) + return 'line1\nline2' + + monkeypatch.setattr(main, 'get_prompt', fake_get_prompt) + monkeypatch.setattr(main.special, 'is_timing_enabled', lambda: False) + assert main.MyCli.get_output_margin(cli, 'ok') == 5 + assert render_counters == [7] + + +def test_on_completions_refreshed_updates_completer_and_invalidates_prompt() -> None: + cli = make_bare_mycli() + entered_lock = {'count': 0} + invalidated: list[bool] = [] + cli._completer_lock = cast(Any, ReusableLock(lambda: entered_lock.__setitem__('count', entered_lock['count'] + 1))) + cli.prompt_session = cast(Any, SimpleNamespace(app=SimpleNamespace(invalidate=lambda: invalidated.append(True)))) + new_completer = cast(Any, SimpleNamespace(get_completions=lambda document, event: ['done'])) + main.MyCli._on_completions_refreshed(cli, new_completer) + assert cli.completer is new_completer + assert invalidated == [True] + assert entered_lock['count'] == 1 + + +def test_get_completions_uses_current_completer() -> None: + cli = make_bare_mycli() + entered_lock = {'count': 0} + cli._completer_lock = cast(Any, ReusableLock(lambda: entered_lock.__setitem__('count', entered_lock['count'] + 1))) + cli.completer = cast(Any, SimpleNamespace(get_completions=lambda document, event: ['done'])) + assert list(main.MyCli.get_completions(cli, 'select', 6)) == ['done'] + assert entered_lock['count'] == 1 + + +def test_click_entrypoint_callback_covers_dsn_list_init_commands(monkeypatch: pytest.MonkeyPatch) -> None: + dummy_class = make_dummy_mycli_class( + config={ + 'main': {'use_keyring': 'false', 'my_cnf_transition_done': 'true'}, + 'connection': {'default_keepalive_ticks': 0}, + 'alias_dsn': {'prod': 'mysql://u:p@h/db'}, + 'alias_dsn.init-commands': {'prod': ['set a=1', 'set b=2']}, + } + ) + monkeypatch.setattr(main, 'MyCli', dummy_class) + monkeypatch.setattr(main.sys, 'stdin', SimpleNamespace(isatty=lambda: True)) + monkeypatch.setattr(main.sys.stderr, 'isatty', lambda: True) + + cli_args = main.CliArgs() + cli_args.dsn = 'prod' + cli_args.init_command = 'set c=3' + call_click_entrypoint_direct(cli_args) + + dummy = dummy_class.last_instance + assert dummy is not None + assert dummy.connect_calls[-1]['init_command'] == 'set a=1; set b=2; set c=3' + + +def test_click_entrypoint_callback_uses_batch_with_progress_path(monkeypatch: pytest.MonkeyPatch) -> None: + dummy_class = make_dummy_mycli_class( + config={ + 'main': {'use_keyring': 'false', 'my_cnf_transition_done': 'true'}, + 'connection': {'default_keepalive_ticks': 0}, + 'alias_dsn': {}, + } + ) + monkeypatch.setattr(main, 'MyCli', dummy_class) + monkeypatch.setattr(main.sys, 'stdin', SimpleNamespace(isatty=lambda: True)) + monkeypatch.setattr(main.sys.stderr, 'isatty', lambda: True) + monkeypatch.setattr(main, 'main_batch_with_progress_bar', lambda mycli, cli_args: 12) + + cli_args = main.CliArgs() + cli_args.batch = 'queries.sql' + cli_args.progress = True + with pytest.raises(SystemExit) as excinfo: + call_click_entrypoint_direct(cli_args) + assert excinfo.value.code == 12 + + +def test_click_entrypoint_callback_uses_batch_without_progress_path(monkeypatch: pytest.MonkeyPatch) -> None: + dummy_class = make_dummy_mycli_class( + config={ + 'main': {'use_keyring': 'false', 'my_cnf_transition_done': 'true'}, + 'connection': {'default_keepalive_ticks': 0}, + 'alias_dsn': {}, + } + ) + monkeypatch.setattr(main, 'MyCli', dummy_class) + monkeypatch.setattr(main.sys, 'stdin', SimpleNamespace(isatty=lambda: True)) + monkeypatch.setattr(main.sys.stderr, 'isatty', lambda: True) + monkeypatch.setattr(main, 'main_batch_without_progress_bar', lambda mycli, cli_args: 13) + + cli_args = main.CliArgs() + cli_args.batch = 'queries.sql' + cli_args.progress = False + with pytest.raises(SystemExit) as excinfo: + call_click_entrypoint_direct(cli_args) + assert excinfo.value.code == 13 + + +def test_click_entrypoint_callback_covers_mycnf_underscore_fallback(monkeypatch: pytest.MonkeyPatch) -> None: + click_lines: list[str] = [] + monkeypatch.setattr(click, 'secho', lambda message='', **kwargs: click_lines.append(str(message))) + monkeypatch.setattr(main.sys, 'stdin', SimpleNamespace(isatty=lambda: True)) + monkeypatch.setattr(main.sys.stderr, 'isatty', lambda: False) + + dummy_class = make_dummy_mycli_class( + config={ + 'main': {'use_keyring': 'false', 'my_cnf_transition_done': 'false'}, + 'connection': {'default_keepalive_ticks': 0}, + 'alias_dsn': {}, + }, + my_cnf={'client': {'ssl_ca': '/tmp/ca.pem'}, 'mysqld': {}}, + config_without_package_defaults={'main': {}}, + ) + monkeypatch.setattr(main, 'MyCli', dummy_class) + + call_click_entrypoint_direct(main.CliArgs()) + assert any('ssl-ca = /tmp/ca.pem' in line for line in click_lines) diff --git a/test/pytests/test_main_regression.py b/test/pytests/test_main_regression.py index c813530c..5946a58a 100644 --- a/test/pytests/test_main_regression.py +++ b/test/pytests/test_main_regression.py @@ -23,7 +23,7 @@ from pathlib import Path import sys from types import ModuleType, SimpleNamespace -from typing import Any, Callable, Literal, cast +from typing import Any, cast import click from click.testing import CliRunner @@ -34,42 +34,13 @@ from mycli import main import mycli.key_bindings from mycli.packages.sqlresult import SQLResult - - -class DummyLogger: - def __init__(self) -> None: - self.debug_calls: list[tuple[tuple[Any, ...], dict[str, Any]]] = [] - self.error_calls: list[tuple[tuple[Any, ...], dict[str, Any]]] = [] - self.warning_calls: list[tuple[tuple[Any, ...], dict[str, Any]]] = [] - - def debug(self, *args: Any, **kwargs: Any) -> None: - self.debug_calls.append((args, kwargs)) - - def error(self, *args: Any, **kwargs: Any) -> None: - self.error_calls.append((args, kwargs)) - - def warning(self, *args: Any, **kwargs: Any) -> None: - self.warning_calls.append((args, kwargs)) - - -class DummyFormatter: - def __init__(self, format_name: str = 'ascii') -> None: - self.format_name = format_name - self.query = '' - self.supported_formats = ['ascii', 'csv', 'tsv', 'vertical'] - self._output_formats = { - 'ascii': SimpleNamespace(formatter_args={'missing_value': main.DEFAULT_MISSING_VALUE}), - 'csv': SimpleNamespace(formatter_args={'missing_value': main.DEFAULT_MISSING_VALUE}), - 'tsv': SimpleNamespace(formatter_args={'missing_value': main.DEFAULT_MISSING_VALUE}), - 'vertical': SimpleNamespace(formatter_args={'missing_value': main.DEFAULT_MISSING_VALUE}), - } - self.calls: list[tuple[tuple[Any, ...], dict[str, Any]]] = [] - - def format_output(self, rows: Any, header: Any, format_name: str | None = None, **kwargs: Any) -> list[str] | str: - self.calls.append(((rows, header, format_name), kwargs)) - if format_name == 'vertical': - return ['vertical output'] - return ['plain output'] +from test.utils import ( # type: ignore[attr-defined] + DummyFormatter, + DummyLogger, + call_click_entrypoint_direct, + make_bare_mycli, + make_dummy_mycli_class, +) class FakeCursorBase: @@ -100,19 +71,6 @@ def ping(self, reconnect: bool = False) -> None: raise self.ping_exc -class ReusableLock: - def __init__(self, on_enter: Callable[[], Any] | None = None) -> None: - self.on_enter = on_enter - - def __enter__(self) -> 'ReusableLock': - if self.on_enter is not None: - self.on_enter() - return self - - def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> Literal[False]: - return False - - class BoolSection(dict[str, Any]): def as_bool(self, key: str) -> bool: return str(self[key]).lower() == 'true' @@ -154,63 +112,6 @@ def __int__(self) -> int: raise ValueError('bad int') -def make_bare_mycli() -> Any: - cli = object.__new__(main.MyCli) - cli.logger = cast(Any, DummyLogger()) - cli.main_formatter = DummyFormatter() - cli.redirect_formatter = DummyFormatter() - cli.helpers_style = 'helpers-style' - cli.helpers_warnings_style = 'helpers-warnings-style' - cli.ptoolkit_style = cast(Any, 'pt-style') - cli.syntax_style = 'native' - cli.cli_style = {} - cli.null_string = '' - cli.numeric_alignment = 'right' - cli.binary_display = None - cli.show_warnings = False - cli.query_history = [] - cli.toolbar_error_message = None - cli.prompt_session = None - cli.last_prompt_message = main.ANSI('') - cli.last_custom_toolbar_message = main.ANSI('') - cli.prompt_lines = 0 - cli.prompt_format = main.MyCli.default_prompt - cli.multiline_continuation_char = '>' - cli.toolbar_format = 'default' - cli.destructive_warning = False - cli.destructive_keywords = ['drop'] - cli.keepalive_ticks = None - cli._keepalive_counter = 0 - cli.less_chatty = True - cli.smart_completion = False - cli.key_bindings = 'emacs' - cli.auto_vertical_output = False - cli.wider_completion_menu = False - cli.explicit_pager = False - cli._completer_lock = cast(Any, ReusableLock()) - cli.terminal_tab_title_format = '' - cli.terminal_window_title_format = '' - cli.multiplex_window_title_format = '' - cli.multiplex_pane_title_format = '' - cli.dsn_alias = None - cli.login_path = None - cli.login_path_as_host = False - cli.post_redirect_command = None - cli.logfile = None - cli.emacs_ttimeoutlen = 1.0 - cli.vi_ttimeoutlen = 1.0 - cli.beep_after_seconds = 0.0 - cli.config = {'history_file': '~/.mycli-history-testing'} - cli.output = lambda *args, **kwargs: None # type: ignore[assignment] - cli.echo = lambda *args, **kwargs: None # type: ignore[assignment] - cli.log_query = lambda *args, **kwargs: None # type: ignore[assignment] - cli.log_output = lambda *args, **kwargs: None # type: ignore[assignment] - cli.configure_pager = lambda: None # type: ignore[assignment] - cli.refresh_completions = lambda reset=False: [SQLResult(status='refresh')] # type: ignore[assignment] - cli.reconnect = lambda database='': False # type: ignore[assignment] - return cli - - def load_main_variant(monkeypatch: pytest.MonkeyPatch, *, fail_pwd: bool = False) -> ModuleType: import builtins @@ -232,53 +133,6 @@ def fake_import(name: str, globals: Any = None, locals: Any = None, fromlist: An return module -def make_dummy_mycli_class( - *, - config: dict[str, Any] | None = None, - my_cnf: dict[str, Any] | None = None, - config_without_package_defaults: dict[str, Any] | None = None, -) -> Any: - class DummyMyCli: - last_instance: Any = None - - def __init__(self, **kwargs: Any) -> None: - type(self).last_instance = self - self.init_kwargs = dict(kwargs) - self.config = config or {'main': {}, 'alias_dsn': {}} - self.my_cnf = my_cnf or {'client': {}, 'mysqld': {}} - self.config_without_package_defaults = config_without_package_defaults or {} - self.default_keepalive_ticks = 5 - self.ssl_mode = None - self.logger = DummyLogger() - self.main_formatter = SimpleNamespace(format_name=None) - self.destructive_warning = False - self.destructive_keywords = ['drop'] - self.dsn_alias = None - self.connect_calls: list[dict[str, Any]] = [] - self.run_query_calls: list[tuple[str, Any, bool]] = [] - self.run_cli_called = False - self.close_called = False - - def connect(self, **kwargs: Any) -> None: - self.connect_calls.append(dict(kwargs)) - - def run_query(self, query: str, checkpoint: Any = None, new_line: bool = True) -> None: - self.run_query_calls.append((query, checkpoint, new_line)) - - def run_cli(self) -> None: - self.run_cli_called = True - - def close(self) -> None: - self.close_called = True - - return DummyMyCli - - -def call_click_entrypoint_direct(cli_args: main.CliArgs) -> None: - assert main.click_entrypoint.callback is not None - cast(Any, main.click_entrypoint.callback).__wrapped__(cli_args) - - def test_import_fallbacks_for_pwd(monkeypatch: pytest.MonkeyPatch) -> None: module = load_main_variant(monkeypatch, fail_pwd=True) diff --git a/test/utils.py b/test/utils.py index 7d278f4c..427fc117 100644 --- a/test/utils.py +++ b/test/utils.py @@ -5,10 +5,13 @@ import platform import signal import time +from types import SimpleNamespace +from typing import Any, Callable, Literal, cast import pymysql import pytest +from mycli import main from mycli.constants import ( DEFAULT_CHARSET, DEFAULT_HOST, @@ -17,6 +20,7 @@ TEST_DATABASE, ) from mycli.main import special +from mycli.packages.sqlresult import SQLResult DATABASE = TEST_DATABASE PASSWORD = os.getenv("PYTEST_PASSWORD") @@ -30,6 +34,159 @@ TEMPFILE_PREFIX = 'mycli_test_suite_' +class DummyLogger: + def __init__(self) -> None: + self.debug_calls: list[tuple[tuple[Any, ...], dict[str, Any]]] = [] + self.error_calls: list[tuple[tuple[Any, ...], dict[str, Any]]] = [] + self.warning_calls: list[tuple[tuple[Any, ...], dict[str, Any]]] = [] + + def debug(self, *args: Any, **kwargs: Any) -> None: + self.debug_calls.append((args, kwargs)) + + def error(self, *args: Any, **kwargs: Any) -> None: + self.error_calls.append((args, kwargs)) + + def warning(self, *args: Any, **kwargs: Any) -> None: + self.warning_calls.append((args, kwargs)) + + +class DummyFormatter: + def __init__(self, format_name: str = 'ascii') -> None: + self.format_name = format_name + self.query = '' + self.supported_formats = ['ascii', 'csv', 'tsv', 'vertical'] + self._output_formats = { + 'ascii': SimpleNamespace(formatter_args={'missing_value': main.DEFAULT_MISSING_VALUE}), + 'csv': SimpleNamespace(formatter_args={'missing_value': main.DEFAULT_MISSING_VALUE}), + 'tsv': SimpleNamespace(formatter_args={'missing_value': main.DEFAULT_MISSING_VALUE}), + 'vertical': SimpleNamespace(formatter_args={'missing_value': main.DEFAULT_MISSING_VALUE}), + } + self.calls: list[tuple[tuple[Any, ...], dict[str, Any]]] = [] + + def format_output(self, rows: Any, header: Any, format_name: str | None = None, **kwargs: Any) -> list[str] | str: + self.calls.append(((rows, header, format_name), kwargs)) + if format_name == 'vertical': + return ['vertical output'] + return ['plain output'] + + +class ReusableLock: + def __init__(self, on_enter: Callable[[], Any] | None = None) -> None: + self.on_enter = on_enter + + def __enter__(self) -> 'ReusableLock': + if self.on_enter is not None: + self.on_enter() + return self + + def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> Literal[False]: + return False + + +def make_bare_mycli() -> Any: + cli = object.__new__(main.MyCli) + cli.logger = cast(Any, DummyLogger()) + cli.main_formatter = DummyFormatter() + cli.redirect_formatter = DummyFormatter() + cli.helpers_style = 'helpers-style' + cli.helpers_warnings_style = 'helpers-warnings-style' + cli.ptoolkit_style = cast(Any, 'pt-style') + cli.syntax_style = 'native' + cli.cli_style = {} + cli.null_string = '' + cli.numeric_alignment = 'right' + cli.binary_display = None + cli.show_warnings = False + cli.query_history = [] + cli.toolbar_error_message = None + cli.prompt_session = None + cli.last_prompt_message = main.ANSI('') + cli.last_custom_toolbar_message = main.ANSI('') + cli.prompt_lines = 0 + cli.prompt_format = main.MyCli.default_prompt + cli.multiline_continuation_char = '>' + cli.toolbar_format = 'default' + cli.destructive_warning = False + cli.destructive_keywords = ['drop'] + cli.keepalive_ticks = None + cli._keepalive_counter = 0 + cli.less_chatty = True + cli.smart_completion = False + cli.key_bindings = 'emacs' + cli.auto_vertical_output = False + cli.wider_completion_menu = False + cli.explicit_pager = False + cli._completer_lock = cast(Any, ReusableLock()) + cli.terminal_tab_title_format = '' + cli.terminal_window_title_format = '' + cli.multiplex_window_title_format = '' + cli.multiplex_pane_title_format = '' + cli.dsn_alias = None + cli.login_path = None + cli.login_path_as_host = False + cli.post_redirect_command = None + cli.logfile = None + cli.emacs_ttimeoutlen = 1.0 + cli.vi_ttimeoutlen = 1.0 + cli.beep_after_seconds = 0.0 + cli.config = {'history_file': '~/.mycli-history-testing'} + cli.output = lambda *args, **kwargs: None # type: ignore[assignment] + cli.echo = lambda *args, **kwargs: None # type: ignore[assignment] + cli.log_query = lambda *args, **kwargs: None # type: ignore[assignment] + cli.log_output = lambda *args, **kwargs: None # type: ignore[assignment] + cli.configure_pager = lambda: None # type: ignore[assignment] + cli.refresh_completions = lambda reset=False: [SQLResult(status='refresh')] # type: ignore[assignment] + cli.reconnect = lambda database='': False # type: ignore[assignment] + return cli + + +def make_dummy_mycli_class( + *, + config: dict[str, Any] | None = None, + my_cnf: dict[str, Any] | None = None, + config_without_package_defaults: dict[str, Any] | None = None, +) -> Any: + class DummyMyCli: + last_instance: Any = None + + def __init__(self, **kwargs: Any) -> None: + type(self).last_instance = self + self.init_kwargs = dict(kwargs) + self.config = config or {'main': {}, 'alias_dsn': {}} + self.my_cnf = my_cnf or {'client': {}, 'mysqld': {}} + self.config_without_package_defaults = config_without_package_defaults or {} + self.default_keepalive_ticks = 5 + self.ssl_mode = None + self.logger = DummyLogger() + self.main_formatter = SimpleNamespace(format_name=None) + self.destructive_warning = False + self.destructive_keywords = ['drop'] + self.dsn_alias = None + self.connect_calls: list[dict[str, Any]] = [] + self.run_query_calls: list[tuple[str, Any, bool]] = [] + self.run_cli_called = False + self.close_called = False + + def connect(self, **kwargs: Any) -> None: + self.connect_calls.append(dict(kwargs)) + + def run_query(self, query: str, checkpoint: Any = None, new_line: bool = True) -> None: + self.run_query_calls.append((query, checkpoint, new_line)) + + def run_cli(self) -> None: + self.run_cli_called = True + + def close(self) -> None: + self.close_called = True + + return DummyMyCli + + +def call_click_entrypoint_direct(cli_args: main.CliArgs) -> None: + assert main.click_entrypoint.callback is not None + cast(Any, main.click_entrypoint.callback).__wrapped__(cli_args) + + def db_connection(dbname=None): conn = pymysql.connect(user=USER, host=HOST, port=PORT, database=dbname, password=PASSWORD, charset=CHARACTER_SET, local_infile=False) conn.autocommit = True