Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,6 @@ dmypy.json

# editors
.vscode/
.idea/
.idea/

venv/
128 changes: 100 additions & 28 deletions fastapi_websocket_pubsub/pub_sub_server.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio
from typing import Coroutine, List, Union
from typing import Coroutine, List, Union, Optional

from fastapi import WebSocket
from fastapi_websocket_rpc import WebsocketRPCEndpoint
from fastapi_websocket_rpc.rpc_channel import RpcChannel
from fastapi_websocket_rpc.rpc_methods import RpcMethodsBase

from .logger import get_logger
from .event_broadcaster import EventBroadcaster
Expand Down Expand Up @@ -34,7 +35,8 @@ def __init__(
on_connect: List[Coroutine] = None,
on_disconnect: List[Coroutine] = None,
rpc_channel_get_remote_id: bool = False,
ignore_broadcaster_disconnected = True,
ignore_broadcaster_disconnected=True,
endpoint: Optional[WebsocketRPCEndpoint] = None,
):
"""
The PubSub endpoint recives subscriptions from clients and publishes data back to them upon receiving relevant publications.
Expand All @@ -57,28 +59,92 @@ def __init__(
on_connect (List[Coroutine]): callbacks on connection being established (each callback is called with the channel)
on_disconnect (List[Coroutine]): callbacks on connection termination (each callback is called with the channel)
ignore_broadcaster_disconnected: Don't end main loop if broadcaster's reader task ends (due to underlying disconnection)
endpoint (Optional[WebsocketRPCEndpoint]): An optional pre-configured `WebsocketRPCEndpoint` instance to use instead of creating a new one.
`met
"""
self.notifier = (
notifier if notifier is not None else WebSocketRpcEventNotifier()
)
self.broadcaster = (
broadcaster
if isinstance(broadcaster, EventBroadcaster) or broadcaster is None
else EventBroadcaster(broadcaster, self.notifier)
)
self.methods = (
methods_class(self.notifier)
if methods_class is not None
else RpcEventServerMethods(self.notifier)
)
if on_disconnect is None:
on_disconnect = []
self.endpoint = WebsocketRPCEndpoint(
self.methods,
on_disconnect=[self.on_disconnect, *on_disconnect],
on_connect=on_connect,
rpc_channel_get_remote_id=rpc_channel_get_remote_id,
)
if on_connect is None:
on_connect = []

if endpoint is None:
self.notifier = (
notifier if notifier is not None else WebSocketRpcEventNotifier()
)
self.broadcaster = (
broadcaster
if isinstance(broadcaster, EventBroadcaster) or broadcaster is None
else EventBroadcaster(broadcaster, self.notifier)
)
self.methods = (
methods_class(self.notifier)
if methods_class is not None
else RpcEventServerMethods(self.notifier)
)
self.endpoint = WebsocketRPCEndpoint(
self.methods,
on_disconnect=[self.on_disconnect, *on_disconnect],
on_connect=on_connect,
rpc_channel_get_remote_id=rpc_channel_get_remote_id,
)
else:
# Extract notifier from endpoint's methods if they exist
if endpoint.methods is not None and isinstance(
endpoint.methods, RpcEventServerMethods
):
# Use the notifier from the existing methods
self.notifier = endpoint.methods.event_notifier
else:
# No valid methods, use provided notifier or create new one
self.notifier = (
notifier if notifier is not None else WebSocketRpcEventNotifier()
)

# Setup broadcaster with the correct notifier
self.broadcaster = (
broadcaster
if isinstance(broadcaster, EventBroadcaster) or broadcaster is None
else EventBroadcaster(broadcaster, self.notifier)
)

# If the endpoint has no methods, or methods is `RpcMethodsBase`, we set up the correct methods
if endpoint.methods is None or type(endpoint.methods) is RpcMethodsBase:
# No methods yet so create them with our notifier
self.methods = (
methods_class(self.notifier)
if methods_class is not None
else RpcEventServerMethods(self.notifier)
)
endpoint.methods = self.methods
elif isinstance(endpoint.methods, RpcEventServerMethods):
# The endpoint has valid methods so we keep them
self.methods = endpoint.methods

# If `methods_class` was provided, it's incompatible with premade endpoint
if methods_class is not None:
raise ValueError(
"Cannot specify `methods_class` when using a premade endpoint "
"that already has methods configured. Either pass `methods=None` "
"to the endpoint or don't specify `methods_class`."
)
else:
# Invalid methods type
raise ValueError(
"Premade endpoint must have methods that derive from `RpcEventServerMethods` or None. "
f"Got {type(endpoint.methods).__name__} instead."
)

if endpoint._on_disconnect is None:
endpoint._on_disconnect = []
if endpoint._on_connect is None:
endpoint._on_connect = []

# Ensure we register this pubsub endpoint's `on_disconnect` handler as the first to run
endpoint._on_disconnect.insert(0, self.on_disconnect)
endpoint._on_disconnect.extend(on_disconnect)
endpoint._on_connect.extend(on_connect)
self.endpoint = endpoint

self._rpc_channel_get_remote_id = rpc_channel_get_remote_id
# server id used to publish events for clients
self._id = self.notifier.gen_subscriber_id()
Expand All @@ -92,7 +158,8 @@ async def subscribe(
return await self.notifier.subscribe(self._subscriber_id, topics, callback)

async def unsubscribe(
self, topics: Union[TopicList, ALL_TOPICS]) -> List[Subscription]:
self, topics: Union[TopicList, ALL_TOPICS]
) -> List[Subscription]:
return await self.notifier.unsubscribe(self._subscriber_id, topics)

async def publish(self, topics: Union[TopicList, Topic], data=None):
Expand All @@ -107,7 +174,7 @@ async def publish(self, topics: Union[TopicList, Topic], data=None):
# sharing here means - the broadcaster listens in to the notifier as well
logger.debug(f"Publishing message to topics: {topics}")
if self.broadcaster is not None:
logger.debug(f"Acquiring broadcaster sharing context")
logger.debug("Acquiring broadcaster sharing context")
async with self.broadcaster.get_context(listen=False, share=True):
await self.notifier.notify(topics, data, notifier_id=self._id)
# otherwise just notify
Expand Down Expand Up @@ -136,14 +203,19 @@ async def main_loop(self, websocket: WebSocket, client_id: str = None, **kwargs)
async with self.broadcaster:
logger.debug("Entering endpoint's main loop with broadcaster")
if self._ignore_broadcaster_disconnected:
await self.endpoint.main_loop(websocket, client_id=client_id, **kwargs)
await self.endpoint.main_loop(
websocket, client_id=client_id, **kwargs
)
else:
main_loop_task = asyncio.create_task(
self.endpoint.main_loop(websocket, client_id=client_id, **kwargs)
self.endpoint.main_loop(
websocket, client_id=client_id, **kwargs
)
)
done, pending = await asyncio.wait(
[main_loop_task, self.broadcaster.get_reader_task()],
return_when=asyncio.FIRST_COMPLETED,
)
done, pending = await asyncio.wait([main_loop_task,
self.broadcaster.get_reader_task()],
return_when=asyncio.FIRST_COMPLETED)
logger.debug(f"task is done: {done}")
# broadcaster's reader task is used by other endpoints and shouldn't be cancelled
if main_loop_task in pending:
Expand Down
Loading