Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 224 additions & 2 deletions astrbot/core/platform/sources/telegram/tg_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,19 @@
)
from astrbot.api.platform import AstrBotMessage, MessageType, PlatformMetadata

# sendMessageDraft 的 draft_id 模块级递增计数器(溢出时归 1)
_TELEGRAM_DRAFT_ID_MAX = 2_147_483_647
_next_draft_id = 0


def _allocate_draft_id() -> int:
"""分配一个全局递增的 draft_id,溢出时归 1。"""
global _next_draft_id
_next_draft_id = (
1 if _next_draft_id >= _TELEGRAM_DRAFT_ID_MAX else _next_draft_id + 1
)
return _next_draft_id


class TelegramPlatformEvent(AstrMessageEvent):
# Telegram 的最大消息长度限制
Expand Down Expand Up @@ -339,6 +352,44 @@ async def react(self, emoji: str | None, big: bool = False) -> None:
except Exception as e:
logger.error(f"[Telegram] 添加反应失败: {e}")

async def _send_message_draft(
self,
chat_id: str,
draft_id: int,
text: str,
message_thread_id: str | None = None,
parse_mode: str | None = None,
) -> None:
"""通过 Bot.send_message_draft 发送草稿消息(流式推送部分消息)。

该 API 仅支持私聊。

Args:
chat_id: 目标私聊的 chat_id
draft_id: 草稿唯一标识,非零整数;相同 draft_id 的变更会以动画展示
text: 消息文本,1-4096 字符
message_thread_id: 可选,目标消息线程 ID
parse_mode: 可选,消息文本的解析模式
"""
kwargs: dict[str, Any] = {}
if message_thread_id:
kwargs["message_thread_id"] = int(message_thread_id)
if parse_mode:
kwargs["parse_mode"] = parse_mode

try:
logger.debug(
f"[Telegram] sendMessageDraft: chat_id={chat_id}, draft_id={draft_id}, text_len={len(text)}"
)
await self.client.send_message_draft(
chat_id=int(chat_id),
draft_id=draft_id,
text=text,
**kwargs,
)
except Exception as e:
logger.warning(f"[Telegram] sendMessageDraft 失败: {e!s}")

async def send_streaming(self, generator, use_fallback: bool = False):
message_thread_id = None

Expand All @@ -356,6 +407,179 @@ async def send_streaming(self, generator, use_fallback: bool = False):
if message_thread_id:
payload["message_thread_id"] = message_thread_id

# sendMessageDraft 仅支持私聊
is_private = self.get_message_type() != MessageType.GROUP_MESSAGE

if is_private:
logger.info("[Telegram] 流式输出: 使用 sendMessageDraft (私聊)")
await self._send_streaming_draft(
user_name, message_thread_id, payload, generator
)
else:
logger.info("[Telegram] 流式输出: 使用 edit_message_text fallback (群聊)")
await self._send_streaming_edit(
user_name, message_thread_id, payload, generator
)

return await super().send_streaming(generator, use_fallback)

async def _send_streaming_draft(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): 请考虑抽取共享的 MessageChain 条目处理逻辑、简化草稿发送循环并本地化草稿 ID 分配,以减少重复,让流式逻辑更容易理解和维护。

你可以在不改变行为的前提下降低复杂度和重复度,方式包括:

1. 抽取共享的 MessageChain 处理

内部的 for i in chain.chain 循环在 _send_streaming_draft_send_streaming_edit 之间基本是重复的。你可以把它提取到一个小的辅助函数中,用于追加文本和发送媒体;调用方只需要传入如何累积文本以及任何额外上下文。

async def _process_message_chain_items(
    self,
    chain: MessageChain,
    payload: dict[str, Any],
    user_name: str,
    message_thread_id: str | None,
    append_text: Callable[[str], None],
) -> None:
    for i in chain.chain:
        if isinstance(i, Plain):
            append_text(i.text)
        elif isinstance(i, Image):
            image_path = await i.convert_to_file_path()
            await self._send_media_with_action(
                self.client,
                ChatAction.UPLOAD_PHOTO,
                self.client.send_photo,
                user_name=user_name,
                photo=image_path,
                **cast(Any, payload),
            )
        elif isinstance(i, File):
            path = await i.get_file()
            name = i.name or os.path.basename(path)
            await self._send_media_with_action(
                self.client,
                ChatAction.UPLOAD_DOCUMENT,
                self.client.send_document,
                user_name=user_name,
                document=path,
                filename=name,
                **cast(Any, payload),
            )
        elif isinstance(i, Record):
            path = await i.convert_to_file_path()
            await self._send_voice_with_fallback(
                self.client,
                path,
                payload,
                caption=i.text or None,
                user_name=user_name,
                message_thread_id=message_thread_id,
                use_media_action=True,
            )
        elif isinstance(i, Video):
            path = await i.convert_to_file_path()
            await self._send_media_with_action(
                self.client,
                ChatAction.UPLOAD_VIDEO,
                self.client.send_video,
                user_name=user_name,
                video=path,
                **cast(Any, payload),
            )
        else:
            logger.warning(f"不支持的消息类型: {type(i)}")

