|
11 | 11 | import json |
12 | 12 | import logging |
13 | 13 | import typing |
14 | | -from dataclasses import dataclass, field |
| 14 | +from dataclasses import dataclass |
15 | 15 |
|
16 | 16 | try: |
17 | 17 | import websockets |
@@ -55,6 +55,22 @@ class PendingCommand: |
55 | 55 | params: T_JSON_DICT |
56 | 56 |
|
57 | 57 |
|
| 58 | +@dataclass |
| 59 | +class EventWaiter: |
| 60 | + """Represents a consumer waiting for a matching event.""" |
| 61 | + future: asyncio.Future |
| 62 | + event_type: typing.Optional[typing.Union[type, typing.Tuple[type, ...]]] = None |
| 63 | + predicate: typing.Optional[typing.Callable[[typing.Any], bool]] = None |
| 64 | + |
| 65 | + def matches(self, event: typing.Any) -> bool: |
| 66 | + """Return True if this waiter should receive the event.""" |
| 67 | + if self.event_type is not None and not isinstance(event, self.event_type): |
| 68 | + return False |
| 69 | + if self.predicate is not None and not self.predicate(event): |
| 70 | + return False |
| 71 | + return True |
| 72 | + |
| 73 | + |
58 | 74 | class CDPConnection: |
59 | 75 | """ |
60 | 76 | Manages a WebSocket connection to Chrome DevTools Protocol. |
@@ -96,6 +112,7 @@ def __init__(self, url: str, timeout: float = 30.0): |
96 | 112 | self._next_command_id = 1 |
97 | 113 | self._pending_commands: typing.Dict[int, PendingCommand] = {} |
98 | 114 | self._event_queue: asyncio.Queue = asyncio.Queue() |
| 115 | + self._event_waiters: typing.List[EventWaiter] = [] |
99 | 116 | self._recv_task: typing.Optional[asyncio.Task] = None |
100 | 117 | self._closed = False |
101 | 118 |
|
@@ -131,6 +148,12 @@ async def close(self) -> None: |
131 | 148 | if not pending.future.done(): |
132 | 149 | pending.future.cancel() |
133 | 150 | self._pending_commands.clear() |
| 151 | + |
| 152 | + # Cancel all event waiters |
| 153 | + for waiter in self._event_waiters: |
| 154 | + if not waiter.future.done(): |
| 155 | + waiter.future.set_exception(CDPConnectionError("Connection closed")) |
| 156 | + self._event_waiters.clear() |
134 | 157 |
|
135 | 158 | # Close the WebSocket |
136 | 159 | if self._ws: |
@@ -213,9 +236,28 @@ async def _handle_event(self, data: T_JSON_DICT) -> None: |
213 | 236 | """Handle an event notification.""" |
214 | 237 | try: |
215 | 238 | event = parse_json_event(data) |
| 239 | + self._notify_event_waiters(event) |
216 | 240 | await self._event_queue.put(event) |
217 | 241 | except Exception as e: |
218 | 242 | logger.error(f"Failed to parse event: {e}") |
| 243 | + |
| 244 | + def _notify_event_waiters(self, event: typing.Any) -> None: |
| 245 | + """Resolve any pending waiters that match the event.""" |
| 246 | + if not self._event_waiters: |
| 247 | + return |
| 248 | + |
| 249 | + remaining_waiters: typing.List[EventWaiter] = [] |
| 250 | + for waiter in self._event_waiters: |
| 251 | + if waiter.future.done(): |
| 252 | + continue |
| 253 | + try: |
| 254 | + if waiter.matches(event): |
| 255 | + waiter.future.set_result(event) |
| 256 | + else: |
| 257 | + remaining_waiters.append(waiter) |
| 258 | + except Exception as e: |
| 259 | + waiter.future.set_exception(e) |
| 260 | + self._event_waiters = remaining_waiters |
219 | 261 |
|
220 | 262 | async def execute( |
221 | 263 | self, |
@@ -330,6 +372,48 @@ def get_event_nowait(self) -> typing.Optional[typing.Any]: |
330 | 372 | return self._event_queue.get_nowait() |
331 | 373 | except asyncio.QueueEmpty: |
332 | 374 | return None |
| 375 | + |
| 376 | + async def wait_for_event( |
| 377 | + self, |
| 378 | + event_type: typing.Optional[typing.Union[type, typing.Tuple[type, ...]]] = None, |
| 379 | + predicate: typing.Optional[typing.Callable[[typing.Any], bool]] = None, |
| 380 | + timeout: typing.Optional[float] = None, |
| 381 | + ) -> typing.Any: |
| 382 | + """ |
| 383 | + Wait for the next event matching the provided filters. |
| 384 | +
|
| 385 | + Args: |
| 386 | + event_type: Optional event class (or tuple of classes) to match. |
| 387 | + predicate: Optional callable that must return True for a match. |
| 388 | + timeout: Optional timeout override in seconds. |
| 389 | +
|
| 390 | + Returns: |
| 391 | + The matching CDP event object. |
| 392 | +
|
| 393 | + Raises: |
| 394 | + CDPConnectionError: If the connection is closed. |
| 395 | + asyncio.TimeoutError: If no matching event arrives in time. |
| 396 | + """ |
| 397 | + if self._closed: |
| 398 | + raise CDPConnectionError("Connection closed") |
| 399 | + if self._ws is None: |
| 400 | + raise CDPConnectionError("Not connected") |
| 401 | + |
| 402 | + future: asyncio.Future = asyncio.Future() |
| 403 | + waiter = EventWaiter(future=future, event_type=event_type, predicate=predicate) |
| 404 | + self._event_waiters.append(waiter) |
| 405 | + |
| 406 | + timeout_val = timeout if timeout is not None else self.timeout |
| 407 | + try: |
| 408 | + return await asyncio.wait_for(future, timeout=timeout_val) |
| 409 | + except asyncio.TimeoutError: |
| 410 | + if waiter in self._event_waiters: |
| 411 | + self._event_waiters.remove(waiter) |
| 412 | + raise asyncio.TimeoutError("Timed out waiting for matching CDP event") |
| 413 | + except Exception: |
| 414 | + if waiter in self._event_waiters: |
| 415 | + self._event_waiters.remove(waiter) |
| 416 | + raise |
333 | 417 |
|
334 | 418 | @property |
335 | 419 | def is_connected(self) -> bool: |
|
0 commit comments