diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 3ba68fa898..cbadb5c18f 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -449,6 +449,20 @@ class ChatProviderTemplate(TypedDict): "satori_heartbeat_interval": 10, "satori_reconnect_delay": 5, }, + "kook": { + "id": "kook", + "type": "kook", + "enable": False, + "kook_bot_token": "", + "kook_bot_nickname": "", + "kook_reconnect_delay": 1, + "kook_max_reconnect_delay": 60, + "kook_max_retry_delay": 60, + "kook_heartbeat_interval": 30, + "kook_heartbeat_timeout": 6, + "kook_max_heartbeat_failures": 3, + "kook_max_consecutive_failures": 5, + }, # "WebChat": { # "id": "webchat", # "type": "webchat", @@ -790,6 +804,51 @@ class ChatProviderTemplate(TypedDict): "type": "string", "hint": "统一 Webhook 模式下的唯一标识符,创建平台时自动生成。", }, + "kook_bot_token": { + "description": "机器人 Token", + "type": "string", + "hint": "必填项。从 KOOK 开发者平台获取的机器人 Token。", + }, + "kook_bot_nickname": { + "description": "Bot Nickname", + "type": "string", + "hint": "可选项。若发送者昵称与此值一致,将忽略该消息以避免广播风暴。", + }, + "kook_reconnect_delay": { + "description": "重连延迟", + "type": "int", + "hint": "重连延迟时间(秒),使用指数退避策略。", + }, + "kook_max_reconnect_delay": { + "description": "最大重连延迟", + "type": "int", + "hint": "重连延迟的最大值(秒)。", + }, + "kook_max_retry_delay": { + "description": "最大重试延迟", + "type": "int", + "hint": "重试的最大延迟时间(秒)。", + }, + "kook_heartbeat_interval": { + "description": "心跳间隔", + "type": "int", + "hint": "心跳检测间隔时间(秒)。", + }, + "kook_heartbeat_timeout": { + "description": "心跳超时时间", + "type": "int", + "hint": "心跳检测超时时间(秒)。", + }, + "kook_max_heartbeat_failures": { + "description": "最大心跳失败次数", + "type": "int", + "hint": "允许的最大心跳失败次数,超过后断开连接。", + }, + "kook_max_consecutive_failures": { + "description": "最大连续失败次数", + "type": "int", + "hint": "允许的最大连续失败次数,超过后停止重试。", + }, }, }, "platform_settings": { diff --git a/astrbot/core/platform/manager.py b/astrbot/core/platform/manager.py index 0238779dad..68737b2bcf 100644 --- a/astrbot/core/platform/manager.py +++ b/astrbot/core/platform/manager.py @@ -180,6 +180,10 @@ async def load_platform(self, platform_config: dict) -> None: from .sources.line.line_adapter import ( LinePlatformAdapter, # noqa: F401 ) + case "kook": + from .sources.kook.kook_adapter import ( + KookPlatformAdapter, # noqa: F401 + ) except (ImportError, ModuleNotFoundError) as e: logger.error( f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->平台日志->安装Pip库 中安装依赖库。", diff --git a/astrbot/core/platform/sources/kook/kook_adapter.py b/astrbot/core/platform/sources/kook/kook_adapter.py new file mode 100644 index 0000000000..1124c6841d --- /dev/null +++ b/astrbot/core/platform/sources/kook/kook_adapter.py @@ -0,0 +1,371 @@ +import asyncio +import json +import re + +from astrbot import logger +from astrbot.api.event import MessageChain +from astrbot.api.message_components import At, AtAll, Image, Plain +from astrbot.api.platform import ( + AstrBotMessage, + MessageMember, + MessageType, + Platform, + PlatformMetadata, + register_platform_adapter, +) +from astrbot.core.platform.astr_message_event import MessageSesion + +from .kook_client import KookClient +from .kook_config import KookConfig +from .kook_event import KookEvent + + +@register_platform_adapter( + "kook", + "KOOK 适配器", +) +class KookPlatformAdapter(Platform): + def __init__( + self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue + ) -> None: + super().__init__(platform_config, event_queue) + self.kook_config = KookConfig.from_dict(platform_config) + logger.debug(f"[KOOK] 配置: {self.kook_config.pretty_jsons()}") + self.settings = platform_settings + self.client = KookClient(self.kook_config, self._on_received) + self._reconnect_task = None + self.running = False + self._main_task = None + + async def send_by_session( + self, session: MessageSesion, message_chain: MessageChain + ): + inner_message = AstrBotMessage() + inner_message.session_id = session.session_id + inner_message.type = session.message_type + message_event = KookEvent( + message_str=message_chain.get_plain_text(), + message_obj=inner_message, + platform_meta=self.meta(), + session_id=session.session_id, + client=self.client, + ) + await message_event.send(message_chain) + + def meta(self) -> PlatformMetadata: + return PlatformMetadata( + name="kook", description="KOOK 适配器", id=self.kook_config.id + ) + + def _should_ignore_event_by_bot_nickname(self, payload: dict) -> bool: + bot_nickname = self.kook_config.bot_nickname.strip() + if not bot_nickname: + return False + + author = payload.get("extra", {}).get("author", {}) + if not isinstance(author, dict): + return False + + author_nickname = author.get("nickname") or author.get("username") or "" + if not isinstance(author_nickname, str): + author_nickname = str(author_nickname) + + return author_nickname.strip().casefold() == bot_nickname.casefold() + + async def _on_received(self, data: dict): + logger.debug(f"KOOK 收到数据: {data}") + if "d" in data and data["s"] == 0: + payload = data["d"] + event_type = payload.get("type") + # 支持type=9(文本)和type=10(卡片) + if event_type in (9, 10): + if self._should_ignore_event_by_bot_nickname(payload): + return + try: + abm = await self.convert_message(payload) + await self.handle_msg(abm) + except Exception as e: + logger.error(f"[KOOK] 消息处理异常: {e}") + + async def run(self): + """主运行循环""" + self.running = True + logger.info("[KOOK] 启动KOOK适配器") + + # 启动主循环 + self._main_task = asyncio.create_task(self._main_loop()) + + try: + await self._main_task + except asyncio.CancelledError: + logger.info("[KOOK] 适配器被取消") + except Exception as e: + logger.error(f"[KOOK] 适配器运行异常: {e}") + finally: + self.running = False + await self._cleanup() + + async def _main_loop(self): + """主循环,处理连接和重连""" + consecutive_failures = 0 + max_consecutive_failures = self.kook_config.max_consecutive_failures + max_retry_delay = self.kook_config.max_retry_delay + + while self.running: + try: + logger.info("[KOOK] 尝试连接KOOK服务器...") + + # 尝试连接 + success = await self.client.connect() + + if success: + logger.info("[KOOK] 连接成功,开始监听消息") + consecutive_failures = 0 # 重置失败计数 + + # 等待连接结束(可能是正常关闭或异常) + while self.client.running and self.running: + try: + # 等待 client 内部触发 _stop_event,或者超时 1 秒后重试 + # 使用 wait_for 配合 timeout 是为了防止极端情况下 self.running 变化没被察觉 + await asyncio.wait_for( + self.client.wait_until_closed(), timeout=1.0 + ) + except asyncio.TimeoutError: + # 正常超时,继续下一轮 while 检查 + continue + + if self.running: + logger.warning("[KOOK] 连接断开,准备重连") + + else: + consecutive_failures += 1 + logger.error( + f"[KOOK] 连接失败,连续失败次数: {consecutive_failures}" + ) + + if consecutive_failures >= max_consecutive_failures: + logger.error("[KOOK] 连续失败次数过多,停止重连") + break + + # 等待一段时间后重试 + wait_time = min( + 2**consecutive_failures, max_retry_delay + ) # 指数退避 + logger.info(f"[KOOK] 等待 {wait_time} 秒后重试...") + await asyncio.sleep(wait_time) + + except Exception as e: + consecutive_failures += 1 + logger.error(f"[KOOK] 主循环异常: {e}") + + if consecutive_failures >= max_consecutive_failures: + logger.error("[KOOK] 连续异常次数过多,停止重连") + break + + await asyncio.sleep(5) + + async def _cleanup(self): + """清理资源""" + logger.info("[KOOK] 开始清理资源") + + if self.client: + try: + await self.client.close() + except Exception as e: + logger.error(f"[KOOK] 关闭客户端异常: {e}") + + if self._main_task and not self._main_task.done(): + self._main_task.cancel() + try: + await self._main_task + except asyncio.CancelledError: + pass + + logger.info("[KOOK] 资源清理完成") + + def _parse_kmarkdown_text_message( + self, data: dict, self_id: str + ) -> tuple[list, str]: + kmarkdown = data.get("extra", {}).get("kmarkdown", {}) + content = data.get("content") or "" + raw_content = kmarkdown.get("raw_content") or content + if not isinstance(content, str): + content = str(content) + if not isinstance(raw_content, str): + raw_content = str(raw_content) + + mention_name_map: dict[str, str] = {} + mention_part = kmarkdown.get("mention_part", []) + if isinstance(mention_part, list): + for item in mention_part: + if not isinstance(item, dict): + continue + mention_id = item.get("id") + if mention_id is None: + continue + mention_name_map[str(mention_id)] = str(item.get("username", "")) + + components = [] + cursor = 0 + for match in re.finditer(r"\(met\)([^()]+)\(met\)", content): + if match.start() > cursor: + plain_text = content[cursor : match.start()] + if plain_text: + components.append(Plain(text=plain_text)) + + mention_target = match.group(1).strip() + if mention_target == "all": + components.append(AtAll()) + elif mention_target: + components.append( + At( + qq=mention_target, + name=mention_name_map.get(mention_target, ""), + ) + ) + cursor = match.end() + + if cursor < len(content): + tail_text = content[cursor:] + if tail_text: + components.append(Plain(text=tail_text)) + + message_str = raw_content + if components: + for comp in components: + if isinstance(comp, Plain): + if not comp.text.strip(): + continue + break + if isinstance(comp, At): + if str(comp.qq) == str(self_id): + message_str = re.sub( + r"^@[^\s]+(\s*-\s*[^\s]+)?\s*", + "", + message_str, + count=1, + ).strip() + break + if not components: + if message_str: + components = [Plain(text=message_str)] + else: + components = [] + + return components, message_str + + def _parse_card_message(self, data: dict) -> tuple[list, str]: + content = data.get("content", "[]") + if not isinstance(content, str): + content = str(content) + card_list = json.loads(content) + + text_parts: list[str] = [] + images: list[str] = [] + + for card in card_list: + if not isinstance(card, dict): + continue + for module in card.get("modules", []): + if not isinstance(module, dict): + continue + + module_type = module.get("type") + if module_type == "section": + section_text = module.get("text", {}).get("content", "") + if section_text: + text_parts.append(str(section_text)) + continue + + if module_type != "container": + continue + + for element in module.get("elements", []): + if not isinstance(element, dict): + continue + if element.get("type") != "image": + continue + + image_src = element.get("src") + if not isinstance(image_src, str): + logger.warning( + f'[KOOK] 处理卡片中的图片时发生错误,图片url "{image_src}" 应该为str类型, 而不是 "{type(image_src)}" ' + ) + continue + if not image_src.startswith(("http://", "https://")): + logger.warning(f"[KOOK] 屏蔽非http图片url: {image_src}") + continue + images.append(image_src) + + text = "".join(text_parts) + message = [] + if text: + message.append(Plain(text=text)) + for img_url in images: + message.append(Image(file=img_url)) + return message, text + + async def convert_message(self, data: dict) -> AstrBotMessage: + abm = AstrBotMessage() + abm.raw_message = data + abm.self_id = self.client.bot_id + + channel_type = data.get("channel_type") + author_id = data.get("author_id", "unknown") + # channel_type定义: https://developer.kookapp.cn/doc/event/event-introduction + match channel_type: + case "GROUP": + session_id = data.get("target_id") or "unknown" + abm.type = MessageType.GROUP_MESSAGE + abm.group_id = session_id + abm.session_id = session_id + case "PERSON": + abm.type = MessageType.FRIEND_MESSAGE + abm.group_id = "" + abm.session_id = data.get("author_id", "unknown") + case "BROADCAST": + session_id = data.get("target_id") or "unknown" + abm.type = MessageType.OTHER_MESSAGE + abm.group_id = session_id + abm.session_id = session_id + case _: + raise ValueError(f"不支持的频道类型: {channel_type}") + + abm.sender = MessageMember( + user_id=author_id, + nickname=data.get("extra", {}).get("author", {}).get("username", ""), + ) + + abm.message_id = data.get("msg_id", "unknown") + + # 普通文本消息 + if data.get("type") == 9: + message, message_str = self._parse_kmarkdown_text_message( + data, str(abm.self_id) + ) + abm.message = message + abm.message_str = message_str + # 卡片消息 + elif data.get("type") == 10: + try: + abm.message, abm.message_str = self._parse_card_message(data) + except Exception as exp: + logger.error(f"[KOOK] 卡片消息解析失败: {exp}") + abm.message_str = "[卡片消息解析失败]" + abm.message = [Plain(text="[卡片消息解析失败]")] + else: + logger.warning(f'[KOOK] 不支持的kook消息类型: "{data.get("type")}"') + abm.message_str = "[不支持的消息类型]" + abm.message = [Plain(text="[不支持的消息类型]")] + + return abm + + async def handle_msg(self, message: AstrBotMessage): + message_event = KookEvent( + message_str=message.message_str, + message_obj=message, + platform_meta=self.meta(), + session_id=message.session_id, + client=self.client, + ) + self.commit_event(message_event) diff --git a/astrbot/core/platform/sources/kook/kook_client.py b/astrbot/core/platform/sources/kook/kook_client.py new file mode 100644 index 0000000000..a48a6fb658 --- /dev/null +++ b/astrbot/core/platform/sources/kook/kook_client.py @@ -0,0 +1,440 @@ +import asyncio +import base64 +import json +import os +import random +import time +import zlib +from pathlib import Path + +import aiofiles +import aiohttp +import websockets + +from astrbot import logger +from astrbot.core.platform.message_type import MessageType + +from .kook_config import KookConfig +from .kook_types import KookApiPaths, KookMessageType + + +class KookClient: + def __init__(self, config: KookConfig, event_callback): + # 数据字段 + self.config = config + self._bot_id = "" + self._bot_name = "" + + # 资源字段 + self._http_client = aiohttp.ClientSession( + headers={ + "Authorization": f"Bot {self.config.token}", + } + ) + self.event_callback = event_callback # 回调函数,用于处理接收到的事件 + self.ws = None + self.heartbeat_task = None + self._stop_event = asyncio.Event() # 用于通知连接结束 + + # 状态/计算字段 + self.running = False + self.session_id = None + self.last_sn = 0 # 记录最后处理的消息序号 + self.last_heartbeat_time = 0 + self.heartbeat_failed_count = 0 + + @property + def bot_id(self): + return self._bot_id + + @property + def bot_name(self): + return self._bot_name + + async def get_bot_info(self) -> str: + """获取机器人账号ID""" + url = KookApiPaths.USER_ME + + try: + async with self._http_client.get(url) as resp: + if resp.status != 200: + logger.error(f"[KOOK] 获取机器人账号ID失败,状态码: {resp.status}") + return "" + + data = await resp.json() + if data.get("code") != 0: + logger.error(f"[KOOK] 获取机器人账号ID失败: {data}") + return "" + + bot_id: str = data["data"]["id"] + self._bot_id = bot_id + logger.info(f"[KOOK] 获取机器人账号ID成功: {bot_id}") + bot_name: str = data["data"]["nickname"] or data["data"]["username"] + self._bot_name = bot_name + logger.info(f"[KOOK] 获取机器人名称成功: {self._bot_name}") + + return bot_id + except Exception as e: + logger.error(f"[KOOK] 获取机器人账号ID异常: {e}") + return "" + + async def get_gateway_url(self, resume=False, sn=0, session_id=None): + """获取网关连接地址""" + url = KookApiPaths.GATEWAY_INDEX + + # 构建连接参数 + params = {} + if resume: + params["resume"] = 1 + params["sn"] = sn + if session_id: + params["session_id"] = session_id + + try: + async with self._http_client.get(url, params=params) as resp: + if resp.status != 200: + logger.error(f"[KOOK] 获取gateway失败,状态码: {resp.status}") + return None + + data = await resp.json() + if data.get("code") != 0: + logger.error(f"[KOOK] 获取gateway失败: {data}") + return None + + gateway_url: str = data["data"]["url"] + logger.info(f"[KOOK] 获取gateway成功: {gateway_url.split('?')[0]}") + return gateway_url + except Exception as e: + logger.error(f"[KOOK] 获取gateway异常: {e}") + return None + + async def connect(self, resume=False): + """连接WebSocket""" + if self.ws: + try: + await self.ws.close() + except Exception: + pass + self.ws = None + self._stop_event.clear() + try: + # 获取gateway地址 + gateway_url = await self.get_gateway_url( + resume=resume, sn=self.last_sn, session_id=self.session_id + ) + await self.get_bot_info() + + if not gateway_url: + return False + + # 连接WebSocket + self.ws = await websockets.connect(gateway_url) + self.running = True + logger.info("[KOOK] WebSocket 连接成功") + + # 启动心跳任务 + if self.heartbeat_task: + self.heartbeat_task.cancel() + self.heartbeat_task = asyncio.create_task(self._heartbeat_loop()) + + # 开始监听消息 + await self.listen() + return True + + except Exception as e: + logger.error(f"[KOOK] WebSocket 连接失败: {e}") + if self.ws: + try: + await self.ws.close() + except Exception: + pass + self.ws = None + return False + + async def listen(self): + """监听WebSocket消息""" + try: + while self.running: + try: + msg = await asyncio.wait_for(self.ws.recv(), timeout=10) # type: ignore + + if isinstance(msg, bytes): + try: + msg = zlib.decompress(msg) + except Exception as e: + logger.error(f"[KOOK] 解压消息失败: {e}") + continue + msg = msg.decode("utf-8") + + logger.debug(f"[KOOK] 收到原始消息: {msg}") + data = json.loads(msg) + + # 处理不同类型的信令 + await self._handle_signal(data) + + except asyncio.TimeoutError: + # 超时检查,继续循环 + continue + except websockets.exceptions.ConnectionClosed: + logger.warning("[KOOK] WebSocket连接已关闭") + break + except Exception as e: + logger.error(f"[KOOK] 消息处理异常: {e}") + break + + except Exception as e: + logger.error(f"[KOOK] WebSocket 监听异常: {e}") + finally: + self.running = False + self._stop_event.set() + + async def _handle_signal(self, data): + """处理不同类型的信令""" + signal_type = data.get("s") + + if signal_type == 0: # 事件消息 + # 更新消息序号 + if "sn" in data: + self.last_sn = data["sn"] + await self.event_callback(data) + + elif signal_type == 1: # HELLO握手 + await self._handle_hello(data) + + elif signal_type == 3: # PONG心跳响应 + await self._handle_pong(data) + + elif signal_type == 5: # RECONNECT重连指令 + await self._handle_reconnect(data) + + elif signal_type == 6: # RESUME ACK + await self._handle_resume_ack(data) + + else: + logger.debug(f"[KOOK] 未处理的信令类型: {signal_type}") + + async def _handle_hello(self, data): + """处理HELLO握手""" + hello_data = data.get("d", {}) + code = hello_data.get("code", 0) + + if code == 0: + self.session_id = hello_data.get("session_id") + logger.info(f"[KOOK] 握手成功,session_id: {self.session_id}") + # TODO 重置重连延迟 + # self.reconnect_delay = 1 + else: + logger.error(f"[KOOK] 握手失败,错误码: {code}") + if code == 40103: # token过期 + logger.error("[KOOK] Token已过期,需要重新获取") + self.running = False + + async def _handle_pong(self, data): + """处理PONG心跳响应""" + self.last_heartbeat_time = time.time() + self.heartbeat_failed_count = 0 + logger.debug("[KOOK] 收到心跳响应") + + async def _handle_reconnect(self, data): + """处理重连指令""" + logger.warning("[KOOK] 收到重连指令") + # 清空本地状态 + self.last_sn = 0 + self.session_id = None + self.running = False + + async def _handle_resume_ack(self, data): + """处理RESUME确认""" + resume_data = data.get("d", {}) + self.session_id = resume_data.get("session_id") + logger.info(f"[KOOK] Resume成功,session_id: {self.session_id}") + + async def _heartbeat_loop(self): + """心跳循环""" + while self.running: + try: + # 随机化心跳间隔 (±5秒) + interval = max( + 1, self.config.heartbeat_interval + random.randint(-5, 5) + ) + await asyncio.sleep(interval) + + if not self.running: + break + + # 发送心跳 + await self._send_ping() + + # 等待PONG响应 + await asyncio.sleep(self.config.heartbeat_timeout) + + # 检查是否收到PONG响应 + if ( + time.time() - self.last_heartbeat_time + > self.config.heartbeat_timeout + ): + self.heartbeat_failed_count += 1 + logger.warning( + f"[KOOK] 心跳超时,失败次数: {self.heartbeat_failed_count}" + ) + + if ( + self.heartbeat_failed_count + >= self.config.max_heartbeat_failures + ): + logger.error("[KOOK] 心跳失败次数过多,准备重连") + self.running = False + break + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"[KOOK] 心跳异常: {e}") + self.heartbeat_failed_count += 1 + + async def _send_ping(self): + """发送心跳PING""" + try: + ping_data = {"s": 2, "sn": self.last_sn} + await self.ws.send(json.dumps(ping_data)) # type: ignore + logger.debug(f"[KOOK] 发送心跳,sn: {self.last_sn}") + except Exception as e: + logger.error(f"[KOOK] 发送心跳失败: {e}") + + async def send_text( + self, + target_id: str, + content: str, + astrbot_message_type: MessageType, + kook_message_type: KookMessageType, + reply_message_id: str | int = "", + ): + """发送文本消息 + 消息发送接口文档参见: https://developer.kookapp.cn/doc/http/message#%E5%8F%91%E9%80%81%E9%A2%91%E9%81%93%E8%81%8A%E5%A4%A9%E6%B6%88%E6%81%AF + KMarkdown格式参见: https://developer.kookapp.cn/doc/kmarkdown-desc + """ + url = KookApiPaths.CHANNEL_MESSAGE_CREATE + if astrbot_message_type == MessageType.FRIEND_MESSAGE: + url = KookApiPaths.DIRECT_MESSAGE_CREATE + + payload = { + "target_id": target_id, + "content": content, + "type": kook_message_type, + } + if reply_message_id: + payload["quote"] = reply_message_id + payload["reply_msg_id"] = reply_message_id + + try: + async with self._http_client.post(url, json=payload) as resp: + if resp.status == 200: + result = await resp.json() + if result.get("code") != 0: + raise RuntimeError( + f'发送kook消息类型 "{kook_message_type.name}" 失败: {result}' + ) + # else: + # logger.info("[KOOK] 发送消息成功") + else: + raise RuntimeError( + f'发送kook消息类型 "{kook_message_type.name}" HTTP错误: {resp.status} , 响应内容 : {await resp.text()}' + ) + except RuntimeError: + raise + except Exception as e: + logger.error( + f'[KOOK] 发送kook消息类型 "{kook_message_type.name}" 异常: {e}' + ) + + async def upload_asset(self, file_url: str | None) -> str: + """上传文件到kook,获得远端资源url + 接口定义参见: https://developer.kookapp.cn/doc/http/asset + """ + if not file_url: + return "" + + bytes_data: bytes | None = None + filename = "unknown" + if file_url.startswith(("http://", "https://")): + filename = file_url.split("/")[-1] + return file_url + + if file_url.startswith("base64:///"): + # b64decode的时候得开头留一个'/'的, 不然会报错 + b64_str = file_url.removeprefix("base64://") + bytes_data = base64.b64decode(b64_str) + + elif file_url.startswith("file://") or os.path.exists(file_url): + file_url = file_url.removeprefix("file:///") + file_url = file_url.removeprefix("file://") + + try: + target_path = Path(file_url).resolve() + except Exception as exp: + logger.error(f'[KOOK] 获取文件 "{file_url}" 绝对路径失败: "{exp}"') + raise FileNotFoundError( + f'获取文件 "{file_url}" 绝对路径失败: "{exp}"' + ) from exp + + if not target_path.is_file(): + raise FileNotFoundError(f"文件不存在: {target_path.name}") + + filename = target_path.name + async with aiofiles.open(target_path, "rb") as f: + bytes_data = await f.read() + + else: + raise ValueError(f'[KOOK] 不支持的文件资源类型: "{file_url}"') + + data = aiohttp.FormData() + data.add_field("file", bytes_data, filename=filename) + + url = KookApiPaths.ASSET_CREATE + try: + async with self._http_client.post(url, data=data) as resp: + if resp.status == 200: + result: dict = await resp.json() + logger.debug(f"[KOOK] 上传文件响应: {result}") + if result.get("code") == 0: + logger.info("[KOOK] 上传文件到kook服务器成功") + remote_url = result["data"]["url"] + logger.debug(f"[KOOK] 文件远端URL: {remote_url}") + return remote_url + else: + raise RuntimeError(f"上传文件到kook服务器失败: {result}") + else: + raise RuntimeError( + f"上传文件到kook服务器 HTTP错误: {resp.status} , {await resp.text()}" + ) + except RuntimeError: + raise + except Exception as e: + raise RuntimeError(f"上传文件到kook服务器异常: {e}") from e + + async def wait_until_closed(self): + """提供给外部调用的等待方法""" + await self._stop_event.wait() + + async def close(self): + """关闭连接""" + self.running = False + self._stop_event.set() + + if self.heartbeat_task: + self.heartbeat_task.cancel() + try: + await self.heartbeat_task + except asyncio.CancelledError: + pass + + if self.ws: + try: + await self.ws.close() + except Exception as e: + logger.error(f"[KOOK] 关闭WebSocket异常: {e}") + + if self._http_client: + await self._http_client.close() + + logger.info("[KOOK] 连接已关闭") diff --git a/astrbot/core/platform/sources/kook/kook_config.py b/astrbot/core/platform/sources/kook/kook_config.py new file mode 100644 index 0000000000..21f2547b03 --- /dev/null +++ b/astrbot/core/platform/sources/kook/kook_config.py @@ -0,0 +1,133 @@ +import json +from dataclasses import asdict, dataclass +from typing import Any + + +@dataclass +class KookConfig: + """KOOK 适配器配置类""" + + # 基础配置 + token: str + bot_nickname: str = "" + enable: bool = False + id: str = "kook" + + # 重连配置 + reconnect_delay: int = 1 + """重连延迟基数(秒),指数退避""" + max_reconnect_delay: int = 60 + """最大重连延迟(秒)""" + max_retry_delay: int = 60 + """最大重试延迟(秒)""" + + # 心跳配置 + heartbeat_interval: int = 30 + """心跳间隔(秒)""" + heartbeat_timeout: int = 6 + """心跳超时时间(秒)""" + max_heartbeat_failures: int = 3 + """最大心跳失败次数""" + + # 失败处理 + max_consecutive_failures: int = 5 + """最大连续失败次数""" + + @classmethod + def from_dict(cls, config_dict: dict) -> "KookConfig": + """从字典创建配置对象""" + return cls( + # 适配器id 应该是不能改的 + # id=config_dict.get("id", "kook"), + enable=config_dict.get("enable", False), + token=config_dict.get("kook_bot_token", ""), + bot_nickname=config_dict.get("kook_bot_nickname", ""), + reconnect_delay=config_dict.get( + "kook_reconnect_delay", + KookConfig.reconnect_delay, + ), + max_reconnect_delay=config_dict.get( + "kook_max_reconnect_delay", + KookConfig.max_reconnect_delay, + ), + max_retry_delay=config_dict.get( + "kook_max_retry_delay", + KookConfig.max_retry_delay, + ), + heartbeat_interval=config_dict.get( + "kook_heartbeat_interval", + KookConfig.heartbeat_interval, + ), + heartbeat_timeout=config_dict.get( + "kook_heartbeat_timeout", + KookConfig.heartbeat_timeout, + ), + max_heartbeat_failures=config_dict.get( + "kook_max_heartbeat_failures", + KookConfig.max_heartbeat_failures, + ), + max_consecutive_failures=config_dict.get( + "kook_max_consecutive_failures", + KookConfig.max_consecutive_failures, + ), + ) + + def to_dict(self) -> dict[str, Any]: + return asdict(self) + + def pretty_jsons(self, indent=2) -> str: + dict_config = self.to_dict() + dict_config["token"] = "*" * len(self.token) if self.token else "MISSING" + return json.dumps(dict_config, indent=indent, ensure_ascii=False) + + +# TODO 没用上的config配置,未来有空会实现这些配置描述的功能? +# # 连接配置 +# CONNECTION_CONFIG = { +# # 心跳配置 +# "heartbeat_interval": 30, # 心跳间隔(秒) +# "heartbeat_timeout": 6, # 心跳超时时间(秒) +# "max_heartbeat_failures": 3, # 最大心跳失败次数 +# # 重连配置 +# "initial_reconnect_delay": 1, # 初始重连延迟(秒) +# "max_reconnect_delay": 60, # 最大重连延迟(秒) +# "max_consecutive_failures": 5, # 最大连续失败次数 +# # WebSocket配置 +# "websocket_timeout": 10, # WebSocket接收超时(秒) +# "connection_timeout": 30, # 连接超时(秒) +# # 消息处理配置 +# "enable_compression": True, # 是否启用消息压缩 +# "max_message_size": 1024 * 1024, # 最大消息大小(字节) +# } + +# # 日志配置 +# LOGGING_CONFIG = { +# "level": "INFO", # 日志级别:DEBUG, INFO, WARNING, ERROR +# "format": "[KOOK] %(message)s", +# "enable_heartbeat_logs": False, # 是否启用心跳日志 +# "enable_message_logs": False, # 是否启用消息日志 +# } + +# # 错误处理配置 +# ERROR_HANDLING_CONFIG = { +# "retry_on_network_error": True, # 网络错误时是否重试 +# "retry_on_token_expired": True, # Token过期时是否重试 +# "max_retry_attempts": 3, # 最大重试次数 +# "retry_delay_base": 2, # 重试延迟基数(秒) +# } + +# # 性能配置 +# PERFORMANCE_CONFIG = { +# "enable_message_buffering": True, # 是否启用消息缓冲 +# "buffer_size": 100, # 缓冲区大小 +# "enable_connection_pooling": True, # 是否启用连接池 +# "max_concurrent_requests": 10, # 最大并发请求数 +# } + +# # 安全配置 +# SECURITY_CONFIG = { +# "verify_ssl": True, # 是否验证SSL证书 +# "enable_rate_limiting": True, # 是否启用速率限制 +# "rate_limit_requests": 100, # 速率限制请求数 +# "rate_limit_window": 60, # 速率限制窗口(秒) +# } diff --git a/astrbot/core/platform/sources/kook/kook_event.py b/astrbot/core/platform/sources/kook/kook_event.py new file mode 100644 index 0000000000..12f72a9790 --- /dev/null +++ b/astrbot/core/platform/sources/kook/kook_event.py @@ -0,0 +1,209 @@ +import asyncio +import json +from collections.abc import Coroutine +from pathlib import Path +from typing import Any + +from astrbot import logger +from astrbot.api.event import AstrMessageEvent, MessageChain +from astrbot.api.platform import AstrBotMessage, PlatformMetadata +from astrbot.core.message.components import ( + At, + AtAll, + BaseMessageComponent, + File, + Image, + Json, + Plain, + Record, + Reply, + Video, +) +from astrbot.core.platform import MessageType + +from .kook_client import KookClient +from .kook_types import ( + FileModule, + KookCardMessage, + KookCardMessageContainer, + KookMessageType, + OrderMessage, +) + + +class KookEvent(AstrMessageEvent): + def __init__( + self, + message_str: str, + message_obj: AstrBotMessage, + platform_meta: PlatformMetadata, + session_id: str, + client: KookClient, + ): + super().__init__(message_str, message_obj, platform_meta, session_id) + self.client = client + self.channel_id = message_obj.group_id or message_obj.session_id + self.astrbot_message_type: MessageType = message_obj.type + self._file_message_counter = 0 + + def _wrap_message( + self, index: int, message_component: BaseMessageComponent + ) -> Coroutine[Any, Any, OrderMessage]: + async def wrap_upload( + index: int, message_type: KookMessageType, upload_coro + ) -> OrderMessage: + url = await upload_coro + return OrderMessage(index=index, text=url, type=message_type) + + async def handle_plain( + index: int, + text: str | None, + reply_id: str | int = "", + type: KookMessageType = KookMessageType.KMARKDOWN, + ): + if not text: + text = "" + return OrderMessage( + index=index, + text=text, + type=type, + reply_id=reply_id, + ) + + match message_component: + case Image(): + self._file_message_counter += 1 + return wrap_upload( + index, + KookMessageType.IMAGE, + self.client.upload_asset(message_component.file), + ) + + case Video(): + self._file_message_counter += 1 + return wrap_upload( + index, + KookMessageType.VIDEO, + self.client.upload_asset(message_component.file), + ) + case File(): + + async def handle_file(index: int, f_item: File): + f_data = await f_item.get_file() + url = await self.client.upload_asset(f_data) + return OrderMessage( + index=index, text=url, type=KookMessageType.FILE + ) + + self._file_message_counter += 1 + return handle_file(index, message_component) + + case Record(): + + async def handle_audio(index: int, f_item: Record): + file_path = await f_item.convert_to_file_path() + url = await self.client.upload_asset(file_path) + title = f_item.text or Path(file_path).name + return OrderMessage( + index=index, + text=KookCardMessageContainer( + [ + KookCardMessage( + modules=[ + FileModule( + type="audio", + title=title, + src=url, + ) + ] + ) + ] + ).to_json(), + type=KookMessageType.CARD, + ) + + return handle_audio(index, message_component) + case Plain(): + return handle_plain(index, message_component.text) + case At(): + return handle_plain(index, f"(met){message_component.qq}(met)") + case AtAll(): + return handle_plain(index, "(met)all(met)") + case Reply(): + return handle_plain(index, "", reply_id=message_component.id) + case Json(): + json_data = message_component.data + # kook卡片json外层得是一个列表 + if isinstance(json_data, dict): + json_data = [json_data] + return handle_plain( + index, + # 考虑到kook可能会更改消息结构,为了能让插件开发者 + # 自行根据kook文档描述填卡片json内容,故不做模型校验 + # KookCardMessage().model_validate(message_component.data).to_json(), + text=json.dumps(json_data), + type=KookMessageType.CARD, + ) + case _: + raise NotImplementedError( + f'kook适配器尚未实现对 "{message_component.type}" 消息类型的支持' + ) + + async def send(self, message: MessageChain): + file_upload_tasks: list[Coroutine[Any, Any, OrderMessage]] = [] + for index, item in enumerate(message.chain): + file_upload_tasks.append(self._wrap_message(index, item)) + + if self._file_message_counter > 0: + logger.debug("[Kook] 正在向kook服务器上传文件") + + tasks_result = await asyncio.gather(*file_upload_tasks, return_exceptions=True) + order_messages: list[OrderMessage] = [] + + for index, result in enumerate(tasks_result): + if isinstance(result, BaseException): + logger.error(f"[Kook] {result}") + # 构造一个虚假的 OrderMessage,让用户知道这里本来有张图但坏了 + # 这样后面的 for 循环就能把它当成普通文本发出去 + err_node = OrderMessage( + index=index, + text=str(result), + type=KookMessageType.TEXT, + ) + order_messages.append(err_node) + else: + order_messages.append(result) + + order_messages.sort(key=lambda x: x.index) + + reply_id: str | int = "" + errors: list[Exception] = [] + for item in order_messages: + if item.reply_id: + reply_id = item.reply_id + if not item.text: + logger.debug(f'[Kook] 跳过空消息,类型为"{item.type}"') + continue + try: + await self.client.send_text( + self.channel_id, + item.text, + self.astrbot_message_type, + item.type, + reply_id, + ) + except RuntimeError as exp: + await self.client.send_text( + self.channel_id, + str(exp), + self.astrbot_message_type, + KookMessageType.TEXT, + reply_id, + ) + errors.append(exp) + + if errors: + err_msg = "\n".join([str(err) for err in errors]) + logger.error(f"[kook] {err_msg}") + + await super().send(message) diff --git a/astrbot/core/platform/sources/kook/kook_types.py b/astrbot/core/platform/sources/kook/kook_types.py new file mode 100644 index 0000000000..dd18ac00f1 --- /dev/null +++ b/astrbot/core/platform/sources/kook/kook_types.py @@ -0,0 +1,241 @@ +import json +from dataclasses import field +from enum import IntEnum +from typing import Literal + +from pydantic import BaseModel, ConfigDict +from pydantic.dataclasses import dataclass + + +class KookApiPaths: + """Kook Api 路径""" + + BASE_URL = "https://www.kookapp.cn" + API_VERSION_PATH = "/api/v3" + + # 初始化相关 + USER_ME = f"{BASE_URL}{API_VERSION_PATH}/user/me" + GATEWAY_INDEX = f"{BASE_URL}{API_VERSION_PATH}/gateway/index" + + # 消息相关 + ASSET_CREATE = f"{BASE_URL}{API_VERSION_PATH}/asset/create" + ## 频道消息 + CHANNEL_MESSAGE_CREATE = f"{BASE_URL}{API_VERSION_PATH}/message/create" + ## 私聊消息 + DIRECT_MESSAGE_CREATE = f"{BASE_URL}{API_VERSION_PATH}/direct-message/create" + + +# 定义参见kook事件结构文档: https://developer.kookapp.cn/doc/event/event-introduction +class KookMessageType(IntEnum): + TEXT = 1 + IMAGE = 2 + VIDEO = 3 + FILE = 4 + AUDIO = 8 + KMARKDOWN = 9 + CARD = 10 + SYSTEM = 255 + + +ThemeType = Literal[ + "primary", "success", "danger", "warning", "info", "secondary", "none", "invisible" +] +"""主题,可选的值为:primary, success, danger, warning, info, secondary, none.默认为 primary,为 none 时不显示侧边框。""" +SizeType = Literal["xs", "sm", "md", "lg"] +"""大小,可选值为:xs, sm, md, lg, 一般默认为 lg""" + +SectionMode = Literal["left", "right"] +CountdownMode = Literal["day", "hour", "second"] + + +class KookCardColor(str): + """16 进制色值""" + + +class KookCardModelBase: + """卡片模块基类""" + + type: str + + +@dataclass +class PlainTextElement(KookCardModelBase): + content: str + type: str = "plain-text" + emoji: bool = True + + +@dataclass +class KmarkdownElement(KookCardModelBase): + content: str + type: str = "kmarkdown" + + +@dataclass +class ImageElement(KookCardModelBase): + src: str + type: str = "image" + alt: str = "" + size: SizeType = "lg" + circle: bool = False + fallbackUrl: str | None = None + + +@dataclass +class ButtonElement(KookCardModelBase): + text: str + type: str = "button" + theme: ThemeType = "primary" + value: str = "" + """当为 link 时,会跳转到 value 代表的链接; +当为 return-val 时,系统会通过系统消息将消息 id,点击用户 id 和 value 发回给发送者,发送者可以根据自己的需求进行处理,消息事件参见button 点击事件。私聊和频道内均可使用按钮点击事件。""" + click: Literal["", "link", "return-val"] = "" + """click 代表用户点击的事件,默认为"",代表无任何事件。""" + + +AnyElement = PlainTextElement | KmarkdownElement | ImageElement | ButtonElement | str + + +@dataclass +class ParagraphStructure(KookCardModelBase): + fields: list[PlainTextElement | KmarkdownElement] + type: str = "paragraph" + cols: int = 1 + """范围是 1-3 , 移动端忽略此参数""" + + +@dataclass +class HeaderModule(KookCardModelBase): + text: PlainTextElement + type: str = "header" + + +@dataclass +class SectionModule(KookCardModelBase): + text: PlainTextElement | KmarkdownElement | ParagraphStructure + type: str = "section" + mode: SectionMode = "left" + accessory: ImageElement | ButtonElement | None = None + + +@dataclass +class ImageGroupModule(KookCardModelBase): + """1 到多张图片的组合""" + + elements: list[ImageElement] + type: str = "image-group" + + +@dataclass +class ContainerModule(KookCardModelBase): + """1 到多张图片的组合,与图片组模块(ImageGroupModule)不同,图片并不会裁切为正方形。多张图片会纵向排列。""" + + elements: list[ImageElement] + type: str = "container" + + +@dataclass +class ActionGroupModule(KookCardModelBase): + elements: list[ButtonElement] + type: str = "action-group" + + +@dataclass +class ContextModule(KookCardModelBase): + elements: list[PlainTextElement | KmarkdownElement | ImageElement] + """最多包含10个元素""" + type: str = "context" + + +@dataclass +class DividerModule(KookCardModelBase): + type: str = "divider" + + +@dataclass +class FileModule(KookCardModelBase): + src: str + title: str = "" + type: Literal["file", "audio", "video"] = "file" + cover: str | None = None + """cover 仅音频有效, 是音频的封面图""" + + +@dataclass +class CountdownModule(KookCardModelBase): + """startTime 和 endTime 为毫秒时间戳,startTime 和 endTime 不能小于服务器当前时间戳。""" + + endTime: int + """毫秒时间戳""" + type: str = "countdown" + startTime: int | None = None + """毫秒时间戳, 仅当mode为second才有这个字段""" + mode: CountdownMode = "day" + """mode 主要是倒计时的样式""" + + +@dataclass +class InviteModule(KookCardModelBase): + code: str + """邀请链接或者邀请码""" + type: str = "invite" + + +# 所有模块的联合类型 +AnyModule = ( + HeaderModule + | SectionModule + | ImageGroupModule + | ContainerModule + | ActionGroupModule + | ContextModule + | DividerModule + | FileModule + | CountdownModule + | InviteModule +) + + +class KookCardMessage(BaseModel): + """卡片定义文档详见 : https://developer.kookapp.cn/doc/cardmessage + 此类型不能直接to_json后发送,因为kook要求卡片容器json顶层必须是**列表** + 若要发送卡片消息,请使用KookCardMessageContainer + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + type: str = "card" + theme: ThemeType | None = None + size: SizeType | None = None + color: KookCardColor | None = None + modules: list[AnyModule] = field(default_factory=list) + """单个 card 模块数量不限制,但是一条消息中所有卡片的模块数量之和最多是 50""" + + def add_module(self, module: AnyModule): + self.modules.append(module) + + def to_dict(self, exclude_none: bool = True): + """exclude_none:去掉值为 None 字段,保留结构""" + return self.model_dump(exclude_none=exclude_none) + + def to_json(self, indent: int | None = None, ensure_ascii: bool = True): + return json.dumps(self.to_dict(), indent=indent, ensure_ascii=ensure_ascii) + + +class KookCardMessageContainer(list[KookCardMessage]): + """卡片消息容器(列表),此类型可以直接to_json后发送出去""" + + def append(self, object: KookCardMessage) -> None: + return super().append(object) + + def to_json(self, indent: int | None = None, ensure_ascii: bool = True) -> str: + return json.dumps( + [i.to_dict() for i in self], indent=indent, ensure_ascii=ensure_ascii + ) + + +@dataclass +class OrderMessage: + index: int + text: str + type: KookMessageType + reply_id: str | int = "" diff --git a/dashboard/src/i18n/locales/en-US/features/config-metadata.json b/dashboard/src/i18n/locales/en-US/features/config-metadata.json index b8473dae63..a143678c23 100644 --- a/dashboard/src/i18n/locales/en-US/features/config-metadata.json +++ b/dashboard/src/i18n/locales/en-US/features/config-metadata.json @@ -584,6 +584,51 @@ "only_use_webhook_url_to_send": { "description": "Send Replies via Webhook Only", "hint": "When enabled, all WeCom AI Bot replies are sent through msg_push_webhook_url. The message push webhook supports more message types (such as images, files, etc.). If you do not need the typing effect, it is strongly recommended to use this option. " + }, + "kook_bot_token": { + "description": "Bot Token", + "type": "string", + "hint": "Required. The Bot Token obtained from the KOOK Developer Platform." + }, + "kook_bot_nickname": { + "description": "Bot Nickname", + "type": "string", + "hint": "Optional. If the sender nickname matches this value, the message will be ignored to prevent broadcast storms." + }, + "kook_reconnect_delay": { + "description": "Reconnect Delay", + "type": "int", + "hint": "Delay time for reconnection (seconds), using an exponential backoff strategy." + }, + "kook_max_reconnect_delay": { + "description": "Max Reconnect Delay", + "type": "int", + "hint": "The maximum value for reconnection delay (seconds)." + }, + "kook_max_retry_delay": { + "description": "Max Retry Delay", + "type": "int", + "hint": "The maximum delay time for retries (seconds)." + }, + "kook_heartbeat_interval": { + "description": "Heartbeat Interval", + "type": "int", + "hint": "The interval time for heartbeat detection (seconds)." + }, + "kook_heartbeat_timeout": { + "description": "Heartbeat Timeout", + "type": "int", + "hint": "The timeout duration for heartbeat detection (seconds)." + }, + "kook_max_heartbeat_failures": { + "description": "Max Heartbeat Failures", + "type": "int", + "hint": "Maximum allowed heartbeat failures; the connection will be dropped if exceeded." + }, + "kook_max_consecutive_failures": { + "description": "Max Consecutive Failures", + "type": "int", + "hint": "Maximum allowed consecutive failures; retries will stop if exceeded." } }, "general": { diff --git a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json index e3a52258f3..015ce3082c 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json +++ b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json @@ -587,6 +587,51 @@ "only_use_webhook_url_to_send": { "description": "仅使用 Webhook 发送消息", "hint": "启用后,企业微信智能机器人的所有回复都改为通过消息推送 Webhook 发送。消息推送 Webhook 支持更多的消息类型(如图片、文件等)。如果不需要打字机效果,强烈建议使用此选项。" + }, + "kook_bot_token": { + "description": "机器人 Token", + "type": "string", + "hint": "必填项。从 KOOK 开发者平台获取的机器人 Token" + }, + "kook_bot_nickname": { + "description": "Bot Nickname", + "type": "string", + "hint": "可选项。若发送者昵称与此值一致,将忽略该消息。" + }, + "kook_reconnect_delay": { + "description": "重连延迟", + "type": "int", + "hint": "重连延迟时间(秒),使用指数退避策略" + }, + "kook_max_reconnect_delay": { + "description": "最大重连延迟", + "type": "int", + "hint": "重连延迟的最大值(秒)" + }, + "kook_max_retry_delay": { + "description": "最大重试延迟", + "type": "int", + "hint": "重试的最大延迟时间(秒)" + }, + "kook_heartbeat_interval": { + "description": "心跳间隔", + "type": "int", + "hint": "心跳检测间隔时间(秒)" + }, + "kook_heartbeat_timeout": { + "description": "心跳超时时间", + "type": "int", + "hint": "心跳检测超时时间(秒)" + }, + "kook_max_heartbeat_failures": { + "description": "最大心跳失败次数", + "type": "int", + "hint": "允许的最大心跳失败次数,超过后断开连接" + }, + "kook_max_consecutive_failures": { + "description": "最大连续失败次数", + "type": "int", + "hint": "允许的最大连续失败次数,超过后停止重试" } }, "general": { diff --git a/tests/test_kook/shared.py b/tests/test_kook/shared.py new file mode 100644 index 0000000000..5c5c9da86c --- /dev/null +++ b/tests/test_kook/shared.py @@ -0,0 +1,4 @@ +from pathlib import Path + + +TEST_DATA_DIR = Path(__file__).parent / "data" diff --git a/tests/test_kook/test_kook_event.py b/tests/test_kook/test_kook_event.py new file mode 100644 index 0000000000..253839506e --- /dev/null +++ b/tests/test_kook/test_kook_event.py @@ -0,0 +1,223 @@ +from unittest.mock import AsyncMock, MagicMock + +import pytest +from astrbot.api.platform import AstrBotMessage, MessageType, PlatformMetadata, Unknown +from astrbot.api.event import MessageChain +from astrbot.core.message.components import ( + File, + Image, + Plain, + Video, + At, + AtAll, + BaseMessageComponent, + Json, + Record, + Reply, +) + + +from astrbot.core.platform.sources.kook.kook_event import KookEvent +from astrbot.core.platform.sources.kook.kook_types import KookMessageType, OrderMessage + + +async def mock_kook_client(upload_asset_return: str, send_text_return: str): + # 1. Mock 掉整个 KookClient 类 + client = MagicMock() + + client.upload_asset = AsyncMock(return_value=upload_asset_return) + client.send_text = AsyncMock(return_value=send_text_return) + return client + + +def mock_file_message(input: str): + message = MagicMock(spec=File) + message.get_file = AsyncMock(return_value=input) + return message + + +def mock_record_message(input: str): + message = MagicMock(spec=Record) + message.text = input + message.convert_to_file_path = AsyncMock(return_value=input) + return message + + +def mock_astrbot_message(): + message = AstrBotMessage() + message.type = MessageType.OTHER_MESSAGE + message.group_id = "test" + message.session_id = "test" + message.message_id = "test" + return message + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "input_message,upload_asset_return, expected_output, expected_error", + [ + ( + Image("test image"), + "test image", + OrderMessage( + 1, + text="test image", + type=KookMessageType.IMAGE, + ), + None, + ), + ( + Video("test video"), + "test video", + OrderMessage( + 1, + text="test video", + type=KookMessageType.VIDEO, + ), + None, + ), + ( + mock_file_message("test file"), + "test file", + OrderMessage( + 1, + text="test file", + type=KookMessageType.FILE, + ), + None, + ), + ( + mock_record_message("./tests/file.wav"), + "./tests/file.wav", + OrderMessage( + 1, + text='[{"type": "card", "modules": [{"src": "./tests/file.wav", "title": "./tests/file.wav", "type": "audio"}]}]', + type=KookMessageType.CARD, + ), + None, + ), + ( + Plain("test plain"), + "test plain", + OrderMessage( + 1, + text="test plain", + type=KookMessageType.KMARKDOWN, + ), + None, + ), + ( + At(qq="test at"), + "test at", + OrderMessage( + 1, + text="(met)test at(met)", + type=KookMessageType.KMARKDOWN, + ), + None, + ), + ( + AtAll(qq="all"), + "test atAll", + OrderMessage( + 1, + text="(met)all(met)", + type=KookMessageType.KMARKDOWN, + ), + None, + ), + ( + Reply(id="test reply"), + "test reply", + OrderMessage( + 1, + text="", + type=KookMessageType.KMARKDOWN, + reply_id="test reply", + ), + None, + ), + ( + Json(data={"test": "json"}), + "test json", + OrderMessage( + 1, + text='[{"test": "json"}]', + type=KookMessageType.CARD, + ), + None, + ), + ( + Unknown(text="test unknown"), + "test unknown", + None, + NotImplementedError, + ), + ], +) +async def test_kook_event_warp_message( + input_message: BaseMessageComponent, + upload_asset_return: str, + expected_output: OrderMessage, + expected_error: type[Exception] | None, +): + client = await mock_kook_client( + upload_asset_return, + "", + ) + + event = KookEvent( + "", + mock_astrbot_message(), + PlatformMetadata( + name="test", + id="test", + description="test", + ), + "", + client, + ) + + if expected_error: + with pytest.raises(expected_error): + await event._wrap_message(1, input_message) + return + + result = await event._wrap_message(1, input_message) + assert result == expected_output + + +# @pytest.mark.asyncio +# @pytest.mark.parametrize( +# "message_chain,send_text_expected_output,expected_error", +# [ +# ( +# MessageChain( +# chain=[ +# Image(file="test image"), +# Plain(text="test plain"), +# ], +# ), +# "" +# ), +# ], +# ) +# async def test_kook_event_send(): +# client = await mock_kook_client( +# "", +# "", +# ) + +# event = KookEvent( +# "", +# mock_astrbot_message(), +# PlatformMetadata( +# name="test", +# id="test", +# description="test", +# ), +# "", +# client, +# ) + +# await event.send(message=mock_astrbot_message()) diff --git a/tests/test_kook/test_kook_types.py b/tests/test_kook/test_kook_types.py new file mode 100644 index 0000000000..760e36c596 --- /dev/null +++ b/tests/test_kook/test_kook_types.py @@ -0,0 +1,107 @@ +import json +from pathlib import Path + +import pytest + +from astrbot.core.platform.sources.kook.kook_types import ( + ActionGroupModule, + ButtonElement, + ContextModule, + CountdownModule, + DividerModule, + FileModule, + HeaderModule, + ImageElement, + ImageGroupModule, + InviteModule, + KmarkdownElement, + KookCardMessage, + ParagraphStructure, + PlainTextElement, + SectionModule, + KookCardMessageContainer, +) +from tests.test_kook.shared import TEST_DATA_DIR + + +def test_kook_card_message_container_append(): + container = KookCardMessageContainer() + container.append(KookCardMessage()) + assert len(container) == 1 + + +@pytest.mark.parametrize( + "input, expect_container_length", + [ + ([KookCardMessage()], 1), + ([KookCardMessage()] * 2, 2), + ], +) +def test_kook_card_message_container_to_json( + input: list[KookCardMessage], expect_container_length: int +): + container = KookCardMessageContainer(input) + json_output = container.to_json() + output = json.loads(json_output) + assert isinstance(output, list) + assert len(output) == expect_container_length + + +def test_all_kook_card_type(): + expect_json_data = Path(TEST_DATA_DIR / "kook_card_data.json").read_text( + encoding="utf-8" + ) + json_output = KookCardMessage( + theme="info", + size="lg", + modules=[ + HeaderModule(text=PlainTextElement(content="test1")), + SectionModule(text=KmarkdownElement(content="test2")), + DividerModule(), + SectionModule( + text=ParagraphStructure( + cols=2, + fields=[ + KmarkdownElement(content="test3"), + KmarkdownElement(content="**test4**"), + ], + ) + ), + ImageGroupModule( + elements=[ + ImageElement( + src="https://img.kookapp.cn/attachments/2023-01/05/63b645851ff19.svg" + ) + ] + ), + FileModule( + src="https://img.kookapp.cn/attachments/2023-01/05/63b645851ff19.svg", + title="test5", + type="file", + ), + CountdownModule( + endTime=1772343427360, + startTime=1772343378259, + mode="second", + ), + ActionGroupModule( + elements=[ + ButtonElement( + value="btn_clicked", + text="点我测试回调", + click="return-val", + theme="primary", + ), + ButtonElement( + value="https://www.kookapp.cn", + text="访问官网", + click="link", + theme="danger", + ), + ] + ), + ContextModule(elements=[PlainTextElement(content="test6")]), + InviteModule(code="test7"), + ], + ).to_json(indent=4, ensure_ascii=False) + assert json_output == expect_json_data