然后在两个流式函数中:

# in _send_streaming_draft
async for chain in generator:
    if isinstance(chain, MessageChain):
        if chain.type == "break":
            # existing break handling...
            ...
            continue

        await self._process_message_chain_items(
            chain,
            payload,
            user_name,
            message_thread_id,
            append_text=lambda t: delta.__iadd__(t),  # or a small wrapper
        )
# in _send_streaming_edit
async for chain in generator:
    if isinstance(chain, MessageChain):
        if chain.type == "break":
            # existing break handling...
            ...
            continue

        await self._process_message_chain_items(
            chain,
            payload,
            user_name,
            message_thread_id,
            append_text=lambda t: delta.__iadd__(t),
        )

这样可以保持所有行为不变,但去掉大量重复的 Plain/Image/File/Record/Video 分支,让两个方法都聚焦于文本是如何流式发送的(草稿 vs 编辑)。

2. 简化草稿发送循环(每个分段无需重启)

你可以在整个 _send_streaming_draft 调用期间保持一个单一的发送循环,避免在每次遇到 break 时取消/重启任务。让循环:

  • 观察一个 current_draft_id
  • 按固定时间间隔发送 delta 中的内容
  • 只在 generator 结束时停止一次

break 时,你只需要发送最终的真实消息、清空 delta 并更新 current_draft_id;循环会自动使用新的状态继续运行。

async def _send_streaming_draft(...):
    draft_id = _allocate_draft_id()
    delta = ""
    last_sent_text = ""
    send_interval = 0.5
    done = False  # generator finished

    async def _draft_sender_loop() -> None:
        nonlocal last_sent_text, draft_id
        while not done:
            await asyncio.sleep(send_interval)
            if delta and delta != last_sent_text:
                draft_text = delta[: self.MAX_MESSAGE_LENGTH]
                if draft_text != last_sent_text:
                    try:
                        await self._send_message_draft(
                            user_name,
                            draft_id,           # always use latest draft_id
                            draft_text,
                            message_thread_id,
                        )
                        last_sent_text = draft_text
                    except Exception:
                        pass

    sender_task = asyncio.create_task(_draft_sender_loop())
    try:
        async for chain in generator:
            if not isinstance(chain, MessageChain):
                continue

            if chain.type == "break":
                # flush current segment as real message
                if delta:
                    await self._send_final_segment(delta, payload)
                # reset state for next segment; loop keeps running
                delta = ""
                last_sent_text = ""
                draft_id = _allocate_draft_id()
                continue

            await self._process_message_chain_items(
                chain,
                payload,
                user_name,
                message_thread_id,
                append_text=lambda t: delta.__iadd__(t),
            )
    finally:
        done = True
        await sender_task

    if delta:
        await self._send_final_segment(delta, payload)

async def _send_final_segment(self, delta: str, payload: dict[str, Any]) -> None:
    try:
        markdown_text = telegramify_markdown.markdownify(
            delta,
            normalize_whitespace=False,
        )
        await self.client.send_message(
            text=markdown_text,
            parse_mode="MarkdownV2",
            **cast(Any, payload),
        )
    except Exception as e:
        logger.warning(f"Markdown转换失败,使用普通文本: {e!s}")
        await self.client.send_message(text=delta, **cast(Any, payload))

这保持了“周期性发送最新缓冲区”的语义以及每个分段的 draft_id 更新,但去掉了:

  • streaming_done 标记的切换
  • 在每次 break 上的任务取消/重建
  • 重复的“发送最终 markdown vs 纯文本”逻辑(移到了 _send_final_segment 中)。

3. 本地化草稿 ID 状态

如果你的对象模型允许,可以通过把计数器挂到类/实例上来避免模块级的 global 计数器(同时保留回绕行为):

