diff --git a/cookbook/client/tinker/self_host/sample.py b/cookbook/client/tinker/self_host/sample.py index 132eb63a..5a2307ec 100644 --- a/cookbook/client/tinker/self_host/sample.py +++ b/cookbook/client/tinker/self_host/sample.py @@ -52,7 +52,6 @@ params = types.SamplingParams( max_tokens=128, # Maximum number of tokens to generate temperature=0.7, - stop=['\n'] # Stop generation when a newline character is produced ) # Step 6: Send the sampling request to the server. diff --git a/cookbook/client/twinkle/self_host/self_congnition.py b/cookbook/client/twinkle/self_host/self_congnition.py index 6bf6afce..e31daaba 100644 --- a/cookbook/client/twinkle/self_host/self_congnition.py +++ b/cookbook/client/twinkle/self_host/self_congnition.py @@ -21,11 +21,20 @@ logger = get_logger() +base_model = 'Qwen/Qwen3.5-4B' +base_url = 'http://localhost:8000' +api_key = 'EMPTY_API_KEY' + # Step 2: Initialize the Twinkle client to communicate with the remote server. # - base_url: the address of the running Twinkle server # - api_key: authentication token (loaded from environment variable) -client = init_twinkle_client(base_url='http://127.0.0.1:8000', api_key='EMPTY_TOKEN') +client = init_twinkle_client(base_url=base_url, api_key=api_key) + +# List available models of the server +print('Available models:') +for item in client.get_server_capabilities().supported_models: + print('- ' + item.model_name) # Step 3: Query the server for existing training runs and their checkpoints. # This is useful for resuming a previous training session. @@ -50,7 +59,7 @@ def train(): dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(500))) # Apply a chat template so the data matches the model's expected input format - dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B', max_length=512) + dataset.set_template('Template', model_id=f'ms://{base_model}', max_length=512) # Replace placeholder names in the dataset with custom model/author names dataset.map('SelfCognitionProcessor', init_args={'model_name': 'twinkle模型', 'model_author': 'ModelScope社区'}) @@ -64,7 +73,7 @@ def train(): # Step 5: Configure the model # Create a multi-LoRA Transformers model pointing to the base model on ModelScope - model = MultiLoraTransformersModel(model_id='ms://Qwen/Qwen3.5-4B') + model = MultiLoraTransformersModel(model_id=f'ms://{base_model}') # Define LoRA configuration: apply low-rank adapters to all linear layers lora_config = LoraConfig(target_modules='all-linear') diff --git a/src/twinkle/server/common/__init__.py b/src/twinkle/server/common/__init__.py index bb00e2bd..4c290eb2 100644 --- a/src/twinkle/server/common/__init__.py +++ b/src/twinkle/server/common/__init__.py @@ -2,7 +2,6 @@ from .checkpoint_factory import create_checkpoint_manager, create_training_run_manager from .datum import datum_to_input_feature, extract_rl_feature, input_feature_to_datum from .router import StickyLoraRequestRouter -from .serialize import deserialize_object, serialize_object __all__ = [ 'datum_to_input_feature', @@ -11,6 +10,4 @@ 'create_checkpoint_manager', 'create_training_run_manager', 'StickyLoraRequestRouter', - 'deserialize_object', - 'serialize_object', ] diff --git a/src/twinkle/server/common/serialize.py b/src/twinkle/server/common/serialize.py deleted file mode 100644 index f1b3f6dd..00000000 --- a/src/twinkle/server/common/serialize.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -# Moved from twinkle/common/serialize.py — logic unchanged. -import json -from numbers import Number -from peft import LoraConfig -from typing import Any, Mapping - -from twinkle.dataset import DatasetMeta - -supported_types = { - DatasetMeta, - LoraConfig, -} - -primitive_types = (str, Number, bool, bytes, type(None)) -container_types = (Mapping, list, tuple, set, frozenset) -basic_types = (*primitive_types, *container_types) - - -def _serialize_data_slice(data_slice): - """Serialize data_slice (Iterable) into a JSON-compatible dict.""" - if data_slice is None: - return None - if isinstance(data_slice, range): - return {'_slice_type_': 'range', 'start': data_slice.start, 'stop': data_slice.stop, 'step': data_slice.step} - if isinstance(data_slice, (list, tuple)): - return {'_slice_type_': 'list', 'values': list(data_slice)} - raise ValueError(f'Http mode does not support data_slice of type {type(data_slice).__name__}. ' - 'Supported types: range, list, tuple.') - - -def _deserialize_data_slice(data_slice): - """Deserialize a dict back into the original data_slice object.""" - if data_slice is None: - return None - if not isinstance(data_slice, dict) or '_slice_type_' not in data_slice: - return data_slice - slice_type = data_slice['_slice_type_'] - if slice_type == 'range': - return range(data_slice['start'], data_slice['stop'], data_slice['step']) - if slice_type == 'list': - return data_slice['values'] - raise ValueError(f'Unsupported data_slice type: {slice_type}') - - -def serialize_object(obj) -> str: - if isinstance(obj, DatasetMeta): - data = obj.__dict__.copy() - data['data_slice'] = _serialize_data_slice(data.get('data_slice')) - data['_TWINKLE_TYPE_'] = 'DatasetMeta' - return json.dumps(data, ensure_ascii=False) - elif isinstance(obj, LoraConfig): - filtered_dict = { - _subkey: _subvalue - for _subkey, _subvalue in obj.__dict__.items() - if isinstance(_subvalue, basic_types) and not _subkey.startswith('_') - } - filtered_dict['_TWINKLE_TYPE_'] = 'LoraConfig' - return json.dumps(filtered_dict, ensure_ascii=False) - elif isinstance(obj, Mapping): - return json.dumps(obj, ensure_ascii=False) - elif isinstance(obj, basic_types): - return obj - else: - raise ValueError(f'Unsupported object: {obj}') - - -def deserialize_object(data: str) -> Any: - try: - data = json.loads(data) - except Exception: # noqa - return data - - if '_TWINKLE_TYPE_' in data: - _type = data.pop('_TWINKLE_TYPE_') - if _type == 'DatasetMeta': - data['data_slice'] = _deserialize_data_slice(data.get('data_slice')) - return DatasetMeta(**data) - elif _type == 'LoraConfig': - return LoraConfig(**data) - else: - raise ValueError(f'Unsupported type: {_type}') - else: - return data diff --git a/src/twinkle/server/gateway/proxy.py b/src/twinkle/server/gateway/proxy.py index e8346d6c..5ed9b7bf 100644 --- a/src/twinkle/server/gateway/proxy.py +++ b/src/twinkle/server/gateway/proxy.py @@ -67,8 +67,8 @@ def _prepare_headers(self, request_headers) -> dict[str, str]: headers.pop('host', None) headers.pop('content-length', None) request_id = request_headers.get('X-Ray-Serve-Request-Id') - if request_id is not None and not request_headers.get('serve_multiplexed_model_id'): - headers['serve_multiplexed_model_id'] = request_id + if request_id is not None and not request_headers.get('Serve-Multiplexed-Model-Id'): + headers['Serve-Multiplexed-Model-Id'] = request_id return headers async def proxy_request( @@ -99,7 +99,7 @@ async def proxy_request( service_type, endpoint, target_url, - headers.get('serve_multiplexed_model_id'), + headers.get('Serve-Multiplexed-Model-Id'), ) response = await self.client.request( diff --git a/src/twinkle/server/gateway/server.py b/src/twinkle/server/gateway/server.py index dd591ccf..713ba5c9 100644 --- a/src/twinkle/server/gateway/server.py +++ b/src/twinkle/server/gateway/server.py @@ -10,9 +10,9 @@ import asyncio from fastapi import FastAPI, HTTPException, Request from ray import serve -from tinker import types as tinker_types from typing import Any +import twinkle_client.types as types from twinkle.server.utils.state import get_server_state from twinkle.server.utils.validation import verify_request_token from twinkle.utils.logger import get_logger @@ -36,7 +36,7 @@ def __init__(self, self.http_options = http_options or {} self.proxy = ServiceProxy(http_options=http_options, route_prefix=self.route_prefix) self.supported_models = self._normalize_models(supported_models) or [ - tinker_types.SupportedModel(model_name='Qwen/Qwen3-30B-A3B-Instruct-2507'), + types.SupportedModel(model_name='Qwen/Qwen3-30B-A3B-Instruct-2507'), ] self._modelscope_config_lock = asyncio.Lock() @@ -45,12 +45,12 @@ def _normalize_models(self, supported_models): return [] normalized = [] for item in supported_models: - if isinstance(item, tinker_types.SupportedModel): + if isinstance(item, types.SupportedModel): normalized.append(item) elif isinstance(item, dict): - normalized.append(tinker_types.SupportedModel(**item)) + normalized.append(types.SupportedModel(**item)) elif isinstance(item, str): - normalized.append(tinker_types.SupportedModel(model_name=item)) + normalized.append(types.SupportedModel(model_name=item)) return normalized def _validate_base_model(self, base_model: str) -> None: @@ -61,8 +61,8 @@ def _validate_base_model(self, base_model: str) -> None: detail=f"Base model '{base_model}' is not supported. " f"Supported models: {', '.join(supported_model_names)}") - def _get_base_model(self, model_id: str) -> str: - metadata = self.state.get_model_metadata(model_id) + async def _get_base_model(self, model_id: str) -> str: + metadata = await self.state.get_model_metadata(model_id) if metadata and metadata.get('base_model'): return metadata['base_model'] raise HTTPException(status_code=404, detail=f'Model {model_id} not found') diff --git a/src/twinkle/server/gateway/tinker_gateway_handlers.py b/src/twinkle/server/gateway/tinker_gateway_handlers.py index 71c0654f..516da528 100644 --- a/src/twinkle/server/gateway/tinker_gateway_handlers.py +++ b/src/twinkle/server/gateway/tinker_gateway_handlers.py @@ -42,7 +42,9 @@ async def get_server_capabilities( request: Request, self: GatewayServer = Depends(self_fn), ) -> types.GetServerCapabilitiesResponse: - return types.GetServerCapabilitiesResponse(supported_models=self.supported_models) + # Convert twinkle_client.types.SupportedModel to tinker.types.SupportedModel + tinker_supported_models = [types.SupportedModel(model_name=m.model_name) for m in self.supported_models] + return types.GetServerCapabilitiesResponse(supported_models=tinker_supported_models) @app.post('/telemetry') async def telemetry(request: Request, body: types.TelemetrySendRequest) -> types.TelemetryResponse: @@ -54,7 +56,7 @@ async def create_session( body: types.CreateSessionRequest, self: GatewayServer = Depends(self_fn), ) -> types.CreateSessionResponse: - session_id = self.state.create_session(body.model_dump()) + session_id = await self.state.create_session(body.model_dump()) return types.CreateSessionResponse(session_id=session_id) @app.post('/session_heartbeat') @@ -70,7 +72,7 @@ async def session_heartbeat( async def create_sampling_session( request: Request, body: types.CreateSamplingSessionRequest, self: GatewayServer = Depends(self_fn) ) -> types.CreateSamplingSessionResponse: # noqa: E125 - sampling_session_id = self.state.create_sampling_session(body.model_dump()) + sampling_session_id = await self.state.create_sampling_session(body.model_dump()) return types.CreateSamplingSessionResponse(sampling_session_id=sampling_session_id) @app.post('/retrieve_future') @@ -221,36 +223,36 @@ async def create_model(request: Request, body: types.CreateModelRequest, @app.post('/get_info') async def get_info(request: Request, body: types.GetInfoRequest, self: GatewayServer = Depends(self_fn)) -> Any: - return await self.proxy.proxy_to_model(request, 'get_info', self._get_base_model(body.model_id)) + return await self.proxy.proxy_to_model(request, 'get_info', await self._get_base_model(body.model_id)) @app.post('/unload_model') async def unload_model(request: Request, body: types.UnloadModelRequest, self: GatewayServer = Depends(self_fn)) -> Any: - return await self.proxy.proxy_to_model(request, 'unload_model', self._get_base_model(body.model_id)) + return await self.proxy.proxy_to_model(request, 'unload_model', await self._get_base_model(body.model_id)) @app.post('/forward') async def forward(request: Request, body: types.ForwardRequest, self: GatewayServer = Depends(self_fn)) -> Any: - return await self.proxy.proxy_to_model(request, 'forward', self._get_base_model(body.model_id)) + return await self.proxy.proxy_to_model(request, 'forward', await self._get_base_model(body.model_id)) @app.post('/forward_backward') async def forward_backward(request: Request, body: types.ForwardBackwardRequest, self: GatewayServer = Depends(self_fn)) -> Any: - return await self.proxy.proxy_to_model(request, 'forward_backward', self._get_base_model(body.model_id)) + return await self.proxy.proxy_to_model(request, 'forward_backward', await self._get_base_model(body.model_id)) @app.post('/optim_step') async def optim_step(request: Request, body: types.OptimStepRequest, self: GatewayServer = Depends(self_fn)) -> Any: - return await self.proxy.proxy_to_model(request, 'optim_step', self._get_base_model(body.model_id)) + return await self.proxy.proxy_to_model(request, 'optim_step', await self._get_base_model(body.model_id)) @app.post('/save_weights') async def save_weights(request: Request, body: types.SaveWeightsRequest, self: GatewayServer = Depends(self_fn)) -> Any: - return await self.proxy.proxy_to_model(request, 'save_weights', self._get_base_model(body.model_id)) + return await self.proxy.proxy_to_model(request, 'save_weights', await self._get_base_model(body.model_id)) @app.post('/load_weights') async def load_weights(request: Request, body: types.LoadWeightsRequest, self: GatewayServer = Depends(self_fn)) -> Any: - return await self.proxy.proxy_to_model(request, 'load_weights', self._get_base_model(body.model_id)) + return await self.proxy.proxy_to_model(request, 'load_weights', await self._get_base_model(body.model_id)) # --- Sampler Proxy Endpoints --- @@ -258,7 +260,7 @@ async def load_weights(request: Request, body: types.LoadWeightsRequest, async def asample(request: Request, body: types.SampleRequest, self: GatewayServer = Depends(self_fn)) -> Any: base_model = body.base_model if not base_model and body.sampling_session_id: - session = self.state.get_sampling_session(body.sampling_session_id) + session = await self.state.get_sampling_session(body.sampling_session_id) if session: base_model = session.get('base_model') return await self.proxy.proxy_to_sampler(request, 'asample', base_model) @@ -269,4 +271,5 @@ async def save_weights_for_sampler( body: types.SaveWeightsForSamplerRequest, self: GatewayServer = Depends(self_fn), ) -> Any: - return await self.proxy.proxy_to_model(request, 'save_weights_for_sampler', self._get_base_model(body.model_id)) + return await self.proxy.proxy_to_model(request, 'save_weights_for_sampler', await + self._get_base_model(body.model_id)) diff --git a/src/twinkle/server/gateway/twinkle_gateway_handlers.py b/src/twinkle/server/gateway/twinkle_gateway_handlers.py index 9c0a3ba7..9271b681 100644 --- a/src/twinkle/server/gateway/twinkle_gateway_handlers.py +++ b/src/twinkle/server/gateway/twinkle_gateway_handlers.py @@ -28,13 +28,20 @@ def _register_twinkle_routes(app: FastAPI, self_fn: Callable[[], GatewayServer]) async def healthz(request: Request) -> types.HealthResponse: return types.HealthResponse(status='ok') + @app.get('/twinkle/get_server_capabilities', response_model=types.GetServerCapabilitiesResponse) + async def get_server_capabilities( + request: Request, + self: GatewayServer = Depends(self_fn), + ) -> types.GetServerCapabilitiesResponse: + return types.GetServerCapabilitiesResponse(supported_models=self.supported_models) + @app.post('/twinkle/create_session', response_model=types.CreateSessionResponse) async def create_session( request: Request, body: types.CreateSessionRequest, self: GatewayServer = Depends(self_fn), ) -> types.CreateSessionResponse: - session_id = self.state.create_session(body.model_dump()) + session_id = await self.state.create_session(body.model_dump()) return types.CreateSessionResponse(session_id=session_id) @app.post('/twinkle/session_heartbeat', response_model=types.SessionHeartbeatResponse) diff --git a/src/twinkle/server/launcher.py b/src/twinkle/server/launcher.py index 53b88350..cf084606 100644 --- a/src/twinkle/server/launcher.py +++ b/src/twinkle/server/launcher.py @@ -26,6 +26,7 @@ from typing import Any, Callable, Dict, Optional, Union from twinkle import get_logger +from twinkle.server.utils.ray_serve_patch import apply_ray_serve_patches, get_runtime_env_for_patches logger = get_logger() @@ -124,7 +125,10 @@ def _init_ray(self) -> None: namespace = self.ray_namespace or self.config.get('ray_namespace') or 'twinkle_cluster' if not ray.is_initialized(): - ray.init(namespace=namespace) + # Use runtime_env to apply patches in worker processes + # This is required because Ray Serve's ProxyActor runs in separate processes + runtime_env = get_runtime_env_for_patches() + ray.init(namespace=namespace, runtime_env=runtime_env) logger.info(f'Ray initialized with namespace={namespace}') self._ray_initialized = True @@ -189,6 +193,9 @@ def _deploy_application(self, app_config: dict[str, Any]) -> None: def launch(self) -> None: """Launch the server with all configured applications.""" + # Apply Ray Serve patches before initializing Ray + apply_ray_serve_patches() + self._init_ray() self._start_serve() diff --git a/src/twinkle/server/model/app.py b/src/twinkle/server/model/app.py index 49692e7f..c8a90686 100644 --- a/src/twinkle/server/model/app.py +++ b/src/twinkle/server/model/app.py @@ -14,7 +14,7 @@ import twinkle from twinkle import DeviceGroup, DeviceMesh -from twinkle.server.utils.adapter_manager import AdapterManagerMixin +from twinkle.server.utils.lifecycle import AdapterManagerMixin from twinkle.server.utils.state import ServerStateProxy, get_server_state from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin from twinkle.server.utils.validation import get_token_from_request, verify_request_token @@ -78,12 +78,18 @@ def __init__(self, **kwargs) self.state: ServerStateProxy = get_server_state() - self.state.register_replica(self.replica_id, self.max_loras) + self._replica_registered = False # Initialize mixins self._init_task_queue(TaskQueueConfig.from_dict(queue_config)) self._init_adapter_manager(**adapter_config) - self.start_adapter_countdown() + # Note: countdown task is started lazily in _ensure_sticky() + + async def _ensure_replica_registered(self): + """Lazily register replica on first async request.""" + if not self._replica_registered: + await self.state.register_replica(self.replica_id, self.max_loras) + self._replica_registered = True @serve.multiplexed(max_num_models_per_replica=5) async def _sticky_entry(self, sticky_key: str): @@ -92,25 +98,35 @@ async def _sticky_entry(self, sticky_key: str): async def _ensure_sticky(self): sticky_key = serve.get_multiplexed_model_id() await self._sticky_entry(sticky_key) + # Lazy-start countdown task on first request (requires running event loop) + self._ensure_countdown_started() async def _on_request_start(self, request: Request) -> str: await self._ensure_sticky() + await self._ensure_replica_registered() token = get_token_from_request(request) return token def __del__(self): - self.state.unregister_replica(self.replica_id) - - def _cleanup_adapter(self, adapter_name: str) -> None: + try: + # Best-effort cleanup; event loop may already be closed + import asyncio + loop = asyncio.get_event_loop() + if loop.is_running(): + asyncio.create_task(self.state.unregister_replica(self.replica_id)) + except Exception: + pass + + async def _cleanup_adapter(self, adapter_name: str) -> None: if self.get_adapter_info(adapter_name): self.clear_adapter_state(adapter_name) self.model.remove_adapter(adapter_name) self.unregister_adapter(adapter_name) - self.state.unload_model(adapter_name) + await self.state.unload_model(adapter_name) - def _on_adapter_expired(self, adapter_name: str) -> None: + async def _on_adapter_expired(self, adapter_name: str) -> None: self.fail_pending_tasks_for_model(adapter_name, reason='Adapter expired') - self._cleanup_adapter(adapter_name) + await self._cleanup_adapter(adapter_name) def build_model_app(model_id: str, diff --git a/src/twinkle/server/model/backends/transformers_model.py b/src/twinkle/server/model/backends/transformers_model.py index 2a895e02..fe30f616 100644 --- a/src/twinkle/server/model/backends/transformers_model.py +++ b/src/twinkle/server/model/backends/transformers_model.py @@ -35,7 +35,9 @@ def tinker_forward_only(self, *, inputs: List[types.Datum], **kwargs): template = self.get_template(**kwargs) input_features = datum_to_input_feature(inputs, template) outputs = super().forward_only(inputs=input_features, **kwargs) - logits = outputs['logits'].detach().cpu() + logits = outputs.get('logits') + if logits is not None: + logits = logits.detach().cpu() logps = outputs.get('logps', None) if logps is not None: logps = logps.detach().cpu() @@ -58,7 +60,9 @@ def tinker_forward_backward(self, *, inputs: List[types.Datum], adapter_name: st loss_kwargs.update(loss_values) loss = super().calculate_loss(adapter_name=adapter_name, **loss_kwargs) super().backward(adapter_name=adapter_name, **kwargs) - logits = outputs['logits'].detach() + logits = outputs.get('logits') + if logits is not None: + logits = logits.detach() logps = outputs.get('logps', None) if logps is not None: logps = logps.detach().cpu() diff --git a/src/twinkle/server/model/tinker_handlers.py b/src/twinkle/server/model/tinker_handlers.py index 6f458d8f..34d9a5f3 100644 --- a/src/twinkle/server/model/tinker_handlers.py +++ b/src/twinkle/server/model/tinker_handlers.py @@ -40,8 +40,7 @@ async def create_model( async def _create_adapter(): _model_id = None try: - - _model_id = self.state.register_model(body.model_dump(), token=token, replica_id=self.replica_id) + _model_id = await self.state.register_model(body.model_dump(), token=token, replica_id=self.replica_id) if body.lora_config: lora_cfg = LoraConfig(r=body.lora_config.rank, target_modules='all-linear') adapter_name = self.get_adapter_name(adapter_name=_model_id) @@ -57,7 +56,7 @@ async def _create_adapter(): except Exception: if _model_id: adapter_name = self.get_adapter_name(adapter_name=_model_id) - self._cleanup_adapter(adapter_name) + await self._cleanup_adapter(adapter_name) logger.error(traceback.format_exc()) return types.RequestFailedResponse( error=traceback.format_exc(), @@ -96,7 +95,7 @@ async def unload_model( async def _do_unload(): adapter_name = self.get_adapter_name(adapter_name=body.model_id) - self._cleanup_adapter(adapter_name) + await self._cleanup_adapter(adapter_name) return types.UnloadModelResponse(model_id=body.model_id) return await self.schedule_task(_do_unload, model_id=body.model_id, token=token, task_type='unload_model') @@ -261,10 +260,10 @@ async def _do_save_for_sampler(): name=checkpoint_name, output_dir=save_dir, adapter_name=adapter_name, save_optimizer=False) payload = body.model_dump() payload['model_path'] = tinker_path - metadata = self.state.get_model_metadata(body.model_id) or {} + metadata = await self.state.get_model_metadata(body.model_id) or {} if metadata.get('base_model'): payload['base_model'] = metadata['base_model'] - sampling_session_id = self.state.create_sampling_session(payload) + sampling_session_id = await self.state.create_sampling_session(payload) return types.SaveWeightsForSamplerResponseInternal(path=None, sampling_session_id=sampling_session_id) except Exception: logger.error(traceback.format_exc()) diff --git a/src/twinkle/server/model/twinkle_handlers.py b/src/twinkle/server/model/twinkle_handlers.py index 35c87441..171a17b5 100644 --- a/src/twinkle/server/model/twinkle_handlers.py +++ b/src/twinkle/server/model/twinkle_handlers.py @@ -19,9 +19,9 @@ import twinkle_client.types as types from twinkle.data_format import InputFeature, Trajectory from twinkle.server.common.checkpoint_factory import create_checkpoint_manager, create_training_run_manager -from twinkle.server.common.serialize import deserialize_object from twinkle.server.utils.validation import get_session_id_from_request from twinkle.utils.logger import get_logger +from twinkle_client.common.serialize import deserialize_object logger = get_logger() @@ -59,9 +59,14 @@ def _register_twinkle_routes(app: FastAPI, self_fn: Callable[[], ModelManagement async def run_task(coro): """Await a schedule_task_and_wait coroutine and surface any exception as a structured HTTP 500 response so the client receives the full traceback instead - of an opaque connection-level error.""" + of an opaque connection-level error. + + Note: HTTPException is re-raised directly to preserve its status code and detail. + """ try: return await coro + except HTTPException: + raise # Re-raise HTTPException directly to preserve status code except Exception: logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=traceback.format_exc()) @@ -69,11 +74,13 @@ async def run_task(coro): @app.post('/twinkle/create', response_model=types.CreateResponse) async def create(request: Request, body: types.CreateRequest, self: ModelManagement = Depends(self_fn)) -> types.CreateResponse: + await self._on_request_start(request) return types.CreateResponse() @app.post('/twinkle/forward', response_model=types.ForwardResponse) async def forward(request: Request, body: types.ForwardRequest, self: ModelManagement = Depends(self_fn)) -> types.ForwardResponse: + await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -91,6 +98,7 @@ async def forward_only( body: types.ForwardOnlyRequest, self: ModelManagement = Depends(self_fn), ) -> types.ForwardResponse: + await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -108,6 +116,7 @@ async def calculate_loss( body: types.AdapterRequest, self: ModelManagement = Depends(self_fn), ) -> types.CalculateLossResponse: + await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -120,6 +129,7 @@ async def _task(): @app.post('/twinkle/backward') async def backward(request: Request, body: types.AdapterRequest, self: ModelManagement = Depends(self_fn)) -> None: + await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -135,6 +145,7 @@ async def forward_backward( body: types.ForwardRequest, self: ModelManagement = Depends(self_fn), ) -> types.ForwardBackwardResponse: + await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -152,6 +163,7 @@ async def clip_grad_norm( body: types.AdapterRequest, self: ModelManagement = Depends(self_fn), ) -> types.ClipGradNormResponse: + await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -164,6 +176,7 @@ async def _task(): @app.post('/twinkle/step') async def step(request: Request, body: types.AdapterRequest, self: ModelManagement = Depends(self_fn)) -> None: + await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -175,6 +188,7 @@ async def _task(): @app.post('/twinkle/zero_grad') async def zero_grad(request: Request, body: types.AdapterRequest, self: ModelManagement = Depends(self_fn)) -> None: + await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -186,6 +200,7 @@ async def _task(): @app.post('/twinkle/lr_step') async def lr_step(request: Request, body: types.AdapterRequest, self: ModelManagement = Depends(self_fn)) -> None: + await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -201,6 +216,7 @@ async def clip_grad_and_step( body: types.ClipGradAndStepRequest, self: ModelManagement = Depends(self_fn), ) -> None: + await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -221,6 +237,7 @@ async def get_train_configs( body: types.AdapterRequest, self: ModelManagement = Depends(self_fn), ) -> types.GetTrainConfigsResponse: + await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -233,6 +250,7 @@ async def _task(): @app.post('/twinkle/set_loss') async def set_loss(request: Request, body: types.SetLossRequest, self: ModelManagement = Depends(self_fn)) -> None: + await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -248,6 +266,7 @@ async def set_optimizer( body: types.SetOptimizerRequest, self: ModelManagement = Depends(self_fn), ) -> None: + await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -263,6 +282,7 @@ async def set_lr_scheduler( body: types.SetLrSchedulerRequest, self: ModelManagement = Depends(self_fn), ) -> None: + await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -380,6 +400,7 @@ async def apply_patch( body: types.ApplyPatchRequest, self: ModelManagement = Depends(self_fn), ) -> None: + await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -396,6 +417,7 @@ async def add_metric( body: types.AddMetricRequest, self: ModelManagement = Depends(self_fn), ) -> None: + await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -412,6 +434,7 @@ async def set_template( body: types.SetTemplateRequest, self: ModelManagement = Depends(self_fn), ) -> None: + await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -427,6 +450,7 @@ async def set_processor( body: types.SetProcessorRequest, self: ModelManagement = Depends(self_fn), ) -> None: + await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -442,6 +466,7 @@ async def calculate_metric( body: types.CalculateMetricRequest, self: ModelManagement = Depends(self_fn), ) -> types.CalculateMetricResponse: + await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -458,6 +483,7 @@ async def get_state_dict( body: types.GetStateDictRequest, self: ModelManagement = Depends(self_fn), ) -> types.GetStateDictResponse: + await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): diff --git a/src/twinkle/server/processor/app.py b/src/twinkle/server/processor/app.py index 4b03af86..7dd4bc7c 100644 --- a/src/twinkle/server/processor/app.py +++ b/src/twinkle/server/processor/app.py @@ -20,7 +20,7 @@ import twinkle from twinkle import DeviceGroup, DeviceMesh, get_logger -from twinkle.server.utils.processor_manager import ProcessorManagerMixin +from twinkle.server.utils.lifecycle import ProcessorManagerMixin from twinkle.server.utils.state import ServerStateProxy, get_server_state from twinkle.server.utils.validation import verify_request_token from .twinkle_handlers import _register_processor_routes @@ -69,7 +69,7 @@ def __init__(self, processor_timeout=float(_cfg.get('processor_timeout', 1800.0)), per_token_processor_limit=int(_cfg.get('per_token_processor_limit', _env_limit)), ) - self.start_processor_countdown() + # Note: countdown task is started lazily in _ensure_sticky() @serve.multiplexed(max_num_models_per_replica=100) async def _sticky_entry(self, sticky_key: str): @@ -78,6 +78,8 @@ async def _sticky_entry(self, sticky_key: str): async def _ensure_sticky(self): sticky_key = serve.get_multiplexed_model_id() await self._sticky_entry(sticky_key) + # Lazy-start countdown task on first request (requires running event loop) + self._ensure_countdown_started() def _on_processor_expired(self, processor_id: str) -> None: """Called by the countdown thread when a processor's session expires.""" diff --git a/src/twinkle/server/processor/twinkle_handlers.py b/src/twinkle/server/processor/twinkle_handlers.py index 86e35f86..09424044 100644 --- a/src/twinkle/server/processor/twinkle_handlers.py +++ b/src/twinkle/server/processor/twinkle_handlers.py @@ -18,9 +18,9 @@ from .app import ProcessorManagement import twinkle_client.types as types -from twinkle.server.common.serialize import deserialize_object from twinkle.server.utils.validation import get_session_id_from_request, get_token_from_request from twinkle.utils.logger import get_logger +from twinkle_client.common.serialize import deserialize_object logger = get_logger() diff --git a/src/twinkle/server/sampler/app.py b/src/twinkle/server/sampler/app.py index c69a6956..dc54e4f6 100644 --- a/src/twinkle/server/sampler/app.py +++ b/src/twinkle/server/sampler/app.py @@ -13,7 +13,6 @@ import twinkle from twinkle import DeviceGroup, DeviceMesh -from twinkle.server.utils.adapter_manager import AdapterManagerMixin from twinkle.server.utils.state import ServerStateProxy, get_server_state from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin from twinkle.server.utils.validation import get_token_from_request, verify_request_token @@ -25,14 +24,13 @@ logger = get_logger() -class SamplerManagement(TaskQueueMixin, AdapterManagerMixin): +class SamplerManagement(TaskQueueMixin): """Unified sampler management service. Manages: - vLLM or Torch sampler initialization and lifecycle - Tinker inference requests (/tinker/asample) with rate limiting via TaskQueueMixin - Twinkle inference requests (/twinkle/*) calling sampler directly - - Adapter lifecycle via AdapterManagerMixin - Template configuration for trajectory encoding """ @@ -43,7 +41,6 @@ def __init__(self, device_mesh: dict[str, Any], sampler_type: str = 'vllm', engine_args: dict[str, Any] | None = None, - adapter_config: dict[str, Any] | None = None, queue_config: dict[str, Any] | None = None, **kwargs): self.device_group = DeviceGroup(**device_group) @@ -82,11 +79,8 @@ def __init__(self, self.sampler.set_template('Template', model_id=model_id) self.state: ServerStateProxy = get_server_state() - # Initialize both mixins + # Initialize task queue mixin self._init_task_queue(TaskQueueConfig.from_dict(queue_config)) - _adapter_config = adapter_config or {} - self._init_adapter_manager(**_adapter_config) - self.start_adapter_countdown() @serve.multiplexed(max_num_models_per_replica=5) async def _sticky_entry(self, sticky_key: str): @@ -101,14 +95,6 @@ async def _on_request_start(self, request: Request) -> str: token = get_token_from_request(request) return token - def _on_adapter_expired(self, adapter_name: str, token: str = None) -> None: - """Handle expired adapters by removing them from the sampler.""" - try: - self.sampler.remove_adapter(adapter_name) - logger.info(f'Removed expired adapter {adapter_name}') - except Exception as e: - logger.warning(f'Failed to remove expired adapter {adapter_name}: {e}') - def build_sampler_app(model_id: str, nproc_per_node: int, @@ -117,7 +103,6 @@ def build_sampler_app(model_id: str, deploy_options: dict[str, Any], sampler_type: str = 'vllm', engine_args: dict[str, Any] | None = None, - adapter_config: dict[str, Any] | None = None, queue_config: dict[str, Any] | None = None, **kwargs): """Build a unified sampler application for text generation inference. @@ -133,7 +118,6 @@ def build_sampler_app(model_id: str, deploy_options: Ray Serve deployment options sampler_type: Type of sampler to use ('vllm' or 'torch') engine_args: Additional engine arguments for the sampler - adapter_config: Adapter lifecycle config (timeout, per-token limits) queue_config: Task queue configuration dict (rps_limit, tps_limit, etc.) **kwargs: Additional arguments passed to the sampler @@ -161,8 +145,7 @@ def get_self() -> SamplerManagement: SamplerManagementWithIngress = serve.ingress(app)(SamplerManagement) DeploymentClass = serve.deployment(name='SamplerManagement')(SamplerManagementWithIngress) return DeploymentClass.options(**deploy_options).bind(model_id, nproc_per_node, device_group, device_mesh, - sampler_type, engine_args, adapter_config, queue_config, - **kwargs) + sampler_type, engine_args, queue_config, **kwargs) build_sampler_app = wrap_builder_with_device_group_env(build_sampler_app) diff --git a/src/twinkle/server/sampler/tinker_handlers.py b/src/twinkle/server/sampler/tinker_handlers.py index 16b75040..f0106024 100644 --- a/src/twinkle/server/sampler/tinker_handlers.py +++ b/src/twinkle/server/sampler/tinker_handlers.py @@ -51,7 +51,7 @@ async def _do_sample(): # Get model_path from body or sampling session model_path = body.model_path if not model_path and body.sampling_session_id: - session = self.state.get_sampling_session(body.sampling_session_id) + session = await self.state.get_sampling_session(body.sampling_session_id) if session: model_path = session.get('model_path') diff --git a/src/twinkle/server/sampler/twinkle_handlers.py b/src/twinkle/server/sampler/twinkle_handlers.py index 9b981d31..f20d1738 100644 --- a/src/twinkle/server/sampler/twinkle_handlers.py +++ b/src/twinkle/server/sampler/twinkle_handlers.py @@ -10,7 +10,7 @@ from fastapi import Depends, FastAPI, HTTPException, Request from typing import TYPE_CHECKING, Callable -from twinkle.server.common.serialize import deserialize_object +from twinkle_client.common.serialize import deserialize_object if TYPE_CHECKING: from .app import SamplerManagement @@ -154,13 +154,10 @@ def add_adapter_to_sampler( """Add a LoRA adapter to the sampler.""" assert body.adapter_name, 'You need to specify a valid `adapter_name`' full_adapter_name = _get_twinkle_sampler_adapter_name(request, body.adapter_name) - from twinkle.server.utils.validation import get_token_from_request - token = get_token_from_request(request) from peft import LoraConfig config = LoraConfig(**body.config) if isinstance(body.config, dict) else body.config - self.register_adapter(full_adapter_name, token) self.sampler.add_adapter_to_sampler(full_adapter_name, config) return types.AddAdapterResponse(adapter_name=full_adapter_name) diff --git a/src/twinkle/server/utils/__init__.py b/src/twinkle/server/utils/__init__.py index d19d34d0..fdab4278 100644 --- a/src/twinkle/server/utils/__init__.py +++ b/src/twinkle/server/utils/__init__.py @@ -1,8 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from .adapter_manager import AdapterManagerMixin from .checkpoint_base import (TRAIN_RUN_INFO_FILENAME, TWINKLE_DEFAULT_SAVE_DIR, BaseCheckpointManager, BaseFileManager, BaseTrainingRunManager) from .device_utils import auto_fill_device_group_visible_devices, wrap_builder_with_device_group_env -from .processor_manager import ProcessorManagerMixin +from .lifecycle import AdapterManagerMixin, ProcessorManagerMixin, SessionResourceMixin from .rate_limiter import RateLimiter from .task_queue import QueueState, TaskQueueConfig, TaskQueueMixin, TaskStatus diff --git a/src/twinkle/server/utils/adapter_manager.py b/src/twinkle/server/utils/adapter_manager.py deleted file mode 100644 index 844ccfd1..00000000 --- a/src/twinkle/server/utils/adapter_manager.py +++ /dev/null @@ -1,314 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -""" -Adapter Lifecycle Manager Mixin for Twinkle Server. - -This module provides adapter lifecycle management as a mixin class that can be -inherited directly by services. It tracks adapter activity and provides interfaces -for registration, heartbeat updates, and expiration handling. - -By inheriting this mixin, services can override the _on_adapter_expired() method -to handle expired adapters without using callbacks or polling. -""" -from __future__ import annotations - -import threading -import time -from typing import TYPE_CHECKING, Any - -if TYPE_CHECKING: - from twinkle.server.utils.state import ServerStateProxy - -from twinkle.utils.logger import get_logger - -logger = get_logger() - - -class AdapterManagerMixin: - """Mixin for adapter lifecycle management with session-based expiration. - - This mixin tracks adapter activity and automatically expires adapters - when their associated session expires. - - Inheriting classes should: - 1. Call _init_adapter_manager() in __init__ - 2. Override _on_adapter_expired() to customize expiration handling - - Attributes: - _adapter_timeout: Session inactivity timeout in seconds used to determine if a session is alive. - _adapter_max_lifetime: Maximum lifetime in seconds for any adapter, regardless of session liveness. - """ - - # Type hint for state attribute that inheriting classes must provide - state: ServerStateProxy - - def _init_adapter_manager( - self, - adapter_timeout: float = 1800.0, - adapter_max_lifetime: float = 36000.0, - ) -> None: - """Initialize the adapter manager. - - This should be called in the __init__ of the inheriting class. - - Args: - adapter_timeout: Timeout in seconds used to check whether a session is still alive. - Default is 1800.0 (30 minutes). - adapter_max_lifetime: Maximum lifetime in seconds for an adapter regardless of session - liveness. Adapters older than this are treated as expired. Default is 36000.0 (10 hours). - """ - self._adapter_timeout = adapter_timeout - self._adapter_max_lifetime = adapter_max_lifetime - - # Adapter lifecycle tracking - # Dict mapping adapter_name -> - # {'token': str, 'session_id': str, 'created_at': float, 'state': dict, 'expiring': bool} - self._adapter_records: dict[str, dict[str, Any]] = {} - - # Countdown thread - self._adapter_countdown_thread: threading.Thread | None = None - self._adapter_countdown_running = False - - def register_adapter(self, adapter_name: str, token: str, session_id: str) -> None: - """Register a new adapter for lifecycle tracking. - - The adapter will expire when its associated session expires. - - Args: - adapter_name: Name of the adapter to register. - token: User token that owns this adapter. - session_id: Session ID to associate with this adapter. Must be a non-empty string. - - Raises: - ValueError: If session_id is None or empty. - """ - if not session_id: - raise ValueError(f'session_id must be provided when registering adapter {adapter_name}') - current_time = time.time() - self._adapter_records[adapter_name] = { - 'token': token, - 'session_id': session_id, - 'created_at': current_time, - 'state': {}, - 'expiring': False, - } - logger.debug( - f'[AdapterManager] Registered adapter {adapter_name} for token {token[:8]}... (session: {session_id})') - - def _is_session_alive(self, session_id: str) -> bool: - """Check if a session is still alive via state proxy. - - Args: - session_id: Session ID to check - - Returns: - True if session is alive, False if expired or not found - """ - if not session_id: - return True # No session association means always alive - - # Get session last heartbeat through proxy - last_heartbeat = self.state.get_session_last_heartbeat(session_id) - if last_heartbeat is None: - return False # Session doesn't exist - - # Check if session has timed out using adapter_timeout - return (time.time() - last_heartbeat) < self._adapter_timeout - - def unregister_adapter(self, adapter_name: str) -> bool: - """Unregister an adapter from lifecycle tracking. - - Args: - adapter_name: Name of the adapter to unregister. - - Returns: - True if adapter was found and removed, False otherwise. - """ - if adapter_name in self._adapter_records: - adapter_info = self._adapter_records.pop(adapter_name) - token = adapter_info.get('token') - logger.debug( - f"[AdapterManager] Unregistered adapter {adapter_name} for token {token[:8] if token else 'unknown'}..." - ) - return True - return False - - def set_adapter_state(self, adapter_name: str, key: str, value: Any) -> None: - """Set a per-adapter state value. - - This is intentionally generic so higher-level services can store - adapter-scoped state (e.g., training readiness) without maintaining - separate side maps. - """ - info = self._adapter_records.get(adapter_name) - if info is None: - return - state = info.setdefault('state', {}) - state[key] = value - - def get_adapter_state(self, adapter_name: str, key: str, default: Any = None) -> Any: - """Get a per-adapter state value.""" - info = self._adapter_records.get(adapter_name) - if info is None: - return default - state = info.get('state') or {} - return state.get(key, default) - - def pop_adapter_state(self, adapter_name: str, key: str, default: Any = None) -> Any: - """Pop a per-adapter state value.""" - info = self._adapter_records.get(adapter_name) - if info is None: - return default - state = info.get('state') - if not isinstance(state, dict): - return default - return state.pop(key, default) - - def clear_adapter_state(self, adapter_name: str) -> None: - """Clear all per-adapter state values.""" - info = self._adapter_records.get(adapter_name) - if info is None: - return - info['state'] = {} - - def get_adapter_info(self, adapter_name: str) -> dict[str, Any] | None: - """Get information about a registered adapter. - - Args: - adapter_name: Name of the adapter to query. - - Returns: - Dict with adapter information or None if not found. - """ - return self._adapter_records.get(adapter_name) - - def _on_adapter_expired(self, adapter_name: str) -> None: - """Hook method called when an adapter expires. - - This method must be overridden by inheriting classes to handle - adapter expiration logic. The base implementation raises NotImplementedError. - - Args: - adapter_name: Name of the expired adapter. - - Raises: - NotImplementedError: If not overridden by inheriting class. - """ - raise NotImplementedError(f'_on_adapter_expired must be implemented by {self.__class__.__name__}') - - @staticmethod - def get_adapter_name(adapter_name: str) -> str: - """Get the adapter name for a request. - - This is a passthrough method for consistency with the original API. - - Args: - adapter_name: The adapter name (typically model_id) - - Returns: - The adapter name to use - """ - return adapter_name - - def assert_adapter_exists(self, adapter_name: str) -> None: - """Validate that an adapter exists and is not expiring.""" - info = self._adapter_records.get(adapter_name) - assert adapter_name and info is not None and not info.get('expiring'), \ - f'Adapter {adapter_name} not found' - - def _adapter_countdown_loop(self) -> None: - """Background thread that monitors and handles adapters whose session has expired or exceeded max lifetime. - - This thread runs continuously and: - 1. Checks whether an adapter has exceeded `_adapter_max_lifetime` (sync, no async call) - 2. Checks session liveness for remaining adapters every second - 3. Calls _on_adapter_expired() for adapters that have expired - 4. Removes expired adapters from tracking - """ - logger.debug(f'[AdapterManager] Countdown thread started (session_timeout={self._adapter_timeout}s)') - while self._adapter_countdown_running: - try: - time.sleep(10) - - expired_adapters: list[tuple[str, str | None]] = [] - # Create snapshot to avoid modification during iteration - adapter_snapshot = list(self._adapter_records.items()) - for adapter_name, info in adapter_snapshot: - if info.get('expiring'): - continue - - session_id = info.get('session_id') - created_at = info.get('created_at', 0.0) - now = time.time() - - # Check max lifetime first (no async call needed) - if now - created_at >= self._adapter_max_lifetime: - logger.debug(f'[AdapterManager] Adapter {adapter_name} exceeded max lifetime ' - f'({self._adapter_max_lifetime}s), marking as expired') - info['expiring'] = True - info['state'] = {} - token = info.get('token') - expired_adapters.append((adapter_name, token, session_id)) - continue - - try: - session_alive = self._is_session_alive(session_id) - except Exception as e: - logger.warning(f'[AdapterManager] Failed to check session liveness for {adapter_name}: ' - f'{type(e).__name__}: {e}') - continue - session_expired = not session_alive - logger.debug(f'[AdapterManager] Adapter {adapter_name} session check ' - f'(session_id={session_id}, session_alive={not session_expired})') - - if session_expired: - info['expiring'] = True - info['state'] = {} # best-effort clear - token = info.get('token') - expired_adapters.append((adapter_name, token, session_id)) - - for adapter_name, _token, session_id in expired_adapters: - success = False - try: - self._on_adapter_expired(adapter_name) - logger.info(f'[AdapterManager] Adapter {adapter_name} expired ' - f'(reason=session_expired, session={session_id})') - success = True - except Exception as e: - logger.warning(f'[AdapterManager] Error while expiring adapter {adapter_name}: {e}') - finally: - if success: - self._adapter_records.pop(adapter_name, None) - else: - info = self._adapter_records.get(adapter_name) - if info is not None: - info['expiring'] = False - - except Exception as e: - logger.warning(f'[AdapterManager] Error in countdown loop: {e}') - continue - - logger.debug('[AdapterManager] Countdown thread stopped') - - def start_adapter_countdown(self) -> None: - """Start the background adapter countdown thread. - - This should be called once when the mixin is initialized. - It's safe to call multiple times - subsequent calls are ignored. - """ - if not self._adapter_countdown_running: - self._adapter_countdown_running = True - self._adapter_countdown_thread = threading.Thread(target=self._adapter_countdown_loop, daemon=True) - self._adapter_countdown_thread.start() - logger.debug('[AdapterManager] Countdown thread started') - - def stop_adapter_countdown(self) -> None: - """Stop the background adapter countdown thread. - - This should be called when shutting down the server. - """ - if self._adapter_countdown_running: - self._adapter_countdown_running = False - if self._adapter_countdown_thread: - # Wait for thread to finish (it checks the flag every second) - self._adapter_countdown_thread.join(timeout=2.0) - logger.debug('[AdapterManager] Countdown thread stopped') diff --git a/src/twinkle/server/utils/lifecycle/__init__.py b/src/twinkle/server/utils/lifecycle/__init__.py new file mode 100644 index 00000000..ea574000 --- /dev/null +++ b/src/twinkle/server/utils/lifecycle/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Lifecycle management utilities for session-bound resources.""" + +from .adapter import AdapterManagerMixin +from .base import SessionResourceMixin +from .processor import ProcessorManagerMixin + +__all__ = ['AdapterManagerMixin', 'ProcessorManagerMixin', 'SessionResourceMixin'] diff --git a/src/twinkle/server/utils/lifecycle/adapter.py b/src/twinkle/server/utils/lifecycle/adapter.py new file mode 100644 index 00000000..6b45228b --- /dev/null +++ b/src/twinkle/server/utils/lifecycle/adapter.py @@ -0,0 +1,185 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Adapter Lifecycle Manager Mixin for Twinkle Server. + +This module provides adapter lifecycle management as a mixin class that can be +inherited directly by services. It tracks adapter activity and provides interfaces +for registration, heartbeat updates, and expiration handling. + +By inheriting this mixin, services can override the _on_adapter_expired() method +to handle expired adapters without using callbacks or polling. +""" +from __future__ import annotations + +from fastapi import HTTPException +from typing import Any + +from twinkle.utils.logger import get_logger +from .base import SessionResourceMixin + +logger = get_logger() + + +class AdapterManagerMixin(SessionResourceMixin): + """Mixin for adapter lifecycle management with session-based expiration. + + This mixin tracks adapter activity and automatically expires adapters + when their associated session expires. + + Inheriting classes should: + 1. Call _init_adapter_manager() in __init__ + 2. Override _on_adapter_expired() to customize expiration handling + + Attributes: + _adapter_timeout: Session inactivity timeout in seconds used to determine if a session is alive. + _adapter_max_lifetime: Maximum lifetime in seconds for any adapter, regardless of session liveness. + """ + + # Set resource type for logging + _resource_type = 'Adapter' + + def _init_adapter_manager( + self, + adapter_timeout: float = 1800.0, + adapter_max_lifetime: float = 36000.0, + ) -> None: + """Initialize the adapter manager. + + This should be called in the __init__ of the inheriting class. + + Args: + adapter_timeout: Timeout in seconds used to check whether a session is still alive. + Default is 1800.0 (30 minutes). + adapter_max_lifetime: Maximum lifetime in seconds for an adapter regardless of session + liveness. Adapters older than this are treated as expired. Default is 36000.0 (10 hours). + """ + self._init_resource_manager( + resource_timeout=adapter_timeout, + resource_max_lifetime=adapter_max_lifetime, + ) + + @property + def _adapter_timeout(self) -> float: + """Adapter timeout for backward compatibility.""" + return self._resource_timeout + + @property + def _adapter_max_lifetime(self) -> float | None: + """Adapter max lifetime for backward compatibility.""" + return self._resource_max_lifetime + + @property + def _adapter_records(self) -> dict[str, dict[str, Any]]: + """Adapter records for backward compatibility.""" + return self._resource_records + + def register_adapter(self, adapter_name: str, token: str, session_id: str) -> None: + """Register a new adapter for lifecycle tracking. + + The adapter will expire when its associated session expires. + + Args: + adapter_name: Name of the adapter to register. + token: User token that owns this adapter. + session_id: Session ID to associate with this adapter. Must be a non-empty string. + + Raises: + ValueError: If session_id is None or empty. + """ + self.register_resource(adapter_name, token, session_id) + + def unregister_adapter(self, adapter_name: str) -> bool: + """Unregister an adapter from lifecycle tracking. + + Args: + adapter_name: Name of the adapter to unregister. + + Returns: + True if adapter was found and removed, False otherwise. + """ + return self.unregister_resource(adapter_name) + + def set_adapter_state(self, adapter_name: str, key: str, value: Any) -> None: + """Set a per-adapter state value. + + This is intentionally generic so higher-level services can store + adapter-scoped state (e.g., training readiness) without maintaining + separate side maps. + """ + self.set_resource_state(adapter_name, key, value) + + def get_adapter_state(self, adapter_name: str, key: str, default: Any = None) -> Any: + """Get a per-adapter state value.""" + return self.get_resource_state(adapter_name, key, default) + + def pop_adapter_state(self, adapter_name: str, key: str, default: Any = None) -> Any: + """Pop a per-adapter state value.""" + return self.pop_resource_state(adapter_name, key, default) + + def clear_adapter_state(self, adapter_name: str) -> None: + """Clear all per-adapter state values.""" + self.clear_resource_state(adapter_name) + + def get_adapter_info(self, adapter_name: str) -> dict[str, Any] | None: + """Get information about a registered adapter. + + Args: + adapter_name: Name of the adapter to query. + + Returns: + Dict with adapter information or None if not found. + """ + return self.get_resource_info(adapter_name) + + async def _on_resource_expired(self, resource_id: str) -> None: + """Internal hook called by base class. Delegates to _on_adapter_expired.""" + await self._on_adapter_expired(resource_id) + + async def _on_adapter_expired(self, adapter_name: str) -> None: + """Hook method called when an adapter expires. + + This method must be overridden by inheriting classes to handle + adapter expiration logic. The base implementation raises NotImplementedError. + + Args: + adapter_name: Name of the expired adapter. + + Raises: + NotImplementedError: If not overridden by inheriting class. + """ + raise NotImplementedError(f'_on_adapter_expired must be implemented by {self.__class__.__name__}') + + @staticmethod + def get_adapter_name(adapter_name: str) -> str: + """Get the adapter name for a request. + + This is a passthrough method for consistency with the original API. + + Args: + adapter_name: The adapter name (typically model_id) + + Returns: + The adapter name to use + """ + return adapter_name + + def assert_adapter_exists(self, adapter_name: str) -> None: + """Validate that an adapter exists and is not expiring. + + Raises: + HTTPException: 400 if adapter not found or expiring, with clear error message. + """ + info = self._resource_records.get(adapter_name) + if not adapter_name or info is None or info.get('expiring'): + raise HTTPException( + status_code=400, + detail=f"Adapter '{adapter_name}' not found. " + f'Please call add_adapter_to_model() first to create an adapter.') + + def _ensure_countdown_started(self) -> None: + """Ensure the countdown task is started. Call from async context.""" + super()._ensure_countdown_started() + + def stop_adapter_countdown(self) -> None: + """Stop the background countdown task.""" + self.stop_resource_countdown() diff --git a/src/twinkle/server/utils/lifecycle/base.py b/src/twinkle/server/utils/lifecycle/base.py new file mode 100644 index 00000000..e719555b --- /dev/null +++ b/src/twinkle/server/utils/lifecycle/base.py @@ -0,0 +1,328 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Base class for session-bound resource lifecycle management. + +This module provides a generic mixin for managing resources (adapters, processors, etc.) +that are bound to user sessions and should expire when their session expires. +""" +from __future__ import annotations + +import asyncio +import time +from abc import abstractmethod +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from twinkle.server.utils.state import ServerStateProxy + +from twinkle.utils.logger import get_logger + +logger = get_logger() + + +class SessionResourceMixin: + """Base mixin for managing session-bound resources with automatic expiration. + + This mixin tracks resources and automatically expires them when their + associated session expires or when they exceed their maximum lifetime. + + Inheriting classes should: + 1. Call _init_resource_manager() in __init__ + 2. Override _on_resource_expired() to handle resource-specific cleanup + 3. Optionally override _validate_registration() for custom validation + + Attributes: + _resource_timeout: Session inactivity timeout in seconds. + _resource_max_lifetime: Maximum lifetime in seconds for any resource. + _resource_records: Dict mapping resource_id -> resource info dict. + """ + + # Type hint for state attribute that inheriting classes must provide + state: ServerStateProxy + + # Resource type name for logging (override in subclass) + _resource_type: str = 'resource' + + def _init_resource_manager( + self, + resource_timeout: float = 1800.0, + resource_max_lifetime: float | None = None, + ) -> None: + """Initialize the resource manager. + + This should be called in the __init__ of the inheriting class. + + Args: + resource_timeout: Timeout in seconds to determine if a session is alive. + Default is 1800.0 (30 minutes). + resource_max_lifetime: Maximum lifetime in seconds for a resource regardless + of session liveness. None means no max lifetime limit. + """ + self._resource_timeout = resource_timeout + self._resource_max_lifetime = resource_max_lifetime + + # Resource lifecycle tracking + # Dict mapping resource_id -> + # {'token': str, 'session_id': str, 'created_at': float, 'state': dict, 'expiring': bool} + self._resource_records: dict[str, dict[str, Any]] = {} + + # Countdown task + self._resource_countdown_running = False + self._countdown_task: asyncio.Task | None = None + + async def _is_session_alive(self, session_id: str) -> bool: + """Check if a session is still alive via state proxy. + + Args: + session_id: Session ID to check + + Returns: + True if session is alive, False if expired or not found + """ + if not session_id: + return True # No session association means always alive + + try: + last_heartbeat = await self.state.get_session_last_heartbeat(session_id) + except Exception as e: + logger.warning(f'[{self._resource_type}Manager] Failed to check session liveness: {e}') + return True # Assume alive on error + + if last_heartbeat is None: + return False # Session doesn't exist + + # Check if session has timed out + return (time.time() - last_heartbeat) < self._resource_timeout + + def _validate_registration(self, resource_id: str, token: str, session_id: str) -> None: + """Validate before registering a resource. Override for custom validation. + + Args: + resource_id: Resource identifier + token: User token + session_id: Session ID + + Raises: + ValueError: If validation fails + RuntimeError: If resource limit is reached + """ + if not session_id: + raise ValueError(f'session_id must be provided when registering {self._resource_type} {resource_id}') + + def _create_resource_record(self, token: str, session_id: str) -> dict[str, Any]: + """Create a new resource record. Override to add custom fields. + + Args: + token: User token + session_id: Session ID + + Returns: + Resource record dict + """ + return { + 'token': token, + 'session_id': session_id, + 'created_at': time.time(), + 'state': {}, + 'expiring': False, + } + + def register_resource(self, resource_id: str, token: str, session_id: str) -> None: + """Register a new resource for lifecycle tracking. + + Args: + resource_id: Unique identifier of the resource. + token: User token that owns this resource. + session_id: Session ID to associate with this resource. + + Raises: + ValueError: If session_id is None or empty. + RuntimeError: If custom validation fails (e.g., limit reached). + """ + self._validate_registration(resource_id, token, session_id) + + self._resource_records[resource_id] = self._create_resource_record(token, session_id) + logger.debug(f'[{self._resource_type}Manager] Registered {self._resource_type} {resource_id} ' + f'for token {token[:8]}... (session: {session_id})') + + def unregister_resource(self, resource_id: str) -> bool: + """Unregister a resource from lifecycle tracking. + + Args: + resource_id: Resource identifier to unregister. + + Returns: + True if resource was found and removed, False otherwise. + """ + if resource_id in self._resource_records: + info = self._resource_records.pop(resource_id) + token = info.get('token') + logger.debug(f'[{self._resource_type}Manager] Unregistered {self._resource_type} {resource_id} ' + f"for token {token[:8] if token else 'unknown'}...") + return True + return False + + def get_resource_info(self, resource_id: str) -> dict[str, Any] | None: + """Get information about a registered resource. + + Args: + resource_id: Resource identifier to query. + + Returns: + Dict with resource information or None if not found. + """ + return self._resource_records.get(resource_id) + + def set_resource_state(self, resource_id: str, key: str, value: Any) -> None: + """Set a per-resource state value. + + This is intentionally generic so higher-level services can store + resource-scoped state without maintaining separate side maps. + """ + info = self._resource_records.get(resource_id) + if info is None: + return + state = info.setdefault('state', {}) + state[key] = value + + def get_resource_state(self, resource_id: str, key: str, default: Any = None) -> Any: + """Get a per-resource state value.""" + info = self._resource_records.get(resource_id) + if info is None: + return default + state = info.get('state') or {} + return state.get(key, default) + + def pop_resource_state(self, resource_id: str, key: str, default: Any = None) -> Any: + """Pop a per-resource state value.""" + info = self._resource_records.get(resource_id) + if info is None: + return default + state = info.get('state') + if not isinstance(state, dict): + return default + return state.pop(key, default) + + def clear_resource_state(self, resource_id: str) -> None: + """Clear all per-resource state values.""" + info = self._resource_records.get(resource_id) + if info is None: + return + info['state'] = {} + + def assert_resource_exists(self, resource_id: str) -> None: + """Validate that a resource exists and is not expiring. + + Raises: + AssertionError: If resource not found or expiring. + """ + info = self._resource_records.get(resource_id) + assert resource_id and info is not None and not info.get('expiring'), \ + f'{self._resource_type} {resource_id} not found' + + @abstractmethod + async def _on_resource_expired(self, resource_id: str) -> None: + """Hook method called when a resource expires. + + This method must be implemented by inheriting classes to handle + resource-specific expiration logic. + + Args: + resource_id: Identifier of the expired resource. + """ + raise NotImplementedError(f'_on_resource_expired must be implemented by {self.__class__.__name__}') + + async def _resource_countdown_loop(self) -> None: + """Background task that monitors and handles expired resources. + + This task runs continuously and: + 1. Checks whether a resource has exceeded `_resource_max_lifetime` (if set) + 2. Checks session liveness for remaining resources + 3. Calls _on_resource_expired() for resources that have expired + 4. Removes expired resources from tracking + """ + logger.debug(f'[{self._resource_type}Manager] Countdown task started ' + f'(session_timeout={self._resource_timeout}s)') + while self._resource_countdown_running: + try: + await asyncio.sleep(10) + + expired_resources: list[tuple[str, str | None]] = [] + # Create snapshot to avoid modification during iteration + resource_snapshot = list(self._resource_records.items()) + for resource_id, info in resource_snapshot: + if info.get('expiring'): + continue + + session_id = info.get('session_id') + created_at = info.get('created_at', 0.0) + now = time.time() + + # Check max lifetime first (no async call needed) + if self._resource_max_lifetime and now - created_at >= self._resource_max_lifetime: + logger.debug(f'[{self._resource_type}Manager] {self._resource_type} {resource_id} ' + f'exceeded max lifetime ({self._resource_max_lifetime}s), marking as expired') + info['expiring'] = True + info['state'] = {} + token = info.get('token') + expired_resources.append((resource_id, token, session_id)) + continue + + try: + session_alive = await self._is_session_alive(session_id) + except Exception as e: + logger.warning(f'[{self._resource_type}Manager] Failed to check session liveness ' + f'for {resource_id}: {type(e).__name__}: {e}') + continue + session_expired = not session_alive + logger.debug(f'[{self._resource_type}Manager] {self._resource_type} {resource_id} session check ' + f'(session_id={session_id}, session_alive={not session_expired})') + + if session_expired: + info['expiring'] = True + info['state'] = {} # best-effort clear + token = info.get('token') + expired_resources.append((resource_id, token, session_id)) + + for resource_id, _token, session_id in expired_resources: + success = False + try: + await self._on_resource_expired(resource_id) + logger.info(f'[{self._resource_type}Manager] {self._resource_type} {resource_id} expired ' + f'(reason=session_expired, session={session_id})') + success = True + except Exception as e: + logger.warning(f'[{self._resource_type}Manager] Error while expiring {self._resource_type} ' + f'{resource_id}: {e}') + finally: + if success: + self._resource_records.pop(resource_id, None) + else: + info = self._resource_records.get(resource_id) + if info is not None: + info['expiring'] = False + + except Exception as e: + logger.warning(f'[{self._resource_type}Manager] Error in countdown loop: {e}') + continue + + logger.debug(f'[{self._resource_type}Manager] Countdown task stopped') + + def _ensure_countdown_started(self) -> None: + """Ensure the countdown task is started. Call from async context.""" + if not self._resource_countdown_running: + self._resource_countdown_running = True + self._countdown_task = asyncio.create_task(self._resource_countdown_loop()) + logger.debug(f'[{self._resource_type}Manager] Countdown task started') + + async def _async_ensure_countdown_started(self) -> None: + """Async version for convenience.""" + self._ensure_countdown_started() + + def stop_resource_countdown(self) -> None: + """Stop the background countdown task.""" + if self._resource_countdown_running: + self._resource_countdown_running = False + if self._countdown_task: + self._countdown_task.cancel() + logger.debug(f'[{self._resource_type}Manager] Countdown task stopped') diff --git a/src/twinkle/server/utils/lifecycle/processor.py b/src/twinkle/server/utils/lifecycle/processor.py new file mode 100644 index 00000000..2a764882 --- /dev/null +++ b/src/twinkle/server/utils/lifecycle/processor.py @@ -0,0 +1,143 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Processor Lifecycle Manager Mixin for Twinkle Server. + +Mirrors AdapterManagerMixin but adds a global per-token processor limit. +Sessions are tracked via session ID; processors expire when their session expires. +""" +from __future__ import annotations + +import time +from typing import Any + +from twinkle.utils.logger import get_logger +from .base import SessionResourceMixin + +logger = get_logger() + + +class ProcessorManagerMixin(SessionResourceMixin): + """Mixin for processor lifecycle management with session-based expiration. + + Mirrors AdapterManagerMixin with an additional per-token processor limit. + + Inheriting classes should: + 1. Call _init_processor_manager() in __init__ + 2. Override _on_processor_expired() to handle cleanup + + Attributes: + _processor_timeout: Session inactivity timeout in seconds. + _per_token_processor_limit: Maximum active processors per user token. + """ + + # Set resource type for logging + _resource_type = 'Processor' + + def _init_processor_manager( + self, + processor_timeout: float = 1800.0, + per_token_processor_limit: int = 20, + ) -> None: + """Initialize the processor manager. + + Args: + processor_timeout: Timeout in seconds to determine if a session is alive. + Default is 1800.0 (30 minutes). + per_token_processor_limit: Maximum active processors per user token. + Default is 20. + """ + self._init_resource_manager( + resource_timeout=processor_timeout, + resource_max_lifetime=None, # No max lifetime for processors + ) + self._per_token_processor_limit = per_token_processor_limit + + @property + def _processor_timeout(self) -> float: + """Processor timeout for backward compatibility.""" + return self._resource_timeout + + @property + def _processor_records(self) -> dict[str, dict[str, Any]]: + """Processor records for backward compatibility.""" + return self._resource_records + + def _validate_registration(self, resource_id: str, token: str, session_id: str) -> None: + """Validate before registering a processor. Checks per-token limit. + + Args: + resource_id: Processor identifier + token: User token + session_id: Session ID + + Raises: + ValueError: If session_id is empty. + RuntimeError: If per-token limit is reached. + """ + super()._validate_registration(resource_id, token, session_id) + + current_count = sum(1 for info in self._resource_records.values() if info.get('token') == token) + if current_count >= self._per_token_processor_limit: + raise RuntimeError(f'Per-user processor limit ({self._per_token_processor_limit}) reached ' + f'for token {token[:8]}...') + + def _create_resource_record(self, token: str, session_id: str) -> dict[str, Any]: + """Create a new processor record without state field.""" + return { + 'token': token, + 'session_id': session_id, + 'created_at': time.time(), + 'expiring': False, + } + + def register_processor(self, processor_id: str, token: str, session_id: str) -> None: + """Register a new processor for lifecycle tracking. + + Args: + processor_id: Unique identifier of the processor. + token: User token that owns this processor. + session_id: Session ID to associate with this processor. Must be non-empty. + + Raises: + ValueError: If session_id is None or empty. + RuntimeError: If the per-token processor limit has been reached. + """ + self.register_resource(processor_id, token, session_id) + + def unregister_processor(self, processor_id: str) -> bool: + """Unregister a processor from lifecycle tracking. + + Returns: + True if found and removed, False otherwise. + """ + return self.unregister_resource(processor_id) + + def get_processor_info(self, processor_id: str) -> dict[str, Any] | None: + """Get tracking info for a registered processor, or None if not found.""" + return self.get_resource_info(processor_id) + + def assert_processor_exists(self, processor_id: str) -> None: + """Assert a processor exists and is not expiring.""" + self.assert_resource_exists(processor_id) + + async def _on_resource_expired(self, resource_id: str) -> None: + """Internal hook called by base class. Delegates to _on_processor_expired.""" + await self._on_processor_expired(resource_id) + + def _on_processor_expired(self, processor_id: str) -> None: + """Hook called when a processor's session expires. + + Must be overridden by inheriting classes. + + Raises: + NotImplementedError: If not overridden. + """ + raise NotImplementedError(f'_on_processor_expired must be implemented by {self.__class__.__name__}') + + def _ensure_countdown_started(self) -> None: + """Ensure the countdown task is started. Call from async context.""" + super()._ensure_countdown_started() + + def stop_processor_countdown(self) -> None: + """Stop the background countdown task.""" + self.stop_resource_countdown() diff --git a/src/twinkle/server/utils/processor_manager.py b/src/twinkle/server/utils/processor_manager.py deleted file mode 100644 index df289b39..00000000 --- a/src/twinkle/server/utils/processor_manager.py +++ /dev/null @@ -1,195 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -""" -Processor Lifecycle Manager Mixin for Twinkle Server. - -Mirrors AdapterManagerMixin but adds a global per-token processor limit. -Sessions are tracked via session ID; processors expire when their session expires. -""" -from __future__ import annotations - -import threading -import time -from typing import TYPE_CHECKING, Any - -if TYPE_CHECKING: - from twinkle.server.utils.state import ServerStateProxy - -from twinkle.utils.logger import get_logger - -logger = get_logger() - - -class ProcessorManagerMixin: - """Mixin for processor lifecycle management with session-based expiration. - - Mirrors AdapterManagerMixin with an additional per-token processor limit. - - Inheriting classes should: - 1. Call _init_processor_manager() in __init__ - 2. Override _on_processor_expired() to handle cleanup - - Attributes: - _processor_timeout: Session inactivity timeout in seconds. - _per_token_processor_limit: Maximum active processors per user token. - """ - - # Type hint for state attribute that inheriting classes must provide - state: ServerStateProxy - - def _init_processor_manager( - self, - processor_timeout: float = 1800.0, - per_token_processor_limit: int = 20, - ) -> None: - """Initialize the processor manager. - - Args: - processor_timeout: Timeout in seconds to determine if a session is alive. - Default is 1800.0 (30 minutes). - per_token_processor_limit: Maximum active processors per user token. - Default is 20. - """ - self._processor_timeout = processor_timeout - self._per_token_processor_limit = per_token_processor_limit - - # processor_id -> {'token': str, 'session_id': str, 'created_at': float, 'expiring': bool} - self._processor_records: dict[str, dict[str, Any]] = {} - - self._processor_countdown_thread: threading.Thread | None = None - self._processor_countdown_running = False - - def register_processor(self, processor_id: str, token: str, session_id: str) -> None: - """Register a new processor for lifecycle tracking. - - Args: - processor_id: Unique identifier of the processor. - token: User token that owns this processor. - session_id: Session ID to associate with this processor. Must be non-empty. - - Raises: - ValueError: If session_id is None or empty. - RuntimeError: If the per-token processor limit has been reached. - """ - if not session_id: - raise ValueError(f'session_id must be provided when registering processor {processor_id}') - - current_count = sum(1 for info in self._processor_records.values() if info.get('token') == token) - if current_count >= self._per_token_processor_limit: - raise RuntimeError(f'Per-user processor limit ({self._per_token_processor_limit}) reached ' - f'for token {token[:8]}...') - - self._processor_records[processor_id] = { - 'token': token, - 'session_id': session_id, - 'created_at': time.time(), - 'expiring': False, - } - logger.debug(f'[ProcessorManager] Registered processor {processor_id} ' - f'for token {token[:8]}... (session: {session_id})') - - def unregister_processor(self, processor_id: str) -> bool: - """Unregister a processor from lifecycle tracking. - - Returns: - True if found and removed, False otherwise. - """ - if processor_id in self._processor_records: - info = self._processor_records.pop(processor_id) - token = info.get('token', '') - logger.debug(f'[ProcessorManager] Unregistered processor {processor_id} ' - f'for token {token[:8] if token else "unknown"}...') - return True - return False - - def get_processor_info(self, processor_id: str) -> dict[str, Any] | None: - """Get tracking info for a registered processor, or None if not found.""" - return self._processor_records.get(processor_id) - - def assert_processor_exists(self, processor_id: str) -> None: - """Assert a processor exists and is not expiring.""" - info = self._processor_records.get(processor_id) - assert processor_id and info is not None and not info.get('expiring'), \ - f'Processor {processor_id} not found' - - def _on_processor_expired(self, processor_id: str) -> None: - """Hook called when a processor's session expires. - - Must be overridden by inheriting classes. - - Raises: - NotImplementedError: If not overridden. - """ - raise NotImplementedError(f'_on_processor_expired must be implemented by {self.__class__.__name__}') - - def _is_session_alive(self, session_id: str) -> bool: - """Check if a session is still alive via state proxy.""" - if not session_id: - return True - last_heartbeat = self.state.get_session_last_heartbeat(session_id) - if last_heartbeat is None: - return False - return (time.time() - last_heartbeat) < self._processor_timeout - - def _processor_countdown_loop(self) -> None: - """Background thread: checks session liveness and expires stale processors.""" - logger.debug(f'[ProcessorManager] Countdown thread started (session_timeout={self._processor_timeout}s)') - while self._processor_countdown_running: - try: - time.sleep(1) - - expired: list[tuple[str, str | None]] = [] - for processor_id, info in list(self._processor_records.items()): - if info.get('expiring'): - continue - session_id = info.get('session_id') - try: - session_alive = self._is_session_alive(session_id) - except Exception as e: - logger.warning(f'[ProcessorManager] Failed to check session liveness ' - f'for {processor_id}: {type(e).__name__}: {e}') - continue - - logger.debug(f'[ProcessorManager] Processor {processor_id} session check ' - f'(session_id={session_id}, session_alive={session_alive})') - if not session_alive: - info['expiring'] = True - expired.append((processor_id, session_id)) - - for processor_id, session_id in expired: - success = False - try: - self._on_processor_expired(processor_id) - logger.info(f'[ProcessorManager] Processor {processor_id} expired ' - f'(reason=session_expired, session={session_id})') - success = True - except Exception as e: - logger.warning(f'[ProcessorManager] Error while expiring processor {processor_id}: {e}') - finally: - if success: - self._processor_records.pop(processor_id, None) - else: - info = self._processor_records.get(processor_id) - if info is not None: - info['expiring'] = False - - except Exception as e: - logger.warning(f'[ProcessorManager] Error in countdown loop: {e}') - continue - - logger.debug('[ProcessorManager] Countdown thread stopped') - - def start_processor_countdown(self) -> None: - """Start the background countdown thread. Safe to call multiple times.""" - if not self._processor_countdown_running: - self._processor_countdown_running = True - self._processor_countdown_thread = threading.Thread(target=self._processor_countdown_loop, daemon=True) - self._processor_countdown_thread.start() - logger.debug('[ProcessorManager] Countdown thread started') - - def stop_processor_countdown(self) -> None: - """Stop the background countdown thread.""" - if self._processor_countdown_running: - self._processor_countdown_running = False - if self._processor_countdown_thread: - self._processor_countdown_thread.join(timeout=2.0) - logger.debug('[ProcessorManager] Countdown thread stopped') diff --git a/src/twinkle/server/utils/ray_serve_patch.py b/src/twinkle/server/utils/ray_serve_patch.py new file mode 100644 index 00000000..a6fcbc33 --- /dev/null +++ b/src/twinkle/server/utils/ray_serve_patch.py @@ -0,0 +1,141 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Patches for Ray Serve to handle HTTP header normalization. + +This module patches Ray Serve's HTTPProxy.setup_request_context_and_handle +to handle HTTP header normalization by proxies (AWS/nginx). + +The problem: + curl sends: serve_multiplexed_model_id: 123 + proxy converts: Serve-Multiplexed-Model-Id: 123 (_ → -, title-cased) + uvicorn lowercases: serve-multiplexed-model-id: 123 + Ray Serve compares: "serve_multiplexed_model_id" (with underscores) + RESULT: NO MATCH → multiplexed_model_id is never set + +The fix normalizes header names by converting hyphens to underscores for comparison. + +IMPORTANT: Ray Serve's ProxyActor runs in a separate worker process. +Use get_runtime_env_for_patches() with ray.init() to ensure the patch +is applied in all worker processes. +""" +from __future__ import annotations + +from typing import Tuple + +from twinkle.utils.logger import get_logger + +logger = get_logger() + +# Track if patch has been applied +_patch_applied = False + + +def _patched_setup_request_context_and_handle( + self, + app_name: str, + handle, + route: str, + proxy_request, + internal_request_id: str, +) -> tuple: + """Patched version of HTTPProxy.setup_request_context_and_handle. + + This version handles HTTP header normalization by proxies: + - Converts hyphens to underscores for SERVE_MULTIPLEXED_MODEL_ID comparison + """ + from ray.serve._private.constants import SERVE_MULTIPLEXED_MODEL_ID + + request_context_info = { + 'route': route, + 'app_name': app_name, + '_internal_request_id': internal_request_id, + 'is_http_request': True, + } + + for key, value in proxy_request.headers: + decoded_key = key.decode() + + # Check for multiplexed model ID header + # Normalize: convert hyphens to underscores for comparison + # HTTP proxies convert underscores to hyphens: serve_multiplexed_model_id → serve-multiplexed-model-id + normalized_key = decoded_key.replace('-', '_') + if normalized_key == SERVE_MULTIPLEXED_MODEL_ID: + multiplexed_model_id = value.decode() + handle = handle.options(multiplexed_model_id=multiplexed_model_id) + request_context_info['multiplexed_model_id'] = multiplexed_model_id + logger.debug(f'[Ray Serve Patch] Matched multiplexed_model_id: {multiplexed_model_id}') + + # Original logic for other headers (unchanged) + if decoded_key == 'x-request-id': + request_context_info['request_id'] = value.decode() + + import ray.serve.context as serve_context + serve_context._serve_request_context.set(serve_context._RequestContext(**request_context_info)) + + return handle, request_context_info.get('request_id') + + +def _apply_patch_in_worker_process(): + """Apply patch in Ray worker process. + + This function is called by Ray's worker_process_setup_hook in each worker process. + """ + global _patch_applied + + if _patch_applied: + return + + try: + from ray.serve._private.proxy import HTTPProxy + + HTTPProxy.setup_request_context_and_handle = _patched_setup_request_context_and_handle + _patch_applied = True + + logger.debug('[Ray Serve Patch] Applied in worker process: ' + 'HTTPProxy.setup_request_context_and_handle patched') + except ImportError: + # Ray Serve not available in this worker + pass + except Exception as e: + logger.warning(f'[Ray Serve Patch] Failed to apply in worker process: {e}') + + +def apply_ray_serve_patches(): + """Apply patches to Ray Serve in the main process. + + Note: This only patches the main process. For Ray Serve's ProxyActor, + use get_runtime_env_for_patches() with ray.init() to ensure the patch + is applied in worker processes. + """ + global _patch_applied + + if _patch_applied: + return + + try: + from ray.serve._private.proxy import HTTPProxy + + HTTPProxy.setup_request_context_and_handle = _patched_setup_request_context_and_handle + _patch_applied = True + + logger.info('Applied Ray Serve patch: HTTPProxy.setup_request_context_and_handle ' + 'now handles header normalization (hyphens → underscores)') + except ImportError: + logger.warning('Ray Serve not available, skipping patch') + except Exception as e: + logger.warning(f'Failed to apply Ray Serve patch: {e}') + + +def get_runtime_env_for_patches() -> dict: + """Get Ray runtime_env to apply patches in worker processes. + + Ray actors run in separate processes. This returns a runtime_env dict + that configures Ray to run the patch function in each worker process. + + Usage: + ray.init(runtime_env=get_runtime_env_for_patches()) + + Returns: + dict: Ray runtime_env configuration + """ + return {'worker_process_setup_hook': ('twinkle.server.utils.ray_serve_patch._apply_patch_in_worker_process')} diff --git a/src/twinkle/server/utils/state/server_state.py b/src/twinkle/server/utils/state/server_state.py index a70fdac5..a42aa9a4 100644 --- a/src/twinkle/server/utils/state/server_state.py +++ b/src/twinkle/server/utils/state/server_state.py @@ -53,7 +53,7 @@ def __init__( # ----- Session Management ----- - def create_session(self, payload: dict[str, Any]) -> str: + async def create_session(self, payload: dict[str, Any]) -> str: """Create a new session with the given payload. Args: @@ -79,7 +79,7 @@ async def touch_session(self, session_id: str) -> bool: """ return self._session_mgr.touch(session_id) - def get_session_last_heartbeat(self, session_id: str) -> float | None: + async def get_session_last_heartbeat(self, session_id: str) -> float | None: """Get the last heartbeat timestamp for a session. Returns: @@ -89,11 +89,11 @@ def get_session_last_heartbeat(self, session_id: str) -> float | None: # ----- Model Registration ----- - def register_model(self, - payload: dict[str, Any], - token: str, - model_id: str | None = None, - replica_id: str | None = None) -> str: + async def register_model(self, + payload: dict[str, Any], + token: str, + model_id: str | None = None, + replica_id: str | None = None) -> str: """Register a new model with the server state. Args: @@ -122,7 +122,7 @@ def register_model(self, self._model_mgr.add(_model_id, record) return _model_id - def unload_model(self, model_id: str) -> bool: + async def unload_model(self, model_id: str) -> bool: """Remove a model from the registry. Returns: @@ -130,14 +130,14 @@ def unload_model(self, model_id: str) -> bool: """ return self._model_mgr.remove(model_id) - def get_model_metadata(self, model_id: str) -> dict[str, Any] | None: + async def get_model_metadata(self, model_id: str) -> dict[str, Any] | None: """Get metadata for a registered model as a plain dict.""" record = self._model_mgr.get(model_id) return record.model_dump() if record is not None else None # ----- Replica Management ----- - def register_replica(self, replica_id: str, max_loras: int) -> None: + async def register_replica(self, replica_id: str, max_loras: int) -> None: """Register a replica and its LoRA capacity. Args: @@ -146,7 +146,7 @@ def register_replica(self, replica_id: str, max_loras: int) -> None: """ self._model_mgr.register_replica(replica_id, max_loras) - def unregister_replica(self, replica_id: str) -> None: + async def unregister_replica(self, replica_id: str) -> None: """Remove a replica from the registry. Args: @@ -167,7 +167,7 @@ async def get_available_replica_ids(self, candidate_ids: list[str]) -> list[str] # ----- Sampling Session Management ----- - def create_sampling_session(self, payload: dict[str, Any], sampling_session_id: str | None = None) -> str: + async def create_sampling_session(self, payload: dict[str, Any], sampling_session_id: str | None = None) -> str: """Create a new sampling session. Args: @@ -188,7 +188,7 @@ def create_sampling_session(self, payload: dict[str, Any], sampling_session_id: self._sampling_mgr.add(_sampling_session_id, record) return _sampling_session_id - def get_sampling_session(self, sampling_session_id: str) -> dict[str, Any] | None: + async def get_sampling_session(self, sampling_session_id: str) -> dict[str, Any] | None: """Get a sampling session by ID as a plain dict.""" record = self._sampling_mgr.get(sampling_session_id) return record.model_dump() if record is not None else None @@ -241,7 +241,7 @@ async def store_future_status( # ----- Resource Cleanup ----- - def cleanup_expired_resources(self) -> dict[str, int]: + async def cleanup_expired_resources(self) -> dict[str, int]: """Clean up expired sessions, models, sampling_sessions, and futures. Sessions expire based on last_heartbeat (or created_at). Models and @@ -275,7 +275,7 @@ async def _cleanup_loop(self) -> None: while self._cleanup_running: try: await asyncio.sleep(self.cleanup_interval) - stats = self.cleanup_expired_resources() + stats = await self.cleanup_expired_resources() if any(stats.values()): logger.debug(f'[ServerState Cleanup] Removed expired resources: {stats}') except asyncio.CancelledError: @@ -284,7 +284,7 @@ async def _cleanup_loop(self) -> None: logger.warning(f'[ServerState Cleanup] Error during cleanup: {e}') continue - def start_cleanup_task(self) -> bool: + async def start_cleanup_task(self) -> bool: """Start the background cleanup task. Returns: @@ -296,7 +296,7 @@ def start_cleanup_task(self) -> bool: self._cleanup_task = asyncio.create_task(self._cleanup_loop()) return True - def stop_cleanup_task(self) -> bool: + async def stop_cleanup_task(self) -> bool: """Stop the background cleanup task. Returns: @@ -310,7 +310,7 @@ def stop_cleanup_task(self) -> bool: self._cleanup_task = None return True - def get_cleanup_stats(self) -> dict[str, Any]: + async def get_cleanup_stats(self) -> dict[str, Any]: """Get current cleanup configuration and resource counts. Returns: @@ -347,48 +347,48 @@ def __init__(self, actor_handle) -> None: # ----- Session Management ----- - def create_session(self, payload: dict[str, Any]) -> str: - return ray.get(self._actor.create_session.remote(payload)) + async def create_session(self, payload: dict[str, Any]) -> str: + return await self._actor.create_session.remote(payload) async def touch_session(self, session_id: str) -> bool: return await self._actor.touch_session.remote(session_id) - def get_session_last_heartbeat(self, session_id: str) -> float | None: - return ray.get(self._actor.get_session_last_heartbeat.remote(session_id)) + async def get_session_last_heartbeat(self, session_id: str) -> float | None: + return await self._actor.get_session_last_heartbeat.remote(session_id) # ----- Model Registration ----- - def register_model(self, - payload: dict[str, Any], - token: str, - model_id: str | None = None, - replica_id: str | None = None) -> str: - return ray.get(self._actor.register_model.remote(payload, token, model_id, replica_id)) + async def register_model(self, + payload: dict[str, Any], + token: str, + model_id: str | None = None, + replica_id: str | None = None) -> str: + return await self._actor.register_model.remote(payload, token, model_id, replica_id) - def unload_model(self, model_id: str) -> bool: - return ray.get(self._actor.unload_model.remote(model_id)) + async def unload_model(self, model_id: str) -> bool: + return await self._actor.unload_model.remote(model_id) - def get_model_metadata(self, model_id: str) -> dict[str, Any] | None: - return ray.get(self._actor.get_model_metadata.remote(model_id)) + async def get_model_metadata(self, model_id: str) -> dict[str, Any] | None: + return await self._actor.get_model_metadata.remote(model_id) # ----- Replica Management ----- - def register_replica(self, replica_id: str, max_loras: int) -> None: - ray.get(self._actor.register_replica.remote(replica_id, max_loras)) + async def register_replica(self, replica_id: str, max_loras: int) -> None: + await self._actor.register_replica.remote(replica_id, max_loras) - def unregister_replica(self, replica_id: str) -> None: - ray.get(self._actor.unregister_replica.remote(replica_id)) + async def unregister_replica(self, replica_id: str) -> None: + await self._actor.unregister_replica.remote(replica_id) async def get_available_replica_ids(self, candidate_ids: list[str]) -> list[str]: return await self._actor.get_available_replica_ids.remote(candidate_ids) # ----- Sampling Session Management ----- - def create_sampling_session(self, payload: dict[str, Any], sampling_session_id: str | None = None) -> str: - return ray.get(self._actor.create_sampling_session.remote(payload, sampling_session_id)) + async def create_sampling_session(self, payload: dict[str, Any], sampling_session_id: str | None = None) -> str: + return await self._actor.create_sampling_session.remote(payload, sampling_session_id) - def get_sampling_session(self, sampling_session_id: str) -> dict[str, Any] | None: - return ray.get(self._actor.get_sampling_session.remote(sampling_session_id)) + async def get_sampling_session(self, sampling_session_id: str) -> dict[str, Any] | None: + return await self._actor.get_sampling_session.remote(sampling_session_id) # ----- Future Management ----- @@ -405,23 +405,23 @@ async def store_future_status( queue_state: str | None = None, queue_state_reason: str | None = None, ) -> None: - """Store task status with optional result (synchronous).""" + """Store task status with optional result.""" await self._actor.store_future_status.remote(request_id, status, model_id, reason, result, queue_state, queue_state_reason) # ----- Resource Cleanup ----- - def cleanup_expired_resources(self) -> dict[str, int]: - return ray.get(self._actor.cleanup_expired_resources.remote()) + async def cleanup_expired_resources(self) -> dict[str, int]: + return await self._actor.cleanup_expired_resources.remote() - def start_cleanup_task(self) -> bool: - return ray.get(self._actor.start_cleanup_task.remote()) + async def start_cleanup_task(self) -> bool: + return await self._actor.start_cleanup_task.remote() - def stop_cleanup_task(self) -> bool: - return ray.get(self._actor.stop_cleanup_task.remote()) + async def stop_cleanup_task(self) -> bool: + return await self._actor.stop_cleanup_task.remote() - def get_cleanup_stats(self) -> dict[str, Any]: - return ray.get(self._actor.get_cleanup_stats.remote()) + async def get_cleanup_stats(self) -> dict[str, Any]: + return await self._actor.get_cleanup_stats.remote() # --------------------------------------------------------------------------- diff --git a/src/twinkle_client/__init__.py b/src/twinkle_client/__init__.py index f41a83ce..a5105d49 100644 --- a/src/twinkle_client/__init__.py +++ b/src/twinkle_client/__init__.py @@ -1,6 +1,9 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from __future__ import annotations -from typing import Optional +from typing import Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from .manager import TwinkleClient def init_tinker_client(**kwargs) -> None: diff --git a/src/twinkle_client/common/serialize.py b/src/twinkle_client/common/serialize.py index 513e7fd5..de3ca4bb 100644 --- a/src/twinkle_client/common/serialize.py +++ b/src/twinkle_client/common/serialize.py @@ -2,16 +2,22 @@ import json from numbers import Number from peft import LoraConfig -from typing import Mapping +from typing import Any, Mapping from twinkle.dataset import DatasetMeta +supported_types = { + DatasetMeta, + LoraConfig, +} + primitive_types = (str, Number, bool, bytes, type(None)) container_types = (Mapping, list, tuple, set, frozenset) basic_types = (*primitive_types, *container_types) def _serialize_data_slice(data_slice): + """Serialize data_slice (Iterable) into a JSON-compatible dict.""" if data_slice is None: return None if isinstance(data_slice, range): @@ -21,6 +27,21 @@ def _serialize_data_slice(data_slice): raise ValueError(f'Http mode does not support data_slice of type {type(data_slice).__name__}. ' 'Supported types: range, list, tuple.') + +def _deserialize_data_slice(data_slice): + """Deserialize a dict back into the original data_slice object.""" + if data_slice is None: + return None + if not isinstance(data_slice, dict) or '_slice_type_' not in data_slice: + return data_slice + slice_type = data_slice['_slice_type_'] + if slice_type == 'range': + return range(data_slice['start'], data_slice['stop'], data_slice['step']) + if slice_type == 'list': + return data_slice['values'] + raise ValueError(f'Unsupported data_slice type: {slice_type}') + + def serialize_object(obj) -> str: if isinstance(obj, DatasetMeta): data = obj.__dict__.copy() @@ -40,4 +61,23 @@ def serialize_object(obj) -> str: elif isinstance(obj, basic_types): return obj else: - raise ValueError(f'Unsupported object: {obj}') \ No newline at end of file + raise ValueError(f'Unsupported object: {obj}') + + +def deserialize_object(data: str) -> Any: + try: + data = json.loads(data) + except Exception: # noqa + return data + + if '_TWINKLE_TYPE_' in data: + _type = data.pop('_TWINKLE_TYPE_') + if _type == 'DatasetMeta': + data['data_slice'] = _deserialize_data_slice(data.get('data_slice')) + return DatasetMeta(**data) + elif _type == 'LoraConfig': + return LoraConfig(**data) + else: + raise ValueError(f'Unsupported type: {_type}') + else: + return data diff --git a/src/twinkle_client/http/http_utils.py b/src/twinkle_client/http/http_utils.py index 1373a66d..ddfe6c59 100644 --- a/src/twinkle_client/http/http_utils.py +++ b/src/twinkle_client/http/http_utils.py @@ -16,7 +16,7 @@ def _build_headers(additional_headers: Optional[Dict[str, str]] = None) -> Dict[ """ headers = { 'X-Ray-Serve-Request-Id': get_request_id(), - 'serve_multiplexed_model_id': get_request_id(), # For model multiplexing + 'Serve-Multiplexed-Model-Id': get_request_id(), # For model multiplexing 'Authorization': 'Bearer ' + get_api_key(), 'Twinkle-Authorization': 'Bearer ' + get_api_key(), # For server compatibility } diff --git a/src/twinkle_client/manager.py b/src/twinkle_client/manager.py index 108465ec..b9398997 100644 --- a/src/twinkle_client/manager.py +++ b/src/twinkle_client/manager.py @@ -5,7 +5,7 @@ import threading from typing import Any, Dict, List, Optional, Tuple from twinkle import get_logger -from twinkle_client.types.server import DeleteCheckpointResponse +from twinkle_client.types.server import (DeleteCheckpointResponse, GetServerCapabilitiesResponse) from twinkle_client.types.session import (CreateSessionRequest, CreateSessionResponse, SessionHeartbeatRequest, SessionHeartbeatResponse) from twinkle_client.types.training import (Checkpoint, Cursor, ParsedCheckpointTwinklePath, TrainingRun, @@ -167,6 +167,21 @@ def health_check(self) -> bool: except Exception: return False + def get_server_capabilities(self) -> GetServerCapabilitiesResponse: + """ + Get the server's supported models and capabilities. + + Returns: + :class:`~twinkle_client.types.server.GetServerCapabilitiesResponse` with + ``supported_models`` field containing a list of supported model names. + + Raises: + TwinkleClientError: If the request fails. + """ + response = http_get(self._get_url('/get_server_capabilities')) + data = self._handle_response(response) + return GetServerCapabilitiesResponse(**data) + # ------------------------------------------------------------------ # Training Runs # ------------------------------------------------------------------ diff --git a/src/twinkle_client/types/__init__.py b/src/twinkle_client/types/__init__.py index 59c88597..00b1f967 100644 --- a/src/twinkle_client/types/__init__.py +++ b/src/twinkle_client/types/__init__.py @@ -68,7 +68,9 @@ CheckpointPathResponse, DeleteCheckpointResponse, ErrorResponse, + GetServerCapabilitiesResponse, HealthResponse, + SupportedModel, WeightsInfoRequest, WeightsInfoResponse as ServerWeightsInfoResponse, ) diff --git a/src/twinkle_client/types/server.py b/src/twinkle_client/types/server.py index df7ed58a..2d9233b4 100644 --- a/src/twinkle_client/types/server.py +++ b/src/twinkle_client/types/server.py @@ -1,7 +1,17 @@ # Copyright (c) ModelScope Contributors. All rights reserved. """Shared Pydantic response models for the twinkle server health/error endpoints.""" from pydantic import BaseModel -from typing import Any +from typing import Any, List, Optional + + +class SupportedModel(BaseModel): + """Information about a supported model.""" + model_name: str + + +class GetServerCapabilitiesResponse(BaseModel): + """Response body for the /get_server_capabilities endpoint.""" + supported_models: List[SupportedModel] class HealthResponse(BaseModel):