diff --git a/slack_bolt/__init__.py b/slack_bolt/__init__.py index 4e43252fd..dfe950bf2 100644 --- a/slack_bolt/__init__.py +++ b/slack_bolt/__init__.py @@ -14,6 +14,7 @@ from .context.fail import Fail from .context.respond import Respond from .context.say import Say +from .context.say_stream import SayStream from .kwargs_injection import Args from .listener import Listener from .listener_matcher import CustomListenerMatcher @@ -42,6 +43,7 @@ "Fail", "Respond", "Say", + "SayStream", "Args", "Listener", "CustomListenerMatcher", diff --git a/slack_bolt/async_app.py b/slack_bolt/async_app.py index fdf724d4c..f95d952aa 100644 --- a/slack_bolt/async_app.py +++ b/slack_bolt/async_app.py @@ -59,6 +59,7 @@ async def command(ack, body, respond): from .context.set_suggested_prompts.async_set_suggested_prompts import AsyncSetSuggestedPrompts from .context.get_thread_context.async_get_thread_context import AsyncGetThreadContext from .context.save_thread_context.async_save_thread_context import AsyncSaveThreadContext +from .context.say_stream.async_say_stream import AsyncSayStream __all__ = [ "AsyncApp", @@ -66,6 +67,7 @@ async def command(ack, body, respond): "AsyncBoltContext", "AsyncRespond", "AsyncSay", + "AsyncSayStream", "AsyncListener", "AsyncCustomListenerMatcher", "AsyncBoltRequest", diff --git a/slack_bolt/context/async_context.py b/slack_bolt/context/async_context.py index 631f74a82..33f260d38 100644 --- a/slack_bolt/context/async_context.py +++ b/slack_bolt/context/async_context.py @@ -10,6 +10,7 @@ from slack_bolt.context.get_thread_context.async_get_thread_context import AsyncGetThreadContext from slack_bolt.context.save_thread_context.async_save_thread_context import AsyncSaveThreadContext from slack_bolt.context.say.async_say import AsyncSay +from slack_bolt.context.say_stream.async_say_stream import AsyncSayStream from slack_bolt.context.set_status.async_set_status import AsyncSetStatus from slack_bolt.context.set_suggested_prompts.async_set_suggested_prompts import AsyncSetSuggestedPrompts from slack_bolt.context.set_title.async_set_title import AsyncSetTitle @@ -203,6 +204,10 @@ def set_suggested_prompts(self) -> Optional[AsyncSetSuggestedPrompts]: def get_thread_context(self) -> Optional[AsyncGetThreadContext]: return self.get("get_thread_context") + @property + def say_stream(self) -> Optional[AsyncSayStream]: + return self.get("say_stream") + @property def save_thread_context(self) -> Optional[AsyncSaveThreadContext]: return self.get("save_thread_context") diff --git a/slack_bolt/context/base_context.py b/slack_bolt/context/base_context.py index 843d5ef60..502febcb8 100644 --- a/slack_bolt/context/base_context.py +++ b/slack_bolt/context/base_context.py @@ -38,6 +38,7 @@ class BaseContext(dict): "set_status", "set_title", "set_suggested_prompts", + "say_stream", ] # Note that these items are not copyable, so when you add new items to this list, # you must modify ThreadListenerRunner/AsyncioListenerRunner's _build_lazy_request method to pass the values. diff --git a/slack_bolt/context/context.py b/slack_bolt/context/context.py index 48df4ad32..6184d5083 100644 --- a/slack_bolt/context/context.py +++ b/slack_bolt/context/context.py @@ -10,6 +10,7 @@ from slack_bolt.context.respond import Respond from slack_bolt.context.save_thread_context import SaveThreadContext from slack_bolt.context.say import Say +from slack_bolt.context.say_stream import SayStream from slack_bolt.context.set_status import SetStatus from slack_bolt.context.set_suggested_prompts import SetSuggestedPrompts from slack_bolt.context.set_title import SetTitle @@ -204,6 +205,10 @@ def set_suggested_prompts(self) -> Optional[SetSuggestedPrompts]: def get_thread_context(self) -> Optional[GetThreadContext]: return self.get("get_thread_context") + @property + def say_stream(self) -> Optional[SayStream]: + return self.get("say_stream") + @property def save_thread_context(self) -> Optional[SaveThreadContext]: return self.get("save_thread_context") diff --git a/slack_bolt/context/say_stream/__init__.py b/slack_bolt/context/say_stream/__init__.py new file mode 100644 index 000000000..86db7b1cc --- /dev/null +++ b/slack_bolt/context/say_stream/__init__.py @@ -0,0 +1,6 @@ +# Don't add async module imports here +from .say_stream import SayStream + +__all__ = [ + "SayStream", +] diff --git a/slack_bolt/context/say_stream/async_say_stream.py b/slack_bolt/context/say_stream/async_say_stream.py new file mode 100644 index 000000000..dc752d02a --- /dev/null +++ b/slack_bolt/context/say_stream/async_say_stream.py @@ -0,0 +1,74 @@ +import warnings +from typing import Optional + +from slack_sdk.web.async_client import AsyncWebClient +from slack_sdk.web.async_chat_stream import AsyncChatStream + +from slack_bolt.warning import ExperimentalWarning + + +class AsyncSayStream: + client: AsyncWebClient + channel: Optional[str] + recipient_team_id: Optional[str] + recipient_user_id: Optional[str] + thread_ts: Optional[str] + + def __init__( + self, + *, + client: AsyncWebClient, + channel: Optional[str] = None, + recipient_team_id: Optional[str] = None, + recipient_user_id: Optional[str] = None, + thread_ts: Optional[str] = None, + ): + self.client = client + self.channel = channel + self.recipient_team_id = recipient_team_id + self.recipient_user_id = recipient_user_id + self.thread_ts = thread_ts + + async def __call__( + self, + *, + buffer_size: Optional[int] = None, + channel: Optional[str] = None, + recipient_team_id: Optional[str] = None, + recipient_user_id: Optional[str] = None, + thread_ts: Optional[str] = None, + **kwargs, + ) -> AsyncChatStream: + """Starts a new chat stream with context. + + Warning: This is an experimental feature and may change in future versions. + """ + warnings.warn( + "say_stream is experimental and may change in future versions.", + category=ExperimentalWarning, + stacklevel=2, + ) + + channel = channel or self.channel + thread_ts = thread_ts or self.thread_ts + if channel is None: + raise ValueError("say_stream without channel here is unsupported") + if thread_ts is None: + raise ValueError("say_stream without thread_ts here is unsupported") + + if buffer_size is not None: + return await self.client.chat_stream( + buffer_size=buffer_size, + channel=channel, + recipient_team_id=recipient_team_id or self.recipient_team_id, + recipient_user_id=recipient_user_id or self.recipient_user_id, + thread_ts=thread_ts, + **kwargs, + ) + return await self.client.chat_stream( + channel=channel, + recipient_team_id=recipient_team_id or self.recipient_team_id, + recipient_user_id=recipient_user_id or self.recipient_user_id, + thread_ts=thread_ts, + **kwargs, + ) diff --git a/slack_bolt/context/say_stream/say_stream.py b/slack_bolt/context/say_stream/say_stream.py new file mode 100644 index 000000000..1e1d7985f --- /dev/null +++ b/slack_bolt/context/say_stream/say_stream.py @@ -0,0 +1,74 @@ +import warnings +from typing import Optional + +from slack_sdk import WebClient +from slack_sdk.web.chat_stream import ChatStream + +from slack_bolt.warning import ExperimentalWarning + + +class SayStream: + client: WebClient + channel: Optional[str] + recipient_team_id: Optional[str] + recipient_user_id: Optional[str] + thread_ts: Optional[str] + + def __init__( + self, + *, + client: WebClient, + channel: Optional[str] = None, + recipient_team_id: Optional[str] = None, + recipient_user_id: Optional[str] = None, + thread_ts: Optional[str] = None, + ): + self.client = client + self.channel = channel + self.recipient_team_id = recipient_team_id + self.recipient_user_id = recipient_user_id + self.thread_ts = thread_ts + + def __call__( + self, + *, + buffer_size: Optional[int] = None, + channel: Optional[str] = None, + recipient_team_id: Optional[str] = None, + recipient_user_id: Optional[str] = None, + thread_ts: Optional[str] = None, + **kwargs, + ) -> ChatStream: + """Starts a new chat stream with context. + + Warning: This is an experimental feature and may change in future versions. + """ + warnings.warn( + "say_stream is experimental and may change in future versions.", + category=ExperimentalWarning, + stacklevel=2, + ) + + channel = channel or self.channel + thread_ts = thread_ts or self.thread_ts + if channel is None: + raise ValueError("say_stream without channel here is unsupported") + if thread_ts is None: + raise ValueError("say_stream without thread_ts here is unsupported") + + if buffer_size is not None: + return self.client.chat_stream( + buffer_size=buffer_size, + channel=channel, + recipient_team_id=recipient_team_id or self.recipient_team_id, + recipient_user_id=recipient_user_id or self.recipient_user_id, + thread_ts=thread_ts, + **kwargs, + ) + return self.client.chat_stream( + channel=channel, + recipient_team_id=recipient_team_id or self.recipient_team_id, + recipient_user_id=recipient_user_id or self.recipient_user_id, + thread_ts=thread_ts, + **kwargs, + ) diff --git a/slack_bolt/kwargs_injection/args.py b/slack_bolt/kwargs_injection/args.py index 113e39c08..dfb242fd1 100644 --- a/slack_bolt/kwargs_injection/args.py +++ b/slack_bolt/kwargs_injection/args.py @@ -11,6 +11,7 @@ from slack_bolt.agent.agent import BoltAgent from slack_bolt.context.save_thread_context import SaveThreadContext from slack_bolt.context.say import Say +from slack_bolt.context.say_stream import SayStream from slack_bolt.context.set_status import SetStatus from slack_bolt.context.set_suggested_prompts import SetSuggestedPrompts from slack_bolt.context.set_title import SetTitle @@ -105,6 +106,8 @@ def handle_buttons(args): """`save_thread_context()` utility function for AI Agents & Assistants""" agent: Optional[BoltAgent] """`agent` listener argument for AI Agents & Assistants""" + say_stream: Optional[SayStream] + """`say_stream()` utility function for AI Agents & Assistants""" # middleware next: Callable[[], None] """`next()` utility function, which tells the middleware chain that it can continue with the next one""" @@ -139,6 +142,7 @@ def __init__( get_thread_context: Optional[GetThreadContext] = None, save_thread_context: Optional[SaveThreadContext] = None, agent: Optional[BoltAgent] = None, + say_stream: Optional[SayStream] = None, # As this method is not supposed to be invoked by bolt-python users, # the naming conflict with the built-in one affects # only the internals of this method @@ -173,6 +177,7 @@ def __init__( self.get_thread_context = get_thread_context self.save_thread_context = save_thread_context self.agent = agent + self.say_stream = say_stream self.next: Callable[[], None] = next self.next_: Callable[[], None] = next diff --git a/slack_bolt/kwargs_injection/async_args.py b/slack_bolt/kwargs_injection/async_args.py index 1f1dde024..19719e900 100644 --- a/slack_bolt/kwargs_injection/async_args.py +++ b/slack_bolt/kwargs_injection/async_args.py @@ -10,6 +10,7 @@ from slack_bolt.context.get_thread_context.async_get_thread_context import AsyncGetThreadContext from slack_bolt.context.save_thread_context.async_save_thread_context import AsyncSaveThreadContext from slack_bolt.context.say.async_say import AsyncSay +from slack_bolt.context.say_stream.async_say_stream import AsyncSayStream from slack_bolt.context.set_status.async_set_status import AsyncSetStatus from slack_bolt.context.set_suggested_prompts.async_set_suggested_prompts import AsyncSetSuggestedPrompts from slack_bolt.context.set_title.async_set_title import AsyncSetTitle @@ -104,6 +105,8 @@ async def handle_buttons(args): """`save_thread_context()` utility function for AI Agents & Assistants""" agent: Optional[AsyncBoltAgent] """`agent` listener argument for AI Agents & Assistants""" + say_stream: Optional[AsyncSayStream] + """`say_stream()` utility function for AI Agents & Assistants""" # middleware next: Callable[[], Awaitable[None]] """`next()` utility function, which tells the middleware chain that it can continue with the next one""" @@ -138,6 +141,7 @@ def __init__( get_thread_context: Optional[AsyncGetThreadContext] = None, save_thread_context: Optional[AsyncSaveThreadContext] = None, agent: Optional[AsyncBoltAgent] = None, + say_stream: Optional[AsyncSayStream] = None, next: Callable[[], Awaitable[None]], **kwargs, # noqa ): @@ -169,6 +173,7 @@ def __init__( self.get_thread_context = get_thread_context self.save_thread_context = save_thread_context self.agent = agent + self.say_stream = say_stream self.next: Callable[[], Awaitable[None]] = next self.next_: Callable[[], Awaitable[None]] = next diff --git a/slack_bolt/kwargs_injection/async_utils.py b/slack_bolt/kwargs_injection/async_utils.py index aa84b2d11..534fb6133 100644 --- a/slack_bolt/kwargs_injection/async_utils.py +++ b/slack_bolt/kwargs_injection/async_utils.py @@ -60,6 +60,7 @@ def build_async_required_kwargs( "set_suggested_prompts": request.context.set_suggested_prompts, "get_thread_context": request.context.get_thread_context, "save_thread_context": request.context.save_thread_context, + "say_stream": request.context.say_stream, # middleware "next": next_func, "next_": next_func, # for the middleware using Python's built-in `next()` function diff --git a/slack_bolt/kwargs_injection/utils.py b/slack_bolt/kwargs_injection/utils.py index 5cd410a07..101e00099 100644 --- a/slack_bolt/kwargs_injection/utils.py +++ b/slack_bolt/kwargs_injection/utils.py @@ -59,6 +59,7 @@ def build_required_kwargs( "set_title": request.context.set_title, "set_suggested_prompts": request.context.set_suggested_prompts, "save_thread_context": request.context.save_thread_context, + "say_stream": request.context.say_stream, # middleware "next": next_func, "next_": next_func, # for the middleware using Python's built-in `next()` function diff --git a/slack_bolt/middleware/attaching_agent_kwargs/async_attaching_agent_kwargs.py b/slack_bolt/middleware/attaching_agent_kwargs/async_attaching_agent_kwargs.py index 0b43c21ce..08851c1eb 100644 --- a/slack_bolt/middleware/attaching_agent_kwargs/async_attaching_agent_kwargs.py +++ b/slack_bolt/middleware/attaching_agent_kwargs/async_attaching_agent_kwargs.py @@ -2,6 +2,7 @@ from slack_bolt.context.assistant.async_assistant_utilities import AsyncAssistantUtilities from slack_bolt.context.assistant.thread_context_store.async_store import AsyncAssistantThreadContextStore +from slack_bolt.context.say_stream.async_say_stream import AsyncSayStream from slack_bolt.middleware.async_middleware import AsyncMiddleware from slack_bolt.request.async_request import AsyncBoltRequest from slack_bolt.request.payload_utils import is_assistant_event, to_event @@ -36,4 +37,15 @@ async def async_process( req.context["set_suggested_prompts"] = assistant.set_suggested_prompts req.context["get_thread_context"] = assistant.get_thread_context req.context["save_thread_context"] = assistant.save_thread_context + + # TODO: in the future we might want to introduce a "proper" extract_ts utility + thread_ts = req.context.thread_ts or event.get("ts") + if req.context.channel_id and thread_ts: + req.context["say_stream"] = AsyncSayStream( + client=req.context.client, + channel=req.context.channel_id, + recipient_team_id=req.context.team_id or req.context.enterprise_id, + recipient_user_id=req.context.user_id, + thread_ts=thread_ts, + ) return await next() diff --git a/slack_bolt/middleware/attaching_agent_kwargs/attaching_agent_kwargs.py b/slack_bolt/middleware/attaching_agent_kwargs/attaching_agent_kwargs.py index 4963ea67d..38a62c0c8 100644 --- a/slack_bolt/middleware/attaching_agent_kwargs/attaching_agent_kwargs.py +++ b/slack_bolt/middleware/attaching_agent_kwargs/attaching_agent_kwargs.py @@ -2,6 +2,7 @@ from slack_bolt.context.assistant.assistant_utilities import AssistantUtilities from slack_bolt.context.assistant.thread_context_store.store import AssistantThreadContextStore +from slack_bolt.context.say_stream.say_stream import SayStream from slack_bolt.middleware import Middleware from slack_bolt.request.payload_utils import is_assistant_event, to_event from slack_bolt.request.request import BoltRequest @@ -30,4 +31,15 @@ def process(self, *, req: BoltRequest, resp: BoltResponse, next: Callable[[], Bo req.context["set_suggested_prompts"] = assistant.set_suggested_prompts req.context["get_thread_context"] = assistant.get_thread_context req.context["save_thread_context"] = assistant.save_thread_context + + # TODO: in the future we might want to introduce a "proper" extract_ts utility + thread_ts = req.context.thread_ts or event.get("ts") + if req.context.channel_id and thread_ts: + req.context["say_stream"] = SayStream( + client=req.context.client, + channel=req.context.channel_id, + recipient_team_id=req.context.team_id or req.context.enterprise_id, + recipient_user_id=req.context.user_id, + thread_ts=thread_ts, + ) return next() diff --git a/tests/scenario_tests/test_events_say_stream.py b/tests/scenario_tests/test_events_say_stream.py new file mode 100644 index 000000000..75b0c612c --- /dev/null +++ b/tests/scenario_tests/test_events_say_stream.py @@ -0,0 +1,238 @@ +import json +import time +from urllib.parse import quote + +from slack_sdk.web import WebClient + +from slack_bolt import App, BoltRequest, BoltContext +from slack_bolt.context.say_stream.say_stream import SayStream +from slack_bolt.middleware.assistant import Assistant +from tests.mock_web_api_server import ( + setup_mock_web_api_server, + cleanup_mock_web_api_server, +) +from tests.scenario_tests.test_app import app_mention_event_body +from tests.scenario_tests.test_events_assistant import ( + thread_started_event_body, + user_message_event_body as threaded_user_message_event_body, +) +from tests.scenario_tests.test_message_bot import bot_message_event_payload, user_message_event_payload +from tests.scenario_tests.test_view_submission import body as view_submission_body +from tests.utils import remove_os_env_temporarily, restore_os_env + + +def assert_target_called(called: dict, timeout: float = 1.0): + deadline = time.time() + timeout + while called["value"] is not True and time.time() < deadline: + time.sleep(0.1) + assert called["value"] is True + + +class TestEventsSayStream: + valid_token = "xoxb-valid" + mock_api_server_base_url = "http://localhost:8888" + web_client = WebClient( + token=valid_token, + base_url=mock_api_server_base_url, + ) + + def setup_method(self): + self.old_os_env = remove_os_env_temporarily() + setup_mock_web_api_server(self) + + def teardown_method(self): + cleanup_mock_web_api_server(self) + restore_os_env(self.old_os_env) + + def test_say_stream_injected_for_app_mention(self): + app = App(client=self.web_client) + called = {"value": False} + + @app.event("app_mention") + def handle_mention(say_stream: SayStream, context: BoltContext): + assert say_stream is not None + assert isinstance(say_stream, SayStream) + assert say_stream == context.say_stream + assert say_stream.channel == "C111" + assert say_stream.thread_ts == "1595926230.009600" + assert say_stream.recipient_team_id == context.team_id + assert say_stream.recipient_user_id == context.user_id + called["value"] = True + + request = BoltRequest(body=app_mention_event_body, mode="socket_mode") + response = app.dispatch(request) + assert response.status == 200 + assert_target_called(called) + + def test_say_stream_with_org_level_install(self): + app = App(client=self.web_client) + called = {"value": False} + + @app.event("app_mention") + def handle_mention(say_stream: SayStream, context: BoltContext): + assert context.team_id is None + assert context.enterprise_id == "E111" + assert say_stream is not None + assert isinstance(say_stream, SayStream) + assert say_stream.recipient_team_id == "E111" + called["value"] = True + + request = BoltRequest(body=org_app_mention_event_body, mode="socket_mode") + response = app.dispatch(request) + assert response.status == 200 + assert_target_called(called) + + def test_say_stream_injected_for_threaded_message(self): + app = App(client=self.web_client) + called = {"value": False} + + @app.event("message") + def handle_message(say_stream: SayStream, context: BoltContext): + assert say_stream is not None + assert isinstance(say_stream, SayStream) + assert say_stream == context.say_stream + assert say_stream.channel == "D111" + assert say_stream.thread_ts == "1726133698.626339" + assert say_stream.recipient_team_id == context.team_id + assert say_stream.recipient_user_id == context.user_id + called["value"] = True + + request = BoltRequest(body=threaded_user_message_event_body, mode="socket_mode") + response = app.dispatch(request) + assert response.status == 200 + assert_target_called(called) + + def test_say_stream_in_user_message(self): + app = App(client=self.web_client) + called = {"value": False} + + @app.message("") + def handle_user_message(say_stream: SayStream, context: BoltContext): + assert say_stream is not None + assert isinstance(say_stream, SayStream) + assert say_stream == context.say_stream + assert say_stream.channel == "C111" + assert say_stream.thread_ts == "1610261659.001400" + assert say_stream.recipient_team_id == context.team_id + assert say_stream.recipient_user_id == context.user_id + called["value"] = True + + request = BoltRequest(body=user_message_event_payload, mode="socket_mode") + response = app.dispatch(request) + assert response.status == 200 + assert_target_called(called) + + def test_say_stream_in_bot_message(self): + app = App(client=self.web_client) + called = {"value": False} + + @app.message("") + def handle_bot_message(say_stream: SayStream, context: BoltContext): + assert say_stream is not None + assert isinstance(say_stream, SayStream) + assert say_stream == context.say_stream + assert say_stream.channel == "C111" + assert say_stream.thread_ts == "1610261539.000900" + assert say_stream.recipient_team_id == context.team_id + assert say_stream.recipient_user_id == context.user_id + called["value"] = True + + request = BoltRequest(body=bot_message_event_payload, mode="socket_mode") + response = app.dispatch(request) + assert response.status == 200 + assert_target_called(called) + + def test_say_stream_in_assistant_thread_started(self): + app = App(client=self.web_client) + assistant = Assistant() + called = {"value": False} + + @assistant.thread_started + def start_thread(say_stream: SayStream, context: BoltContext): + assert say_stream is not None + assert isinstance(say_stream, SayStream) + assert say_stream == context.say_stream + assert say_stream.channel == "D111" + assert say_stream.thread_ts == "1726133698.626339" + assert say_stream.recipient_team_id == context.team_id + assert say_stream.recipient_user_id == context.user_id + called["value"] = True + + app.assistant(assistant) + + request = BoltRequest(body=thread_started_event_body, mode="socket_mode") + response = app.dispatch(request) + assert response.status == 200 + assert_target_called(called) + + def test_say_stream_in_assistant_user_message(self): + app = App(client=self.web_client) + assistant = Assistant() + called = {"value": False} + + @assistant.user_message + def handle_user_message(say_stream: SayStream, context: BoltContext): + assert say_stream is not None + assert isinstance(say_stream, SayStream) + assert say_stream == context.say_stream + assert say_stream.channel == "D111" + assert say_stream.thread_ts == "1726133698.626339" + assert say_stream.recipient_team_id == context.team_id + assert say_stream.recipient_user_id == context.user_id + called["value"] = True + + app.assistant(assistant) + + request = BoltRequest(body=threaded_user_message_event_body, mode="socket_mode") + response = app.dispatch(request) + assert response.status == 200 + assert_target_called(called) + + def test_say_stream_is_none_for_view_submission(self): + app = App(client=self.web_client, request_verification_enabled=False) + called = {"value": False} + + @app.view("view-id") + def handle_view(ack, say_stream, context: BoltContext): + ack() + assert say_stream is None + assert context.say_stream is None + called["value"] = True + + request = BoltRequest( + body=f"payload={quote(json.dumps(view_submission_body))}", + ) + response = app.dispatch(request) + assert response.status == 200 + assert_target_called(called) + + +org_app_mention_event_body = { + "token": "verification_token", + "team_id": "T111", + "enterprise_id": "E111", + "api_app_id": "A111", + "event": { + "client_msg_id": "9cbd4c5b-7ddf-4ede-b479-ad21fca66d63", + "type": "app_mention", + "text": "<@W111> Hi there!", + "user": "W222", + "ts": "1595926230.009600", + "team": "T111", + "channel": "C111", + "event_ts": "1595926230.009600", + }, + "type": "event_callback", + "event_id": "Ev111", + "event_time": 1595926230, + "authorizations": [ + { + "enterprise_id": "E111", + "team_id": None, + "user_id": "W111", + "is_bot": True, + "is_enterprise_install": True, + } + ], + "is_ext_shared_channel": False, +} diff --git a/tests/scenario_tests_async/test_events_say_stream.py b/tests/scenario_tests_async/test_events_say_stream.py new file mode 100644 index 000000000..c24bc7bfc --- /dev/null +++ b/tests/scenario_tests_async/test_events_say_stream.py @@ -0,0 +1,250 @@ +import asyncio +import json +import time +from urllib.parse import quote + +import pytest +from slack_sdk.web.async_client import AsyncWebClient + +from slack_bolt.app.async_app import AsyncApp +from slack_bolt.async_app import AsyncAssistant +from slack_bolt.context.async_context import AsyncBoltContext +from slack_bolt.context.say_stream.async_say_stream import AsyncSayStream +from slack_bolt.request.async_request import AsyncBoltRequest +from tests.mock_web_api_server import ( + cleanup_mock_web_api_server_async, + setup_mock_web_api_server_async, +) +from tests.scenario_tests_async.test_app import app_mention_event_body +from tests.scenario_tests_async.test_events_assistant import user_message_event_body as threaded_user_message_event_body +from tests.scenario_tests_async.test_events_assistant import thread_started_event_body, user_message_event_body +from tests.scenario_tests_async.test_message_bot import bot_message_event_payload, user_message_event_payload +from tests.scenario_tests_async.test_view_submission import body as view_submission_body +from tests.utils import remove_os_env_temporarily, restore_os_env + + +async def assert_target_called(called: dict, timeout: float = 0.5): + deadline = time.time() + timeout + while called["value"] is not True and time.time() < deadline: + await asyncio.sleep(0.1) + assert called["value"] is True + + +class TestAsyncEventsSayStream: + valid_token = "xoxb-valid" + mock_api_server_base_url = "http://localhost:8888" + web_client = AsyncWebClient( + token=valid_token, + base_url=mock_api_server_base_url, + ) + + @pytest.fixture(scope="function", autouse=True) + def setup_teardown(self): + old_os_env = remove_os_env_temporarily() + setup_mock_web_api_server_async(self) + try: + yield + finally: + cleanup_mock_web_api_server_async(self) + restore_os_env(old_os_env) + + @pytest.mark.asyncio + async def test_say_stream_injected_for_app_mention(self): + app = AsyncApp(client=self.web_client) + called = {"value": False} + + @app.event("app_mention") + async def handle_mention(say_stream: AsyncSayStream, context: AsyncBoltContext): + assert say_stream is not None + assert isinstance(say_stream, AsyncSayStream) + assert say_stream == context.say_stream + assert say_stream.channel == "C111" + assert say_stream.thread_ts == "1595926230.009600" + assert say_stream.recipient_team_id == context.team_id + assert say_stream.recipient_user_id == context.user_id + called["value"] = True + + request = AsyncBoltRequest(body=app_mention_event_body, mode="socket_mode") + response = await app.async_dispatch(request) + assert response.status == 200 + await assert_target_called(called) + + @pytest.mark.asyncio + async def test_say_stream_with_org_level_install(self): + app = AsyncApp(client=self.web_client) + called = {"value": False} + + @app.event("app_mention") + async def handle_mention(say_stream: AsyncSayStream, context: AsyncBoltContext): + assert context.team_id is None + assert context.enterprise_id == "E111" + assert say_stream is not None + assert isinstance(say_stream, AsyncSayStream) + assert say_stream.recipient_team_id == "E111" + called["value"] = True + + request = AsyncBoltRequest(body=org_app_mention_event_body, mode="socket_mode") + response = await app.async_dispatch(request) + assert response.status == 200 + await assert_target_called(called) + + @pytest.mark.asyncio + async def test_say_stream_injected_for_threaded_message(self): + app = AsyncApp(client=self.web_client) + called = {"value": False} + + @app.event("message") + async def handle_message(say_stream: AsyncSayStream, context: AsyncBoltContext): + assert say_stream is not None + assert isinstance(say_stream, AsyncSayStream) + assert say_stream == context.say_stream + assert say_stream.channel == "D111" + assert say_stream.thread_ts == "1726133698.626339" + assert say_stream.recipient_team_id == context.team_id + assert say_stream.recipient_user_id == context.user_id + called["value"] = True + + request = AsyncBoltRequest(body=threaded_user_message_event_body, mode="socket_mode") + response = await app.async_dispatch(request) + assert response.status == 200 + await assert_target_called(called) + + @pytest.mark.asyncio + async def test_say_stream_in_user_message(self): + app = AsyncApp(client=self.web_client) + called = {"value": False} + + @app.message("") + async def handle_user_message(say_stream: AsyncSayStream, context: AsyncBoltContext): + assert say_stream is not None + assert isinstance(say_stream, AsyncSayStream) + assert say_stream == context.say_stream + assert say_stream.channel == "C111" + assert say_stream.thread_ts == "1610261659.001400" + assert say_stream.recipient_team_id == context.team_id + assert say_stream.recipient_user_id == context.user_id + called["value"] = True + + request = AsyncBoltRequest(body=user_message_event_payload, mode="socket_mode") + response = await app.async_dispatch(request) + assert response.status == 200 + await assert_target_called(called) + + @pytest.mark.asyncio + async def test_say_stream_in_bot_message(self): + app = AsyncApp(client=self.web_client) + called = {"value": False} + + @app.message("") + async def handle_user_message(say_stream: AsyncSayStream, context: AsyncBoltContext): + assert say_stream is not None + assert isinstance(say_stream, AsyncSayStream) + assert say_stream == context.say_stream + assert say_stream.channel == "C111" + assert say_stream.thread_ts == "1610261539.000900" + assert say_stream.recipient_team_id == context.team_id + assert say_stream.recipient_user_id == context.user_id + called["value"] = True + + request = AsyncBoltRequest(body=bot_message_event_payload, mode="socket_mode") + response = await app.async_dispatch(request) + assert response.status == 200 + await assert_target_called(called) + + @pytest.mark.asyncio + async def test_say_stream_in_assistant_thread_started(self): + app = AsyncApp(client=self.web_client) + assistant = AsyncAssistant() + called = {"value": False} + + @assistant.thread_started + async def start_thread(say_stream: AsyncSayStream, context: AsyncBoltContext): + assert say_stream is not None + assert isinstance(say_stream, AsyncSayStream) + assert say_stream == context.say_stream + assert say_stream.channel == "D111" + assert say_stream.thread_ts == "1726133698.626339" + assert say_stream.recipient_team_id == context.team_id + assert say_stream.recipient_user_id == context.user_id + called["value"] = True + + app.assistant(assistant) + + request = AsyncBoltRequest(body=thread_started_event_body, mode="socket_mode") + response = await app.async_dispatch(request) + assert response.status == 200 + await assert_target_called(called) + + @pytest.mark.asyncio + async def test_say_stream_in_assistant_user_message(self): + app = AsyncApp(client=self.web_client) + assistant = AsyncAssistant() + called = {"value": False} + + @assistant.user_message + async def handle_user_message(say_stream: AsyncSayStream, context: AsyncBoltContext): + assert say_stream is not None + assert isinstance(say_stream, AsyncSayStream) + assert say_stream == context.say_stream + assert say_stream.channel == "D111" + assert say_stream.thread_ts == "1726133698.626339" + assert say_stream.recipient_team_id == context.team_id + assert say_stream.recipient_user_id == context.user_id + called["value"] = True + + app.assistant(assistant) + + request = AsyncBoltRequest(body=user_message_event_body, mode="socket_mode") + response = await app.async_dispatch(request) + assert response.status == 200 + await assert_target_called(called) + + @pytest.mark.asyncio + async def test_say_stream_is_none_for_view_submission(self): + app = AsyncApp(client=self.web_client, request_verification_enabled=False) + called = {"value": False} + + @app.view("view-id") + async def handle_view(ack, say_stream, context: AsyncBoltContext): + await ack() + assert say_stream is None + assert context.say_stream is None + called["value"] = True + + request = AsyncBoltRequest( + body=f"payload={quote(json.dumps(view_submission_body))}", + ) + response = await app.async_dispatch(request) + assert response.status == 200 + await assert_target_called(called) + + +org_app_mention_event_body = { + "token": "verification_token", + "team_id": "T111", + "enterprise_id": "E111", + "api_app_id": "A111", + "event": { + "client_msg_id": "9cbd4c5b-7ddf-4ede-b479-ad21fca66d63", + "type": "app_mention", + "text": "<@W111> Hi there!", + "user": "W222", + "ts": "1595926230.009600", + "team": "T111", + "channel": "C111", + "event_ts": "1595926230.009600", + }, + "type": "event_callback", + "event_id": "Ev111", + "event_time": 1595926230, + "authorizations": [ + { + "enterprise_id": "E111", + "team_id": None, + "user_id": "W111", + "is_bot": True, + "is_enterprise_install": True, + } + ], + "is_ext_shared_channel": False, +} diff --git a/tests/slack_bolt/context/test_say_stream.py b/tests/slack_bolt/context/test_say_stream.py new file mode 100644 index 000000000..c8f4c3a31 --- /dev/null +++ b/tests/slack_bolt/context/test_say_stream.py @@ -0,0 +1,103 @@ +import pytest +from slack_sdk import WebClient + +from slack_bolt.context.say_stream.say_stream import SayStream +from slack_bolt.warning import ExperimentalWarning +from tests.mock_web_api_server import cleanup_mock_web_api_server, setup_mock_web_api_server + + +class TestSayStream: + default_chat_stream_buffer_size = WebClient.chat_stream.__kwdefaults__["buffer_size"] + + def setup_method(self): + setup_mock_web_api_server(self) + valid_token = "xoxb-valid" + mock_api_server_base_url = "http://localhost:8888" + self.web_client = WebClient(token=valid_token, base_url=mock_api_server_base_url) + + def teardown_method(self): + cleanup_mock_web_api_server(self) + + def test_missing_channel_raises(self): + say_stream = SayStream(client=self.web_client, channel=None, thread_ts="111.222") + with pytest.warns(ExperimentalWarning): + with pytest.raises(ValueError, match="channel"): + say_stream() + + def test_missing_thread_ts_raises(self): + say_stream = SayStream(client=self.web_client, channel="C111", thread_ts=None) + with pytest.warns(ExperimentalWarning): + with pytest.raises(ValueError, match="thread_ts"): + say_stream() + + def test_default_params(self): + say_stream = SayStream( + client=self.web_client, + channel="C111", + recipient_team_id="T111", + recipient_user_id="U111", + thread_ts="111.222", + ) + stream = say_stream() + + assert stream._buffer_size == self.default_chat_stream_buffer_size + assert stream._stream_args == { + "channel": "C111", + "thread_ts": "111.222", + "recipient_team_id": "T111", + "recipient_user_id": "U111", + "task_display_mode": None, + } + + def test_parameter_overrides(self): + say_stream = SayStream( + client=self.web_client, + channel="C111", + recipient_team_id="T111", + recipient_user_id="U111", + thread_ts="111.222", + ) + stream = say_stream(channel="C222", thread_ts="333.444", recipient_team_id="T222", recipient_user_id="U222") + + assert stream._buffer_size == self.default_chat_stream_buffer_size + assert stream._stream_args == { + "channel": "C222", + "thread_ts": "333.444", + "recipient_team_id": "T222", + "recipient_user_id": "U222", + "task_display_mode": None, + } + + def test_buffer_size_overrides(self): + say_stream = SayStream( + client=self.web_client, + channel="C111", + recipient_team_id="T111", + recipient_user_id="U111", + thread_ts="111.222", + ) + stream = say_stream( + buffer_size=100, + channel="C222", + thread_ts="333.444", + recipient_team_id="T222", + recipient_user_id="U222", + ) + + assert stream._buffer_size == 100 + assert stream._stream_args == { + "channel": "C222", + "thread_ts": "333.444", + "recipient_team_id": "T222", + "recipient_user_id": "U222", + "task_display_mode": None, + } + + def test_experimental_warning(self): + say_stream = SayStream( + client=self.web_client, + channel="C111", + thread_ts="111.222", + ) + with pytest.warns(ExperimentalWarning, match="say_stream is experimental"): + say_stream() diff --git a/tests/slack_bolt_async/context/test_async_say_stream.py b/tests/slack_bolt_async/context/test_async_say_stream.py new file mode 100644 index 000000000..fbc4c5c7e --- /dev/null +++ b/tests/slack_bolt_async/context/test_async_say_stream.py @@ -0,0 +1,117 @@ +import pytest +from slack_sdk.web.async_client import AsyncWebClient + +from slack_bolt.context.say_stream.async_say_stream import AsyncSayStream +from slack_bolt.warning import ExperimentalWarning +from tests.mock_web_api_server import ( + cleanup_mock_web_api_server, + setup_mock_web_api_server, +) +from tests.utils import remove_os_env_temporarily, restore_os_env + + +class TestAsyncSayStream: + default_chat_stream_buffer_size = AsyncWebClient.chat_stream.__kwdefaults__["buffer_size"] + + @pytest.fixture(scope="function", autouse=True) + def setup_teardown(self): + old_os_env = remove_os_env_temporarily() + setup_mock_web_api_server(self) + valid_token = "xoxb-valid" + mock_api_server_base_url = "http://localhost:8888" + try: + self.web_client = AsyncWebClient(token=valid_token, base_url=mock_api_server_base_url) + yield # run the test here + finally: + cleanup_mock_web_api_server(self) + restore_os_env(old_os_env) + + @pytest.mark.asyncio + async def test_missing_channel_raises(self): + say_stream = AsyncSayStream(client=self.web_client, channel=None, thread_ts="111.222") + with pytest.warns(ExperimentalWarning): + with pytest.raises(ValueError, match="channel"): + await say_stream() + + @pytest.mark.asyncio + async def test_missing_thread_ts_raises(self): + say_stream = AsyncSayStream(client=self.web_client, channel="C111", thread_ts=None) + with pytest.warns(ExperimentalWarning): + with pytest.raises(ValueError, match="thread_ts"): + await say_stream() + + @pytest.mark.asyncio + async def test_default_params(self): + say_stream = AsyncSayStream( + client=self.web_client, + channel="C111", + recipient_team_id="T111", + recipient_user_id="U111", + thread_ts="111.222", + ) + stream = await say_stream() + + assert stream._buffer_size == self.default_chat_stream_buffer_size + assert stream._stream_args == { + "channel": "C111", + "thread_ts": "111.222", + "recipient_team_id": "T111", + "recipient_user_id": "U111", + "task_display_mode": None, + } + + @pytest.mark.asyncio + async def test_parameter_overrides(self): + say_stream = AsyncSayStream( + client=self.web_client, + channel="C111", + recipient_team_id="T111", + recipient_user_id="U111", + thread_ts="111.222", + ) + stream = await say_stream(channel="C222", thread_ts="333.444", recipient_team_id="T222", recipient_user_id="U222") + + assert stream._buffer_size == self.default_chat_stream_buffer_size + assert stream._stream_args == { + "channel": "C222", + "thread_ts": "333.444", + "recipient_team_id": "T222", + "recipient_user_id": "U222", + "task_display_mode": None, + } + + @pytest.mark.asyncio + async def test_buffer_size_overrides(self): + say_stream = AsyncSayStream( + client=self.web_client, + channel="C111", + recipient_team_id="T111", + recipient_user_id="U111", + thread_ts="111.222", + ) + stream = await say_stream( + buffer_size=100, + channel="C222", + thread_ts="333.444", + recipient_team_id="T222", + recipient_user_id="U222", + ) + + assert stream._buffer_size == 100 + assert stream._stream_args == { + "channel": "C222", + "thread_ts": "333.444", + "recipient_team_id": "T222", + "recipient_user_id": "U222", + "task_display_mode": None, + } + + @pytest.mark.asyncio + async def test_experimental_warning(self): + say_stream = AsyncSayStream( + client=self.web_client, + channel="C111", + thread_ts="111.222", + ) + with pytest.warns(ExperimentalWarning, match="say_stream is experimental"): + await say_stream()