class TelegramPlatformEvent(AstrMessageEvent):
    _TELEGRAM_DRAFT_ID_MAX = 2_147_483_647
    _next_draft_id: int = 0  # class-level or move to __init__ as self._next_draft_id

    @classmethod
    def _allocate_draft_id(cls) -> int:
        cls._next_draft_id = (
            1 if cls._next_draft_id >= cls._TELEGRAM_DRAFT_ID_MAX else cls._next_draft_id + 1
        )
        return cls._next_draft_id

然后在 _send_streaming_draft 中:

draft_id = self._allocate_draft_id()
...
draft_id = self._allocate_draft_id()

这样就移除了 global,并让草稿 ID 的变化更容易推理(也更容易在测试中覆盖/模拟),同时保持相同的整数行为。

Original comment in English

issue (complexity): Consider extracting the shared MessageChain item handling, simplifying the draft sender loop, and localizing draft ID allocation to reduce duplication and make the streaming logic easier to follow and maintain.

You can reduce complexity and duplication without changing behavior by:

1. Extracting shared MessageChain processing

The inner for i in chain.chain loop is essentially duplicated between _send_streaming_draft and _send_streaming_edit. You can factor it out into a small helper that appends text and sends media; the caller just passes how to accumulate text and any extra context.

async def _process_message_chain_items(
    self,
    chain: MessageChain,
    payload: dict[str, Any],
    user_name: str,
    message_thread_id: str | None,
    append_text: Callable[[str], None],
) -> None:
    for i in chain.chain:
        if isinstance(i, Plain):
            append_text(i.text)
        elif isinstance(i, Image):
            image_path = await i.convert_to_file_path()
            await self._send_media_with_action(
                self.client,
                ChatAction.UPLOAD_PHOTO,
                self.client.send_photo,
                user_name=user_name,
                photo=image_path,
                **cast(Any, payload),
            )
        elif isinstance(i, File):
            path = await i.get_file()
            name = i.name or os.path.basename(path)
            await self._send_media_with_action(
                self.client,
                ChatAction.UPLOAD_DOCUMENT,
                self.client.send_document,
                user_name=user_name,
                document=path,
                filename=name,
                **cast(Any, payload),
            )
        elif isinstance(i, Record):
            path = await i.convert_to_file_path()
            await self._send_voice_with_fallback(
                self.client,
                path,
                payload,
                caption=i.text or None,
                user_name=user_name,
                message_thread_id=message_thread_id,
                use_media_action=True,
            )
        elif isinstance(i, Video):
            path = await i.convert_to_file_path()
            await self._send_media_with_action(
                self.client,
                ChatAction.UPLOAD_VIDEO,
                self.client.send_video,
                user_name=user_name,
                video=path,
                **cast(Any, payload),
            )
        else:
            logger.warning(f"不支持的消息类型: {type(i)}")

Then in both streaming functions:

# in _send_streaming_draft
async for chain in generator:
    if isinstance(chain, MessageChain):
        if chain.type == "break":
            # existing break handling...
            ...
            continue

        await self._process_message_chain_items(
            chain,
            payload,
            user_name,
            message_thread_id,
            append_text=lambda t: delta.__iadd__(t),  # or a small wrapper
        )
# in _send_streaming_edit
async for chain in generator:
    if isinstance(chain, MessageChain):
        if chain.type == "break":
            # existing break handling...
            ...
            continue

        await self._process_message_chain_items(
            chain,
            payload,
            user_name,
            message_thread_id,
            append_text=lambda t: delta.__iadd__(t),
        )

This keeps all behavior but removes the large duplicated Plain/Image/File/Record/Video branches, making both methods focused on how text is streamed (draft vs edit).

2. Simplifying the draft sender loop (no restart per segment)

You can keep a single sender loop for the entire _send_streaming_draft call and avoid cancelling/restarting the task on each break. Let the loop:

  • watch a current_draft_id
  • send whatever delta contains at fixed intervals
  • stop only once at the end of the generator

On break, you only need to send the final real message, clear delta, and update current_draft_id; the loop picks up the new state automatically.

