diff --git a/config.example.yaml b/config.example.yaml index aa38d1a4..5b67c7df 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -57,6 +57,11 @@ query_engine: wikipedia: enabled: true # Wikipedia 热门话题 timeout: 15 # 超时(秒) + wikipedia_top_views: + enabled: true # Wikipedia 热门浏览(真实热搜数据) + timeout: 30 # 超时(秒) + lang: "en" # 语言 + ttl: 21600 # 缓存 TTL(秒,默认 6 小时) bing_suggestions: enabled: true # Bing 建议API bing_api: diff --git a/pyproject.toml b/pyproject.toml index e534fb97..9cfcab5b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ version = "1.0.0" requires-python = ">=3.10" dependencies = [ "playwright>=1.49.0", - "playwright-stealth>=1.0.6", + "playwright-stealth>=1.0.6,<2.0", "pyyaml>=6.0.1", "aiohttp>=3.11.0", "beautifulsoup4>=4.12.3", @@ -35,17 +35,7 @@ dev = [ "pytest-xdist>=3.5.0", "hypothesis>=6.125.0", "faker>=35.0.0", -] -test = [ - "pytest>=8.0.0", - "pytest-asyncio>=0.24.0", - "pytest-playwright>=0.5.0", - "pytest-benchmark>=5.0.0", - "pytest-cov>=6.0.0", - "pytest-timeout>=2.3.0", - "pytest-xdist>=3.5.0", - "hypothesis>=6.125.0", - "faker>=35.0.0", + "respx>=0.21.0", ] viz = [ "streamlit>=1.41.0", diff --git a/src/account/points_detector.py b/src/account/points_detector.py index d15ec926..104f4b2c 100644 --- a/src/account/points_detector.py +++ b/src/account/points_detector.py @@ -8,7 +8,7 @@ from playwright.async_api import Page -from constants import REWARDS_URLS +from api.dashboard_client import DashboardClient logger = logging.getLogger(__name__) @@ -16,7 +16,7 @@ class PointsDetector: """积分检测器类""" - DASHBOARD_URL = REWARDS_URLS["dashboard"] + DASHBOARD_URL = "https://rewards.bing.com/" POINTS_SELECTORS = [ "p.text-title1.font-semibold", @@ -90,11 +90,25 @@ async def get_current_points(self, page: Page, skip_navigation: bool = False) -> logger.debug("跳过导航,使用当前页面") await page.wait_for_timeout(1000) + # 优先使用 Dashboard API + try: + logger.debug("尝试使用 Dashboard API 获取积分...") + client = DashboardClient(page) + api_points: int | None = await client.get_current_points() + if api_points is not None and api_points >= 0: + logger.debug(f"✓ 从 API 获取积分: {api_points:,}") + return int(api_points) + except Exception as e: + logger.warning( + f"API 获取积分失败({type(e).__name__}: {e}),使用 HTML 解析作为备用" + ) + + # 备用:HTML 解析 logger.debug("尝试从页面源码提取积分...") points = await self._extract_points_from_source(page) if points is not None: - logger.info(f"✓ 从源码提取积分: {points:,}") + logger.debug(f"✓ 从源码提取积分: {points:,}") return points logger.debug("源码提取失败,尝试选择器...") @@ -107,13 +121,14 @@ async def get_current_points(self, page: Page, skip_navigation: bool = False) -> points_text = await element.text_content() logger.debug(f"找到积分文本: {points_text}") - points = self._parse_points(points_text) + if points_text: + points = self._parse_points(points_text) - if points is not None and points >= 100: - logger.info(f"✓ 当前积分: {points:,}") - return points - elif points is not None: - logger.debug(f"积分值太小,可能是误识别: {points}") + if points is not None and points >= 100: + logger.debug(f"✓ 当前积分: {points:,}") + return points + elif points is not None: + logger.debug(f"积分值太小,可能是误识别: {points}") except Exception as e: logger.debug(f"选择器 {selector} 失败: {e}") @@ -143,7 +158,7 @@ def _parse_points(self, text: str) -> int | None: Returns: 积分数量,失败返回 None """ - if not text: + if not text or not text.strip(): return None try: @@ -310,7 +325,12 @@ async def _check_task_status(self, page: Page, selectors: list, task_name: str) Returns: 任务状态字典 """ - status = {"found": False, "completed": False, "progress": None, "max_progress": None} + status: dict[str, bool | int | None] = { + "found": False, + "completed": False, + "progress": None, + "max_progress": None, + } try: for selector in selectors: @@ -338,10 +358,12 @@ async def _check_task_status(self, page: Page, selectors: list, task_name: str) # 查找类似 "15/30" 的进度 progress_match = re.search(r"(\d+)\s*/\s*(\d+)", text) if progress_match: - status["progress"] = int(progress_match.group(1)) - status["max_progress"] = int(progress_match.group(2)) + progress_val = int(progress_match.group(1)) + max_progress_val = int(progress_match.group(2)) + status["progress"] = progress_val + status["max_progress"] = max_progress_val - if status["progress"] >= status["max_progress"]: + if progress_val >= max_progress_val: status["completed"] = True logger.debug(f"{task_name} 状态: {status}") diff --git a/src/api/__init__.py b/src/api/__init__.py new file mode 100644 index 00000000..9b64aaec --- /dev/null +++ b/src/api/__init__.py @@ -0,0 +1,5 @@ +"""API clients module""" + +from .dashboard_client import DashboardClient + +__all__ = ["DashboardClient"] diff --git a/src/api/dashboard_client.py b/src/api/dashboard_client.py new file mode 100644 index 00000000..948f9f7d --- /dev/null +++ b/src/api/dashboard_client.py @@ -0,0 +1,170 @@ +""" +Dashboard API Client + +Fetches points data from Microsoft Rewards Dashboard API. +""" + +import logging +import re +from typing import Any + +from playwright.async_api import Page + +from constants import API_ENDPOINTS, REWARDS_URLS + +logger = logging.getLogger(__name__) + + +class DashboardClient: + """Client for fetching data from Microsoft Rewards Dashboard API""" + + def __init__(self, page: Page): + """ + Initialize Dashboard client + + Args: + page: Playwright Page object + """ + self.page = page + self._cached_points: int | None = None + base = REWARDS_URLS.get("dashboard", "https://rewards.bing.com") + self._base_url = base.rstrip("/") + + async def get_current_points(self) -> int | None: + """ + Get current points from Dashboard API + + Attempts to fetch points via API call first, falls back to + parsing page content if API fails. + + Returns: + Points balance or None if unable to determine + """ + try: + points = await self._fetch_points_via_api() + if points is not None and points >= 0: + self._cached_points = points + return points + except TimeoutError as e: + logger.warning(f"API request timeout: {e}") + except ConnectionError as e: + logger.warning(f"API connection error: {e}") + except Exception as e: + logger.warning(f"API call failed: {e}") + + try: + points = await self._fetch_points_via_page_content() + if points is not None and points >= 0: + self._cached_points = points + return points + except Exception as e: + logger.debug(f"Page content parsing failed: {e}") + + return self._cached_points + + async def _fetch_points_via_api(self) -> int | None: + """ + Fetch points via internal API endpoint + + Returns: + Points balance or None + """ + try: + api_url = f"{self._base_url}{API_ENDPOINTS['dashboard_balance']}" + response = await self.page.evaluate( + f""" + async () => {{ + try {{ + const resp = await fetch('{api_url}', {{ + method: 'GET', + credentials: 'include' + }}); + if (!resp.ok) return null; + return await resp.json(); + }} catch {{ + return null; + }} + }} + """ + ) + + if response and isinstance(response, dict): + available = response.get("availablePoints") + balance = response.get("pointsBalance") + points = available if available is not None else balance + if points is not None: + try: + return int(points) + except (ValueError, TypeError): + pass + + except Exception as e: + logger.debug(f"API fetch error: {e}") + + return None + + async def _fetch_points_via_page_content(self) -> int | None: + """ + Extract points from page content as fallback + + Returns: + Points balance or None + """ + try: + content = await self.page.content() + + patterns = [ + r'"availablePoints"\s*:\s*(\d+)', + r'"pointsBalance"\s*:\s*(\d+)', + r'"totalPoints"\s*:\s*(\d+)', + ] + + for pattern in patterns: + match = re.search(pattern, content) + if match: + points = int(match.group(1)) + if 0 <= points <= 1000000: + return points + + except Exception as e: + logger.debug(f"Page content extraction error: {e}") + + return None + + async def get_dashboard_data(self) -> dict[str, Any] | None: + """ + Fetch full dashboard data + + Returns: + Dashboard data dict or None + """ + try: + api_url = f"{self._base_url}{API_ENDPOINTS['dashboard_data']}" + response = await self.page.evaluate( + f""" + async () => {{ + try {{ + const resp = await fetch('{api_url}', {{ + method: 'GET', + credentials: 'include' + }}); + if (!resp.ok) return null; + return await resp.json(); + }} catch {{ + return null; + }} + }} + """ + ) + + if response is not None and isinstance(response, dict): + return dict(response) + + except TimeoutError as e: + logger.warning(f"Dashboard API timeout: {e}") + except ConnectionError as e: + logger.warning(f"Dashboard API connection error: {e}") + except Exception as e: + logger.warning(f"Dashboard API error: {e}") + + return None diff --git a/src/constants/urls.py b/src/constants/urls.py index 4658cbed..9eb89689 100644 --- a/src/constants/urls.py +++ b/src/constants/urls.py @@ -37,6 +37,8 @@ API_ENDPOINTS = { "dashboard": "https://rewards.bing.com/api/getuserinfo", + "dashboard_balance": "/api/getuserbalance", + "dashboard_data": "/api/dashboard", "report_activity": "https://rewards.bing.com/api/reportactivity", "quiz": "https://www.bing.com/bingqa/ReportActivity", "app_dashboard": "https://prod.rewardsplatform.microsoft.com/dapi/me", diff --git a/src/review/models.py b/src/review/models.py index a385c049..0e1b3b88 100644 --- a/src/review/models.py +++ b/src/review/models.py @@ -118,8 +118,10 @@ class ReviewMetadata(BaseModel): pr_number: int owner: str repo: str + branch: str = Field(default="", description="拉取评论时的分支名称") + head_sha: str = Field(default="", description="拉取评论时的 HEAD commit SHA(前7位)") last_updated: str = Field(default_factory=lambda: datetime.utcnow().isoformat()) - version: str = "2.2" + version: str = "2.3" etag_comments: str | None = Field(None, description="GitHub ETag,用于条件请求") etag_reviews: str | None = Field(None, description="Reviews ETag") diff --git a/src/review/resolver.py b/src/review/resolver.py index c738f4a9..0c27e9b2 100644 --- a/src/review/resolver.py +++ b/src/review/resolver.py @@ -1,4 +1,5 @@ import logging +import subprocess from .comment_manager import ReviewManager from .graphql_client import GraphQLClient @@ -15,6 +16,34 @@ logger = logging.getLogger(__name__) +def get_git_branch() -> str: + """获取当前 git 分支名称""" + try: + result = subprocess.run( + ["git", "branch", "--show-current"], + capture_output=True, + text=True, + timeout=10, + ) + return result.stdout.strip() if result.returncode == 0 else "" + except Exception: + return "" + + +def get_git_head_sha() -> str: + """获取当前 HEAD commit SHA(前7位)""" + try: + result = subprocess.run( + ["git", "rev-parse", "--short", "HEAD"], + capture_output=True, + text=True, + timeout=10, + ) + return result.stdout.strip() if result.returncode == 0 else "" + except Exception: + return "" + + class ReviewResolver: """ 评论解决器 - 整合所有组件 @@ -150,7 +179,16 @@ def fetch_threads(self, pr_number: int) -> dict: threads = self._inject_sourcery_types(threads) - metadata = ReviewMetadata(pr_number=pr_number, owner=self.owner, repo=self.repo) + branch = get_git_branch() + head_sha = get_git_head_sha() + + metadata = ReviewMetadata( + pr_number=pr_number, + owner=self.owner, + repo=self.repo, + branch=branch, + head_sha=head_sha, + ) self.manager.save_threads(threads, metadata) self.manager.save_overviews(overviews, metadata) diff --git a/src/search/query_engine.py b/src/search/query_engine.py index 199e6d30..d3f38d03 100644 --- a/src/search/query_engine.py +++ b/src/search/query_engine.py @@ -96,8 +96,11 @@ def _init_sources(self) -> None: if self.config.get("query_engine.sources.duckduckgo.enabled", True): try: ddg_source = DuckDuckGoSource(self.config) - self.sources.append(ddg_source) - self.logger.info("✓ DuckDuckGoSource enabled") + if ddg_source.is_available(): + self.sources.append(ddg_source) + self.logger.info("✓ DuckDuckGoSource enabled") + else: + self.logger.warning("DuckDuckGoSource not available") except Exception as e: self.logger.error(f"Failed to initialize DuckDuckGoSource: {e}") else: @@ -106,13 +109,33 @@ def _init_sources(self) -> None: if self.config.get("query_engine.sources.wikipedia.enabled", True): try: wiki_source = WikipediaSource(self.config) - self.sources.append(wiki_source) - self.logger.info("✓ WikipediaSource enabled") + if wiki_source.is_available(): + self.sources.append(wiki_source) + self.logger.info("✓ WikipediaSource enabled") + else: + self.logger.warning("WikipediaSource not available") except Exception as e: self.logger.error(f"Failed to initialize WikipediaSource: {e}") else: self.logger.info("WikipediaSource disabled in config") + if self.config.get("query_engine.sources.wikipedia_top_views.enabled", True): + try: + from .query_sources import WikipediaTopViewsSource + + wiki_top_views_source = WikipediaTopViewsSource(self.config) + if wiki_top_views_source.is_available(): + self.sources.append(wiki_top_views_source) + self.logger.info("✓ WikipediaTopViewsSource enabled") + else: + self.logger.warning("WikipediaTopViewsSource not available") + except ImportError as e: + self.logger.error(f"WikipediaTopViewsSource module not found: {e}") + except Exception as e: + self.logger.error(f"Failed to initialize WikipediaTopViewsSource: {e}") + else: + self.logger.info("WikipediaTopViewsSource disabled in config") + if self.config.get("query_engine.sources.bing_suggestions.enabled", True): try: bing_source = BingSuggestionsSource(self.config) @@ -126,6 +149,8 @@ def _init_sources(self) -> None: else: self.logger.info("BingSuggestionsSource disabled in config") + self.sources.sort(key=lambda s: s.get_priority()) + async def generate_queries(self, count: int, expand: bool = True) -> list[str]: """ Generate a list of unique search queries @@ -137,6 +162,8 @@ async def generate_queries(self, count: int, expand: bool = True) -> list[str]: Returns: List of unique query strings """ + self._query_sources.clear() + # Check cache first cache_key = f"queries_{count}_{expand}" cached = self.cache.get(cache_key) @@ -193,7 +220,7 @@ async def _fetch_from_sources(self, count: int) -> list[str]: source_name = self.sources[i].get_source_name() for query in result: normalized = query.lower().strip() - if normalized: + if normalized and normalized not in self._query_sources: self._query_sources[normalized] = source_name all_queries.extend(result) self.logger.debug(f"Source {source_name} returned {len(result)} queries") @@ -230,7 +257,11 @@ async def _expand_queries(self, queries: list[str]) -> list[str]: ]: normalized = suggestion.lower().strip() # Only add if not already in original queries or expanded list - if normalized and normalized not in existing_queries: + if ( + normalized + and normalized not in existing_queries + and normalized not in self._query_sources + ): self._query_sources[normalized] = "bing_suggestions" existing_queries.add(normalized) expanded.append(suggestion) diff --git a/src/search/query_sources/__init__.py b/src/search/query_sources/__init__.py index 088adfc5..f17f0ebc 100644 --- a/src/search/query_sources/__init__.py +++ b/src/search/query_sources/__init__.py @@ -5,5 +5,6 @@ from .bing_suggestions_source import BingSuggestionsSource from .local_file_source import LocalFileSource from .query_source import QuerySource +from .wikipedia_top_views_source import WikipediaTopViewsSource -__all__ = ["QuerySource", "LocalFileSource", "BingSuggestionsSource"] +__all__ = ["QuerySource", "LocalFileSource", "BingSuggestionsSource", "WikipediaTopViewsSource"] diff --git a/src/search/query_sources/bing_suggestions_source.py b/src/search/query_sources/bing_suggestions_source.py index 4e3e59fd..d92f2d45 100644 --- a/src/search/query_sources/bing_suggestions_source.py +++ b/src/search/query_sources/bing_suggestions_source.py @@ -102,6 +102,10 @@ def get_source_name(self) -> str: """Return the name of this source""" return "bing_suggestions" + def get_priority(self) -> int: + """Return priority (lower = higher priority)""" + return 70 + def is_available(self) -> bool: """Check if this source is available""" return self._available diff --git a/src/search/query_sources/duckduckgo_source.py b/src/search/query_sources/duckduckgo_source.py index 7d646378..6c66aaff 100644 --- a/src/search/query_sources/duckduckgo_source.py +++ b/src/search/query_sources/duckduckgo_source.py @@ -120,6 +120,10 @@ def get_source_name(self) -> str: """Return the name of this source""" return "duckduckgo" + def get_priority(self) -> int: + """Return priority (lower = higher priority)""" + return 50 + def is_available(self) -> bool: """Check if this source is available""" return self._available diff --git a/src/search/query_sources/local_file_source.py b/src/search/query_sources/local_file_source.py index 28f635c3..0003ec64 100644 --- a/src/search/query_sources/local_file_source.py +++ b/src/search/query_sources/local_file_source.py @@ -182,6 +182,10 @@ def get_source_name(self) -> str: """Return the name of this source""" return "local_file" + def get_priority(self) -> int: + """Return priority (lower = higher priority)""" + return 100 + def is_available(self) -> bool: """Check if this source is available""" return len(self.base_terms) > 0 diff --git a/src/search/query_sources/query_source.py b/src/search/query_sources/query_source.py index 1a388adf..f5acddd2 100644 --- a/src/search/query_sources/query_source.py +++ b/src/search/query_sources/query_source.py @@ -53,3 +53,15 @@ def is_available(self) -> bool: True if available, False otherwise """ pass + + def get_priority(self) -> int: + """ + Return priority (lower value = higher priority) + + Default priority is 100. Subclasses can override this method + to provide custom priority values. + + Returns: + Priority value + """ + return 100 diff --git a/src/search/query_sources/wikipedia_source.py b/src/search/query_sources/wikipedia_source.py index 6565fbbc..98e80748 100644 --- a/src/search/query_sources/wikipedia_source.py +++ b/src/search/query_sources/wikipedia_source.py @@ -171,6 +171,10 @@ def get_source_name(self) -> str: """Return the name of this source""" return "wikipedia" + def get_priority(self) -> int: + """Return priority (lower = higher priority)""" + return 60 + def is_available(self) -> bool: """Check if this source is available""" return self._available diff --git a/src/search/query_sources/wikipedia_top_views_source.py b/src/search/query_sources/wikipedia_top_views_source.py new file mode 100644 index 00000000..22cb7cde --- /dev/null +++ b/src/search/query_sources/wikipedia_top_views_source.py @@ -0,0 +1,211 @@ +""" +Wikipedia Top Views query source - fetches trending topics from Wikipedia Pageviews API +""" + +import time +from datetime import datetime, timedelta +from typing import Any + +import aiohttp + +from constants import QUERY_SOURCE_URLS + +from .query_source import QuerySource + + +class WikipediaTopViewsSource(QuerySource): + """Query source that fetches trending topics from Wikipedia Pageviews API""" + + EXCLUDED_PREFIXES = [ + "Main_Page", + "Special:", + "File:", + "Wikipedia:", + "Template:", + "Help:", + "Category:", + "Portal:", + "Talk:", + "User:", + "Draft:", + "List_of_", + ] + + def __init__(self, config): + """ + Initialize Wikipedia Top Views source + + Args: + config: ConfigManager instance + """ + super().__init__(config) + self.timeout = config.get("query_engine.sources.wikipedia_top_views.timeout", 30) + lang = config.get("query_engine.sources.wikipedia_top_views.lang", "en") + if not isinstance(lang, str) or not lang.isalpha() or len(lang) > 10: + self.logger.warning(f"Invalid lang '{lang}', using 'en'") + self.lang = "en" + else: + self.lang = lang + self.cache_ttl = config.get("query_engine.sources.wikipedia_top_views.ttl", 6 * 3600) + self._available: bool = True + self._session: aiohttp.ClientSession | None = None + + self._cache_data: list[str] | None = None + self._cache_time: float = 0 + self._cache_hits: int = 0 + self._cache_misses: int = 0 + + self.logger.info("WikipediaTopViewsSource initialized") + + async def _get_session(self) -> aiohttp.ClientSession: + """Get or create aiohttp session""" + if self._session is None or self._session.closed: + connector = aiohttp.TCPConnector( + limit=10, + limit_per_host=5, + ttl_dns_cache=300, + ) + self._session = aiohttp.ClientSession(connector=connector) + return self._session + + async def close(self) -> None: + """Close the aiohttp session""" + if self._session is not None and not self._session.closed: + await self._session.close() + self._session = None + + def get_source_name(self) -> str: + """Return the name of this source""" + return "wikipedia_top_views" + + def get_priority(self) -> int: + """Return priority (lower = higher priority)""" + return 120 + + def is_available(self) -> bool: + """Check if this source is available""" + return self._available + + def get_cache_stats(self) -> dict: + """ + Get cache statistics + + Returns: + Dictionary with cache stats + """ + total = self._cache_hits + self._cache_misses + return { + "hits": self._cache_hits, + "misses": self._cache_misses, + "hit_rate": self._cache_hits / total if total > 0 else 0, + } + + def _is_cache_valid(self) -> bool: + """Check if cache is valid""" + if self._cache_data is None: + return False + return bool(time.monotonic() - self._cache_time < self.cache_ttl) + + def _get_from_cache(self, count: int) -> list[str]: + """Get queries from cache""" + self._cache_hits += 1 + if self._cache_data is None: + return [] + return self._cache_data[:count] + + def _cache_articles(self, articles: list[str]) -> None: + """Cache articles""" + self._cache_data = articles + self._cache_time = time.monotonic() + + def _get_api_date(self) -> tuple[str, str, str]: + """Get yesterday's date for API call (UTC)""" + yesterday = datetime.utcnow() - timedelta(days=1) + return (str(yesterday.year), f"{yesterday.month:02d}", f"{yesterday.day:02d}") + + def _build_api_url(self) -> str: + """Build API URL using constants""" + base_url = QUERY_SOURCE_URLS["wikipedia_top_views"] + yyyy, mm, dd = self._get_api_date() + return f"{base_url}/{self.lang}.wikipedia/all-access/{yyyy}/{mm}/{dd}" + + async def _fetch_top_articles(self, session: aiohttp.ClientSession) -> list[dict[str, Any]]: + """ + Fetch top articles from Wikipedia API + + Args: + session: aiohttp session + + Returns: + List of article objects + """ + try: + url = self._build_api_url() + self.logger.debug(f"Fetching top articles from: {url}") + + async with session.get( + url, timeout=aiohttp.ClientTimeout(total=self.timeout) + ) as response: + if response.status == 200: + data = await response.json() + if "items" in data and data["items"]: + articles = data["items"][0].get("articles", []) + self._available = True + return list(articles) if articles else [] + else: + self.logger.warning(f"Wikipedia API returned status {response.status}") + except Exception as e: + self.logger.error(f"Error fetching top articles: {e}") + return [] + + def _filter_articles(self, articles: list[dict]) -> list[str]: + """ + Filter out non-article entries + + Args: + articles: List of article objects + + Returns: + List of filtered article titles + """ + filtered = [] + for article in articles: + title = article.get("article", "") + if not any(title.startswith(prefix) for prefix in self.EXCLUDED_PREFIXES): + filtered.append(title.replace("_", " ")) + return filtered + + async def fetch_queries(self, count: int) -> list[str]: + """ + Fetch queries from Wikipedia Pageviews API + + Args: + count: Number of queries to fetch + + Returns: + List of query strings + """ + if self._is_cache_valid(): + self.logger.debug("Cache hit for Wikipedia top views") + return self._get_from_cache(count) + + self._cache_misses += 1 + queries = [] + + try: + session = await self._get_session() + articles = await self._fetch_top_articles(session) + + filtered_articles = self._filter_articles(articles) + queries = filtered_articles[:count] + + if filtered_articles: + self._cache_articles(filtered_articles) + self.logger.debug(f"Cached {len(filtered_articles)} top articles") + + self.logger.debug(f"Fetched {len(queries)} queries from Wikipedia top views") + + except Exception as e: + self.logger.error(f"Failed to fetch queries from Wikipedia top views: {e}") + + return queries diff --git a/tests/unit/test_dashboard_client.py b/tests/unit/test_dashboard_client.py new file mode 100644 index 00000000..96dc63eb --- /dev/null +++ b/tests/unit/test_dashboard_client.py @@ -0,0 +1,139 @@ +""" +Tests for DashboardClient +""" + +from unittest.mock import AsyncMock + +import pytest + + +class TestDashboardClient: + """Test DashboardClient functionality""" + + @pytest.fixture + def mock_page(self): + """Create mock page""" + return AsyncMock() + + @pytest.fixture + def dashboard_client(self, mock_page): + """Create DashboardClient instance""" + from api.dashboard_client import DashboardClient + + return DashboardClient(mock_page) + + def test_initialization(self, dashboard_client, mock_page): + """Test DashboardClient initializes correctly""" + assert dashboard_client.page is mock_page + assert dashboard_client._cached_points is None + + @pytest.mark.asyncio + async def test_get_current_points_api_success(self, dashboard_client, mock_page): + """Test get_current_points returns points via API""" + mock_page.evaluate = AsyncMock(return_value={"availablePoints": 5000}) + + result = await dashboard_client.get_current_points() + assert result == 5000 + + @pytest.mark.asyncio + async def test_get_current_points_api_returns_points_balance(self, dashboard_client, mock_page): + """Test get_current_points uses pointsBalance field""" + mock_page.evaluate = AsyncMock(return_value={"pointsBalance": 3000}) + + result = await dashboard_client.get_current_points() + assert result == 3000 + + @pytest.mark.asyncio + async def test_get_current_points_api_failure_fallback(self, dashboard_client, mock_page): + """Test get_current_points falls back to page content on API failure""" + mock_page.evaluate = AsyncMock(side_effect=[None, '{"availablePoints": 2500}']) + mock_page.content = AsyncMock(return_value='{"availablePoints": 2500}') + + result = await dashboard_client.get_current_points() + assert result == 2500 + + @pytest.mark.asyncio + async def test_get_current_points_timeout_error(self, dashboard_client, mock_page): + """Test get_current_points handles timeout error""" + mock_page.evaluate = AsyncMock(side_effect=TimeoutError("Request timeout")) + mock_page.content = AsyncMock(return_value='{"availablePoints": 1000}') + + result = await dashboard_client.get_current_points() + assert result == 1000 + + @pytest.mark.asyncio + async def test_get_current_points_connection_error(self, dashboard_client, mock_page): + """Test get_current_points handles connection error""" + mock_page.evaluate = AsyncMock(side_effect=ConnectionError("Connection failed")) + mock_page.content = AsyncMock(return_value='{"availablePoints": 1000}') + + result = await dashboard_client.get_current_points() + assert result == 1000 + + @pytest.mark.asyncio + async def test_get_current_points_returns_cached_on_failure(self, dashboard_client, mock_page): + """Test get_current_points returns cached value on complete failure""" + dashboard_client._cached_points = 8000 + mock_page.evaluate = AsyncMock(side_effect=Exception("Error")) + mock_page.content = AsyncMock(side_effect=Exception("Error")) + + result = await dashboard_client.get_current_points() + assert result == 8000 + + @pytest.mark.asyncio + async def test_get_current_points_returns_none_on_no_data(self, dashboard_client, mock_page): + """Test get_current_points returns None when no data available""" + mock_page.evaluate = AsyncMock(return_value=None) + mock_page.content = AsyncMock(return_value="") + + result = await dashboard_client.get_current_points() + assert result is None + + @pytest.mark.asyncio + async def test_fetch_points_via_api_string_value(self, dashboard_client, mock_page): + """Test _fetch_points_via_api handles string points value""" + mock_page.evaluate = AsyncMock(return_value={"availablePoints": "7500"}) + + result = await dashboard_client._fetch_points_via_api() + assert result == 7500 + + @pytest.mark.asyncio + async def test_fetch_points_via_page_content(self, dashboard_client, mock_page): + """Test _fetch_points_via_page_content extracts points""" + mock_page.content = AsyncMock(return_value='{"pointsBalance": 6000}') + + result = await dashboard_client._fetch_points_via_page_content() + assert result == 6000 + + @pytest.mark.asyncio + async def test_fetch_points_via_page_content_invalid_range(self, dashboard_client, mock_page): + """Test _fetch_points_via_page_content rejects invalid range""" + mock_page.content = AsyncMock(return_value='{"availablePoints": 99999999}') + + result = await dashboard_client._fetch_points_via_page_content() + assert result is None + + @pytest.mark.asyncio + async def test_get_dashboard_data_success(self, dashboard_client, mock_page): + """Test get_dashboard_data returns data""" + mock_data = {"user": "test", "points": 1000} + mock_page.evaluate = AsyncMock(return_value=mock_data) + + result = await dashboard_client.get_dashboard_data() + assert result == mock_data + + @pytest.mark.asyncio + async def test_get_dashboard_data_timeout(self, dashboard_client, mock_page): + """Test get_dashboard_data handles timeout""" + mock_page.evaluate = AsyncMock(side_effect=TimeoutError("Timeout")) + + result = await dashboard_client.get_dashboard_data() + assert result is None + + @pytest.mark.asyncio + async def test_get_dashboard_data_connection_error(self, dashboard_client, mock_page): + """Test get_dashboard_data handles connection error""" + mock_page.evaluate = AsyncMock(side_effect=ConnectionError("Connection failed")) + + result = await dashboard_client.get_dashboard_data() + assert result is None diff --git a/tests/unit/test_manage_reviews_cli.py b/tests/unit/test_manage_reviews_cli.py new file mode 100644 index 00000000..522c9d1f --- /dev/null +++ b/tests/unit/test_manage_reviews_cli.py @@ -0,0 +1,439 @@ +""" +Tests for CLI commands in manage_reviews.py +""" + +import subprocess +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +class TestVerifyContextLogic: + """Test verify-context command logic without CLI import""" + + def test_verify_context_no_data(self, tmp_path: Path): + """Test verify-context when no local data exists""" + + manager = MagicMock() + manager.get_metadata.return_value = None + + metadata = manager.get_metadata() + + if metadata is None: + output = { + "success": True, + "valid": None, + "message": "无本地评论数据", + } + else: + output = {"success": True, "valid": True} + + assert output["valid"] is None + assert "无本地评论数据" in output["message"] + + def test_verify_context_branch_match(self, tmp_path: Path): + """Test verify-context when branch matches""" + from src.review.models import ReviewMetadata + + metadata = ReviewMetadata( + pr_number=123, + owner="test-owner", + repo="test-repo", + branch="feature-test", + head_sha="abc1234", + ) + + current_branch = "feature-test" + + if not metadata.branch: + output = {"success": True, "valid": None, "warning": "旧版本数据"} + elif current_branch == metadata.branch: + output = { + "success": True, + "valid": True, + "current_branch": current_branch, + "stored_branch": metadata.branch, + "stored_pr": metadata.pr_number, + "message": "上下文验证通过", + } + else: + output = {"success": True, "valid": False} + + assert output["valid"] is True + assert output["current_branch"] == "feature-test" + + def test_verify_context_branch_mismatch(self, tmp_path: Path): + """Test verify-context when branch does not match""" + from src.review.models import ReviewMetadata + + metadata = ReviewMetadata( + pr_number=123, + owner="test-owner", + repo="test-repo", + branch="feature-old", + head_sha="abc1234", + ) + + current_branch = "feature-new" + + if not metadata.branch: + output = {"success": True, "valid": None} + elif current_branch == metadata.branch: + output = {"success": True, "valid": True} + else: + output = { + "success": True, + "valid": False, + "current_branch": current_branch, + "stored_branch": metadata.branch, + "stored_pr": metadata.pr_number, + "message": f"分支不匹配:当前分支 {current_branch},本地数据属于 {metadata.branch}", + } + + assert output["valid"] is False + assert output["current_branch"] == "feature-new" + assert output["stored_branch"] == "feature-old" + + def test_verify_context_old_version_data(self, tmp_path: Path): + """Test verify-context with old version data (no branch field)""" + from src.review.models import ReviewMetadata + + metadata = ReviewMetadata( + pr_number=123, + owner="test-owner", + repo="test-repo", + ) + + if not metadata.branch: + output = { + "success": True, + "valid": None, + "warning": "旧版本数据,缺少分支信息,跳过验证", + "stored_pr": metadata.pr_number, + } + else: + output = {"success": True, "valid": True} + + assert output["valid"] is None + assert "warning" in output + + +class TestFetchCommandLogic: + """Test fetch command logic without CLI import""" + + def test_fetch_auto_detect_pr_success(self): + """Test fetch command auto-detects PR number""" + pr_data = {"number": 456} + + pr_number = pr_data.get("number") + assert pr_number == 456 + + def test_fetch_auto_detect_pr_failure(self): + """Test fetch command handles auto-detect failure""" + result = MagicMock() + result.returncode = 1 + result.stdout = "" + + if result.returncode != 0: + output = { + "success": False, + "message": "无法自动获取 PR 编号,请使用 --pr 参数指定", + } + else: + output = {"success": True} + + assert output["success"] is False + + def test_fetch_with_explicit_pr(self): + """Test fetch command with explicit PR number""" + args_pr = 789 + assert args_pr == 789 + + def test_gh_cli_not_available(self): + """Test behavior when gh CLI is not available""" + result = MagicMock() + result.returncode = 127 + result.stderr = "'gh' is not recognized as an internal or external command" + + if result.returncode != 0: + output = { + "success": False, + "message": "无法自动获取 PR 编号,请使用 --pr 参数指定", + } + else: + output = {"success": True} + + assert output["success"] is False + + def test_gh_cli_not_authenticated(self): + """Test behavior when gh CLI is not authenticated""" + result = MagicMock() + result.returncode = 1 + result.stderr = "To get started with GitHub CLI, please run: gh auth login" + + if result.returncode != 0: + output = { + "success": False, + "message": "无法自动获取 PR 编号,请使用 --pr 参数指定", + } + else: + output = {"success": True} + + assert output["success"] is False + + def test_gh_cli_no_pr_on_branch(self): + """Test behavior when no PR exists for current branch""" + result = MagicMock() + result.returncode = 1 + result.stderr = "no pull requests found" + + if result.returncode != 0: + output = { + "success": False, + "message": "无法自动获取 PR 编号,请使用 --pr 参数指定", + } + else: + output = {"success": True} + + assert output["success"] is False + + def test_gh_cli_not_installed_file_not_found(self): + """Test behavior when gh CLI is not installed (FileNotFoundError)""" + + try: + raise FileNotFoundError("gh not found") + except FileNotFoundError: + output = { + "success": False, + "message": "gh 命令未安装,请安装 GitHub CLI 或手动指定 --pr 参数", + } + + assert output["success"] is False + assert "gh 命令未安装" in output["message"] + + def test_gh_cli_timeout(self): + """Test behavior when gh CLI times out""" + import subprocess + + try: + raise subprocess.TimeoutExpired("gh", 30) + except subprocess.TimeoutExpired: + output = { + "success": False, + "message": "gh 命令执行超时,请检查网络连接或手动指定 --pr 参数", + } + + assert output["success"] is False + assert "超时" in output["message"] + + def test_gh_cli_permission_error(self): + """Test behavior when gh CLI has permission issues""" + try: + raise PermissionError("Permission denied") + except PermissionError: + output = { + "success": False, + "message": "gh 命令权限不足,请检查 GitHub CLI 认证状态或手动指定 --pr 参数", + } + + assert output["success"] is False + assert "权限不足" in output["message"] + + +class TestPointsDetector: + """Test points detection functionality""" + + @pytest.mark.asyncio + async def test_api_success_returns_points(self): + """Test API success directly returns points""" + from src.account.points_detector import PointsDetector + + detector = PointsDetector() + mock_page = AsyncMock() + mock_page.goto = AsyncMock() + mock_page.url = "https://rewards.bing.com/" + mock_page.wait_for_timeout = AsyncMock() + + mock_client = AsyncMock() + mock_client.get_current_points.return_value = 5000 + + with patch("src.account.points_detector.DashboardClient", return_value=mock_client): + result = await detector.get_current_points(mock_page, skip_navigation=True) + assert result == 5000 + + @pytest.mark.asyncio + async def test_api_failure_fallback_to_html(self): + """Test API failure falls back to HTML parsing""" + from src.account.points_detector import PointsDetector + + detector = PointsDetector() + mock_page = AsyncMock() + mock_page.goto = AsyncMock() + mock_page.url = "https://rewards.bing.com/" + mock_page.wait_for_timeout = AsyncMock() + mock_page.content = AsyncMock(return_value='{"availablePoints": 3000}') + mock_page.query_selector_all = AsyncMock(return_value=[]) + mock_page.evaluate = AsyncMock(return_value="Available Points: 3000") + + mock_client = AsyncMock() + mock_client.get_current_points.side_effect = Exception("API error") + + with patch("src.account.points_detector.DashboardClient", return_value=mock_client): + result = await detector.get_current_points(mock_page, skip_navigation=True) + assert result == 3000 + + @pytest.mark.asyncio + async def test_empty_value_handling(self): + """Test handling of empty/null values""" + from src.account.points_detector import PointsDetector + + detector = PointsDetector() + mock_page = AsyncMock() + mock_page.goto = AsyncMock() + mock_page.url = "https://rewards.bing.com/" + mock_page.wait_for_timeout = AsyncMock() + mock_page.content = AsyncMock(return_value="") + mock_page.query_selector_all = AsyncMock(return_value=[]) + mock_page.evaluate = AsyncMock(return_value="") + + mock_client = AsyncMock() + mock_client.get_current_points.return_value = None + + with patch("src.account.points_detector.DashboardClient", return_value=mock_client): + result = await detector.get_current_points(mock_page, skip_navigation=True) + assert result is None + + @pytest.mark.asyncio + async def test_api_timeout_fallback(self): + """Test API timeout falls back to HTML parsing""" + from src.account.points_detector import PointsDetector + + detector = PointsDetector() + mock_page = AsyncMock() + mock_page.goto = AsyncMock() + mock_page.url = "https://rewards.bing.com/" + mock_page.wait_for_timeout = AsyncMock() + mock_page.content = AsyncMock(return_value='{"availablePoints": 2000}') + mock_page.query_selector_all = AsyncMock(return_value=[]) + mock_page.evaluate = AsyncMock(return_value="Points: 2000") + + mock_client = AsyncMock() + mock_client.get_current_points.side_effect = TimeoutError("API timeout") + + with patch("src.account.points_detector.DashboardClient", return_value=mock_client): + result = await detector.get_current_points(mock_page, skip_navigation=True) + assert result == 2000 + + @pytest.mark.asyncio + async def test_api_connection_error_fallback(self): + """Test API connection error falls back to HTML parsing""" + from src.account.points_detector import PointsDetector + + detector = PointsDetector() + mock_page = AsyncMock() + mock_page.goto = AsyncMock() + mock_page.url = "https://rewards.bing.com/" + mock_page.wait_for_timeout = AsyncMock() + mock_page.content = AsyncMock(return_value='{"pointsBalance": 1500}') + mock_page.query_selector_all = AsyncMock(return_value=[]) + mock_page.evaluate = AsyncMock(return_value="Balance: 1500") + + mock_client = AsyncMock() + mock_client.get_current_points.side_effect = ConnectionError("Connection failed") + + with patch("src.account.points_detector.DashboardClient", return_value=mock_client): + result = await detector.get_current_points(mock_page, skip_navigation=True) + assert result == 1500 + + +class TestParsePoints: + """Test _parse_points method""" + + def test_parse_points_valid_text(self): + """Test parsing valid points text""" + from src.account.points_detector import PointsDetector + + detector = PointsDetector() + assert detector._parse_points("1,234 points") == 1234 + assert detector._parse_points("Available: 5000") == 5000 + assert detector._parse_points("12345") == 12345 + + def test_parse_points_empty_string(self): + """Test parsing empty string returns None""" + from src.account.points_detector import PointsDetector + + detector = PointsDetector() + assert detector._parse_points("") is None + + def test_parse_points_none(self): + """Test parsing None returns None""" + from src.account.points_detector import PointsDetector + + detector = PointsDetector() + result = detector._parse_points(None) # type: ignore + assert result is None + + def test_parse_points_whitespace_only(self): + """Test parsing whitespace-only string returns None""" + from src.account.points_detector import PointsDetector + + detector = PointsDetector() + assert detector._parse_points(" ") is None + assert detector._parse_points("\t\n") is None + + def test_parse_points_no_numbers(self): + """Test parsing text without numbers returns None""" + from src.account.points_detector import PointsDetector + + detector = PointsDetector() + assert detector._parse_points("no numbers here") is None + + def test_parse_points_out_of_range(self): + """Test parsing out of range values returns None""" + from src.account.points_detector import PointsDetector + + detector = PointsDetector() + assert detector._parse_points("99999999") is None + + +class TestNonGitEnvironment: + """Test non-Git environment error handling""" + + def test_get_git_branch_not_git_repo(self): + """Test get_git_branch in non-git directory""" + from src.review.resolver import get_git_branch + + with patch.object(subprocess, "run") as mock_run: + mock_run.return_value = MagicMock( + returncode=128, stdout="", stderr="not a git repository" + ) + result = get_git_branch() + assert result == "" + + def test_get_git_branch_git_not_installed(self): + """Test get_git_branch when git is not installed""" + from src.review.resolver import get_git_branch + + with patch.object(subprocess, "run", side_effect=FileNotFoundError("git not found")): + result = get_git_branch() + assert result == "" + + def test_get_git_head_sha_not_git_repo(self): + """Test get_git_head_sha in non-git directory""" + from src.review.resolver import get_git_head_sha + + with patch.object(subprocess, "run") as mock_run: + mock_run.return_value = MagicMock( + returncode=128, stdout="", stderr="not a git repository" + ) + result = get_git_head_sha() + assert result == "" + + def test_get_git_head_sha_git_not_installed(self): + """Test get_git_head_sha when git is not installed""" + from src.review.resolver import get_git_head_sha + + with patch.object(subprocess, "run", side_effect=FileNotFoundError("git not found")): + result = get_git_head_sha() + assert result == "" diff --git a/tests/unit/test_online_query_sources.py b/tests/unit/test_online_query_sources.py index 4c8f6061..b21863bf 100644 --- a/tests/unit/test_online_query_sources.py +++ b/tests/unit/test_online_query_sources.py @@ -366,3 +366,296 @@ def mock_get(key, default=None): config.get = MagicMock(side_effect=mock_get) source = WikipediaSource(config) assert source.timeout == 15 + + +class TestWikipediaTopViewsSource: + """Test Wikipedia Top Views query source""" + + @pytest.fixture + def mock_config(self): + """Create mock config""" + config = MagicMock() + config.get = MagicMock(return_value=30) + return config + + @pytest.fixture + def wikipedia_top_views_source(self, mock_config): + """Create Wikipedia Top Views source""" + from search.query_sources.wikipedia_top_views_source import WikipediaTopViewsSource + + return WikipediaTopViewsSource(mock_config) + + def test_source_initialization(self, wikipedia_top_views_source): + """Test source initialization""" + assert wikipedia_top_views_source is not None + assert wikipedia_top_views_source.get_source_name() == "wikipedia_top_views" + assert wikipedia_top_views_source.is_available() is True + + def test_get_priority(self, wikipedia_top_views_source): + """Test get_priority returns 120""" + assert wikipedia_top_views_source.get_priority() == 120 + + def test_excluded_prefixes_exist(self, wikipedia_top_views_source): + """Test that excluded prefixes are defined""" + assert len(wikipedia_top_views_source.EXCLUDED_PREFIXES) > 0 + assert "Main_Page" in wikipedia_top_views_source.EXCLUDED_PREFIXES + assert "Special:" in wikipedia_top_views_source.EXCLUDED_PREFIXES + + def test_cache_stats_initial(self, wikipedia_top_views_source): + """Test initial cache stats""" + stats = wikipedia_top_views_source.get_cache_stats() + assert stats["hits"] == 0 + assert stats["misses"] == 0 + assert stats["hit_rate"] == 0 + + @pytest.mark.asyncio + async def test_fetch_queries_returns_list(self, wikipedia_top_views_source): + """Test that fetch_queries returns a list with mocked HTTP""" + mock_session = AsyncMock() + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.json = AsyncMock( + return_value={ + "items": [ + { + "articles": [ + {"article": "Python_(programming_language)"}, + {"article": "Main_Page"}, + {"article": "Artificial_intelligence"}, + ] + } + ] + } + ) + mock_context = AsyncMock() + mock_context.__aenter__ = AsyncMock(return_value=mock_response) + mock_context.__aexit__ = AsyncMock(return_value=None) + mock_session.get = MagicMock(return_value=mock_context) + mock_session.closed = False + + wikipedia_top_views_source._session = mock_session + + queries = await wikipedia_top_views_source.fetch_queries(5) + assert isinstance(queries, list) + + @pytest.mark.asyncio + async def test_fetch_queries_filters_excluded_articles(self, wikipedia_top_views_source): + """Test that fetch_queries filters out excluded articles""" + mock_session = AsyncMock() + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.json = AsyncMock( + return_value={ + "items": [ + { + "articles": [ + {"article": "Main_Page"}, + {"article": "Special:Search"}, + {"article": "Valid_Article"}, + ] + } + ] + } + ) + mock_context = AsyncMock() + mock_context.__aenter__ = AsyncMock(return_value=mock_response) + mock_context.__aexit__ = AsyncMock(return_value=None) + mock_session.get = MagicMock(return_value=mock_context) + mock_session.closed = False + + wikipedia_top_views_source._session = mock_session + + queries = await wikipedia_top_views_source.fetch_queries(10) + assert "Main Page" not in queries + assert "Special:Search" not in queries + assert "Valid Article" in queries + + @pytest.mark.asyncio + async def test_fetch_queries_handles_error(self, wikipedia_top_views_source): + """Test that fetch_queries handles HTTP errors gracefully""" + mock_session = AsyncMock() + mock_session.get = AsyncMock(side_effect=Exception("Network error")) + mock_session.closed = False + + wikipedia_top_views_source._session = mock_session + + queries = await wikipedia_top_views_source.fetch_queries(5) + assert isinstance(queries, list) + assert len(queries) == 0 + + @pytest.mark.asyncio + async def test_fetch_queries_handles_non_200_status(self, wikipedia_top_views_source): + """Test that fetch_queries handles non-200 status codes""" + mock_session = AsyncMock() + mock_response = AsyncMock() + mock_response.status = 500 + mock_context = AsyncMock() + mock_context.__aenter__ = AsyncMock(return_value=mock_response) + mock_context.__aexit__ = AsyncMock(return_value=None) + mock_session.get = MagicMock(return_value=mock_context) + mock_session.closed = False + + wikipedia_top_views_source._session = mock_session + + queries = await wikipedia_top_views_source.fetch_queries(5) + assert isinstance(queries, list) + + @pytest.mark.asyncio + async def test_cache_hit(self, wikipedia_top_views_source): + """Test that cache is used on second call""" + mock_session = AsyncMock() + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.json = AsyncMock( + return_value={ + "items": [ + { + "articles": [ + {"article": "Test_Article"}, + ] + } + ] + } + ) + mock_context = AsyncMock() + mock_context.__aenter__ = AsyncMock(return_value=mock_response) + mock_context.__aexit__ = AsyncMock(return_value=None) + mock_session.get = MagicMock(return_value=mock_context) + mock_session.closed = False + + wikipedia_top_views_source._session = mock_session + + await wikipedia_top_views_source.fetch_queries(5) + stats_after_first = wikipedia_top_views_source.get_cache_stats() + assert stats_after_first["misses"] == 1 + + await wikipedia_top_views_source.fetch_queries(5) + stats_after_second = wikipedia_top_views_source.get_cache_stats() + assert stats_after_second["hits"] == 1 + + @pytest.mark.asyncio + async def test_close_session(self, wikipedia_top_views_source): + """Test that close() properly closes the session""" + mock_session = AsyncMock() + mock_session.closed = False + wikipedia_top_views_source._session = mock_session + + await wikipedia_top_views_source.close() + mock_session.close.assert_awaited_once() + + @pytest.mark.asyncio + async def test_close_handles_none_session(self, wikipedia_top_views_source): + """Test that close() handles None session gracefully""" + wikipedia_top_views_source._session = None + await wikipedia_top_views_source.close() + + +class TestSourceAvailabilitySkip: + """Test that sources with is_available()=False are skipped""" + + @pytest.fixture + def mock_config(self): + """Create mock config""" + config = MagicMock() + config.get = MagicMock(return_value=15) + return config + + def test_duckduckgo_unavailable_not_added(self, mock_config): + """Test DuckDuckGo source not added when unavailable""" + source = DuckDuckGoSource(mock_config) + source._available = False + + assert source.is_available() is False + + def test_wikipedia_unavailable_not_added(self, mock_config): + """Test Wikipedia source not added when unavailable""" + source = WikipediaSource(mock_config) + source._available = False + + assert source.is_available() is False + + def test_wikipedia_top_views_unavailable_not_added(self, mock_config): + """Test WikipediaTopViewsSource not added when unavailable""" + from search.query_sources.wikipedia_top_views_source import WikipediaTopViewsSource + + source = WikipediaTopViewsSource(mock_config) + source._available = False + + assert source.is_available() is False + + @pytest.mark.asyncio + async def test_source_returns_empty_when_unavailable(self, mock_config): + """Test that fetch_queries returns empty list when unavailable""" + source = DuckDuckGoSource(mock_config) + source._available = False + + mock_session = AsyncMock() + mock_session.closed = False + source._session = mock_session + + result = await source.fetch_queries(5) + assert result == [] + + +class TestQuerySourcePriority: + """Test query source priority ordering""" + + @pytest.fixture + def mock_config(self): + """Create mock config""" + config = MagicMock() + config.get = MagicMock(return_value=15) + return config + + def test_local_file_source_priority(self, mock_config): + """Test LocalFileSource priority is 100""" + from search.query_sources.local_file_source import LocalFileSource + + source = LocalFileSource(mock_config) + assert source.get_priority() == 100 + + def test_wikipedia_top_views_priority(self, mock_config): + """Test WikipediaTopViewsSource priority is 120""" + from search.query_sources.wikipedia_top_views_source import WikipediaTopViewsSource + + source = WikipediaTopViewsSource(mock_config) + assert source.get_priority() == 120 + + def test_duckduckgo_priority(self, mock_config): + """Test DuckDuckGoSource priority is 50""" + source = DuckDuckGoSource(mock_config) + assert source.get_priority() == 50 + + def test_wikipedia_priority(self, mock_config): + """Test WikipediaSource priority is 60""" + source = WikipediaSource(mock_config) + assert source.get_priority() == 60 + + def test_bing_suggestions_priority(self, mock_config): + """Test BingSuggestionsSource priority is 70""" + from search.query_sources.bing_suggestions_source import BingSuggestionsSource + + source = BingSuggestionsSource(mock_config) + assert source.get_priority() == 70 + + def test_sources_sort_by_priority(self, mock_config): + """Test that sources can be sorted by priority""" + from search.query_sources.bing_suggestions_source import BingSuggestionsSource + from search.query_sources.local_file_source import LocalFileSource + from search.query_sources.wikipedia_top_views_source import WikipediaTopViewsSource + + sources = [ + BingSuggestionsSource(mock_config), + WikipediaSource(mock_config), + LocalFileSource(mock_config), + DuckDuckGoSource(mock_config), + WikipediaTopViewsSource(mock_config), + ] + + sources.sort(key=lambda s: s.get_priority()) + + assert sources[0].get_source_name() == "duckduckgo" + assert sources[1].get_source_name() == "wikipedia" + assert sources[2].get_source_name() == "bing_suggestions" + assert sources[3].get_source_name() == "local_file" + assert sources[4].get_source_name() == "wikipedia_top_views" diff --git a/tests/unit/test_review_context.py b/tests/unit/test_review_context.py new file mode 100644 index 00000000..b41b0830 --- /dev/null +++ b/tests/unit/test_review_context.py @@ -0,0 +1,82 @@ +import subprocess +from unittest.mock import MagicMock, patch + +from review.models import ReviewMetadata +from review.resolver import get_git_branch, get_git_head_sha + + +class TestGetGitBranch: + """测试 get_git_branch 函数""" + + def test_get_branch_success(self): + """测试成功获取分支名称""" + with patch.object(subprocess, "run") as mock_run: + mock_run.return_value = MagicMock( + returncode=0, stdout="feature-test-branch\n", stderr="" + ) + result = get_git_branch() + assert result == "feature-test-branch" + + def test_get_branch_failure(self): + """测试获取分支失败""" + with patch.object(subprocess, "run") as mock_run: + mock_run.return_value = MagicMock(returncode=1, stdout="", stderr="error") + result = get_git_branch() + assert result == "" + + def test_get_branch_exception(self): + """测试获取分支异常""" + with patch.object(subprocess, "run", side_effect=Exception("test error")): + result = get_git_branch() + assert result == "" + + +class TestGetGitHeadSha: + """测试 get_git_head_sha 函数""" + + def test_get_sha_success(self): + """测试成功获取 SHA""" + with patch.object(subprocess, "run") as mock_run: + mock_run.return_value = MagicMock(returncode=0, stdout="abc1234\n", stderr="") + result = get_git_head_sha() + assert result == "abc1234" + + def test_get_sha_failure(self): + """测试获取 SHA 失败""" + with patch.object(subprocess, "run") as mock_run: + mock_run.return_value = MagicMock(returncode=1, stdout="", stderr="error") + result = get_git_head_sha() + assert result == "" + + def test_get_sha_exception(self): + """测试获取 SHA 异常""" + with patch.object(subprocess, "run", side_effect=Exception("test error")): + result = get_git_head_sha() + assert result == "" + + +class TestReviewMetadataBranch: + """测试 ReviewMetadata 分支字段""" + + def test_metadata_with_branch(self): + """测试带分支信息的 metadata""" + metadata = ReviewMetadata( + pr_number=123, + owner="test-owner", + repo="test-repo", + branch="feature-test", + head_sha="abc1234", + ) + assert metadata.branch == "feature-test" + assert metadata.head_sha == "abc1234" + assert metadata.version == "2.3" + + def test_metadata_without_branch(self): + """测试不带分支信息的 metadata(兼容旧版本)""" + metadata = ReviewMetadata( + pr_number=123, + owner="test-owner", + repo="test-repo", + ) + assert metadata.branch == "" + assert metadata.head_sha == "" diff --git a/tests/unit/test_review_parsers.py b/tests/unit/test_review_parsers.py index e81cfa11..50425423 100644 --- a/tests/unit/test_review_parsers.py +++ b/tests/unit/test_review_parsers.py @@ -205,7 +205,7 @@ def test_create_metadata(self): assert metadata.pr_number == 123 assert metadata.owner == "test-owner" assert metadata.repo == "test-repo" - assert metadata.version == "2.2" + assert metadata.version == "2.3" assert metadata.etag_comments is None diff --git a/tools/manage_reviews.py b/tools/manage_reviews.py index b177e8e2..a0a99d18 100644 --- a/tools/manage_reviews.py +++ b/tools/manage_reviews.py @@ -3,11 +3,12 @@ AI 审查评论管理工具 CLI 用法: - python tools/manage_reviews.py fetch --owner OWNER --repo REPO --pr PR_NUMBER + python tools/manage_reviews.py fetch --owner OWNER --repo REPO [--pr PR_NUMBER] python tools/manage_reviews.py resolve --thread-id THREAD_ID --type RESOLUTION_TYPE [--reply "回复内容"] python tools/manage_reviews.py list [--status STATUS] [--source SOURCE] [--format FORMAT] python tools/manage_reviews.py overviews python tools/manage_reviews.py stats + python tools/manage_reviews.py verify-context 环境变量: GITHUB_TOKEN: GitHub Personal Access Token (也可通过 .env 文件配置) @@ -15,6 +16,7 @@ import argparse import json +import logging import sys from pathlib import Path @@ -57,7 +59,8 @@ def get_token() -> str: { "success": False, "message": "错误: 未设置 GITHUB_TOKEN 环境变量,请在 .env 文件中配置", - } + }, + ensure_ascii=False, ) ) sys.exit(1) @@ -164,10 +167,90 @@ def print_threads_table(threads: list[ReviewThreadState], title: str = "审查 def cmd_fetch(args: argparse.Namespace) -> None: """执行 fetch 子命令""" + import subprocess + + logger = logging.getLogger(__name__) db_path = get_db_path() + + pr_number = args.pr + if pr_number is None: + try: + result = subprocess.run( + ["gh", "pr", "view", "--json", "number", "-R", f"{args.owner}/{args.repo}"], + capture_output=True, + text=True, + timeout=30, + ) + if result.returncode == 0: + import json as json_module + + pr_data = json_module.loads(result.stdout) + pr_number = pr_data.get("number") + else: + logger.error(f"gh pr view 命令失败: {result.stderr}") + print( + json.dumps( + { + "success": False, + "message": "无法自动获取 PR 编号,请使用 --pr 参数指定", + }, + ensure_ascii=False, + ) + ) + return + except FileNotFoundError: + logger.error("gh 命令未安装") + print( + json.dumps( + { + "success": False, + "message": "gh 命令未安装,请安装 GitHub CLI 或手动指定 --pr 参数", + }, + ensure_ascii=False, + ) + ) + return + except subprocess.TimeoutExpired: + logger.error("gh 命令执行超时") + print( + json.dumps( + { + "success": False, + "message": "gh 命令执行超时,请检查网络连接或手动指定 --pr 参数", + }, + ensure_ascii=False, + ) + ) + return + except PermissionError: + logger.error("gh 命令权限不足") + print( + json.dumps( + { + "success": False, + "message": "gh 命令权限不足,请检查 GitHub CLI 认证状态或手动指定 --pr 参数", + }, + ensure_ascii=False, + ) + ) + return + except Exception as e: + logger.error(f"获取 PR 编号失败: {type(e).__name__}") + print( + json.dumps( + {"success": False, "message": "获取 PR 编号失败,请使用 --pr 参数指定"}, + ensure_ascii=False, + ) + ) + return + + if pr_number is None: + print(json.dumps({"success": False, "message": "未指定 PR 编号"}, ensure_ascii=False)) + return + resolver = ReviewResolver(token=get_token(), owner=args.owner, repo=args.repo, db_path=db_path) - result = resolver.fetch_threads(args.pr) + result = resolver.fetch_threads(pr_number) print(json.dumps(result, indent=2, ensure_ascii=False)) @@ -354,7 +437,91 @@ def cmd_stats(args: argparse.Namespace) -> None: print(json.dumps(result, indent=2, ensure_ascii=False)) +def cmd_verify_context(args: argparse.Namespace) -> None: + """执行 verify-context 子命令 - 验证本地评论是否属于当前分支""" + import logging + import subprocess + + logger = logging.getLogger(__name__) + db_path = get_db_path() + manager = ReviewManager(db_path) + git_error = None + + try: + result = subprocess.run( + ["git", "branch", "--show-current"], + capture_output=True, + text=True, + timeout=10, + ) + if result.returncode == 0: + current_branch = result.stdout.strip() + else: + current_branch = "" + git_error = f"git 命令执行失败: {result.stderr}" + logger.warning(git_error) + except FileNotFoundError: + current_branch = "" + git_error = "git 命令未找到,请确保已安装 Git" + logger.warning(git_error) + except subprocess.TimeoutExpired: + current_branch = "" + git_error = "git 命令执行超时" + logger.warning(git_error) + except Exception as e: + current_branch = "" + git_error = f"获取分支信息异常: {e}" + logger.warning(git_error) + + metadata = manager.get_metadata() + + if metadata is None: + output = { + "success": True, + "valid": None, + "message": "无本地评论数据", + } + elif not metadata.branch: + output = { + "success": True, + "valid": None, + "warning": "旧版本数据,缺少分支信息,跳过验证", + "stored_pr": metadata.pr_number, + } + if git_error: + output["git_warning"] = git_error + elif current_branch == metadata.branch: + output = { + "success": True, + "valid": True, + "current_branch": current_branch, + "stored_branch": metadata.branch, + "stored_pr": metadata.pr_number, + "message": "上下文验证通过", + } + if git_error: + output["git_warning"] = git_error + else: + output = { + "success": True, + "valid": False, + "current_branch": current_branch, + "stored_branch": metadata.branch, + "stored_pr": metadata.pr_number, + "message": f"分支不匹配:当前分支 {current_branch},本地数据属于 {metadata.branch} (PR #{metadata.pr_number})", + "action": f"请执行 'python tools/manage_reviews.py fetch --owner {metadata.owner} --repo {metadata.repo} --pr <当前分支的PR号>' 重新拉取评论", + } + if git_error: + output["git_warning"] = git_error + + print(json.dumps(output, indent=2, ensure_ascii=False)) + + def main() -> None: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) parser = argparse.ArgumentParser( description="AI 审查评论管理工具", formatter_class=argparse.RawDescriptionHelpFormatter ) @@ -364,7 +531,9 @@ def main() -> None: parser_fetch = subparsers.add_parser("fetch", help="获取 PR 的评论线程") parser_fetch.add_argument("--owner", required=True, help="仓库所有者") parser_fetch.add_argument("--repo", required=True, help="仓库名称") - parser_fetch.add_argument("--pr", type=int, required=True, help="PR 编号") + parser_fetch.add_argument( + "--pr", type=int, required=False, help="PR 编号(可选,不指定时自动获取当前分支的 PR)" + ) parser_fetch.set_defaults(func=cmd_fetch) parser_resolve = subparsers.add_parser("resolve", help="解决评论线程") @@ -407,6 +576,9 @@ def main() -> None: ) parser_stats.set_defaults(func=cmd_stats) + parser_verify = subparsers.add_parser("verify-context", help="验证本地评论是否属于当前分支") + parser_verify.set_defaults(func=cmd_verify_context) + args = parser.parse_args() if args.command is None: @@ -416,10 +588,15 @@ def main() -> None: try: args.func(args) except KeyboardInterrupt: - print(json.dumps({"success": False, "message": "操作已取消"})) + print(json.dumps({"success": False, "message": "操作已取消"}, ensure_ascii=False)) sys.exit(130) except Exception: - print(json.dumps({"success": False, "message": "操作失败,请检查日志获取详细信息"})) + print( + json.dumps( + {"success": False, "message": "操作失败,请检查日志获取详细信息"}, + ensure_ascii=False, + ) + ) sys.exit(1)