Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Features
* Make `--progress` and `--checkpoint` strictly by statement.
* Allow more characters in passwords read from a file.
* Show sponsors and contributors separately in startup messages.
* Add support for expired password (sandbox) mode (#440)


Bug Fixes
Expand Down
4 changes: 4 additions & 0 deletions mycli/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@

DEFAULT_WIDTH = 80
DEFAULT_HEIGHT = 25

# MySQL error codes not available in pymysql.constants.ER
ER_MUST_CHANGE_PASSWORD_LOGIN = 1862
ER_MUST_CHANGE_PASSWORD = 1820
18 changes: 18 additions & 0 deletions mycli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
DEFAULT_HOST,
DEFAULT_PORT,
DEFAULT_WIDTH,
ER_MUST_CHANGE_PASSWORD_LOGIN,
ISSUES_URL,
REPO_URL,
)
Expand Down Expand Up @@ -152,6 +153,7 @@ def __init__(
self.prompt_session: PromptSession | None = None
self._keepalive_counter = 0
self.keepalive_ticks: int | None = 0
self.sandbox_mode: bool = False

# self.cnf_files is a class variable that stores the list of mysql
# config files to read in at launch.
Expand Down Expand Up @@ -750,6 +752,13 @@ def _connect(
keyring_retrieved_cleanly=keyring_retrieved_cleanly,
keyring_save_eligible=keyring_save_eligible,
)
elif e1.args[0] == ER_MUST_CHANGE_PASSWORD_LOGIN:
self.echo(
"Your password has expired and the server rejected the connection.",
err=True,
fg='red',
)
raise e1
elif e1.args[0] == CR_SERVER_LOST:
self.echo(
(
Expand Down Expand Up @@ -803,6 +812,15 @@ def _connect(
sys.exit(1)

_connect(keyring_retrieved_cleanly=keyring_retrieved_cleanly)

# Check if SQLExecute detected sandbox mode during connection
if self.sqlexecute and self.sqlexecute.sandbox_mode:
self.sandbox_mode = True
self.echo(
"Your password has expired. Use ALTER USER to set a new password, or quit.",
err=True,
fg='yellow',
)
except Exception as e: # Connecting to a database could fail.
self.logger.debug("Database connection failed: %r.", e)
self.logger.error("traceback: %r", traceback.format_exc())
Expand Down
99 changes: 89 additions & 10 deletions mycli/main_modes/repl.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from mycli.constants import (
DEFAULT_HOST,
DEFAULT_WIDTH,
ER_MUST_CHANGE_PASSWORD,
HOME_URL,
ISSUES_URL,
)
Expand Down Expand Up @@ -132,7 +133,8 @@ def _show_startup_banner(
if mycli.less_chatty:
return

print(sqlexecute.server_info)
if sqlexecute.server_info is not None:
print(sqlexecute.server_info)
print('mycli', mycli_package.__version__)
print(SUPPORT_INFO)
if random.random() <= 0.25:
Expand Down Expand Up @@ -232,8 +234,6 @@ def get_prompt(
) -> str:
sqlexecute = mycli.sqlexecute
assert sqlexecute is not None
assert sqlexecute.server_info is not None
assert sqlexecute.server_info.species is not None
if mycli.login_path and mycli.login_path_as_host:
prompt_host = mycli.login_path
elif sqlexecute.host is not None:
Expand All @@ -250,7 +250,8 @@ def get_prompt(
string = string.replace('\\h', prompt_host or '(none)')
string = string.replace('\\H', short_prompt_host or '(none)')
string = string.replace('\\d', sqlexecute.dbname or '(none)')
string = string.replace('\\t', sqlexecute.server_info.species.name)
species_name = sqlexecute.server_info.species.name if sqlexecute.server_info and sqlexecute.server_info.species else 'MySQL'
string = string.replace('\\t', species_name)
string = string.replace('\\n', '\n')
string = string.replace('\\D', now.strftime('%a %b %d %H:%M:%S %Y'))
string = string.replace('\\m', now.strftime('%M'))
Expand Down Expand Up @@ -518,6 +519,52 @@ def _build_prompt_session(
mycli.prompt_session.app.ttimeoutlen = mycli.emacs_ttimeoutlen


_SANDBOX_ALLOWED_RE = re.compile(
r'^\s*(ALTER\s+USER|SET\s+PASSWORD|QUIT|EXIT|\\q)\b',
re.IGNORECASE,
)

_PASSWORD_CHANGE_RE = re.compile(
r'^\s*(ALTER\s+USER|SET\s+PASSWORD)\b',
re.IGNORECASE,
)


def _is_sandbox_allowed(text: str) -> bool:
"""Return True if the command is allowed in expired-password sandbox mode."""
stripped = text.strip()
if not stripped:
return True
return bool(_SANDBOX_ALLOWED_RE.match(stripped))


def _is_password_change(text: str) -> bool:
"""Return True if the command is a password change statement."""
return bool(_PASSWORD_CHANGE_RE.match(text.strip()))


_IDENTIFIED_BY_RE = re.compile(
r"IDENTIFIED\s+BY\s+'([^']*)'",
re.IGNORECASE,
)

_SET_PASSWORD_RE = re.compile(
r"SET\s+PASSWORD\s*=\s*'([^']*)'",
re.IGNORECASE,
)


def _extract_new_password(text: str) -> str | None:
"""Extract the new password from an ALTER USER or SET PASSWORD statement."""
m = _IDENTIFIED_BY_RE.search(text)
if m:
return m.group(1)
m = _SET_PASSWORD_RE.search(text)
if m:
return m.group(1)
return None


def _one_iteration(
mycli: 'MyCli',
state: ReplState,
Expand Down Expand Up @@ -615,6 +662,14 @@ def _one_iteration(
mycli.echo(str(e), err=True, fg='red')
return

if mycli.sandbox_mode and not _is_sandbox_allowed(text):
mycli.echo(
"ERROR 1820: You must reset your password using ALTER USER or SET PASSWORD before executing this statement.",
err=True,
fg='red',
)
return

if mycli.destructive_warning:
destroy = confirm_destructive_query(mycli.destructive_keywords, text)
if destroy is None:
Expand Down Expand Up @@ -674,20 +729,44 @@ def _one_iteration(
mycli.echo('Not Yet Implemented.', fg='yellow')
except pymysql.OperationalError as e1:
mycli.logger.debug('Exception: %r', e1)
if e1.args[0] in (2003, 2006, 2013):
if e1.args[0] == ER_MUST_CHANGE_PASSWORD:
mycli.sandbox_mode = True
mycli.echo(
"ERROR 1820: You must reset your password using ALTER USER or SET PASSWORD before executing this statement.",
err=True,
fg='red',
)
elif e1.args[0] in (2003, 2006, 2013):
if not mycli.reconnect():
return
_one_iteration(mycli, state, text)
return

mycli.logger.error('sql: %r, error: %r', text, e1)
mycli.logger.error('traceback: %r', traceback.format_exc())
mycli.echo(str(e1), err=True, fg='red')
else:
mycli.logger.error('sql: %r, error: %r', text, e1)
mycli.logger.error('traceback: %r', traceback.format_exc())
mycli.echo(str(e1), err=True, fg='red')
except Exception as e:
mycli.logger.error('sql: %r, error: %r', text, e)
mycli.logger.error('traceback: %r', traceback.format_exc())
mycli.echo(str(e), err=True, fg='red')
else:
if mycli.sandbox_mode and _is_password_change(text):
new_password = _extract_new_password(text)
if new_password is not None:
sqlexecute.password = new_password
try:
sqlexecute.connect()
mycli.sandbox_mode = False
mycli.echo("Password changed successfully. Reconnected.", err=True, fg='green')
mycli.refresh_completions()
except Exception as e:
mycli.sandbox_mode = False
mycli.echo(
f"Password changed but reconnection failed: {e}\nPlease restart mycli with your new password.",
err=True,
fg='yellow',
)

if is_dropping_database(text, sqlexecute.dbname):
sqlexecute.dbname = None
sqlexecute.connect()
Expand Down Expand Up @@ -756,7 +835,7 @@ def main_repl(mycli: 'MyCli') -> None:
state = ReplState()

mycli.configure_pager()
if mycli.smart_completion:
if mycli.smart_completion and not mycli.sandbox_mode:
mycli.refresh_completions()

history = _create_history(mycli)
Expand Down
84 changes: 61 additions & 23 deletions mycli/sqlexecute.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pymysql.converters import conversions, convert_date, convert_datetime, convert_time, decoders
from pymysql.cursors import Cursor

from mycli.constants import ER_MUST_CHANGE_PASSWORD
from mycli.packages.special import iocommands
from mycli.packages.special.main import CommandNotFound, execute
from mycli.packages.sqlresult import SQLResult
Expand Down Expand Up @@ -280,32 +281,50 @@ def connect(
client_flag = pymysql.constants.CLIENT.INTERACTIVE
if init_command and len(list(iocommands.split_queries(init_command))) > 1:
client_flag |= pymysql.constants.CLIENT.MULTI_STATEMENTS
client_flag |= pymysql.constants.CLIENT.HANDLE_EXPIRED_PASSWORDS

ssl_context = None
if ssl:
ssl_context = self._create_ssl_ctx(ssl)

conn = pymysql.connect(
database=db,
user=user,
password=password or '',
host=host,
port=port or 0,
unix_socket=socket,
use_unicode=True,
charset=character_set or '',
autocommit=True,
client_flag=client_flag,
local_infile=local_infile or False,
conv=conv,
ssl=ssl_context, # type: ignore[arg-type]
program_name="mycli",
defer_connect=defer_connect,
init_command=init_command or None,
cursorclass=pymysql.cursors.SSCursor if unbuffered else pymysql.cursors.Cursor,
) # type: ignore[misc]
connect_kwargs: dict[str, Any] = {
"database": db,
"user": user,
"password": password or '',
"host": host,
"port": port or 0,
"unix_socket": socket,
"use_unicode": True,
"charset": character_set or '',
"autocommit": True,
"client_flag": client_flag,
"local_infile": local_infile or False,
"conv": conv,
"ssl": ssl_context, # type: ignore[arg-type]
"program_name": "mycli",
"defer_connect": defer_connect,
"init_command": init_command or None,
"cursorclass": pymysql.cursors.SSCursor if unbuffered else pymysql.cursors.Cursor,
}

self.sandbox_mode = False
try:
conn = pymysql.connect(**connect_kwargs) # type: ignore[misc]
except pymysql.OperationalError as e:
if e.args[0] == ER_MUST_CHANGE_PASSWORD:
# Post-handshake queries (SET NAMES, SET AUTOCOMMIT, init_command)
# fail with ER_MUST_CHANGE_PASSWORD in sandbox mode.
# Reconnect with only the raw handshake.
connect_kwargs['defer_connect'] = True
connect_kwargs['autocommit'] = None
connect_kwargs['init_command'] = None
conn = pymysql.connect(**connect_kwargs) # type: ignore[misc]
self._connect_sandbox(conn)
self.sandbox_mode = True
else:
raise

if ssh_host:
if ssh_host and not self.sandbox_mode:
##### paramiko.Channel is a bad socket implementation overall if you want SSL through an SSH tunnel
#####
# instead let's open a tunnel and rewrite host:port to local bind
Expand Down Expand Up @@ -343,9 +362,10 @@ def connect(
self.ssl = ssl
self.init_command = init_command
self.unbuffered = unbuffered
# retrieve connection id
self.reset_connection_id()
self.server_info = ServerInfo.from_version_string(conn.server_version) # type: ignore[attr-defined]
# retrieve connection id (skip in sandbox mode as queries will fail)
if not self.sandbox_mode:
self.reset_connection_id()
self.server_info = ServerInfo.from_version_string(conn.server_version) # type: ignore[attr-defined]

def run(self, statement: str) -> Generator[SQLResult, None, None]:
"""Execute the sql in the database and return the results."""
Expand Down Expand Up @@ -576,6 +596,24 @@ def change_db(self, db: str) -> None:
self.conn.select_db(db)
self.dbname = db

@staticmethod
def _connect_sandbox(conn: Connection) -> None:
"""Connect in sandbox mode, performing only the handshake.

pymysql's normal connect() runs post-handshake queries (SET NAMES,
SET AUTOCOMMIT, init_command) that all fail with ER_MUST_CHANGE_PASSWORD
in sandbox mode. This method performs the raw socket connection and
authentication handshake only.
"""
# Reuse pymysql internals for the handshake + auth, but
# temporarily stub out set_character_set so it becomes a no-op.
original_set_charset = conn.set_character_set
conn.set_character_set = lambda *_args, **_kwargs: None # type: ignore[assignment]
try:
conn.connect()
finally:
conn.set_character_set = original_set_charset # type: ignore[assignment]

def _create_ssl_ctx(self, sslp: dict) -> ssl.SSLContext:
ca = sslp.get("ca")
capath = sslp.get("capath")
Expand Down
Loading
Loading