From 44b37f207242653f2b69f0043fe04e6d78173b11 Mon Sep 17 00:00:00 2001 From: Scott Nemes Date: Wed, 8 Apr 2026 15:13:11 -0700 Subject: [PATCH 1/4] Add supported for expired password (sandbox) mode (#440) --- mycli/constants.py | 4 + mycli/main.py | 35 ++++++++ mycli/main_modes/repl.py | 100 ++++++++++++++++++--- mycli/myclirc | 3 + mycli/sqlexecute.py | 86 +++++++++++++----- test/myclirc | 3 + test/pytests/test_main_modes_repl.py | 129 +++++++++++++++++++++++++++ test/pytests/test_main_regression.py | 1 + test/pytests/test_sqlexecute.py | 68 ++++++++++++++ 9 files changed, 396 insertions(+), 33 deletions(-) diff --git a/mycli/constants.py b/mycli/constants.py index 2d278ae4..f6ef1900 100644 --- a/mycli/constants.py +++ b/mycli/constants.py @@ -13,3 +13,7 @@ DEFAULT_WIDTH = 80 DEFAULT_HEIGHT = 25 + +# MySQL error codes not available in pymysql.constants.ER +ER_MUST_CHANGE_PASSWORD_LOGIN = 1862 +ER_MUST_CHANGE_PASSWORD = 1820 diff --git a/mycli/main.py b/mycli/main.py index bbc2fb55..aa2c51ee 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -58,6 +58,7 @@ DEFAULT_HOST, DEFAULT_PORT, DEFAULT_WIDTH, + ER_MUST_CHANGE_PASSWORD_LOGIN, ISSUES_URL, REPO_URL, ) @@ -152,6 +153,7 @@ def __init__( self.prompt_session: PromptSession | None = None self._keepalive_counter = 0 self.keepalive_ticks: int | None = 0 + self.sandbox_mode: bool = False # self.cnf_files is a class variable that stores the list of mysql # config files to read in at launch. @@ -556,6 +558,7 @@ def connect( reset_keyring: bool | None = None, keepalive_ticks: int | None = None, show_warnings: bool | None = None, + connect_expired_password: bool = False, ) -> None: cnf = { "database": None, @@ -684,6 +687,13 @@ def connect( # should not fail, but will help the typechecker assert not isinstance(passwd, int) + # CLI flag takes precedence; fall back to config file setting + if not connect_expired_password: + try: + connect_expired_password = str_to_bool(user_connection_config.get('connect_expired_password', '')) + except (TypeError, ValueError): + connect_expired_password = False + connection_info: dict[Any, Any] = { "database": database, "user": user, @@ -701,6 +711,7 @@ def connect( "ssh_key_filename": ssh_key_filename, "init_command": init_command, "unbuffered": unbuffered, + "connect_expired_password": connect_expired_password, } def _update_keyring(password: str | None, keyring_retrieved_cleanly: bool): @@ -750,6 +761,16 @@ def _connect( keyring_retrieved_cleanly=keyring_retrieved_cleanly, keyring_save_eligible=keyring_save_eligible, ) + elif e1.args[0] == ER_MUST_CHANGE_PASSWORD_LOGIN: + self.echo( + ( + "Your password has expired. Use the --connect-expired-password flag or " + "connect_expired_password config option to enter sandbox mode." + ), + err=True, + fg='red', + ) + raise e1 elif e1.args[0] == CR_SERVER_LOST: self.echo( ( @@ -803,6 +824,15 @@ def _connect( sys.exit(1) _connect(keyring_retrieved_cleanly=keyring_retrieved_cleanly) + + # Check if SQLExecute detected sandbox mode during connection + if self.sqlexecute and self.sqlexecute.sandbox_mode: + self.sandbox_mode = True + self.echo( + "Your password has expired. Use ALTER USER to set a new password, or quit.", + err=True, + fg='yellow', + ) except Exception as e: # Connecting to a database could fail. self.logger.debug("Database connection failed: %r.", e) self.logger.error("traceback: %r", traceback.format_exc()) @@ -1406,6 +1436,10 @@ class CliArgs: default=None, help='Enable/disable LOAD DATA LOCAL INFILE.', ) + connect_expired_password: bool = clickdc.option( + is_flag=True, + help='Notify the server that this client is prepared to handle expired password sandbox mode.', + ) login_path: str | None = clickdc.option( '-g', type=str, @@ -1899,6 +1933,7 @@ def get_password_from_file(password_file: str | None) -> str | None: reset_keyring=reset_keyring, keepalive_ticks=keepalive_ticks, show_warnings=cli_args.show_warnings, + connect_expired_password=cli_args.connect_expired_password, ) if combined_init_cmd: diff --git a/mycli/main_modes/repl.py b/mycli/main_modes/repl.py index 17edcd19..9112ec8b 100644 --- a/mycli/main_modes/repl.py +++ b/mycli/main_modes/repl.py @@ -39,6 +39,7 @@ from mycli.constants import ( DEFAULT_HOST, DEFAULT_WIDTH, + ER_MUST_CHANGE_PASSWORD, HOME_URL, ISSUES_URL, ) @@ -132,7 +133,8 @@ def _show_startup_banner( if mycli.less_chatty: return - print(sqlexecute.server_info) + if sqlexecute.server_info is not None: + print(sqlexecute.server_info) print('mycli', mycli_package.__version__) print(SUPPORT_INFO) if random.random() <= 0.5: @@ -230,8 +232,6 @@ def get_prompt( ) -> str: sqlexecute = mycli.sqlexecute assert sqlexecute is not None - assert sqlexecute.server_info is not None - assert sqlexecute.server_info.species is not None if mycli.login_path and mycli.login_path_as_host: prompt_host = mycli.login_path elif sqlexecute.host is not None: @@ -248,7 +248,8 @@ def get_prompt( string = string.replace('\\h', prompt_host or '(none)') string = string.replace('\\H', short_prompt_host or '(none)') string = string.replace('\\d', sqlexecute.dbname or '(none)') - string = string.replace('\\t', sqlexecute.server_info.species.name) + species_name = sqlexecute.server_info.species.name if sqlexecute.server_info and sqlexecute.server_info.species else 'MySQL' + string = string.replace('\\t', species_name) string = string.replace('\\n', '\n') string = string.replace('\\D', now.strftime('%a %b %d %H:%M:%S %Y')) string = string.replace('\\m', now.strftime('%M')) @@ -516,6 +517,52 @@ def _build_prompt_session( mycli.prompt_session.app.ttimeoutlen = mycli.emacs_ttimeoutlen +_SANDBOX_ALLOWED_RE = re.compile( + r'^\s*(ALTER\s+USER|SET\s+PASSWORD|QUIT|EXIT|\\q)\b', + re.IGNORECASE, +) + +_PASSWORD_CHANGE_RE = re.compile( + r'^\s*(ALTER\s+USER|SET\s+PASSWORD)\b', + re.IGNORECASE, +) + + +def _is_sandbox_allowed(text: str) -> bool: + """Return True if the command is allowed in expired-password sandbox mode.""" + stripped = text.strip() + if not stripped: + return True + return bool(_SANDBOX_ALLOWED_RE.match(stripped)) + + +def _is_password_change(text: str) -> bool: + """Return True if the command is a password change statement.""" + return bool(_PASSWORD_CHANGE_RE.match(text.strip())) + + +_IDENTIFIED_BY_RE = re.compile( + r"IDENTIFIED\s+BY\s+'([^']*)'", + re.IGNORECASE, +) + +_SET_PASSWORD_RE = re.compile( + r"SET\s+PASSWORD\s*=\s*'([^']*)'", + re.IGNORECASE, +) + + +def _extract_new_password(text: str) -> str | None: + """Extract the new password from an ALTER USER or SET PASSWORD statement.""" + m = _IDENTIFIED_BY_RE.search(text) + if m: + return m.group(1) + m = _SET_PASSWORD_RE.search(text) + if m: + return m.group(1) + return None + + def _one_iteration( mycli: 'MyCli', state: ReplState, @@ -613,6 +660,14 @@ def _one_iteration( mycli.echo(str(e), err=True, fg='red') return + if mycli.sandbox_mode and not _is_sandbox_allowed(text): + mycli.echo( + "ERROR 1820: You must reset your password using ALTER USER or SET PASSWORD before executing this statement.", + err=True, + fg='red', + ) + return + if mycli.destructive_warning: destroy = confirm_destructive_query(mycli.destructive_keywords, text) if destroy is None: @@ -672,20 +727,45 @@ def _one_iteration( mycli.echo('Not Yet Implemented.', fg='yellow') except pymysql.OperationalError as e1: mycli.logger.debug('Exception: %r', e1) - if e1.args[0] in (2003, 2006, 2013): + if e1.args[0] == ER_MUST_CHANGE_PASSWORD: + mycli.sandbox_mode = True + mycli.echo( + "ERROR 1820: You must reset your password using ALTER USER or SET PASSWORD before executing this statement.", + err=True, + fg='red', + ) + elif e1.args[0] in (2003, 2006, 2013): if not mycli.reconnect(): return _one_iteration(mycli, state, text) return - - mycli.logger.error('sql: %r, error: %r', text, e1) - mycli.logger.error('traceback: %r', traceback.format_exc()) - mycli.echo(str(e1), err=True, fg='red') + else: + mycli.logger.error('sql: %r, error: %r', text, e1) + mycli.logger.error('traceback: %r', traceback.format_exc()) + mycli.echo(str(e1), err=True, fg='red') except Exception as e: mycli.logger.error('sql: %r, error: %r', text, e) mycli.logger.error('traceback: %r', traceback.format_exc()) mycli.echo(str(e), err=True, fg='red') else: + if mycli.sandbox_mode and _is_password_change(text): + new_password = _extract_new_password(text) + if new_password is not None: + sqlexecute.password = new_password + sqlexecute.connect_expired_password = False + try: + sqlexecute.connect() + mycli.sandbox_mode = False + mycli.echo("Password changed successfully. Reconnected.", err=True, fg='green') + mycli.refresh_completions() + except Exception as e: + mycli.sandbox_mode = False + mycli.echo( + f"Password changed but reconnection failed: {e}\nPlease restart mycli with your new password.", + err=True, + fg='yellow', + ) + if is_dropping_database(text, sqlexecute.dbname): sqlexecute.dbname = None sqlexecute.connect() @@ -744,7 +824,7 @@ def main_repl(mycli: 'MyCli') -> None: state = ReplState() mycli.configure_pager() - if mycli.smart_completion: + if mycli.smart_completion and not mycli.sandbox_mode: mycli.refresh_completions() history = _create_history(mycli) diff --git a/mycli/myclirc b/mycli/myclirc index ff44a15e..177cff05 100644 --- a/mycli/myclirc +++ b/mycli/myclirc @@ -209,6 +209,9 @@ default_character_set = utf8mb4 # whether to enable LOAD DATA LOCAL INFILE for connections without --local-infile being set default_local_infile = False +# whether to allow connecting with an expired password (enters sandbox mode) +connect_expired_password = True + # How often to send periodic background pings to the server when input is idle. Ticks are # roughly in seconds, but may be faster. Set to zero to disable. Suggestion: 300. default_keepalive_ticks = 0 diff --git a/mycli/sqlexecute.py b/mycli/sqlexecute.py index 40b933a5..dc811933 100644 --- a/mycli/sqlexecute.py +++ b/mycli/sqlexecute.py @@ -175,6 +175,7 @@ def __init__( ssh_key_filename: str | None, init_command: str | None = None, unbuffered: bool | None = None, + connect_expired_password: bool = False, ) -> None: self.dbname = database self.user = user @@ -194,6 +195,7 @@ def __init__( self.ssh_key_filename = ssh_key_filename self.init_command = init_command self.unbuffered = unbuffered + self.connect_expired_password = connect_expired_password self.conn: Connection | None = None self.connect() @@ -280,32 +282,51 @@ def connect( client_flag = pymysql.constants.CLIENT.INTERACTIVE if init_command and len(list(iocommands.split_queries(init_command))) > 1: client_flag |= pymysql.constants.CLIENT.MULTI_STATEMENTS + if self.connect_expired_password: + client_flag |= pymysql.constants.CLIENT.HANDLE_EXPIRED_PASSWORDS ssl_context = None if ssl: ssl_context = self._create_ssl_ctx(ssl) - conn = pymysql.connect( - database=db, - user=user, - password=password or '', - host=host, - port=port or 0, - unix_socket=socket, - use_unicode=True, - charset=character_set or '', - autocommit=True, - client_flag=client_flag, - local_infile=local_infile or False, - conv=conv, - ssl=ssl_context, # type: ignore[arg-type] - program_name="mycli", - defer_connect=defer_connect, - init_command=init_command or None, - cursorclass=pymysql.cursors.SSCursor if unbuffered else pymysql.cursors.Cursor, - ) # type: ignore[misc] + connect_kwargs: dict[str, Any] = { + "database": db, + "user": user, + "password": password or '', + "host": host, + "port": port or 0, + "unix_socket": socket, + "use_unicode": True, + "charset": character_set or '', + "autocommit": True, + "client_flag": client_flag, + "local_infile": local_infile or False, + "conv": conv, + "ssl": ssl_context, # type: ignore[arg-type] + "program_name": "mycli", + "defer_connect": defer_connect, + "init_command": init_command or None, + "cursorclass": pymysql.cursors.SSCursor if unbuffered else pymysql.cursors.Cursor, + } + + self.sandbox_mode = False + try: + conn = pymysql.connect(**connect_kwargs) # type: ignore[misc] + except pymysql.OperationalError as e: + if e.args[0] == 1820 and self.connect_expired_password: + # Post-handshake queries (SET NAMES, SET AUTOCOMMIT, init_command) + # fail with ER_MUST_CHANGE_PASSWORD in sandbox mode. + # Reconnect with only the raw handshake. + connect_kwargs['defer_connect'] = True + connect_kwargs['autocommit'] = None + connect_kwargs['init_command'] = None + conn = pymysql.connect(**connect_kwargs) # type: ignore[misc] + self._connect_sandbox(conn) + self.sandbox_mode = True + else: + raise - if ssh_host: + if ssh_host and not self.sandbox_mode: ##### paramiko.Channel is a bad socket implementation overall if you want SSL through an SSH tunnel ##### # instead let's open a tunnel and rewrite host:port to local bind @@ -343,9 +364,10 @@ def connect( self.ssl = ssl self.init_command = init_command self.unbuffered = unbuffered - # retrieve connection id - self.reset_connection_id() - self.server_info = ServerInfo.from_version_string(conn.server_version) # type: ignore[attr-defined] + # retrieve connection id (skip in sandbox mode as queries will fail) + if not self.sandbox_mode: + self.reset_connection_id() + self.server_info = ServerInfo.from_version_string(conn.server_version) # type: ignore[attr-defined] def run(self, statement: str) -> Generator[SQLResult, None, None]: """Execute the sql in the database and return the results.""" @@ -576,6 +598,24 @@ def change_db(self, db: str) -> None: self.conn.select_db(db) self.dbname = db + @staticmethod + def _connect_sandbox(conn: Connection) -> None: + """Connect in sandbox mode, performing only the handshake. + + pymysql's normal connect() runs post-handshake queries (SET NAMES, + SET AUTOCOMMIT, init_command) that all fail with ER_MUST_CHANGE_PASSWORD + in sandbox mode. This method performs the raw socket connection and + authentication handshake only. + """ + # Reuse pymysql internals for the handshake + auth, but + # temporarily stub out set_character_set so it becomes a no-op. + original_set_charset = conn.set_character_set + conn.set_character_set = lambda *_args, **_kwargs: None # type: ignore[assignment] + try: + conn.connect() + finally: + conn.set_character_set = original_set_charset # type: ignore[assignment] + def _create_ssl_ctx(self, sslp: dict) -> ssl.SSLContext: ca = sslp.get("ca") capath = sslp.get("capath") diff --git a/test/myclirc b/test/myclirc index fa10eabf..f688a0e9 100644 --- a/test/myclirc +++ b/test/myclirc @@ -207,6 +207,9 @@ default_character_set = utf8mb4 # whether to enable LOAD DATA LOCAL INFILE for connections without --local-infile being set default_local_infile = False +# whether to allow connecting with an expired password (enters sandbox mode) +connect_expired_password = True + # How often to send periodic background pings to the server when input is idle. Ticks are # roughly in seconds, but may be faster. Set to zero to disable. Suggestion: 300. default_keepalive_ticks = 0 diff --git a/test/pytests/test_main_modes_repl.py b/test/pytests/test_main_modes_repl.py index 919aa575..3e9fb995 100644 --- a/test/pytests/test_main_modes_repl.py +++ b/test/pytests/test_main_modes_repl.py @@ -174,6 +174,7 @@ def make_repl_cli(sqlexecute: Any | None = None) -> Any: cli.cli_style = {} cli.emacs_ttimeoutlen = 1.0 cli.vi_ttimeoutlen = 2.0 + cli.sandbox_mode = False cli.destructive_warning = False cli.destructive_keywords = ['drop'] cli.llm_prompt_field_truncate = 0 @@ -793,6 +794,134 @@ def run(self, text: str) -> Iterator[SQLResult]: assert cli_quiet.output_calls[0][0] == ['None', 'ran:select 2'] +@pytest.mark.parametrize( + 'text, expected', + [ + ('', True), + (' ', True), + ("ALTER USER 'root'@'localhost' IDENTIFIED BY 'new'", True), + ('alter user root identified by "pw"', True), + ("SET PASSWORD = 'newpass'", True), + ("set password = 'newpass'", True), + ('quit', True), + ('exit', True), + ('\\q', True), + ('SELECT 1', False), + ('DROP TABLE t', False), + ('USE mydb', False), + ('SHOW DATABASES', False), + ], +) +def test_is_sandbox_allowed(text: str, expected: bool) -> None: + assert repl_mode._is_sandbox_allowed(text) is expected + + +@pytest.mark.parametrize( + 'text, expected', + [ + ("ALTER USER 'root'@'localhost' IDENTIFIED BY 'new'", True), + ("SET PASSWORD = 'newpass'", True), + ('SELECT 1', False), + ('quit', False), + ], +) +def test_is_password_change(text: str, expected: bool) -> None: + assert repl_mode._is_password_change(text) is expected + + +@pytest.mark.parametrize( + 'text, expected', + [ + ("ALTER USER 'root'@'localhost' IDENTIFIED BY 'newpass'", 'newpass'), + ("SET PASSWORD = 'secret123'", 'secret123'), + ("ALTER USER root IDENTIFIED BY 'p@ss w0rd!'", 'p@ss w0rd!'), + ('ALTER USER root IDENTIFIED WITH mysql_native_password', None), + ('SELECT 1', None), + ], +) +def test_extract_new_password(text: str, expected: str | None) -> None: + assert repl_mode._extract_new_password(text) == expected + + +def test_one_iteration_blocks_disallowed_in_sandbox_mode(monkeypatch: pytest.MonkeyPatch) -> None: + patch_repl_runtime_defaults(monkeypatch) + + class FakeSQLExecute: + def __init__(self) -> None: + self.dbname = 'db' + self.connection_id = 0 + + def run(self, text: str) -> Iterator[SQLResult]: + return iter([SQLResult(status=f'ran:{text}')]) + + cli = make_repl_cli(FakeSQLExecute()) + cli.sandbox_mode = True + + repl_mode._one_iteration(cli, repl_mode.ReplState(), 'SELECT 1') + assert any('ERROR 1820' in msg for msg in cli.echo_calls) + assert not cli.query_history + + +def test_one_iteration_allows_alter_user_in_sandbox_mode(monkeypatch: pytest.MonkeyPatch) -> None: + patch_repl_runtime_defaults(monkeypatch) + + class FakeSQLExecute: + def __init__(self) -> None: + self.dbname = 'db' + self.connection_id = 0 + self.password = 'old' + self.connect_expired_password = True + self.connect_calls: list[bool] = [] + + def connect(self) -> None: + self.connect_calls.append(True) + + def run(self, text: str) -> Iterator[SQLResult]: + return iter([SQLResult(status='OK')]) + + sqlexecute = FakeSQLExecute() + cli = make_repl_cli(sqlexecute) + cli.sandbox_mode = True + monkeypatch.setattr(repl_mode, 'is_mutating', lambda status: False) + + repl_mode._one_iteration( + cli, repl_mode.ReplState(), "ALTER USER 'root'@'localhost' IDENTIFIED BY 'newpass'" + ) + assert cli.sandbox_mode is False + assert sqlexecute.password == 'newpass' + assert sqlexecute.connect_expired_password is False + assert sqlexecute.connect_calls == [True] + assert any('Reconnected' in msg for msg in cli.echo_calls) + + +def test_one_iteration_sandbox_reconnect_failure(monkeypatch: pytest.MonkeyPatch) -> None: + patch_repl_runtime_defaults(monkeypatch) + + class FakeSQLExecute: + def __init__(self) -> None: + self.dbname = 'db' + self.connection_id = 0 + self.password = 'old' + self.connect_expired_password = True + + def connect(self) -> None: + raise RuntimeError('connection refused') + + def run(self, text: str) -> Iterator[SQLResult]: + return iter([SQLResult(status='OK')]) + + sqlexecute = FakeSQLExecute() + cli = make_repl_cli(sqlexecute) + cli.sandbox_mode = True + monkeypatch.setattr(repl_mode, 'is_mutating', lambda status: False) + + repl_mode._one_iteration( + cli, repl_mode.ReplState(), "ALTER USER 'root'@'localhost' IDENTIFIED BY 'newpass'" + ) + assert cli.sandbox_mode is False + assert any('reconnection failed' in msg for msg in cli.echo_calls) + + def test_one_iteration_covers_redirect_destructive_success_refresh_and_logfile(monkeypatch: pytest.MonkeyPatch) -> None: patch_repl_runtime_defaults(monkeypatch) diff --git a/test/pytests/test_main_regression.py b/test/pytests/test_main_regression.py index c813530c..49ae590c 100644 --- a/test/pytests/test_main_regression.py +++ b/test/pytests/test_main_regression.py @@ -134,6 +134,7 @@ def __init__(self, **kwargs: Any) -> None: self.dbname = kwargs.get('database') self.user = kwargs.get('user') self.conn = kwargs.get('conn') + self.sandbox_mode = False class ToggleBool: diff --git a/test/pytests/test_sqlexecute.py b/test/pytests/test_sqlexecute.py index 5155cb9a..69461e41 100644 --- a/test/pytests/test_sqlexecute.py +++ b/test/pytests/test_sqlexecute.py @@ -673,6 +673,8 @@ def make_executor_for_connect_tests() -> SQLExecute: executor.ssh_key_filename = '/stored/key.pem' executor.init_command = 'select 1' executor.unbuffered = False + executor.connect_expired_password = False + executor.sandbox_mode = False executor.conn = None return executor @@ -762,6 +764,72 @@ def fake_reset_connection_id(self) -> None: assert executor.server_info.version == 80036 +def test_connect_sets_expired_password_flag(monkeypatch) -> None: + executor = make_executor_for_connect_tests() + executor.connect_expired_password = True + executor.ssl = None + + new_conn = DummyConnection(server_version='8.0.36-0ubuntu0.22.04.1') + connect_kwargs = {} + + def fake_connect(**kwargs): + connect_kwargs.update(kwargs) + return new_conn + + monkeypatch.setattr(sqlexecute.pymysql, 'connect', fake_connect) + monkeypatch.setattr(SQLExecute, 'reset_connection_id', lambda self: None) + + executor.connect() + + assert connect_kwargs['client_flag'] & sqlexecute.pymysql.constants.CLIENT.HANDLE_EXPIRED_PASSWORDS + assert executor.sandbox_mode is False + + +def test_connect_falls_back_to_sandbox_on_1820(monkeypatch) -> None: + executor = make_executor_for_connect_tests() + executor.connect_expired_password = True + executor.ssl = None + + new_conn = DummyConnection(server_version='8.0.36-0ubuntu0.22.04.1') + call_count = 0 + sandbox_calls = [] + + def fake_connect(**kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise pymysql.OperationalError(1820, 'must change password') + return new_conn + + def fake_connect_sandbox(self, conn): + sandbox_calls.append(conn) + + monkeypatch.setattr(sqlexecute.pymysql, 'connect', fake_connect) + monkeypatch.setattr(SQLExecute, '_connect_sandbox', fake_connect_sandbox) + + executor.connect() + + assert call_count == 2 + assert len(sandbox_calls) == 1 + assert executor.sandbox_mode is True + assert executor.server_info is None + assert executor.connection_id is None + + +def test_connect_1820_without_expired_flag_reraises(monkeypatch) -> None: + executor = make_executor_for_connect_tests() + executor.connect_expired_password = False + executor.ssl = None + + def fake_connect(**kwargs): + raise pymysql.OperationalError(1820, 'must change password') + + monkeypatch.setattr(sqlexecute.pymysql, 'connect', fake_connect) + + with pytest.raises(pymysql.OperationalError, match='must change password'): + executor.connect() + + def test_connect_uses_ssh_tunnel_when_ssh_host_is_set(monkeypatch) -> None: executor = make_executor_for_connect_tests() executor.ssl = None From b61ff4394fb0d3ab34ccf58aa0a327bd8d1d9ff7 Mon Sep 17 00:00:00 2001 From: Scott Nemes Date: Wed, 8 Apr 2026 15:17:18 -0700 Subject: [PATCH 2/4] Updated changelog --- changelog.md | 1 + 1 file changed, 1 insertion(+) diff --git a/changelog.md b/changelog.md index f8bcafe4..d2fdf3c8 100644 --- a/changelog.md +++ b/changelog.md @@ -7,6 +7,7 @@ Features * Make `--progress` and `--checkpoint` strictly by statement. * Allow more characters in passwords read from a file. * Show sponsors and contributors separately in startup messages. +* Add support for expired password (sandbox) mode (#440) Bug Fixes From 992f50036b161a4d05ee96066b41928394a72ac3 Mon Sep 17 00:00:00 2001 From: Scott Nemes Date: Wed, 8 Apr 2026 15:28:43 -0700 Subject: [PATCH 3/4] Forgot to checkin the formatted file --- test/pytests/test_main_modes_repl.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/test/pytests/test_main_modes_repl.py b/test/pytests/test_main_modes_repl.py index 04e78e13..2cd6f0cc 100644 --- a/test/pytests/test_main_modes_repl.py +++ b/test/pytests/test_main_modes_repl.py @@ -887,9 +887,7 @@ def run(self, text: str) -> Iterator[SQLResult]: cli.sandbox_mode = True monkeypatch.setattr(repl_mode, 'is_mutating', lambda status: False) - repl_mode._one_iteration( - cli, repl_mode.ReplState(), "ALTER USER 'root'@'localhost' IDENTIFIED BY 'newpass'" - ) + repl_mode._one_iteration(cli, repl_mode.ReplState(), "ALTER USER 'root'@'localhost' IDENTIFIED BY 'newpass'") assert cli.sandbox_mode is False assert sqlexecute.password == 'newpass' assert sqlexecute.connect_expired_password is False @@ -918,9 +916,7 @@ def run(self, text: str) -> Iterator[SQLResult]: cli.sandbox_mode = True monkeypatch.setattr(repl_mode, 'is_mutating', lambda status: False) - repl_mode._one_iteration( - cli, repl_mode.ReplState(), "ALTER USER 'root'@'localhost' IDENTIFIED BY 'newpass'" - ) + repl_mode._one_iteration(cli, repl_mode.ReplState(), "ALTER USER 'root'@'localhost' IDENTIFIED BY 'newpass'") assert cli.sandbox_mode is False assert any('reconnection failed' in msg for msg in cli.echo_calls) From 88e609a042483449d6785b60610afd107e5b4e4d Mon Sep 17 00:00:00 2001 From: Scott Nemes Date: Thu, 9 Apr 2026 10:45:07 -0700 Subject: [PATCH 4/4] Removed CLI and option file options; sandbox mode will be used any time a user with an expired password attempts to connect --- mycli/main.py | 19 +------------------ mycli/main_modes/repl.py | 1 - mycli/myclirc | 3 --- mycli/sqlexecute.py | 8 +++----- test/myclirc | 3 --- test/pytests/test_main_modes_repl.py | 3 --- test/pytests/test_sqlexecute.py | 17 ----------------- 7 files changed, 4 insertions(+), 50 deletions(-) diff --git a/mycli/main.py b/mycli/main.py index da4dcf0e..0c780d47 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -558,7 +558,6 @@ def connect( reset_keyring: bool | None = None, keepalive_ticks: int | None = None, show_warnings: bool | None = None, - connect_expired_password: bool = False, ) -> None: cnf = { "database": None, @@ -687,13 +686,6 @@ def connect( # should not fail, but will help the typechecker assert not isinstance(passwd, int) - # CLI flag takes precedence; fall back to config file setting - if not connect_expired_password: - try: - connect_expired_password = str_to_bool(user_connection_config.get('connect_expired_password', '')) - except (TypeError, ValueError): - connect_expired_password = False - connection_info: dict[Any, Any] = { "database": database, "user": user, @@ -711,7 +703,6 @@ def connect( "ssh_key_filename": ssh_key_filename, "init_command": init_command, "unbuffered": unbuffered, - "connect_expired_password": connect_expired_password, } def _update_keyring(password: str | None, keyring_retrieved_cleanly: bool): @@ -763,10 +754,7 @@ def _connect( ) elif e1.args[0] == ER_MUST_CHANGE_PASSWORD_LOGIN: self.echo( - ( - "Your password has expired. Use the --connect-expired-password flag or " - "connect_expired_password config option to enter sandbox mode." - ), + "Your password has expired and the server rejected the connection.", err=True, fg='red', ) @@ -1436,10 +1424,6 @@ class CliArgs: default=None, help='Enable/disable LOAD DATA LOCAL INFILE.', ) - connect_expired_password: bool = clickdc.option( - is_flag=True, - help='Notify the server that this client is prepared to handle expired password sandbox mode.', - ) login_path: str | None = clickdc.option( '-g', type=str, @@ -1933,7 +1917,6 @@ def get_password_from_file(password_file: str | None) -> str | None: reset_keyring=reset_keyring, keepalive_ticks=keepalive_ticks, show_warnings=cli_args.show_warnings, - connect_expired_password=cli_args.connect_expired_password, ) if combined_init_cmd: diff --git a/mycli/main_modes/repl.py b/mycli/main_modes/repl.py index 70631973..090cb785 100644 --- a/mycli/main_modes/repl.py +++ b/mycli/main_modes/repl.py @@ -754,7 +754,6 @@ def _one_iteration( new_password = _extract_new_password(text) if new_password is not None: sqlexecute.password = new_password - sqlexecute.connect_expired_password = False try: sqlexecute.connect() mycli.sandbox_mode = False diff --git a/mycli/myclirc b/mycli/myclirc index 177cff05..ff44a15e 100644 --- a/mycli/myclirc +++ b/mycli/myclirc @@ -209,9 +209,6 @@ default_character_set = utf8mb4 # whether to enable LOAD DATA LOCAL INFILE for connections without --local-infile being set default_local_infile = False -# whether to allow connecting with an expired password (enters sandbox mode) -connect_expired_password = True - # How often to send periodic background pings to the server when input is idle. Ticks are # roughly in seconds, but may be faster. Set to zero to disable. Suggestion: 300. default_keepalive_ticks = 0 diff --git a/mycli/sqlexecute.py b/mycli/sqlexecute.py index dc811933..b045a4c6 100644 --- a/mycli/sqlexecute.py +++ b/mycli/sqlexecute.py @@ -14,6 +14,7 @@ from pymysql.converters import conversions, convert_date, convert_datetime, convert_time, decoders from pymysql.cursors import Cursor +from mycli.constants import ER_MUST_CHANGE_PASSWORD from mycli.packages.special import iocommands from mycli.packages.special.main import CommandNotFound, execute from mycli.packages.sqlresult import SQLResult @@ -175,7 +176,6 @@ def __init__( ssh_key_filename: str | None, init_command: str | None = None, unbuffered: bool | None = None, - connect_expired_password: bool = False, ) -> None: self.dbname = database self.user = user @@ -195,7 +195,6 @@ def __init__( self.ssh_key_filename = ssh_key_filename self.init_command = init_command self.unbuffered = unbuffered - self.connect_expired_password = connect_expired_password self.conn: Connection | None = None self.connect() @@ -282,8 +281,7 @@ def connect( client_flag = pymysql.constants.CLIENT.INTERACTIVE if init_command and len(list(iocommands.split_queries(init_command))) > 1: client_flag |= pymysql.constants.CLIENT.MULTI_STATEMENTS - if self.connect_expired_password: - client_flag |= pymysql.constants.CLIENT.HANDLE_EXPIRED_PASSWORDS + client_flag |= pymysql.constants.CLIENT.HANDLE_EXPIRED_PASSWORDS ssl_context = None if ssl: @@ -313,7 +311,7 @@ def connect( try: conn = pymysql.connect(**connect_kwargs) # type: ignore[misc] except pymysql.OperationalError as e: - if e.args[0] == 1820 and self.connect_expired_password: + if e.args[0] == ER_MUST_CHANGE_PASSWORD: # Post-handshake queries (SET NAMES, SET AUTOCOMMIT, init_command) # fail with ER_MUST_CHANGE_PASSWORD in sandbox mode. # Reconnect with only the raw handshake. diff --git a/test/myclirc b/test/myclirc index f688a0e9..fa10eabf 100644 --- a/test/myclirc +++ b/test/myclirc @@ -207,9 +207,6 @@ default_character_set = utf8mb4 # whether to enable LOAD DATA LOCAL INFILE for connections without --local-infile being set default_local_infile = False -# whether to allow connecting with an expired password (enters sandbox mode) -connect_expired_password = True - # How often to send periodic background pings to the server when input is idle. Ticks are # roughly in seconds, but may be faster. Set to zero to disable. Suggestion: 300. default_keepalive_ticks = 0 diff --git a/test/pytests/test_main_modes_repl.py b/test/pytests/test_main_modes_repl.py index 2cd6f0cc..ba3ea163 100644 --- a/test/pytests/test_main_modes_repl.py +++ b/test/pytests/test_main_modes_repl.py @@ -873,7 +873,6 @@ def __init__(self) -> None: self.dbname = 'db' self.connection_id = 0 self.password = 'old' - self.connect_expired_password = True self.connect_calls: list[bool] = [] def connect(self) -> None: @@ -890,7 +889,6 @@ def run(self, text: str) -> Iterator[SQLResult]: repl_mode._one_iteration(cli, repl_mode.ReplState(), "ALTER USER 'root'@'localhost' IDENTIFIED BY 'newpass'") assert cli.sandbox_mode is False assert sqlexecute.password == 'newpass' - assert sqlexecute.connect_expired_password is False assert sqlexecute.connect_calls == [True] assert any('Reconnected' in msg for msg in cli.echo_calls) @@ -903,7 +901,6 @@ def __init__(self) -> None: self.dbname = 'db' self.connection_id = 0 self.password = 'old' - self.connect_expired_password = True def connect(self) -> None: raise RuntimeError('connection refused') diff --git a/test/pytests/test_sqlexecute.py b/test/pytests/test_sqlexecute.py index 69461e41..e250b154 100644 --- a/test/pytests/test_sqlexecute.py +++ b/test/pytests/test_sqlexecute.py @@ -673,7 +673,6 @@ def make_executor_for_connect_tests() -> SQLExecute: executor.ssh_key_filename = '/stored/key.pem' executor.init_command = 'select 1' executor.unbuffered = False - executor.connect_expired_password = False executor.sandbox_mode = False executor.conn = None return executor @@ -766,7 +765,6 @@ def fake_reset_connection_id(self) -> None: def test_connect_sets_expired_password_flag(monkeypatch) -> None: executor = make_executor_for_connect_tests() - executor.connect_expired_password = True executor.ssl = None new_conn = DummyConnection(server_version='8.0.36-0ubuntu0.22.04.1') @@ -787,7 +785,6 @@ def fake_connect(**kwargs): def test_connect_falls_back_to_sandbox_on_1820(monkeypatch) -> None: executor = make_executor_for_connect_tests() - executor.connect_expired_password = True executor.ssl = None new_conn = DummyConnection(server_version='8.0.36-0ubuntu0.22.04.1') @@ -816,20 +813,6 @@ def fake_connect_sandbox(self, conn): assert executor.connection_id is None -def test_connect_1820_without_expired_flag_reraises(monkeypatch) -> None: - executor = make_executor_for_connect_tests() - executor.connect_expired_password = False - executor.ssl = None - - def fake_connect(**kwargs): - raise pymysql.OperationalError(1820, 'must change password') - - monkeypatch.setattr(sqlexecute.pymysql, 'connect', fake_connect) - - with pytest.raises(pymysql.OperationalError, match='must change password'): - executor.connect() - - def test_connect_uses_ssh_tunnel_when_ssh_host_is_set(monkeypatch) -> None: executor = make_executor_for_connect_tests() executor.ssl = None