diff --git a/workers/proxy_worker/dispatcher.py b/workers/proxy_worker/dispatcher.py index 55aaeb3b..0276a044 100644 --- a/workers/proxy_worker/dispatcher.py +++ b/workers/proxy_worker/dispatcher.py @@ -35,8 +35,22 @@ from proxy_worker.version import VERSION from .utils.dependency import DependencyManager +# Cache protobuf constants to avoid repeated lookups in hot paths +_RpcLog = protos.RpcLog +_LOG_LEVEL_CRITICAL = _RpcLog.Critical +_LOG_LEVEL_ERROR = _RpcLog.Error +_LOG_LEVEL_WARNING = _RpcLog.Warning +_LOG_LEVEL_INFO = _RpcLog.Information +_LOG_LEVEL_DEBUG = _RpcLog.Debug +_LOG_LEVEL_NONE = getattr(_RpcLog, 'None') + +_RpcLogCategory = _RpcLog.RpcLogCategory +_LOG_CATEGORY_SYSTEM = _RpcLogCategory.Value('System') +_LOG_CATEGORY_USER = _RpcLogCategory.Value('User') + # Library worker import reloaded in init and reload request _library_worker = None +_library_worker_has_cv = False # Thread-local invocation ID registry for efficient lookup _thread_invocation_registry: typing.Dict[int, str] = {} @@ -99,9 +113,9 @@ def get_global_current_invocation_id() -> Optional[str]: def get_current_invocation_id() -> Optional[Any]: - global _library_worker + global _library_worker, _library_worker_has_cv # Check global current invocation first (most up-to-date) - if _library_worker and not hasattr(_library_worker, 'invocation_id_cv'): + if _library_worker and not _library_worker_has_cv: global_invocation_id = get_global_current_invocation_id() if global_invocation_id is not None: return global_invocation_id @@ -121,23 +135,22 @@ def get_current_invocation_id() -> Optional[Any]: # No event loop running pass + # Check contextvar from library worker + if _library_worker and _library_worker_has_cv: + try: + cv = _library_worker.invocation_id_cv + val = cv.get() + if val is not None: + return val + except (AttributeError, LookupError): + pass + # Check the thread-local invocation ID registry current_thread_id = threading.get_ident() thread_invocation_id = get_thread_invocation_id(current_thread_id) if thread_invocation_id is not None: return thread_invocation_id - # Check contextvar from library worker - if _library_worker: - try: - cv = getattr(_library_worker, 'invocation_id_cv', None) - if cv: - val = cv.get() - if val is not None: - return val - except (AttributeError, LookupError): - pass - return getattr(_invocation_id_local, 'invocation_id', None) @@ -204,22 +217,22 @@ def __init__(self, loop: AbstractEventLoop, host: str, port: int, def on_logging(self, record: logging.LogRecord, formatted_msg: str) -> None: if record.levelno >= logging.CRITICAL: - log_level = protos.RpcLog.Critical + log_level = _LOG_LEVEL_CRITICAL elif record.levelno >= logging.ERROR: - log_level = protos.RpcLog.Error + log_level = _LOG_LEVEL_ERROR elif record.levelno >= logging.WARNING: - log_level = protos.RpcLog.Warning + log_level = _LOG_LEVEL_WARNING elif record.levelno >= logging.INFO: - log_level = protos.RpcLog.Information + log_level = _LOG_LEVEL_INFO elif record.levelno >= logging.DEBUG: - log_level = protos.RpcLog.Debug + log_level = _LOG_LEVEL_DEBUG else: - log_level = getattr(protos.RpcLog, 'None') + log_level = _LOG_LEVEL_NONE if is_system_log_category(record.name): - log_category = protos.RpcLog.RpcLogCategory.Value('System') + log_category = _LOG_CATEGORY_SYSTEM else: # customers using logging will yield 'root' in record.name - log_category = protos.RpcLog.RpcLogCategory.Value('User') + log_category = _LOG_CATEGORY_USER log = dict( level=log_level, @@ -404,12 +417,13 @@ def stop(self) -> None: @staticmethod def reload_library_worker(directory: str): - global _library_worker + global _library_worker, _library_worker_has_cv v2_scriptfile = os.path.join(directory, get_script_file_name()) if os.path.exists(v2_scriptfile): try: import azure_functions_runtime # NoQA _library_worker = azure_functions_runtime + _library_worker_has_cv = hasattr(_library_worker, 'invocation_id_cv') logger.debug("azure_functions_runtime import succeeded: %s", _library_worker.__file__) except ImportError: @@ -419,6 +433,7 @@ def reload_library_worker(directory: str): try: import azure_functions_runtime_v1 # NoQA _library_worker = azure_functions_runtime_v1 + _library_worker_has_cv = hasattr(_library_worker, 'invocation_id_cv') logger.debug("azure_functions_runtime_v1 import succeeded: %s", _library_worker.__file__) # type: ignore[union-attr] except ImportError: diff --git a/workers/tests/unittest_proxy/test_dispatcher.py b/workers/tests/unittest_proxy/test_dispatcher.py index 976b09a9..eff94de0 100644 --- a/workers/tests/unittest_proxy/test_dispatcher.py +++ b/workers/tests/unittest_proxy/test_dispatcher.py @@ -60,6 +60,9 @@ def test_dispatcher_initialization(self, mock_thread, mock_queue): @patch("proxy_worker.dispatcher.is_system_log_category") def test_on_logging_levels_and_categories(self, mock_is_system, mock_rpc_log, mock_streaming_message): + # Import module to access cached constants + import proxy_worker.dispatcher as dispatcher_module + loop = Mock() dispatcher = Dispatcher(loop, "localhost", 5000, "worker", "req", 5.0) @@ -68,23 +71,34 @@ def test_on_logging_levels_and_categories(self, mock_is_system, mock_rpc_log, mock_streaming_message.return_value = Mock() levels = [ - (logging.CRITICAL, mock_rpc_log.Critical), - (logging.ERROR, mock_rpc_log.Error), - (logging.WARNING, mock_rpc_log.Warning), - (logging.INFO, mock_rpc_log.Information), - (logging.DEBUG, mock_rpc_log.Debug), - (5, getattr(mock_rpc_log, 'None')), + (logging.CRITICAL, dispatcher_module._LOG_LEVEL_CRITICAL), + (logging.ERROR, dispatcher_module._LOG_LEVEL_ERROR), + (logging.WARNING, dispatcher_module._LOG_LEVEL_WARNING), + (logging.INFO, dispatcher_module._LOG_LEVEL_INFO), + (logging.DEBUG, dispatcher_module._LOG_LEVEL_DEBUG), + (5, dispatcher_module._LOG_LEVEL_NONE), ] for level, expected in levels: - record = Mock(levelno=level, name="custom.logger") + record = Mock(levelno=level) + record.name = "custom.logger" mock_is_system.return_value = level % 2 == 0 # alternate True/False dispatcher.on_logging(record, "Test message") + # Determine expected category from cached constants if mock_is_system.return_value: - mock_rpc_log.RpcLogCategory.Value.assert_called_with("System") + expected_category = dispatcher_module._LOG_CATEGORY_SYSTEM else: - mock_rpc_log.RpcLogCategory.Value.assert_called_with("User") + expected_category = dispatcher_module._LOG_CATEGORY_USER + + # Verify RpcLog was initialized with correct mapped values + # We use call_args to verify kwargs, ignoring any extra kwargs + # like invocation_id if present + args, kwargs = mock_rpc_log.call_args + self.assertEqual(kwargs['level'], expected) + self.assertEqual(kwargs['log_category'], expected_category) + self.assertEqual(kwargs['message'], "Test message") + self.assertEqual(kwargs['category'], "custom.logger") def fake_import(name, globals=None, locals=None, fromlist=(), level=0): @@ -96,7 +110,7 @@ def fake_import(name, globals=None, locals=None, fromlist=(), level=0): mock_module.version.VERSION = AsyncMock(return_value="1.0.0") if name in ["azure_functions_runtime", "azure_functions_runtime_v1"]: return mock_module - return builtins.__import__(name, globals, locals, fromlist, level) + return _real_import(name, globals, locals, fromlist, level) @patch("proxy_worker.dispatcher.DependencyManager.should_load_cx_dependencies",