From 6f96e2582d5a88602fad031759a3bfe40abf120d Mon Sep 17 00:00:00 2001 From: Jordan Selig Date: Thu, 26 Mar 2026 11:29:25 -0400 Subject: [PATCH 1/2] [App Service] Fix #8831, #13662, #13008: `az webapp ssh`: tunnel reliability and instance targeting - Add WebSocket retry logic with exponential backoff for tunnel connections (#8831) - Add keepalive pings to prevent idle disconnects (#8831) - Register signal handlers (SIGINT/SIGTERM) and atexit for clean tunnel shutdown (#13662) - Wrap main loops in try/finally to ensure cleanup on KeyboardInterrupt (#13662) - Add close() method to TunnelServer for deterministic resource cleanup (#13662) - Improve --instance and --timeout parameter help text (#13008) - Add --instance examples to help text for webapp ssh and create-remote-connection (#13008) - Add unit tests for tunnel retry, keepalive, close, and signal registration Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../cli/command_modules/appservice/_help.py | 5 + .../cli/command_modules/appservice/_params.py | 16 +- .../cli/command_modules/appservice/custom.py | 53 ++++-- .../latest/test_webapp_commands_thru_mock.py | 169 ++++++++++++++++++ .../cli/command_modules/appservice/tunnel.py | 129 +++++++++++-- 5 files changed, 341 insertions(+), 31 deletions(-) diff --git a/src/azure-cli/azure/cli/command_modules/appservice/_help.py b/src/azure-cli/azure/cli/command_modules/appservice/_help.py index e0dad92e98c..cd78906c246 100644 --- a/src/azure-cli/azure/cli/command_modules/appservice/_help.py +++ b/src/azure-cli/azure/cli/command_modules/appservice/_help.py @@ -1918,6 +1918,8 @@ examples: - name: Create a remote connection using a tcp tunnel to your web app text: az webapp create-remote-connection --name MyWebApp --resource-group MyResourceGroup + - name: Create a remote connection to a specific instance + text: az webapp create-remote-connection --name MyWebApp --resource-group MyResourceGroup --instance 89c07485c4742abcde3f0e19ea4402a06e3b48145ed81e6468066f10a78074b1 """ helps['webapp delete'] = """ @@ -2394,6 +2396,9 @@ - name: ssh into a web app text: > az webapp ssh -n MyUniqueAppName -g MyResourceGroup + - name: ssh into a specific instance of a web app + text: > + az webapp ssh -n MyUniqueAppName -g MyResourceGroup --instance 89c07485c4742abcde3f0e19ea4402a06e3b48145ed81e6468066f10a78074b1 """ helps['webapp start'] = """ diff --git a/src/azure-cli/azure/cli/command_modules/appservice/_params.py b/src/azure-cli/azure/cli/command_modules/appservice/_params.py index bd4dc0b7ea5..5cab2bf7169 100644 --- a/src/azure-cli/azure/cli/command_modules/appservice/_params.py +++ b/src/azure-cli/azure/cli/command_modules/appservice/_params.py @@ -977,14 +977,22 @@ def load_arguments(self, _): with self.argument_context('webapp ssh') as c: c.argument('port', options_list=['--port', '-p'], help='Port for the remote connection. Default: Random available port', type=int) - c.argument('timeout', options_list=['--timeout', '-t'], help='timeout in seconds. Defaults to none', type=int) - c.argument('instance', options_list=['--instance', '-i'], help='Webapp instance to connect to. Defaults to none.') + c.argument('timeout', options_list=['--timeout', '-t'], + help='Timeout in seconds. The tunnel will automatically close after this duration. ' + 'Defaults to none (keep open until manually closed).', type=int) + c.argument('instance', options_list=['--instance', '-i'], + help='Webapp instance to connect to. Use `az webapp list-instances` to get available instances. ' + 'If not specified, connects to an arbitrary instance.') with self.argument_context('webapp create-remote-connection') as c: c.argument('port', options_list=['--port', '-p'], help='Port for the remote connection. Default: Random available port', type=int) - c.argument('timeout', options_list=['--timeout', '-t'], help='timeout in seconds. Defaults to none', type=int) - c.argument('instance', options_list=['--instance', '-i'], help='Webapp instance to connect to. Defaults to none.') + c.argument('timeout', options_list=['--timeout', '-t'], + help='Timeout in seconds. The tunnel will automatically close after this duration. ' + 'Defaults to none (keep open until manually closed).', type=int) + c.argument('instance', options_list=['--instance', '-i'], + help='Webapp instance to connect to. Use `az webapp list-instances` to get available instances. ' + 'If not specified, connects to an arbitrary instance.') with self.argument_context('webapp vnet-integration') as c: c.argument('name', arg_type=webapp_name_arg_type, id_part=None) diff --git a/src/azure-cli/azure/cli/command_modules/appservice/custom.py b/src/azure-cli/azure/cli/command_modules/appservice/custom.py index 386ba088608..be62ec699f0 100644 --- a/src/azure-cli/azure/cli/command_modules/appservice/custom.py +++ b/src/azure-cli/azure/cli/command_modules/appservice/custom.py @@ -9407,6 +9407,8 @@ def get_tunnel(cmd, resource_group_name, name, port=None, slot=None, instance=No def create_tunnel(cmd, resource_group_name, name, port=None, slot=None, timeout=None, instance=None): tunnel_server = get_tunnel(cmd, resource_group_name, name, port, slot, instance) + _register_tunnel_cleanup(tunnel_server) + t = threading.Thread(target=_start_tunnel, args=(tunnel_server,)) t.daemon = True t.start() @@ -9425,16 +9427,23 @@ def create_tunnel(cmd, resource_group_name, name, port=None, slot=None, timeout= logger.warning('Ctrl + C to close') - if timeout: - time.sleep(int(timeout)) - else: - while t.is_alive(): - time.sleep(5) + try: + if timeout: + time.sleep(int(timeout)) + else: + while t.is_alive(): + time.sleep(5) + except KeyboardInterrupt: + logger.warning('Shutting down tunnel...') + finally: + tunnel_server.close() def create_tunnel_and_session(cmd, resource_group_name, name, port=None, slot=None, timeout=None, instance=None): tunnel_server = get_tunnel(cmd, resource_group_name, name, port, slot, instance) + _register_tunnel_cleanup(tunnel_server) + t = threading.Thread(target=_start_tunnel, args=(tunnel_server,)) t.daemon = True t.start() @@ -9447,11 +9456,16 @@ def create_tunnel_and_session(cmd, resource_group_name, name, port=None, slot=No s.daemon = True s.start() - if timeout: - time.sleep(int(timeout)) - else: - while s.is_alive() and t.is_alive(): - time.sleep(5) + try: + if timeout: + time.sleep(int(timeout)) + else: + while s.is_alive() and t.is_alive(): + time.sleep(5) + except KeyboardInterrupt: + logger.warning('Shutting down tunnel...') + finally: + tunnel_server.close() def perform_onedeploy_functionapp(cmd, @@ -9918,6 +9932,25 @@ def _start_tunnel(tunnel_server): tunnel_server.start_server() +def _register_tunnel_cleanup(tunnel_server): + """Register signal handlers and atexit to ensure the tunnel is cleaned up.""" + import atexit + import signal + + def _cleanup(): + tunnel_server.close() + + atexit.register(_cleanup) + + def _signal_handler(signum, frame): # pylint: disable=unused-argument + logger.warning('Received signal %s, shutting down tunnel...', signum) + tunnel_server.close() + sys.exit(0) + + signal.signal(signal.SIGINT, _signal_handler) + signal.signal(signal.SIGTERM, _signal_handler) + + def _start_ssh_session(hostname, port, username, password): tries = 0 while True: diff --git a/src/azure-cli/azure/cli/command_modules/appservice/tests/latest/test_webapp_commands_thru_mock.py b/src/azure-cli/azure/cli/command_modules/appservice/tests/latest/test_webapp_commands_thru_mock.py index 853eadc1edd..be521127689 100644 --- a/src/azure-cli/azure/cli/command_modules/appservice/tests/latest/test_webapp_commands_thru_mock.py +++ b/src/azure-cli/azure/cli/command_modules/appservice/tests/latest/test_webapp_commands_thru_mock.py @@ -34,6 +34,7 @@ update_app_settings, update_application_settings_polling, update_webapp) +from azure.cli.command_modules.appservice.tunnel import TunnelServer # pylint: disable=line-too-long from azure.cli.core.profiles import ResourceType @@ -639,6 +640,174 @@ def test_update_webapp_platform_release_channel_latest(self): self.assertEqual(result.additional_properties["properties"]["platformReleaseChannel"], "Latest") +class TestTunnelServer(unittest.TestCase): + """Tests for TunnelServer reliability and cleanup improvements.""" + + @mock.patch('azure.cli.command_modules.appservice.tunnel.socket.socket') + def test_tunnel_server_close_sets_closing_event(self, mock_socket_cls): + mock_sock = mock.MagicMock() + mock_socket_cls.return_value = mock_sock + mock_sock.getsockname.return_value = ('127.0.0.1', 12345) + server = TunnelServer.__new__(TunnelServer) + server.local_addr = '127.0.0.1' + server.local_port = 0 + server.remote_addr = 'testapp.scm.azurewebsites.net' + server.auth_string = 'Basic dGVzdDp0ZXN0' + server.instance = None + server.client = None + server.ws = None + from threading import Event + server._closing = Event() + server.sock = mock_sock + + server.close() + self.assertTrue(server._closing.is_set()) + mock_sock.close.assert_called_once() + + @mock.patch('azure.cli.command_modules.appservice.tunnel.socket.socket') + def test_tunnel_server_close_is_idempotent(self, mock_socket_cls): + mock_sock = mock.MagicMock() + mock_socket_cls.return_value = mock_sock + mock_sock.getsockname.return_value = ('127.0.0.1', 12345) + server = TunnelServer.__new__(TunnelServer) + server.local_addr = '127.0.0.1' + server.local_port = 0 + server.remote_addr = 'testapp.scm.azurewebsites.net' + server.auth_string = 'Basic dGVzdDp0ZXN0' + server.instance = None + server.client = None + server.ws = None + from threading import Event + server._closing = Event() + server.sock = mock_sock + + server.close() + server.close() + # Socket.close should only be called once + mock_sock.close.assert_called_once() + + @mock.patch('azure.cli.command_modules.appservice.tunnel.socket.socket') + def test_tunnel_server_close_handles_ws_and_client(self, mock_socket_cls): + mock_sock = mock.MagicMock() + mock_socket_cls.return_value = mock_sock + mock_sock.getsockname.return_value = ('127.0.0.1', 12345) + server = TunnelServer.__new__(TunnelServer) + server.local_addr = '127.0.0.1' + server.local_port = 0 + server.remote_addr = 'testapp.scm.azurewebsites.net' + server.auth_string = 'Basic dGVzdDp0ZXN0' + server.instance = None + server.client = mock.MagicMock() + server.ws = mock.MagicMock() + from threading import Event + server._closing = Event() + server.sock = mock_sock + + server.close() + server.ws.close.assert_called_once() + server.client.close.assert_called_once() + mock_sock.close.assert_called_once() + + @mock.patch('azure.cli.command_modules.appservice.tunnel.create_connection') + @mock.patch('azure.cli.command_modules.appservice.tunnel.socket.socket') + def test_create_websocket_connection_retries_on_failure(self, mock_socket_cls, mock_create_conn): + mock_sock = mock.MagicMock() + mock_socket_cls.return_value = mock_sock + mock_sock.getsockname.return_value = ('127.0.0.1', 12345) + server = TunnelServer.__new__(TunnelServer) + server.local_addr = '127.0.0.1' + server.local_port = 0 + server.remote_addr = 'testapp.scm.azurewebsites.net' + server.auth_string = 'Basic dGVzdDp0ZXN0' + server.instance = None + server.client = None + server.ws = None + from threading import Event + server._closing = Event() + server.sock = mock_sock + + mock_ws = mock.MagicMock() + # Fail twice, succeed on third + mock_create_conn.side_effect = [ConnectionError("fail1"), ConnectionError("fail2"), mock_ws] + + with mock.patch('azure.cli.command_modules.appservice.tunnel.time.sleep'): + result = server._create_websocket_connection( + 'wss://test/Tunnel.ashx', ['Authorization: Basic dGVzdDp0ZXN0'], 0) + + self.assertEqual(result, mock_ws) + self.assertEqual(mock_create_conn.call_count, 3) + + @mock.patch('azure.cli.command_modules.appservice.tunnel.create_connection') + @mock.patch('azure.cli.command_modules.appservice.tunnel.socket.socket') + def test_create_websocket_connection_raises_after_max_retries(self, mock_socket_cls, mock_create_conn): + mock_sock = mock.MagicMock() + mock_socket_cls.return_value = mock_sock + mock_sock.getsockname.return_value = ('127.0.0.1', 12345) + server = TunnelServer.__new__(TunnelServer) + server.local_addr = '127.0.0.1' + server.local_port = 0 + server.remote_addr = 'testapp.scm.azurewebsites.net' + server.auth_string = 'Basic dGVzdDp0ZXN0' + server.instance = None + server.client = None + server.ws = None + from threading import Event + server._closing = Event() + server.sock = mock_sock + + mock_create_conn.side_effect = ConnectionError("always fail") + + with mock.patch('azure.cli.command_modules.appservice.tunnel.time.sleep'): + with self.assertRaises(CLIError) as ctx: + server._create_websocket_connection( + 'wss://test/Tunnel.ashx', ['Authorization: Basic dGVzdDp0ZXN0'], 0) + + self.assertIn('Failed to establish WebSocket tunnel connection', str(ctx.exception)) + + @mock.patch('azure.cli.command_modules.appservice.tunnel.socket.socket') + def test_keepalive_ping_stops_on_event(self, mock_socket_cls): + from threading import Event + mock_sock = mock.MagicMock() + mock_socket_cls.return_value = mock_sock + mock_sock.getsockname.return_value = ('127.0.0.1', 12345) + server = TunnelServer.__new__(TunnelServer) + server.local_addr = '127.0.0.1' + server.local_port = 0 + server.remote_addr = 'testapp.scm.azurewebsites.net' + server.auth_string = 'Basic dGVzdDp0ZXN0' + server.instance = None + server.client = None + server.ws = None + server._closing = Event() + server.sock = mock_sock + + mock_ws = mock.MagicMock() + mock_ws.connected = True + stop_event = Event() + # Signal stop immediately so the keepalive loop runs once at most + stop_event.set() + server._send_keepalive_pings(mock_ws, 1, stop_event) + # Should not crash; ws.ping may or may not have been called depending on timing + + +class TestTunnelSignalCleanup(unittest.TestCase): + """Tests for signal handler registration and cleanup in create_tunnel / create_tunnel_and_session.""" + + @mock.patch('signal.signal') + @mock.patch('atexit.register') + def test_register_tunnel_cleanup_registers_handlers(self, mock_atexit, mock_signal): + from azure.cli.command_modules.appservice.custom import _register_tunnel_cleanup + import signal + + mock_tunnel = mock.MagicMock() + _register_tunnel_cleanup(mock_tunnel) + + mock_atexit.assert_called_once() + signal_calls = {call[0][0] for call in mock_signal.call_args_list} + self.assertIn(signal.SIGINT, signal_calls) + self.assertIn(signal.SIGTERM, signal_calls) + + class FakedResponse: # pylint: disable=too-few-public-methods def __init__(self, status_code): self.status_code = status_code diff --git a/src/azure-cli/azure/cli/command_modules/appservice/tunnel.py b/src/azure-cli/azure/cli/command_modules/appservice/tunnel.py index f4059075416..102f6be441e 100644 --- a/src/azure-cli/azure/cli/command_modules/appservice/tunnel.py +++ b/src/azure-cli/azure/cli/command_modules/appservice/tunnel.py @@ -13,7 +13,7 @@ import logging as logs from contextlib import closing from datetime import datetime -from threading import Thread +from threading import Thread, Event import websocket from websocket import create_connection, WebSocket @@ -24,6 +24,12 @@ from knack.log import get_logger logger = get_logger(__name__) +# Retry / keepalive constants +_MAX_RECONNECT_ATTEMPTS = 5 +_INITIAL_RECONNECT_DELAY = 1 # seconds +_MAX_RECONNECT_DELAY = 30 # seconds +_KEEPALIVE_INTERVAL = 30 # seconds between WebSocket pings + class TunnelWebSocket(WebSocket): def recv_frame(self): @@ -52,6 +58,7 @@ def __init__(self, local_addr, local_port, remote_addr, auth_string, instance): self.instance = instance self.client = None self.ws = None + self._closing = Event() logger.info('Creating a socket on port: %s', self.local_port) self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) logger.info('Setting socket options') @@ -117,12 +124,48 @@ def is_webapp_up(self): logger.warning('Waiting for app to start up... ') return False + def _create_websocket_connection(self, host, basic_auth_header, verify_mode): + """Create a WebSocket connection with retry logic and exponential backoff.""" + delay = _INITIAL_RECONNECT_DELAY + for attempt in range(1, _MAX_RECONNECT_ATTEMPTS + 1): + if self._closing.is_set(): + return None + try: + ws = create_connection(host, + sockopt=((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1),), + class_=TunnelWebSocket, + header=basic_auth_header, + sslopt={'cert_reqs': verify_mode}, + timeout=60 * 60, + enable_multithread=True) + logger.info('WebSocket connected on attempt %s', attempt) + return ws + except Exception as ex: # pylint: disable=broad-except + logger.info('WebSocket connection attempt %s/%s failed: %s', + attempt, _MAX_RECONNECT_ATTEMPTS, ex) + if attempt == _MAX_RECONNECT_ATTEMPTS: + raise CLIError( + 'Failed to establish WebSocket tunnel connection after {} attempts. ' + 'Last error: {}'.format(_MAX_RECONNECT_ATTEMPTS, ex)) + logger.warning('Retrying WebSocket connection in %s seconds...', delay) + time.sleep(delay) + delay = min(delay * 2, _MAX_RECONNECT_DELAY) + return None # pragma: no cover + def _listen(self): self.sock.listen(100) index = 0 - while True: - self.client, _address = self.sock.accept() - self.client.settimeout(60 * 60) + while not self._closing.is_set(): + try: + self.sock.settimeout(1.0) + try: + self.client, _address = self.sock.accept() + except socket.timeout: + continue + self.client.settimeout(60 * 60) + except OSError: + # Socket closed during shutdown + break host = 'wss://{}{}'.format(self.remote_addr, '/AppServiceTunnel/Tunnel.ashx') basic_auth_header = [f"Authorization: {self.auth_string}"] if self.instance is not None: @@ -136,30 +179,47 @@ def _listen(self): logger.info('Websocket tracing disabled, use --verbose flag to enable') websocket.enableTrace(False) verify_mode = ssl.CERT_NONE if should_disable_connection_verify() else ssl.CERT_REQUIRED - self.ws = create_connection(host, - sockopt=((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1),), - class_=TunnelWebSocket, - header=basic_auth_header, - sslopt={'cert_reqs': verify_mode}, - timeout=60 * 60, - enable_multithread=True) + self.ws = self._create_websocket_connection(host, basic_auth_header, verify_mode) + if self.ws is None: + break logger.info('Websocket, connected status: %s', self.ws.connected) index = index + 1 logger.info('Got debugger connection... index: %s', index) + keepalive_stop = Event() debugger_thread = Thread(target=self._listen_to_client, args=(self.client, self.ws, index)) web_socket_thread = Thread(target=self._listen_to_web_socket, args=(self.client, self.ws, index)) + keepalive_thread = Thread(target=self._send_keepalive_pings, + args=(self.ws, index, keepalive_stop)) + keepalive_thread.daemon = True debugger_thread.start() web_socket_thread.start() + keepalive_thread.start() logger.info('Both debugger and websocket threads started...') logger.info('Successfully connected to local server..') debugger_thread.join() web_socket_thread.join() + keepalive_stop.set() + keepalive_thread.join(timeout=5) logger.info('Both debugger and websocket threads stopped...') logger.info('Stopped local server..') + def _send_keepalive_pings(self, ws_socket, index, stop_event): + """Periodically send WebSocket pings to prevent idle disconnects.""" + while not stop_event.is_set() and not self._closing.is_set(): + try: + if ws_socket.connected: + ws_socket.ping() + logger.info('Sent keepalive ping, index: %s', index) + else: + break + except Exception as ex: # pylint: disable=broad-except + logger.info('Keepalive ping failed (index %s): %s', index, ex) + break + stop_event.wait(_KEEPALIVE_INTERVAL) + def _listen_to_web_socket(self, client, ws_socket, index): try: - while True: + while not self._closing.is_set(): logger.info('Waiting for websocket data, connection status: %s, index: %s', ws_socket.connected, index) data = ws_socket.recv() logger.info('Received websocket data: %s, index: %s', data, index) @@ -175,12 +235,18 @@ def _listen_to_web_socket(self, client, ws_socket, index): logger.info(ex) finally: logger.info('Client disconnected!, index: %s', index) - client.close() - ws_socket.close() + try: + client.close() + except Exception: # pylint: disable=broad-except + pass + try: + ws_socket.close() + except Exception: # pylint: disable=broad-except + pass def _listen_to_client(self, client, ws_socket, index): try: - while True: + while not self._closing.is_set(): logger.info('Waiting for debugger data, index: %s', index) buf = bytearray(4096) nbytes = client.recv_into(buf, 4096) @@ -197,8 +263,37 @@ def _listen_to_client(self, client, ws_socket, index): logger.warning("Connection Timed Out") finally: logger.info('Client disconnected %s', index) - client.close() - ws_socket.close() + try: + client.close() + except Exception: # pylint: disable=broad-except + pass + try: + ws_socket.close() + except Exception: # pylint: disable=broad-except + pass + + def close(self): + """Cleanly shut down the tunnel server and release all resources.""" + if self._closing.is_set(): + return + self._closing.set() + logger.info('Closing tunnel server...') + if self.ws: + try: + self.ws.close() + except Exception: # pylint: disable=broad-except + pass + if self.client: + try: + self.client.close() + except Exception: # pylint: disable=broad-except + pass + if self.sock: + try: + self.sock.close() + except Exception: # pylint: disable=broad-except + pass + logger.info('Tunnel server closed') def start_server(self): self._listen() From 4617d79a1c08a542bb28765c030a90a9312a0a48 Mon Sep 17 00:00:00 2001 From: Jordan Selig Date: Thu, 26 Mar 2026 15:19:43 -0400 Subject: [PATCH 2/2] Address review feedback: interruptible backoff, no sys.exit in signal handler - Replace time.sleep(delay) with self._closing.wait(delay) in retry backoff so shutdown wakes immediately when closing event is set - Remove sys.exit(0) from signal handler; let close() set the event and the main loop exit naturally - Remove unused 'import time' from tunnel.py - Update tests to mock _closing.wait instead of time.sleep Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/azure-cli/azure/cli/command_modules/appservice/custom.py | 1 - .../appservice/tests/latest/test_webapp_commands_thru_mock.py | 4 ++-- src/azure-cli/azure/cli/command_modules/appservice/tunnel.py | 3 +-- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/azure-cli/azure/cli/command_modules/appservice/custom.py b/src/azure-cli/azure/cli/command_modules/appservice/custom.py index be62ec699f0..a3e0b59748d 100644 --- a/src/azure-cli/azure/cli/command_modules/appservice/custom.py +++ b/src/azure-cli/azure/cli/command_modules/appservice/custom.py @@ -9945,7 +9945,6 @@ def _cleanup(): def _signal_handler(signum, frame): # pylint: disable=unused-argument logger.warning('Received signal %s, shutting down tunnel...', signum) tunnel_server.close() - sys.exit(0) signal.signal(signal.SIGINT, _signal_handler) signal.signal(signal.SIGTERM, _signal_handler) diff --git a/src/azure-cli/azure/cli/command_modules/appservice/tests/latest/test_webapp_commands_thru_mock.py b/src/azure-cli/azure/cli/command_modules/appservice/tests/latest/test_webapp_commands_thru_mock.py index be521127689..ab243ae7ee4 100644 --- a/src/azure-cli/azure/cli/command_modules/appservice/tests/latest/test_webapp_commands_thru_mock.py +++ b/src/azure-cli/azure/cli/command_modules/appservice/tests/latest/test_webapp_commands_thru_mock.py @@ -730,7 +730,7 @@ def test_create_websocket_connection_retries_on_failure(self, mock_socket_cls, m # Fail twice, succeed on third mock_create_conn.side_effect = [ConnectionError("fail1"), ConnectionError("fail2"), mock_ws] - with mock.patch('azure.cli.command_modules.appservice.tunnel.time.sleep'): + with mock.patch.object(server._closing, 'wait'): result = server._create_websocket_connection( 'wss://test/Tunnel.ashx', ['Authorization: Basic dGVzdDp0ZXN0'], 0) @@ -757,7 +757,7 @@ def test_create_websocket_connection_raises_after_max_retries(self, mock_socket_ mock_create_conn.side_effect = ConnectionError("always fail") - with mock.patch('azure.cli.command_modules.appservice.tunnel.time.sleep'): + with mock.patch.object(server._closing, 'wait'): with self.assertRaises(CLIError) as ctx: server._create_websocket_connection( 'wss://test/Tunnel.ashx', ['Authorization: Basic dGVzdDp0ZXN0'], 0) diff --git a/src/azure-cli/azure/cli/command_modules/appservice/tunnel.py b/src/azure-cli/azure/cli/command_modules/appservice/tunnel.py index 102f6be441e..591917b2985 100644 --- a/src/azure-cli/azure/cli/command_modules/appservice/tunnel.py +++ b/src/azure-cli/azure/cli/command_modules/appservice/tunnel.py @@ -8,7 +8,6 @@ import ssl import json import socket -import time import traceback import logging as logs from contextlib import closing @@ -148,7 +147,7 @@ def _create_websocket_connection(self, host, basic_auth_header, verify_mode): 'Failed to establish WebSocket tunnel connection after {} attempts. ' 'Last error: {}'.format(_MAX_RECONNECT_ATTEMPTS, ex)) logger.warning('Retrying WebSocket connection in %s seconds...', delay) - time.sleep(delay) + self._closing.wait(delay) delay = min(delay * 2, _MAX_RECONNECT_DELAY) return None # pragma: no cover