async def _send_streaming_draft(...):
    draft_id = _allocate_draft_id()
    delta = ""
    last_sent_text = ""
    send_interval = 0.5
    done = False  # generator finished

    async def _draft_sender_loop() -> None:
        nonlocal last_sent_text, draft_id
        while not done:
            await asyncio.sleep(send_interval)
            if delta and delta != last_sent_text:
                draft_text = delta[: self.MAX_MESSAGE_LENGTH]
                if draft_text != last_sent_text:
                    try:
                        await self._send_message_draft(
                            user_name,
                            draft_id,           # always use latest draft_id
                            draft_text,
                            message_thread_id,
                        )
                        last_sent_text = draft_text
                    except Exception:
                        pass

    sender_task = asyncio.create_task(_draft_sender_loop())
    try:
        async for chain in generator:
            if not isinstance(chain, MessageChain):
                continue

            if chain.type == "break":
                # flush current segment as real message
                if delta:
                    await self._send_final_segment(delta, payload)
                # reset state for next segment; loop keeps running
                delta = ""
                last_sent_text = ""
                draft_id = _allocate_draft_id()
                continue

            await self._process_message_chain_items(
                chain,
                payload,
                user_name,
                message_thread_id,
                append_text=lambda t: delta.__iadd__(t),
            )
    finally:
        done = True
        await sender_task

    if delta:
        await self._send_final_segment(delta, payload)

async def _send_final_segment(self, delta: str, payload: dict[str, Any]) -> None:
    try:
        markdown_text = telegramify_markdown.markdownify(
            delta,
            normalize_whitespace=False,
        )
        await self.client.send_message(
            text=markdown_text,
            parse_mode="MarkdownV2",
            **cast(Any, payload),
        )
    except Exception as e:
        logger.warning(f"Markdown转换失败,使用普通文本: {e!s}")
        await self.client.send_message(text=delta, **cast(Any, payload))

This keeps the “periodically send latest buffer” semantics and segment-by-segment draft_id updates, but removes:

  • streaming_done toggling
  • task cancellation / recreation on each break
  • duplicated “send final markdown vs plain text” logic (moved to _send_final_segment).

3. Localizing draft ID state

If possible in your object model, you can avoid the module‑level global counter by attaching it to the class/instance (still preserves wraparound behavior):

class TelegramPlatformEvent(AstrMessageEvent):
    _TELEGRAM_DRAFT_ID_MAX = 2_147_483_647
    _next_draft_id: int = 0  # class-level or move to __init__ as self._next_draft_id

    @classmethod
    def _allocate_draft_id(cls) -> int:
        cls._next_draft_id = (
            1 if cls._next_draft_id >= cls._TELEGRAM_DRAFT_ID_MAX else cls._next_draft_id + 1
        )
        return cls._next_draft_id

Then in _send_streaming_draft:

draft_id = self._allocate_draft_id()
...
draft_id = self._allocate_draft_id()

This removes global and makes the draft ID evolution easier to reason about (and to override/mock in tests) while keeping the same integer behavior.

self,
user_name: str,
message_thread_id: str | None,
payload: dict[str, Any],
generator,
) -> None:
"""使用 sendMessageDraft API 进行流式推送(私聊专用)。

流式过程中使用 sendMessageDraft 推送草稿动画,
流式结束后发送一条真实消息保留最终内容(draft 是临时的,会消失)。
使用独立的异步发送循环,按固定间隔发送最新缓冲区内容,
完全解耦 token 到达速度与 API 网络延迟。
"""
draft_id = _allocate_draft_id()
delta = ""
last_sent_text = ""
send_interval = 0.5 # 独立发送循环间隔 (秒)
streaming_done = False # 信号:生成器已结束

async def _draft_sender_loop() -> None:
"""独立的草稿发送循环,按固定间隔发送最新内容。"""
nonlocal last_sent_text
while not streaming_done:
await asyncio.sleep(send_interval)
if delta and delta != last_sent_text:
draft_text = delta[: self.MAX_MESSAGE_LENGTH]
if draft_text != last_sent_text:
try:
await self._send_message_draft(
user_name,
draft_id,
draft_text,
message_thread_id,
)
last_sent_text = draft_text
except Exception:
pass # 草稿发送失败不影响流式
Comment on lines +462 to +463
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

虽然注释说明了草稿发送失败不应影响流式传输,但完全静默地忽略异常 (pass) 可能会隐藏持续存在的问题(例如,认证失败或网络问题)。建议在此处添加一个 debug 级别的日志记录,以便在需要时可以排查问题,同时又不会在正常情况下产生过多日志。

Suggested change
except Exception:
pass # 草稿发送失败不影响流式
except Exception as e:
logger.debug(f"[Telegram] sendMessageDraft failed in loop (ignored): {e!s}")


# 启动独立发送循环
sender_task = asyncio.create_task(_draft_sender_loop())

