Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 175 additions & 1 deletion test/pytests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@
import shutil
from tempfile import NamedTemporaryFile
from textwrap import dedent
from types import SimpleNamespace
from typing import Any, cast

import click
from click.testing import CliRunner
from pymysql.err import OperationalError
import pytest

from mycli import main
from mycli.constants import (
DEFAULT_DATABASE,
DEFAULT_HOST,
Expand All @@ -26,7 +30,20 @@
from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS
from mycli.packages.sqlresult import SQLResult
from mycli.sqlexecute import ServerInfo, SQLExecute
from test.utils import DATABASE, HOST, PASSWORD, PORT, TEMPFILE_PREFIX, USER, dbtest, run
from test.utils import (
DATABASE,
HOST,
PASSWORD,
PORT,
TEMPFILE_PREFIX,
USER,
ReusableLock,
call_click_entrypoint_direct,
dbtest,
make_bare_mycli,
make_dummy_mycli_class,
run,
)

pytests_dir = os.path.abspath(os.path.dirname(__file__))
project_root_dir = os.path.abspath(os.path.join(pytests_dir, '..', '..'))
Expand Down Expand Up @@ -2150,3 +2167,160 @@ def test_null_string_config(monkeypatch):
os.remove(myclirc.name)
except Exception as e:
print(f'An error occurred while attempting to delete the file: {e}')


def test_change_prompt_format_requires_argument() -> None:
cli = make_bare_mycli()
assert main.MyCli.change_prompt_format(cli, '')[0].status == 'Missing required argument, format.'


def test_change_prompt_format_updates_prompt() -> None:
cli = make_bare_mycli()
assert main.MyCli.change_prompt_format(cli, '\\u@\\h> ')[0].status == 'Changed prompt format to \\u@\\h> '


def test_output_timing_logs_and_prints_with_warning_style(monkeypatch: pytest.MonkeyPatch) -> None:
cli = make_bare_mycli()
timings_logged: list[str] = []
cli.log_output = lambda text: timings_logged.append(text) # type: ignore[assignment]
printed: list[tuple[Any, Any]] = []
monkeypatch.setattr(main, 'print_formatted_text', lambda text, style=None: printed.append((text, style)))
main.MyCli.output_timing(cli, 'Time: 1.000s', is_warnings_style=True)
assert timings_logged == ['Time: 1.000s']
assert printed[-1][1] == cli.ptoolkit_style


def test_run_cli_delegates_to_main_repl(monkeypatch: pytest.MonkeyPatch) -> None:
cli = make_bare_mycli()
run_cli_calls: list[Any] = []
monkeypatch.setattr(main, 'main_repl', lambda target: run_cli_calls.append(target))
main.MyCli.run_cli(cli)
assert run_cli_calls == [cli]


def test_get_output_margin_uses_prompt_session_render_counter(monkeypatch: pytest.MonkeyPatch) -> None:
cli = make_bare_mycli()
render_counters: list[int] = []
cli.prompt_lines = 0
cli.get_reserved_space = lambda: 2 # type: ignore[assignment]
cli.prompt_session = cast(
Any,
SimpleNamespace(app=SimpleNamespace(render_counter=7)),
)

def fake_get_prompt(mycli: Any, string: str, render_counter: int) -> str:
render_counters.append(render_counter)
return 'line1\nline2'

monkeypatch.setattr(main, 'get_prompt', fake_get_prompt)
monkeypatch.setattr(main.special, 'is_timing_enabled', lambda: False)
assert main.MyCli.get_output_margin(cli, 'ok') == 5
assert render_counters == [7]


def test_on_completions_refreshed_updates_completer_and_invalidates_prompt() -> None:
cli = make_bare_mycli()
entered_lock = {'count': 0}
invalidated: list[bool] = []
cli._completer_lock = cast(Any, ReusableLock(lambda: entered_lock.__setitem__('count', entered_lock['count'] + 1)))
cli.prompt_session = cast(Any, SimpleNamespace(app=SimpleNamespace(invalidate=lambda: invalidated.append(True))))
new_completer = cast(Any, SimpleNamespace(get_completions=lambda document, event: ['done']))
main.MyCli._on_completions_refreshed(cli, new_completer)
assert cli.completer is new_completer
assert invalidated == [True]
assert entered_lock['count'] == 1


def test_get_completions_uses_current_completer() -> None:
cli = make_bare_mycli()
entered_lock = {'count': 0}
cli._completer_lock = cast(Any, ReusableLock(lambda: entered_lock.__setitem__('count', entered_lock['count'] + 1)))
cli.completer = cast(Any, SimpleNamespace(get_completions=lambda document, event: ['done']))
assert list(main.MyCli.get_completions(cli, 'select', 6)) == ['done']
assert entered_lock['count'] == 1


def test_click_entrypoint_callback_covers_dsn_list_init_commands(monkeypatch: pytest.MonkeyPatch) -> None:
dummy_class = make_dummy_mycli_class(
config={
'main': {'use_keyring': 'false', 'my_cnf_transition_done': 'true'},
'connection': {'default_keepalive_ticks': 0},
'alias_dsn': {'prod': 'mysql://u:p@h/db'},
'alias_dsn.init-commands': {'prod': ['set a=1', 'set b=2']},
}
)
monkeypatch.setattr(main, 'MyCli', dummy_class)
monkeypatch.setattr(main.sys, 'stdin', SimpleNamespace(isatty=lambda: True))
monkeypatch.setattr(main.sys.stderr, 'isatty', lambda: True)

cli_args = main.CliArgs()
cli_args.dsn = 'prod'
cli_args.init_command = 'set c=3'
call_click_entrypoint_direct(cli_args)

dummy = dummy_class.last_instance
assert dummy is not None
assert dummy.connect_calls[-1]['init_command'] == 'set a=1; set b=2; set c=3'


def test_click_entrypoint_callback_uses_batch_with_progress_path(monkeypatch: pytest.MonkeyPatch) -> None:
dummy_class = make_dummy_mycli_class(
config={
'main': {'use_keyring': 'false', 'my_cnf_transition_done': 'true'},
'connection': {'default_keepalive_ticks': 0},
'alias_dsn': {},
}
)
monkeypatch.setattr(main, 'MyCli', dummy_class)
monkeypatch.setattr(main.sys, 'stdin', SimpleNamespace(isatty=lambda: True))
monkeypatch.setattr(main.sys.stderr, 'isatty', lambda: True)
monkeypatch.setattr(main, 'main_batch_with_progress_bar', lambda mycli, cli_args: 12)

cli_args = main.CliArgs()
cli_args.batch = 'queries.sql'
cli_args.progress = True
with pytest.raises(SystemExit) as excinfo:
call_click_entrypoint_direct(cli_args)
assert excinfo.value.code == 12


def test_click_entrypoint_callback_uses_batch_without_progress_path(monkeypatch: pytest.MonkeyPatch) -> None:
dummy_class = make_dummy_mycli_class(
config={
'main': {'use_keyring': 'false', 'my_cnf_transition_done': 'true'},
'connection': {'default_keepalive_ticks': 0},
'alias_dsn': {},
}
)
monkeypatch.setattr(main, 'MyCli', dummy_class)
monkeypatch.setattr(main.sys, 'stdin', SimpleNamespace(isatty=lambda: True))
monkeypatch.setattr(main.sys.stderr, 'isatty', lambda: True)
monkeypatch.setattr(main, 'main_batch_without_progress_bar', lambda mycli, cli_args: 13)

cli_args = main.CliArgs()
cli_args.batch = 'queries.sql'
cli_args.progress = False
with pytest.raises(SystemExit) as excinfo:
call_click_entrypoint_direct(cli_args)
assert excinfo.value.code == 13


def test_click_entrypoint_callback_covers_mycnf_underscore_fallback(monkeypatch: pytest.MonkeyPatch) -> None:
click_lines: list[str] = []
monkeypatch.setattr(click, 'secho', lambda message='', **kwargs: click_lines.append(str(message)))
monkeypatch.setattr(main.sys, 'stdin', SimpleNamespace(isatty=lambda: True))
monkeypatch.setattr(main.sys.stderr, 'isatty', lambda: False)

dummy_class = make_dummy_mycli_class(
config={
'main': {'use_keyring': 'false', 'my_cnf_transition_done': 'false'},
'connection': {'default_keepalive_ticks': 0},
'alias_dsn': {},
},
my_cnf={'client': {'ssl_ca': '/tmp/ca.pem'}, 'mysqld': {}},
config_without_package_defaults={'main': {}},
)
monkeypatch.setattr(main, 'MyCli', dummy_class)

call_click_entrypoint_direct(main.CliArgs())
assert any('ssl-ca = /tmp/ca.pem' in line for line in click_lines)
162 changes: 8 additions & 154 deletions test/pytests/test_main_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pathlib import Path
import sys
from types import ModuleType, SimpleNamespace
from typing import Any, Callable, Literal, cast
from typing import Any, cast

import click
from click.testing import CliRunner
Expand All @@ -34,42 +34,13 @@
from mycli import main
import mycli.key_bindings
from mycli.packages.sqlresult import SQLResult


class DummyLogger:
def __init__(self) -> None:
self.debug_calls: list[tuple[tuple[Any, ...], dict[str, Any]]] = []
self.error_calls: list[tuple[tuple[Any, ...], dict[str, Any]]] = []
self.warning_calls: list[tuple[tuple[Any, ...], dict[str, Any]]] = []

def debug(self, *args: Any, **kwargs: Any) -> None:
self.debug_calls.append((args, kwargs))

def error(self, *args: Any, **kwargs: Any) -> None:
self.error_calls.append((args, kwargs))

def warning(self, *args: Any, **kwargs: Any) -> None:
self.warning_calls.append((args, kwargs))


class DummyFormatter:
def __init__(self, format_name: str = 'ascii') -> None:
self.format_name = format_name
self.query = ''
self.supported_formats = ['ascii', 'csv', 'tsv', 'vertical']
self._output_formats = {
'ascii': SimpleNamespace(formatter_args={'missing_value': main.DEFAULT_MISSING_VALUE}),
'csv': SimpleNamespace(formatter_args={'missing_value': main.DEFAULT_MISSING_VALUE}),
'tsv': SimpleNamespace(formatter_args={'missing_value': main.DEFAULT_MISSING_VALUE}),
'vertical': SimpleNamespace(formatter_args={'missing_value': main.DEFAULT_MISSING_VALUE}),
}
self.calls: list[tuple[tuple[Any, ...], dict[str, Any]]] = []

def format_output(self, rows: Any, header: Any, format_name: str | None = None, **kwargs: Any) -> list[str] | str:
self.calls.append(((rows, header, format_name), kwargs))
if format_name == 'vertical':
return ['vertical output']
return ['plain output']
from test.utils import ( # type: ignore[attr-defined]
DummyFormatter,
DummyLogger,
call_click_entrypoint_direct,
make_bare_mycli,
make_dummy_mycli_class,
)


class FakeCursorBase:
Expand Down Expand Up @@ -100,19 +71,6 @@ def ping(self, reconnect: bool = False) -> None:
raise self.ping_exc


class ReusableLock:
def __init__(self, on_enter: Callable[[], Any] | None = None) -> None:
self.on_enter = on_enter

def __enter__(self) -> 'ReusableLock':
if self.on_enter is not None:
self.on_enter()
return self

def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> Literal[False]:
return False


class BoolSection(dict[str, Any]):
def as_bool(self, key: str) -> bool:
return str(self[key]).lower() == 'true'
Expand Down Expand Up @@ -154,63 +112,6 @@ def __int__(self) -> int:
raise ValueError('bad int')


def make_bare_mycli() -> Any:
cli = object.__new__(main.MyCli)
cli.logger = cast(Any, DummyLogger())
cli.main_formatter = DummyFormatter()
cli.redirect_formatter = DummyFormatter()
cli.helpers_style = 'helpers-style'
cli.helpers_warnings_style = 'helpers-warnings-style'
cli.ptoolkit_style = cast(Any, 'pt-style')
cli.syntax_style = 'native'
cli.cli_style = {}
cli.null_string = '<null>'
cli.numeric_alignment = 'right'
cli.binary_display = None
cli.show_warnings = False
cli.query_history = []
cli.toolbar_error_message = None
cli.prompt_session = None
cli.last_prompt_message = main.ANSI('')
cli.last_custom_toolbar_message = main.ANSI('')
cli.prompt_lines = 0
cli.prompt_format = main.MyCli.default_prompt
cli.multiline_continuation_char = '>'
cli.toolbar_format = 'default'
cli.destructive_warning = False
cli.destructive_keywords = ['drop']
cli.keepalive_ticks = None
cli._keepalive_counter = 0
cli.less_chatty = True
cli.smart_completion = False
cli.key_bindings = 'emacs'
cli.auto_vertical_output = False
cli.wider_completion_menu = False
cli.explicit_pager = False
cli._completer_lock = cast(Any, ReusableLock())
cli.terminal_tab_title_format = ''
cli.terminal_window_title_format = ''
cli.multiplex_window_title_format = ''
cli.multiplex_pane_title_format = ''
cli.dsn_alias = None
cli.login_path = None
cli.login_path_as_host = False
cli.post_redirect_command = None
cli.logfile = None
cli.emacs_ttimeoutlen = 1.0
cli.vi_ttimeoutlen = 1.0
cli.beep_after_seconds = 0.0
cli.config = {'history_file': '~/.mycli-history-testing'}
cli.output = lambda *args, **kwargs: None # type: ignore[assignment]
cli.echo = lambda *args, **kwargs: None # type: ignore[assignment]
cli.log_query = lambda *args, **kwargs: None # type: ignore[assignment]
cli.log_output = lambda *args, **kwargs: None # type: ignore[assignment]
cli.configure_pager = lambda: None # type: ignore[assignment]
cli.refresh_completions = lambda reset=False: [SQLResult(status='refresh')] # type: ignore[assignment]
cli.reconnect = lambda database='': False # type: ignore[assignment]
return cli


def load_main_variant(monkeypatch: pytest.MonkeyPatch, *, fail_pwd: bool = False) -> ModuleType:
import builtins

Expand All @@ -232,53 +133,6 @@ def fake_import(name: str, globals: Any = None, locals: Any = None, fromlist: An
return module


def make_dummy_mycli_class(
*,
config: dict[str, Any] | None = None,
my_cnf: dict[str, Any] | None = None,
config_without_package_defaults: dict[str, Any] | None = None,
) -> Any:
class DummyMyCli:
last_instance: Any = None

def __init__(self, **kwargs: Any) -> None:
type(self).last_instance = self
self.init_kwargs = dict(kwargs)
self.config = config or {'main': {}, 'alias_dsn': {}}
self.my_cnf = my_cnf or {'client': {}, 'mysqld': {}}
self.config_without_package_defaults = config_without_package_defaults or {}
self.default_keepalive_ticks = 5
self.ssl_mode = None
self.logger = DummyLogger()
self.main_formatter = SimpleNamespace(format_name=None)
self.destructive_warning = False
self.destructive_keywords = ['drop']
self.dsn_alias = None
self.connect_calls: list[dict[str, Any]] = []
self.run_query_calls: list[tuple[str, Any, bool]] = []
self.run_cli_called = False
self.close_called = False

def connect(self, **kwargs: Any) -> None:
self.connect_calls.append(dict(kwargs))

def run_query(self, query: str, checkpoint: Any = None, new_line: bool = True) -> None:
self.run_query_calls.append((query, checkpoint, new_line))

def run_cli(self) -> None:
self.run_cli_called = True

def close(self) -> None:
self.close_called = True

return DummyMyCli


def call_click_entrypoint_direct(cli_args: main.CliArgs) -> None:
assert main.click_entrypoint.callback is not None
cast(Any, main.click_entrypoint.callback).__wrapped__(cli_args)


def test_import_fallbacks_for_pwd(monkeypatch: pytest.MonkeyPatch) -> None:
module = load_main_variant(monkeypatch, fail_pwd=True)

Expand Down
Loading
Loading