diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py index c7e6e82dd016..8e91d609ce91 100644 --- a/nonebot/drivers/__init__.py +++ b/nonebot/drivers/__init__.py @@ -9,6 +9,7 @@ description: nonebot.drivers 模块 """ +from nonebot.internal.driver import UNSET as UNSET from nonebot.internal.driver import URL as URL from nonebot.internal.driver import ASGIMixin as ASGIMixin from nonebot.internal.driver import Cookies as Cookies @@ -25,6 +26,7 @@ from nonebot.internal.driver import ReverseDriver as ReverseDriver from nonebot.internal.driver import ReverseMixin as ReverseMixin from nonebot.internal.driver import Timeout as Timeout +from nonebot.internal.driver import Unset as Unset from nonebot.internal.driver import WebSocket as WebSocket from nonebot.internal.driver import WebSocketClientMixin as WebSocketClientMixin from nonebot.internal.driver import WebSocketServerSetup as WebSocketServerSetup diff --git a/nonebot/drivers/aiohttp.py b/nonebot/drivers/aiohttp.py index cb9aa810cd79..979ce16a228e 100644 --- a/nonebot/drivers/aiohttp.py +++ b/nonebot/drivers/aiohttp.py @@ -19,7 +19,7 @@ from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from typing_extensions import override from multidict import CIMultiDict @@ -44,6 +44,7 @@ QueryTypes, Timeout, TimeoutTypes, + Unset, ) try: @@ -86,11 +87,14 @@ def __init__( raise RuntimeError(f"Unsupported HTTP version: {version}") if isinstance(timeout, Timeout): - self._timeout = aiohttp.ClientTimeout( - total=timeout.total, - connect=timeout.connect, - sock_read=timeout.read, - ) + timeout_kwargs: dict[str, Any] = {} + if not isinstance(timeout.total, Unset): + timeout_kwargs["total"] = timeout.total + if not isinstance(timeout.connect, Unset): + timeout_kwargs["connect"] = timeout.connect + if not isinstance(timeout.read, Unset): + timeout_kwargs["sock_read"] = timeout.read + self._timeout = aiohttp.ClientTimeout(**timeout_kwargs) else: self._timeout = aiohttp.ClientTimeout(timeout) @@ -122,11 +126,14 @@ async def request(self, setup: Request) -> Response: ) if isinstance(setup.timeout, Timeout): - timeout = aiohttp.ClientTimeout( - total=setup.timeout.total, - connect=setup.timeout.connect, - sock_read=setup.timeout.read, - ) + timeout_kwargs: dict[str, Any] = {} + if not isinstance(setup.timeout.total, Unset): + timeout_kwargs["total"] = setup.timeout.total + if not isinstance(setup.timeout.connect, Unset): + timeout_kwargs["connect"] = setup.timeout.connect + if not isinstance(setup.timeout.read, Unset): + timeout_kwargs["sock_read"] = setup.timeout.read + timeout = aiohttp.ClientTimeout(**timeout_kwargs) else: timeout = aiohttp.ClientTimeout(setup.timeout) @@ -172,11 +179,14 @@ async def stream_request( ) if isinstance(setup.timeout, Timeout): - timeout = aiohttp.ClientTimeout( - total=setup.timeout.total, - connect=setup.timeout.connect, - sock_read=setup.timeout.read, - ) + timeout_kwargs: dict[str, Any] = {} + if not isinstance(setup.timeout.total, Unset): + timeout_kwargs["total"] = setup.timeout.total + if not isinstance(setup.timeout.connect, Unset): + timeout_kwargs["connect"] = setup.timeout.connect + if not isinstance(setup.timeout.read, Unset): + timeout_kwargs["sock_read"] = setup.timeout.read + timeout = aiohttp.ClientTimeout(**timeout_kwargs) else: timeout = aiohttp.ClientTimeout(setup.timeout) @@ -271,10 +281,13 @@ async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]: raise RuntimeError(f"Unsupported HTTP version: {setup.version}") if isinstance(setup.timeout, Timeout): - timeout = aiohttp.ClientWSTimeout( - ws_receive=setup.timeout.read, # type: ignore - ws_close=setup.timeout.total, # type: ignore - ) + timeout_kwargs: dict[str, Any] = {} + if not isinstance(setup.timeout.read, Unset): + timeout_kwargs["ws_receive"] = setup.timeout.read + ws_close = setup.timeout.close or setup.timeout.total + if not isinstance(ws_close, Unset): + timeout_kwargs["ws_close"] = ws_close + timeout = aiohttp.ClientWSTimeout(**timeout_kwargs) # type: ignore else: timeout = aiohttp.ClientWSTimeout(ws_close=setup.timeout or 10.0) # type: ignore diff --git a/nonebot/drivers/httpx.py b/nonebot/drivers/httpx.py index 70bec59562f9..cf1f63357444 100644 --- a/nonebot/drivers/httpx.py +++ b/nonebot/drivers/httpx.py @@ -18,7 +18,7 @@ """ from collections.abc import AsyncGenerator -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from typing_extensions import override from multidict import CIMultiDict @@ -40,6 +40,7 @@ QueryTypes, Timeout, TimeoutTypes, + Unset, ) try: @@ -74,11 +75,14 @@ def __init__( self._version = HTTPVersion(version) if isinstance(timeout, Timeout): - self._timeout = httpx.Timeout( - timeout=timeout.total, - connect=timeout.connect, - read=timeout.read, - ) + timeout_kwargs: dict[str, Any] = {} + if not isinstance(timeout.total, Unset): + timeout_kwargs["timeout"] = timeout.total + if not isinstance(timeout.connect, Unset): + timeout_kwargs["connect"] = timeout.connect + if not isinstance(timeout.read, Unset): + timeout_kwargs["read"] = timeout.read + self._timeout = httpx.Timeout(**timeout_kwargs) if timeout_kwargs else None else: self._timeout = httpx.Timeout(timeout) @@ -93,11 +97,14 @@ def client(self) -> httpx.AsyncClient: @override async def request(self, setup: Request) -> Response: if isinstance(setup.timeout, Timeout): - timeout = httpx.Timeout( - timeout=setup.timeout.total, - connect=setup.timeout.connect, - read=setup.timeout.read, - ) + timeout_kwargs: dict[str, Any] = {} + if not isinstance(setup.timeout.total, Unset): + timeout_kwargs["timeout"] = setup.timeout.total + if not isinstance(setup.timeout.connect, Unset): + timeout_kwargs["connect"] = setup.timeout.connect + if not isinstance(setup.timeout.read, Unset): + timeout_kwargs["read"] = setup.timeout.read + timeout = httpx.Timeout(**timeout_kwargs) if timeout_kwargs else None else: timeout = httpx.Timeout(setup.timeout) @@ -129,11 +136,14 @@ async def stream_request( chunk_size: int = 1024, ) -> AsyncGenerator[Response, None]: if isinstance(setup.timeout, Timeout): - timeout = httpx.Timeout( - timeout=setup.timeout.total, - connect=setup.timeout.connect, - read=setup.timeout.read, - ) + timeout_kwargs: dict[str, Any] = {} + if not isinstance(setup.timeout.total, Unset): + timeout_kwargs["timeout"] = setup.timeout.total + if not isinstance(setup.timeout.connect, Unset): + timeout_kwargs["connect"] = setup.timeout.connect + if not isinstance(setup.timeout.read, Unset): + timeout_kwargs["read"] = setup.timeout.read + timeout = httpx.Timeout(**timeout_kwargs) if timeout_kwargs else None else: timeout = httpx.Timeout(setup.timeout) diff --git a/nonebot/drivers/websockets.py b/nonebot/drivers/websockets.py index 3e6aa07ba6d6..36e89279118b 100644 --- a/nonebot/drivers/websockets.py +++ b/nonebot/drivers/websockets.py @@ -25,7 +25,13 @@ from typing import TYPE_CHECKING, Any, TypeVar from typing_extensions import ParamSpec, override -from nonebot.drivers import Request, Timeout, WebSocketClientMixin, combine_driver +from nonebot.drivers import ( + Request, + Timeout, + Unset, + WebSocketClientMixin, + combine_driver, +) from nonebot.drivers import WebSocket as BaseWebSocket from nonebot.drivers.none import Driver as NoneDriver from nonebot.exception import WebSocketClosed @@ -71,15 +77,25 @@ def type(self) -> str: @asynccontextmanager async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]: if isinstance(setup.timeout, Timeout): - timeout = setup.timeout.total or setup.timeout.connect or setup.timeout.read + timeout_kwargs: dict[str, Any] = {} + open_timeout = ( + setup.timeout.total or setup.timeout.connect or setup.timeout.read + ) + if not isinstance(open_timeout, Unset): + timeout_kwargs["open_timeout"] = open_timeout + if not isinstance(setup.timeout.close, Unset): + timeout_kwargs["close_timeout"] = setup.timeout.close else: - timeout = setup.timeout + timeout_kwargs = { + "open_timeout": setup.timeout, + "close_timeout": setup.timeout or 10.0, + } connection = connect( str(setup.url), additional_headers={**setup.headers, **setup.cookies.as_header(setup)}, proxy=setup.proxy if setup.proxy is not None else True, - open_timeout=timeout, + **timeout_kwargs, ) async with connection as ws: yield WebSocket(request=setup, websocket=ws) diff --git a/nonebot/internal/driver/__init__.py b/nonebot/internal/driver/__init__.py index e4b3f042c3f6..168e6af4b18c 100644 --- a/nonebot/internal/driver/__init__.py +++ b/nonebot/internal/driver/__init__.py @@ -9,6 +9,7 @@ from .abstract import ReverseMixin as ReverseMixin from .abstract import WebSocketClientMixin as WebSocketClientMixin from .combine import combine_driver as combine_driver +from .model import UNSET as UNSET from .model import URL as URL from .model import ContentTypes as ContentTypes from .model import Cookies as Cookies @@ -29,5 +30,6 @@ from .model import SimpleQuery as SimpleQuery from .model import Timeout as Timeout from .model import TimeoutTypes as TimeoutTypes +from .model import Unset as Unset from .model import WebSocket as WebSocket from .model import WebSocketServerSetup as WebSocketServerSetup diff --git a/nonebot/internal/driver/model.py b/nonebot/internal/driver/model.py index 169d589d129d..7bdb8c72437c 100644 --- a/nonebot/internal/driver/model.py +++ b/nonebot/internal/driver/model.py @@ -4,19 +4,42 @@ from enum import Enum from http.cookiejar import Cookie, CookieJar from typing import IO, Any, TypeAlias +from typing_extensions import Self import urllib.request from multidict import CIMultiDict from yarl import URL as URL +class Unset: + """Sentinel for unset fields.""" + + __slots__ = () + _instance: Self | None = None + + def __new__(cls) -> Self: + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __repr__(self) -> str: + return "UNSET" + + def __bool__(self) -> bool: + return False + + +UNSET = Unset() + + @dataclass class Timeout: """Request 超时配置。""" - total: float | None = None - connect: float | None = None - read: float | None = None + total: float | None | Unset = UNSET + connect: float | None | Unset = UNSET + read: float | None | Unset = UNSET + close: float | None | Unset = UNSET RawURL: TypeAlias = tuple[bytes, bytes, int | None, bytes] diff --git a/tests/test_driver.py b/tests/test_driver.py index 1b7a2a33b9d7..a3acc3cc0e38 100644 --- a/tests/test_driver.py +++ b/tests/test_driver.py @@ -10,6 +10,7 @@ from nonebot.adapters import Bot from nonebot.dependencies import Dependent from nonebot.drivers import ( + UNSET, URL, ASGIMixin, Driver, @@ -18,6 +19,7 @@ Request, Response, Timeout, + Unset, WebSocket, WebSocketClientMixin, WebSocketServerSetup, @@ -706,6 +708,143 @@ async def receive(self, timeout: float | None = None) -> WSMessage: # noqa: ASY await ws.receive() +def test_unset_sentinel(): + assert UNSET is Unset() + assert repr(UNSET) == "UNSET" + assert not UNSET + assert bool(UNSET) is False + + +def test_timeout_unset_vs_none(): + # default: all fields are UNSET + t = Timeout() + assert isinstance(t.total, Unset) + assert isinstance(t.connect, Unset) + assert isinstance(t.read, Unset) + assert isinstance(t.close, Unset) + + # explicitly set to None + t = Timeout(close=None) + assert t.close is None + assert not isinstance(t.close, Unset) + + # explicitly set to a value + t = Timeout(total=5.0, close=None) + assert t.total == 5.0 + assert t.close is None + assert isinstance(t.connect, Unset) + assert isinstance(t.read, Unset) + + +@pytest.mark.anyio +@pytest.mark.parametrize( + "driver", + [ + pytest.param("nonebot.drivers.httpx:Driver", id="httpx"), + pytest.param("nonebot.drivers.aiohttp:Driver", id="aiohttp"), + ], + indirect=True, +) +async def test_http_client_timeout_unset(driver: Driver, server_url: URL): + """HTTP requests work with fully unset, partial, and None timeout fields.""" + assert isinstance(driver, HTTPClientMixin) + + # all fields unset — library defaults should apply + request = Request("POST", server_url, content="test", timeout=Timeout()) + response = await driver.request(request) + assert response.status_code == 200 + + # only total set + request = Request("POST", server_url, content="test", timeout=Timeout(total=10.0)) + response = await driver.request(request) + assert response.status_code == 200 + + # explicit None (no timeout) + request = Request( + "POST", server_url, content="test", timeout=Timeout(total=None, read=None) + ) + response = await driver.request(request) + assert response.status_code == 200 + + # stream_request with unset timeout + request = Request("POST", server_url, content="test", timeout=Timeout()) + async for resp in driver.stream_request(request, chunk_size=1024): + assert resp.status_code == 200 + + # stream_request with partial timeout + request = Request( + "POST", server_url, content="test", timeout=Timeout(total=10.0, read=None) + ) + async for resp in driver.stream_request(request, chunk_size=1024): + assert resp.status_code == 200 + + # session with Timeout object + session = driver.get_session(timeout=Timeout(total=10.0, connect=5.0, read=5.0)) + async with session: + request = Request("POST", server_url, content="test") + response = await session.request(request) + assert response.status_code == 200 + + # session with fully unset Timeout + session = driver.get_session(timeout=Timeout()) + async with session: + request = Request("POST", server_url, content="test") + response = await session.request(request) + assert response.status_code == 200 + + +@pytest.mark.anyio +@pytest.mark.parametrize( + "driver", + [ + pytest.param("nonebot.drivers.websockets:Driver", id="websockets"), + pytest.param("nonebot.drivers.aiohttp:Driver", id="aiohttp"), + ], + indirect=True, +) +async def test_websocket_client_timeout_unset(driver: Driver, server_url: URL): + """WebSocket connections work with fully unset, partial, and None timeout fields.""" + assert isinstance(driver, WebSocketClientMixin) + + ws_url = server_url.with_scheme("ws") + + # all fields unset + request = Request("GET", ws_url, timeout=Timeout()) + async with driver.websocket(request) as ws: + await ws.send("quit") + with pytest.raises(WebSocketClosed): + await ws.receive() + + await anyio.sleep(1) + + # close explicitly set to None (no close timeout) + request = Request("GET", ws_url, timeout=Timeout(close=None)) + async with driver.websocket(request) as ws: + await ws.send("quit") + with pytest.raises(WebSocketClosed): + await ws.receive() + + await anyio.sleep(1) + + # partial: only total set + request = Request("GET", ws_url, timeout=Timeout(total=10.0)) + async with driver.websocket(request) as ws: + await ws.send("quit") + with pytest.raises(WebSocketClosed): + await ws.receive() + + await anyio.sleep(1) + + # read and close explicitly set + request = Request("GET", ws_url, timeout=Timeout(read=5.0, close=5.0)) + async with driver.websocket(request) as ws: + await ws.send("quit") + with pytest.raises(WebSocketClosed): + await ws.receive() + + await anyio.sleep(1) + + @pytest.mark.parametrize( ("driver", "driver_type"), [