diff --git a/.gitignore b/.gitignore index 94be64f3..b4069973 100644 --- a/.gitignore +++ b/.gitignore @@ -87,5 +87,6 @@ screenshots/*.html # Trae 规格和文档 .trae/spec/ +.trae/specs/ .trae/documents/ .trae/data/ \ No newline at end of file diff --git a/CURRENT_TASK.md b/CURRENT_TASK.md new file mode 100644 index 00000000..26b99d0a --- /dev/null +++ b/CURRENT_TASK.md @@ -0,0 +1,135 @@ +# TASK: Dashboard API 集成 + +> 分支: `feature/dashboard-api` +> 并行组: 第一组 +> 优先级: 🔴 最高 +> 预计时间: 3-4 天 +> 依赖: 无 + +*** + +## 一、目标 + +利用已验证可用的 Dashboard API,增强积分检测能力,替代现有的 HTML 解析方案。 + +*** + +## 二、背景 + +### 2.1 API 验证结果 + +| API | 状态 | HTTP 状态码 | 备注 | +|-----|------|-------------|------| +| Dashboard API | ✅ 可用 | 200 | 返回完整用户数据 | + +### 2.2 API 端点 + +``` +GET https://rewards.bing.com/api/getuserinfo?type=1 +Headers: + Cookie: {session_cookies} + Referer: https://rewards.bing.com/ +``` + +### 2.3 响应示例 + +```json +{ + "dashboard": { + "userStatus": { + "levelInfo": { + "activeLevel": "newLevel3", + "activeLevelName": "Gold Member", + "progress": 1790, + "progressMax": 750 + }, + "availablePoints": 12345, + "counters": { + "pcSearch": [...], + "mobileSearch": [...] + } + }, + "dailySetPromotions": {...}, + "morePromotions": [...], + "punchCards": [...] + } +} +``` + +*** + +## 三、任务清单 + +### 3.1 数据结构定义 + +- [ ] 创建 `src/api/__init__.py` +- [ ] 创建 `src/api/models.py` + - [ ] `DashboardData` dataclass + - [ ] `UserStatus` dataclass + - [ ] `LevelInfo` dataclass + - [ ] `Counters` dataclass + - [ ] `Promotion` dataclass + - [ ] `PunchCard` dataclass + +### 3.2 DashboardClient 实现 + +- [ ] 创建 `src/api/dashboard_client.py` + - [ ] `get_dashboard_data()` - 获取完整 Dashboard 数据 + - [ ] `get_search_counters()` - 获取搜索计数器 + - [ ] `get_level_info()` - 获取会员等级信息 + - [ ] `get_promotions()` - 获取推广任务列表 + - [ ] `get_current_points()` - 获取当前积分 + +### 3.3 HTML Fallback 机制 + +- [ ] 实现 API 失败时的 HTML 解析 fallback +- [ ] 从页面脚本提取 `var dashboard = {...}` + +### 3.4 集成与测试 + +- [ ] 更新 `PointsDetector` 使用新 API +- [ ] 创建 `tests/unit/test_dashboard_client.py` +- [ ] 验证积分检测准确性 + +*** + +## 四、参考资源 + +### 4.1 TS 项目参考 + +| 文件 | 路径 | +|------|------| +| Dashboard API 实现 | `Microsoft-Rewards-Script/src/browser/BrowserFunc.ts` | +| 数据结构定义 | `Microsoft-Rewards-Script/src/interface/DashboardData.ts` | + +### 4.2 关键代码参考 + +```python +async def get_dashboard_data(self) -> DashboardData: + try: + response = await self._call_api() + if response.data and response.data.get('dashboard'): + return self._parse_dashboard(response.data['dashboard']) + except Exception as e: + self.logger.warn(f"API failed: {e}, trying HTML fallback") + return await self._html_fallback() + raise DashboardError("Failed to get dashboard data") +``` + +*** + +## 五、验收标准 + +- [ ] DashboardClient 可成功调用 API +- [ ] 返回完整的用户数据(积分、等级、任务) +- [ ] HTML fallback 机制正常工作 +- [ ] 单元测试覆盖率 > 80% +- [ ] 无 mypy 类型错误 + +*** + +## 六、合并条件 + +- [ ] 所有测试通过 +- [ ] Code Review 通过 +- [ ] 文档更新完成 diff --git a/pyproject.toml b/pyproject.toml index e534fb97..28f3f292 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "lxml>=5.3.0", "psutil>=6.1.0", "pyotp>=2.9.0", + "httpx>=0.28.0", ] [project.optional-dependencies] @@ -23,6 +24,7 @@ dev = [ "mypy>=1.14.0", "pydantic>=2.9.0", "httpx>=0.28.0", + "respx>=0.21.0", "tinydb>=4.8.0", "filelock>=3.15.0", "rich>=13.0.0", @@ -46,6 +48,8 @@ test = [ "pytest-xdist>=3.5.0", "hypothesis>=6.125.0", "faker>=35.0.0", + "httpx>=0.28.0", + "respx>=0.21.0", ] viz = [ "streamlit>=1.41.0", @@ -84,7 +88,7 @@ ignore = [ ] [tool.ruff.lint.isort] -known-first-party = ["src", "infrastructure", "browser", "login", "search", "account", "tasks", "ui"] +known-first-party = ["src", "infrastructure", "browser", "login", "search", "account", "tasks", "ui", "api", "constants"] [tool.ruff.format] quote-style = "double" diff --git a/src/account/points_detector.py b/src/account/points_detector.py index d15ec926..eefcff05 100644 --- a/src/account/points_detector.py +++ b/src/account/points_detector.py @@ -3,11 +3,13 @@ 从 Microsoft Rewards Dashboard 抓取积分信息 """ +import asyncio import logging import re from playwright.async_api import Page +from api.dashboard_client import DashboardClient, DashboardError from constants import REWARDS_URLS logger = logging.getLogger(__name__) @@ -90,6 +92,22 @@ async def get_current_points(self, page: Page, skip_navigation: bool = False) -> logger.debug("跳过导航,使用当前页面") await page.wait_for_timeout(1000) + # 优先使用 Dashboard API + try: + logger.debug("尝试使用 Dashboard API 获取积分...") + async with DashboardClient(page) as client: + api_points: int | None = await asyncio.wait_for( + client.get_current_points(), timeout=35.0 + ) + if api_points is not None and api_points >= 0: + logger.info("✓ 从 API 获取积分成功") + return int(api_points) + except asyncio.TimeoutError: + logger.warning("Dashboard API 超时,使用 HTML 解析作为备用") + except DashboardError as e: + logger.warning(f"Dashboard API 失败: {e},使用 HTML 解析作为备用") + + # 备用:HTML 解析 logger.debug("尝试从页面源码提取积分...") points = await self._extract_points_from_source(page) @@ -107,13 +125,14 @@ async def get_current_points(self, page: Page, skip_navigation: bool = False) -> points_text = await element.text_content() logger.debug(f"找到积分文本: {points_text}") - points = self._parse_points(points_text) + if points_text and points_text.strip(): + points = self._parse_points(points_text.strip()) - if points is not None and points >= 100: - logger.info(f"✓ 当前积分: {points:,}") - return points - elif points is not None: - logger.debug(f"积分值太小,可能是误识别: {points}") + if points is not None and points >= 100: + logger.info("✓ 当前积分获取成功") + return points + elif points is not None: + logger.debug(f"积分值太小,可能是误识别: {points}") except Exception as e: logger.debug(f"选择器 {selector} 失败: {e}") @@ -310,7 +329,12 @@ async def _check_task_status(self, page: Page, selectors: list, task_name: str) Returns: 任务状态字典 """ - status = {"found": False, "completed": False, "progress": None, "max_progress": None} + status: dict[str, bool | int | None] = { + "found": False, + "completed": False, + "progress": None, + "max_progress": None, + } try: for selector in selectors: @@ -338,10 +362,12 @@ async def _check_task_status(self, page: Page, selectors: list, task_name: str) # 查找类似 "15/30" 的进度 progress_match = re.search(r"(\d+)\s*/\s*(\d+)", text) if progress_match: - status["progress"] = int(progress_match.group(1)) - status["max_progress"] = int(progress_match.group(2)) + progress_val = int(progress_match.group(1)) + max_progress_val = int(progress_match.group(2)) + status["progress"] = progress_val + status["max_progress"] = max_progress_val - if status["progress"] >= status["max_progress"]: + if progress_val >= max_progress_val: status["completed"] = True logger.debug(f"{task_name} 状态: {status}") diff --git a/src/api/__init__.py b/src/api/__init__.py new file mode 100644 index 00000000..3bf4c986 --- /dev/null +++ b/src/api/__init__.py @@ -0,0 +1,39 @@ +"""API 模块 + +提供面向业务层的统一 API 访问入口,包括仪表盘相关的客户端封装以及数据模型。 +该模块聚合了对外公开的主要类型,方便调用方通过 `api` 包进行导入和使用。 + +主要组件 +------- +- ``DashboardClient``: 仪表盘 API 客户端,对外提供高层封装的请求接口 +- ``DashboardError``: 仪表盘相关错误类型,用于封装请求或解析过程中的异常 +- ``DashboardData``: 仪表盘整体数据模型 +- ``UserStatus``: 用户当前状态信息模型 +- ``LevelInfo``: 用户等级与经验值等信息模型 +- ``SearchCounter`` / ``SearchCounters``: 搜索计数与统计信息模型 +- ``Promotion``: 活动与促销信息模型 +- ``PunchCard``: 打卡与活跃度相关的数据模型 +""" + +from .dashboard_client import DashboardClient, DashboardError +from .models import ( + DashboardData, + LevelInfo, + Promotion, + PunchCard, + SearchCounter, + SearchCounters, + UserStatus, +) + +__all__ = [ + "DashboardClient", + "DashboardError", + "DashboardData", + "UserStatus", + "LevelInfo", + "SearchCounter", + "SearchCounters", + "Promotion", + "PunchCard", +] diff --git a/src/api/dashboard_client.py b/src/api/dashboard_client.py new file mode 100644 index 00000000..b6952f98 --- /dev/null +++ b/src/api/dashboard_client.py @@ -0,0 +1,279 @@ +"""Dashboard API 客户端""" + +import asyncio +import json +import logging +import re + +import httpx +from playwright.async_api import Page + +from constants import API_ENDPOINTS, API_PARAMS, REWARDS_URLS + +from .models import DashboardData, SearchCounters + +logger = logging.getLogger(__name__) + + +class DashboardError(Exception): + """Dashboard API 错误""" + + def __init__(self, message: str, status_code: int | None = None): + super().__init__(message) + self.status_code = status_code + + def is_auth_error(self) -> bool: + """检查是否为认证错误 (401/403)""" + return self.status_code in (401, 403) + + +class DashboardClient: + """Dashboard API 客户端""" + + DEFAULT_MAX_RETRIES = 2 + DEFAULT_RETRY_DELAY = 1.0 + DEFAULT_TIMEOUT = 10.0 + + def __init__( + self, + page: Page, + max_retries: int | None = None, + retry_delay: float | None = None, + timeout: float | None = None, + ): + """ + 初始化 DashboardClient + + Args: + page: Playwright Page 对象 + max_retries: 最大重试次数,默认 2 + retry_delay: 重试间隔(秒),默认 1.0 + timeout: 请求超时(秒),默认 10.0 + + Raises: + ValueError: 缺少必要的 API 端点配置 + """ + self._page = page + self._max_retries = max_retries if max_retries is not None else self.DEFAULT_MAX_RETRIES + self._retry_delay = retry_delay if retry_delay is not None else self.DEFAULT_RETRY_DELAY + self._timeout = timeout if timeout is not None else self.DEFAULT_TIMEOUT + + if "dashboard" not in API_ENDPOINTS: + raise ValueError("Missing 'dashboard' endpoint in API_ENDPOINTS") + if "dashboard_type" not in API_PARAMS: + raise ValueError("Missing 'dashboard_type' in API_PARAMS") + + base_url = API_ENDPOINTS["dashboard"] + if not base_url.startswith(("http://", "https://")): + raise ValueError(f"Invalid API endpoint URL: {base_url}") + + query_param = API_PARAMS["dashboard_type"] + if not query_param.startswith("?"): + raise ValueError(f"Invalid API query parameter: {query_param}") + + self._api_url = base_url + query_param + self._client = httpx.AsyncClient( + timeout=self._timeout, + limits=httpx.Limits( + max_keepalive_connections=5, + max_connections=10, + ), + ) + + async def close(self): + """关闭 HTTP 客户端""" + if self._client is not None: + await self._client.aclose() + self._client = None + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + + async def _get_cookies_header(self) -> str: + """ + 从 Page context 获取 cookies 字符串 + + 使用 Playwright 的 URL 作用域 cookie 选择,让 Playwright 按浏览器规则 + 返回对该 URL 生效的 cookies(包括父域 cookies)。 + + Returns: + cookies 字符串 + """ + cookies = await self._page.context.cookies([self._api_url]) + return "; ".join(f"{c['name']}={c['value']}" for c in cookies) + + async def _call_api(self) -> DashboardData: + """ + 调用 Dashboard API + + Returns: + DashboardData 对象 + + Raises: + DashboardError: API 调用失败 + """ + headers = { + "Referer": REWARDS_URLS["dashboard"], + "Cookie": await self._get_cookies_header(), + "Accept": "application/json", + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", + } + + for attempt in range(self._max_retries + 1): + try: + if self._client is None or self._client.is_closed: + raise DashboardError("HTTP client has been closed") + response = await self._client.get(self._api_url, headers=headers) + response.raise_for_status() + data = response.json() + + if not isinstance(data, dict): + raise DashboardError( + f"Invalid API response from {self._api_url}: not a dict (attempt {attempt + 1})" + ) + if "dashboard" not in data: + raise DashboardError( + f"Invalid API response from {self._api_url}: missing 'dashboard' field (attempt {attempt + 1})" + ) + if not isinstance(data["dashboard"], dict): + raise DashboardError( + f"Invalid API response from {self._api_url}: 'dashboard' is not a dict (attempt {attempt + 1})" + ) + + return self._parse_dashboard(data["dashboard"]) + + except httpx.HTTPStatusError as e: + raise DashboardError( + f"HTTP error {e.response.status_code} from {self._api_url}", + status_code=e.response.status_code, + ) from e + except httpx.TimeoutException as e: + if attempt < self._max_retries: + logger.debug( + f"Request timeout, retrying ({attempt + 1}/{self._max_retries})..." + ) + await asyncio.sleep(self._retry_delay) + continue + raise DashboardError( + f"API timeout after {attempt + 1} attempts: {self._api_url}" + ) from e + except httpx.RequestError as e: + if attempt < self._max_retries: + logger.debug( + f"Network error, retrying ({attempt + 1}/{self._max_retries}): {e}" + ) + await asyncio.sleep(self._retry_delay) + continue + raise DashboardError( + f"Network error after {attempt + 1} attempts: {self._api_url} - {e}" + ) from e + except (json.JSONDecodeError, TypeError, KeyError, ValueError) as e: + raise DashboardError(f"Parse error from {self._api_url}: {e}") from e + + def _parse_dashboard(self, data: dict[str, object]) -> DashboardData: + """ + 解析 dashboard 数据 + + Args: + data: dashboard 数据字典 + + Returns: + DashboardData 对象 + """ + return DashboardData.from_dict(data) + + async def _html_fallback(self) -> DashboardData | None: + """ + HTML fallback,从页面源码提取 dashboard 数据 + + Returns: + DashboardData 对象,失败返回 None + """ + try: + html = await self._page.content() + match = re.search(r"var\s+dashboard\s*=\s*({.*?});", html, re.DOTALL) + + if match: + json_str = match.group(1) + data = json.loads(json_str) + return self._parse_dashboard(data) + + logger.warning("HTML fallback: dashboard variable not found in page") + return None + + except json.JSONDecodeError as e: + logger.warning(f"HTML fallback JSON parse error: {e}") + return None + except re.error as e: + logger.warning(f"HTML fallback regex error: {e}") + return None + except (TypeError, KeyError, ValueError) as e: + logger.warning(f"HTML fallback data parse error: {e}") + return None + + async def get_dashboard_data(self) -> DashboardData: + """ + 获取完整 Dashboard 数据 + + Returns: + DashboardData 对象 + + Raises: + DashboardError: 所有数据源都失败 + """ + for attempt in range(self._max_retries + 1): + try: + return await self._call_api() + except DashboardError as e: + if e.is_auth_error(): + logger.warning(f"Auth error ({e.status_code}), attempting HTML fallback") + fallback_data = await self._html_fallback() + if fallback_data: + return fallback_data + raise + + if e.status_code and 500 <= e.status_code < 600 and attempt < self._max_retries: + logger.warning( + f"Server error ({e.status_code}), attempt {attempt + 1} failed, retrying..." + ) + await asyncio.sleep(self._retry_delay) + continue + + logger.warning("All API attempts failed, attempting HTML fallback") + fallback_data = await self._html_fallback() + if fallback_data: + return fallback_data + raise + + raise DashboardError("Failed to get dashboard data") + + async def get_current_points(self) -> int | None: + """ + 获取当前积分 + + Returns: + 当前积分,失败返回 None + """ + try: + data = await self.get_dashboard_data() + return data.user_status.available_points + except DashboardError as e: + logger.warning(f"get_current_points failed: {e}") + return None + + async def get_search_counters(self) -> SearchCounters | None: + """ + 获取搜索计数器 + + Returns: + SearchCounters 对象,失败返回 None + """ + try: + data = await self.get_dashboard_data() + return data.user_status.counters + except DashboardError as e: + logger.warning(f"get_search_counters failed: {e}") + return None diff --git a/src/api/models.py b/src/api/models.py new file mode 100644 index 00000000..474e3494 --- /dev/null +++ b/src/api/models.py @@ -0,0 +1,264 @@ +"""Dashboard API 数据模型""" + +import re +from dataclasses import dataclass, field +from typing import Any, TypeVar + +_T = TypeVar("_T") + +_CAMEL_TO_SNAKE_PATTERN = re.compile(r"([a-z0-9])([A-Z])") + + +def _camel_to_snake(name: str) -> str: + """将 camelCase 转换为 snake_case""" + return _CAMEL_TO_SNAKE_PATTERN.sub(r"\1_\2", name).lower() + + +def _transform_dict(data: Any) -> Any: + """递归转换字典键为 snake_case""" + if isinstance(data, dict): + return {_camel_to_snake(k): _transform_dict(v) for k, v in data.items()} + elif isinstance(data, list): + return [_transform_dict(item) for item in data] + else: + return data + + +def _filter_dataclass_fields(data: dict[str, Any], cls: type[_T]) -> dict[str, Any]: + """过滤 dataclass 字段,只保留已声明的字段""" + allowed = cls.__dataclass_fields__.keys() # type: ignore[attr-defined] + return {k: v for k, v in data.items() if k in allowed} + + +@dataclass +class SearchCounter: + """搜索计数器""" + + progress: int = 0 + max_progress: int = 0 + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "SearchCounter": + """从字典创建实例""" + data = _transform_dict(data) + return cls(**_filter_dataclass_fields(data, cls)) + + +@dataclass +class SearchCounters: + """搜索计数器集合""" + + pc_search: list[SearchCounter] = field(default_factory=list) + mobile_search: list[SearchCounter] = field(default_factory=list) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "SearchCounters": + """从字典创建实例""" + data = _transform_dict(data) + + pc_raw = data.get("pc_search") or [] + if not isinstance(pc_raw, list): + pc_raw = [] + + mobile_raw = data.get("mobile_search") or [] + if not isinstance(mobile_raw, list): + mobile_raw = [] + + pc_search = [SearchCounter.from_dict(item) for item in pc_raw if isinstance(item, dict)] + mobile_search = [ + SearchCounter.from_dict(item) for item in mobile_raw if isinstance(item, dict) + ] + return cls(pc_search=pc_search, mobile_search=mobile_search) + + +@dataclass +class LevelInfo: + """会员等级信息""" + + active_level: str = "" + active_level_name: str = "" + progress: int = 0 + progress_max: int = 0 + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "LevelInfo": + """从字典创建实例""" + data = _transform_dict(data) + return cls(**_filter_dataclass_fields(data, cls)) + + +@dataclass +class Promotion: + """推广任务""" + + promotion_type: str = "" + title: str = "" + points: int = 0 + status: str = "" + url: str | None = None + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "Promotion": + """从字典创建实例""" + data = _transform_dict(data) + return cls(**_filter_dataclass_fields(data, cls)) + + +@dataclass +class PunchCard: + """打卡任务""" + + name: str = "" + progress: int = 0 + max_progress: int = 0 + completed: bool = False + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "PunchCard": + """从字典创建实例""" + data = _transform_dict(data) + return cls(**_filter_dataclass_fields(data, cls)) + + +@dataclass +class StreakPromotion: + """连胜推广任务""" + + promotion_type: str = "" + title: str = "" + points: int = 0 + status: str = "" + url: str | None = None + streak_count: int = 0 + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "StreakPromotion": + """从字典创建实例""" + data = _transform_dict(data) + return cls(**_filter_dataclass_fields(data, cls)) + + +@dataclass +class StreakBonusPromotion: + """连胜奖励推广""" + + promotion_type: str = "" + title: str = "" + points: int = 0 + status: str = "" + url: str | None = None + streak_day: int = 0 + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "StreakBonusPromotion": + """从字典创建实例""" + data = _transform_dict(data) + return cls(**_filter_dataclass_fields(data, cls)) + + +@dataclass +class UserStatus: + """用户状态""" + + available_points: int = 0 + level_info: LevelInfo = field(default_factory=LevelInfo) + counters: SearchCounters = field(default_factory=SearchCounters) + bing_star_monthly_bonus_progress: int = 0 + bing_star_monthly_bonus_maximum: int = 0 + default_search_engine_monthly_bonus_progress: int = 0 + default_search_engine_monthly_bonus_maximum: int = 0 + default_search_engine_monthly_bonus_state: str = "" + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "UserStatus": + """从字典创建实例""" + data = _transform_dict(data) + level_info_raw = data.get("level_info") + level_info = ( + LevelInfo.from_dict(level_info_raw) if isinstance(level_info_raw, dict) else LevelInfo() + ) + counters_raw = data.get("counters") + counters = ( + SearchCounters.from_dict(counters_raw) + if isinstance(counters_raw, dict) + else SearchCounters() + ) + return cls( + available_points=data.get("available_points", 0), + level_info=level_info, + counters=counters, + bing_star_monthly_bonus_progress=data.get("bing_star_monthly_bonus_progress", 0), + bing_star_monthly_bonus_maximum=data.get("bing_star_monthly_bonus_maximum", 0), + default_search_engine_monthly_bonus_progress=data.get( + "default_search_engine_monthly_bonus_progress", 0 + ), + default_search_engine_monthly_bonus_maximum=data.get( + "default_search_engine_monthly_bonus_maximum", 0 + ), + default_search_engine_monthly_bonus_state=data.get( + "default_search_engine_monthly_bonus_state", "" + ), + ) + + +@dataclass +class DashboardData: + """Dashboard 数据""" + + user_status: UserStatus = field(default_factory=UserStatus) + daily_set_promotions: dict[str, list[Promotion]] = field(default_factory=dict) + more_promotions: list[Promotion] = field(default_factory=list) + punch_cards: list[PunchCard] = field(default_factory=list) + streak_promotion: StreakPromotion | None = None + streak_bonus_promotions: list[StreakBonusPromotion] = field(default_factory=list) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "DashboardData": + """从字典创建实例""" + data = _transform_dict(data) + user_status = UserStatus.from_dict(data.get("user_status", {})) + + daily_set_promotions_data = data.get("daily_set_promotions") or {} + if not isinstance(daily_set_promotions_data, dict): + daily_set_promotions_data = {} + daily_set: dict[str, list[Promotion]] = {} + for key, items in daily_set_promotions_data.items(): + if isinstance(items, list): + daily_set[key] = [Promotion.from_dict(item) for item in items] + + more_promotions_data = data.get("more_promotions") or [] + if not isinstance(more_promotions_data, list): + more_promotions_data = [] + more_promotions = [ + Promotion.from_dict(item) for item in more_promotions_data if isinstance(item, dict) + ] + + punch_cards_data = data.get("punch_cards") or [] + if not isinstance(punch_cards_data, list): + punch_cards_data = [] + punch_cards = [ + PunchCard.from_dict(item) for item in punch_cards_data if isinstance(item, dict) + ] + + streak_promotion = None + streak_promotion_raw = data.get("streak_promotion") + if isinstance(streak_promotion_raw, dict): + streak_promotion = StreakPromotion.from_dict(streak_promotion_raw) + + streak_bonus_data = data.get("streak_bonus_promotions") or [] + if not isinstance(streak_bonus_data, list): + streak_bonus_data = [] + streak_bonus_promotions = [ + StreakBonusPromotion.from_dict(item) + for item in streak_bonus_data + if isinstance(item, dict) + ] + + return cls( + user_status=user_status, + daily_set_promotions=daily_set, + more_promotions=more_promotions, + punch_cards=punch_cards, + streak_promotion=streak_promotion, + streak_bonus_promotions=streak_bonus_promotions, + ) diff --git a/tests/unit/test_dashboard_client.py b/tests/unit/test_dashboard_client.py new file mode 100644 index 00000000..b105c3c4 --- /dev/null +++ b/tests/unit/test_dashboard_client.py @@ -0,0 +1,552 @@ +"""Dashboard Client 单元测试""" + +import sys +from pathlib import Path +from unittest.mock import AsyncMock, Mock + +import httpx +import pytest +import respx + +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) + +from api.dashboard_client import DashboardClient, DashboardError +from api.models import DashboardData, SearchCounters + + +@pytest.fixture +def mock_page(): + """Mock Playwright Page 对象""" + page = Mock() + page.context = Mock() + + async def mock_cookies(urls=None): + return [ + {"name": "cookie1", "value": "value1", "domain": "rewards.bing.com"}, + {"name": "cookie2", "value": "value2", "domain": ".rewards.bing.com"}, + ] + + page.context.cookies = mock_cookies + page.content = AsyncMock( + return_value=""" + + + + """ + ) + return page + + +@pytest.fixture +def mock_page_no_dashboard(): + """没有 dashboard 变量的 Mock Page""" + page = Mock() + page.context = Mock() + + async def mock_cookies(urls=None): + return [{"name": "cookie1", "value": "value1", "domain": "rewards.bing.com"}] + + page.context.cookies = mock_cookies + page.content = AsyncMock(return_value="no dashboard") + return page + + +@pytest.fixture +def mock_api_response(): + """Mock API 成功响应数据""" + return { + "dashboard": { + "userStatus": { + "availablePoints": 12345, + "levelInfo": { + "activeLevel": "newLevel3", + "activeLevelName": "Gold Member", + "progress": 1790, + "progressMax": 750, + }, + "counters": { + "pcSearch": [{"progress": 15, "maxProgress": 30}], + "mobileSearch": [{"progress": 10, "maxProgress": 20}], + }, + }, + "dailySetPromotions": {}, + "morePromotions": [], + "punchCards": [], + } + } + + +async def test_dashboard_data_from_dict(): + """测试 DashboardData.from_dict 方法""" + data = { + "userStatus": { + "availablePoints": 12345, + "levelInfo": { + "activeLevel": "newLevel3", + "activeLevelName": "Gold Member", + "progress": 1790, + "progressMax": 750, + }, + "counters": { + "pcSearch": [{"progress": 15, "maxProgress": 30}], + "mobileSearch": [{"progress": 10, "maxProgress": 20}], + }, + } + } + + dashboard_data = DashboardData.from_dict(data) + assert isinstance(dashboard_data, DashboardData) + assert dashboard_data.user_status.available_points == 12345 + assert dashboard_data.user_status.level_info.active_level == "newLevel3" + assert len(dashboard_data.user_status.counters.pc_search) == 1 + assert dashboard_data.user_status.counters.pc_search[0].progress == 15 + + +async def test_dashboard_data_from_dict_missing_fields(): + """测试 DashboardData.from_dict 方法(缺失字段)""" + data = {} + + dashboard_data = DashboardData.from_dict(data) + assert isinstance(dashboard_data, DashboardData) + assert dashboard_data.user_status.available_points == 0 + assert dashboard_data.user_status.level_info.active_level == "" + assert len(dashboard_data.user_status.counters.pc_search) == 0 + + +async def test_get_cookies_header(mock_page): + """测试 _get_cookies_header 方法""" + client = DashboardClient(mock_page) + cookies = await client._get_cookies_header() + assert "cookie1=value1" in cookies + assert "cookie2=value2" in cookies + + +@respx.mock +async def test_get_dashboard_data_success(mock_page, mock_api_response): + """测试 API 调用成功场景""" + respx.get("https://rewards.bing.com/api/getuserinfo?type=1").respond( + status_code=200, json=mock_api_response + ) + + client = DashboardClient(mock_page) + data = await client.get_dashboard_data() + assert isinstance(data, DashboardData) + assert data.user_status.available_points == 12345 + + +@respx.mock +async def test_get_dashboard_data_api_error_fallback(mock_page): + """测试 API 错误 + HTML fallback 场景""" + respx.get("https://rewards.bing.com/api/getuserinfo?type=1").respond( + status_code=500, text="Internal Server Error" + ) + + client = DashboardClient(mock_page) + data = await client.get_dashboard_data() + assert isinstance(data, DashboardData) + assert data.user_status.available_points == 12345 + + +@respx.mock +async def test_get_dashboard_data_unauthorized_fallback(mock_page): + """测试 401 错误 + HTML fallback 场景""" + respx.get("https://rewards.bing.com/api/getuserinfo?type=1").respond( + status_code=401, text="Unauthorized" + ) + + client = DashboardClient(mock_page) + data = await client.get_dashboard_data() + assert isinstance(data, DashboardData) + assert data.user_status.available_points == 12345 + + +@respx.mock +async def test_get_dashboard_data_forbidden_fallback(mock_page): + """测试 403 错误 + HTML fallback 场景""" + respx.get("https://rewards.bing.com/api/getuserinfo?type=1").respond( + status_code=403, text="Forbidden" + ) + + client = DashboardClient(mock_page) + data = await client.get_dashboard_data() + assert isinstance(data, DashboardData) + assert data.user_status.available_points == 12345 + + +@respx.mock +async def test_get_dashboard_data_html_fallback_fails(mock_page_no_dashboard): + """测试 HTML fallback 失败场景""" + respx.get("https://rewards.bing.com/api/getuserinfo?type=1").respond( + status_code=500, text="Internal Server Error" + ) + + client = DashboardClient(mock_page_no_dashboard) + with pytest.raises(DashboardError): + await client.get_dashboard_data() + + +@respx.mock +async def test_get_current_points(mock_page, mock_api_response): + """测试 get_current_points 方法""" + respx.get("https://rewards.bing.com/api/getuserinfo?type=1").respond( + status_code=200, json=mock_api_response + ) + + client = DashboardClient(mock_page) + points = await client.get_current_points() + assert points == 12345 + + +@respx.mock +async def test_get_search_counters(mock_page, mock_api_response): + """测试 get_search_counters 方法""" + respx.get("https://rewards.bing.com/api/getuserinfo?type=1").respond( + status_code=200, json=mock_api_response + ) + + client = DashboardClient(mock_page) + counters = await client.get_search_counters() + assert isinstance(counters, SearchCounters) + assert len(counters.pc_search) == 1 + assert len(counters.mobile_search) == 1 + assert counters.pc_search[0].progress == 15 + assert counters.mobile_search[0].progress == 10 + + +@respx.mock +async def test_retry_logic(mock_page): + """测试重试逻辑""" + call_count = 0 + + def side_effect(request): + nonlocal call_count + call_count += 1 + if call_count < 3: + return httpx.Response(500, text="Internal Server Error") + else: + return httpx.Response( + 200, json={"dashboard": {"userStatus": {"availablePoints": 9999}}} + ) + + respx.get("https://rewards.bing.com/api/getuserinfo?type=1").mock(side_effect=side_effect) + + client = DashboardClient(mock_page) + data = await client.get_dashboard_data() + assert data.user_status.available_points == 9999 + assert call_count == 3 + + +@respx.mock +async def test_timeout_fallback(mock_page): + """测试超时 + HTML fallback 场景""" + respx.get("https://rewards.bing.com/api/getuserinfo?type=1").mock( + side_effect=httpx.TimeoutException("Timeout") + ) + + client = DashboardClient(mock_page) + data = await client.get_dashboard_data() + assert isinstance(data, DashboardData) + assert data.user_status.available_points == 12345 + + +@respx.mock +async def test_network_error_fallback(mock_page): + """测试网络错误 + HTML fallback 场景""" + respx.get("https://rewards.bing.com/api/getuserinfo?type=1").mock( + side_effect=httpx.ConnectError("Connection failed") + ) + + client = DashboardClient(mock_page) + data = await client.get_dashboard_data() + assert isinstance(data, DashboardData) + assert data.user_status.available_points == 12345 + + +def test_dashboard_error_is_auth_error(): + """测试 DashboardError.is_auth_error 方法""" + error_401 = DashboardError("Unauthorized", status_code=401) + error_403 = DashboardError("Forbidden", status_code=403) + error_500 = DashboardError("Server Error", status_code=500) + error_no_code = DashboardError("Unknown error") + + assert error_401.is_auth_error() is True + assert error_403.is_auth_error() is True + assert error_500.is_auth_error() is False + assert error_no_code.is_auth_error() is False + + +@respx.mock +async def test_json_parse_error_fallback(mock_page): + """测试 JSON 解析错误 + HTML fallback 场景""" + respx.get("https://rewards.bing.com/api/getuserinfo?type=1").respond( + status_code=200, text="invalid json {{{" + ) + + client = DashboardClient(mock_page) + data = await client.get_dashboard_data() + assert isinstance(data, DashboardData) + assert data.user_status.available_points == 12345 + + +@respx.mock +async def test_response_not_dict_fallback(mock_page): + """测试响应不是 dict + HTML fallback 场景""" + respx.get("https://rewards.bing.com/api/getuserinfo?type=1").respond( + status_code=200, json=[1, 2, 3] + ) + + client = DashboardClient(mock_page) + data = await client.get_dashboard_data() + assert isinstance(data, DashboardData) + assert data.user_status.available_points == 12345 + + +@respx.mock +async def test_missing_dashboard_field_fallback(mock_page): + """测试缺少 dashboard 字段 + HTML fallback 场景""" + respx.get("https://rewards.bing.com/api/getuserinfo?type=1").respond( + status_code=200, json={"otherField": "value"} + ) + + client = DashboardClient(mock_page) + data = await client.get_dashboard_data() + assert isinstance(data, DashboardData) + assert data.user_status.available_points == 12345 + + +@respx.mock +async def test_get_current_points_returns_none_on_error(mock_page_no_dashboard): + """测试 get_current_points 失败返回 None""" + respx.get("https://rewards.bing.com/api/getuserinfo?type=1").respond( + status_code=500, text="Internal Server Error" + ) + + client = DashboardClient(mock_page_no_dashboard) + points = await client.get_current_points() + assert points is None + + +@respx.mock +async def test_get_search_counters_returns_none_on_error(mock_page_no_dashboard): + """测试 get_search_counters 失败返回 None""" + respx.get("https://rewards.bing.com/api/getuserinfo?type=1").respond( + status_code=500, text="Internal Server Error" + ) + + client = DashboardClient(mock_page_no_dashboard) + counters = await client.get_search_counters() + assert counters is None + + +async def test_camel_to_snake_conversion(): + """测试 camelCase 到 snake_case 的转换""" + data = { + "userStatus": { + "availablePoints": 9999, + "bingStarMonthlyBonusProgress": 100, + "bingStarMonthlyBonusMaximum": 500, + "defaultSearchEngineMonthlyBonusState": "active", + } + } + + dashboard_data = DashboardData.from_dict(data) + assert dashboard_data.user_status.available_points == 9999 + assert dashboard_data.user_status.bing_star_monthly_bonus_progress == 100 + assert dashboard_data.user_status.bing_star_monthly_bonus_maximum == 500 + assert dashboard_data.user_status.default_search_engine_monthly_bonus_state == "active" + + +async def test_extra_fields_ignored(): + """测试额外字段被忽略""" + data = { + "userStatus": { + "availablePoints": 12345, + "unknownField1": "value1", + "unknownField2": 123, + }, + "unknownTopLevel": "ignored", + } + + dashboard_data = DashboardData.from_dict(data) + assert dashboard_data.user_status.available_points == 12345 + assert not hasattr(dashboard_data.user_status, "unknownField1") + assert not hasattr(dashboard_data, "unknownTopLevel") + + +def test_missing_api_endpoint_raises_error(mock_page): + """测试缺少 API 端点配置时抛出 ValueError""" + from unittest.mock import patch + + with patch("api.dashboard_client.API_ENDPOINTS", {}): + with pytest.raises(ValueError, match="dashboard"): + DashboardClient(mock_page) + + +def test_missing_api_params_raises_error(mock_page): + """测试缺少 API 参数配置时抛出 ValueError""" + from unittest.mock import patch + + with patch("api.dashboard_client.API_PARAMS", {}): + with pytest.raises(ValueError, match="dashboard_type"): + DashboardClient(mock_page) + + +@respx.mock +async def test_auth_error_no_retry(mock_page): + """测试 401/403 错误不重试,立即触发 fallback""" + call_count = 0 + + def side_effect(request): + nonlocal call_count + call_count += 1 + return httpx.Response(401, text="Unauthorized") + + respx.get("https://rewards.bing.com/api/getuserinfo?type=1").mock(side_effect=side_effect) + + client = DashboardClient(mock_page) + data = await client.get_dashboard_data() + assert isinstance(data, DashboardData) + assert call_count == 1 + + +@respx.mock +async def test_client_close(mock_page, mock_api_response): + """测试客户端关闭方法""" + respx.get("https://rewards.bing.com/api/getuserinfo?type=1").respond( + status_code=200, json=mock_api_response + ) + + client = DashboardClient(mock_page) + await client.close() + assert client._client is None + + +@respx.mock +async def test_client_context_manager(mock_page, mock_api_response): + """测试客户端上下文管理器""" + respx.get("https://rewards.bing.com/api/getuserinfo?type=1").respond( + status_code=200, json=mock_api_response + ) + + async with DashboardClient(mock_page) as client: + data = await client.get_dashboard_data() + assert isinstance(data, DashboardData) + + assert client._client is None + + +@pytest.mark.parametrize("status_code", (401, 403)) +@respx.mock +async def test_get_dashboard_data_auth_error_with_failing_html_fallback( + status_code, + mock_page_no_dashboard, +): + """ + 当 API 返回认证错误 (401/403) 且 HTML fallback 也失败时, + 应该抛出 DashboardError。 + """ + respx.get("https://rewards.bing.com/api/getuserinfo?type=1").respond( + status_code=status_code, text="Unauthorized" + ) + + client = DashboardClient(mock_page_no_dashboard) + with pytest.raises(DashboardError) as exc_info: + await client.get_dashboard_data() + + assert exc_info.value.status_code == status_code + assert exc_info.value.is_auth_error() + + +@respx.mock +async def test_get_current_points_api_error_html_fallback(mock_page): + """API 5xx 时应从 HTML fallback 中获取当前积分""" + respx.get("https://rewards.bing.com/api/getuserinfo?type=1").respond( + status_code=500, text="Internal Server Error" + ) + + client = DashboardClient(mock_page) + points = await client.get_current_points() + assert points == 12345 + + +@respx.mock +async def test_get_search_counters_timeout_html_fallback(mock_page): + """API 超时时应从 HTML fallback 中构建搜索计数器""" + respx.get("https://rewards.bing.com/api/getuserinfo?type=1").mock( + side_effect=httpx.TimeoutException("Request timed out") + ) + + client = DashboardClient(mock_page) + counters = await client.get_search_counters() + + assert counters is not None + assert len(counters.pc_search) == 1 + assert len(counters.mobile_search) == 1 + assert counters.pc_search[0].progress == 15 + assert counters.mobile_search[0].progress == 10 + + +async def test_search_counters_handles_null_values(): + """测试 SearchCounters 处理 null/非列表值""" + data = { + "pc_search": None, + "mobile_search": "invalid", + } + + counters = SearchCounters.from_dict(data) + assert counters.pc_search == [] + assert counters.mobile_search == [] + + +async def test_search_counters_handles_scalar_values(): + """测试 SearchCounters 处理标量值""" + data = { + "pc_search": 123, + "mobile_search": "string", + } + + counters = SearchCounters.from_dict(data) + assert counters.pc_search == [] + assert counters.mobile_search == [] + + +async def test_cookie_filtering_by_domain(mock_page): + """测试 Cookie 使用 Playwright URL 作用域选择""" + + async def mock_cookies(urls=None): + return [ + {"name": "bing_cookie", "value": "bing_value", "domain": "bing.com"}, + {"name": "rewards_cookie", "value": "rewards_value", "domain": "rewards.bing.com"}, + {"name": "rewards_sub_cookie", "value": "sub_value", "domain": ".rewards.bing.com"}, + ] + + mock_page.context.cookies = mock_cookies + + client = DashboardClient(mock_page) + cookies = await client._get_cookies_header() + + assert "bing_cookie=bing_value" in cookies + assert "rewards_cookie=rewards_value" in cookies + assert "rewards_sub_cookie=sub_value" in cookies