From 8f0aaf7077f39528bd634f7d793c005eae3c2783 Mon Sep 17 00:00:00 2001 From: Taiwo-Sh Date: Thu, 5 Feb 2026 14:31:25 +0100 Subject: [PATCH] feat: Adds support for passing premade websocket-rpc endpoint to pubsub endpoint --- .gitignore | 4 +- fastapi_websocket_pubsub/pub_sub_server.py | 128 ++++++++--- tests/premade_endpoint_test.py | 255 +++++++++++++++++++++ 3 files changed, 358 insertions(+), 29 deletions(-) create mode 100644 tests/premade_endpoint_test.py diff --git a/.gitignore b/.gitignore index 8d02e6c..e8a604f 100644 --- a/.gitignore +++ b/.gitignore @@ -130,4 +130,6 @@ dmypy.json # editors .vscode/ -.idea/ \ No newline at end of file +.idea/ + +venv/ diff --git a/fastapi_websocket_pubsub/pub_sub_server.py b/fastapi_websocket_pubsub/pub_sub_server.py index 755268a..fdcd710 100644 --- a/fastapi_websocket_pubsub/pub_sub_server.py +++ b/fastapi_websocket_pubsub/pub_sub_server.py @@ -1,9 +1,10 @@ import asyncio -from typing import Coroutine, List, Union +from typing import Coroutine, List, Union, Optional from fastapi import WebSocket from fastapi_websocket_rpc import WebsocketRPCEndpoint from fastapi_websocket_rpc.rpc_channel import RpcChannel +from fastapi_websocket_rpc.rpc_methods import RpcMethodsBase from .logger import get_logger from .event_broadcaster import EventBroadcaster @@ -34,7 +35,8 @@ def __init__( on_connect: List[Coroutine] = None, on_disconnect: List[Coroutine] = None, rpc_channel_get_remote_id: bool = False, - ignore_broadcaster_disconnected = True, + ignore_broadcaster_disconnected=True, + endpoint: Optional[WebsocketRPCEndpoint] = None, ): """ The PubSub endpoint recives subscriptions from clients and publishes data back to them upon receiving relevant publications. @@ -57,28 +59,92 @@ def __init__( on_connect (List[Coroutine]): callbacks on connection being established (each callback is called with the channel) on_disconnect (List[Coroutine]): callbacks on connection termination (each callback is called with the channel) ignore_broadcaster_disconnected: Don't end main loop if broadcaster's reader task ends (due to underlying disconnection) + endpoint (Optional[WebsocketRPCEndpoint]): An optional pre-configured `WebsocketRPCEndpoint` instance to use instead of creating a new one. + `met """ - self.notifier = ( - notifier if notifier is not None else WebSocketRpcEventNotifier() - ) - self.broadcaster = ( - broadcaster - if isinstance(broadcaster, EventBroadcaster) or broadcaster is None - else EventBroadcaster(broadcaster, self.notifier) - ) - self.methods = ( - methods_class(self.notifier) - if methods_class is not None - else RpcEventServerMethods(self.notifier) - ) if on_disconnect is None: on_disconnect = [] - self.endpoint = WebsocketRPCEndpoint( - self.methods, - on_disconnect=[self.on_disconnect, *on_disconnect], - on_connect=on_connect, - rpc_channel_get_remote_id=rpc_channel_get_remote_id, - ) + if on_connect is None: + on_connect = [] + + if endpoint is None: + self.notifier = ( + notifier if notifier is not None else WebSocketRpcEventNotifier() + ) + self.broadcaster = ( + broadcaster + if isinstance(broadcaster, EventBroadcaster) or broadcaster is None + else EventBroadcaster(broadcaster, self.notifier) + ) + self.methods = ( + methods_class(self.notifier) + if methods_class is not None + else RpcEventServerMethods(self.notifier) + ) + self.endpoint = WebsocketRPCEndpoint( + self.methods, + on_disconnect=[self.on_disconnect, *on_disconnect], + on_connect=on_connect, + rpc_channel_get_remote_id=rpc_channel_get_remote_id, + ) + else: + # Extract notifier from endpoint's methods if they exist + if endpoint.methods is not None and isinstance( + endpoint.methods, RpcEventServerMethods + ): + # Use the notifier from the existing methods + self.notifier = endpoint.methods.event_notifier + else: + # No valid methods, use provided notifier or create new one + self.notifier = ( + notifier if notifier is not None else WebSocketRpcEventNotifier() + ) + + # Setup broadcaster with the correct notifier + self.broadcaster = ( + broadcaster + if isinstance(broadcaster, EventBroadcaster) or broadcaster is None + else EventBroadcaster(broadcaster, self.notifier) + ) + + # If the endpoint has no methods, or methods is `RpcMethodsBase`, we set up the correct methods + if endpoint.methods is None or type(endpoint.methods) is RpcMethodsBase: + # No methods yet so create them with our notifier + self.methods = ( + methods_class(self.notifier) + if methods_class is not None + else RpcEventServerMethods(self.notifier) + ) + endpoint.methods = self.methods + elif isinstance(endpoint.methods, RpcEventServerMethods): + # The endpoint has valid methods so we keep them + self.methods = endpoint.methods + + # If `methods_class` was provided, it's incompatible with premade endpoint + if methods_class is not None: + raise ValueError( + "Cannot specify `methods_class` when using a premade endpoint " + "that already has methods configured. Either pass `methods=None` " + "to the endpoint or don't specify `methods_class`." + ) + else: + # Invalid methods type + raise ValueError( + "Premade endpoint must have methods that derive from `RpcEventServerMethods` or None. " + f"Got {type(endpoint.methods).__name__} instead." + ) + + if endpoint._on_disconnect is None: + endpoint._on_disconnect = [] + if endpoint._on_connect is None: + endpoint._on_connect = [] + + # Ensure we register this pubsub endpoint's `on_disconnect` handler as the first to run + endpoint._on_disconnect.insert(0, self.on_disconnect) + endpoint._on_disconnect.extend(on_disconnect) + endpoint._on_connect.extend(on_connect) + self.endpoint = endpoint + self._rpc_channel_get_remote_id = rpc_channel_get_remote_id # server id used to publish events for clients self._id = self.notifier.gen_subscriber_id() @@ -92,7 +158,8 @@ async def subscribe( return await self.notifier.subscribe(self._subscriber_id, topics, callback) async def unsubscribe( - self, topics: Union[TopicList, ALL_TOPICS]) -> List[Subscription]: + self, topics: Union[TopicList, ALL_TOPICS] + ) -> List[Subscription]: return await self.notifier.unsubscribe(self._subscriber_id, topics) async def publish(self, topics: Union[TopicList, Topic], data=None): @@ -107,7 +174,7 @@ async def publish(self, topics: Union[TopicList, Topic], data=None): # sharing here means - the broadcaster listens in to the notifier as well logger.debug(f"Publishing message to topics: {topics}") if self.broadcaster is not None: - logger.debug(f"Acquiring broadcaster sharing context") + logger.debug("Acquiring broadcaster sharing context") async with self.broadcaster.get_context(listen=False, share=True): await self.notifier.notify(topics, data, notifier_id=self._id) # otherwise just notify @@ -136,14 +203,19 @@ async def main_loop(self, websocket: WebSocket, client_id: str = None, **kwargs) async with self.broadcaster: logger.debug("Entering endpoint's main loop with broadcaster") if self._ignore_broadcaster_disconnected: - await self.endpoint.main_loop(websocket, client_id=client_id, **kwargs) + await self.endpoint.main_loop( + websocket, client_id=client_id, **kwargs + ) else: main_loop_task = asyncio.create_task( - self.endpoint.main_loop(websocket, client_id=client_id, **kwargs) + self.endpoint.main_loop( + websocket, client_id=client_id, **kwargs + ) + ) + done, pending = await asyncio.wait( + [main_loop_task, self.broadcaster.get_reader_task()], + return_when=asyncio.FIRST_COMPLETED, ) - done, pending = await asyncio.wait([main_loop_task, - self.broadcaster.get_reader_task()], - return_when=asyncio.FIRST_COMPLETED) logger.debug(f"task is done: {done}") # broadcaster's reader task is used by other endpoints and shouldn't be cancelled if main_loop_task in pending: diff --git a/tests/premade_endpoint_test.py b/tests/premade_endpoint_test.py new file mode 100644 index 0000000..e148248 --- /dev/null +++ b/tests/premade_endpoint_test.py @@ -0,0 +1,255 @@ +import os +import sys +import pytest +import asyncio +import uvicorn +import requests +from fastapi import FastAPI +from multiprocessing import Process + +from fastapi_websocket_rpc.utils import gen_uid +from fastapi_websocket_rpc.logger import get_logger +from fastapi_websocket_rpc import WebsocketRPCEndpoint, RpcMethodsBase + +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)) +) +from fastapi_websocket_pubsub import PubSubEndpoint, PubSubClient +from fastapi_websocket_pubsub.rpc_event_methods import RpcEventServerMethods +from fastapi_websocket_pubsub.websocket_rpc_event_notifier import ( + WebSocketRpcEventNotifier, +) + +logger = get_logger("Test") + +PORT = int(os.environ.get("PORT") or "7990") +uri = f"ws://localhost:{PORT}/pubsub" +trigger_url = f"http://localhost:{PORT}/trigger" + +DATA = "MAGIC" +EVENT_TOPIC = "event/has-happened" + + +class CustomPubSubMethods(RpcEventServerMethods): + async def custom_method(self) -> str: + return "custom_response" + + +class InvalidMethods(RpcMethodsBase): + async def some_method(self) -> str: + return "hello" + + +def setup_server_rest_route(app, endpoint: PubSubEndpoint): + @app.get("/trigger") + async def trigger_events(): + asyncio.create_task(endpoint.publish([EVENT_TOPIC], data=DATA)) + return "triggered" + + +def setup_server_with_premade_endpoint(): + app = FastAPI() + premade_endpoint = WebsocketRPCEndpoint() + endpoint = PubSubEndpoint(endpoint=premade_endpoint) + endpoint.register_route(app, path="/pubsub") + setup_server_rest_route(app, endpoint) + uvicorn.run(app, port=PORT) + + +def setup_server_with_no_methods_endpoint(): + app = FastAPI() + premade_endpoint = WebsocketRPCEndpoint() + endpoint = PubSubEndpoint(endpoint=premade_endpoint) + endpoint.register_route(app, path="/pubsub") + setup_server_rest_route(app, endpoint) + uvicorn.run(app, port=PORT) + + +def setup_server_with_custom_methods(): + app = FastAPI() + endpoint = PubSubEndpoint(methods_class=CustomPubSubMethods) + endpoint.register_route(app, path="/pubsub") + setup_server_rest_route(app, endpoint) + uvicorn.run(app, port=PORT) + + +def setup_server_with_premade_methods_and_notifier(): + app = FastAPI() + custom_notifier = WebSocketRpcEventNotifier() + custom_methods = RpcEventServerMethods(custom_notifier) + premade_endpoint = WebsocketRPCEndpoint(custom_methods) + endpoint = PubSubEndpoint(endpoint=premade_endpoint) + endpoint.register_route(app, path="/pubsub") + setup_server_rest_route(app, endpoint) + uvicorn.run(app, port=PORT) + + +@pytest.fixture() +def server(): + proc = Process(target=setup_server_with_premade_endpoint, args=(), daemon=True) + proc.start() + yield proc + proc.kill() + + +@pytest.fixture() +def server_no_methods(): + proc = Process(target=setup_server_with_no_methods_endpoint, args=(), daemon=True) + proc.start() + yield proc + proc.kill() + + +@pytest.fixture() +def server_custom_methods(): + proc = Process(target=setup_server_with_custom_methods, args=(), daemon=True) + proc.start() + yield proc + proc.kill() + + +@pytest.fixture() +def server_premade_methods_notifier(): + proc = Process( + target=setup_server_with_premade_methods_and_notifier, args=(), daemon=True + ) + proc.start() + yield proc + proc.kill() + + +@pytest.mark.asyncio +async def test_premade_endpoint_subscribe_http_trigger(server): + # finish trigger + finish = asyncio.Event() + async with PubSubClient() as client: + + async def on_event(data, topic): + assert data == DATA + finish.set() + + # subscribe for the event + client.subscribe(EVENT_TOPIC, on_event) + # start listening + client.start_client(uri) + # wait for the client to be ready to receive events + await client.wait_until_ready() + # trigger the server via an HTTP route + requests.get(trigger_url) + # wait for finish trigger + await asyncio.wait_for(finish.wait(), 5) + + +@pytest.mark.asyncio +async def test_premade_endpoint_pub_sub(server): + # finish trigger + finish = asyncio.Event() + async with PubSubClient() as client: + + async def on_event(data, topic): + assert data == DATA + finish.set() + + # subscribe for the event + client.subscribe(EVENT_TOPIC, on_event) + # start listening + client.start_client(uri) + # wait for the client to be ready to receive events + await client.wait_until_ready() + # publish events (with sync=False to avoid deadlocks waiting on the publish to ourselves) + published = await client.publish( + [EVENT_TOPIC], data=DATA, sync=False, notifier_id=gen_uid() + ) + assert published.result + # wait for finish trigger + await asyncio.wait_for(finish.wait(), 5) + + +@pytest.mark.asyncio +async def test_premade_endpoint_no_methods(server_no_methods): + # finish trigger + finish = asyncio.Event() + async with PubSubClient() as client: + + async def on_event(data, topic): + assert data == DATA + finish.set() + + # subscribe for the event + client.subscribe(EVENT_TOPIC, on_event) + # start listening + client.start_client(uri) + # wait for the client to be ready to receive events + await client.wait_until_ready() + # trigger the server via an HTTP route + requests.get(trigger_url) + # wait for finish trigger + await asyncio.wait_for(finish.wait(), 5) + + +@pytest.mark.asyncio +async def test_premade_endpoint_custom_methods_preserved(server_custom_methods): + # finish trigger + finish = asyncio.Event() + async with PubSubClient() as client: + + async def on_event(data, topic): + assert data == DATA + finish.set() + + # subscribe for the event + client.subscribe(EVENT_TOPIC, on_event) + # start listening + client.start_client(uri) + # wait for the client to be ready to receive events + await client.wait_until_ready() + # trigger the server via an HTTP route + requests.get(trigger_url) + # wait for finish trigger + await asyncio.wait_for(finish.wait(), 5) + + +@pytest.mark.asyncio +async def test_premade_endpoint_with_methods_and_notifier( + server_premade_methods_notifier, +): + # finish trigger + finish = asyncio.Event() + async with PubSubClient() as client: + + async def on_event(data, topic): + assert data == DATA + finish.set() + + # subscribe for the event + client.subscribe(EVENT_TOPIC, on_event) + # start listening + client.start_client(uri) + # wait for the client to be ready to receive events + await client.wait_until_ready() + # trigger the server via an HTTP route + requests.get(trigger_url) + # wait for finish trigger + await asyncio.wait_for(finish.wait(), 5) + + +def test_premade_endpoint_invalid_methods_raises_error(): + invalid_methods = InvalidMethods() + premade_endpoint = WebsocketRPCEndpoint(invalid_methods) + + with pytest.raises(ValueError) as exc_info: + PubSubEndpoint(endpoint=premade_endpoint) + + assert "RpcEventServerMethods" in str(exc_info.value) + assert "InvalidMethods" in str(exc_info.value) + + +def test_premade_endpoint_with_methods_and_methods_class_raises_error(): + custom_notifier = WebSocketRpcEventNotifier() + custom_methods = RpcEventServerMethods(custom_notifier) + premade_endpoint = WebsocketRPCEndpoint(custom_methods) + + with pytest.raises(ValueError) as exc_info: + PubSubEndpoint(methods_class=RpcEventServerMethods, endpoint=premade_endpoint) + + assert "methods_class" in str(exc_info.value)