Skip to content

Commit cdb471c

Browse files
Copilothyp3ri0n-ng
andcommitted
Add I/O and multiplexing support with CDPConnection class
Co-authored-by: hyp3ri0n-ng <3106718+hyp3ri0n-ng@users.noreply.github.com>
1 parent d50d37e commit cdb471c

File tree

5 files changed

+977
-1
lines changed

5 files changed

+977
-1
lines changed

cdp/connection.py

Lines changed: 342 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,342 @@
1+
"""
2+
CDP Connection Module
3+
4+
This module provides I/O and multiplexing support for Chrome DevTools Protocol.
5+
It handles WebSocket connections, JSON-RPC message framing, command multiplexing,
6+
and event dispatching.
7+
"""
8+
9+
from __future__ import annotations
10+
import asyncio
11+
import json
12+
import logging
13+
import typing
14+
from dataclasses import dataclass, field
15+
16+
try:
17+
import websockets
18+
from websockets.client import WebSocketClientProtocol
19+
WEBSOCKETS_AVAILABLE = True
20+
except ImportError:
21+
WEBSOCKETS_AVAILABLE = False
22+
WebSocketClientProtocol = typing.Any # type: ignore
23+
24+
from cdp.util import parse_json_event, T_JSON_DICT
25+
26+
27+
logger = logging.getLogger(__name__)
28+
29+
30+
class CDPError(Exception):
31+
"""Base exception for CDP errors."""
32+
pass
33+
34+
35+
class CDPConnectionError(CDPError):
36+
"""Raised when there's a connection error."""
37+
pass
38+
39+
40+
class CDPCommandError(CDPError):
41+
"""Raised when a command returns an error."""
42+
43+
def __init__(self, code: int, message: str, data: typing.Optional[typing.Any] = None):
44+
self.code = code
45+
self.message = message
46+
self.data = data
47+
super().__init__(f"CDP Command Error {code}: {message}")
48+
49+
50+
@dataclass
51+
class PendingCommand:
52+
"""Represents a command waiting for a response."""
53+
future: asyncio.Future
54+
method: str
55+
params: T_JSON_DICT
56+
57+
58+
class CDPConnection:
59+
"""
60+
Manages a WebSocket connection to Chrome DevTools Protocol.
61+
62+
This class handles:
63+
- WebSocket connection management
64+
- JSON-RPC message framing (request ID assignment)
65+
- Command multiplexing (tracking multiple concurrent commands)
66+
- Event dispatching
67+
- Error handling
68+
69+
Example:
70+
async with CDPConnection("ws://localhost:9222/devtools/page/...") as conn:
71+
# Send a command
72+
result = await conn.execute(some_command())
73+
74+
# Listen for events
75+
async for event in conn.listen():
76+
print(event)
77+
"""
78+
79+
def __init__(self, url: str, timeout: float = 30.0):
80+
"""
81+
Initialize a CDP connection.
82+
83+
Args:
84+
url: WebSocket URL for the CDP endpoint
85+
timeout: Default timeout for commands in seconds
86+
"""
87+
if not WEBSOCKETS_AVAILABLE:
88+
raise ImportError(
89+
"websockets library is required for CDPConnection. "
90+
"Install it with: pip install websockets"
91+
)
92+
93+
self.url = url
94+
self.timeout = timeout
95+
self._ws: typing.Optional[WebSocketClientProtocol] = None
96+
self._next_command_id = 1
97+
self._pending_commands: typing.Dict[int, PendingCommand] = {}
98+
self._event_queue: asyncio.Queue = asyncio.Queue()
99+
self._recv_task: typing.Optional[asyncio.Task] = None
100+
self._closed = False
101+
102+
async def connect(self) -> None:
103+
"""Establish the WebSocket connection."""
104+
if self._ws is not None:
105+
raise CDPConnectionError("Already connected")
106+
107+
try:
108+
self._ws = await websockets.connect(self.url) # type: ignore
109+
self._recv_task = asyncio.create_task(self._receive_loop())
110+
logger.info(f"Connected to {self.url}")
111+
except Exception as e:
112+
raise CDPConnectionError(f"Failed to connect to {self.url}: {e}")
113+
114+
async def close(self) -> None:
115+
"""Close the WebSocket connection."""
116+
if self._closed:
117+
return
118+
119+
self._closed = True
120+
121+
# Cancel the receive task
122+
if self._recv_task:
123+
self._recv_task.cancel()
124+
try:
125+
await self._recv_task
126+
except asyncio.CancelledError:
127+
pass
128+
129+
# Cancel all pending commands
130+
for cmd_id, pending in self._pending_commands.items():
131+
if not pending.future.done():
132+
pending.future.cancel()
133+
self._pending_commands.clear()
134+
135+
# Close the WebSocket
136+
if self._ws:
137+
await self._ws.close()
138+
self._ws = None
139+
140+
logger.info("Connection closed")
141+
142+
async def __aenter__(self) -> CDPConnection:
143+
"""Async context manager entry."""
144+
await self.connect()
145+
return self
146+
147+
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
148+
"""Async context manager exit."""
149+
await self.close()
150+
151+
async def _receive_loop(self) -> None:
152+
"""
153+
Main receive loop that processes incoming WebSocket messages.
154+
155+
This loop:
156+
- Receives messages from the WebSocket
157+
- Parses JSON-RPC responses and matches them to pending commands
158+
- Dispatches events to the event queue
159+
"""
160+
try:
161+
while not self._closed and self._ws:
162+
try:
163+
message = await self._ws.recv()
164+
data = json.loads(message)
165+
166+
if 'id' in data:
167+
# This is a command response
168+
await self._handle_response(data)
169+
elif 'method' in data:
170+
# This is an event
171+
await self._handle_event(data)
172+
else:
173+
logger.warning(f"Received unexpected message: {data}")
174+
175+
except json.JSONDecodeError as e:
176+
logger.error(f"Failed to decode JSON: {e}")
177+
except Exception as e:
178+
logger.error(f"Error in receive loop: {e}")
179+
if not self._closed:
180+
raise
181+
except asyncio.CancelledError:
182+
logger.debug("Receive loop cancelled")
183+
except Exception as e:
184+
logger.error(f"Fatal error in receive loop: {e}")
185+
# Cancel all pending commands with this error
186+
for pending in self._pending_commands.values():
187+
if not pending.future.done():
188+
pending.future.set_exception(CDPConnectionError(f"Connection error: {e}"))
189+
190+
async def _handle_response(self, data: T_JSON_DICT) -> None:
191+
"""Handle a command response."""
192+
cmd_id = data['id']
193+
194+
if cmd_id not in self._pending_commands:
195+
logger.warning(f"Received response for unknown command ID {cmd_id}")
196+
return
197+
198+
pending = self._pending_commands.pop(cmd_id)
199+
200+
if 'error' in data:
201+
error = data['error']
202+
exc = CDPCommandError(
203+
code=error.get('code', -1),
204+
message=error.get('message', 'Unknown error'),
205+
data=error.get('data')
206+
)
207+
pending.future.set_exception(exc)
208+
else:
209+
result = data.get('result', {})
210+
pending.future.set_result(result)
211+
212+
async def _handle_event(self, data: T_JSON_DICT) -> None:
213+
"""Handle an event notification."""
214+
try:
215+
event = parse_json_event(data)
216+
await self._event_queue.put(event)
217+
except Exception as e:
218+
logger.error(f"Failed to parse event: {e}")
219+
220+
async def execute(
221+
self,
222+
cmd: typing.Generator[T_JSON_DICT, T_JSON_DICT, typing.Any],
223+
timeout: typing.Optional[float] = None
224+
) -> typing.Any:
225+
"""
226+
Execute a CDP command.
227+
228+
This method:
229+
- Assigns a unique ID to the command
230+
- Sends it over the WebSocket
231+
- Waits for the response (with multiplexing support)
232+
- Returns the parsed result
233+
234+
Args:
235+
cmd: A CDP command generator (from any CDP domain module)
236+
timeout: Optional timeout override for this command
237+
238+
Returns:
239+
The command result (type depends on the command)
240+
241+
Raises:
242+
CDPCommandError: If the command returns an error
243+
asyncio.TimeoutError: If the command times out
244+
CDPConnectionError: If there's a connection error
245+
246+
Example:
247+
from cdp import page
248+
result = await conn.execute(page.navigate(url="https://example.com"))
249+
"""
250+
if self._ws is None:
251+
raise CDPConnectionError("Not connected")
252+
253+
if self._closed:
254+
raise CDPConnectionError("Connection closed")
255+
256+
# Get the command request from the generator
257+
request = cmd.send(None) # type: ignore[arg-type]
258+
259+
# Assign a unique ID
260+
cmd_id = self._next_command_id
261+
self._next_command_id += 1
262+
request['id'] = cmd_id
263+
264+
# Create a future to track this command
265+
future: asyncio.Future = asyncio.Future()
266+
self._pending_commands[cmd_id] = PendingCommand(
267+
future=future,
268+
method=request['method'],
269+
params=request.get('params', {})
270+
)
271+
272+
try:
273+
# Send the command
274+
await self._ws.send(json.dumps(request))
275+
logger.debug(f"Sent command {cmd_id}: {request['method']}")
276+
277+
# Wait for the response
278+
timeout_val = timeout if timeout is not None else self.timeout
279+
result = await asyncio.wait_for(future, timeout=timeout_val)
280+
281+
# Send the result back to the generator
282+
try:
283+
cmd.send(result)
284+
except StopIteration as e:
285+
return e.value
286+
287+
raise CDPError("Command generator did not stop")
288+
289+
except asyncio.TimeoutError:
290+
# Clean up the pending command
291+
self._pending_commands.pop(cmd_id, None)
292+
raise asyncio.TimeoutError(f"Command {request['method']} timed out")
293+
except Exception:
294+
# Clean up the pending command on error
295+
self._pending_commands.pop(cmd_id, None)
296+
raise
297+
298+
async def listen(self) -> typing.AsyncIterator[typing.Any]:
299+
"""
300+
Listen for events from the browser.
301+
302+
This is an async iterator that yields CDP events as they arrive.
303+
304+
Yields:
305+
CDP event objects (type depends on the event)
306+
307+
Example:
308+
async for event in conn.listen():
309+
if isinstance(event, page.LoadEventFired):
310+
print("Page loaded!")
311+
"""
312+
while not self._closed:
313+
try:
314+
event = await asyncio.wait_for(self._event_queue.get(), timeout=1.0)
315+
yield event
316+
except asyncio.TimeoutError:
317+
# Check if connection is still alive
318+
if self._closed:
319+
break
320+
continue
321+
322+
def get_event_nowait(self) -> typing.Optional[typing.Any]:
323+
"""
324+
Get an event from the queue without waiting.
325+
326+
Returns:
327+
A CDP event object, or None if no events are available
328+
"""
329+
try:
330+
return self._event_queue.get_nowait()
331+
except asyncio.QueueEmpty:
332+
return None
333+
334+
@property
335+
def is_connected(self) -> bool:
336+
"""Check if the connection is open."""
337+
return self._ws is not None and not self._closed
338+
339+
@property
340+
def pending_command_count(self) -> int:
341+
"""Get the number of pending commands (for debugging/monitoring)."""
342+
return len(self._pending_commands)

0 commit comments

Comments
 (0)