try:
async for chain in generator:
if isinstance(chain, MessageChain):
if chain.type == "break":
# 分割符:停止发送循环,发送真实消息,重置状态
streaming_done = True
await sender_task
if delta:
try:
markdown_text = telegramify_markdown.markdownify(
delta,
normalize_whitespace=False,
)
await self.client.send_message(
text=markdown_text,
parse_mode="MarkdownV2",
**cast(Any, payload),
)
except Exception as e:
logger.warning(f"Markdown转换失败,使用普通文本: {e!s}")
await self.client.send_message(
text=delta, **cast(Any, payload)
)
# 重置并启动新的发送循环
delta = ""
last_sent_text = ""
draft_id = _allocate_draft_id()
streaming_done = False
sender_task = asyncio.create_task(_draft_sender_loop())
continue

# 处理消息链中的每个组件
for i in chain.chain:
if isinstance(i, Plain):
delta += i.text
elif isinstance(i, Image):
image_path = await i.convert_to_file_path()
await self._send_media_with_action(
self.client,
ChatAction.UPLOAD_PHOTO,
self.client.send_photo,
user_name=user_name,
photo=image_path,
**cast(Any, payload),
)
continue
elif isinstance(i, File):
path = await i.get_file()
name = i.name or os.path.basename(path)
await self._send_media_with_action(
self.client,
ChatAction.UPLOAD_DOCUMENT,
self.client.send_document,
user_name=user_name,
document=path,
filename=name,
**cast(Any, payload),
)
continue
elif isinstance(i, Record):
path = await i.convert_to_file_path()
await self._send_voice_with_fallback(
self.client,
path,
payload,
caption=i.text or delta or None,
user_name=user_name,
message_thread_id=message_thread_id,
use_media_action=True,
)
continue
elif isinstance(i, Video):
path = await i.convert_to_file_path()
await self._send_media_with_action(
self.client,
ChatAction.UPLOAD_VIDEO,
self.client.send_video,
user_name=user_name,
video=path,
**cast(Any, payload),
)
continue
else:
logger.warning(f"不支持的消息类型: {type(i)}")
continue
finally:
# 停止发送循环
streaming_done = True
if not sender_task.done():
await sender_task

# 流式结束:发送真实消息保留最终内容
if delta:
try:
markdown_text = telegramify_markdown.markdownify(
delta,
normalize_whitespace=False,
)
await self.client.send_message(
text=markdown_text,
parse_mode="MarkdownV2",
**cast(Any, payload),
)
except Exception as e:
logger.warning(f"Markdown转换失败,使用普通文本: {e!s}")
await self.client.send_message(text=delta, **cast(Any, payload))
Comment on lines +560 to +573
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

这部分发送最终消息的逻辑与前面处理 break 分割符时(475-490行)的代码几乎完全相同。为了提高代码的可维护性并减少重复,建议将这部分逻辑提取到一个独立的内部辅助函数中。


async def _send_streaming_edit(
self,
user_name: str,
message_thread_id: str | None,
payload: dict[str, Any],
generator,
) -> None:
"""使用 send_message + edit_message_text 进行流式推送(群聊 fallback)。"""
delta = ""
current_content = ""
message_id = None
Expand Down Expand Up @@ -506,5 +730,3 @@ async def send_streaming(self, generator, use_fallback: bool = False):
)
except Exception as e:
logger.warning(f"编辑消息失败(streaming): {e!s}")

return await super().send_streaming(generator, use_fallback)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ dependencies = [
"pydantic>=2.12.5",
"pydub>=0.25.1",
"pyjwt>=2.10.1",
"python-telegram-bot>=22.0",
"python-telegram-bot>=22.6",
"qq-botpy>=1.2.1",
"quart>=0.20.0",
"readability-lxml>=0.8.4.1",
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ py-cord>=2.6.1
pydantic>=2.12.5
pydub>=0.25.1
pyjwt>=2.10.1
python-telegram-bot>=22.0
python-telegram-bot>=22.6
qq-botpy>=1.2.1
quart>=0.20.0
readability-lxml>=0.8.4.1
Expand Down
1 change: 1 addition & 0 deletions tests/fixtures/mocks/telegram.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def create_bot():
bot.set_my_commands = AsyncMock()
bot.set_message_reaction = AsyncMock()
bot.edit_message_text = AsyncMock()
bot.send_message_draft = AsyncMock()
return bot

@staticmethod
Expand Down