diff --git a/ably/realtime/annotations.py b/ably/realtime/annotations.py new file mode 100644 index 00000000..13f9a17d --- /dev/null +++ b/ably/realtime/annotations.py @@ -0,0 +1,240 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from ably.rest.annotations import RestAnnotations, construct_validate_annotation +from ably.transport.websockettransport import ProtocolMessageAction +from ably.types.annotation import AnnotationAction +from ably.types.channelstate import ChannelState +from ably.types.flags import Flag +from ably.util.eventemitter import EventEmitter +from ably.util.exceptions import AblyException +from ably.util.helper import is_callable_or_coroutine + +if TYPE_CHECKING: + from ably.realtime.channel import RealtimeChannel + from ably.realtime.connectionmanager import ConnectionManager + +log = logging.getLogger(__name__) + + +class RealtimeAnnotations: + """ + Provides realtime methods for managing annotations on messages, + including publishing annotations and subscribing to annotation events. + """ + + __connection_manager: ConnectionManager + __channel: RealtimeChannel + + def __init__(self, channel: RealtimeChannel, connection_manager: ConnectionManager): + """ + Initialize RealtimeAnnotations. + + Args: + channel: The Realtime Channel this annotations instance belongs to + """ + self.__channel = channel + self.__connection_manager = connection_manager + self.__subscriptions = EventEmitter() + self.__rest_annotations = RestAnnotations(channel) + + async def publish(self, msg_or_serial, annotation: dict, params: dict | None = None): + """ + Publish an annotation on a message via the realtime connection. + + Args: + msg_or_serial: Either a message serial (string) or a Message object + annotation: Dict containing annotation properties (type, name, data, etc.) + params: Optional dict of query parameters + + Returns: + None + + Raises: + AblyException: If the request fails, inputs are invalid, or channel is in unpublishable state + """ + annotation = construct_validate_annotation(msg_or_serial, annotation) + + # Check if channel and connection are in publishable state + self.__channel._throw_if_unpublishable_state() + + log.info( + f'RealtimeAnnotations.publish(), channelName = {self.__channel.name}, ' + f'sending annotation with messageSerial = {annotation.message_serial}, ' + f'type = {annotation.type}' + ) + + # Convert to wire format (array of annotations) + wire_annotation = annotation.as_dict(binary=self.__channel.ably.options.use_binary_protocol) + + # Build protocol message + protocol_message = { + "action": ProtocolMessageAction.ANNOTATION, + "channel": self.__channel.name, + "annotations": [wire_annotation], + } + + if params: + # Stringify boolean params + stringified_params = {k: str(v).lower() if isinstance(v, bool) else v for k, v in params.items()} + protocol_message["params"] = stringified_params + + # Send via WebSocket + await self.__connection_manager.send_protocol_message(protocol_message) + + async def delete( + self, + msg_or_serial, + annotation: dict, + params: dict | None = None, + ): + """ + Delete an annotation on a message. + + This is a convenience method that sets the action to 'annotation.delete' + and calls publish(). + + Args: + msg_or_serial: Either a message serial (string) or a Message object + annotation: Dict containing annotation properties + params: Optional dict of query parameters + + Returns: + None + + Raises: + AblyException: If the request fails or inputs are invalid + """ + annotation_values = annotation.copy() + annotation_values['action'] = AnnotationAction.ANNOTATION_DELETE + return await self.publish(msg_or_serial, annotation_values, params) + + async def subscribe(self, *args): + """ + Subscribe to annotation events on this channel. + + Parameters + ---------- + *args: type, listener + Subscribe type and listener + + arg1(type): str, optional + Subscribe to annotations of the given type + + arg2(listener): callable + Subscribe to all annotations on the channel + + When no type is provided, arg1 is used as the listener. + + Raises + ------ + AblyException + If unable to subscribe due to invalid channel state or missing ANNOTATION_SUBSCRIBE mode + ValueError + If no valid subscribe arguments are passed + """ + # Parse arguments similar to channel.subscribe + if len(args) == 0: + raise ValueError("annotations.subscribe called without arguments") + + if len(args) >= 2 and isinstance(args[0], str): + annotation_type = args[0] + if not args[1]: + raise ValueError("annotations.subscribe called without listener") + if not is_callable_or_coroutine(args[1]): + raise ValueError("subscribe listener must be function or coroutine function") + listener = args[1] + elif is_callable_or_coroutine(args[0]): + listener = args[0] + annotation_type = None + else: + raise ValueError('invalid subscribe arguments') + + # Register subscription + if annotation_type is not None: + self.__subscriptions.on(annotation_type, listener) + else: + self.__subscriptions.on(listener) + + await self.__channel.attach() + + # Check if ANNOTATION_SUBSCRIBE mode is enabled + if self.__channel.state == ChannelState.ATTACHED: + if Flag.ANNOTATION_SUBSCRIBE not in self.__channel.modes: + raise AblyException( + message="You are trying to add an annotation listener, but you haven't requested the " + "annotation_subscribe channel mode in ChannelOptions, so this won't do anything " + "(we only deliver annotations to clients who have explicitly requested them)", + code=93001, + status_code=400, + ) + + def unsubscribe(self, *args): + """ + Unsubscribe from annotation events on this channel. + + Parameters + ---------- + *args: type, listener + Unsubscribe type and listener + + arg1(type): str, optional + Unsubscribe from annotations of the given type + + arg2(listener): callable + Unsubscribe from all annotations on the channel + + When no type is provided, arg1 is used as the listener. + + Raises + ------ + ValueError + If no valid unsubscribe arguments are passed + """ + if len(args) == 0: + raise ValueError("annotations.unsubscribe called without arguments") + + if len(args) >= 2 and isinstance(args[0], str): + annotation_type = args[0] + listener = args[1] + self.__subscriptions.off(annotation_type, listener) + elif is_callable_or_coroutine(args[0]): + listener = args[0] + self.__subscriptions.off(listener) + else: + raise ValueError('invalid unsubscribe arguments') + + def _process_incoming(self, incoming_annotations): + """ + Process incoming annotations from the server. + + This is called internally when ANNOTATION protocol messages are received. + + Args: + incoming_annotations: List of Annotation objects received from the server + """ + for annotation in incoming_annotations: + # Emit to type-specific listeners and catch-all listeners + annotation_type = annotation.type or '' + self.__subscriptions._emit(annotation_type, annotation) + + async def get(self, msg_or_serial, params: dict | None = None): + """ + Retrieve annotations for a message with pagination support. + + This delegates to the REST implementation. + + Args: + msg_or_serial: Either a message serial (string) or a Message object + params: Optional dict of query parameters (limit, start, end, direction) + + Returns: + PaginatedResult: A paginated result containing Annotation objects + + Raises: + AblyException: If the request fails or serial is invalid + """ + # Delegate to REST implementation + return await self.__rest_annotations.get(msg_or_serial, params) diff --git a/ably/realtime/channel.py b/ably/realtime/channel.py index e0fd6251..801f4c6a 100644 --- a/ably/realtime/channel.py +++ b/ably/realtime/channel.py @@ -4,10 +4,14 @@ import logging from typing import TYPE_CHECKING +from ably.realtime.annotations import RealtimeAnnotations from ably.realtime.connection import ConnectionState +from ably.realtime.presence import RealtimePresence from ably.rest.channel import Channel from ably.rest.channel import Channels as RestChannels from ably.transport.websockettransport import ProtocolMessageAction +from ably.types.annotation import Annotation +from ably.types.channelmode import ChannelMode, decode_channel_mode, encode_channel_mode from ably.types.channeloptions import ChannelOptions from ably.types.channelstate import ChannelState, ChannelStateChange from ably.types.flags import Flag, has_flag @@ -64,6 +68,7 @@ def __init__(self, realtime: AblyRealtime, name: str, channel_options: ChannelOp self.__error_reason: AblyException | None = None self.__channel_options = channel_options or ChannelOptions() self.__params: dict[str, str] | None = None + self.__modes: list[ChannelMode] = [] # Channel mode flags from ATTACHED message # Delta-specific fields for RTL19/RTL20 compliance vcdiff_decoder = self.__realtime.options.vcdiff_decoder if self.__realtime.options.vcdiff_decoder else None @@ -74,12 +79,15 @@ def __init__(self, realtime: AblyRealtime, name: str, channel_options: ChannelOp # will be disrupted if the user called .off() to remove all listeners self.__internal_state_emitter = EventEmitter() + # Pass channel options as dictionary to parent Channel class + Channel.__init__(self, realtime, name, self.__channel_options.to_dict()) + # Initialize presence for this channel - from ably.realtime.presence import RealtimePresence + self.__presence = RealtimePresence(self) - # Pass channel options as dictionary to parent Channel class - Channel.__init__(self, realtime, name, self.__channel_options.to_dict()) + # Initialize realtime annotations for this channel (override REST annotations) + self._Channel__annotations = RealtimeAnnotations(self, realtime.connection.connection_manager) async def set_options(self, channel_options: ChannelOptions) -> None: """Set channel options""" @@ -149,8 +157,10 @@ def _attach_impl(self): "channel": self.name, } - if self.__attach_resume: - attach_msg["flags"] = Flag.ATTACH_RESUME + flags = self._encode_flags() + + if flags: + attach_msg["flags"] = flags if self.__channel_serial: attach_msg["channelSerial"] = self.__channel_serial @@ -491,8 +501,8 @@ async def _send_update( if not message.serial: raise AblyException( "Message serial is required for update/delete/append operations", - 400, - 40003 + status_code=400, + code=40003, ) # Check connection and channel state @@ -702,6 +712,8 @@ def _on_message(self, proto_msg: dict) -> None: resumed = has_flag(flags, Flag.RESUMED) # RTP1: Check for HAS_PRESENCE flag has_presence = has_flag(flags, Flag.HAS_PRESENCE) + # Store channel attach flags + self.__modes = decode_channel_mode(flags) # RTL12 if self.state == ChannelState.ATTACHED: @@ -744,6 +756,15 @@ def _on_message(self, proto_msg: dict) -> None: decoded_presence = PresenceMessage.from_encoded_array(presence_messages, cipher=self.cipher) sync_channel_serial = proto_msg.get('channelSerial') self.__presence.set_presence(decoded_presence, is_sync=True, sync_channel_serial=sync_channel_serial) + elif action == ProtocolMessageAction.ANNOTATION: + # Handle ANNOTATION messages + annotation_data = proto_msg.get('annotations', []) + try: + annotations = Annotation.from_encoded_array(annotation_data, cipher=self.cipher) + # Process annotations through the annotations handler + self.annotations._process_incoming(annotations) + except Exception as e: + log.error(f"Annotation processing error {e}. Skip annotations {annotation_data}") elif action == ProtocolMessageAction.ERROR: error = AblyException.from_dict(proto_msg.get('error')) self._notify_state(ChannelState.FAILED, reason=error) @@ -890,6 +911,15 @@ def presence(self): """Get the RealtimePresence object for this channel""" return self.__presence + @property + def annotations(self) -> RealtimeAnnotations: + return self._Channel__annotations + + @property + def modes(self): + """Get the list of channel modes""" + return self.__modes + def _start_decode_failure_recovery(self, error: AblyException) -> None: """Start RTL18 decode failure recovery procedure""" @@ -908,6 +938,20 @@ def _start_decode_failure_recovery(self, error: AblyException) -> None: self._notify_state(ChannelState.ATTACHING, reason=error) self._check_pending_state() + def _encode_flags(self) -> int | None: + if not self.__channel_options.modes and not self.__attach_resume: + return None + + flags = 0 + + if self.__attach_resume: + flags |= Flag.ATTACH_RESUME + + if self.__channel_options.modes: + flags |= encode_channel_mode(self.__channel_options.modes) + + return flags + class Channels(RestChannels): """Creates and destroys RealtimeChannel objects. diff --git a/ably/rest/annotations.py b/ably/rest/annotations.py new file mode 100644 index 00000000..7f97cf7c --- /dev/null +++ b/ably/rest/annotations.py @@ -0,0 +1,211 @@ +from __future__ import annotations + +import json +import logging +from urllib import parse + +import msgpack + +from ably.http.paginatedresult import PaginatedResult, format_params +from ably.types.annotation import ( + Annotation, + AnnotationAction, + make_annotation_response_handler, +) +from ably.types.message import Message +from ably.util.exceptions import AblyException + +log = logging.getLogger(__name__) + + +def serial_from_msg_or_serial(msg_or_serial): + """ + Extract the message serial from either a string serial or a Message object. + + Args: + msg_or_serial: Either a string serial or a Message object with a serial property + + Returns: + str: The message serial + + Raises: + AblyException: If the input is invalid or serial is missing + """ + if isinstance(msg_or_serial, str): + message_serial = msg_or_serial + elif isinstance(msg_or_serial, Message): + message_serial = msg_or_serial.serial + else: + message_serial = None + + if not message_serial or not isinstance(message_serial, str): + raise AblyException( + message='First argument of annotations.publish() must be either a Message ' + '(or at least an object with a string `serial` property) or a message serial (string)', + status_code=400, + code=40003, + ) + + return message_serial + + +def construct_validate_annotation(msg_or_serial, annotation: dict): + """ + Construct and validate an Annotation from input values. + + Args: + msg_or_serial: Either a string serial or a Message object + annotation: Dict of annotation properties or Annotation object + + Returns: + Annotation: The constructed annotation + + Raises: + AblyException: If the inputs are invalid + """ + message_serial = serial_from_msg_or_serial(msg_or_serial) + + if not annotation or (not isinstance(annotation, dict) and not isinstance(annotation, Annotation)): + raise AblyException( + message='Second argument of annotations.publish() must be a dict or Annotation ' + '(the intended annotation to publish)', + status_code=400, + code=40003, + ) + + annotation_values = annotation.copy() + annotation_values['message_serial'] = message_serial + + return Annotation.from_values(annotation_values) + + +class RestAnnotations: + """ + Provides REST API methods for managing annotations on messages. + """ + + def __init__(self, channel): + """ + Initialize RestAnnotations. + + Args: + channel: The REST Channel this annotations instance belongs to + """ + self.__channel = channel + + def __base_path_for_serial(self, serial): + """ + Build the base API path for a message serial's annotations. + + Args: + serial: The message serial + + Returns: + str: The API path + """ + channel_path = '/channels/{}/'.format(parse.quote_plus(self.__channel.name, safe=':')) + return channel_path + 'messages/' + parse.quote_plus(serial, safe=':') + '/annotations' + + async def publish( + self, + msg_or_serial, + annotation: dict | Annotation, + params: dict | None = None, + ): + """ + Publish an annotation on a message. + + Args: + msg_or_serial: Either a message serial (string) or a Message object + annotation: Dict containing annotation properties (type, name, data, etc.) or Annotation object + params: Optional dict of query parameters + + Returns: + None + + Raises: + AblyException: If the request fails or inputs are invalid + """ + annotation = construct_validate_annotation(msg_or_serial, annotation) + + # Convert to wire format + request_body = annotation.as_dict(binary=self.__channel.ably.options.use_binary_protocol) + + # Wrap in array as API expects array of annotations + request_body = [request_body] + + # Encode based on protocol + if not self.__channel.ably.options.use_binary_protocol: + request_body = json.dumps(request_body, separators=(',', ':')) + else: + request_body = msgpack.packb(request_body, use_bin_type=True) + + # Build path + path = self.__base_path_for_serial(annotation.message_serial) + if params: + params = {k: str(v).lower() if type(v) is bool else v for k, v in params.items()} + path += '?' + parse.urlencode(params) + + # Send request + await self.__channel.ably.http.post(path, body=request_body) + + async def delete( + self, + msg_or_serial, + annotation: dict | Annotation, + params: dict | None = None, + ): + """ + Delete an annotation on a message. + + This is a convenience method that sets the action to 'annotation.delete' + and calls publish(). + + Args: + msg_or_serial: Either a message serial (string) or a Message object + annotation: Dict containing annotation properties or Annotation object + params: Optional dict of query parameters + + Returns: + None + + Raises: + AblyException: If the request fails or inputs are invalid + """ + # Set action to delete + if isinstance(annotation, Annotation): + annotation_values = annotation.as_dict() + else: + annotation_values = annotation.copy() + annotation_values['action'] = AnnotationAction.ANNOTATION_DELETE + return await self.publish(msg_or_serial, annotation_values, params) + + async def get(self, msg_or_serial, params: dict | None = None): + """ + Retrieve annotations for a message with pagination support. + + Args: + msg_or_serial: Either a message serial (string) or a Message object + params: Optional dict of query parameters (limit, start, end, direction) + + Returns: + PaginatedResult: A paginated result containing Annotation objects + + Raises: + AblyException: If the request fails or serial is invalid + """ + message_serial = serial_from_msg_or_serial(msg_or_serial) + + # Build path + params_str = format_params({}, **params) if params else '' + path = self.__base_path_for_serial(message_serial) + params_str + + # Create annotation response handler + annotation_handler = make_annotation_response_handler(cipher=None) + + # Return paginated result + return await PaginatedResult.paginated_query( + self.__channel.ably.http, + url=path, + response_processor=annotation_handler + ) diff --git a/ably/rest/auth.py b/ably/rest/auth.py index 2aaa4b12..2dc5d497 100644 --- a/ably/rest/auth.py +++ b/ably/rest/auth.py @@ -90,7 +90,7 @@ def __init__(self, ably: AblyRest | AblyRealtime, options: Options): async def get_auth_transport_param(self): auth_credentials = {} if self.auth_options.client_id: - auth_credentials["client_id"] = self.auth_options.client_id + auth_credentials["clientId"] = self.auth_options.client_id if self.__auth_mechanism == Auth.Method.BASIC: key_name = self.__auth_options.key_name key_secret = self.__auth_options.key_secret diff --git a/ably/rest/channel.py b/ably/rest/channel.py index 2c1c0246..e16f209d 100644 --- a/ably/rest/channel.py +++ b/ably/rest/channel.py @@ -9,6 +9,7 @@ import msgpack from ably.http.paginatedresult import PaginatedResult, format_params +from ably.rest.annotations import RestAnnotations from ably.types.channeldetails import ChannelDetails from ably.types.message import ( Message, @@ -30,6 +31,8 @@ class Channel: + __annotations: RestAnnotations + def __init__(self, ably, name, options): self.__ably = ably self.__name = name @@ -37,6 +40,7 @@ def __init__(self, ably, name, options): self.__cipher = None self.options = options self.__presence = Presence(self) + self.__annotations = RestAnnotations(self) @catch_all async def history(self, direction=None, limit: int = None, start=None, end=None): @@ -169,8 +173,8 @@ async def _send_update( if not message.serial: raise AblyException( "Message serial is required for update/delete/append operations", - 400, - 40003 + status_code=400, + code=40003, ) if not operation: @@ -282,8 +286,8 @@ async def get_message(self, serial_or_message, timeout=None): raise AblyException( 'This message lacks a serial. Make sure you have enabled "Message annotations, ' 'updates, and deletes" in channel settings on your dashboard.', - 400, - 40003 + status_code=400, + code=40003, ) # Build the path @@ -321,8 +325,8 @@ async def get_message_versions(self, serial_or_message, params=None): raise AblyException( 'This message lacks a serial. Make sure you have enabled "Message annotations, ' 'updates, and deletes" in channel settings on your dashboard.', - 400, - 40003 + status_code=400, + code=40003, ) # Build the path @@ -363,6 +367,10 @@ def options(self): def presence(self): return self.__presence + @property + def annotations(self) -> RestAnnotations: + return self.__annotations + @options.setter def options(self, options): self.__options = options diff --git a/ably/transport/websockettransport.py b/ably/transport/websockettransport.py index 4f6f9fe0..be13d096 100644 --- a/ably/transport/websockettransport.py +++ b/ably/transport/websockettransport.py @@ -189,6 +189,7 @@ async def on_protocol_message(self, msg): ProtocolMessageAction.DETACHED, ProtocolMessageAction.MESSAGE, ProtocolMessageAction.PRESENCE, + ProtocolMessageAction.ANNOTATION, ProtocolMessageAction.SYNC ): self.connection_manager.on_channel_message(msg) diff --git a/ably/types/annotation.py b/ably/types/annotation.py new file mode 100644 index 00000000..e099d00d --- /dev/null +++ b/ably/types/annotation.py @@ -0,0 +1,221 @@ +import logging +from enum import IntEnum + +from ably.types.mixins import EncodeDataMixin +from ably.util.encoding import encode_data +from ably.util.helper import to_text + +log = logging.getLogger(__name__) + + +class AnnotationAction(IntEnum): + """Annotation action types""" + ANNOTATION_CREATE = 0 + ANNOTATION_DELETE = 1 + + +class Annotation(EncodeDataMixin): + """ + Represents an annotation on a message, such as a reaction or other metadata. + + Annotations are not encrypted as they need to be parsed by the server for summarization. + """ + + def __init__(self, + action=None, + serial=None, + message_serial=None, + type=None, + name=None, + count=None, + data=None, + encoding='', + client_id=None, + timestamp=None, + extras=None): + """ + Args: + action: The action type - either 'annotation.create' or 'annotation.delete' + serial: A unique identifier for the annotation + message_serial: The serial of the message this annotation is for + type: The type of annotation (e.g., 'reaction', 'like', etc.) + name: The name/value of the annotation (e.g., specific emoji) + count: Count associated with the annotation + data: Optional data payload for the annotation + encoding: Encoding format for the data + client_id: The client ID that created this annotation + timestamp: Timestamp of the annotation + extras: Additional metadata + """ + super().__init__(encoding) + + self.__serial = to_text(serial) if serial is not None else None + self.__message_serial = to_text(message_serial) if message_serial is not None else None + self.__type = to_text(type) if type is not None else None + self.__name = to_text(name) if name is not None else None + self.__action = action if action is not None else AnnotationAction.ANNOTATION_CREATE + self.__count = count + self.__data = data + self.__client_id = to_text(client_id) if client_id is not None else None + self.__timestamp = timestamp + self.__extras = extras + + def __eq__(self, other): + if isinstance(other, Annotation): + return (self.serial == other.serial + and self.message_serial == other.message_serial + and self.type == other.type + and self.name == other.name + and self.action == other.action) + return NotImplemented + + def __ne__(self, other): + if isinstance(other, Annotation): + result = self.__eq__(other) + if result != NotImplemented: + return not result + return NotImplemented + + @property + def action(self): + return self.__action + + @property + def serial(self): + return self.__serial + + @property + def message_serial(self): + return self.__message_serial + + @property + def type(self): + return self.__type + + @property + def name(self): + return self.__name + + @property + def count(self): + return self.__count + + @property + def data(self): + return self.__data + + @property + def client_id(self): + return self.__client_id + + @property + def timestamp(self): + return self.__timestamp + + @property + def extras(self): + return self.__extras + + def as_dict(self, binary=False): + """ + Convert annotation to dictionary format for API communication. + + Note: Annotations are not encrypted as they need to be parsed by the server. + """ + request_body = { + 'action': int(self.action) if self.action is not None else None, + 'serial': self.serial, + 'messageSerial': self.message_serial, + 'type': self.type, # Annotation type (not data type) + 'name': self.name, + 'count': self.count, + 'clientId': self.client_id or None, + 'timestamp': self.timestamp or None, + 'extras': self.extras, + **encode_data(self.data, self._encoding_array, binary) + } + + # None values aren't included + request_body = {k: v for k, v in request_body.items() if v is not None} + + return request_body + + @staticmethod + def from_encoded(obj, cipher=None, context=None): + """ + Create an Annotation from an encoded object received from the API. + + Note: cipher parameter is accepted for consistency but annotations are not encrypted. + """ + action = obj.get('action') + serial = obj.get('serial') + message_serial = obj.get('messageSerial') + type_val = obj.get('type') + name = obj.get('name') + count = obj.get('count') + data = obj.get('data') + encoding = obj.get('encoding', '') + client_id = obj.get('clientId') + timestamp = obj.get('timestamp') + extras = obj.get('extras', None) + + # Decode data if present + decoded_data = Annotation.decode(data, encoding, cipher, context) if data is not None else {} + + # Convert action from int to enum + if action is not None: + try: + action = AnnotationAction(action) + except ValueError: + # If it's not a valid action value, store as None + action = None + else: + action = None + + return Annotation( + action=action, + serial=serial, + message_serial=message_serial, + type=type_val, + name=name, + count=count, + client_id=client_id, + timestamp=timestamp, + extras=extras, + **decoded_data + ) + + @staticmethod + def from_encoded_array(obj_array, cipher=None, context=None): + """Create an array of Annotations from encoded objects""" + return [Annotation.from_encoded(obj, cipher, context) for obj in obj_array] + + @staticmethod + def from_values(values): + """Create an Annotation from a dict of values""" + return Annotation(**values) + + def __str__(self): + return ( + f"Annotation(action={self.action}, messageSerial={self.message_serial}, " + f"type={self.type}, name={self.name})" + ) + + def __repr__(self): + return self.__str__() + + +def make_annotation_response_handler(cipher=None): + """Create a response handler for annotation API responses""" + def annotation_response_handler(response): + annotations = response.to_native() + return Annotation.from_encoded_array(annotations, cipher=cipher) + return annotation_response_handler + + +def make_single_annotation_response_handler(cipher=None): + """Create a response handler for single annotation API responses""" + def single_annotation_response_handler(response): + annotation = response.to_native() + return Annotation.from_encoded(annotation, cipher=cipher) + return single_annotation_response_handler diff --git a/ably/types/channelmode.py b/ably/types/channelmode.py new file mode 100644 index 00000000..23ed735c --- /dev/null +++ b/ably/types/channelmode.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from enum import Enum + +from ably.types.flags import Flag + + +class ChannelMode(int, Enum): + PRESENCE = Flag.PRESENCE + PUBLISH = Flag.PUBLISH + SUBSCRIBE = Flag.SUBSCRIBE + PRESENCE_SUBSCRIBE = Flag.PRESENCE_SUBSCRIBE + ANNOTATION_PUBLISH = Flag.ANNOTATION_PUBLISH + ANNOTATION_SUBSCRIBE = Flag.ANNOTATION_SUBSCRIBE + + +def encode_channel_mode(modes: list[ChannelMode]) -> int: + """ + Encode a list of ChannelMode values into a bitmask. + + Args: + modes: List of ChannelMode values to encode + + Returns: + Integer bitmask with the corresponding flags set + """ + flags = 0 + + for mode in modes: + flags |= mode.value + + return flags + + +def decode_channel_mode(flags: int) -> list[ChannelMode]: + """ + Decode channel mode flags from a bitmask into a list of ChannelMode values. + + Args: + flags: Integer bitmask containing channel mode flags + + Returns: + List of ChannelMode values that are set in the flags + """ + modes = [] + + # Check each channel mode flag + for mode in ChannelMode: + if flags & mode.value: + modes.append(mode) + + return modes diff --git a/ably/types/channeloptions.py b/ably/types/channeloptions.py index 48e34dfe..02f2bd5d 100644 --- a/ably/types/channeloptions.py +++ b/ably/types/channeloptions.py @@ -2,6 +2,7 @@ from typing import Any +from ably.types.channelmode import ChannelMode from ably.util.crypto import CipherParams from ably.util.exceptions import AblyException @@ -17,36 +18,48 @@ class ChannelOptions: Channel parameters that configure the behavior of the channel. """ - def __init__(self, cipher: CipherParams | None = None, params: dict | None = None): + def __init__( + self, + cipher: CipherParams | None = None, + params: dict | None = None, + modes: list[ChannelMode] | None = None + ): self.__cipher = cipher self.__params = params + self.__modes = modes # Validate params if self.__params and not isinstance(self.__params, dict): raise AblyException("params must be a dictionary", 40000, 400) @property - def cipher(self): + def cipher(self) -> CipherParams | None: """Get cipher configuration""" return self.__cipher @property - def params(self) -> dict[str, str]: + def params(self) -> dict[str, str] | None: """Get channel parameters""" return self.__params + @property + def modes(self) -> list[ChannelMode] | None: + """Get channel parameters""" + return self.__modes + def __eq__(self, other): """Check equality with another ChannelOptions instance""" if not isinstance(other, ChannelOptions): return False return (self.__cipher == other.__cipher and - self.__params == other.__params) + self.__params == other.__params and self.__modes == other.__modes) def __hash__(self): """Make ChannelOptions hashable""" return hash(( self.__cipher, tuple(sorted(self.__params.items())) if self.__params else None, + tuple(sorted(self.__modes)) if self.__modes else None )) def to_dict(self) -> dict[str, Any]: @@ -56,6 +69,8 @@ def to_dict(self) -> dict[str, Any]: result['cipher'] = self.__cipher if self.__params: result['params'] = self.__params + if self.__modes: + result['modes'] = self.__modes return result @classmethod @@ -67,4 +82,5 @@ def from_dict(cls, options_dict: dict[str, Any]) -> ChannelOptions: return cls( cipher=options_dict.get('cipher'), params=options_dict.get('params'), + modes=options_dict.get('modes'), ) diff --git a/ably/types/flags.py b/ably/types/flags.py index 1666434c..86666019 100644 --- a/ably/types/flags.py +++ b/ably/types/flags.py @@ -13,6 +13,8 @@ class Flag(int, Enum): PUBLISH = 1 << 17 SUBSCRIBE = 1 << 18 PRESENCE_SUBSCRIBE = 1 << 19 + ANNOTATION_PUBLISH = 1 << 21 + ANNOTATION_SUBSCRIBE = 1 << 22 def has_flag(message_flags: int, flag: Flag): diff --git a/ably/types/message.py b/ably/types/message.py index 11caba57..81043608 100644 --- a/ably/types/message.py +++ b/ably/types/message.py @@ -1,27 +1,16 @@ -import base64 -import json import logging from enum import IntEnum from ably.types.mixins import DeltaExtras, EncodeDataMixin from ably.types.typedbuffer import TypedBuffer from ably.util.crypto import CipherData +from ably.util.encoding import encode_data from ably.util.exceptions import AblyException +from ably.util.helper import to_text log = logging.getLogger(__name__) -def to_text(value): - if value is None: - return value - elif isinstance(value, str): - return value - elif isinstance(value, bytes): - return value.decode() - else: - raise TypeError(f"expected string or bytes, not {type(value)}") - - class MessageVersion: """ Contains the details regarding the current version of the message - including when it was updated and by whom. @@ -234,38 +223,9 @@ def decrypt(self, channel_cipher): self.__data = decrypted_data def as_dict(self, binary=False): - data = self.data - data_type = None - encoding = self._encoding_array[:] - - if isinstance(data, (dict, list)): - encoding.append('json') - data = json.dumps(data) - data = str(data) - elif isinstance(data, str) and not binary: - pass - elif not binary and isinstance(data, (bytearray, bytes)): - data = base64.b64encode(data).decode('ascii') - encoding.append('base64') - elif isinstance(data, CipherData): - encoding.append(data.encoding_str) - data_type = data.type - if not binary: - data = base64.b64encode(data.buffer).decode('ascii') - encoding.append('base64') - else: - data = data.buffer - elif binary and isinstance(data, bytearray): - data = bytes(data) - - if not (isinstance(data, (bytes, str, list, dict, bytearray)) or data is None): - raise AblyException("Invalid data payload", 400, 40011) - request_body = { 'name': self.name, - 'data': data, 'timestamp': self.timestamp or None, - 'type': data_type or None, 'clientId': self.client_id or None, 'id': self.id or None, 'connectionId': self.connection_id or None, @@ -274,11 +234,9 @@ def as_dict(self, binary=False): 'version': self.version.as_dict() if self.version else None, 'serial': self.serial, 'action': int(self.action) if self.action is not None else None, + **encode_data(self.data, self._encoding_array, binary), } - if encoding: - request_body['encoding'] = '/'.join(encoding).strip('/') - # None values aren't included request_body = {k: v for k, v in request_body.items() if v is not None} diff --git a/ably/types/presence.py b/ably/types/presence.py index 723ceacc..7d1a3c05 100644 --- a/ably/types/presence.py +++ b/ably/types/presence.py @@ -1,5 +1,3 @@ -import base64 -import json from datetime import datetime, timedelta from urllib import parse @@ -7,7 +5,7 @@ from ably.types.mixins import EncodeDataMixin from ably.types.typedbuffer import TypedBuffer from ably.util.crypto import CipherData -from ably.util.exceptions import AblyException +from ably.util.encoding import encode_data def _ms_since_epoch(dt): @@ -151,36 +149,10 @@ def to_encoded(self, binary=False): Handles proper encoding of data including JSON serialization, base64 encoding for binary data, and encryption support. """ - data = self.data - data_type = None - encoding = self._encoding_array[:] - - # Handle different data types and build encoding string - if isinstance(data, (dict, list)): - encoding.append('json') - data = json.dumps(data) - data = str(data) - elif isinstance(data, str) and not binary: - pass - elif not binary and isinstance(data, (bytearray, bytes)): - data = base64.b64encode(data).decode('ascii') - encoding.append('base64') - elif isinstance(data, CipherData): - encoding.append(data.encoding_str) - data_type = data.type - if not binary: - data = base64.b64encode(data.buffer).decode('ascii') - encoding.append('base64') - else: - data = data.buffer - elif binary and isinstance(data, bytearray): - data = bytes(data) - - if not (isinstance(data, (bytes, str, list, dict, bytearray)) or data is None): - raise AblyException("Invalid data payload", 400, 40011) result = { 'action': self.action, + **encode_data(self.data, self._encoding_array, binary), } if self.id: @@ -189,12 +161,6 @@ def to_encoded(self, binary=False): result['clientId'] = self.client_id if self.connection_id: result['connectionId'] = self.connection_id - if data is not None: - result['data'] = data - if data_type: - result['type'] = data_type - if encoding: - result['encoding'] = '/'.join(encoding).strip('/') if self.extras: result['extras'] = self.extras if self.timestamp: diff --git a/ably/util/encoding.py b/ably/util/encoding.py new file mode 100644 index 00000000..3b3858b4 --- /dev/null +++ b/ably/util/encoding.py @@ -0,0 +1,35 @@ +import base64 +import json +from typing import Any + +from ably.util.crypto import CipherData + + +def encode_data(data: Any, encoding_array: list, binary: bool = False): + encoding = encoding_array[:] + + if isinstance(data, (dict, list)): + encoding.append('json') + data = json.dumps(data) + data = str(data) + elif isinstance(data, str) and not binary: + pass + elif not binary and isinstance(data, (bytearray, bytes)): + data = base64.b64encode(data).decode('ascii') + encoding.append('base64') + elif isinstance(data, CipherData): + encoding.append(data.encoding_str) + if not binary: + data = base64.b64encode(data.buffer).decode('ascii') + encoding.append('base64') + else: + data = data.buffer + elif binary and isinstance(data, bytearray): + data = bytes(data) + + result = { 'data': data } + + if encoding: + result['encoding'] = '/'.join(encoding).strip('/') + + return result diff --git a/ably/util/helper.py b/ably/util/helper.py index 53226f27..a35ebe6e 100644 --- a/ably/util/helper.py +++ b/ably/util/helper.py @@ -98,3 +98,13 @@ def validate_message_size(encoded_messages: list, use_binary_protocol: bool, max 400, 40009, ) + +def to_text(value): + if value is None: + return value + elif isinstance(value, str): + return value + elif isinstance(value, bytes): + return value.decode() + else: + raise TypeError(f"expected string or bytes, not {type(value)}") diff --git a/test/ably/realtime/realtimeannotations_test.py b/test/ably/realtime/realtimeannotations_test.py new file mode 100644 index 00000000..6852adaa --- /dev/null +++ b/test/ably/realtime/realtimeannotations_test.py @@ -0,0 +1,334 @@ +import asyncio +import logging +import random +import string + +import pytest + +from ably import AblyException +from ably.types.annotation import AnnotationAction +from ably.types.channelmode import ChannelMode +from ably.types.channeloptions import ChannelOptions +from ably.types.message import MessageAction +from test.ably.testapp import TestApp +from test.ably.utils import BaseAsyncTestCase, ReusableFuture, assert_waiter + +log = logging.getLogger(__name__) + + +@pytest.mark.parametrize("transport", ["json", "msgpack"], ids=["JSON", "MsgPack"]) +class TestRealtimeAnnotations(BaseAsyncTestCase): + + @pytest.fixture(autouse=True) + async def setup(self, transport): + self.test_vars = await TestApp.get_test_vars() + + client_id = ''.join(random.choices(string.ascii_letters + string.digits, k=10)) + self.realtime_client = await TestApp.get_ably_realtime( + use_binary_protocol=True if transport == 'msgpack' else False, + client_id=client_id, + ) + self.rest_client = await TestApp.get_ably_rest( + use_binary_protocol=True if transport == 'msgpack' else False, + client_id=client_id, + ) + + async def test_publish_and_subscribe_annotations(self): + """Test publishing and subscribing to annotations""" + channel_options = ChannelOptions(modes=[ + ChannelMode.PUBLISH, + ChannelMode.SUBSCRIBE, + ChannelMode.ANNOTATION_PUBLISH, + ChannelMode.ANNOTATION_SUBSCRIBE + ]) + channel_name = self.get_channel_name('mutable:publish_and_subscribe_annotations') + channel = self.realtime_client.channels.get( + channel_name, + channel_options, + ) + rest_channel = self.rest_client.channels.get(channel_name) + await channel.attach() + + # Setup annotation listener + annotation_future = asyncio.Future() + + async def on_annotation(annotation): + if not annotation_future.done(): + annotation_future.set_result(annotation) + + await channel.annotations.subscribe(on_annotation) + + # Publish a message + publish_result = await channel.publish('message', 'foobar') + + # Reset for next message (summary) + message_summary = asyncio.Future() + + def on_message(msg): + if not message_summary.done(): + message_summary.set_result(msg) + + await channel.subscribe('message', on_message) + + # Publish annotation using realtime + await channel.annotations.publish(publish_result.serials[0], { + 'type': 'reaction:distinct.v1', + 'name': '👍' + }) + + # Wait for annotation + annotation = await annotation_future + assert annotation.action == AnnotationAction.ANNOTATION_CREATE + assert annotation.message_serial == publish_result.serials[0] + assert annotation.type == 'reaction:distinct.v1' + assert annotation.name == '👍' + assert annotation.serial > annotation.message_serial + + # Wait for summary message + summary = await message_summary + assert summary.action == MessageAction.MESSAGE_SUMMARY + assert summary.serial == publish_result.serials[0] + + # Try again but with REST publish + annotation_future2 = asyncio.Future() + + async def on_annotation2(annotation): + if not annotation_future2.done(): + annotation_future2.set_result(annotation) + + await channel.annotations.subscribe(on_annotation2) + + await rest_channel.annotations.publish(publish_result.serials[0], { + 'type': 'reaction:distinct.v1', + 'name': '😕' + }) + + annotation = await annotation_future2 + assert annotation.action == AnnotationAction.ANNOTATION_CREATE + assert annotation.message_serial == publish_result.serials[0] + assert annotation.type == 'reaction:distinct.v1' + assert annotation.name == '😕' + assert annotation.serial > annotation.message_serial + + async def test_get_all_annotations_for_a_message(self): + """Test retrieving all annotations with pagination""" + channel_options = ChannelOptions(modes=[ + ChannelMode.PUBLISH, + ChannelMode.SUBSCRIBE, + ChannelMode.ANNOTATION_PUBLISH, + ChannelMode.ANNOTATION_SUBSCRIBE + ]) + channel = self.realtime_client.channels.get( + self.get_channel_name('mutable:get_all_annotations_for_a_message'), + channel_options + ) + await channel.attach() + + # Publish a message + publish_result = await channel.publish('message', 'foobar') + + # Publish multiple annotations + emojis = ['👍', '😕', '👎'] + for emoji in emojis: + await channel.annotations.publish(publish_result.serials[0], { + 'type': 'reaction:distinct.v1', + 'name': emoji + }) + + # Wait for all annotations to appear + annotations = [] + + async def check_annotations(): + nonlocal annotations + res = await channel.annotations.get(publish_result.serials[0], {}) + annotations = res.items + return len(annotations) == 3 + + await assert_waiter(check_annotations, timeout=10) + + # Verify annotations + assert annotations[0].action == AnnotationAction.ANNOTATION_CREATE + assert annotations[0].message_serial == publish_result.serials[0] + assert annotations[0].type == 'reaction:distinct.v1' + assert annotations[0].name == '👍' + assert annotations[1].name == '😕' + assert annotations[2].name == '👎' + assert annotations[1].serial > annotations[0].serial + assert annotations[2].serial > annotations[1].serial + + async def test_subscribe_by_annotation_type(self): + """Test subscribing to specific annotation types""" + channel_options = ChannelOptions(modes=[ + ChannelMode.PUBLISH, + ChannelMode.SUBSCRIBE, + ChannelMode.ANNOTATION_PUBLISH, + ChannelMode.ANNOTATION_SUBSCRIBE + ]) + channel = self.realtime_client.channels.get( + self.get_channel_name('mutable:subscribe_by_type'), + channel_options + ) + await channel.attach() + + # Setup message listener + message_future = asyncio.Future() + + def on_message(msg): + if not message_future.done(): + message_future.set_result(msg) + + await channel.subscribe('message', on_message) + + # Subscribe to specific annotation type + reaction_future = asyncio.Future() + + async def on_reaction(annotation): + if not reaction_future.done(): + reaction_future.set_result(annotation) + + await channel.annotations.subscribe('reaction:distinct.v1', on_reaction) + + # Publish message and annotation + publish_result = await channel.publish('message', 'test') + + await channel.annotations.publish(publish_result.serials[0], { + 'type': 'reaction:distinct.v1', + 'name': '👍' + }) + + # Should receive the annotation + annotation = await reaction_future + assert annotation.type == 'reaction:distinct.v1' + assert annotation.name == '👍' + + async def test_unsubscribe_annotations(self): + """Test unsubscribing from annotations""" + channel_options = ChannelOptions(modes=[ + ChannelMode.PUBLISH, + ChannelMode.SUBSCRIBE, + ChannelMode.ANNOTATION_PUBLISH, + ChannelMode.ANNOTATION_SUBSCRIBE + ]) + channel = self.realtime_client.channels.get( + self.get_channel_name('mutable:unsubscribe_annotations'), + channel_options + ) + await channel.attach() + + annotations_received = [] + annotation_future = ReusableFuture() + + async def on_annotation(annotation): + annotations_received.append(annotation) + annotation_future.set_result(annotation) + + await channel.annotations.subscribe(on_annotation) + + # Publish message and first annotation + publish_result = await channel.publish('message', 'test') + + await channel.annotations.publish(publish_result.serials[0], { + 'type': 'reaction:distinct.v1', + 'name': '👍' + }) + + # Wait for the first annotation to appear + await annotation_future.get() + assert len(annotations_received) == 1 + + # Unsubscribe + channel.annotations.unsubscribe(on_annotation) + + await channel.annotations.subscribe(lambda annotation: annotation_future.set_result(annotation)) + + # Publish another annotation + await channel.annotations.publish(publish_result.serials[0], { + 'type': 'reaction:distinct.v1', + 'name': '😕' + }) + + # Wait for the second annotation to appear in another listener + await annotation_future.get() + + assert len(annotations_received) == 1 + + async def test_delete_annotation(self): + """Test deleting annotations""" + channel_options = ChannelOptions(modes=[ + ChannelMode.PUBLISH, + ChannelMode.SUBSCRIBE, + ChannelMode.ANNOTATION_PUBLISH, + ChannelMode.ANNOTATION_SUBSCRIBE + ]) + channel = self.realtime_client.channels.get( + self.get_channel_name('mutable:delete_annotation'), + channel_options + ) + await channel.attach() + + # Setup message listener + message_future = asyncio.Future() + + def on_message(msg): + if not message_future.done(): + message_future.set_result(msg) + + await channel.subscribe('message', on_message) + + annotations_received = [] + annotation_future = ReusableFuture() + async def on_annotation(annotation): + annotations_received.append(annotation) + annotation_future.set_result(annotation) + + await channel.annotations.subscribe(on_annotation) + + # Publish message and annotation + await channel.publish('message', 'test') + message = await message_future + + await channel.annotations.publish(message.serial, { + 'type': 'reaction:distinct.v1', + 'name': '👍' + }) + + await annotation_future.get() + + # Wait for create annotation + assert len(annotations_received) == 1 + assert annotations_received[0].action == AnnotationAction.ANNOTATION_CREATE + + # Delete the annotation + await channel.annotations.delete(message.serial, { + 'type': 'reaction:distinct.v1', + 'name': '👍' + }) + + # Wait for delete annotation + await annotation_future.get() + + assert len(annotations_received) == 2 + assert annotations_received[1].action == AnnotationAction.ANNOTATION_DELETE + + async def test_subscribe_without_annotation_mode_fails(self): + """Test that subscribing without annotation_subscribe mode raises an error""" + # Create channel without annotation_subscribe mode + channel_options = ChannelOptions(modes=[ + ChannelMode.PUBLISH, + ChannelMode.SUBSCRIBE + ]) + channel = self.realtime_client.channels.get( + self.get_channel_name('mutable:no_annotation_mode'), + channel_options + ) + await channel.attach() + + async def on_annotation(annotation): + pass + + # Should raise error about missing annotation_subscribe mode + with pytest.raises(AblyException) as exc_info: + await channel.annotations.subscribe(on_annotation) + + assert exc_info.value.status_code == 400 + assert 'annotation_subscribe' in str(exc_info.value).lower() diff --git a/test/ably/realtime/realtimeconnection_test.py b/test/ably/realtime/realtimeconnection_test.py index b38c5aaf..f1eb9003 100644 --- a/test/ably/realtime/realtimeconnection_test.py +++ b/test/ably/realtime/realtimeconnection_test.py @@ -369,7 +369,7 @@ async def test_connection_client_id_query_params(self): ably = await TestApp.get_ably_realtime(client_id=client_id) await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5) - assert ably.connection.connection_manager.transport.params["client_id"] == client_id + assert ably.connection.connection_manager.transport.params["clientId"] == client_id assert ably.auth.client_id == client_id await ably.close() diff --git a/test/ably/rest/restannotations_test.py b/test/ably/rest/restannotations_test.py new file mode 100644 index 00000000..8969e84d --- /dev/null +++ b/test/ably/rest/restannotations_test.py @@ -0,0 +1,203 @@ +import logging +import random +import string + +import pytest + +from ably import AblyException +from ably.types.annotation import AnnotationAction +from ably.types.message import Message +from test.ably.testapp import TestApp +from test.ably.utils import BaseAsyncTestCase, assert_waiter + +log = logging.getLogger(__name__) + + +@pytest.mark.parametrize("transport", ["json", "msgpack"], ids=["JSON", "MsgPack"]) +class TestRestAnnotations(BaseAsyncTestCase): + + @pytest.fixture(autouse=True) + async def setup(self, transport): + self.test_vars = await TestApp.get_test_vars() + client_id = ''.join(random.choices(string.ascii_letters + string.digits, k=10)) + self.ably = await TestApp.get_ably_rest( + use_binary_protocol=True if transport == 'msgpack' else False, + client_id=client_id, + ) + + async def test_publish_annotation_success(self): + """Test successfully publishing an annotation on a message""" + channel = self.ably.channels[self.get_channel_name('mutable:annotation_publish_test')] + + # First publish a message + result = await channel.publish('test-event', 'test data') + assert result.serials is not None + assert len(result.serials) > 0 + serial = result.serials[0] + + # Publish an annotation + await channel.annotations.publish(serial, { + 'type': 'reaction:distinct.v1', + 'name': '👍' + }) + + annotations_result = None + + # Wait for annotations to appear + async def check_annotations(): + nonlocal annotations_result + annotations_result = await channel.annotations.get(serial) + return len(annotations_result.items) == 1 + + await assert_waiter(check_annotations, timeout=10) + + # Get annotations to verify + annotations = annotations_result.items + assert len(annotations) >= 1 + assert annotations[0].message_serial == serial + assert annotations[0].type == 'reaction:distinct.v1' + assert annotations[0].name == '👍' + + async def test_publish_annotation_with_message_object(self): + """Test publishing an annotation using a Message object""" + channel = self.ably.channels[self.get_channel_name('mutable:annotation_publish_msg_obj')] + + # Publish a message + result = await channel.publish('test-event', 'test data') + serial = result.serials[0] + + # Create a message object + message = Message(serial=serial) + + # Publish annotation with message object + await channel.annotations.publish(message, { + 'type': 'reaction:distinct.v1', + 'name': '😕' + }) + + annotations_result = None + + # Wait for annotations to appear + async def check_annotations(): + nonlocal annotations_result + annotations_result = await channel.annotations.get(serial) + return len(annotations_result.items) == 1 + + await assert_waiter(check_annotations, timeout=10) + + # Verify + annotations_result = await channel.annotations.get(serial) + annotations = annotations_result.items + assert len(annotations) >= 1 + assert annotations[0].name == '😕' + + async def test_publish_annotation_without_serial_fails(self): + """Test that publishing without a serial raises an exception""" + channel = self.ably.channels[self.get_channel_name('mutable:annotation_no_serial')] + + with pytest.raises(AblyException) as exc_info: + await channel.annotations.publish(None, {'type': 'reaction', 'name': '👍'}) + + assert exc_info.value.status_code == 400 + assert exc_info.value.code == 40003 + + async def test_delete_annotation_success(self): + """Test successfully deleting an annotation""" + channel = self.ably.channels[self.get_channel_name('mutable:annotation_delete_test')] + + # Publish a message + result = await channel.publish('test-event', 'test data') + serial = result.serials[0] + + # Publish an annotation + await channel.annotations.publish(serial, { + 'type': 'reaction:distinct.v1', + 'name': '👍' + }) + + annotations_result = None + + # Wait for annotation to appear + async def check_annotation(): + nonlocal annotations_result + annotations_result = await channel.annotations.get(serial) + return len(annotations_result.items) >= 1 + + await assert_waiter(check_annotation, timeout=10) + + # Delete the annotation + await channel.annotations.delete(serial, { + 'type': 'reaction:distinct.v1', + 'name': '👍' + }) + + # Wait for annotation to appear + async def check_deleted_annotation(): + nonlocal annotations_result + annotations_result = await channel.annotations.get(serial) + return len(annotations_result.items) >= 2 + + await assert_waiter(check_deleted_annotation, timeout=10) + assert annotations_result.items[-1].type == 'reaction:distinct.v1' + assert annotations_result.items[-1].action == AnnotationAction.ANNOTATION_DELETE + + async def test_get_all_annotations(self): + """Test retrieving all annotations for a message""" + channel = self.ably.channels[self.get_channel_name('mutable:annotation_get_all_test')] + + # Publish a message + result = await channel.publish('test-event', 'test data') + serial = result.serials[0] + + # Publish annotations + await channel.annotations.publish(serial, {'type': 'reaction:distinct.v1', 'name': '👍'}) + await channel.annotations.publish(serial, {'type': 'reaction:distinct.v1', 'name': '😕'}) + await channel.annotations.publish(serial, {'type': 'reaction:distinct.v1', 'name': '👎'}) + + # Wait and get all annotations + async def check_annotations(): + res = await channel.annotations.get(serial) + return len(res.items) >= 3 + + await assert_waiter(check_annotations, timeout=10) + + annotations_result = await channel.annotations.get(serial) + annotations = annotations_result.items + assert len(annotations) >= 3 + assert annotations[0].type == 'reaction:distinct.v1' + assert annotations[0].message_serial == serial + # Verify serials are in order + if len(annotations) > 1: + assert annotations[1].serial > annotations[0].serial + if len(annotations) > 2: + assert annotations[2].serial > annotations[1].serial + + async def test_annotation_properties(self): + """Test that annotation properties are correctly set""" + channel = self.ably.channels[self.get_channel_name('mutable:annotation_properties_test')] + + # Publish a message + result = await channel.publish('test-event', 'test data') + serial = result.serials[0] + + # Publish annotation with various properties + await channel.annotations.publish(serial, { + 'type': 'reaction:distinct.v1', + 'name': '❤️', + 'data': {'count': 5} + }) + + # Retrieve and verify + async def check_annotation(): + res = await channel.annotations.get(serial) + return len(res.items) > 0 + + await assert_waiter(check_annotation, timeout=10) + + annotations_result = await channel.annotations.get(serial) + annotation = annotations_result.items[0] + assert annotation.message_serial == serial + assert annotation.type == 'reaction:distinct.v1' + assert annotation.name == '❤️' + assert annotation.serial is not None + assert annotation.serial > serial diff --git a/test/ably/utils.py b/test/ably/utils.py index 09658fc0..eb75d3e6 100644 --- a/test/ably/utils.py +++ b/test/ably/utils.py @@ -229,6 +229,9 @@ def assert_waiter_sync(block: Callable[[], bool], timeout: float = 10) -> None: class WaitableEvent: + """ + Replacement for asyncio.Future that will work with autogenerated sync tests. + """ def __init__(self): self._finished = False @@ -243,3 +246,20 @@ async def wait(self, timeout=10): def finish(self): self._finished = True + +class ReusableFuture: + """ + A reusable future that after each wait() resets itself and wait for the next value. + """ + def __init__(self): + self.__future = asyncio.Future() + + async def get(self, timeout=10): + await asyncio.wait_for(self.__future, timeout=timeout) + self.__future = asyncio.Future() + + def set_result(self, result): + self.__future.set_result(result) + + def set_exception(self, exception): + self.__future.set_exception(exception) diff --git a/uv.lock b/uv.lock index 1b196ab7..5b48323d 100644 --- a/uv.lock +++ b/uv.lock @@ -10,7 +10,7 @@ resolution-markers = [ [[package]] name = "ably" -version = "2.1.3" +version = "3.0.0" source = { editable = "." } dependencies = [ { name = "h2", version = "4.1.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" },