Skip to content
2 changes: 2 additions & 0 deletions nonebot/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
description: nonebot.drivers 模块
"""

from nonebot.internal.driver import UNSET as UNSET
from nonebot.internal.driver import URL as URL
from nonebot.internal.driver import ASGIMixin as ASGIMixin
from nonebot.internal.driver import Cookies as Cookies
Expand All @@ -25,6 +26,7 @@
from nonebot.internal.driver import ReverseDriver as ReverseDriver
from nonebot.internal.driver import ReverseMixin as ReverseMixin
from nonebot.internal.driver import Timeout as Timeout
from nonebot.internal.driver import Unset as Unset
from nonebot.internal.driver import WebSocket as WebSocket
from nonebot.internal.driver import WebSocketClientMixin as WebSocketClientMixin
from nonebot.internal.driver import WebSocketServerSetup as WebSocketServerSetup
Expand Down
53 changes: 33 additions & 20 deletions nonebot/drivers/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from typing_extensions import override

from multidict import CIMultiDict
Expand All @@ -44,6 +44,7 @@
QueryTypes,
Timeout,
TimeoutTypes,
Unset,
)

try:
Expand Down Expand Up @@ -86,11 +87,14 @@ def __init__(
raise RuntimeError(f"Unsupported HTTP version: {version}")

if isinstance(timeout, Timeout):
self._timeout = aiohttp.ClientTimeout(
total=timeout.total,
connect=timeout.connect,
sock_read=timeout.read,
)
timeout_kwargs: dict[str, Any] = {}
if not isinstance(timeout.total, Unset):
timeout_kwargs["total"] = timeout.total
if not isinstance(timeout.connect, Unset):
timeout_kwargs["connect"] = timeout.connect
if not isinstance(timeout.read, Unset):
timeout_kwargs["sock_read"] = timeout.read
self._timeout = aiohttp.ClientTimeout(**timeout_kwargs)
else:
self._timeout = aiohttp.ClientTimeout(timeout)

Expand Down Expand Up @@ -122,11 +126,14 @@ async def request(self, setup: Request) -> Response:
)

if isinstance(setup.timeout, Timeout):
timeout = aiohttp.ClientTimeout(
total=setup.timeout.total,
connect=setup.timeout.connect,
sock_read=setup.timeout.read,
)
timeout_kwargs: dict[str, Any] = {}
if not isinstance(setup.timeout.total, Unset):
timeout_kwargs["total"] = setup.timeout.total
if not isinstance(setup.timeout.connect, Unset):
timeout_kwargs["connect"] = setup.timeout.connect
if not isinstance(setup.timeout.read, Unset):
timeout_kwargs["sock_read"] = setup.timeout.read
timeout = aiohttp.ClientTimeout(**timeout_kwargs)
else:
timeout = aiohttp.ClientTimeout(setup.timeout)

Expand Down Expand Up @@ -172,11 +179,14 @@ async def stream_request(
)

if isinstance(setup.timeout, Timeout):
timeout = aiohttp.ClientTimeout(
total=setup.timeout.total,
connect=setup.timeout.connect,
sock_read=setup.timeout.read,
)
timeout_kwargs: dict[str, Any] = {}
if not isinstance(setup.timeout.total, Unset):
timeout_kwargs["total"] = setup.timeout.total
if not isinstance(setup.timeout.connect, Unset):
timeout_kwargs["connect"] = setup.timeout.connect
if not isinstance(setup.timeout.read, Unset):
timeout_kwargs["sock_read"] = setup.timeout.read
timeout = aiohttp.ClientTimeout(**timeout_kwargs)
else:
timeout = aiohttp.ClientTimeout(setup.timeout)

Expand Down Expand Up @@ -271,10 +281,13 @@ async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
raise RuntimeError(f"Unsupported HTTP version: {setup.version}")

if isinstance(setup.timeout, Timeout):
timeout = aiohttp.ClientWSTimeout(
ws_receive=setup.timeout.read, # type: ignore
ws_close=setup.timeout.total, # type: ignore
)
timeout_kwargs: dict[str, Any] = {}
if not isinstance(setup.timeout.read, Unset):
timeout_kwargs["ws_receive"] = setup.timeout.read
ws_close = setup.timeout.close or setup.timeout.total
if not isinstance(ws_close, Unset):
timeout_kwargs["ws_close"] = ws_close
timeout = aiohttp.ClientWSTimeout(**timeout_kwargs) # type: ignore
else:
timeout = aiohttp.ClientWSTimeout(ws_close=setup.timeout or 10.0) # type: ignore

Expand Down
42 changes: 26 additions & 16 deletions nonebot/drivers/httpx.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""

from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from typing_extensions import override

from multidict import CIMultiDict
Expand All @@ -40,6 +40,7 @@
QueryTypes,
Timeout,
TimeoutTypes,
Unset,
)

try:
Expand Down Expand Up @@ -74,11 +75,14 @@ def __init__(
self._version = HTTPVersion(version)

if isinstance(timeout, Timeout):
self._timeout = httpx.Timeout(
timeout=timeout.total,
connect=timeout.connect,
read=timeout.read,
)
timeout_kwargs: dict[str, Any] = {}
if not isinstance(timeout.total, Unset):
timeout_kwargs["timeout"] = timeout.total
if not isinstance(timeout.connect, Unset):
timeout_kwargs["connect"] = timeout.connect
if not isinstance(timeout.read, Unset):
timeout_kwargs["read"] = timeout.read
self._timeout = httpx.Timeout(**timeout_kwargs) if timeout_kwargs else None
else:
self._timeout = httpx.Timeout(timeout)

Expand All @@ -93,11 +97,14 @@ def client(self) -> httpx.AsyncClient:
@override
async def request(self, setup: Request) -> Response:
if isinstance(setup.timeout, Timeout):
timeout = httpx.Timeout(
timeout=setup.timeout.total,
connect=setup.timeout.connect,
read=setup.timeout.read,
)
timeout_kwargs: dict[str, Any] = {}
if not isinstance(setup.timeout.total, Unset):
timeout_kwargs["timeout"] = setup.timeout.total
if not isinstance(setup.timeout.connect, Unset):
timeout_kwargs["connect"] = setup.timeout.connect
if not isinstance(setup.timeout.read, Unset):
timeout_kwargs["read"] = setup.timeout.read
timeout = httpx.Timeout(**timeout_kwargs) if timeout_kwargs else None
else:
timeout = httpx.Timeout(setup.timeout)

Expand Down Expand Up @@ -129,11 +136,14 @@ async def stream_request(
chunk_size: int = 1024,
) -> AsyncGenerator[Response, None]:
if isinstance(setup.timeout, Timeout):
timeout = httpx.Timeout(
timeout=setup.timeout.total,
connect=setup.timeout.connect,
read=setup.timeout.read,
)
timeout_kwargs: dict[str, Any] = {}
if not isinstance(setup.timeout.total, Unset):
timeout_kwargs["timeout"] = setup.timeout.total
if not isinstance(setup.timeout.connect, Unset):
timeout_kwargs["connect"] = setup.timeout.connect
if not isinstance(setup.timeout.read, Unset):
timeout_kwargs["read"] = setup.timeout.read
timeout = httpx.Timeout(**timeout_kwargs) if timeout_kwargs else None
else:
timeout = httpx.Timeout(setup.timeout)

Expand Down
24 changes: 20 additions & 4 deletions nonebot/drivers/websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,13 @@
from typing import TYPE_CHECKING, Any, TypeVar
from typing_extensions import ParamSpec, override

from nonebot.drivers import Request, Timeout, WebSocketClientMixin, combine_driver
from nonebot.drivers import (
Request,
Timeout,
Unset,
WebSocketClientMixin,
combine_driver,
)
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers.none import Driver as NoneDriver
from nonebot.exception import WebSocketClosed
Expand Down Expand Up @@ -71,15 +77,25 @@ def type(self) -> str:
@asynccontextmanager
async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
if isinstance(setup.timeout, Timeout):
timeout = setup.timeout.total or setup.timeout.connect or setup.timeout.read
timeout_kwargs: dict[str, Any] = {}
open_timeout = (
setup.timeout.total or setup.timeout.connect or setup.timeout.read
)
if not isinstance(open_timeout, Unset):
timeout_kwargs["open_timeout"] = open_timeout
if not isinstance(setup.timeout.close, Unset):
timeout_kwargs["close_timeout"] = setup.timeout.close
else:
timeout = setup.timeout
timeout_kwargs = {
"open_timeout": setup.timeout,
"close_timeout": setup.timeout or 10.0,
}

connection = connect(
str(setup.url),
additional_headers={**setup.headers, **setup.cookies.as_header(setup)},
proxy=setup.proxy if setup.proxy is not None else True,
open_timeout=timeout,
**timeout_kwargs,
)
async with connection as ws:
yield WebSocket(request=setup, websocket=ws)
Expand Down
2 changes: 2 additions & 0 deletions nonebot/internal/driver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .abstract import ReverseMixin as ReverseMixin
from .abstract import WebSocketClientMixin as WebSocketClientMixin
from .combine import combine_driver as combine_driver
from .model import UNSET as UNSET
from .model import URL as URL
from .model import ContentTypes as ContentTypes
from .model import Cookies as Cookies
Expand All @@ -29,5 +30,6 @@
from .model import SimpleQuery as SimpleQuery
from .model import Timeout as Timeout
from .model import TimeoutTypes as TimeoutTypes
from .model import Unset as Unset
from .model import WebSocket as WebSocket
from .model import WebSocketServerSetup as WebSocketServerSetup
29 changes: 26 additions & 3 deletions nonebot/internal/driver/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,42 @@
from enum import Enum
from http.cookiejar import Cookie, CookieJar
from typing import IO, Any, TypeAlias
from typing_extensions import Self
import urllib.request

from multidict import CIMultiDict
from yarl import URL as URL


class Unset:
"""Sentinel for unset fields."""

__slots__ = ()
_instance: Self | None = None

def __new__(cls) -> Self:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance

def __repr__(self) -> str:
return "UNSET"

def __bool__(self) -> bool:
return False


UNSET = Unset()


@dataclass
class Timeout:
"""Request 超时配置。"""

total: float | None = None
connect: float | None = None
read: float | None = None
total: float | None | Unset = UNSET
connect: float | None | Unset = UNSET
read: float | None | Unset = UNSET
close: float | None | Unset = UNSET


RawURL: TypeAlias = tuple[bytes, bytes, int | None, bytes]
Expand Down
Loading
Loading