diff --git a/changelog.md b/changelog.md index f8bcafe4..d2fdf3c8 100644 --- a/changelog.md +++ b/changelog.md @@ -7,6 +7,7 @@ Features * Make `--progress` and `--checkpoint` strictly by statement. * Allow more characters in passwords read from a file. * Show sponsors and contributors separately in startup messages. +* Add support for expired password (sandbox) mode (#440) Bug Fixes diff --git a/mycli/constants.py b/mycli/constants.py index 2d278ae4..f6ef1900 100644 --- a/mycli/constants.py +++ b/mycli/constants.py @@ -13,3 +13,7 @@ DEFAULT_WIDTH = 80 DEFAULT_HEIGHT = 25 + +# MySQL error codes not available in pymysql.constants.ER +ER_MUST_CHANGE_PASSWORD_LOGIN = 1862 +ER_MUST_CHANGE_PASSWORD = 1820 diff --git a/mycli/main.py b/mycli/main.py index 57ec4068..0c780d47 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -58,6 +58,7 @@ DEFAULT_HOST, DEFAULT_PORT, DEFAULT_WIDTH, + ER_MUST_CHANGE_PASSWORD_LOGIN, ISSUES_URL, REPO_URL, ) @@ -152,6 +153,7 @@ def __init__( self.prompt_session: PromptSession | None = None self._keepalive_counter = 0 self.keepalive_ticks: int | None = 0 + self.sandbox_mode: bool = False # self.cnf_files is a class variable that stores the list of mysql # config files to read in at launch. @@ -750,6 +752,13 @@ def _connect( keyring_retrieved_cleanly=keyring_retrieved_cleanly, keyring_save_eligible=keyring_save_eligible, ) + elif e1.args[0] == ER_MUST_CHANGE_PASSWORD_LOGIN: + self.echo( + "Your password has expired and the server rejected the connection.", + err=True, + fg='red', + ) + raise e1 elif e1.args[0] == CR_SERVER_LOST: self.echo( ( @@ -803,6 +812,15 @@ def _connect( sys.exit(1) _connect(keyring_retrieved_cleanly=keyring_retrieved_cleanly) + + # Check if SQLExecute detected sandbox mode during connection + if self.sqlexecute and self.sqlexecute.sandbox_mode: + self.sandbox_mode = True + self.echo( + "Your password has expired. Use ALTER USER to set a new password, or quit.", + err=True, + fg='yellow', + ) except Exception as e: # Connecting to a database could fail. self.logger.debug("Database connection failed: %r.", e) self.logger.error("traceback: %r", traceback.format_exc()) diff --git a/mycli/main_modes/repl.py b/mycli/main_modes/repl.py index 2bd6b0a2..090cb785 100644 --- a/mycli/main_modes/repl.py +++ b/mycli/main_modes/repl.py @@ -39,6 +39,7 @@ from mycli.constants import ( DEFAULT_HOST, DEFAULT_WIDTH, + ER_MUST_CHANGE_PASSWORD, HOME_URL, ISSUES_URL, ) @@ -132,7 +133,8 @@ def _show_startup_banner( if mycli.less_chatty: return - print(sqlexecute.server_info) + if sqlexecute.server_info is not None: + print(sqlexecute.server_info) print('mycli', mycli_package.__version__) print(SUPPORT_INFO) if random.random() <= 0.25: @@ -232,8 +234,6 @@ def get_prompt( ) -> str: sqlexecute = mycli.sqlexecute assert sqlexecute is not None - assert sqlexecute.server_info is not None - assert sqlexecute.server_info.species is not None if mycli.login_path and mycli.login_path_as_host: prompt_host = mycli.login_path elif sqlexecute.host is not None: @@ -250,7 +250,8 @@ def get_prompt( string = string.replace('\\h', prompt_host or '(none)') string = string.replace('\\H', short_prompt_host or '(none)') string = string.replace('\\d', sqlexecute.dbname or '(none)') - string = string.replace('\\t', sqlexecute.server_info.species.name) + species_name = sqlexecute.server_info.species.name if sqlexecute.server_info and sqlexecute.server_info.species else 'MySQL' + string = string.replace('\\t', species_name) string = string.replace('\\n', '\n') string = string.replace('\\D', now.strftime('%a %b %d %H:%M:%S %Y')) string = string.replace('\\m', now.strftime('%M')) @@ -518,6 +519,52 @@ def _build_prompt_session( mycli.prompt_session.app.ttimeoutlen = mycli.emacs_ttimeoutlen +_SANDBOX_ALLOWED_RE = re.compile( + r'^\s*(ALTER\s+USER|SET\s+PASSWORD|QUIT|EXIT|\\q)\b', + re.IGNORECASE, +) + +_PASSWORD_CHANGE_RE = re.compile( + r'^\s*(ALTER\s+USER|SET\s+PASSWORD)\b', + re.IGNORECASE, +) + + +def _is_sandbox_allowed(text: str) -> bool: + """Return True if the command is allowed in expired-password sandbox mode.""" + stripped = text.strip() + if not stripped: + return True + return bool(_SANDBOX_ALLOWED_RE.match(stripped)) + + +def _is_password_change(text: str) -> bool: + """Return True if the command is a password change statement.""" + return bool(_PASSWORD_CHANGE_RE.match(text.strip())) + + +_IDENTIFIED_BY_RE = re.compile( + r"IDENTIFIED\s+BY\s+'([^']*)'", + re.IGNORECASE, +) + +_SET_PASSWORD_RE = re.compile( + r"SET\s+PASSWORD\s*=\s*'([^']*)'", + re.IGNORECASE, +) + + +def _extract_new_password(text: str) -> str | None: + """Extract the new password from an ALTER USER or SET PASSWORD statement.""" + m = _IDENTIFIED_BY_RE.search(text) + if m: + return m.group(1) + m = _SET_PASSWORD_RE.search(text) + if m: + return m.group(1) + return None + + def _one_iteration( mycli: 'MyCli', state: ReplState, @@ -615,6 +662,14 @@ def _one_iteration( mycli.echo(str(e), err=True, fg='red') return + if mycli.sandbox_mode and not _is_sandbox_allowed(text): + mycli.echo( + "ERROR 1820: You must reset your password using ALTER USER or SET PASSWORD before executing this statement.", + err=True, + fg='red', + ) + return + if mycli.destructive_warning: destroy = confirm_destructive_query(mycli.destructive_keywords, text) if destroy is None: @@ -674,20 +729,44 @@ def _one_iteration( mycli.echo('Not Yet Implemented.', fg='yellow') except pymysql.OperationalError as e1: mycli.logger.debug('Exception: %r', e1) - if e1.args[0] in (2003, 2006, 2013): + if e1.args[0] == ER_MUST_CHANGE_PASSWORD: + mycli.sandbox_mode = True + mycli.echo( + "ERROR 1820: You must reset your password using ALTER USER or SET PASSWORD before executing this statement.", + err=True, + fg='red', + ) + elif e1.args[0] in (2003, 2006, 2013): if not mycli.reconnect(): return _one_iteration(mycli, state, text) return - - mycli.logger.error('sql: %r, error: %r', text, e1) - mycli.logger.error('traceback: %r', traceback.format_exc()) - mycli.echo(str(e1), err=True, fg='red') + else: + mycli.logger.error('sql: %r, error: %r', text, e1) + mycli.logger.error('traceback: %r', traceback.format_exc()) + mycli.echo(str(e1), err=True, fg='red') except Exception as e: mycli.logger.error('sql: %r, error: %r', text, e) mycli.logger.error('traceback: %r', traceback.format_exc()) mycli.echo(str(e), err=True, fg='red') else: + if mycli.sandbox_mode and _is_password_change(text): + new_password = _extract_new_password(text) + if new_password is not None: + sqlexecute.password = new_password + try: + sqlexecute.connect() + mycli.sandbox_mode = False + mycli.echo("Password changed successfully. Reconnected.", err=True, fg='green') + mycli.refresh_completions() + except Exception as e: + mycli.sandbox_mode = False + mycli.echo( + f"Password changed but reconnection failed: {e}\nPlease restart mycli with your new password.", + err=True, + fg='yellow', + ) + if is_dropping_database(text, sqlexecute.dbname): sqlexecute.dbname = None sqlexecute.connect() @@ -756,7 +835,7 @@ def main_repl(mycli: 'MyCli') -> None: state = ReplState() mycli.configure_pager() - if mycli.smart_completion: + if mycli.smart_completion and not mycli.sandbox_mode: mycli.refresh_completions() history = _create_history(mycli) diff --git a/mycli/sqlexecute.py b/mycli/sqlexecute.py index 40b933a5..b045a4c6 100644 --- a/mycli/sqlexecute.py +++ b/mycli/sqlexecute.py @@ -14,6 +14,7 @@ from pymysql.converters import conversions, convert_date, convert_datetime, convert_time, decoders from pymysql.cursors import Cursor +from mycli.constants import ER_MUST_CHANGE_PASSWORD from mycli.packages.special import iocommands from mycli.packages.special.main import CommandNotFound, execute from mycli.packages.sqlresult import SQLResult @@ -280,32 +281,50 @@ def connect( client_flag = pymysql.constants.CLIENT.INTERACTIVE if init_command and len(list(iocommands.split_queries(init_command))) > 1: client_flag |= pymysql.constants.CLIENT.MULTI_STATEMENTS + client_flag |= pymysql.constants.CLIENT.HANDLE_EXPIRED_PASSWORDS ssl_context = None if ssl: ssl_context = self._create_ssl_ctx(ssl) - conn = pymysql.connect( - database=db, - user=user, - password=password or '', - host=host, - port=port or 0, - unix_socket=socket, - use_unicode=True, - charset=character_set or '', - autocommit=True, - client_flag=client_flag, - local_infile=local_infile or False, - conv=conv, - ssl=ssl_context, # type: ignore[arg-type] - program_name="mycli", - defer_connect=defer_connect, - init_command=init_command or None, - cursorclass=pymysql.cursors.SSCursor if unbuffered else pymysql.cursors.Cursor, - ) # type: ignore[misc] + connect_kwargs: dict[str, Any] = { + "database": db, + "user": user, + "password": password or '', + "host": host, + "port": port or 0, + "unix_socket": socket, + "use_unicode": True, + "charset": character_set or '', + "autocommit": True, + "client_flag": client_flag, + "local_infile": local_infile or False, + "conv": conv, + "ssl": ssl_context, # type: ignore[arg-type] + "program_name": "mycli", + "defer_connect": defer_connect, + "init_command": init_command or None, + "cursorclass": pymysql.cursors.SSCursor if unbuffered else pymysql.cursors.Cursor, + } + + self.sandbox_mode = False + try: + conn = pymysql.connect(**connect_kwargs) # type: ignore[misc] + except pymysql.OperationalError as e: + if e.args[0] == ER_MUST_CHANGE_PASSWORD: + # Post-handshake queries (SET NAMES, SET AUTOCOMMIT, init_command) + # fail with ER_MUST_CHANGE_PASSWORD in sandbox mode. + # Reconnect with only the raw handshake. + connect_kwargs['defer_connect'] = True + connect_kwargs['autocommit'] = None + connect_kwargs['init_command'] = None + conn = pymysql.connect(**connect_kwargs) # type: ignore[misc] + self._connect_sandbox(conn) + self.sandbox_mode = True + else: + raise - if ssh_host: + if ssh_host and not self.sandbox_mode: ##### paramiko.Channel is a bad socket implementation overall if you want SSL through an SSH tunnel ##### # instead let's open a tunnel and rewrite host:port to local bind @@ -343,9 +362,10 @@ def connect( self.ssl = ssl self.init_command = init_command self.unbuffered = unbuffered - # retrieve connection id - self.reset_connection_id() - self.server_info = ServerInfo.from_version_string(conn.server_version) # type: ignore[attr-defined] + # retrieve connection id (skip in sandbox mode as queries will fail) + if not self.sandbox_mode: + self.reset_connection_id() + self.server_info = ServerInfo.from_version_string(conn.server_version) # type: ignore[attr-defined] def run(self, statement: str) -> Generator[SQLResult, None, None]: """Execute the sql in the database and return the results.""" @@ -576,6 +596,24 @@ def change_db(self, db: str) -> None: self.conn.select_db(db) self.dbname = db + @staticmethod + def _connect_sandbox(conn: Connection) -> None: + """Connect in sandbox mode, performing only the handshake. + + pymysql's normal connect() runs post-handshake queries (SET NAMES, + SET AUTOCOMMIT, init_command) that all fail with ER_MUST_CHANGE_PASSWORD + in sandbox mode. This method performs the raw socket connection and + authentication handshake only. + """ + # Reuse pymysql internals for the handshake + auth, but + # temporarily stub out set_character_set so it becomes a no-op. + original_set_charset = conn.set_character_set + conn.set_character_set = lambda *_args, **_kwargs: None # type: ignore[assignment] + try: + conn.connect() + finally: + conn.set_character_set = original_set_charset # type: ignore[assignment] + def _create_ssl_ctx(self, sslp: dict) -> ssl.SSLContext: ca = sslp.get("ca") capath = sslp.get("capath") diff --git a/test/pytests/test_main_modes_repl.py b/test/pytests/test_main_modes_repl.py index f67867cc..ba3ea163 100644 --- a/test/pytests/test_main_modes_repl.py +++ b/test/pytests/test_main_modes_repl.py @@ -174,6 +174,7 @@ def make_repl_cli(sqlexecute: Any | None = None) -> Any: cli.cli_style = {} cli.emacs_ttimeoutlen = 1.0 cli.vi_ttimeoutlen = 2.0 + cli.sandbox_mode = False cli.destructive_warning = False cli.destructive_keywords = ['drop'] cli.llm_prompt_field_truncate = 0 @@ -796,6 +797,127 @@ def run(self, text: str) -> Iterator[SQLResult]: assert cli_quiet.output_calls[0][0] == ['None', 'ran:select 2'] +@pytest.mark.parametrize( + 'text, expected', + [ + ('', True), + (' ', True), + ("ALTER USER 'root'@'localhost' IDENTIFIED BY 'new'", True), + ('alter user root identified by "pw"', True), + ("SET PASSWORD = 'newpass'", True), + ("set password = 'newpass'", True), + ('quit', True), + ('exit', True), + ('\\q', True), + ('SELECT 1', False), + ('DROP TABLE t', False), + ('USE mydb', False), + ('SHOW DATABASES', False), + ], +) +def test_is_sandbox_allowed(text: str, expected: bool) -> None: + assert repl_mode._is_sandbox_allowed(text) is expected + + +@pytest.mark.parametrize( + 'text, expected', + [ + ("ALTER USER 'root'@'localhost' IDENTIFIED BY 'new'", True), + ("SET PASSWORD = 'newpass'", True), + ('SELECT 1', False), + ('quit', False), + ], +) +def test_is_password_change(text: str, expected: bool) -> None: + assert repl_mode._is_password_change(text) is expected + + +@pytest.mark.parametrize( + 'text, expected', + [ + ("ALTER USER 'root'@'localhost' IDENTIFIED BY 'newpass'", 'newpass'), + ("SET PASSWORD = 'secret123'", 'secret123'), + ("ALTER USER root IDENTIFIED BY 'p@ss w0rd!'", 'p@ss w0rd!'), + ('ALTER USER root IDENTIFIED WITH mysql_native_password', None), + ('SELECT 1', None), + ], +) +def test_extract_new_password(text: str, expected: str | None) -> None: + assert repl_mode._extract_new_password(text) == expected + + +def test_one_iteration_blocks_disallowed_in_sandbox_mode(monkeypatch: pytest.MonkeyPatch) -> None: + patch_repl_runtime_defaults(monkeypatch) + + class FakeSQLExecute: + def __init__(self) -> None: + self.dbname = 'db' + self.connection_id = 0 + + def run(self, text: str) -> Iterator[SQLResult]: + return iter([SQLResult(status=f'ran:{text}')]) + + cli = make_repl_cli(FakeSQLExecute()) + cli.sandbox_mode = True + + repl_mode._one_iteration(cli, repl_mode.ReplState(), 'SELECT 1') + assert any('ERROR 1820' in msg for msg in cli.echo_calls) + assert not cli.query_history + + +def test_one_iteration_allows_alter_user_in_sandbox_mode(monkeypatch: pytest.MonkeyPatch) -> None: + patch_repl_runtime_defaults(monkeypatch) + + class FakeSQLExecute: + def __init__(self) -> None: + self.dbname = 'db' + self.connection_id = 0 + self.password = 'old' + self.connect_calls: list[bool] = [] + + def connect(self) -> None: + self.connect_calls.append(True) + + def run(self, text: str) -> Iterator[SQLResult]: + return iter([SQLResult(status='OK')]) + + sqlexecute = FakeSQLExecute() + cli = make_repl_cli(sqlexecute) + cli.sandbox_mode = True + monkeypatch.setattr(repl_mode, 'is_mutating', lambda status: False) + + repl_mode._one_iteration(cli, repl_mode.ReplState(), "ALTER USER 'root'@'localhost' IDENTIFIED BY 'newpass'") + assert cli.sandbox_mode is False + assert sqlexecute.password == 'newpass' + assert sqlexecute.connect_calls == [True] + assert any('Reconnected' in msg for msg in cli.echo_calls) + + +def test_one_iteration_sandbox_reconnect_failure(monkeypatch: pytest.MonkeyPatch) -> None: + patch_repl_runtime_defaults(monkeypatch) + + class FakeSQLExecute: + def __init__(self) -> None: + self.dbname = 'db' + self.connection_id = 0 + self.password = 'old' + + def connect(self) -> None: + raise RuntimeError('connection refused') + + def run(self, text: str) -> Iterator[SQLResult]: + return iter([SQLResult(status='OK')]) + + sqlexecute = FakeSQLExecute() + cli = make_repl_cli(sqlexecute) + cli.sandbox_mode = True + monkeypatch.setattr(repl_mode, 'is_mutating', lambda status: False) + + repl_mode._one_iteration(cli, repl_mode.ReplState(), "ALTER USER 'root'@'localhost' IDENTIFIED BY 'newpass'") + assert cli.sandbox_mode is False + assert any('reconnection failed' in msg for msg in cli.echo_calls) + + def test_one_iteration_covers_redirect_destructive_success_refresh_and_logfile(monkeypatch: pytest.MonkeyPatch) -> None: patch_repl_runtime_defaults(monkeypatch) diff --git a/test/pytests/test_main_regression.py b/test/pytests/test_main_regression.py index 5946a58a..25c3ebf3 100644 --- a/test/pytests/test_main_regression.py +++ b/test/pytests/test_main_regression.py @@ -92,6 +92,7 @@ def __init__(self, **kwargs: Any) -> None: self.dbname = kwargs.get('database') self.user = kwargs.get('user') self.conn = kwargs.get('conn') + self.sandbox_mode = False class ToggleBool: diff --git a/test/pytests/test_sqlexecute.py b/test/pytests/test_sqlexecute.py index 5155cb9a..e250b154 100644 --- a/test/pytests/test_sqlexecute.py +++ b/test/pytests/test_sqlexecute.py @@ -673,6 +673,7 @@ def make_executor_for_connect_tests() -> SQLExecute: executor.ssh_key_filename = '/stored/key.pem' executor.init_command = 'select 1' executor.unbuffered = False + executor.sandbox_mode = False executor.conn = None return executor @@ -762,6 +763,56 @@ def fake_reset_connection_id(self) -> None: assert executor.server_info.version == 80036 +def test_connect_sets_expired_password_flag(monkeypatch) -> None: + executor = make_executor_for_connect_tests() + executor.ssl = None + + new_conn = DummyConnection(server_version='8.0.36-0ubuntu0.22.04.1') + connect_kwargs = {} + + def fake_connect(**kwargs): + connect_kwargs.update(kwargs) + return new_conn + + monkeypatch.setattr(sqlexecute.pymysql, 'connect', fake_connect) + monkeypatch.setattr(SQLExecute, 'reset_connection_id', lambda self: None) + + executor.connect() + + assert connect_kwargs['client_flag'] & sqlexecute.pymysql.constants.CLIENT.HANDLE_EXPIRED_PASSWORDS + assert executor.sandbox_mode is False + + +def test_connect_falls_back_to_sandbox_on_1820(monkeypatch) -> None: + executor = make_executor_for_connect_tests() + executor.ssl = None + + new_conn = DummyConnection(server_version='8.0.36-0ubuntu0.22.04.1') + call_count = 0 + sandbox_calls = [] + + def fake_connect(**kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise pymysql.OperationalError(1820, 'must change password') + return new_conn + + def fake_connect_sandbox(self, conn): + sandbox_calls.append(conn) + + monkeypatch.setattr(sqlexecute.pymysql, 'connect', fake_connect) + monkeypatch.setattr(SQLExecute, '_connect_sandbox', fake_connect_sandbox) + + executor.connect() + + assert call_count == 2 + assert len(sandbox_calls) == 1 + assert executor.sandbox_mode is True + assert executor.server_info is None + assert executor.connection_id is None + + def test_connect_uses_ssh_tunnel_when_ssh_host_is_set(monkeypatch) -> None: executor = make_executor_for_connect_tests() executor.ssl = None