From 3e5e50102e77b11d00eeacd8d0ff40aa6ac2a109 Mon Sep 17 00:00:00 2001 From: Carl Bergenhem Date: Thu, 5 Feb 2026 15:09:01 -0800 Subject: [PATCH 01/13] Add triage-aware PR comment filtering Fetch triage entries from Socket API after scan submission, remove alerts with ignore/monitor state from results, regenerate connector notifications with filtered components, and inject a triage count summary into GitHub PR comments. Co-Authored-By: Claude Opus 4.6 --- socket_basics/core/triage.py | 195 +++++++++++++++++ socket_basics/socket_basics.py | 216 ++++++++++++++++++- tests/__init__.py | 0 tests/test_triage.py | 384 +++++++++++++++++++++++++++++++++ 4 files changed, 794 insertions(+), 1 deletion(-) create mode 100644 socket_basics/core/triage.py create mode 100644 tests/__init__.py create mode 100644 tests/test_triage.py diff --git a/socket_basics/core/triage.py b/socket_basics/core/triage.py new file mode 100644 index 0000000..0738f95 --- /dev/null +++ b/socket_basics/core/triage.py @@ -0,0 +1,195 @@ +"""Triage filtering for Socket Security Basics. + +Fetches triage entries from the Socket API and filters scan components +whose alerts have been triaged (state: ignore or monitor). +""" + +import fnmatch +import logging +from typing import Any, Dict, List, Tuple + +logger = logging.getLogger(__name__) + +# Triage states that cause a finding to be removed from reports +_SUPPRESSED_STATES = {"ignore", "monitor"} + + +def fetch_triage_data(sdk: Any, org_slug: str) -> List[Dict[str, Any]]: + """Fetch all triage alert entries from the Socket API, handling pagination. + + Args: + sdk: Initialized socketdev SDK instance. + org_slug: Organization slug for the API call. + + Returns: + List of triage entry dicts. + """ + all_entries: List[Dict[str, Any]] = [] + page = 1 + per_page = 100 + + while True: + try: + response = sdk.triage.list_alert_triage( + org_slug, + {"per_page": per_page, "page": page}, + ) + except Exception: + logger.exception("Failed to fetch triage data (page %d)", page) + break + + if not isinstance(response, dict): + logger.warning("Unexpected triage API response type: %s", type(response)) + break + + results = response.get("results") or [] + all_entries.extend(results) + + next_page = response.get("nextPage") + if next_page is None: + break + page = int(next_page) + + logger.debug("Fetched %d triage entries for org %s", len(all_entries), org_slug) + return all_entries + + +class TriageFilter: + """Matches local scan alerts against triage entries and filters them out.""" + + def __init__(self, triage_entries: List[Dict[str, Any]]) -> None: + # Only keep entries whose state suppresses findings + self.entries = [ + e for e in triage_entries + if (e.get("state") or "").lower() in _SUPPRESSED_STATES + ] + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def is_alert_triaged(self, component: Dict[str, Any], alert: Dict[str, Any]) -> bool: + """Return True if the alert on the given component matches a suppressed triage entry.""" + alert_keys = self._extract_alert_keys(alert) + if not alert_keys: + return False + + for entry in self.entries: + entry_key = entry.get("alert_key") + if not entry_key: + continue + + if entry_key not in alert_keys: + continue + + # alert_key matched; now check package scope + if self._is_broad_match(entry): + return True + + if self._package_matches(entry, component): + return True + + return False + + def filter_components( + self, components: List[Dict[str, Any]] + ) -> Tuple[List[Dict[str, Any]], int]: + """Remove triaged alerts from components. + + Returns: + (filtered_components, triaged_count) where triaged_count is the + total number of individual alerts removed. + """ + if not self.entries: + return components, 0 + + filtered: List[Dict[str, Any]] = [] + triaged_count = 0 + + for comp in components: + remaining_alerts: List[Dict[str, Any]] = [] + for alert in comp.get("alerts", []): + if self.is_alert_triaged(comp, alert): + triaged_count += 1 + else: + remaining_alerts.append(alert) + + if remaining_alerts: + new_comp = dict(comp) + new_comp["alerts"] = remaining_alerts + filtered.append(new_comp) + + return filtered, triaged_count + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + @staticmethod + def _extract_alert_keys(alert: Dict[str, Any]) -> set: + """Build the set of candidate keys that could match a triage entry's alert_key.""" + keys: set = set() + props = alert.get("props") or {} + + for field in ( + alert.get("title"), + alert.get("type"), + props.get("ruleId"), + props.get("detectorName"), + props.get("vulnerabilityId"), + props.get("cveId"), + ): + if field: + keys.add(str(field)) + + return keys + + @staticmethod + def _is_broad_match(entry: Dict[str, Any]) -> bool: + """Return True when the triage entry has no package scope (applies globally).""" + return ( + entry.get("package_name") is None + and entry.get("package_type") is None + and entry.get("package_version") is None + and entry.get("package_namespace") is None + ) + + @staticmethod + def _version_matches(entry_version: str, component_version: str) -> bool: + """Check version match, supporting wildcard suffix patterns like '1.2.*'.""" + if not entry_version or entry_version == "*": + return True + if not component_version: + return False + # fnmatch handles '*' and '?' glob patterns + return fnmatch.fnmatch(component_version, entry_version) + + @classmethod + def _package_matches(cls, entry: Dict[str, Any], component: Dict[str, Any]) -> bool: + """Return True if the triage entry's package scope matches the component.""" + qualifiers = component.get("qualifiers") or {} + comp_name = component.get("name") or "" + comp_type = ( + qualifiers.get("ecosystem") + or qualifiers.get("type") + or component.get("type") + or "" + ) + comp_version = component.get("version") or qualifiers.get("version") or "" + comp_namespace = qualifiers.get("namespace") or "" + + entry_name = entry.get("package_name") + entry_type = entry.get("package_type") + entry_version = entry.get("package_version") + entry_namespace = entry.get("package_namespace") + + if entry_name is not None and entry_name != comp_name: + return False + if entry_type is not None and entry_type.lower() != comp_type.lower(): + return False + if entry_namespace is not None and entry_namespace != comp_namespace: + return False + if entry_version is not None and not cls._version_matches(entry_version, comp_version): + return False + + return True diff --git a/socket_basics/socket_basics.py b/socket_basics/socket_basics.py index a7f7f04..6311611 100644 --- a/socket_basics/socket_basics.py +++ b/socket_basics/socket_basics.py @@ -17,7 +17,7 @@ import sys import os from pathlib import Path -from typing import Dict, Any, Optional +from typing import Dict, Any, List, Optional import hashlib try: # Python 3.11+ @@ -378,6 +378,214 @@ def submit_socket_facts(self, socket_facts_path: Path, results: Dict[str, Any]) return results + def apply_triage_filter(self, results: Dict[str, Any]) -> Dict[str, Any]: + """Filter out triaged alerts and regenerate notifications. + + Fetches triage entries from the Socket API, removes alerts with + state ``ignore`` or ``monitor``, regenerates connector notifications + for the remaining components, and injects a triage summary line into + github_pr notification content. + + Args: + results: Current scan results dict (components + notifications). + + Returns: + Updated results dict with triaged findings removed. + """ + socket_api_key = self.config.get('socket_api_key') + socket_org = self.config.get('socket_org') + + if not socket_api_key or not socket_org: + logger.debug("Skipping triage filter: missing socket_api_key or socket_org") + return results + + # Import SDK and triage helpers + try: + from socketdev import socketdev + except ImportError: + logger.debug("socketdev SDK not available; skipping triage filter") + return results + + try: + from .core.triage import TriageFilter, fetch_triage_data + except ImportError: + from socket_basics.core.triage import TriageFilter, fetch_triage_data + + sdk = socketdev(token=socket_api_key, timeout=100) + triage_entries = fetch_triage_data(sdk, socket_org) + + if not triage_entries: + logger.debug("No triage entries found; skipping filter") + return results + + triage_filter = TriageFilter(triage_entries) + filtered_components, triaged_count = triage_filter.filter_components( + results.get('components', []) + ) + + if triaged_count == 0: + logger.debug("No findings matched triage entries") + return results + + logger.info("Filtered %d triaged finding(s) from results", triaged_count) + results['components'] = filtered_components + results['triaged_count'] = triaged_count + + # Regenerate notifications from the filtered components + self._regenerate_notifications(results, filtered_components, triaged_count) + + return results + + def _regenerate_notifications( + self, + results: Dict[str, Any], + filtered_components: List[Dict[str, Any]], + triaged_count: int, + ) -> None: + """Regenerate connector notifications from filtered components. + + Groups components by their connector origin (via the ``generatedBy`` + field on alerts), calls each connector's ``generate_notifications``, + merges the results, and injects a triage summary into github_pr + content. + """ + connector_components: Dict[str, List[Dict[str, Any]]] = {} + for comp in filtered_components: + for alert in comp.get('alerts', []): + gen = alert.get('generatedBy') or '' + connector_name = self._connector_name_from_generated_by(gen) + if connector_name: + connector_components.setdefault(connector_name, []).append(comp) + break # one mapping per component is enough + + merged_notifications: Dict[str, list] = {} + + for connector_name, comps in connector_components.items(): + connector = self.connector_manager.loaded_connectors.get(connector_name) + if connector is None: + logger.debug("Connector %s not loaded; skipping notification regen", connector_name) + continue + + if not hasattr(connector, 'generate_notifications'): + logger.debug("Connector %s has no generate_notifications", connector_name) + continue + + try: + if connector_name == 'trivy': + item_name, scan_type = self._derive_trivy_params(comps) + notifs = connector.generate_notifications(comps, item_name, scan_type) + else: + notifs = connector.generate_notifications(comps) + except Exception: + logger.exception("Failed to regenerate notifications for %s", connector_name) + continue + + if not isinstance(notifs, dict): + continue + + for notifier_key, payload in notifs.items(): + if notifier_key not in merged_notifications: + merged_notifications[notifier_key] = payload + elif isinstance(merged_notifications[notifier_key], list) and isinstance(payload, list): + merged_notifications[notifier_key].extend(payload) + + # Inject triage summary into github_pr notification content + full_scan_url = results.get('full_scan_html_url', '') + self._inject_triage_summary(merged_notifications, triaged_count, full_scan_url) + + if merged_notifications: + results['notifications'] = merged_notifications + + @staticmethod + def _connector_name_from_generated_by(generated_by: str) -> str | None: + """Map a generatedBy value back to its connector name.""" + gb = generated_by.lower() + if gb.startswith('opengrep') or gb.startswith('sast'): + return 'opengrep' + if gb == 'trufflehog': + return 'trufflehog' + if gb.startswith('trivy'): + return 'trivy' + if gb == 'socket-tier1': + return 'socket_tier1' + return None + + def _derive_trivy_params( + self, components: List[Dict[str, Any]] + ) -> tuple: + """Derive item_name and scan_type for Trivy notification regeneration.""" + scan_type = 'image' + for comp in components: + for alert in comp.get('alerts', []): + props = alert.get('props') or {} + st = props.get('scanType', '') + if st: + scan_type = st + break + if scan_type != 'image': + break + + item_name = "Unknown" + images_str = ( + self.config.get('container_images', '') + or self.config.get('container_images_to_scan', '') + or self.config.get('docker_images', '') + ) + if images_str: + if isinstance(images_str, list): + item_name = images_str[0] if images_str else "Unknown" + else: + images = [img.strip() for img in str(images_str).split(',') if img.strip()] + item_name = images[0] if images else "Unknown" + else: + dockerfiles = self.config.get('dockerfiles', '') + if dockerfiles: + if isinstance(dockerfiles, list): + item_name = dockerfiles[0] if dockerfiles else "Unknown" + else: + docker_list = [df.strip() for df in str(dockerfiles).split(',') if df.strip()] + item_name = docker_list[0] if docker_list else "Unknown" + + if scan_type == 'vuln' and item_name == "Unknown": + try: + item_name = os.path.basename(str(self.config.workspace)) + except Exception: + item_name = "Workspace" + + return item_name, scan_type + + @staticmethod + def _inject_triage_summary( + notifications: Dict[str, list], + triaged_count: int, + full_scan_url: str, + ) -> None: + """Insert a triage summary line into github_pr notification content.""" + gh_items = notifications.get('github_pr') + if not gh_items or not isinstance(gh_items, list): + return + + dashboard_link = full_scan_url or "https://socket.dev/dashboard" + summary_line = ( + f"\n> :white_check_mark: **{triaged_count} finding(s) triaged** " + f"via [Socket Dashboard]({dashboard_link}) and removed from this report.\n" + ) + + for item in gh_items: + if not isinstance(item, dict) or 'content' not in item: + continue + content = item['content'] + # Insert after the first markdown heading line (# Title) + lines = content.split('\n') + insert_idx = 0 + for i, line in enumerate(lines): + if line.strip().startswith('# '): + insert_idx = i + 1 + break + lines.insert(insert_idx, summary_line) + item['content'] = '\n'.join(lines) + + def main(): """Main entry point""" parser = parse_cli_args() @@ -429,6 +637,12 @@ def main(): except Exception: logger.exception("Failed to submit socket facts file") + # Filter out triaged alerts before notifying + try: + results = scanner.apply_triage_filter(results) + except Exception: + logger.exception("Failed to apply triage filter") + # Optionally upload to S3 if requested try: enable_s3 = getattr(args, 'enable_s3_upload', False) or config.get('enable_s3_upload', False) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_triage.py b/tests/test_triage.py new file mode 100644 index 0000000..b1a472e --- /dev/null +++ b/tests/test_triage.py @@ -0,0 +1,384 @@ +"""Tests for socket_basics.core.triage module.""" + +import pytest +from socket_basics.core.triage import TriageFilter, fetch_triage_data + + +# --------------------------------------------------------------------------- +# Fixtures / helpers +# --------------------------------------------------------------------------- + +def _make_component( + name: str = "lodash", + comp_type: str = "npm", + version: str = "4.17.21", + alerts: list | None = None, +) -> dict: + return { + "id": f"pkg:{comp_type}/{name}@{version}", + "name": name, + "version": version, + "type": comp_type, + "qualifiers": {"ecosystem": comp_type, "version": version}, + "alerts": alerts or [], + } + + +def _make_alert( + title: str = "badEncoding", + alert_type: str = "supplyChainRisk", + severity: str = "high", + rule_id: str | None = None, + detector_name: str | None = None, + cve_id: str | None = None, + generated_by: str = "opengrep-python", +) -> dict: + props: dict = {} + if rule_id: + props["ruleId"] = rule_id + if detector_name: + props["detectorName"] = detector_name + if cve_id: + props["cveId"] = cve_id + return { + "title": title, + "type": alert_type, + "severity": severity, + "generatedBy": generated_by, + "props": props, + } + + +def _make_triage_entry( + alert_key: str, + state: str = "ignore", + package_name: str | None = None, + package_type: str | None = None, + package_version: str | None = None, + package_namespace: str | None = None, +) -> dict: + return { + "uuid": "test-uuid", + "alert_key": alert_key, + "state": state, + "package_name": package_name, + "package_type": package_type, + "package_version": package_version, + "package_namespace": package_namespace, + "note": "", + "organization_id": "test-org", + } + + +# --------------------------------------------------------------------------- +# TriageFilter.is_alert_triaged +# --------------------------------------------------------------------------- + +class TestIsAlertTriaged: + """Tests for the alert matching logic.""" + + def test_broad_match_by_title(self): + """Triage entry with no package info matches any component with matching alert_key.""" + entry = _make_triage_entry(alert_key="badEncoding") + tf = TriageFilter([entry]) + comp = _make_component() + alert = _make_alert(title="badEncoding") + assert tf.is_alert_triaged(comp, alert) is True + + def test_broad_match_by_rule_id(self): + entry = _make_triage_entry(alert_key="python.lang.security.audit.xss") + tf = TriageFilter([entry]) + comp = _make_component() + alert = _make_alert(title="XSS Vulnerability", rule_id="python.lang.security.audit.xss") + assert tf.is_alert_triaged(comp, alert) is True + + def test_broad_match_by_detector_name(self): + entry = _make_triage_entry(alert_key="AWS") + tf = TriageFilter([entry]) + comp = _make_component() + alert = _make_alert(title="AWS Key Detected", detector_name="AWS") + assert tf.is_alert_triaged(comp, alert) is True + + def test_broad_match_by_cve(self): + entry = _make_triage_entry(alert_key="CVE-2024-1234") + tf = TriageFilter([entry]) + comp = _make_component() + alert = _make_alert(title="Some Vuln", cve_id="CVE-2024-1234") + assert tf.is_alert_triaged(comp, alert) is True + + def test_no_match_different_key(self): + entry = _make_triage_entry(alert_key="differentRule") + tf = TriageFilter([entry]) + comp = _make_component() + alert = _make_alert(title="badEncoding") + assert tf.is_alert_triaged(comp, alert) is False + + def test_package_scoped_match(self): + """Triage entry with package info only matches the specific package.""" + entry = _make_triage_entry( + alert_key="badEncoding", + package_name="lodash", + package_type="npm", + ) + tf = TriageFilter([entry]) + + comp_match = _make_component(name="lodash", comp_type="npm") + comp_no_match = _make_component(name="express", comp_type="npm") + alert = _make_alert(title="badEncoding") + + assert tf.is_alert_triaged(comp_match, alert) is True + assert tf.is_alert_triaged(comp_no_match, alert) is False + + def test_package_version_exact_match(self): + entry = _make_triage_entry( + alert_key="badEncoding", + package_name="lodash", + package_type="npm", + package_version="4.17.21", + ) + tf = TriageFilter([entry]) + + comp_match = _make_component(name="lodash", comp_type="npm", version="4.17.21") + comp_no_match = _make_component(name="lodash", comp_type="npm", version="4.17.20") + alert = _make_alert(title="badEncoding") + + assert tf.is_alert_triaged(comp_match, alert) is True + assert tf.is_alert_triaged(comp_no_match, alert) is False + + def test_version_wildcard(self): + entry = _make_triage_entry( + alert_key="badEncoding", + package_name="lodash", + package_type="npm", + package_version="4.17.*", + ) + tf = TriageFilter([entry]) + alert = _make_alert(title="badEncoding") + + assert tf.is_alert_triaged( + _make_component(name="lodash", comp_type="npm", version="4.17.21"), alert + ) is True + assert tf.is_alert_triaged( + _make_component(name="lodash", comp_type="npm", version="4.17.0"), alert + ) is True + assert tf.is_alert_triaged( + _make_component(name="lodash", comp_type="npm", version="4.18.0"), alert + ) is False + + def test_version_star_matches_all(self): + entry = _make_triage_entry( + alert_key="badEncoding", + package_name="lodash", + package_type="npm", + package_version="*", + ) + tf = TriageFilter([entry]) + alert = _make_alert(title="badEncoding") + assert tf.is_alert_triaged( + _make_component(name="lodash", comp_type="npm", version="99.0.0"), alert + ) is True + + def test_states_block_and_warn_not_suppressed(self): + """Triage entries with block/warn/inherit states should not filter findings.""" + for state in ("block", "warn", "inherit"): + entry = _make_triage_entry(alert_key="badEncoding", state=state) + tf = TriageFilter([entry]) + assert tf.entries == [], f"state={state} should be excluded from filter entries" + + def test_state_monitor_suppressed(self): + entry = _make_triage_entry(alert_key="badEncoding", state="monitor") + tf = TriageFilter([entry]) + comp = _make_component() + alert = _make_alert(title="badEncoding") + assert tf.is_alert_triaged(comp, alert) is True + + def test_alert_with_no_matchable_keys(self): + """Alert with no title, type, or relevant props should not match.""" + entry = _make_triage_entry(alert_key="something") + tf = TriageFilter([entry]) + comp = _make_component() + alert = {"severity": "high", "props": {}} + assert tf.is_alert_triaged(comp, alert) is False + + +# --------------------------------------------------------------------------- +# TriageFilter.filter_components +# --------------------------------------------------------------------------- + +class TestFilterComponents: + def test_removes_triaged_alerts(self): + entry = _make_triage_entry(alert_key="badEncoding") + tf = TriageFilter([entry]) + + alert_triaged = _make_alert(title="badEncoding") + alert_kept = _make_alert(title="otherIssue") + comp = _make_component(alerts=[alert_triaged, alert_kept]) + + filtered, count = tf.filter_components([comp]) + assert count == 1 + assert len(filtered) == 1 + assert len(filtered[0]["alerts"]) == 1 + assert filtered[0]["alerts"][0]["title"] == "otherIssue" + + def test_removes_component_when_all_alerts_triaged(self): + entry = _make_triage_entry(alert_key="badEncoding") + tf = TriageFilter([entry]) + + comp = _make_component(alerts=[_make_alert(title="badEncoding")]) + filtered, count = tf.filter_components([comp]) + assert count == 1 + assert len(filtered) == 0 + + def test_no_triage_entries_returns_original(self): + tf = TriageFilter([]) + comp = _make_component(alerts=[_make_alert()]) + filtered, count = tf.filter_components([comp]) + assert count == 0 + assert filtered is [comp] or filtered == [comp] + + def test_multiple_components_mixed(self): + entry = _make_triage_entry(alert_key="badEncoding") + tf = TriageFilter([entry]) + + comp1 = _make_component(name="a", alerts=[_make_alert(title="badEncoding")]) + comp2 = _make_component(name="b", alerts=[_make_alert(title="otherIssue")]) + comp3 = _make_component( + name="c", + alerts=[ + _make_alert(title="badEncoding"), + _make_alert(title="keepMe"), + ], + ) + + filtered, count = tf.filter_components([comp1, comp2, comp3]) + assert count == 2 + assert len(filtered) == 2 + names = [c["name"] for c in filtered] + assert "a" not in names + assert "b" in names + assert "c" in names + + +# --------------------------------------------------------------------------- +# fetch_triage_data +# --------------------------------------------------------------------------- + +class TestFetchTriageData: + def test_single_page(self): + class FakeTriageAPI: + def list_alert_triage(self, org, params): + return {"results": [{"alert_key": "a", "state": "ignore"}], "nextPage": None} + + class FakeSDK: + triage = FakeTriageAPI() + + entries = fetch_triage_data(FakeSDK(), "my-org") + assert len(entries) == 1 + assert entries[0]["alert_key"] == "a" + + def test_pagination(self): + class FakeTriageAPI: + def __init__(self): + self.call_count = 0 + + def list_alert_triage(self, org, params): + self.call_count += 1 + if params.get("page") == 1: + return {"results": [{"alert_key": "a"}], "nextPage": 2} + return {"results": [{"alert_key": "b"}], "nextPage": None} + + class FakeSDK: + triage = FakeTriageAPI() + + entries = fetch_triage_data(FakeSDK(), "my-org") + assert len(entries) == 2 + + def test_api_error_returns_partial(self): + class FakeTriageAPI: + def __init__(self): + self.calls = 0 + + def list_alert_triage(self, org, params): + self.calls += 1 + if self.calls == 1: + return {"results": [{"alert_key": "a"}], "nextPage": 2} + raise RuntimeError("API error") + + class FakeSDK: + triage = FakeTriageAPI() + + entries = fetch_triage_data(FakeSDK(), "my-org") + assert len(entries) == 1 + + +# --------------------------------------------------------------------------- +# SecurityScanner._connector_name_from_generated_by +# --------------------------------------------------------------------------- + +class TestConnectorNameMapping: + def test_opengrep_variants(self): + from socket_basics.socket_basics import SecurityScanner + assert SecurityScanner._connector_name_from_generated_by("opengrep-python") == "opengrep" + assert SecurityScanner._connector_name_from_generated_by("sast-generic") == "opengrep" + + def test_trufflehog(self): + from socket_basics.socket_basics import SecurityScanner + assert SecurityScanner._connector_name_from_generated_by("trufflehog") == "trufflehog" + + def test_trivy_variants(self): + from socket_basics.socket_basics import SecurityScanner + assert SecurityScanner._connector_name_from_generated_by("trivy-dockerfile") == "trivy" + assert SecurityScanner._connector_name_from_generated_by("trivy-image") == "trivy" + assert SecurityScanner._connector_name_from_generated_by("trivy-npm") == "trivy" + + def test_socket_tier1(self): + from socket_basics.socket_basics import SecurityScanner + assert SecurityScanner._connector_name_from_generated_by("socket-tier1") == "socket_tier1" + + def test_unknown_returns_none(self): + from socket_basics.socket_basics import SecurityScanner + assert SecurityScanner._connector_name_from_generated_by("unknown-tool") is None + + +# --------------------------------------------------------------------------- +# SecurityScanner._inject_triage_summary +# --------------------------------------------------------------------------- + +class TestInjectTriageSummary: + def test_injects_after_heading(self): + from socket_basics.socket_basics import SecurityScanner + + notifications = { + "github_pr": [ + { + "title": "SAST Findings", + "content": "\n# SAST Python Findings\n### Summary\nSome content\n", + } + ] + } + SecurityScanner._inject_triage_summary(notifications, 3, "https://socket.dev/scan/123") + + content = notifications["github_pr"][0]["content"] + assert "3 finding(s) triaged" in content + assert "Socket Dashboard" in content + # Summary line should appear after the # heading + lines = content.split("\n") + heading_idx = next(i for i, l in enumerate(lines) if l.strip().startswith("# ")) + summary_idx = next(i for i, l in enumerate(lines) if "triaged" in l) + assert summary_idx > heading_idx + + def test_no_github_pr_key_is_noop(self): + from socket_basics.socket_basics import SecurityScanner + + notifications = {"slack": [{"title": "t", "content": "c"}]} + SecurityScanner._inject_triage_summary(notifications, 5, "") + assert "github_pr" not in notifications + + def test_uses_default_dashboard_link(self): + from socket_basics.socket_basics import SecurityScanner + + notifications = { + "github_pr": [{"title": "t", "content": "# Title\nBody"}] + } + SecurityScanner._inject_triage_summary(notifications, 1, "") + assert "https://socket.dev/dashboard" in notifications["github_pr"][0]["content"] From 65f5f6a6835157915e865f34d176828a3d2a3738 Mon Sep 17 00:00:00 2001 From: Carl Bergenhem Date: Thu, 5 Feb 2026 15:36:21 -0800 Subject: [PATCH 02/13] Handle triage API access denied gracefully Log an info-level message instead of an error traceback when the Socket API token lacks triage permissions, and skip filtering so the scan completes normally with all findings intact. Co-Authored-By: Claude Opus 4.6 --- socket_basics/core/triage.py | 13 +++++++++++-- tests/test_triage.py | 24 ++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/socket_basics/core/triage.py b/socket_basics/core/triage.py index 0738f95..f9b9796 100644 --- a/socket_basics/core/triage.py +++ b/socket_basics/core/triage.py @@ -34,8 +34,17 @@ def fetch_triage_data(sdk: Any, org_slug: str) -> List[Dict[str, Any]]: org_slug, {"per_page": per_page, "page": page}, ) - except Exception: - logger.exception("Failed to fetch triage data (page %d)", page) + except Exception as exc: + # Handle insufficient permissions gracefully so the scan + # continues without triage filtering. + exc_name = type(exc).__name__ + if "AccessDenied" in exc_name or "Forbidden" in exc_name: + logger.info( + "Triage API access denied (insufficient permissions). " + "Skipping triage filtering for this run." + ) + else: + logger.warning("Failed to fetch triage data (page %d): %s", page, exc) break if not isinstance(response, dict): diff --git a/tests/test_triage.py b/tests/test_triage.py index b1a472e..c0c4954 100644 --- a/tests/test_triage.py +++ b/tests/test_triage.py @@ -310,6 +310,30 @@ class FakeSDK: entries = fetch_triage_data(FakeSDK(), "my-org") assert len(entries) == 1 + def test_access_denied_returns_empty_and_logs_info(self, caplog): + """Insufficient permissions should log an info message (not an error) and return empty.""" + + class APIAccessDenied(Exception): + pass + + class FakeTriageAPI: + def list_alert_triage(self, org, params): + raise APIAccessDenied("Insufficient permissions.") + + class FakeSDK: + triage = FakeTriageAPI() + + import logging + with caplog.at_level(logging.DEBUG): + entries = fetch_triage_data(FakeSDK(), "my-org") + + assert entries == [] + info_messages = [r for r in caplog.records if r.levelno == logging.INFO] + assert any("access denied" in m.message.lower() for m in info_messages) + # Should NOT produce an ERROR-level record + error_messages = [r for r in caplog.records if r.levelno >= logging.ERROR] + assert not error_messages + # --------------------------------------------------------------------------- # SecurityScanner._connector_name_from_generated_by From cac5de77778088acd8a4b20de4c1ccf559778c77 Mon Sep 17 00:00:00 2001 From: Carl Bergenhem Date: Thu, 5 Feb 2026 15:46:17 -0800 Subject: [PATCH 03/13] Fix stale notifications after triage and improve logging Always replace results['notifications'] after triage filtering so pre-filter content is never forwarded to notifiers. Skip PR comment API calls when content is unchanged. Add info-level logging for triaged/remaining finding counts and connector regeneration details. Co-Authored-By: Claude Opus 4.6 --- .../core/notification/github_pr_notifier.py | 12 ++++ socket_basics/socket_basics.py | 60 +++++++++++++++++-- 2 files changed, 66 insertions(+), 6 deletions(-) diff --git a/socket_basics/core/notification/github_pr_notifier.py b/socket_basics/core/notification/github_pr_notifier.py index 555d03e..8c258bc 100644 --- a/socket_basics/core/notification/github_pr_notifier.py +++ b/socket_basics/core/notification/github_pr_notifier.py @@ -100,6 +100,18 @@ def notify(self, facts: Dict[str, Any]) -> None: # Update existing comments with new section content for comment_id, updated_body in comment_updates.items(): + # Detect whether content actually changed before making the API call + original_body = next( + (c.get('body', '') for c in existing_comments if c.get('id') == comment_id), + '', + ) + if original_body == updated_body: + logger.info( + 'GithubPRNotifier: comment %s content unchanged; skipping update', + comment_id, + ) + continue + success = self._update_comment(pr_number, comment_id, updated_body) if success: logger.info('GithubPRNotifier: updated existing comment %s', comment_id) diff --git a/socket_basics/socket_basics.py b/socket_basics/socket_basics.py index 6311611..984133e 100644 --- a/socket_basics/socket_basics.py +++ b/socket_basics/socket_basics.py @@ -419,15 +419,29 @@ def apply_triage_filter(self, results: Dict[str, Any]) -> Dict[str, Any]: return results triage_filter = TriageFilter(triage_entries) + original_components = results.get('components', []) + original_alert_count = sum( + len(c.get('alerts', [])) for c in original_components + ) filtered_components, triaged_count = triage_filter.filter_components( - results.get('components', []) + original_components ) if triaged_count == 0: - logger.debug("No findings matched triage entries") + logger.info( + "Triage filter matched 0 of %d finding(s); no changes applied", + original_alert_count, + ) return results - logger.info("Filtered %d triaged finding(s) from results", triaged_count) + remaining_alert_count = sum( + len(c.get('alerts', [])) for c in filtered_components + ) + logger.info( + "Triage filter removed %d finding(s); %d finding(s) remain", + triaged_count, + remaining_alert_count, + ) results['components'] = filtered_components results['triaged_count'] = triaged_count @@ -448,22 +462,47 @@ def _regenerate_notifications( field on alerts), calls each connector's ``generate_notifications``, merges the results, and injects a triage summary into github_pr content. + + Always replaces ``results['notifications']`` so stale pre-filter + notifications are never forwarded to notifiers. """ connector_components: Dict[str, List[Dict[str, Any]]] = {} + unmapped_count = 0 for comp in filtered_components: + mapped = False for alert in comp.get('alerts', []): gen = alert.get('generatedBy') or '' connector_name = self._connector_name_from_generated_by(gen) if connector_name: connector_components.setdefault(connector_name, []).append(comp) + mapped = True break # one mapping per component is enough + if not mapped: + unmapped_count += 1 + + if unmapped_count: + logger.debug( + "Triage regen: %d component(s) could not be mapped to a connector", + unmapped_count, + ) + + logger.info( + "Regenerating notifications for %d connector(s): %s", + len(connector_components), + ", ".join(connector_components.keys()) or "(none)", + ) merged_notifications: Dict[str, list] = {} for connector_name, comps in connector_components.items(): connector = self.connector_manager.loaded_connectors.get(connector_name) if connector is None: - logger.debug("Connector %s not loaded; skipping notification regen", connector_name) + logger.warning( + "Connector %s not in loaded_connectors (available: %s); " + "cannot regenerate its notifications", + connector_name, + ", ".join(self.connector_manager.loaded_connectors.keys()), + ) continue if not hasattr(connector, 'generate_notifications'): @@ -483,6 +522,13 @@ def _regenerate_notifications( if not isinstance(notifs, dict): continue + notifier_keys = [k for k, v in notifs.items() if v] + logger.debug( + "Connector %s produced notifications for: %s", + connector_name, + ", ".join(notifier_keys) or "(empty)", + ) + for notifier_key, payload in notifs.items(): if notifier_key not in merged_notifications: merged_notifications[notifier_key] = payload @@ -493,8 +539,10 @@ def _regenerate_notifications( full_scan_url = results.get('full_scan_html_url', '') self._inject_triage_summary(merged_notifications, triaged_count, full_scan_url) - if merged_notifications: - results['notifications'] = merged_notifications + # Always replace notifications so stale pre-filter content is never + # forwarded to notifiers. An empty dict is valid and means every + # finding was triaged. + results['notifications'] = merged_notifications @staticmethod def _connector_name_from_generated_by(generated_by: str) -> str | None: From 5824e9e6f5463b08adedbfd9fc86d17c21356ee4 Mon Sep 17 00:00:00 2001 From: Carl Bergenhem Date: Thu, 5 Feb 2026 16:43:48 -0800 Subject: [PATCH 04/13] Rework triage matching to use stream-based alert key lookup The triage API returns opaque alert_key hashes, not human-readable identifiers. This rewrites the matching logic to stream the full scan via sdk.fullscans.stream(), cross-reference Socket alert keys against triage entries, and map back to local components by artifact ID. Co-Authored-By: Claude Opus 4.6 --- socket_basics/core/triage.py | 273 +++++++++++------- socket_basics/socket_basics.py | 39 ++- tests/test_triage.py | 490 +++++++++++++++++++++++---------- 3 files changed, 536 insertions(+), 266 deletions(-) diff --git a/socket_basics/core/triage.py b/socket_basics/core/triage.py index f9b9796..59bf6ce 100644 --- a/socket_basics/core/triage.py +++ b/socket_basics/core/triage.py @@ -1,12 +1,12 @@ """Triage filtering for Socket Security Basics. -Fetches triage entries from the Socket API and filters scan components -whose alerts have been triaged (state: ignore or monitor). +Streams the full scan from the Socket API to obtain alert keys, fetches +triage entries, and filters local scan components whose alerts have been +triaged (state: ignore or monitor). """ -import fnmatch import logging -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Set, Tuple logger = logging.getLogger(__name__) @@ -14,6 +14,10 @@ _SUPPRESSED_STATES = {"ignore", "monitor"} +# ------------------------------------------------------------------ +# API helpers +# ------------------------------------------------------------------ + def fetch_triage_data(sdk: Any, org_slug: str) -> List[Dict[str, Any]]: """Fetch all triage alert entries from the Socket API, handling pagination. @@ -63,62 +67,145 @@ def fetch_triage_data(sdk: Any, org_slug: str) -> List[Dict[str, Any]]: return all_entries -class TriageFilter: - """Matches local scan alerts against triage entries and filters them out.""" +def stream_full_scan_alerts( + sdk: Any, org_slug: str, full_scan_id: str +) -> Dict[str, List[Dict[str, Any]]]: + """Stream a full scan and extract alert keys grouped by artifact. + + Returns: + Mapping of artifact ID to list of alert dicts. Each alert dict + contains at minimum ``key`` and ``type``. The artifact metadata + (name, version, type, etc.) is included under a ``_artifact`` key + in every alert dict for downstream matching. + """ + try: + # use_types=False returns a plain dict keyed by artifact ID + response = sdk.fullscans.stream(org_slug, full_scan_id, use_types=False) + except Exception as exc: + exc_name = type(exc).__name__ + if "AccessDenied" in exc_name or "Forbidden" in exc_name: + logger.info( + "Full scan stream access denied (insufficient permissions). " + "Skipping triage filtering for this run." + ) + else: + logger.warning("Failed to stream full scan %s: %s", full_scan_id, exc) + return {} + + if not isinstance(response, dict): + logger.warning("Unexpected full scan stream response type: %s", type(response)) + return {} + + artifact_alerts: Dict[str, List[Dict[str, Any]]] = {} + for artifact_id, artifact in response.items(): + if not isinstance(artifact, dict): + continue + alerts = artifact.get("alerts") or [] + if not alerts: + continue + meta = { + "artifact_id": artifact_id, + "artifact_name": artifact.get("name"), + "artifact_version": artifact.get("version"), + "artifact_type": artifact.get("type"), + "artifact_namespace": artifact.get("namespace"), + "artifact_subpath": artifact.get("subPath") or artifact.get("subpath"), + } + enriched = [] + for a in alerts: + if isinstance(a, dict) and a.get("key"): + enriched.append({**a, "_artifact": meta}) + if enriched: + artifact_alerts[artifact_id] = enriched + + total_alerts = sum(len(v) for v in artifact_alerts.values()) + logger.debug( + "Streamed full scan %s: %d artifact(s), %d alert(s) with keys", + full_scan_id, + len(artifact_alerts), + total_alerts, + ) + return artifact_alerts + + +# ------------------------------------------------------------------ +# TriageFilter +# ------------------------------------------------------------------ - def __init__(self, triage_entries: List[Dict[str, Any]]) -> None: - # Only keep entries whose state suppresses findings - self.entries = [ - e for e in triage_entries - if (e.get("state") or "").lower() in _SUPPRESSED_STATES - ] +class TriageFilter: + """Cross-references Socket alert keys against triage entries and + maps triaged alerts back to local scan components.""" + + def __init__( + self, + triage_entries: List[Dict[str, Any]], + artifact_alerts: Dict[str, List[Dict[str, Any]]], + ) -> None: + # Build set of suppressed alert keys + self.triaged_keys: Set[str] = set() + for entry in triage_entries: + state = (entry.get("state") or "").lower() + key = entry.get("alert_key") + if state in _SUPPRESSED_STATES and key: + self.triaged_keys.add(key) + + # Flatten all Socket alerts for lookup + self._socket_alerts: List[Dict[str, Any]] = [] + for alerts in artifact_alerts.values(): + self._socket_alerts.extend(alerts) + + # Build a mapping from (artifact_id, alert_type) to triaged status + # for fast lookups when matching against local components + self._triaged_by_artifact: Dict[str, Set[str]] = {} + for alert in self._socket_alerts: + if alert.get("key") in self.triaged_keys: + art_id = alert.get("_artifact", {}).get("artifact_id", "") + alert_type = alert.get("type") or "" + self._triaged_by_artifact.setdefault(art_id, set()).add(alert_type) # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ - def is_alert_triaged(self, component: Dict[str, Any], alert: Dict[str, Any]) -> bool: - """Return True if the alert on the given component matches a suppressed triage entry.""" - alert_keys = self._extract_alert_keys(alert) - if not alert_keys: - return False - - for entry in self.entries: - entry_key = entry.get("alert_key") - if not entry_key: - continue - - if entry_key not in alert_keys: - continue - - # alert_key matched; now check package scope - if self._is_broad_match(entry): - return True - - if self._package_matches(entry, component): - return True - - return False - def filter_components( self, components: List[Dict[str, Any]] ) -> Tuple[List[Dict[str, Any]], int]: - """Remove triaged alerts from components. + """Remove triaged alerts from local components. + + Matches local components to Socket artifacts by component ID, then + checks each local alert against the set of triaged alert types for + that artifact. Returns: - (filtered_components, triaged_count) where triaged_count is the - total number of individual alerts removed. + (filtered_components, triaged_count) """ - if not self.entries: + if not self.triaged_keys: + return components, 0 + + # Build lookup: component id -> set of triaged Socket alert types + triaged_types_by_component = self._map_components_to_triaged_types(components) + + if not triaged_types_by_component: + logger.debug( + "No local components matched Socket artifacts with triaged alerts" + ) return components, 0 filtered: List[Dict[str, Any]] = [] triaged_count = 0 for comp in components: + comp_id = comp.get("id") or "" + triaged_types = triaged_types_by_component.get(comp_id) + + if triaged_types is None: + # Component had no triaged alerts; keep as-is + filtered.append(comp) + continue + remaining_alerts: List[Dict[str, Any]] = [] for alert in comp.get("alerts", []): - if self.is_alert_triaged(comp, alert): + if self._local_alert_is_triaged(alert, triaged_types): triaged_count += 1 else: remaining_alerts.append(alert) @@ -134,71 +221,49 @@ def filter_components( # Internal helpers # ------------------------------------------------------------------ - @staticmethod - def _extract_alert_keys(alert: Dict[str, Any]) -> set: - """Build the set of candidate keys that could match a triage entry's alert_key.""" - keys: set = set() - props = alert.get("props") or {} - - for field in ( - alert.get("title"), - alert.get("type"), - props.get("ruleId"), - props.get("detectorName"), - props.get("vulnerabilityId"), - props.get("cveId"), - ): - if field: - keys.add(str(field)) + def _map_components_to_triaged_types( + self, components: List[Dict[str, Any]] + ) -> Dict[str, Set[str]]: + """Map local component IDs to the set of triaged Socket alert types. - return keys + Matches by component ``id`` (which is typically a hash that Socket + also uses as the artifact ID). + """ + local_ids = {comp.get("id") for comp in components if comp.get("id")} + result: Dict[str, Set[str]] = {} + for comp_id in local_ids: + triaged = self._triaged_by_artifact.get(comp_id) + if triaged: + result[comp_id] = triaged + return result @staticmethod - def _is_broad_match(entry: Dict[str, Any]) -> bool: - """Return True when the triage entry has no package scope (applies globally).""" - return ( - entry.get("package_name") is None - and entry.get("package_type") is None - and entry.get("package_version") is None - and entry.get("package_namespace") is None - ) + def _local_alert_is_triaged( + alert: Dict[str, Any], triaged_types: Set[str] + ) -> bool: + """Check if a local alert matches any of the triaged Socket alert types. + + Socket alert ``type`` values (e.g. ``badEncoding``, ``cve``) are + compared against the local alert's ``type`` field. When the local + alert type is too generic (``"generic"`` or ``"vulnerability"``), + we fall back to matching on ``title``, ``props.ruleId``, or + ``props.vulnerabilityId``. + """ + # Direct type match + local_type = alert.get("type") or "" + if local_type and local_type not in ("generic", "vulnerability"): + return local_type in triaged_types - @staticmethod - def _version_matches(entry_version: str, component_version: str) -> bool: - """Check version match, supporting wildcard suffix patterns like '1.2.*'.""" - if not entry_version or entry_version == "*": - return True - if not component_version: - return False - # fnmatch handles '*' and '?' glob patterns - return fnmatch.fnmatch(component_version, entry_version) - - @classmethod - def _package_matches(cls, entry: Dict[str, Any], component: Dict[str, Any]) -> bool: - """Return True if the triage entry's package scope matches the component.""" - qualifiers = component.get("qualifiers") or {} - comp_name = component.get("name") or "" - comp_type = ( - qualifiers.get("ecosystem") - or qualifiers.get("type") - or component.get("type") - or "" - ) - comp_version = component.get("version") or qualifiers.get("version") or "" - comp_namespace = qualifiers.get("namespace") or "" - - entry_name = entry.get("package_name") - entry_type = entry.get("package_type") - entry_version = entry.get("package_version") - entry_namespace = entry.get("package_namespace") - - if entry_name is not None and entry_name != comp_name: - return False - if entry_type is not None and entry_type.lower() != comp_type.lower(): - return False - if entry_namespace is not None and entry_namespace != comp_namespace: - return False - if entry_version is not None and not cls._version_matches(entry_version, comp_version): - return False - - return True + # Fallback: match candidate fields against triaged types + props = alert.get("props") or {} + candidates = { + v for v in ( + alert.get("title"), + props.get("ruleId"), + props.get("detectorName"), + props.get("vulnerabilityId"), + props.get("cveId"), + ) + if v + } + return bool(candidates & triaged_types) diff --git a/socket_basics/socket_basics.py b/socket_basics/socket_basics.py index 984133e..c5683b2 100644 --- a/socket_basics/socket_basics.py +++ b/socket_basics/socket_basics.py @@ -381,10 +381,10 @@ def submit_socket_facts(self, socket_facts_path: Path, results: Dict[str, Any]) def apply_triage_filter(self, results: Dict[str, Any]) -> Dict[str, Any]: """Filter out triaged alerts and regenerate notifications. - Fetches triage entries from the Socket API, removes alerts with - state ``ignore`` or ``monitor``, regenerates connector notifications - for the remaining components, and injects a triage summary line into - github_pr notification content. + Streams the full scan from the Socket API to obtain alert keys, + cross-references them with triage entries, removes suppressed + alerts from local components, regenerates connector notifications, + and injects a triage summary into github_pr content. Args: results: Current scan results dict (components + notifications). @@ -394,11 +394,16 @@ def apply_triage_filter(self, results: Dict[str, Any]) -> Dict[str, Any]: """ socket_api_key = self.config.get('socket_api_key') socket_org = self.config.get('socket_org') + full_scan_id = results.get('full_scan_id') if not socket_api_key or not socket_org: logger.debug("Skipping triage filter: missing socket_api_key or socket_org") return results + if not full_scan_id: + logger.debug("Skipping triage filter: no full_scan_id in results") + return results + # Import SDK and triage helpers try: from socketdev import socketdev @@ -407,18 +412,34 @@ def apply_triage_filter(self, results: Dict[str, Any]) -> Dict[str, Any]: return results try: - from .core.triage import TriageFilter, fetch_triage_data + from .core.triage import TriageFilter, fetch_triage_data, stream_full_scan_alerts except ImportError: - from socket_basics.core.triage import TriageFilter, fetch_triage_data + from socket_basics.core.triage import TriageFilter, fetch_triage_data, stream_full_scan_alerts sdk = socketdev(token=socket_api_key, timeout=100) - triage_entries = fetch_triage_data(sdk, socket_org) + # Fetch triage entries and stream full scan alert keys in sequence + triage_entries = fetch_triage_data(sdk, socket_org) if not triage_entries: - logger.debug("No triage entries found; skipping filter") + logger.info("No triage entries found; skipping filter") + return results + + suppressed_count = sum( + 1 for e in triage_entries + if (e.get("state") or "").lower() in ("ignore", "monitor") + ) + logger.info( + "Fetched %d triage entries (%d with suppressed state)", + len(triage_entries), + suppressed_count, + ) + + artifact_alerts = stream_full_scan_alerts(sdk, socket_org, full_scan_id) + if not artifact_alerts: + logger.info("No alert keys returned from full scan stream; skipping filter") return results - triage_filter = TriageFilter(triage_entries) + triage_filter = TriageFilter(triage_entries, artifact_alerts) original_components = results.get('components', []) original_alert_count = sum( len(c.get('alerts', [])) for c in original_components diff --git a/tests/test_triage.py b/tests/test_triage.py index c0c4954..a05e6e5 100644 --- a/tests/test_triage.py +++ b/tests/test_triage.py @@ -1,21 +1,30 @@ """Tests for socket_basics.core.triage module.""" +import logging import pytest -from socket_basics.core.triage import TriageFilter, fetch_triage_data +from socket_basics.core.triage import ( + TriageFilter, + fetch_triage_data, + stream_full_scan_alerts, +) # --------------------------------------------------------------------------- # Fixtures / helpers # --------------------------------------------------------------------------- +ARTIFACT_ID = "abc123" + + def _make_component( + comp_id: str = ARTIFACT_ID, name: str = "lodash", comp_type: str = "npm", version: str = "4.17.21", alerts: list | None = None, ) -> dict: return { - "id": f"pkg:{comp_type}/{name}@{version}", + "id": comp_id, "name": name, "version": version, "type": comp_type, @@ -24,9 +33,9 @@ def _make_component( } -def _make_alert( +def _make_local_alert( title: str = "badEncoding", - alert_type: str = "supplyChainRisk", + alert_type: str = "badEncoding", severity: str = "high", rule_id: str | None = None, detector_name: str | None = None, @@ -52,153 +61,144 @@ def _make_alert( def _make_triage_entry( alert_key: str, state: str = "ignore", - package_name: str | None = None, - package_type: str | None = None, - package_version: str | None = None, - package_namespace: str | None = None, ) -> dict: return { "uuid": "test-uuid", "alert_key": alert_key, "state": state, - "package_name": package_name, - "package_type": package_type, - "package_version": package_version, - "package_namespace": package_namespace, "note": "", "organization_id": "test-org", } +def _make_artifact_alerts( + artifact_id: str = ARTIFACT_ID, + alerts: list[dict] | None = None, + name: str = "lodash", + version: str = "4.17.21", + pkg_type: str = "npm", +) -> dict[str, list[dict]]: + """Build an artifact_alerts mapping with enriched _artifact metadata.""" + meta = { + "artifact_id": artifact_id, + "artifact_name": name, + "artifact_version": version, + "artifact_type": pkg_type, + "artifact_namespace": None, + "artifact_subpath": None, + } + enriched = [{**a, "_artifact": meta} for a in (alerts or [])] + return {artifact_id: enriched} + + +def _socket_alert(key: str, alert_type: str) -> dict: + """Create a minimal Socket alert dict (as returned by the full scan stream).""" + return {"key": key, "type": alert_type} + + # --------------------------------------------------------------------------- -# TriageFilter.is_alert_triaged +# TriageFilter construction # --------------------------------------------------------------------------- -class TestIsAlertTriaged: - """Tests for the alert matching logic.""" - - def test_broad_match_by_title(self): - """Triage entry with no package info matches any component with matching alert_key.""" - entry = _make_triage_entry(alert_key="badEncoding") - tf = TriageFilter([entry]) - comp = _make_component() - alert = _make_alert(title="badEncoding") - assert tf.is_alert_triaged(comp, alert) is True - - def test_broad_match_by_rule_id(self): - entry = _make_triage_entry(alert_key="python.lang.security.audit.xss") - tf = TriageFilter([entry]) - comp = _make_component() - alert = _make_alert(title="XSS Vulnerability", rule_id="python.lang.security.audit.xss") - assert tf.is_alert_triaged(comp, alert) is True - - def test_broad_match_by_detector_name(self): - entry = _make_triage_entry(alert_key="AWS") - tf = TriageFilter([entry]) - comp = _make_component() - alert = _make_alert(title="AWS Key Detected", detector_name="AWS") - assert tf.is_alert_triaged(comp, alert) is True - - def test_broad_match_by_cve(self): - entry = _make_triage_entry(alert_key="CVE-2024-1234") - tf = TriageFilter([entry]) - comp = _make_component() - alert = _make_alert(title="Some Vuln", cve_id="CVE-2024-1234") - assert tf.is_alert_triaged(comp, alert) is True - - def test_no_match_different_key(self): - entry = _make_triage_entry(alert_key="differentRule") - tf = TriageFilter([entry]) - comp = _make_component() - alert = _make_alert(title="badEncoding") - assert tf.is_alert_triaged(comp, alert) is False - - def test_package_scoped_match(self): - """Triage entry with package info only matches the specific package.""" - entry = _make_triage_entry( - alert_key="badEncoding", - package_name="lodash", - package_type="npm", +class TestTriageFilterInit: + def test_builds_triaged_keys_for_ignore(self): + entries = [_make_triage_entry("hash-1", state="ignore")] + artifact_alerts = _make_artifact_alerts( + alerts=[_socket_alert("hash-1", "badEncoding")] + ) + tf = TriageFilter(entries, artifact_alerts) + assert "hash-1" in tf.triaged_keys + + def test_builds_triaged_keys_for_monitor(self): + entries = [_make_triage_entry("hash-2", state="monitor")] + artifact_alerts = _make_artifact_alerts( + alerts=[_socket_alert("hash-2", "cve")] + ) + tf = TriageFilter(entries, artifact_alerts) + assert "hash-2" in tf.triaged_keys + + def test_excludes_block_warn_inherit_states(self): + entries = [ + _make_triage_entry("h1", state="block"), + _make_triage_entry("h2", state="warn"), + _make_triage_entry("h3", state="inherit"), + ] + artifact_alerts = _make_artifact_alerts( + alerts=[ + _socket_alert("h1", "a"), + _socket_alert("h2", "b"), + _socket_alert("h3", "c"), + ] + ) + tf = TriageFilter(entries, artifact_alerts) + assert tf.triaged_keys == set() + + def test_builds_triaged_by_artifact_mapping(self): + entries = [_make_triage_entry("hash-1", state="ignore")] + artifact_alerts = _make_artifact_alerts( + artifact_id="art-1", + alerts=[_socket_alert("hash-1", "badEncoding")], ) - tf = TriageFilter([entry]) + tf = TriageFilter(entries, artifact_alerts) + assert "art-1" in tf._triaged_by_artifact + assert "badEncoding" in tf._triaged_by_artifact["art-1"] - comp_match = _make_component(name="lodash", comp_type="npm") - comp_no_match = _make_component(name="express", comp_type="npm") - alert = _make_alert(title="badEncoding") + def test_no_entries_means_empty_triaged_keys(self): + tf = TriageFilter([], {}) + assert tf.triaged_keys == set() - assert tf.is_alert_triaged(comp_match, alert) is True - assert tf.is_alert_triaged(comp_no_match, alert) is False + def test_entry_without_alert_key_ignored(self): + entries = [{"state": "ignore", "alert_key": None}] + tf = TriageFilter(entries, {}) + assert tf.triaged_keys == set() - def test_package_version_exact_match(self): - entry = _make_triage_entry( - alert_key="badEncoding", - package_name="lodash", - package_type="npm", - package_version="4.17.21", - ) - tf = TriageFilter([entry]) - comp_match = _make_component(name="lodash", comp_type="npm", version="4.17.21") - comp_no_match = _make_component(name="lodash", comp_type="npm", version="4.17.20") - alert = _make_alert(title="badEncoding") +# --------------------------------------------------------------------------- +# TriageFilter._local_alert_is_triaged +# --------------------------------------------------------------------------- - assert tf.is_alert_triaged(comp_match, alert) is True - assert tf.is_alert_triaged(comp_no_match, alert) is False +class TestLocalAlertIsTriaged: + def test_direct_type_match(self): + triaged_types = {"badEncoding"} + alert = _make_local_alert(alert_type="badEncoding") + assert TriageFilter._local_alert_is_triaged(alert, triaged_types) is True + + def test_direct_type_no_match(self): + triaged_types = {"badEncoding"} + alert = _make_local_alert(alert_type="cve") + assert TriageFilter._local_alert_is_triaged(alert, triaged_types) is False + + def test_generic_type_falls_back_to_title(self): + triaged_types = {"badEncoding"} + alert = _make_local_alert(title="badEncoding", alert_type="generic") + assert TriageFilter._local_alert_is_triaged(alert, triaged_types) is True + + def test_vulnerability_type_falls_back_to_cve(self): + triaged_types = {"CVE-2024-1234"} + alert = _make_local_alert( + title="Some Vuln", alert_type="vulnerability", cve_id="CVE-2024-1234" + ) + assert TriageFilter._local_alert_is_triaged(alert, triaged_types) is True - def test_version_wildcard(self): - entry = _make_triage_entry( - alert_key="badEncoding", - package_name="lodash", - package_type="npm", - package_version="4.17.*", + def test_generic_type_falls_back_to_rule_id(self): + triaged_types = {"python.lang.security.audit.xss"} + alert = _make_local_alert( + title="XSS", alert_type="generic", + rule_id="python.lang.security.audit.xss", ) - tf = TriageFilter([entry]) - alert = _make_alert(title="badEncoding") - - assert tf.is_alert_triaged( - _make_component(name="lodash", comp_type="npm", version="4.17.21"), alert - ) is True - assert tf.is_alert_triaged( - _make_component(name="lodash", comp_type="npm", version="4.17.0"), alert - ) is True - assert tf.is_alert_triaged( - _make_component(name="lodash", comp_type="npm", version="4.18.0"), alert - ) is False - - def test_version_star_matches_all(self): - entry = _make_triage_entry( - alert_key="badEncoding", - package_name="lodash", - package_type="npm", - package_version="*", + assert TriageFilter._local_alert_is_triaged(alert, triaged_types) is True + + def test_generic_type_falls_back_to_detector_name(self): + triaged_types = {"AWS"} + alert = _make_local_alert( + title="AWS Key", alert_type="generic", detector_name="AWS" ) - tf = TriageFilter([entry]) - alert = _make_alert(title="badEncoding") - assert tf.is_alert_triaged( - _make_component(name="lodash", comp_type="npm", version="99.0.0"), alert - ) is True - - def test_states_block_and_warn_not_suppressed(self): - """Triage entries with block/warn/inherit states should not filter findings.""" - for state in ("block", "warn", "inherit"): - entry = _make_triage_entry(alert_key="badEncoding", state=state) - tf = TriageFilter([entry]) - assert tf.entries == [], f"state={state} should be excluded from filter entries" - - def test_state_monitor_suppressed(self): - entry = _make_triage_entry(alert_key="badEncoding", state="monitor") - tf = TriageFilter([entry]) - comp = _make_component() - alert = _make_alert(title="badEncoding") - assert tf.is_alert_triaged(comp, alert) is True - - def test_alert_with_no_matchable_keys(self): - """Alert with no title, type, or relevant props should not match.""" - entry = _make_triage_entry(alert_key="something") - tf = TriageFilter([entry]) - comp = _make_component() - alert = {"severity": "high", "props": {}} - assert tf.is_alert_triaged(comp, alert) is False + assert TriageFilter._local_alert_is_triaged(alert, triaged_types) is True + + def test_no_fallback_candidates_returns_false(self): + triaged_types = {"something"} + alert = {"type": "generic", "props": {}} + assert TriageFilter._local_alert_is_triaged(alert, triaged_types) is False # --------------------------------------------------------------------------- @@ -206,47 +206,87 @@ def test_alert_with_no_matchable_keys(self): # --------------------------------------------------------------------------- class TestFilterComponents: - def test_removes_triaged_alerts(self): - entry = _make_triage_entry(alert_key="badEncoding") - tf = TriageFilter([entry]) - - alert_triaged = _make_alert(title="badEncoding") - alert_kept = _make_alert(title="otherIssue") - comp = _make_component(alerts=[alert_triaged, alert_kept]) + def test_removes_triaged_alert_by_type(self): + """Component ID matches artifact, triaged alert type matches local alert type.""" + entries = [_make_triage_entry("hash-1")] + artifact_alerts = _make_artifact_alerts( + alerts=[_socket_alert("hash-1", "badEncoding")] + ) + tf = TriageFilter(entries, artifact_alerts) + comp = _make_component( + comp_id=ARTIFACT_ID, + alerts=[ + _make_local_alert(alert_type="badEncoding"), + _make_local_alert(title="kept", alert_type="otherIssue"), + ], + ) filtered, count = tf.filter_components([comp]) assert count == 1 assert len(filtered) == 1 assert len(filtered[0]["alerts"]) == 1 - assert filtered[0]["alerts"][0]["title"] == "otherIssue" + assert filtered[0]["alerts"][0]["title"] == "kept" def test_removes_component_when_all_alerts_triaged(self): - entry = _make_triage_entry(alert_key="badEncoding") - tf = TriageFilter([entry]) + entries = [_make_triage_entry("hash-1")] + artifact_alerts = _make_artifact_alerts( + alerts=[_socket_alert("hash-1", "badEncoding")] + ) + tf = TriageFilter(entries, artifact_alerts) - comp = _make_component(alerts=[_make_alert(title="badEncoding")]) + comp = _make_component( + comp_id=ARTIFACT_ID, + alerts=[_make_local_alert(alert_type="badEncoding")], + ) filtered, count = tf.filter_components([comp]) assert count == 1 assert len(filtered) == 0 def test_no_triage_entries_returns_original(self): - tf = TriageFilter([]) - comp = _make_component(alerts=[_make_alert()]) + tf = TriageFilter([], {}) + comp = _make_component(alerts=[_make_local_alert()]) + filtered, count = tf.filter_components([comp]) + assert count == 0 + assert filtered == [comp] + + def test_component_id_mismatch_keeps_all_alerts(self): + """When local component ID doesn't match any artifact, nothing is filtered.""" + entries = [_make_triage_entry("hash-1")] + artifact_alerts = _make_artifact_alerts( + artifact_id="different-artifact", + alerts=[_socket_alert("hash-1", "badEncoding")], + ) + tf = TriageFilter(entries, artifact_alerts) + + comp = _make_component( + comp_id="unrelated-comp-id", + alerts=[_make_local_alert(alert_type="badEncoding")], + ) filtered, count = tf.filter_components([comp]) assert count == 0 - assert filtered is [comp] or filtered == [comp] + assert len(filtered) == 1 def test_multiple_components_mixed(self): - entry = _make_triage_entry(alert_key="badEncoding") - tf = TriageFilter([entry]) + entries = [_make_triage_entry("hash-1")] + artifact_alerts = _make_artifact_alerts( + artifact_id="art-a", + alerts=[_socket_alert("hash-1", "badEncoding")], + ) + tf = TriageFilter(entries, artifact_alerts) - comp1 = _make_component(name="a", alerts=[_make_alert(title="badEncoding")]) - comp2 = _make_component(name="b", alerts=[_make_alert(title="otherIssue")]) + comp1 = _make_component( + comp_id="art-a", name="a", + alerts=[_make_local_alert(alert_type="badEncoding")], + ) + comp2 = _make_component( + comp_id="art-b", name="b", + alerts=[_make_local_alert(alert_type="otherIssue")], + ) comp3 = _make_component( - name="c", + comp_id="art-a", name="c", alerts=[ - _make_alert(title="badEncoding"), - _make_alert(title="keepMe"), + _make_local_alert(alert_type="badEncoding"), + _make_local_alert(title="keepMe", alert_type="keepMe"), ], ) @@ -258,6 +298,153 @@ def test_multiple_components_mixed(self): assert "b" in names assert "c" in names + def test_multiple_triaged_alert_types_on_same_artifact(self): + entries = [ + _make_triage_entry("hash-1", state="ignore"), + _make_triage_entry("hash-2", state="monitor"), + ] + artifact_alerts = _make_artifact_alerts( + alerts=[ + _socket_alert("hash-1", "badEncoding"), + _socket_alert("hash-2", "cve"), + ], + ) + tf = TriageFilter(entries, artifact_alerts) + + comp = _make_component( + comp_id=ARTIFACT_ID, + alerts=[ + _make_local_alert(alert_type="badEncoding"), + _make_local_alert(alert_type="cve"), + _make_local_alert(title="safe", alert_type="safe"), + ], + ) + filtered, count = tf.filter_components([comp]) + assert count == 2 + assert len(filtered[0]["alerts"]) == 1 + assert filtered[0]["alerts"][0]["type"] == "safe" + + +# --------------------------------------------------------------------------- +# stream_full_scan_alerts +# --------------------------------------------------------------------------- + +class TestStreamFullScanAlerts: + def test_parses_artifacts_and_alerts(self): + class FakeFullscansAPI: + def stream(self, org, scan_id, use_types=False): + return { + "artifact-1": { + "name": "lodash", + "version": "4.17.21", + "type": "npm", + "namespace": None, + "alerts": [ + {"key": "hash-a", "type": "badEncoding"}, + {"key": "hash-b", "type": "cve"}, + ], + }, + "artifact-2": { + "name": "express", + "version": "4.18.0", + "type": "npm", + "namespace": None, + "alerts": [], + }, + } + + class FakeSDK: + fullscans = FakeFullscansAPI() + + result = stream_full_scan_alerts(FakeSDK(), "my-org", "scan-123") + assert "artifact-1" in result + assert "artifact-2" not in result # empty alerts filtered out + assert len(result["artifact-1"]) == 2 + assert result["artifact-1"][0]["key"] == "hash-a" + assert result["artifact-1"][0]["_artifact"]["artifact_name"] == "lodash" + + def test_skips_alerts_without_key(self): + class FakeFullscansAPI: + def stream(self, org, scan_id, use_types=False): + return { + "art-1": { + "name": "pkg", + "version": "1.0.0", + "type": "npm", + "alerts": [ + {"key": "hash-a", "type": "badEncoding"}, + {"type": "noKey"}, # missing key + {"key": "", "type": "emptyKey"}, # empty key + ], + }, + } + + class FakeSDK: + fullscans = FakeFullscansAPI() + + result = stream_full_scan_alerts(FakeSDK(), "org", "scan") + assert len(result["art-1"]) == 1 + + def test_access_denied_returns_empty(self, caplog): + class APIAccessDenied(Exception): + pass + + class FakeFullscansAPI: + def stream(self, org, scan_id, use_types=False): + raise APIAccessDenied("Forbidden") + + class FakeSDK: + fullscans = FakeFullscansAPI() + + with caplog.at_level(logging.DEBUG): + result = stream_full_scan_alerts(FakeSDK(), "org", "scan") + + assert result == {} + info_msgs = [r for r in caplog.records if r.levelno == logging.INFO] + assert any("access denied" in m.message.lower() for m in info_msgs) + + def test_api_error_returns_empty(self): + class FakeFullscansAPI: + def stream(self, org, scan_id, use_types=False): + raise RuntimeError("Network failure") + + class FakeSDK: + fullscans = FakeFullscansAPI() + + result = stream_full_scan_alerts(FakeSDK(), "org", "scan") + assert result == {} + + def test_non_dict_response_returns_empty(self): + class FakeFullscansAPI: + def stream(self, org, scan_id, use_types=False): + return "unexpected string" + + class FakeSDK: + fullscans = FakeFullscansAPI() + + result = stream_full_scan_alerts(FakeSDK(), "org", "scan") + assert result == {} + + def test_subpath_handling(self): + """Supports both camelCase and lowercase subpath field names.""" + class FakeFullscansAPI: + def stream(self, org, scan_id, use_types=False): + return { + "art-1": { + "name": "pkg", + "version": "1.0", + "type": "npm", + "subPath": "src/lib", + "alerts": [{"key": "k1", "type": "t1"}], + }, + } + + class FakeSDK: + fullscans = FakeFullscansAPI() + + result = stream_full_scan_alerts(FakeSDK(), "org", "scan") + assert result["art-1"][0]["_artifact"]["artifact_subpath"] == "src/lib" + # --------------------------------------------------------------------------- # fetch_triage_data @@ -323,14 +510,12 @@ def list_alert_triage(self, org, params): class FakeSDK: triage = FakeTriageAPI() - import logging with caplog.at_level(logging.DEBUG): entries = fetch_triage_data(FakeSDK(), "my-org") assert entries == [] info_messages = [r for r in caplog.records if r.levelno == logging.INFO] assert any("access denied" in m.message.lower() for m in info_messages) - # Should NOT produce an ERROR-level record error_messages = [r for r in caplog.records if r.levelno >= logging.ERROR] assert not error_messages @@ -385,7 +570,6 @@ def test_injects_after_heading(self): content = notifications["github_pr"][0]["content"] assert "3 finding(s) triaged" in content assert "Socket Dashboard" in content - # Summary line should appear after the # heading lines = content.split("\n") heading_idx = next(i for i, l in enumerate(lines) if l.strip().startswith("# ")) summary_idx = next(i for i, l in enumerate(lines) if "triaged" in l) From aeaa515a782d5d10ec2c4bb79799d03130400804 Mon Sep 17 00:00:00 2001 From: Carl Bergenhem Date: Thu, 5 Feb 2026 15:09:01 -0800 Subject: [PATCH 05/13] Add triage-aware PR comment filtering Fetch triage entries from Socket API after scan submission, remove alerts with ignore/monitor state from results, regenerate connector notifications with filtered components, and inject a triage count summary into GitHub PR comments. Co-Authored-By: Claude Opus 4.6 --- socket_basics/core/triage.py | 195 +++++++++++++++++ socket_basics/socket_basics.py | 216 ++++++++++++++++++- tests/__init__.py | 0 tests/test_triage.py | 384 +++++++++++++++++++++++++++++++++ 4 files changed, 794 insertions(+), 1 deletion(-) create mode 100644 socket_basics/core/triage.py create mode 100644 tests/__init__.py create mode 100644 tests/test_triage.py diff --git a/socket_basics/core/triage.py b/socket_basics/core/triage.py new file mode 100644 index 0000000..0738f95 --- /dev/null +++ b/socket_basics/core/triage.py @@ -0,0 +1,195 @@ +"""Triage filtering for Socket Security Basics. + +Fetches triage entries from the Socket API and filters scan components +whose alerts have been triaged (state: ignore or monitor). +""" + +import fnmatch +import logging +from typing import Any, Dict, List, Tuple + +logger = logging.getLogger(__name__) + +# Triage states that cause a finding to be removed from reports +_SUPPRESSED_STATES = {"ignore", "monitor"} + + +def fetch_triage_data(sdk: Any, org_slug: str) -> List[Dict[str, Any]]: + """Fetch all triage alert entries from the Socket API, handling pagination. + + Args: + sdk: Initialized socketdev SDK instance. + org_slug: Organization slug for the API call. + + Returns: + List of triage entry dicts. + """ + all_entries: List[Dict[str, Any]] = [] + page = 1 + per_page = 100 + + while True: + try: + response = sdk.triage.list_alert_triage( + org_slug, + {"per_page": per_page, "page": page}, + ) + except Exception: + logger.exception("Failed to fetch triage data (page %d)", page) + break + + if not isinstance(response, dict): + logger.warning("Unexpected triage API response type: %s", type(response)) + break + + results = response.get("results") or [] + all_entries.extend(results) + + next_page = response.get("nextPage") + if next_page is None: + break + page = int(next_page) + + logger.debug("Fetched %d triage entries for org %s", len(all_entries), org_slug) + return all_entries + + +class TriageFilter: + """Matches local scan alerts against triage entries and filters them out.""" + + def __init__(self, triage_entries: List[Dict[str, Any]]) -> None: + # Only keep entries whose state suppresses findings + self.entries = [ + e for e in triage_entries + if (e.get("state") or "").lower() in _SUPPRESSED_STATES + ] + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def is_alert_triaged(self, component: Dict[str, Any], alert: Dict[str, Any]) -> bool: + """Return True if the alert on the given component matches a suppressed triage entry.""" + alert_keys = self._extract_alert_keys(alert) + if not alert_keys: + return False + + for entry in self.entries: + entry_key = entry.get("alert_key") + if not entry_key: + continue + + if entry_key not in alert_keys: + continue + + # alert_key matched; now check package scope + if self._is_broad_match(entry): + return True + + if self._package_matches(entry, component): + return True + + return False + + def filter_components( + self, components: List[Dict[str, Any]] + ) -> Tuple[List[Dict[str, Any]], int]: + """Remove triaged alerts from components. + + Returns: + (filtered_components, triaged_count) where triaged_count is the + total number of individual alerts removed. + """ + if not self.entries: + return components, 0 + + filtered: List[Dict[str, Any]] = [] + triaged_count = 0 + + for comp in components: + remaining_alerts: List[Dict[str, Any]] = [] + for alert in comp.get("alerts", []): + if self.is_alert_triaged(comp, alert): + triaged_count += 1 + else: + remaining_alerts.append(alert) + + if remaining_alerts: + new_comp = dict(comp) + new_comp["alerts"] = remaining_alerts + filtered.append(new_comp) + + return filtered, triaged_count + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + @staticmethod + def _extract_alert_keys(alert: Dict[str, Any]) -> set: + """Build the set of candidate keys that could match a triage entry's alert_key.""" + keys: set = set() + props = alert.get("props") or {} + + for field in ( + alert.get("title"), + alert.get("type"), + props.get("ruleId"), + props.get("detectorName"), + props.get("vulnerabilityId"), + props.get("cveId"), + ): + if field: + keys.add(str(field)) + + return keys + + @staticmethod + def _is_broad_match(entry: Dict[str, Any]) -> bool: + """Return True when the triage entry has no package scope (applies globally).""" + return ( + entry.get("package_name") is None + and entry.get("package_type") is None + and entry.get("package_version") is None + and entry.get("package_namespace") is None + ) + + @staticmethod + def _version_matches(entry_version: str, component_version: str) -> bool: + """Check version match, supporting wildcard suffix patterns like '1.2.*'.""" + if not entry_version or entry_version == "*": + return True + if not component_version: + return False + # fnmatch handles '*' and '?' glob patterns + return fnmatch.fnmatch(component_version, entry_version) + + @classmethod + def _package_matches(cls, entry: Dict[str, Any], component: Dict[str, Any]) -> bool: + """Return True if the triage entry's package scope matches the component.""" + qualifiers = component.get("qualifiers") or {} + comp_name = component.get("name") or "" + comp_type = ( + qualifiers.get("ecosystem") + or qualifiers.get("type") + or component.get("type") + or "" + ) + comp_version = component.get("version") or qualifiers.get("version") or "" + comp_namespace = qualifiers.get("namespace") or "" + + entry_name = entry.get("package_name") + entry_type = entry.get("package_type") + entry_version = entry.get("package_version") + entry_namespace = entry.get("package_namespace") + + if entry_name is not None and entry_name != comp_name: + return False + if entry_type is not None and entry_type.lower() != comp_type.lower(): + return False + if entry_namespace is not None and entry_namespace != comp_namespace: + return False + if entry_version is not None and not cls._version_matches(entry_version, comp_version): + return False + + return True diff --git a/socket_basics/socket_basics.py b/socket_basics/socket_basics.py index a7f7f04..6311611 100644 --- a/socket_basics/socket_basics.py +++ b/socket_basics/socket_basics.py @@ -17,7 +17,7 @@ import sys import os from pathlib import Path -from typing import Dict, Any, Optional +from typing import Dict, Any, List, Optional import hashlib try: # Python 3.11+ @@ -378,6 +378,214 @@ def submit_socket_facts(self, socket_facts_path: Path, results: Dict[str, Any]) return results + def apply_triage_filter(self, results: Dict[str, Any]) -> Dict[str, Any]: + """Filter out triaged alerts and regenerate notifications. + + Fetches triage entries from the Socket API, removes alerts with + state ``ignore`` or ``monitor``, regenerates connector notifications + for the remaining components, and injects a triage summary line into + github_pr notification content. + + Args: + results: Current scan results dict (components + notifications). + + Returns: + Updated results dict with triaged findings removed. + """ + socket_api_key = self.config.get('socket_api_key') + socket_org = self.config.get('socket_org') + + if not socket_api_key or not socket_org: + logger.debug("Skipping triage filter: missing socket_api_key or socket_org") + return results + + # Import SDK and triage helpers + try: + from socketdev import socketdev + except ImportError: + logger.debug("socketdev SDK not available; skipping triage filter") + return results + + try: + from .core.triage import TriageFilter, fetch_triage_data + except ImportError: + from socket_basics.core.triage import TriageFilter, fetch_triage_data + + sdk = socketdev(token=socket_api_key, timeout=100) + triage_entries = fetch_triage_data(sdk, socket_org) + + if not triage_entries: + logger.debug("No triage entries found; skipping filter") + return results + + triage_filter = TriageFilter(triage_entries) + filtered_components, triaged_count = triage_filter.filter_components( + results.get('components', []) + ) + + if triaged_count == 0: + logger.debug("No findings matched triage entries") + return results + + logger.info("Filtered %d triaged finding(s) from results", triaged_count) + results['components'] = filtered_components + results['triaged_count'] = triaged_count + + # Regenerate notifications from the filtered components + self._regenerate_notifications(results, filtered_components, triaged_count) + + return results + + def _regenerate_notifications( + self, + results: Dict[str, Any], + filtered_components: List[Dict[str, Any]], + triaged_count: int, + ) -> None: + """Regenerate connector notifications from filtered components. + + Groups components by their connector origin (via the ``generatedBy`` + field on alerts), calls each connector's ``generate_notifications``, + merges the results, and injects a triage summary into github_pr + content. + """ + connector_components: Dict[str, List[Dict[str, Any]]] = {} + for comp in filtered_components: + for alert in comp.get('alerts', []): + gen = alert.get('generatedBy') or '' + connector_name = self._connector_name_from_generated_by(gen) + if connector_name: + connector_components.setdefault(connector_name, []).append(comp) + break # one mapping per component is enough + + merged_notifications: Dict[str, list] = {} + + for connector_name, comps in connector_components.items(): + connector = self.connector_manager.loaded_connectors.get(connector_name) + if connector is None: + logger.debug("Connector %s not loaded; skipping notification regen", connector_name) + continue + + if not hasattr(connector, 'generate_notifications'): + logger.debug("Connector %s has no generate_notifications", connector_name) + continue + + try: + if connector_name == 'trivy': + item_name, scan_type = self._derive_trivy_params(comps) + notifs = connector.generate_notifications(comps, item_name, scan_type) + else: + notifs = connector.generate_notifications(comps) + except Exception: + logger.exception("Failed to regenerate notifications for %s", connector_name) + continue + + if not isinstance(notifs, dict): + continue + + for notifier_key, payload in notifs.items(): + if notifier_key not in merged_notifications: + merged_notifications[notifier_key] = payload + elif isinstance(merged_notifications[notifier_key], list) and isinstance(payload, list): + merged_notifications[notifier_key].extend(payload) + + # Inject triage summary into github_pr notification content + full_scan_url = results.get('full_scan_html_url', '') + self._inject_triage_summary(merged_notifications, triaged_count, full_scan_url) + + if merged_notifications: + results['notifications'] = merged_notifications + + @staticmethod + def _connector_name_from_generated_by(generated_by: str) -> str | None: + """Map a generatedBy value back to its connector name.""" + gb = generated_by.lower() + if gb.startswith('opengrep') or gb.startswith('sast'): + return 'opengrep' + if gb == 'trufflehog': + return 'trufflehog' + if gb.startswith('trivy'): + return 'trivy' + if gb == 'socket-tier1': + return 'socket_tier1' + return None + + def _derive_trivy_params( + self, components: List[Dict[str, Any]] + ) -> tuple: + """Derive item_name and scan_type for Trivy notification regeneration.""" + scan_type = 'image' + for comp in components: + for alert in comp.get('alerts', []): + props = alert.get('props') or {} + st = props.get('scanType', '') + if st: + scan_type = st + break + if scan_type != 'image': + break + + item_name = "Unknown" + images_str = ( + self.config.get('container_images', '') + or self.config.get('container_images_to_scan', '') + or self.config.get('docker_images', '') + ) + if images_str: + if isinstance(images_str, list): + item_name = images_str[0] if images_str else "Unknown" + else: + images = [img.strip() for img in str(images_str).split(',') if img.strip()] + item_name = images[0] if images else "Unknown" + else: + dockerfiles = self.config.get('dockerfiles', '') + if dockerfiles: + if isinstance(dockerfiles, list): + item_name = dockerfiles[0] if dockerfiles else "Unknown" + else: + docker_list = [df.strip() for df in str(dockerfiles).split(',') if df.strip()] + item_name = docker_list[0] if docker_list else "Unknown" + + if scan_type == 'vuln' and item_name == "Unknown": + try: + item_name = os.path.basename(str(self.config.workspace)) + except Exception: + item_name = "Workspace" + + return item_name, scan_type + + @staticmethod + def _inject_triage_summary( + notifications: Dict[str, list], + triaged_count: int, + full_scan_url: str, + ) -> None: + """Insert a triage summary line into github_pr notification content.""" + gh_items = notifications.get('github_pr') + if not gh_items or not isinstance(gh_items, list): + return + + dashboard_link = full_scan_url or "https://socket.dev/dashboard" + summary_line = ( + f"\n> :white_check_mark: **{triaged_count} finding(s) triaged** " + f"via [Socket Dashboard]({dashboard_link}) and removed from this report.\n" + ) + + for item in gh_items: + if not isinstance(item, dict) or 'content' not in item: + continue + content = item['content'] + # Insert after the first markdown heading line (# Title) + lines = content.split('\n') + insert_idx = 0 + for i, line in enumerate(lines): + if line.strip().startswith('# '): + insert_idx = i + 1 + break + lines.insert(insert_idx, summary_line) + item['content'] = '\n'.join(lines) + + def main(): """Main entry point""" parser = parse_cli_args() @@ -429,6 +637,12 @@ def main(): except Exception: logger.exception("Failed to submit socket facts file") + # Filter out triaged alerts before notifying + try: + results = scanner.apply_triage_filter(results) + except Exception: + logger.exception("Failed to apply triage filter") + # Optionally upload to S3 if requested try: enable_s3 = getattr(args, 'enable_s3_upload', False) or config.get('enable_s3_upload', False) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_triage.py b/tests/test_triage.py new file mode 100644 index 0000000..b1a472e --- /dev/null +++ b/tests/test_triage.py @@ -0,0 +1,384 @@ +"""Tests for socket_basics.core.triage module.""" + +import pytest +from socket_basics.core.triage import TriageFilter, fetch_triage_data + + +# --------------------------------------------------------------------------- +# Fixtures / helpers +# --------------------------------------------------------------------------- + +def _make_component( + name: str = "lodash", + comp_type: str = "npm", + version: str = "4.17.21", + alerts: list | None = None, +) -> dict: + return { + "id": f"pkg:{comp_type}/{name}@{version}", + "name": name, + "version": version, + "type": comp_type, + "qualifiers": {"ecosystem": comp_type, "version": version}, + "alerts": alerts or [], + } + + +def _make_alert( + title: str = "badEncoding", + alert_type: str = "supplyChainRisk", + severity: str = "high", + rule_id: str | None = None, + detector_name: str | None = None, + cve_id: str | None = None, + generated_by: str = "opengrep-python", +) -> dict: + props: dict = {} + if rule_id: + props["ruleId"] = rule_id + if detector_name: + props["detectorName"] = detector_name + if cve_id: + props["cveId"] = cve_id + return { + "title": title, + "type": alert_type, + "severity": severity, + "generatedBy": generated_by, + "props": props, + } + + +def _make_triage_entry( + alert_key: str, + state: str = "ignore", + package_name: str | None = None, + package_type: str | None = None, + package_version: str | None = None, + package_namespace: str | None = None, +) -> dict: + return { + "uuid": "test-uuid", + "alert_key": alert_key, + "state": state, + "package_name": package_name, + "package_type": package_type, + "package_version": package_version, + "package_namespace": package_namespace, + "note": "", + "organization_id": "test-org", + } + + +# --------------------------------------------------------------------------- +# TriageFilter.is_alert_triaged +# --------------------------------------------------------------------------- + +class TestIsAlertTriaged: + """Tests for the alert matching logic.""" + + def test_broad_match_by_title(self): + """Triage entry with no package info matches any component with matching alert_key.""" + entry = _make_triage_entry(alert_key="badEncoding") + tf = TriageFilter([entry]) + comp = _make_component() + alert = _make_alert(title="badEncoding") + assert tf.is_alert_triaged(comp, alert) is True + + def test_broad_match_by_rule_id(self): + entry = _make_triage_entry(alert_key="python.lang.security.audit.xss") + tf = TriageFilter([entry]) + comp = _make_component() + alert = _make_alert(title="XSS Vulnerability", rule_id="python.lang.security.audit.xss") + assert tf.is_alert_triaged(comp, alert) is True + + def test_broad_match_by_detector_name(self): + entry = _make_triage_entry(alert_key="AWS") + tf = TriageFilter([entry]) + comp = _make_component() + alert = _make_alert(title="AWS Key Detected", detector_name="AWS") + assert tf.is_alert_triaged(comp, alert) is True + + def test_broad_match_by_cve(self): + entry = _make_triage_entry(alert_key="CVE-2024-1234") + tf = TriageFilter([entry]) + comp = _make_component() + alert = _make_alert(title="Some Vuln", cve_id="CVE-2024-1234") + assert tf.is_alert_triaged(comp, alert) is True + + def test_no_match_different_key(self): + entry = _make_triage_entry(alert_key="differentRule") + tf = TriageFilter([entry]) + comp = _make_component() + alert = _make_alert(title="badEncoding") + assert tf.is_alert_triaged(comp, alert) is False + + def test_package_scoped_match(self): + """Triage entry with package info only matches the specific package.""" + entry = _make_triage_entry( + alert_key="badEncoding", + package_name="lodash", + package_type="npm", + ) + tf = TriageFilter([entry]) + + comp_match = _make_component(name="lodash", comp_type="npm") + comp_no_match = _make_component(name="express", comp_type="npm") + alert = _make_alert(title="badEncoding") + + assert tf.is_alert_triaged(comp_match, alert) is True + assert tf.is_alert_triaged(comp_no_match, alert) is False + + def test_package_version_exact_match(self): + entry = _make_triage_entry( + alert_key="badEncoding", + package_name="lodash", + package_type="npm", + package_version="4.17.21", + ) + tf = TriageFilter([entry]) + + comp_match = _make_component(name="lodash", comp_type="npm", version="4.17.21") + comp_no_match = _make_component(name="lodash", comp_type="npm", version="4.17.20") + alert = _make_alert(title="badEncoding") + + assert tf.is_alert_triaged(comp_match, alert) is True + assert tf.is_alert_triaged(comp_no_match, alert) is False + + def test_version_wildcard(self): + entry = _make_triage_entry( + alert_key="badEncoding", + package_name="lodash", + package_type="npm", + package_version="4.17.*", + ) + tf = TriageFilter([entry]) + alert = _make_alert(title="badEncoding") + + assert tf.is_alert_triaged( + _make_component(name="lodash", comp_type="npm", version="4.17.21"), alert + ) is True + assert tf.is_alert_triaged( + _make_component(name="lodash", comp_type="npm", version="4.17.0"), alert + ) is True + assert tf.is_alert_triaged( + _make_component(name="lodash", comp_type="npm", version="4.18.0"), alert + ) is False + + def test_version_star_matches_all(self): + entry = _make_triage_entry( + alert_key="badEncoding", + package_name="lodash", + package_type="npm", + package_version="*", + ) + tf = TriageFilter([entry]) + alert = _make_alert(title="badEncoding") + assert tf.is_alert_triaged( + _make_component(name="lodash", comp_type="npm", version="99.0.0"), alert + ) is True + + def test_states_block_and_warn_not_suppressed(self): + """Triage entries with block/warn/inherit states should not filter findings.""" + for state in ("block", "warn", "inherit"): + entry = _make_triage_entry(alert_key="badEncoding", state=state) + tf = TriageFilter([entry]) + assert tf.entries == [], f"state={state} should be excluded from filter entries" + + def test_state_monitor_suppressed(self): + entry = _make_triage_entry(alert_key="badEncoding", state="monitor") + tf = TriageFilter([entry]) + comp = _make_component() + alert = _make_alert(title="badEncoding") + assert tf.is_alert_triaged(comp, alert) is True + + def test_alert_with_no_matchable_keys(self): + """Alert with no title, type, or relevant props should not match.""" + entry = _make_triage_entry(alert_key="something") + tf = TriageFilter([entry]) + comp = _make_component() + alert = {"severity": "high", "props": {}} + assert tf.is_alert_triaged(comp, alert) is False + + +# --------------------------------------------------------------------------- +# TriageFilter.filter_components +# --------------------------------------------------------------------------- + +class TestFilterComponents: + def test_removes_triaged_alerts(self): + entry = _make_triage_entry(alert_key="badEncoding") + tf = TriageFilter([entry]) + + alert_triaged = _make_alert(title="badEncoding") + alert_kept = _make_alert(title="otherIssue") + comp = _make_component(alerts=[alert_triaged, alert_kept]) + + filtered, count = tf.filter_components([comp]) + assert count == 1 + assert len(filtered) == 1 + assert len(filtered[0]["alerts"]) == 1 + assert filtered[0]["alerts"][0]["title"] == "otherIssue" + + def test_removes_component_when_all_alerts_triaged(self): + entry = _make_triage_entry(alert_key="badEncoding") + tf = TriageFilter([entry]) + + comp = _make_component(alerts=[_make_alert(title="badEncoding")]) + filtered, count = tf.filter_components([comp]) + assert count == 1 + assert len(filtered) == 0 + + def test_no_triage_entries_returns_original(self): + tf = TriageFilter([]) + comp = _make_component(alerts=[_make_alert()]) + filtered, count = tf.filter_components([comp]) + assert count == 0 + assert filtered is [comp] or filtered == [comp] + + def test_multiple_components_mixed(self): + entry = _make_triage_entry(alert_key="badEncoding") + tf = TriageFilter([entry]) + + comp1 = _make_component(name="a", alerts=[_make_alert(title="badEncoding")]) + comp2 = _make_component(name="b", alerts=[_make_alert(title="otherIssue")]) + comp3 = _make_component( + name="c", + alerts=[ + _make_alert(title="badEncoding"), + _make_alert(title="keepMe"), + ], + ) + + filtered, count = tf.filter_components([comp1, comp2, comp3]) + assert count == 2 + assert len(filtered) == 2 + names = [c["name"] for c in filtered] + assert "a" not in names + assert "b" in names + assert "c" in names + + +# --------------------------------------------------------------------------- +# fetch_triage_data +# --------------------------------------------------------------------------- + +class TestFetchTriageData: + def test_single_page(self): + class FakeTriageAPI: + def list_alert_triage(self, org, params): + return {"results": [{"alert_key": "a", "state": "ignore"}], "nextPage": None} + + class FakeSDK: + triage = FakeTriageAPI() + + entries = fetch_triage_data(FakeSDK(), "my-org") + assert len(entries) == 1 + assert entries[0]["alert_key"] == "a" + + def test_pagination(self): + class FakeTriageAPI: + def __init__(self): + self.call_count = 0 + + def list_alert_triage(self, org, params): + self.call_count += 1 + if params.get("page") == 1: + return {"results": [{"alert_key": "a"}], "nextPage": 2} + return {"results": [{"alert_key": "b"}], "nextPage": None} + + class FakeSDK: + triage = FakeTriageAPI() + + entries = fetch_triage_data(FakeSDK(), "my-org") + assert len(entries) == 2 + + def test_api_error_returns_partial(self): + class FakeTriageAPI: + def __init__(self): + self.calls = 0 + + def list_alert_triage(self, org, params): + self.calls += 1 + if self.calls == 1: + return {"results": [{"alert_key": "a"}], "nextPage": 2} + raise RuntimeError("API error") + + class FakeSDK: + triage = FakeTriageAPI() + + entries = fetch_triage_data(FakeSDK(), "my-org") + assert len(entries) == 1 + + +# --------------------------------------------------------------------------- +# SecurityScanner._connector_name_from_generated_by +# --------------------------------------------------------------------------- + +class TestConnectorNameMapping: + def test_opengrep_variants(self): + from socket_basics.socket_basics import SecurityScanner + assert SecurityScanner._connector_name_from_generated_by("opengrep-python") == "opengrep" + assert SecurityScanner._connector_name_from_generated_by("sast-generic") == "opengrep" + + def test_trufflehog(self): + from socket_basics.socket_basics import SecurityScanner + assert SecurityScanner._connector_name_from_generated_by("trufflehog") == "trufflehog" + + def test_trivy_variants(self): + from socket_basics.socket_basics import SecurityScanner + assert SecurityScanner._connector_name_from_generated_by("trivy-dockerfile") == "trivy" + assert SecurityScanner._connector_name_from_generated_by("trivy-image") == "trivy" + assert SecurityScanner._connector_name_from_generated_by("trivy-npm") == "trivy" + + def test_socket_tier1(self): + from socket_basics.socket_basics import SecurityScanner + assert SecurityScanner._connector_name_from_generated_by("socket-tier1") == "socket_tier1" + + def test_unknown_returns_none(self): + from socket_basics.socket_basics import SecurityScanner + assert SecurityScanner._connector_name_from_generated_by("unknown-tool") is None + + +# --------------------------------------------------------------------------- +# SecurityScanner._inject_triage_summary +# --------------------------------------------------------------------------- + +class TestInjectTriageSummary: + def test_injects_after_heading(self): + from socket_basics.socket_basics import SecurityScanner + + notifications = { + "github_pr": [ + { + "title": "SAST Findings", + "content": "\n# SAST Python Findings\n### Summary\nSome content\n", + } + ] + } + SecurityScanner._inject_triage_summary(notifications, 3, "https://socket.dev/scan/123") + + content = notifications["github_pr"][0]["content"] + assert "3 finding(s) triaged" in content + assert "Socket Dashboard" in content + # Summary line should appear after the # heading + lines = content.split("\n") + heading_idx = next(i for i, l in enumerate(lines) if l.strip().startswith("# ")) + summary_idx = next(i for i, l in enumerate(lines) if "triaged" in l) + assert summary_idx > heading_idx + + def test_no_github_pr_key_is_noop(self): + from socket_basics.socket_basics import SecurityScanner + + notifications = {"slack": [{"title": "t", "content": "c"}]} + SecurityScanner._inject_triage_summary(notifications, 5, "") + assert "github_pr" not in notifications + + def test_uses_default_dashboard_link(self): + from socket_basics.socket_basics import SecurityScanner + + notifications = { + "github_pr": [{"title": "t", "content": "# Title\nBody"}] + } + SecurityScanner._inject_triage_summary(notifications, 1, "") + assert "https://socket.dev/dashboard" in notifications["github_pr"][0]["content"] From 098ed7291266ad301af83460efd363280193234c Mon Sep 17 00:00:00 2001 From: Carl Bergenhem Date: Thu, 5 Feb 2026 15:36:21 -0800 Subject: [PATCH 06/13] Handle triage API access denied gracefully Log an info-level message instead of an error traceback when the Socket API token lacks triage permissions, and skip filtering so the scan completes normally with all findings intact. Co-Authored-By: Claude Opus 4.6 --- socket_basics/core/triage.py | 13 +++++++++++-- tests/test_triage.py | 24 ++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/socket_basics/core/triage.py b/socket_basics/core/triage.py index 0738f95..f9b9796 100644 --- a/socket_basics/core/triage.py +++ b/socket_basics/core/triage.py @@ -34,8 +34,17 @@ def fetch_triage_data(sdk: Any, org_slug: str) -> List[Dict[str, Any]]: org_slug, {"per_page": per_page, "page": page}, ) - except Exception: - logger.exception("Failed to fetch triage data (page %d)", page) + except Exception as exc: + # Handle insufficient permissions gracefully so the scan + # continues without triage filtering. + exc_name = type(exc).__name__ + if "AccessDenied" in exc_name or "Forbidden" in exc_name: + logger.info( + "Triage API access denied (insufficient permissions). " + "Skipping triage filtering for this run." + ) + else: + logger.warning("Failed to fetch triage data (page %d): %s", page, exc) break if not isinstance(response, dict): diff --git a/tests/test_triage.py b/tests/test_triage.py index b1a472e..c0c4954 100644 --- a/tests/test_triage.py +++ b/tests/test_triage.py @@ -310,6 +310,30 @@ class FakeSDK: entries = fetch_triage_data(FakeSDK(), "my-org") assert len(entries) == 1 + def test_access_denied_returns_empty_and_logs_info(self, caplog): + """Insufficient permissions should log an info message (not an error) and return empty.""" + + class APIAccessDenied(Exception): + pass + + class FakeTriageAPI: + def list_alert_triage(self, org, params): + raise APIAccessDenied("Insufficient permissions.") + + class FakeSDK: + triage = FakeTriageAPI() + + import logging + with caplog.at_level(logging.DEBUG): + entries = fetch_triage_data(FakeSDK(), "my-org") + + assert entries == [] + info_messages = [r for r in caplog.records if r.levelno == logging.INFO] + assert any("access denied" in m.message.lower() for m in info_messages) + # Should NOT produce an ERROR-level record + error_messages = [r for r in caplog.records if r.levelno >= logging.ERROR] + assert not error_messages + # --------------------------------------------------------------------------- # SecurityScanner._connector_name_from_generated_by From 1e551d285d64532e73418e74c9a4a7587395ca0d Mon Sep 17 00:00:00 2001 From: Carl Bergenhem Date: Thu, 5 Feb 2026 15:46:17 -0800 Subject: [PATCH 07/13] Fix stale notifications after triage and improve logging Always replace results['notifications'] after triage filtering so pre-filter content is never forwarded to notifiers. Skip PR comment API calls when content is unchanged. Add info-level logging for triaged/remaining finding counts and connector regeneration details. Co-Authored-By: Claude Opus 4.6 --- .../core/notification/github_pr_notifier.py | 12 ++++ socket_basics/socket_basics.py | 60 +++++++++++++++++-- 2 files changed, 66 insertions(+), 6 deletions(-) diff --git a/socket_basics/core/notification/github_pr_notifier.py b/socket_basics/core/notification/github_pr_notifier.py index 555d03e..8c258bc 100644 --- a/socket_basics/core/notification/github_pr_notifier.py +++ b/socket_basics/core/notification/github_pr_notifier.py @@ -100,6 +100,18 @@ def notify(self, facts: Dict[str, Any]) -> None: # Update existing comments with new section content for comment_id, updated_body in comment_updates.items(): + # Detect whether content actually changed before making the API call + original_body = next( + (c.get('body', '') for c in existing_comments if c.get('id') == comment_id), + '', + ) + if original_body == updated_body: + logger.info( + 'GithubPRNotifier: comment %s content unchanged; skipping update', + comment_id, + ) + continue + success = self._update_comment(pr_number, comment_id, updated_body) if success: logger.info('GithubPRNotifier: updated existing comment %s', comment_id) diff --git a/socket_basics/socket_basics.py b/socket_basics/socket_basics.py index 6311611..984133e 100644 --- a/socket_basics/socket_basics.py +++ b/socket_basics/socket_basics.py @@ -419,15 +419,29 @@ def apply_triage_filter(self, results: Dict[str, Any]) -> Dict[str, Any]: return results triage_filter = TriageFilter(triage_entries) + original_components = results.get('components', []) + original_alert_count = sum( + len(c.get('alerts', [])) for c in original_components + ) filtered_components, triaged_count = triage_filter.filter_components( - results.get('components', []) + original_components ) if triaged_count == 0: - logger.debug("No findings matched triage entries") + logger.info( + "Triage filter matched 0 of %d finding(s); no changes applied", + original_alert_count, + ) return results - logger.info("Filtered %d triaged finding(s) from results", triaged_count) + remaining_alert_count = sum( + len(c.get('alerts', [])) for c in filtered_components + ) + logger.info( + "Triage filter removed %d finding(s); %d finding(s) remain", + triaged_count, + remaining_alert_count, + ) results['components'] = filtered_components results['triaged_count'] = triaged_count @@ -448,22 +462,47 @@ def _regenerate_notifications( field on alerts), calls each connector's ``generate_notifications``, merges the results, and injects a triage summary into github_pr content. + + Always replaces ``results['notifications']`` so stale pre-filter + notifications are never forwarded to notifiers. """ connector_components: Dict[str, List[Dict[str, Any]]] = {} + unmapped_count = 0 for comp in filtered_components: + mapped = False for alert in comp.get('alerts', []): gen = alert.get('generatedBy') or '' connector_name = self._connector_name_from_generated_by(gen) if connector_name: connector_components.setdefault(connector_name, []).append(comp) + mapped = True break # one mapping per component is enough + if not mapped: + unmapped_count += 1 + + if unmapped_count: + logger.debug( + "Triage regen: %d component(s) could not be mapped to a connector", + unmapped_count, + ) + + logger.info( + "Regenerating notifications for %d connector(s): %s", + len(connector_components), + ", ".join(connector_components.keys()) or "(none)", + ) merged_notifications: Dict[str, list] = {} for connector_name, comps in connector_components.items(): connector = self.connector_manager.loaded_connectors.get(connector_name) if connector is None: - logger.debug("Connector %s not loaded; skipping notification regen", connector_name) + logger.warning( + "Connector %s not in loaded_connectors (available: %s); " + "cannot regenerate its notifications", + connector_name, + ", ".join(self.connector_manager.loaded_connectors.keys()), + ) continue if not hasattr(connector, 'generate_notifications'): @@ -483,6 +522,13 @@ def _regenerate_notifications( if not isinstance(notifs, dict): continue + notifier_keys = [k for k, v in notifs.items() if v] + logger.debug( + "Connector %s produced notifications for: %s", + connector_name, + ", ".join(notifier_keys) or "(empty)", + ) + for notifier_key, payload in notifs.items(): if notifier_key not in merged_notifications: merged_notifications[notifier_key] = payload @@ -493,8 +539,10 @@ def _regenerate_notifications( full_scan_url = results.get('full_scan_html_url', '') self._inject_triage_summary(merged_notifications, triaged_count, full_scan_url) - if merged_notifications: - results['notifications'] = merged_notifications + # Always replace notifications so stale pre-filter content is never + # forwarded to notifiers. An empty dict is valid and means every + # finding was triaged. + results['notifications'] = merged_notifications @staticmethod def _connector_name_from_generated_by(generated_by: str) -> str | None: From a4851703893808a6fe6e003e903331222102d77a Mon Sep 17 00:00:00 2001 From: Carl Bergenhem Date: Thu, 5 Feb 2026 16:43:48 -0800 Subject: [PATCH 08/13] Rework triage matching to use stream-based alert key lookup The triage API returns opaque alert_key hashes, not human-readable identifiers. This rewrites the matching logic to stream the full scan via sdk.fullscans.stream(), cross-reference Socket alert keys against triage entries, and map back to local components by artifact ID. Co-Authored-By: Claude Opus 4.6 --- socket_basics/core/triage.py | 273 +++++++++++------- socket_basics/socket_basics.py | 39 ++- tests/test_triage.py | 490 +++++++++++++++++++++++---------- 3 files changed, 536 insertions(+), 266 deletions(-) diff --git a/socket_basics/core/triage.py b/socket_basics/core/triage.py index f9b9796..59bf6ce 100644 --- a/socket_basics/core/triage.py +++ b/socket_basics/core/triage.py @@ -1,12 +1,12 @@ """Triage filtering for Socket Security Basics. -Fetches triage entries from the Socket API and filters scan components -whose alerts have been triaged (state: ignore or monitor). +Streams the full scan from the Socket API to obtain alert keys, fetches +triage entries, and filters local scan components whose alerts have been +triaged (state: ignore or monitor). """ -import fnmatch import logging -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Set, Tuple logger = logging.getLogger(__name__) @@ -14,6 +14,10 @@ _SUPPRESSED_STATES = {"ignore", "monitor"} +# ------------------------------------------------------------------ +# API helpers +# ------------------------------------------------------------------ + def fetch_triage_data(sdk: Any, org_slug: str) -> List[Dict[str, Any]]: """Fetch all triage alert entries from the Socket API, handling pagination. @@ -63,62 +67,145 @@ def fetch_triage_data(sdk: Any, org_slug: str) -> List[Dict[str, Any]]: return all_entries -class TriageFilter: - """Matches local scan alerts against triage entries and filters them out.""" +def stream_full_scan_alerts( + sdk: Any, org_slug: str, full_scan_id: str +) -> Dict[str, List[Dict[str, Any]]]: + """Stream a full scan and extract alert keys grouped by artifact. + + Returns: + Mapping of artifact ID to list of alert dicts. Each alert dict + contains at minimum ``key`` and ``type``. The artifact metadata + (name, version, type, etc.) is included under a ``_artifact`` key + in every alert dict for downstream matching. + """ + try: + # use_types=False returns a plain dict keyed by artifact ID + response = sdk.fullscans.stream(org_slug, full_scan_id, use_types=False) + except Exception as exc: + exc_name = type(exc).__name__ + if "AccessDenied" in exc_name or "Forbidden" in exc_name: + logger.info( + "Full scan stream access denied (insufficient permissions). " + "Skipping triage filtering for this run." + ) + else: + logger.warning("Failed to stream full scan %s: %s", full_scan_id, exc) + return {} + + if not isinstance(response, dict): + logger.warning("Unexpected full scan stream response type: %s", type(response)) + return {} + + artifact_alerts: Dict[str, List[Dict[str, Any]]] = {} + for artifact_id, artifact in response.items(): + if not isinstance(artifact, dict): + continue + alerts = artifact.get("alerts") or [] + if not alerts: + continue + meta = { + "artifact_id": artifact_id, + "artifact_name": artifact.get("name"), + "artifact_version": artifact.get("version"), + "artifact_type": artifact.get("type"), + "artifact_namespace": artifact.get("namespace"), + "artifact_subpath": artifact.get("subPath") or artifact.get("subpath"), + } + enriched = [] + for a in alerts: + if isinstance(a, dict) and a.get("key"): + enriched.append({**a, "_artifact": meta}) + if enriched: + artifact_alerts[artifact_id] = enriched + + total_alerts = sum(len(v) for v in artifact_alerts.values()) + logger.debug( + "Streamed full scan %s: %d artifact(s), %d alert(s) with keys", + full_scan_id, + len(artifact_alerts), + total_alerts, + ) + return artifact_alerts + + +# ------------------------------------------------------------------ +# TriageFilter +# ------------------------------------------------------------------ - def __init__(self, triage_entries: List[Dict[str, Any]]) -> None: - # Only keep entries whose state suppresses findings - self.entries = [ - e for e in triage_entries - if (e.get("state") or "").lower() in _SUPPRESSED_STATES - ] +class TriageFilter: + """Cross-references Socket alert keys against triage entries and + maps triaged alerts back to local scan components.""" + + def __init__( + self, + triage_entries: List[Dict[str, Any]], + artifact_alerts: Dict[str, List[Dict[str, Any]]], + ) -> None: + # Build set of suppressed alert keys + self.triaged_keys: Set[str] = set() + for entry in triage_entries: + state = (entry.get("state") or "").lower() + key = entry.get("alert_key") + if state in _SUPPRESSED_STATES and key: + self.triaged_keys.add(key) + + # Flatten all Socket alerts for lookup + self._socket_alerts: List[Dict[str, Any]] = [] + for alerts in artifact_alerts.values(): + self._socket_alerts.extend(alerts) + + # Build a mapping from (artifact_id, alert_type) to triaged status + # for fast lookups when matching against local components + self._triaged_by_artifact: Dict[str, Set[str]] = {} + for alert in self._socket_alerts: + if alert.get("key") in self.triaged_keys: + art_id = alert.get("_artifact", {}).get("artifact_id", "") + alert_type = alert.get("type") or "" + self._triaged_by_artifact.setdefault(art_id, set()).add(alert_type) # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ - def is_alert_triaged(self, component: Dict[str, Any], alert: Dict[str, Any]) -> bool: - """Return True if the alert on the given component matches a suppressed triage entry.""" - alert_keys = self._extract_alert_keys(alert) - if not alert_keys: - return False - - for entry in self.entries: - entry_key = entry.get("alert_key") - if not entry_key: - continue - - if entry_key not in alert_keys: - continue - - # alert_key matched; now check package scope - if self._is_broad_match(entry): - return True - - if self._package_matches(entry, component): - return True - - return False - def filter_components( self, components: List[Dict[str, Any]] ) -> Tuple[List[Dict[str, Any]], int]: - """Remove triaged alerts from components. + """Remove triaged alerts from local components. + + Matches local components to Socket artifacts by component ID, then + checks each local alert against the set of triaged alert types for + that artifact. Returns: - (filtered_components, triaged_count) where triaged_count is the - total number of individual alerts removed. + (filtered_components, triaged_count) """ - if not self.entries: + if not self.triaged_keys: + return components, 0 + + # Build lookup: component id -> set of triaged Socket alert types + triaged_types_by_component = self._map_components_to_triaged_types(components) + + if not triaged_types_by_component: + logger.debug( + "No local components matched Socket artifacts with triaged alerts" + ) return components, 0 filtered: List[Dict[str, Any]] = [] triaged_count = 0 for comp in components: + comp_id = comp.get("id") or "" + triaged_types = triaged_types_by_component.get(comp_id) + + if triaged_types is None: + # Component had no triaged alerts; keep as-is + filtered.append(comp) + continue + remaining_alerts: List[Dict[str, Any]] = [] for alert in comp.get("alerts", []): - if self.is_alert_triaged(comp, alert): + if self._local_alert_is_triaged(alert, triaged_types): triaged_count += 1 else: remaining_alerts.append(alert) @@ -134,71 +221,49 @@ def filter_components( # Internal helpers # ------------------------------------------------------------------ - @staticmethod - def _extract_alert_keys(alert: Dict[str, Any]) -> set: - """Build the set of candidate keys that could match a triage entry's alert_key.""" - keys: set = set() - props = alert.get("props") or {} - - for field in ( - alert.get("title"), - alert.get("type"), - props.get("ruleId"), - props.get("detectorName"), - props.get("vulnerabilityId"), - props.get("cveId"), - ): - if field: - keys.add(str(field)) + def _map_components_to_triaged_types( + self, components: List[Dict[str, Any]] + ) -> Dict[str, Set[str]]: + """Map local component IDs to the set of triaged Socket alert types. - return keys + Matches by component ``id`` (which is typically a hash that Socket + also uses as the artifact ID). + """ + local_ids = {comp.get("id") for comp in components if comp.get("id")} + result: Dict[str, Set[str]] = {} + for comp_id in local_ids: + triaged = self._triaged_by_artifact.get(comp_id) + if triaged: + result[comp_id] = triaged + return result @staticmethod - def _is_broad_match(entry: Dict[str, Any]) -> bool: - """Return True when the triage entry has no package scope (applies globally).""" - return ( - entry.get("package_name") is None - and entry.get("package_type") is None - and entry.get("package_version") is None - and entry.get("package_namespace") is None - ) + def _local_alert_is_triaged( + alert: Dict[str, Any], triaged_types: Set[str] + ) -> bool: + """Check if a local alert matches any of the triaged Socket alert types. + + Socket alert ``type`` values (e.g. ``badEncoding``, ``cve``) are + compared against the local alert's ``type`` field. When the local + alert type is too generic (``"generic"`` or ``"vulnerability"``), + we fall back to matching on ``title``, ``props.ruleId``, or + ``props.vulnerabilityId``. + """ + # Direct type match + local_type = alert.get("type") or "" + if local_type and local_type not in ("generic", "vulnerability"): + return local_type in triaged_types - @staticmethod - def _version_matches(entry_version: str, component_version: str) -> bool: - """Check version match, supporting wildcard suffix patterns like '1.2.*'.""" - if not entry_version or entry_version == "*": - return True - if not component_version: - return False - # fnmatch handles '*' and '?' glob patterns - return fnmatch.fnmatch(component_version, entry_version) - - @classmethod - def _package_matches(cls, entry: Dict[str, Any], component: Dict[str, Any]) -> bool: - """Return True if the triage entry's package scope matches the component.""" - qualifiers = component.get("qualifiers") or {} - comp_name = component.get("name") or "" - comp_type = ( - qualifiers.get("ecosystem") - or qualifiers.get("type") - or component.get("type") - or "" - ) - comp_version = component.get("version") or qualifiers.get("version") or "" - comp_namespace = qualifiers.get("namespace") or "" - - entry_name = entry.get("package_name") - entry_type = entry.get("package_type") - entry_version = entry.get("package_version") - entry_namespace = entry.get("package_namespace") - - if entry_name is not None and entry_name != comp_name: - return False - if entry_type is not None and entry_type.lower() != comp_type.lower(): - return False - if entry_namespace is not None and entry_namespace != comp_namespace: - return False - if entry_version is not None and not cls._version_matches(entry_version, comp_version): - return False - - return True + # Fallback: match candidate fields against triaged types + props = alert.get("props") or {} + candidates = { + v for v in ( + alert.get("title"), + props.get("ruleId"), + props.get("detectorName"), + props.get("vulnerabilityId"), + props.get("cveId"), + ) + if v + } + return bool(candidates & triaged_types) diff --git a/socket_basics/socket_basics.py b/socket_basics/socket_basics.py index 984133e..c5683b2 100644 --- a/socket_basics/socket_basics.py +++ b/socket_basics/socket_basics.py @@ -381,10 +381,10 @@ def submit_socket_facts(self, socket_facts_path: Path, results: Dict[str, Any]) def apply_triage_filter(self, results: Dict[str, Any]) -> Dict[str, Any]: """Filter out triaged alerts and regenerate notifications. - Fetches triage entries from the Socket API, removes alerts with - state ``ignore`` or ``monitor``, regenerates connector notifications - for the remaining components, and injects a triage summary line into - github_pr notification content. + Streams the full scan from the Socket API to obtain alert keys, + cross-references them with triage entries, removes suppressed + alerts from local components, regenerates connector notifications, + and injects a triage summary into github_pr content. Args: results: Current scan results dict (components + notifications). @@ -394,11 +394,16 @@ def apply_triage_filter(self, results: Dict[str, Any]) -> Dict[str, Any]: """ socket_api_key = self.config.get('socket_api_key') socket_org = self.config.get('socket_org') + full_scan_id = results.get('full_scan_id') if not socket_api_key or not socket_org: logger.debug("Skipping triage filter: missing socket_api_key or socket_org") return results + if not full_scan_id: + logger.debug("Skipping triage filter: no full_scan_id in results") + return results + # Import SDK and triage helpers try: from socketdev import socketdev @@ -407,18 +412,34 @@ def apply_triage_filter(self, results: Dict[str, Any]) -> Dict[str, Any]: return results try: - from .core.triage import TriageFilter, fetch_triage_data + from .core.triage import TriageFilter, fetch_triage_data, stream_full_scan_alerts except ImportError: - from socket_basics.core.triage import TriageFilter, fetch_triage_data + from socket_basics.core.triage import TriageFilter, fetch_triage_data, stream_full_scan_alerts sdk = socketdev(token=socket_api_key, timeout=100) - triage_entries = fetch_triage_data(sdk, socket_org) + # Fetch triage entries and stream full scan alert keys in sequence + triage_entries = fetch_triage_data(sdk, socket_org) if not triage_entries: - logger.debug("No triage entries found; skipping filter") + logger.info("No triage entries found; skipping filter") + return results + + suppressed_count = sum( + 1 for e in triage_entries + if (e.get("state") or "").lower() in ("ignore", "monitor") + ) + logger.info( + "Fetched %d triage entries (%d with suppressed state)", + len(triage_entries), + suppressed_count, + ) + + artifact_alerts = stream_full_scan_alerts(sdk, socket_org, full_scan_id) + if not artifact_alerts: + logger.info("No alert keys returned from full scan stream; skipping filter") return results - triage_filter = TriageFilter(triage_entries) + triage_filter = TriageFilter(triage_entries, artifact_alerts) original_components = results.get('components', []) original_alert_count = sum( len(c.get('alerts', [])) for c in original_components diff --git a/tests/test_triage.py b/tests/test_triage.py index c0c4954..a05e6e5 100644 --- a/tests/test_triage.py +++ b/tests/test_triage.py @@ -1,21 +1,30 @@ """Tests for socket_basics.core.triage module.""" +import logging import pytest -from socket_basics.core.triage import TriageFilter, fetch_triage_data +from socket_basics.core.triage import ( + TriageFilter, + fetch_triage_data, + stream_full_scan_alerts, +) # --------------------------------------------------------------------------- # Fixtures / helpers # --------------------------------------------------------------------------- +ARTIFACT_ID = "abc123" + + def _make_component( + comp_id: str = ARTIFACT_ID, name: str = "lodash", comp_type: str = "npm", version: str = "4.17.21", alerts: list | None = None, ) -> dict: return { - "id": f"pkg:{comp_type}/{name}@{version}", + "id": comp_id, "name": name, "version": version, "type": comp_type, @@ -24,9 +33,9 @@ def _make_component( } -def _make_alert( +def _make_local_alert( title: str = "badEncoding", - alert_type: str = "supplyChainRisk", + alert_type: str = "badEncoding", severity: str = "high", rule_id: str | None = None, detector_name: str | None = None, @@ -52,153 +61,144 @@ def _make_alert( def _make_triage_entry( alert_key: str, state: str = "ignore", - package_name: str | None = None, - package_type: str | None = None, - package_version: str | None = None, - package_namespace: str | None = None, ) -> dict: return { "uuid": "test-uuid", "alert_key": alert_key, "state": state, - "package_name": package_name, - "package_type": package_type, - "package_version": package_version, - "package_namespace": package_namespace, "note": "", "organization_id": "test-org", } +def _make_artifact_alerts( + artifact_id: str = ARTIFACT_ID, + alerts: list[dict] | None = None, + name: str = "lodash", + version: str = "4.17.21", + pkg_type: str = "npm", +) -> dict[str, list[dict]]: + """Build an artifact_alerts mapping with enriched _artifact metadata.""" + meta = { + "artifact_id": artifact_id, + "artifact_name": name, + "artifact_version": version, + "artifact_type": pkg_type, + "artifact_namespace": None, + "artifact_subpath": None, + } + enriched = [{**a, "_artifact": meta} for a in (alerts or [])] + return {artifact_id: enriched} + + +def _socket_alert(key: str, alert_type: str) -> dict: + """Create a minimal Socket alert dict (as returned by the full scan stream).""" + return {"key": key, "type": alert_type} + + # --------------------------------------------------------------------------- -# TriageFilter.is_alert_triaged +# TriageFilter construction # --------------------------------------------------------------------------- -class TestIsAlertTriaged: - """Tests for the alert matching logic.""" - - def test_broad_match_by_title(self): - """Triage entry with no package info matches any component with matching alert_key.""" - entry = _make_triage_entry(alert_key="badEncoding") - tf = TriageFilter([entry]) - comp = _make_component() - alert = _make_alert(title="badEncoding") - assert tf.is_alert_triaged(comp, alert) is True - - def test_broad_match_by_rule_id(self): - entry = _make_triage_entry(alert_key="python.lang.security.audit.xss") - tf = TriageFilter([entry]) - comp = _make_component() - alert = _make_alert(title="XSS Vulnerability", rule_id="python.lang.security.audit.xss") - assert tf.is_alert_triaged(comp, alert) is True - - def test_broad_match_by_detector_name(self): - entry = _make_triage_entry(alert_key="AWS") - tf = TriageFilter([entry]) - comp = _make_component() - alert = _make_alert(title="AWS Key Detected", detector_name="AWS") - assert tf.is_alert_triaged(comp, alert) is True - - def test_broad_match_by_cve(self): - entry = _make_triage_entry(alert_key="CVE-2024-1234") - tf = TriageFilter([entry]) - comp = _make_component() - alert = _make_alert(title="Some Vuln", cve_id="CVE-2024-1234") - assert tf.is_alert_triaged(comp, alert) is True - - def test_no_match_different_key(self): - entry = _make_triage_entry(alert_key="differentRule") - tf = TriageFilter([entry]) - comp = _make_component() - alert = _make_alert(title="badEncoding") - assert tf.is_alert_triaged(comp, alert) is False - - def test_package_scoped_match(self): - """Triage entry with package info only matches the specific package.""" - entry = _make_triage_entry( - alert_key="badEncoding", - package_name="lodash", - package_type="npm", +class TestTriageFilterInit: + def test_builds_triaged_keys_for_ignore(self): + entries = [_make_triage_entry("hash-1", state="ignore")] + artifact_alerts = _make_artifact_alerts( + alerts=[_socket_alert("hash-1", "badEncoding")] + ) + tf = TriageFilter(entries, artifact_alerts) + assert "hash-1" in tf.triaged_keys + + def test_builds_triaged_keys_for_monitor(self): + entries = [_make_triage_entry("hash-2", state="monitor")] + artifact_alerts = _make_artifact_alerts( + alerts=[_socket_alert("hash-2", "cve")] + ) + tf = TriageFilter(entries, artifact_alerts) + assert "hash-2" in tf.triaged_keys + + def test_excludes_block_warn_inherit_states(self): + entries = [ + _make_triage_entry("h1", state="block"), + _make_triage_entry("h2", state="warn"), + _make_triage_entry("h3", state="inherit"), + ] + artifact_alerts = _make_artifact_alerts( + alerts=[ + _socket_alert("h1", "a"), + _socket_alert("h2", "b"), + _socket_alert("h3", "c"), + ] + ) + tf = TriageFilter(entries, artifact_alerts) + assert tf.triaged_keys == set() + + def test_builds_triaged_by_artifact_mapping(self): + entries = [_make_triage_entry("hash-1", state="ignore")] + artifact_alerts = _make_artifact_alerts( + artifact_id="art-1", + alerts=[_socket_alert("hash-1", "badEncoding")], ) - tf = TriageFilter([entry]) + tf = TriageFilter(entries, artifact_alerts) + assert "art-1" in tf._triaged_by_artifact + assert "badEncoding" in tf._triaged_by_artifact["art-1"] - comp_match = _make_component(name="lodash", comp_type="npm") - comp_no_match = _make_component(name="express", comp_type="npm") - alert = _make_alert(title="badEncoding") + def test_no_entries_means_empty_triaged_keys(self): + tf = TriageFilter([], {}) + assert tf.triaged_keys == set() - assert tf.is_alert_triaged(comp_match, alert) is True - assert tf.is_alert_triaged(comp_no_match, alert) is False + def test_entry_without_alert_key_ignored(self): + entries = [{"state": "ignore", "alert_key": None}] + tf = TriageFilter(entries, {}) + assert tf.triaged_keys == set() - def test_package_version_exact_match(self): - entry = _make_triage_entry( - alert_key="badEncoding", - package_name="lodash", - package_type="npm", - package_version="4.17.21", - ) - tf = TriageFilter([entry]) - comp_match = _make_component(name="lodash", comp_type="npm", version="4.17.21") - comp_no_match = _make_component(name="lodash", comp_type="npm", version="4.17.20") - alert = _make_alert(title="badEncoding") +# --------------------------------------------------------------------------- +# TriageFilter._local_alert_is_triaged +# --------------------------------------------------------------------------- - assert tf.is_alert_triaged(comp_match, alert) is True - assert tf.is_alert_triaged(comp_no_match, alert) is False +class TestLocalAlertIsTriaged: + def test_direct_type_match(self): + triaged_types = {"badEncoding"} + alert = _make_local_alert(alert_type="badEncoding") + assert TriageFilter._local_alert_is_triaged(alert, triaged_types) is True + + def test_direct_type_no_match(self): + triaged_types = {"badEncoding"} + alert = _make_local_alert(alert_type="cve") + assert TriageFilter._local_alert_is_triaged(alert, triaged_types) is False + + def test_generic_type_falls_back_to_title(self): + triaged_types = {"badEncoding"} + alert = _make_local_alert(title="badEncoding", alert_type="generic") + assert TriageFilter._local_alert_is_triaged(alert, triaged_types) is True + + def test_vulnerability_type_falls_back_to_cve(self): + triaged_types = {"CVE-2024-1234"} + alert = _make_local_alert( + title="Some Vuln", alert_type="vulnerability", cve_id="CVE-2024-1234" + ) + assert TriageFilter._local_alert_is_triaged(alert, triaged_types) is True - def test_version_wildcard(self): - entry = _make_triage_entry( - alert_key="badEncoding", - package_name="lodash", - package_type="npm", - package_version="4.17.*", + def test_generic_type_falls_back_to_rule_id(self): + triaged_types = {"python.lang.security.audit.xss"} + alert = _make_local_alert( + title="XSS", alert_type="generic", + rule_id="python.lang.security.audit.xss", ) - tf = TriageFilter([entry]) - alert = _make_alert(title="badEncoding") - - assert tf.is_alert_triaged( - _make_component(name="lodash", comp_type="npm", version="4.17.21"), alert - ) is True - assert tf.is_alert_triaged( - _make_component(name="lodash", comp_type="npm", version="4.17.0"), alert - ) is True - assert tf.is_alert_triaged( - _make_component(name="lodash", comp_type="npm", version="4.18.0"), alert - ) is False - - def test_version_star_matches_all(self): - entry = _make_triage_entry( - alert_key="badEncoding", - package_name="lodash", - package_type="npm", - package_version="*", + assert TriageFilter._local_alert_is_triaged(alert, triaged_types) is True + + def test_generic_type_falls_back_to_detector_name(self): + triaged_types = {"AWS"} + alert = _make_local_alert( + title="AWS Key", alert_type="generic", detector_name="AWS" ) - tf = TriageFilter([entry]) - alert = _make_alert(title="badEncoding") - assert tf.is_alert_triaged( - _make_component(name="lodash", comp_type="npm", version="99.0.0"), alert - ) is True - - def test_states_block_and_warn_not_suppressed(self): - """Triage entries with block/warn/inherit states should not filter findings.""" - for state in ("block", "warn", "inherit"): - entry = _make_triage_entry(alert_key="badEncoding", state=state) - tf = TriageFilter([entry]) - assert tf.entries == [], f"state={state} should be excluded from filter entries" - - def test_state_monitor_suppressed(self): - entry = _make_triage_entry(alert_key="badEncoding", state="monitor") - tf = TriageFilter([entry]) - comp = _make_component() - alert = _make_alert(title="badEncoding") - assert tf.is_alert_triaged(comp, alert) is True - - def test_alert_with_no_matchable_keys(self): - """Alert with no title, type, or relevant props should not match.""" - entry = _make_triage_entry(alert_key="something") - tf = TriageFilter([entry]) - comp = _make_component() - alert = {"severity": "high", "props": {}} - assert tf.is_alert_triaged(comp, alert) is False + assert TriageFilter._local_alert_is_triaged(alert, triaged_types) is True + + def test_no_fallback_candidates_returns_false(self): + triaged_types = {"something"} + alert = {"type": "generic", "props": {}} + assert TriageFilter._local_alert_is_triaged(alert, triaged_types) is False # --------------------------------------------------------------------------- @@ -206,47 +206,87 @@ def test_alert_with_no_matchable_keys(self): # --------------------------------------------------------------------------- class TestFilterComponents: - def test_removes_triaged_alerts(self): - entry = _make_triage_entry(alert_key="badEncoding") - tf = TriageFilter([entry]) - - alert_triaged = _make_alert(title="badEncoding") - alert_kept = _make_alert(title="otherIssue") - comp = _make_component(alerts=[alert_triaged, alert_kept]) + def test_removes_triaged_alert_by_type(self): + """Component ID matches artifact, triaged alert type matches local alert type.""" + entries = [_make_triage_entry("hash-1")] + artifact_alerts = _make_artifact_alerts( + alerts=[_socket_alert("hash-1", "badEncoding")] + ) + tf = TriageFilter(entries, artifact_alerts) + comp = _make_component( + comp_id=ARTIFACT_ID, + alerts=[ + _make_local_alert(alert_type="badEncoding"), + _make_local_alert(title="kept", alert_type="otherIssue"), + ], + ) filtered, count = tf.filter_components([comp]) assert count == 1 assert len(filtered) == 1 assert len(filtered[0]["alerts"]) == 1 - assert filtered[0]["alerts"][0]["title"] == "otherIssue" + assert filtered[0]["alerts"][0]["title"] == "kept" def test_removes_component_when_all_alerts_triaged(self): - entry = _make_triage_entry(alert_key="badEncoding") - tf = TriageFilter([entry]) + entries = [_make_triage_entry("hash-1")] + artifact_alerts = _make_artifact_alerts( + alerts=[_socket_alert("hash-1", "badEncoding")] + ) + tf = TriageFilter(entries, artifact_alerts) - comp = _make_component(alerts=[_make_alert(title="badEncoding")]) + comp = _make_component( + comp_id=ARTIFACT_ID, + alerts=[_make_local_alert(alert_type="badEncoding")], + ) filtered, count = tf.filter_components([comp]) assert count == 1 assert len(filtered) == 0 def test_no_triage_entries_returns_original(self): - tf = TriageFilter([]) - comp = _make_component(alerts=[_make_alert()]) + tf = TriageFilter([], {}) + comp = _make_component(alerts=[_make_local_alert()]) + filtered, count = tf.filter_components([comp]) + assert count == 0 + assert filtered == [comp] + + def test_component_id_mismatch_keeps_all_alerts(self): + """When local component ID doesn't match any artifact, nothing is filtered.""" + entries = [_make_triage_entry("hash-1")] + artifact_alerts = _make_artifact_alerts( + artifact_id="different-artifact", + alerts=[_socket_alert("hash-1", "badEncoding")], + ) + tf = TriageFilter(entries, artifact_alerts) + + comp = _make_component( + comp_id="unrelated-comp-id", + alerts=[_make_local_alert(alert_type="badEncoding")], + ) filtered, count = tf.filter_components([comp]) assert count == 0 - assert filtered is [comp] or filtered == [comp] + assert len(filtered) == 1 def test_multiple_components_mixed(self): - entry = _make_triage_entry(alert_key="badEncoding") - tf = TriageFilter([entry]) + entries = [_make_triage_entry("hash-1")] + artifact_alerts = _make_artifact_alerts( + artifact_id="art-a", + alerts=[_socket_alert("hash-1", "badEncoding")], + ) + tf = TriageFilter(entries, artifact_alerts) - comp1 = _make_component(name="a", alerts=[_make_alert(title="badEncoding")]) - comp2 = _make_component(name="b", alerts=[_make_alert(title="otherIssue")]) + comp1 = _make_component( + comp_id="art-a", name="a", + alerts=[_make_local_alert(alert_type="badEncoding")], + ) + comp2 = _make_component( + comp_id="art-b", name="b", + alerts=[_make_local_alert(alert_type="otherIssue")], + ) comp3 = _make_component( - name="c", + comp_id="art-a", name="c", alerts=[ - _make_alert(title="badEncoding"), - _make_alert(title="keepMe"), + _make_local_alert(alert_type="badEncoding"), + _make_local_alert(title="keepMe", alert_type="keepMe"), ], ) @@ -258,6 +298,153 @@ def test_multiple_components_mixed(self): assert "b" in names assert "c" in names + def test_multiple_triaged_alert_types_on_same_artifact(self): + entries = [ + _make_triage_entry("hash-1", state="ignore"), + _make_triage_entry("hash-2", state="monitor"), + ] + artifact_alerts = _make_artifact_alerts( + alerts=[ + _socket_alert("hash-1", "badEncoding"), + _socket_alert("hash-2", "cve"), + ], + ) + tf = TriageFilter(entries, artifact_alerts) + + comp = _make_component( + comp_id=ARTIFACT_ID, + alerts=[ + _make_local_alert(alert_type="badEncoding"), + _make_local_alert(alert_type="cve"), + _make_local_alert(title="safe", alert_type="safe"), + ], + ) + filtered, count = tf.filter_components([comp]) + assert count == 2 + assert len(filtered[0]["alerts"]) == 1 + assert filtered[0]["alerts"][0]["type"] == "safe" + + +# --------------------------------------------------------------------------- +# stream_full_scan_alerts +# --------------------------------------------------------------------------- + +class TestStreamFullScanAlerts: + def test_parses_artifacts_and_alerts(self): + class FakeFullscansAPI: + def stream(self, org, scan_id, use_types=False): + return { + "artifact-1": { + "name": "lodash", + "version": "4.17.21", + "type": "npm", + "namespace": None, + "alerts": [ + {"key": "hash-a", "type": "badEncoding"}, + {"key": "hash-b", "type": "cve"}, + ], + }, + "artifact-2": { + "name": "express", + "version": "4.18.0", + "type": "npm", + "namespace": None, + "alerts": [], + }, + } + + class FakeSDK: + fullscans = FakeFullscansAPI() + + result = stream_full_scan_alerts(FakeSDK(), "my-org", "scan-123") + assert "artifact-1" in result + assert "artifact-2" not in result # empty alerts filtered out + assert len(result["artifact-1"]) == 2 + assert result["artifact-1"][0]["key"] == "hash-a" + assert result["artifact-1"][0]["_artifact"]["artifact_name"] == "lodash" + + def test_skips_alerts_without_key(self): + class FakeFullscansAPI: + def stream(self, org, scan_id, use_types=False): + return { + "art-1": { + "name": "pkg", + "version": "1.0.0", + "type": "npm", + "alerts": [ + {"key": "hash-a", "type": "badEncoding"}, + {"type": "noKey"}, # missing key + {"key": "", "type": "emptyKey"}, # empty key + ], + }, + } + + class FakeSDK: + fullscans = FakeFullscansAPI() + + result = stream_full_scan_alerts(FakeSDK(), "org", "scan") + assert len(result["art-1"]) == 1 + + def test_access_denied_returns_empty(self, caplog): + class APIAccessDenied(Exception): + pass + + class FakeFullscansAPI: + def stream(self, org, scan_id, use_types=False): + raise APIAccessDenied("Forbidden") + + class FakeSDK: + fullscans = FakeFullscansAPI() + + with caplog.at_level(logging.DEBUG): + result = stream_full_scan_alerts(FakeSDK(), "org", "scan") + + assert result == {} + info_msgs = [r for r in caplog.records if r.levelno == logging.INFO] + assert any("access denied" in m.message.lower() for m in info_msgs) + + def test_api_error_returns_empty(self): + class FakeFullscansAPI: + def stream(self, org, scan_id, use_types=False): + raise RuntimeError("Network failure") + + class FakeSDK: + fullscans = FakeFullscansAPI() + + result = stream_full_scan_alerts(FakeSDK(), "org", "scan") + assert result == {} + + def test_non_dict_response_returns_empty(self): + class FakeFullscansAPI: + def stream(self, org, scan_id, use_types=False): + return "unexpected string" + + class FakeSDK: + fullscans = FakeFullscansAPI() + + result = stream_full_scan_alerts(FakeSDK(), "org", "scan") + assert result == {} + + def test_subpath_handling(self): + """Supports both camelCase and lowercase subpath field names.""" + class FakeFullscansAPI: + def stream(self, org, scan_id, use_types=False): + return { + "art-1": { + "name": "pkg", + "version": "1.0", + "type": "npm", + "subPath": "src/lib", + "alerts": [{"key": "k1", "type": "t1"}], + }, + } + + class FakeSDK: + fullscans = FakeFullscansAPI() + + result = stream_full_scan_alerts(FakeSDK(), "org", "scan") + assert result["art-1"][0]["_artifact"]["artifact_subpath"] == "src/lib" + # --------------------------------------------------------------------------- # fetch_triage_data @@ -323,14 +510,12 @@ def list_alert_triage(self, org, params): class FakeSDK: triage = FakeTriageAPI() - import logging with caplog.at_level(logging.DEBUG): entries = fetch_triage_data(FakeSDK(), "my-org") assert entries == [] info_messages = [r for r in caplog.records if r.levelno == logging.INFO] assert any("access denied" in m.message.lower() for m in info_messages) - # Should NOT produce an ERROR-level record error_messages = [r for r in caplog.records if r.levelno >= logging.ERROR] assert not error_messages @@ -385,7 +570,6 @@ def test_injects_after_heading(self): content = notifications["github_pr"][0]["content"] assert "3 finding(s) triaged" in content assert "Socket Dashboard" in content - # Summary line should appear after the # heading lines = content.split("\n") heading_idx = next(i for i, l in enumerate(lines) if l.strip().startswith("# ")) summary_idx = next(i for i, l in enumerate(lines) if "triaged" in l) From f4b2afdbb8e640566984c0e90e57cdde10fef145 Mon Sep 17 00:00:00 2001 From: Carl Bergenhem Date: Thu, 5 Feb 2026 15:09:01 -0800 Subject: [PATCH 09/13] Add triage-aware PR comment filtering Fetch triage entries from Socket API after scan submission, remove alerts with ignore/monitor state from results, regenerate connector notifications with filtered components, and inject a triage count summary into GitHub PR comments. Co-Authored-By: Claude Opus 4.6 --- socket_basics/core/triage.py | 195 +++++++++++++++++ socket_basics/socket_basics.py | 216 ++++++++++++++++++- tests/__init__.py | 0 tests/test_triage.py | 384 +++++++++++++++++++++++++++++++++ 4 files changed, 794 insertions(+), 1 deletion(-) create mode 100644 socket_basics/core/triage.py create mode 100644 tests/__init__.py create mode 100644 tests/test_triage.py diff --git a/socket_basics/core/triage.py b/socket_basics/core/triage.py new file mode 100644 index 0000000..0738f95 --- /dev/null +++ b/socket_basics/core/triage.py @@ -0,0 +1,195 @@ +"""Triage filtering for Socket Security Basics. + +Fetches triage entries from the Socket API and filters scan components +whose alerts have been triaged (state: ignore or monitor). +""" + +import fnmatch +import logging +from typing import Any, Dict, List, Tuple + +logger = logging.getLogger(__name__) + +# Triage states that cause a finding to be removed from reports +_SUPPRESSED_STATES = {"ignore", "monitor"} + + +def fetch_triage_data(sdk: Any, org_slug: str) -> List[Dict[str, Any]]: + """Fetch all triage alert entries from the Socket API, handling pagination. + + Args: + sdk: Initialized socketdev SDK instance. + org_slug: Organization slug for the API call. + + Returns: + List of triage entry dicts. + """ + all_entries: List[Dict[str, Any]] = [] + page = 1 + per_page = 100 + + while True: + try: + response = sdk.triage.list_alert_triage( + org_slug, + {"per_page": per_page, "page": page}, + ) + except Exception: + logger.exception("Failed to fetch triage data (page %d)", page) + break + + if not isinstance(response, dict): + logger.warning("Unexpected triage API response type: %s", type(response)) + break + + results = response.get("results") or [] + all_entries.extend(results) + + next_page = response.get("nextPage") + if next_page is None: + break + page = int(next_page) + + logger.debug("Fetched %d triage entries for org %s", len(all_entries), org_slug) + return all_entries + + +class TriageFilter: + """Matches local scan alerts against triage entries and filters them out.""" + + def __init__(self, triage_entries: List[Dict[str, Any]]) -> None: + # Only keep entries whose state suppresses findings + self.entries = [ + e for e in triage_entries + if (e.get("state") or "").lower() in _SUPPRESSED_STATES + ] + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def is_alert_triaged(self, component: Dict[str, Any], alert: Dict[str, Any]) -> bool: + """Return True if the alert on the given component matches a suppressed triage entry.""" + alert_keys = self._extract_alert_keys(alert) + if not alert_keys: + return False + + for entry in self.entries: + entry_key = entry.get("alert_key") + if not entry_key: + continue + + if entry_key not in alert_keys: + continue + + # alert_key matched; now check package scope + if self._is_broad_match(entry): + return True + + if self._package_matches(entry, component): + return True + + return False + + def filter_components( + self, components: List[Dict[str, Any]] + ) -> Tuple[List[Dict[str, Any]], int]: + """Remove triaged alerts from components. + + Returns: + (filtered_components, triaged_count) where triaged_count is the + total number of individual alerts removed. + """ + if not self.entries: + return components, 0 + + filtered: List[Dict[str, Any]] = [] + triaged_count = 0 + + for comp in components: + remaining_alerts: List[Dict[str, Any]] = [] + for alert in comp.get("alerts", []): + if self.is_alert_triaged(comp, alert): + triaged_count += 1 + else: + remaining_alerts.append(alert) + + if remaining_alerts: + new_comp = dict(comp) + new_comp["alerts"] = remaining_alerts + filtered.append(new_comp) + + return filtered, triaged_count + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + @staticmethod + def _extract_alert_keys(alert: Dict[str, Any]) -> set: + """Build the set of candidate keys that could match a triage entry's alert_key.""" + keys: set = set() + props = alert.get("props") or {} + + for field in ( + alert.get("title"), + alert.get("type"), + props.get("ruleId"), + props.get("detectorName"), + props.get("vulnerabilityId"), + props.get("cveId"), + ): + if field: + keys.add(str(field)) + + return keys + + @staticmethod + def _is_broad_match(entry: Dict[str, Any]) -> bool: + """Return True when the triage entry has no package scope (applies globally).""" + return ( + entry.get("package_name") is None + and entry.get("package_type") is None + and entry.get("package_version") is None + and entry.get("package_namespace") is None + ) + + @staticmethod + def _version_matches(entry_version: str, component_version: str) -> bool: + """Check version match, supporting wildcard suffix patterns like '1.2.*'.""" + if not entry_version or entry_version == "*": + return True + if not component_version: + return False + # fnmatch handles '*' and '?' glob patterns + return fnmatch.fnmatch(component_version, entry_version) + + @classmethod + def _package_matches(cls, entry: Dict[str, Any], component: Dict[str, Any]) -> bool: + """Return True if the triage entry's package scope matches the component.""" + qualifiers = component.get("qualifiers") or {} + comp_name = component.get("name") or "" + comp_type = ( + qualifiers.get("ecosystem") + or qualifiers.get("type") + or component.get("type") + or "" + ) + comp_version = component.get("version") or qualifiers.get("version") or "" + comp_namespace = qualifiers.get("namespace") or "" + + entry_name = entry.get("package_name") + entry_type = entry.get("package_type") + entry_version = entry.get("package_version") + entry_namespace = entry.get("package_namespace") + + if entry_name is not None and entry_name != comp_name: + return False + if entry_type is not None and entry_type.lower() != comp_type.lower(): + return False + if entry_namespace is not None and entry_namespace != comp_namespace: + return False + if entry_version is not None and not cls._version_matches(entry_version, comp_version): + return False + + return True diff --git a/socket_basics/socket_basics.py b/socket_basics/socket_basics.py index a7f7f04..6311611 100644 --- a/socket_basics/socket_basics.py +++ b/socket_basics/socket_basics.py @@ -17,7 +17,7 @@ import sys import os from pathlib import Path -from typing import Dict, Any, Optional +from typing import Dict, Any, List, Optional import hashlib try: # Python 3.11+ @@ -378,6 +378,214 @@ def submit_socket_facts(self, socket_facts_path: Path, results: Dict[str, Any]) return results + def apply_triage_filter(self, results: Dict[str, Any]) -> Dict[str, Any]: + """Filter out triaged alerts and regenerate notifications. + + Fetches triage entries from the Socket API, removes alerts with + state ``ignore`` or ``monitor``, regenerates connector notifications + for the remaining components, and injects a triage summary line into + github_pr notification content. + + Args: + results: Current scan results dict (components + notifications). + + Returns: + Updated results dict with triaged findings removed. + """ + socket_api_key = self.config.get('socket_api_key') + socket_org = self.config.get('socket_org') + + if not socket_api_key or not socket_org: + logger.debug("Skipping triage filter: missing socket_api_key or socket_org") + return results + + # Import SDK and triage helpers + try: + from socketdev import socketdev + except ImportError: + logger.debug("socketdev SDK not available; skipping triage filter") + return results + + try: + from .core.triage import TriageFilter, fetch_triage_data + except ImportError: + from socket_basics.core.triage import TriageFilter, fetch_triage_data + + sdk = socketdev(token=socket_api_key, timeout=100) + triage_entries = fetch_triage_data(sdk, socket_org) + + if not triage_entries: + logger.debug("No triage entries found; skipping filter") + return results + + triage_filter = TriageFilter(triage_entries) + filtered_components, triaged_count = triage_filter.filter_components( + results.get('components', []) + ) + + if triaged_count == 0: + logger.debug("No findings matched triage entries") + return results + + logger.info("Filtered %d triaged finding(s) from results", triaged_count) + results['components'] = filtered_components + results['triaged_count'] = triaged_count + + # Regenerate notifications from the filtered components + self._regenerate_notifications(results, filtered_components, triaged_count) + + return results + + def _regenerate_notifications( + self, + results: Dict[str, Any], + filtered_components: List[Dict[str, Any]], + triaged_count: int, + ) -> None: + """Regenerate connector notifications from filtered components. + + Groups components by their connector origin (via the ``generatedBy`` + field on alerts), calls each connector's ``generate_notifications``, + merges the results, and injects a triage summary into github_pr + content. + """ + connector_components: Dict[str, List[Dict[str, Any]]] = {} + for comp in filtered_components: + for alert in comp.get('alerts', []): + gen = alert.get('generatedBy') or '' + connector_name = self._connector_name_from_generated_by(gen) + if connector_name: + connector_components.setdefault(connector_name, []).append(comp) + break # one mapping per component is enough + + merged_notifications: Dict[str, list] = {} + + for connector_name, comps in connector_components.items(): + connector = self.connector_manager.loaded_connectors.get(connector_name) + if connector is None: + logger.debug("Connector %s not loaded; skipping notification regen", connector_name) + continue + + if not hasattr(connector, 'generate_notifications'): + logger.debug("Connector %s has no generate_notifications", connector_name) + continue + + try: + if connector_name == 'trivy': + item_name, scan_type = self._derive_trivy_params(comps) + notifs = connector.generate_notifications(comps, item_name, scan_type) + else: + notifs = connector.generate_notifications(comps) + except Exception: + logger.exception("Failed to regenerate notifications for %s", connector_name) + continue + + if not isinstance(notifs, dict): + continue + + for notifier_key, payload in notifs.items(): + if notifier_key not in merged_notifications: + merged_notifications[notifier_key] = payload + elif isinstance(merged_notifications[notifier_key], list) and isinstance(payload, list): + merged_notifications[notifier_key].extend(payload) + + # Inject triage summary into github_pr notification content + full_scan_url = results.get('full_scan_html_url', '') + self._inject_triage_summary(merged_notifications, triaged_count, full_scan_url) + + if merged_notifications: + results['notifications'] = merged_notifications + + @staticmethod + def _connector_name_from_generated_by(generated_by: str) -> str | None: + """Map a generatedBy value back to its connector name.""" + gb = generated_by.lower() + if gb.startswith('opengrep') or gb.startswith('sast'): + return 'opengrep' + if gb == 'trufflehog': + return 'trufflehog' + if gb.startswith('trivy'): + return 'trivy' + if gb == 'socket-tier1': + return 'socket_tier1' + return None + + def _derive_trivy_params( + self, components: List[Dict[str, Any]] + ) -> tuple: + """Derive item_name and scan_type for Trivy notification regeneration.""" + scan_type = 'image' + for comp in components: + for alert in comp.get('alerts', []): + props = alert.get('props') or {} + st = props.get('scanType', '') + if st: + scan_type = st + break + if scan_type != 'image': + break + + item_name = "Unknown" + images_str = ( + self.config.get('container_images', '') + or self.config.get('container_images_to_scan', '') + or self.config.get('docker_images', '') + ) + if images_str: + if isinstance(images_str, list): + item_name = images_str[0] if images_str else "Unknown" + else: + images = [img.strip() for img in str(images_str).split(',') if img.strip()] + item_name = images[0] if images else "Unknown" + else: + dockerfiles = self.config.get('dockerfiles', '') + if dockerfiles: + if isinstance(dockerfiles, list): + item_name = dockerfiles[0] if dockerfiles else "Unknown" + else: + docker_list = [df.strip() for df in str(dockerfiles).split(',') if df.strip()] + item_name = docker_list[0] if docker_list else "Unknown" + + if scan_type == 'vuln' and item_name == "Unknown": + try: + item_name = os.path.basename(str(self.config.workspace)) + except Exception: + item_name = "Workspace" + + return item_name, scan_type + + @staticmethod + def _inject_triage_summary( + notifications: Dict[str, list], + triaged_count: int, + full_scan_url: str, + ) -> None: + """Insert a triage summary line into github_pr notification content.""" + gh_items = notifications.get('github_pr') + if not gh_items or not isinstance(gh_items, list): + return + + dashboard_link = full_scan_url or "https://socket.dev/dashboard" + summary_line = ( + f"\n> :white_check_mark: **{triaged_count} finding(s) triaged** " + f"via [Socket Dashboard]({dashboard_link}) and removed from this report.\n" + ) + + for item in gh_items: + if not isinstance(item, dict) or 'content' not in item: + continue + content = item['content'] + # Insert after the first markdown heading line (# Title) + lines = content.split('\n') + insert_idx = 0 + for i, line in enumerate(lines): + if line.strip().startswith('# '): + insert_idx = i + 1 + break + lines.insert(insert_idx, summary_line) + item['content'] = '\n'.join(lines) + + def main(): """Main entry point""" parser = parse_cli_args() @@ -429,6 +637,12 @@ def main(): except Exception: logger.exception("Failed to submit socket facts file") + # Filter out triaged alerts before notifying + try: + results = scanner.apply_triage_filter(results) + except Exception: + logger.exception("Failed to apply triage filter") + # Optionally upload to S3 if requested try: enable_s3 = getattr(args, 'enable_s3_upload', False) or config.get('enable_s3_upload', False) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_triage.py b/tests/test_triage.py new file mode 100644 index 0000000..b1a472e --- /dev/null +++ b/tests/test_triage.py @@ -0,0 +1,384 @@ +"""Tests for socket_basics.core.triage module.""" + +import pytest +from socket_basics.core.triage import TriageFilter, fetch_triage_data + + +# --------------------------------------------------------------------------- +# Fixtures / helpers +# --------------------------------------------------------------------------- + +def _make_component( + name: str = "lodash", + comp_type: str = "npm", + version: str = "4.17.21", + alerts: list | None = None, +) -> dict: + return { + "id": f"pkg:{comp_type}/{name}@{version}", + "name": name, + "version": version, + "type": comp_type, + "qualifiers": {"ecosystem": comp_type, "version": version}, + "alerts": alerts or [], + } + + +def _make_alert( + title: str = "badEncoding", + alert_type: str = "supplyChainRisk", + severity: str = "high", + rule_id: str | None = None, + detector_name: str | None = None, + cve_id: str | None = None, + generated_by: str = "opengrep-python", +) -> dict: + props: dict = {} + if rule_id: + props["ruleId"] = rule_id + if detector_name: + props["detectorName"] = detector_name + if cve_id: + props["cveId"] = cve_id + return { + "title": title, + "type": alert_type, + "severity": severity, + "generatedBy": generated_by, + "props": props, + } + + +def _make_triage_entry( + alert_key: str, + state: str = "ignore", + package_name: str | None = None, + package_type: str | None = None, + package_version: str | None = None, + package_namespace: str | None = None, +) -> dict: + return { + "uuid": "test-uuid", + "alert_key": alert_key, + "state": state, + "package_name": package_name, + "package_type": package_type, + "package_version": package_version, + "package_namespace": package_namespace, + "note": "", + "organization_id": "test-org", + } + + +# --------------------------------------------------------------------------- +# TriageFilter.is_alert_triaged +# --------------------------------------------------------------------------- + +class TestIsAlertTriaged: + """Tests for the alert matching logic.""" + + def test_broad_match_by_title(self): + """Triage entry with no package info matches any component with matching alert_key.""" + entry = _make_triage_entry(alert_key="badEncoding") + tf = TriageFilter([entry]) + comp = _make_component() + alert = _make_alert(title="badEncoding") + assert tf.is_alert_triaged(comp, alert) is True + + def test_broad_match_by_rule_id(self): + entry = _make_triage_entry(alert_key="python.lang.security.audit.xss") + tf = TriageFilter([entry]) + comp = _make_component() + alert = _make_alert(title="XSS Vulnerability", rule_id="python.lang.security.audit.xss") + assert tf.is_alert_triaged(comp, alert) is True + + def test_broad_match_by_detector_name(self): + entry = _make_triage_entry(alert_key="AWS") + tf = TriageFilter([entry]) + comp = _make_component() + alert = _make_alert(title="AWS Key Detected", detector_name="AWS") + assert tf.is_alert_triaged(comp, alert) is True + + def test_broad_match_by_cve(self): + entry = _make_triage_entry(alert_key="CVE-2024-1234") + tf = TriageFilter([entry]) + comp = _make_component() + alert = _make_alert(title="Some Vuln", cve_id="CVE-2024-1234") + assert tf.is_alert_triaged(comp, alert) is True + + def test_no_match_different_key(self): + entry = _make_triage_entry(alert_key="differentRule") + tf = TriageFilter([entry]) + comp = _make_component() + alert = _make_alert(title="badEncoding") + assert tf.is_alert_triaged(comp, alert) is False + + def test_package_scoped_match(self): + """Triage entry with package info only matches the specific package.""" + entry = _make_triage_entry( + alert_key="badEncoding", + package_name="lodash", + package_type="npm", + ) + tf = TriageFilter([entry]) + + comp_match = _make_component(name="lodash", comp_type="npm") + comp_no_match = _make_component(name="express", comp_type="npm") + alert = _make_alert(title="badEncoding") + + assert tf.is_alert_triaged(comp_match, alert) is True + assert tf.is_alert_triaged(comp_no_match, alert) is False + + def test_package_version_exact_match(self): + entry = _make_triage_entry( + alert_key="badEncoding", + package_name="lodash", + package_type="npm", + package_version="4.17.21", + ) + tf = TriageFilter([entry]) + + comp_match = _make_component(name="lodash", comp_type="npm", version="4.17.21") + comp_no_match = _make_component(name="lodash", comp_type="npm", version="4.17.20") + alert = _make_alert(title="badEncoding") + + assert tf.is_alert_triaged(comp_match, alert) is True + assert tf.is_alert_triaged(comp_no_match, alert) is False + + def test_version_wildcard(self): + entry = _make_triage_entry( + alert_key="badEncoding", + package_name="lodash", + package_type="npm", + package_version="4.17.*", + ) + tf = TriageFilter([entry]) + alert = _make_alert(title="badEncoding") + + assert tf.is_alert_triaged( + _make_component(name="lodash", comp_type="npm", version="4.17.21"), alert + ) is True + assert tf.is_alert_triaged( + _make_component(name="lodash", comp_type="npm", version="4.17.0"), alert + ) is True + assert tf.is_alert_triaged( + _make_component(name="lodash", comp_type="npm", version="4.18.0"), alert + ) is False + + def test_version_star_matches_all(self): + entry = _make_triage_entry( + alert_key="badEncoding", + package_name="lodash", + package_type="npm", + package_version="*", + ) + tf = TriageFilter([entry]) + alert = _make_alert(title="badEncoding") + assert tf.is_alert_triaged( + _make_component(name="lodash", comp_type="npm", version="99.0.0"), alert + ) is True + + def test_states_block_and_warn_not_suppressed(self): + """Triage entries with block/warn/inherit states should not filter findings.""" + for state in ("block", "warn", "inherit"): + entry = _make_triage_entry(alert_key="badEncoding", state=state) + tf = TriageFilter([entry]) + assert tf.entries == [], f"state={state} should be excluded from filter entries" + + def test_state_monitor_suppressed(self): + entry = _make_triage_entry(alert_key="badEncoding", state="monitor") + tf = TriageFilter([entry]) + comp = _make_component() + alert = _make_alert(title="badEncoding") + assert tf.is_alert_triaged(comp, alert) is True + + def test_alert_with_no_matchable_keys(self): + """Alert with no title, type, or relevant props should not match.""" + entry = _make_triage_entry(alert_key="something") + tf = TriageFilter([entry]) + comp = _make_component() + alert = {"severity": "high", "props": {}} + assert tf.is_alert_triaged(comp, alert) is False + + +# --------------------------------------------------------------------------- +# TriageFilter.filter_components +# --------------------------------------------------------------------------- + +class TestFilterComponents: + def test_removes_triaged_alerts(self): + entry = _make_triage_entry(alert_key="badEncoding") + tf = TriageFilter([entry]) + + alert_triaged = _make_alert(title="badEncoding") + alert_kept = _make_alert(title="otherIssue") + comp = _make_component(alerts=[alert_triaged, alert_kept]) + + filtered, count = tf.filter_components([comp]) + assert count == 1 + assert len(filtered) == 1 + assert len(filtered[0]["alerts"]) == 1 + assert filtered[0]["alerts"][0]["title"] == "otherIssue" + + def test_removes_component_when_all_alerts_triaged(self): + entry = _make_triage_entry(alert_key="badEncoding") + tf = TriageFilter([entry]) + + comp = _make_component(alerts=[_make_alert(title="badEncoding")]) + filtered, count = tf.filter_components([comp]) + assert count == 1 + assert len(filtered) == 0 + + def test_no_triage_entries_returns_original(self): + tf = TriageFilter([]) + comp = _make_component(alerts=[_make_alert()]) + filtered, count = tf.filter_components([comp]) + assert count == 0 + assert filtered is [comp] or filtered == [comp] + + def test_multiple_components_mixed(self): + entry = _make_triage_entry(alert_key="badEncoding") + tf = TriageFilter([entry]) + + comp1 = _make_component(name="a", alerts=[_make_alert(title="badEncoding")]) + comp2 = _make_component(name="b", alerts=[_make_alert(title="otherIssue")]) + comp3 = _make_component( + name="c", + alerts=[ + _make_alert(title="badEncoding"), + _make_alert(title="keepMe"), + ], + ) + + filtered, count = tf.filter_components([comp1, comp2, comp3]) + assert count == 2 + assert len(filtered) == 2 + names = [c["name"] for c in filtered] + assert "a" not in names + assert "b" in names + assert "c" in names + + +# --------------------------------------------------------------------------- +# fetch_triage_data +# --------------------------------------------------------------------------- + +class TestFetchTriageData: + def test_single_page(self): + class FakeTriageAPI: + def list_alert_triage(self, org, params): + return {"results": [{"alert_key": "a", "state": "ignore"}], "nextPage": None} + + class FakeSDK: + triage = FakeTriageAPI() + + entries = fetch_triage_data(FakeSDK(), "my-org") + assert len(entries) == 1 + assert entries[0]["alert_key"] == "a" + + def test_pagination(self): + class FakeTriageAPI: + def __init__(self): + self.call_count = 0 + + def list_alert_triage(self, org, params): + self.call_count += 1 + if params.get("page") == 1: + return {"results": [{"alert_key": "a"}], "nextPage": 2} + return {"results": [{"alert_key": "b"}], "nextPage": None} + + class FakeSDK: + triage = FakeTriageAPI() + + entries = fetch_triage_data(FakeSDK(), "my-org") + assert len(entries) == 2 + + def test_api_error_returns_partial(self): + class FakeTriageAPI: + def __init__(self): + self.calls = 0 + + def list_alert_triage(self, org, params): + self.calls += 1 + if self.calls == 1: + return {"results": [{"alert_key": "a"}], "nextPage": 2} + raise RuntimeError("API error") + + class FakeSDK: + triage = FakeTriageAPI() + + entries = fetch_triage_data(FakeSDK(), "my-org") + assert len(entries) == 1 + + +# --------------------------------------------------------------------------- +# SecurityScanner._connector_name_from_generated_by +# --------------------------------------------------------------------------- + +class TestConnectorNameMapping: + def test_opengrep_variants(self): + from socket_basics.socket_basics import SecurityScanner + assert SecurityScanner._connector_name_from_generated_by("opengrep-python") == "opengrep" + assert SecurityScanner._connector_name_from_generated_by("sast-generic") == "opengrep" + + def test_trufflehog(self): + from socket_basics.socket_basics import SecurityScanner + assert SecurityScanner._connector_name_from_generated_by("trufflehog") == "trufflehog" + + def test_trivy_variants(self): + from socket_basics.socket_basics import SecurityScanner + assert SecurityScanner._connector_name_from_generated_by("trivy-dockerfile") == "trivy" + assert SecurityScanner._connector_name_from_generated_by("trivy-image") == "trivy" + assert SecurityScanner._connector_name_from_generated_by("trivy-npm") == "trivy" + + def test_socket_tier1(self): + from socket_basics.socket_basics import SecurityScanner + assert SecurityScanner._connector_name_from_generated_by("socket-tier1") == "socket_tier1" + + def test_unknown_returns_none(self): + from socket_basics.socket_basics import SecurityScanner + assert SecurityScanner._connector_name_from_generated_by("unknown-tool") is None + + +# --------------------------------------------------------------------------- +# SecurityScanner._inject_triage_summary +# --------------------------------------------------------------------------- + +class TestInjectTriageSummary: + def test_injects_after_heading(self): + from socket_basics.socket_basics import SecurityScanner + + notifications = { + "github_pr": [ + { + "title": "SAST Findings", + "content": "\n# SAST Python Findings\n### Summary\nSome content\n", + } + ] + } + SecurityScanner._inject_triage_summary(notifications, 3, "https://socket.dev/scan/123") + + content = notifications["github_pr"][0]["content"] + assert "3 finding(s) triaged" in content + assert "Socket Dashboard" in content + # Summary line should appear after the # heading + lines = content.split("\n") + heading_idx = next(i for i, l in enumerate(lines) if l.strip().startswith("# ")) + summary_idx = next(i for i, l in enumerate(lines) if "triaged" in l) + assert summary_idx > heading_idx + + def test_no_github_pr_key_is_noop(self): + from socket_basics.socket_basics import SecurityScanner + + notifications = {"slack": [{"title": "t", "content": "c"}]} + SecurityScanner._inject_triage_summary(notifications, 5, "") + assert "github_pr" not in notifications + + def test_uses_default_dashboard_link(self): + from socket_basics.socket_basics import SecurityScanner + + notifications = { + "github_pr": [{"title": "t", "content": "# Title\nBody"}] + } + SecurityScanner._inject_triage_summary(notifications, 1, "") + assert "https://socket.dev/dashboard" in notifications["github_pr"][0]["content"] From 471bc236a265d71f1bf048fa431f6b799c77e95d Mon Sep 17 00:00:00 2001 From: Carl Bergenhem Date: Thu, 5 Feb 2026 15:36:21 -0800 Subject: [PATCH 10/13] Handle triage API access denied gracefully Log an info-level message instead of an error traceback when the Socket API token lacks triage permissions, and skip filtering so the scan completes normally with all findings intact. Co-Authored-By: Claude Opus 4.6 --- socket_basics/core/triage.py | 13 +++++++++++-- tests/test_triage.py | 24 ++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/socket_basics/core/triage.py b/socket_basics/core/triage.py index 0738f95..f9b9796 100644 --- a/socket_basics/core/triage.py +++ b/socket_basics/core/triage.py @@ -34,8 +34,17 @@ def fetch_triage_data(sdk: Any, org_slug: str) -> List[Dict[str, Any]]: org_slug, {"per_page": per_page, "page": page}, ) - except Exception: - logger.exception("Failed to fetch triage data (page %d)", page) + except Exception as exc: + # Handle insufficient permissions gracefully so the scan + # continues without triage filtering. + exc_name = type(exc).__name__ + if "AccessDenied" in exc_name or "Forbidden" in exc_name: + logger.info( + "Triage API access denied (insufficient permissions). " + "Skipping triage filtering for this run." + ) + else: + logger.warning("Failed to fetch triage data (page %d): %s", page, exc) break if not isinstance(response, dict): diff --git a/tests/test_triage.py b/tests/test_triage.py index b1a472e..c0c4954 100644 --- a/tests/test_triage.py +++ b/tests/test_triage.py @@ -310,6 +310,30 @@ class FakeSDK: entries = fetch_triage_data(FakeSDK(), "my-org") assert len(entries) == 1 + def test_access_denied_returns_empty_and_logs_info(self, caplog): + """Insufficient permissions should log an info message (not an error) and return empty.""" + + class APIAccessDenied(Exception): + pass + + class FakeTriageAPI: + def list_alert_triage(self, org, params): + raise APIAccessDenied("Insufficient permissions.") + + class FakeSDK: + triage = FakeTriageAPI() + + import logging + with caplog.at_level(logging.DEBUG): + entries = fetch_triage_data(FakeSDK(), "my-org") + + assert entries == [] + info_messages = [r for r in caplog.records if r.levelno == logging.INFO] + assert any("access denied" in m.message.lower() for m in info_messages) + # Should NOT produce an ERROR-level record + error_messages = [r for r in caplog.records if r.levelno >= logging.ERROR] + assert not error_messages + # --------------------------------------------------------------------------- # SecurityScanner._connector_name_from_generated_by From 789106f24350c104705b66f302ed6f6565e448ba Mon Sep 17 00:00:00 2001 From: Carl Bergenhem Date: Thu, 5 Feb 2026 15:46:17 -0800 Subject: [PATCH 11/13] Fix stale notifications after triage and improve logging Always replace results['notifications'] after triage filtering so pre-filter content is never forwarded to notifiers. Skip PR comment API calls when content is unchanged. Add info-level logging for triaged/remaining finding counts and connector regeneration details. Co-Authored-By: Claude Opus 4.6 --- .../core/notification/github_pr_notifier.py | 12 ++++ socket_basics/socket_basics.py | 60 +++++++++++++++++-- 2 files changed, 66 insertions(+), 6 deletions(-) diff --git a/socket_basics/core/notification/github_pr_notifier.py b/socket_basics/core/notification/github_pr_notifier.py index 555d03e..8c258bc 100644 --- a/socket_basics/core/notification/github_pr_notifier.py +++ b/socket_basics/core/notification/github_pr_notifier.py @@ -100,6 +100,18 @@ def notify(self, facts: Dict[str, Any]) -> None: # Update existing comments with new section content for comment_id, updated_body in comment_updates.items(): + # Detect whether content actually changed before making the API call + original_body = next( + (c.get('body', '') for c in existing_comments if c.get('id') == comment_id), + '', + ) + if original_body == updated_body: + logger.info( + 'GithubPRNotifier: comment %s content unchanged; skipping update', + comment_id, + ) + continue + success = self._update_comment(pr_number, comment_id, updated_body) if success: logger.info('GithubPRNotifier: updated existing comment %s', comment_id) diff --git a/socket_basics/socket_basics.py b/socket_basics/socket_basics.py index 6311611..984133e 100644 --- a/socket_basics/socket_basics.py +++ b/socket_basics/socket_basics.py @@ -419,15 +419,29 @@ def apply_triage_filter(self, results: Dict[str, Any]) -> Dict[str, Any]: return results triage_filter = TriageFilter(triage_entries) + original_components = results.get('components', []) + original_alert_count = sum( + len(c.get('alerts', [])) for c in original_components + ) filtered_components, triaged_count = triage_filter.filter_components( - results.get('components', []) + original_components ) if triaged_count == 0: - logger.debug("No findings matched triage entries") + logger.info( + "Triage filter matched 0 of %d finding(s); no changes applied", + original_alert_count, + ) return results - logger.info("Filtered %d triaged finding(s) from results", triaged_count) + remaining_alert_count = sum( + len(c.get('alerts', [])) for c in filtered_components + ) + logger.info( + "Triage filter removed %d finding(s); %d finding(s) remain", + triaged_count, + remaining_alert_count, + ) results['components'] = filtered_components results['triaged_count'] = triaged_count @@ -448,22 +462,47 @@ def _regenerate_notifications( field on alerts), calls each connector's ``generate_notifications``, merges the results, and injects a triage summary into github_pr content. + + Always replaces ``results['notifications']`` so stale pre-filter + notifications are never forwarded to notifiers. """ connector_components: Dict[str, List[Dict[str, Any]]] = {} + unmapped_count = 0 for comp in filtered_components: + mapped = False for alert in comp.get('alerts', []): gen = alert.get('generatedBy') or '' connector_name = self._connector_name_from_generated_by(gen) if connector_name: connector_components.setdefault(connector_name, []).append(comp) + mapped = True break # one mapping per component is enough + if not mapped: + unmapped_count += 1 + + if unmapped_count: + logger.debug( + "Triage regen: %d component(s) could not be mapped to a connector", + unmapped_count, + ) + + logger.info( + "Regenerating notifications for %d connector(s): %s", + len(connector_components), + ", ".join(connector_components.keys()) or "(none)", + ) merged_notifications: Dict[str, list] = {} for connector_name, comps in connector_components.items(): connector = self.connector_manager.loaded_connectors.get(connector_name) if connector is None: - logger.debug("Connector %s not loaded; skipping notification regen", connector_name) + logger.warning( + "Connector %s not in loaded_connectors (available: %s); " + "cannot regenerate its notifications", + connector_name, + ", ".join(self.connector_manager.loaded_connectors.keys()), + ) continue if not hasattr(connector, 'generate_notifications'): @@ -483,6 +522,13 @@ def _regenerate_notifications( if not isinstance(notifs, dict): continue + notifier_keys = [k for k, v in notifs.items() if v] + logger.debug( + "Connector %s produced notifications for: %s", + connector_name, + ", ".join(notifier_keys) or "(empty)", + ) + for notifier_key, payload in notifs.items(): if notifier_key not in merged_notifications: merged_notifications[notifier_key] = payload @@ -493,8 +539,10 @@ def _regenerate_notifications( full_scan_url = results.get('full_scan_html_url', '') self._inject_triage_summary(merged_notifications, triaged_count, full_scan_url) - if merged_notifications: - results['notifications'] = merged_notifications + # Always replace notifications so stale pre-filter content is never + # forwarded to notifiers. An empty dict is valid and means every + # finding was triaged. + results['notifications'] = merged_notifications @staticmethod def _connector_name_from_generated_by(generated_by: str) -> str | None: From c8ac9946875c0b24b30e8be43358bc90c71730eb Mon Sep 17 00:00:00 2001 From: Carl Bergenhem Date: Thu, 5 Feb 2026 16:43:48 -0800 Subject: [PATCH 12/13] Rework triage matching to use stream-based alert key lookup The triage API returns opaque alert_key hashes, not human-readable identifiers. This rewrites the matching logic to stream the full scan via sdk.fullscans.stream(), cross-reference Socket alert keys against triage entries, and map back to local components by artifact ID. Co-Authored-By: Claude Opus 4.6 --- socket_basics/core/triage.py | 273 +++++++++++------- socket_basics/socket_basics.py | 39 ++- tests/test_triage.py | 490 +++++++++++++++++++++++---------- 3 files changed, 536 insertions(+), 266 deletions(-) diff --git a/socket_basics/core/triage.py b/socket_basics/core/triage.py index f9b9796..59bf6ce 100644 --- a/socket_basics/core/triage.py +++ b/socket_basics/core/triage.py @@ -1,12 +1,12 @@ """Triage filtering for Socket Security Basics. -Fetches triage entries from the Socket API and filters scan components -whose alerts have been triaged (state: ignore or monitor). +Streams the full scan from the Socket API to obtain alert keys, fetches +triage entries, and filters local scan components whose alerts have been +triaged (state: ignore or monitor). """ -import fnmatch import logging -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Set, Tuple logger = logging.getLogger(__name__) @@ -14,6 +14,10 @@ _SUPPRESSED_STATES = {"ignore", "monitor"} +# ------------------------------------------------------------------ +# API helpers +# ------------------------------------------------------------------ + def fetch_triage_data(sdk: Any, org_slug: str) -> List[Dict[str, Any]]: """Fetch all triage alert entries from the Socket API, handling pagination. @@ -63,62 +67,145 @@ def fetch_triage_data(sdk: Any, org_slug: str) -> List[Dict[str, Any]]: return all_entries -class TriageFilter: - """Matches local scan alerts against triage entries and filters them out.""" +def stream_full_scan_alerts( + sdk: Any, org_slug: str, full_scan_id: str +) -> Dict[str, List[Dict[str, Any]]]: + """Stream a full scan and extract alert keys grouped by artifact. + + Returns: + Mapping of artifact ID to list of alert dicts. Each alert dict + contains at minimum ``key`` and ``type``. The artifact metadata + (name, version, type, etc.) is included under a ``_artifact`` key + in every alert dict for downstream matching. + """ + try: + # use_types=False returns a plain dict keyed by artifact ID + response = sdk.fullscans.stream(org_slug, full_scan_id, use_types=False) + except Exception as exc: + exc_name = type(exc).__name__ + if "AccessDenied" in exc_name or "Forbidden" in exc_name: + logger.info( + "Full scan stream access denied (insufficient permissions). " + "Skipping triage filtering for this run." + ) + else: + logger.warning("Failed to stream full scan %s: %s", full_scan_id, exc) + return {} + + if not isinstance(response, dict): + logger.warning("Unexpected full scan stream response type: %s", type(response)) + return {} + + artifact_alerts: Dict[str, List[Dict[str, Any]]] = {} + for artifact_id, artifact in response.items(): + if not isinstance(artifact, dict): + continue + alerts = artifact.get("alerts") or [] + if not alerts: + continue + meta = { + "artifact_id": artifact_id, + "artifact_name": artifact.get("name"), + "artifact_version": artifact.get("version"), + "artifact_type": artifact.get("type"), + "artifact_namespace": artifact.get("namespace"), + "artifact_subpath": artifact.get("subPath") or artifact.get("subpath"), + } + enriched = [] + for a in alerts: + if isinstance(a, dict) and a.get("key"): + enriched.append({**a, "_artifact": meta}) + if enriched: + artifact_alerts[artifact_id] = enriched + + total_alerts = sum(len(v) for v in artifact_alerts.values()) + logger.debug( + "Streamed full scan %s: %d artifact(s), %d alert(s) with keys", + full_scan_id, + len(artifact_alerts), + total_alerts, + ) + return artifact_alerts + + +# ------------------------------------------------------------------ +# TriageFilter +# ------------------------------------------------------------------ - def __init__(self, triage_entries: List[Dict[str, Any]]) -> None: - # Only keep entries whose state suppresses findings - self.entries = [ - e for e in triage_entries - if (e.get("state") or "").lower() in _SUPPRESSED_STATES - ] +class TriageFilter: + """Cross-references Socket alert keys against triage entries and + maps triaged alerts back to local scan components.""" + + def __init__( + self, + triage_entries: List[Dict[str, Any]], + artifact_alerts: Dict[str, List[Dict[str, Any]]], + ) -> None: + # Build set of suppressed alert keys + self.triaged_keys: Set[str] = set() + for entry in triage_entries: + state = (entry.get("state") or "").lower() + key = entry.get("alert_key") + if state in _SUPPRESSED_STATES and key: + self.triaged_keys.add(key) + + # Flatten all Socket alerts for lookup + self._socket_alerts: List[Dict[str, Any]] = [] + for alerts in artifact_alerts.values(): + self._socket_alerts.extend(alerts) + + # Build a mapping from (artifact_id, alert_type) to triaged status + # for fast lookups when matching against local components + self._triaged_by_artifact: Dict[str, Set[str]] = {} + for alert in self._socket_alerts: + if alert.get("key") in self.triaged_keys: + art_id = alert.get("_artifact", {}).get("artifact_id", "") + alert_type = alert.get("type") or "" + self._triaged_by_artifact.setdefault(art_id, set()).add(alert_type) # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ - def is_alert_triaged(self, component: Dict[str, Any], alert: Dict[str, Any]) -> bool: - """Return True if the alert on the given component matches a suppressed triage entry.""" - alert_keys = self._extract_alert_keys(alert) - if not alert_keys: - return False - - for entry in self.entries: - entry_key = entry.get("alert_key") - if not entry_key: - continue - - if entry_key not in alert_keys: - continue - - # alert_key matched; now check package scope - if self._is_broad_match(entry): - return True - - if self._package_matches(entry, component): - return True - - return False - def filter_components( self, components: List[Dict[str, Any]] ) -> Tuple[List[Dict[str, Any]], int]: - """Remove triaged alerts from components. + """Remove triaged alerts from local components. + + Matches local components to Socket artifacts by component ID, then + checks each local alert against the set of triaged alert types for + that artifact. Returns: - (filtered_components, triaged_count) where triaged_count is the - total number of individual alerts removed. + (filtered_components, triaged_count) """ - if not self.entries: + if not self.triaged_keys: + return components, 0 + + # Build lookup: component id -> set of triaged Socket alert types + triaged_types_by_component = self._map_components_to_triaged_types(components) + + if not triaged_types_by_component: + logger.debug( + "No local components matched Socket artifacts with triaged alerts" + ) return components, 0 filtered: List[Dict[str, Any]] = [] triaged_count = 0 for comp in components: + comp_id = comp.get("id") or "" + triaged_types = triaged_types_by_component.get(comp_id) + + if triaged_types is None: + # Component had no triaged alerts; keep as-is + filtered.append(comp) + continue + remaining_alerts: List[Dict[str, Any]] = [] for alert in comp.get("alerts", []): - if self.is_alert_triaged(comp, alert): + if self._local_alert_is_triaged(alert, triaged_types): triaged_count += 1 else: remaining_alerts.append(alert) @@ -134,71 +221,49 @@ def filter_components( # Internal helpers # ------------------------------------------------------------------ - @staticmethod - def _extract_alert_keys(alert: Dict[str, Any]) -> set: - """Build the set of candidate keys that could match a triage entry's alert_key.""" - keys: set = set() - props = alert.get("props") or {} - - for field in ( - alert.get("title"), - alert.get("type"), - props.get("ruleId"), - props.get("detectorName"), - props.get("vulnerabilityId"), - props.get("cveId"), - ): - if field: - keys.add(str(field)) + def _map_components_to_triaged_types( + self, components: List[Dict[str, Any]] + ) -> Dict[str, Set[str]]: + """Map local component IDs to the set of triaged Socket alert types. - return keys + Matches by component ``id`` (which is typically a hash that Socket + also uses as the artifact ID). + """ + local_ids = {comp.get("id") for comp in components if comp.get("id")} + result: Dict[str, Set[str]] = {} + for comp_id in local_ids: + triaged = self._triaged_by_artifact.get(comp_id) + if triaged: + result[comp_id] = triaged + return result @staticmethod - def _is_broad_match(entry: Dict[str, Any]) -> bool: - """Return True when the triage entry has no package scope (applies globally).""" - return ( - entry.get("package_name") is None - and entry.get("package_type") is None - and entry.get("package_version") is None - and entry.get("package_namespace") is None - ) + def _local_alert_is_triaged( + alert: Dict[str, Any], triaged_types: Set[str] + ) -> bool: + """Check if a local alert matches any of the triaged Socket alert types. + + Socket alert ``type`` values (e.g. ``badEncoding``, ``cve``) are + compared against the local alert's ``type`` field. When the local + alert type is too generic (``"generic"`` or ``"vulnerability"``), + we fall back to matching on ``title``, ``props.ruleId``, or + ``props.vulnerabilityId``. + """ + # Direct type match + local_type = alert.get("type") or "" + if local_type and local_type not in ("generic", "vulnerability"): + return local_type in triaged_types - @staticmethod - def _version_matches(entry_version: str, component_version: str) -> bool: - """Check version match, supporting wildcard suffix patterns like '1.2.*'.""" - if not entry_version or entry_version == "*": - return True - if not component_version: - return False - # fnmatch handles '*' and '?' glob patterns - return fnmatch.fnmatch(component_version, entry_version) - - @classmethod - def _package_matches(cls, entry: Dict[str, Any], component: Dict[str, Any]) -> bool: - """Return True if the triage entry's package scope matches the component.""" - qualifiers = component.get("qualifiers") or {} - comp_name = component.get("name") or "" - comp_type = ( - qualifiers.get("ecosystem") - or qualifiers.get("type") - or component.get("type") - or "" - ) - comp_version = component.get("version") or qualifiers.get("version") or "" - comp_namespace = qualifiers.get("namespace") or "" - - entry_name = entry.get("package_name") - entry_type = entry.get("package_type") - entry_version = entry.get("package_version") - entry_namespace = entry.get("package_namespace") - - if entry_name is not None and entry_name != comp_name: - return False - if entry_type is not None and entry_type.lower() != comp_type.lower(): - return False - if entry_namespace is not None and entry_namespace != comp_namespace: - return False - if entry_version is not None and not cls._version_matches(entry_version, comp_version): - return False - - return True + # Fallback: match candidate fields against triaged types + props = alert.get("props") or {} + candidates = { + v for v in ( + alert.get("title"), + props.get("ruleId"), + props.get("detectorName"), + props.get("vulnerabilityId"), + props.get("cveId"), + ) + if v + } + return bool(candidates & triaged_types) diff --git a/socket_basics/socket_basics.py b/socket_basics/socket_basics.py index 984133e..c5683b2 100644 --- a/socket_basics/socket_basics.py +++ b/socket_basics/socket_basics.py @@ -381,10 +381,10 @@ def submit_socket_facts(self, socket_facts_path: Path, results: Dict[str, Any]) def apply_triage_filter(self, results: Dict[str, Any]) -> Dict[str, Any]: """Filter out triaged alerts and regenerate notifications. - Fetches triage entries from the Socket API, removes alerts with - state ``ignore`` or ``monitor``, regenerates connector notifications - for the remaining components, and injects a triage summary line into - github_pr notification content. + Streams the full scan from the Socket API to obtain alert keys, + cross-references them with triage entries, removes suppressed + alerts from local components, regenerates connector notifications, + and injects a triage summary into github_pr content. Args: results: Current scan results dict (components + notifications). @@ -394,11 +394,16 @@ def apply_triage_filter(self, results: Dict[str, Any]) -> Dict[str, Any]: """ socket_api_key = self.config.get('socket_api_key') socket_org = self.config.get('socket_org') + full_scan_id = results.get('full_scan_id') if not socket_api_key or not socket_org: logger.debug("Skipping triage filter: missing socket_api_key or socket_org") return results + if not full_scan_id: + logger.debug("Skipping triage filter: no full_scan_id in results") + return results + # Import SDK and triage helpers try: from socketdev import socketdev @@ -407,18 +412,34 @@ def apply_triage_filter(self, results: Dict[str, Any]) -> Dict[str, Any]: return results try: - from .core.triage import TriageFilter, fetch_triage_data + from .core.triage import TriageFilter, fetch_triage_data, stream_full_scan_alerts except ImportError: - from socket_basics.core.triage import TriageFilter, fetch_triage_data + from socket_basics.core.triage import TriageFilter, fetch_triage_data, stream_full_scan_alerts sdk = socketdev(token=socket_api_key, timeout=100) - triage_entries = fetch_triage_data(sdk, socket_org) + # Fetch triage entries and stream full scan alert keys in sequence + triage_entries = fetch_triage_data(sdk, socket_org) if not triage_entries: - logger.debug("No triage entries found; skipping filter") + logger.info("No triage entries found; skipping filter") + return results + + suppressed_count = sum( + 1 for e in triage_entries + if (e.get("state") or "").lower() in ("ignore", "monitor") + ) + logger.info( + "Fetched %d triage entries (%d with suppressed state)", + len(triage_entries), + suppressed_count, + ) + + artifact_alerts = stream_full_scan_alerts(sdk, socket_org, full_scan_id) + if not artifact_alerts: + logger.info("No alert keys returned from full scan stream; skipping filter") return results - triage_filter = TriageFilter(triage_entries) + triage_filter = TriageFilter(triage_entries, artifact_alerts) original_components = results.get('components', []) original_alert_count = sum( len(c.get('alerts', [])) for c in original_components diff --git a/tests/test_triage.py b/tests/test_triage.py index c0c4954..a05e6e5 100644 --- a/tests/test_triage.py +++ b/tests/test_triage.py @@ -1,21 +1,30 @@ """Tests for socket_basics.core.triage module.""" +import logging import pytest -from socket_basics.core.triage import TriageFilter, fetch_triage_data +from socket_basics.core.triage import ( + TriageFilter, + fetch_triage_data, + stream_full_scan_alerts, +) # --------------------------------------------------------------------------- # Fixtures / helpers # --------------------------------------------------------------------------- +ARTIFACT_ID = "abc123" + + def _make_component( + comp_id: str = ARTIFACT_ID, name: str = "lodash", comp_type: str = "npm", version: str = "4.17.21", alerts: list | None = None, ) -> dict: return { - "id": f"pkg:{comp_type}/{name}@{version}", + "id": comp_id, "name": name, "version": version, "type": comp_type, @@ -24,9 +33,9 @@ def _make_component( } -def _make_alert( +def _make_local_alert( title: str = "badEncoding", - alert_type: str = "supplyChainRisk", + alert_type: str = "badEncoding", severity: str = "high", rule_id: str | None = None, detector_name: str | None = None, @@ -52,153 +61,144 @@ def _make_alert( def _make_triage_entry( alert_key: str, state: str = "ignore", - package_name: str | None = None, - package_type: str | None = None, - package_version: str | None = None, - package_namespace: str | None = None, ) -> dict: return { "uuid": "test-uuid", "alert_key": alert_key, "state": state, - "package_name": package_name, - "package_type": package_type, - "package_version": package_version, - "package_namespace": package_namespace, "note": "", "organization_id": "test-org", } +def _make_artifact_alerts( + artifact_id: str = ARTIFACT_ID, + alerts: list[dict] | None = None, + name: str = "lodash", + version: str = "4.17.21", + pkg_type: str = "npm", +) -> dict[str, list[dict]]: + """Build an artifact_alerts mapping with enriched _artifact metadata.""" + meta = { + "artifact_id": artifact_id, + "artifact_name": name, + "artifact_version": version, + "artifact_type": pkg_type, + "artifact_namespace": None, + "artifact_subpath": None, + } + enriched = [{**a, "_artifact": meta} for a in (alerts or [])] + return {artifact_id: enriched} + + +def _socket_alert(key: str, alert_type: str) -> dict: + """Create a minimal Socket alert dict (as returned by the full scan stream).""" + return {"key": key, "type": alert_type} + + # --------------------------------------------------------------------------- -# TriageFilter.is_alert_triaged +# TriageFilter construction # --------------------------------------------------------------------------- -class TestIsAlertTriaged: - """Tests for the alert matching logic.""" - - def test_broad_match_by_title(self): - """Triage entry with no package info matches any component with matching alert_key.""" - entry = _make_triage_entry(alert_key="badEncoding") - tf = TriageFilter([entry]) - comp = _make_component() - alert = _make_alert(title="badEncoding") - assert tf.is_alert_triaged(comp, alert) is True - - def test_broad_match_by_rule_id(self): - entry = _make_triage_entry(alert_key="python.lang.security.audit.xss") - tf = TriageFilter([entry]) - comp = _make_component() - alert = _make_alert(title="XSS Vulnerability", rule_id="python.lang.security.audit.xss") - assert tf.is_alert_triaged(comp, alert) is True - - def test_broad_match_by_detector_name(self): - entry = _make_triage_entry(alert_key="AWS") - tf = TriageFilter([entry]) - comp = _make_component() - alert = _make_alert(title="AWS Key Detected", detector_name="AWS") - assert tf.is_alert_triaged(comp, alert) is True - - def test_broad_match_by_cve(self): - entry = _make_triage_entry(alert_key="CVE-2024-1234") - tf = TriageFilter([entry]) - comp = _make_component() - alert = _make_alert(title="Some Vuln", cve_id="CVE-2024-1234") - assert tf.is_alert_triaged(comp, alert) is True - - def test_no_match_different_key(self): - entry = _make_triage_entry(alert_key="differentRule") - tf = TriageFilter([entry]) - comp = _make_component() - alert = _make_alert(title="badEncoding") - assert tf.is_alert_triaged(comp, alert) is False - - def test_package_scoped_match(self): - """Triage entry with package info only matches the specific package.""" - entry = _make_triage_entry( - alert_key="badEncoding", - package_name="lodash", - package_type="npm", +class TestTriageFilterInit: + def test_builds_triaged_keys_for_ignore(self): + entries = [_make_triage_entry("hash-1", state="ignore")] + artifact_alerts = _make_artifact_alerts( + alerts=[_socket_alert("hash-1", "badEncoding")] + ) + tf = TriageFilter(entries, artifact_alerts) + assert "hash-1" in tf.triaged_keys + + def test_builds_triaged_keys_for_monitor(self): + entries = [_make_triage_entry("hash-2", state="monitor")] + artifact_alerts = _make_artifact_alerts( + alerts=[_socket_alert("hash-2", "cve")] + ) + tf = TriageFilter(entries, artifact_alerts) + assert "hash-2" in tf.triaged_keys + + def test_excludes_block_warn_inherit_states(self): + entries = [ + _make_triage_entry("h1", state="block"), + _make_triage_entry("h2", state="warn"), + _make_triage_entry("h3", state="inherit"), + ] + artifact_alerts = _make_artifact_alerts( + alerts=[ + _socket_alert("h1", "a"), + _socket_alert("h2", "b"), + _socket_alert("h3", "c"), + ] + ) + tf = TriageFilter(entries, artifact_alerts) + assert tf.triaged_keys == set() + + def test_builds_triaged_by_artifact_mapping(self): + entries = [_make_triage_entry("hash-1", state="ignore")] + artifact_alerts = _make_artifact_alerts( + artifact_id="art-1", + alerts=[_socket_alert("hash-1", "badEncoding")], ) - tf = TriageFilter([entry]) + tf = TriageFilter(entries, artifact_alerts) + assert "art-1" in tf._triaged_by_artifact + assert "badEncoding" in tf._triaged_by_artifact["art-1"] - comp_match = _make_component(name="lodash", comp_type="npm") - comp_no_match = _make_component(name="express", comp_type="npm") - alert = _make_alert(title="badEncoding") + def test_no_entries_means_empty_triaged_keys(self): + tf = TriageFilter([], {}) + assert tf.triaged_keys == set() - assert tf.is_alert_triaged(comp_match, alert) is True - assert tf.is_alert_triaged(comp_no_match, alert) is False + def test_entry_without_alert_key_ignored(self): + entries = [{"state": "ignore", "alert_key": None}] + tf = TriageFilter(entries, {}) + assert tf.triaged_keys == set() - def test_package_version_exact_match(self): - entry = _make_triage_entry( - alert_key="badEncoding", - package_name="lodash", - package_type="npm", - package_version="4.17.21", - ) - tf = TriageFilter([entry]) - comp_match = _make_component(name="lodash", comp_type="npm", version="4.17.21") - comp_no_match = _make_component(name="lodash", comp_type="npm", version="4.17.20") - alert = _make_alert(title="badEncoding") +# --------------------------------------------------------------------------- +# TriageFilter._local_alert_is_triaged +# --------------------------------------------------------------------------- - assert tf.is_alert_triaged(comp_match, alert) is True - assert tf.is_alert_triaged(comp_no_match, alert) is False +class TestLocalAlertIsTriaged: + def test_direct_type_match(self): + triaged_types = {"badEncoding"} + alert = _make_local_alert(alert_type="badEncoding") + assert TriageFilter._local_alert_is_triaged(alert, triaged_types) is True + + def test_direct_type_no_match(self): + triaged_types = {"badEncoding"} + alert = _make_local_alert(alert_type="cve") + assert TriageFilter._local_alert_is_triaged(alert, triaged_types) is False + + def test_generic_type_falls_back_to_title(self): + triaged_types = {"badEncoding"} + alert = _make_local_alert(title="badEncoding", alert_type="generic") + assert TriageFilter._local_alert_is_triaged(alert, triaged_types) is True + + def test_vulnerability_type_falls_back_to_cve(self): + triaged_types = {"CVE-2024-1234"} + alert = _make_local_alert( + title="Some Vuln", alert_type="vulnerability", cve_id="CVE-2024-1234" + ) + assert TriageFilter._local_alert_is_triaged(alert, triaged_types) is True - def test_version_wildcard(self): - entry = _make_triage_entry( - alert_key="badEncoding", - package_name="lodash", - package_type="npm", - package_version="4.17.*", + def test_generic_type_falls_back_to_rule_id(self): + triaged_types = {"python.lang.security.audit.xss"} + alert = _make_local_alert( + title="XSS", alert_type="generic", + rule_id="python.lang.security.audit.xss", ) - tf = TriageFilter([entry]) - alert = _make_alert(title="badEncoding") - - assert tf.is_alert_triaged( - _make_component(name="lodash", comp_type="npm", version="4.17.21"), alert - ) is True - assert tf.is_alert_triaged( - _make_component(name="lodash", comp_type="npm", version="4.17.0"), alert - ) is True - assert tf.is_alert_triaged( - _make_component(name="lodash", comp_type="npm", version="4.18.0"), alert - ) is False - - def test_version_star_matches_all(self): - entry = _make_triage_entry( - alert_key="badEncoding", - package_name="lodash", - package_type="npm", - package_version="*", + assert TriageFilter._local_alert_is_triaged(alert, triaged_types) is True + + def test_generic_type_falls_back_to_detector_name(self): + triaged_types = {"AWS"} + alert = _make_local_alert( + title="AWS Key", alert_type="generic", detector_name="AWS" ) - tf = TriageFilter([entry]) - alert = _make_alert(title="badEncoding") - assert tf.is_alert_triaged( - _make_component(name="lodash", comp_type="npm", version="99.0.0"), alert - ) is True - - def test_states_block_and_warn_not_suppressed(self): - """Triage entries with block/warn/inherit states should not filter findings.""" - for state in ("block", "warn", "inherit"): - entry = _make_triage_entry(alert_key="badEncoding", state=state) - tf = TriageFilter([entry]) - assert tf.entries == [], f"state={state} should be excluded from filter entries" - - def test_state_monitor_suppressed(self): - entry = _make_triage_entry(alert_key="badEncoding", state="monitor") - tf = TriageFilter([entry]) - comp = _make_component() - alert = _make_alert(title="badEncoding") - assert tf.is_alert_triaged(comp, alert) is True - - def test_alert_with_no_matchable_keys(self): - """Alert with no title, type, or relevant props should not match.""" - entry = _make_triage_entry(alert_key="something") - tf = TriageFilter([entry]) - comp = _make_component() - alert = {"severity": "high", "props": {}} - assert tf.is_alert_triaged(comp, alert) is False + assert TriageFilter._local_alert_is_triaged(alert, triaged_types) is True + + def test_no_fallback_candidates_returns_false(self): + triaged_types = {"something"} + alert = {"type": "generic", "props": {}} + assert TriageFilter._local_alert_is_triaged(alert, triaged_types) is False # --------------------------------------------------------------------------- @@ -206,47 +206,87 @@ def test_alert_with_no_matchable_keys(self): # --------------------------------------------------------------------------- class TestFilterComponents: - def test_removes_triaged_alerts(self): - entry = _make_triage_entry(alert_key="badEncoding") - tf = TriageFilter([entry]) - - alert_triaged = _make_alert(title="badEncoding") - alert_kept = _make_alert(title="otherIssue") - comp = _make_component(alerts=[alert_triaged, alert_kept]) + def test_removes_triaged_alert_by_type(self): + """Component ID matches artifact, triaged alert type matches local alert type.""" + entries = [_make_triage_entry("hash-1")] + artifact_alerts = _make_artifact_alerts( + alerts=[_socket_alert("hash-1", "badEncoding")] + ) + tf = TriageFilter(entries, artifact_alerts) + comp = _make_component( + comp_id=ARTIFACT_ID, + alerts=[ + _make_local_alert(alert_type="badEncoding"), + _make_local_alert(title="kept", alert_type="otherIssue"), + ], + ) filtered, count = tf.filter_components([comp]) assert count == 1 assert len(filtered) == 1 assert len(filtered[0]["alerts"]) == 1 - assert filtered[0]["alerts"][0]["title"] == "otherIssue" + assert filtered[0]["alerts"][0]["title"] == "kept" def test_removes_component_when_all_alerts_triaged(self): - entry = _make_triage_entry(alert_key="badEncoding") - tf = TriageFilter([entry]) + entries = [_make_triage_entry("hash-1")] + artifact_alerts = _make_artifact_alerts( + alerts=[_socket_alert("hash-1", "badEncoding")] + ) + tf = TriageFilter(entries, artifact_alerts) - comp = _make_component(alerts=[_make_alert(title="badEncoding")]) + comp = _make_component( + comp_id=ARTIFACT_ID, + alerts=[_make_local_alert(alert_type="badEncoding")], + ) filtered, count = tf.filter_components([comp]) assert count == 1 assert len(filtered) == 0 def test_no_triage_entries_returns_original(self): - tf = TriageFilter([]) - comp = _make_component(alerts=[_make_alert()]) + tf = TriageFilter([], {}) + comp = _make_component(alerts=[_make_local_alert()]) + filtered, count = tf.filter_components([comp]) + assert count == 0 + assert filtered == [comp] + + def test_component_id_mismatch_keeps_all_alerts(self): + """When local component ID doesn't match any artifact, nothing is filtered.""" + entries = [_make_triage_entry("hash-1")] + artifact_alerts = _make_artifact_alerts( + artifact_id="different-artifact", + alerts=[_socket_alert("hash-1", "badEncoding")], + ) + tf = TriageFilter(entries, artifact_alerts) + + comp = _make_component( + comp_id="unrelated-comp-id", + alerts=[_make_local_alert(alert_type="badEncoding")], + ) filtered, count = tf.filter_components([comp]) assert count == 0 - assert filtered is [comp] or filtered == [comp] + assert len(filtered) == 1 def test_multiple_components_mixed(self): - entry = _make_triage_entry(alert_key="badEncoding") - tf = TriageFilter([entry]) + entries = [_make_triage_entry("hash-1")] + artifact_alerts = _make_artifact_alerts( + artifact_id="art-a", + alerts=[_socket_alert("hash-1", "badEncoding")], + ) + tf = TriageFilter(entries, artifact_alerts) - comp1 = _make_component(name="a", alerts=[_make_alert(title="badEncoding")]) - comp2 = _make_component(name="b", alerts=[_make_alert(title="otherIssue")]) + comp1 = _make_component( + comp_id="art-a", name="a", + alerts=[_make_local_alert(alert_type="badEncoding")], + ) + comp2 = _make_component( + comp_id="art-b", name="b", + alerts=[_make_local_alert(alert_type="otherIssue")], + ) comp3 = _make_component( - name="c", + comp_id="art-a", name="c", alerts=[ - _make_alert(title="badEncoding"), - _make_alert(title="keepMe"), + _make_local_alert(alert_type="badEncoding"), + _make_local_alert(title="keepMe", alert_type="keepMe"), ], ) @@ -258,6 +298,153 @@ def test_multiple_components_mixed(self): assert "b" in names assert "c" in names + def test_multiple_triaged_alert_types_on_same_artifact(self): + entries = [ + _make_triage_entry("hash-1", state="ignore"), + _make_triage_entry("hash-2", state="monitor"), + ] + artifact_alerts = _make_artifact_alerts( + alerts=[ + _socket_alert("hash-1", "badEncoding"), + _socket_alert("hash-2", "cve"), + ], + ) + tf = TriageFilter(entries, artifact_alerts) + + comp = _make_component( + comp_id=ARTIFACT_ID, + alerts=[ + _make_local_alert(alert_type="badEncoding"), + _make_local_alert(alert_type="cve"), + _make_local_alert(title="safe", alert_type="safe"), + ], + ) + filtered, count = tf.filter_components([comp]) + assert count == 2 + assert len(filtered[0]["alerts"]) == 1 + assert filtered[0]["alerts"][0]["type"] == "safe" + + +# --------------------------------------------------------------------------- +# stream_full_scan_alerts +# --------------------------------------------------------------------------- + +class TestStreamFullScanAlerts: + def test_parses_artifacts_and_alerts(self): + class FakeFullscansAPI: + def stream(self, org, scan_id, use_types=False): + return { + "artifact-1": { + "name": "lodash", + "version": "4.17.21", + "type": "npm", + "namespace": None, + "alerts": [ + {"key": "hash-a", "type": "badEncoding"}, + {"key": "hash-b", "type": "cve"}, + ], + }, + "artifact-2": { + "name": "express", + "version": "4.18.0", + "type": "npm", + "namespace": None, + "alerts": [], + }, + } + + class FakeSDK: + fullscans = FakeFullscansAPI() + + result = stream_full_scan_alerts(FakeSDK(), "my-org", "scan-123") + assert "artifact-1" in result + assert "artifact-2" not in result # empty alerts filtered out + assert len(result["artifact-1"]) == 2 + assert result["artifact-1"][0]["key"] == "hash-a" + assert result["artifact-1"][0]["_artifact"]["artifact_name"] == "lodash" + + def test_skips_alerts_without_key(self): + class FakeFullscansAPI: + def stream(self, org, scan_id, use_types=False): + return { + "art-1": { + "name": "pkg", + "version": "1.0.0", + "type": "npm", + "alerts": [ + {"key": "hash-a", "type": "badEncoding"}, + {"type": "noKey"}, # missing key + {"key": "", "type": "emptyKey"}, # empty key + ], + }, + } + + class FakeSDK: + fullscans = FakeFullscansAPI() + + result = stream_full_scan_alerts(FakeSDK(), "org", "scan") + assert len(result["art-1"]) == 1 + + def test_access_denied_returns_empty(self, caplog): + class APIAccessDenied(Exception): + pass + + class FakeFullscansAPI: + def stream(self, org, scan_id, use_types=False): + raise APIAccessDenied("Forbidden") + + class FakeSDK: + fullscans = FakeFullscansAPI() + + with caplog.at_level(logging.DEBUG): + result = stream_full_scan_alerts(FakeSDK(), "org", "scan") + + assert result == {} + info_msgs = [r for r in caplog.records if r.levelno == logging.INFO] + assert any("access denied" in m.message.lower() for m in info_msgs) + + def test_api_error_returns_empty(self): + class FakeFullscansAPI: + def stream(self, org, scan_id, use_types=False): + raise RuntimeError("Network failure") + + class FakeSDK: + fullscans = FakeFullscansAPI() + + result = stream_full_scan_alerts(FakeSDK(), "org", "scan") + assert result == {} + + def test_non_dict_response_returns_empty(self): + class FakeFullscansAPI: + def stream(self, org, scan_id, use_types=False): + return "unexpected string" + + class FakeSDK: + fullscans = FakeFullscansAPI() + + result = stream_full_scan_alerts(FakeSDK(), "org", "scan") + assert result == {} + + def test_subpath_handling(self): + """Supports both camelCase and lowercase subpath field names.""" + class FakeFullscansAPI: + def stream(self, org, scan_id, use_types=False): + return { + "art-1": { + "name": "pkg", + "version": "1.0", + "type": "npm", + "subPath": "src/lib", + "alerts": [{"key": "k1", "type": "t1"}], + }, + } + + class FakeSDK: + fullscans = FakeFullscansAPI() + + result = stream_full_scan_alerts(FakeSDK(), "org", "scan") + assert result["art-1"][0]["_artifact"]["artifact_subpath"] == "src/lib" + # --------------------------------------------------------------------------- # fetch_triage_data @@ -323,14 +510,12 @@ def list_alert_triage(self, org, params): class FakeSDK: triage = FakeTriageAPI() - import logging with caplog.at_level(logging.DEBUG): entries = fetch_triage_data(FakeSDK(), "my-org") assert entries == [] info_messages = [r for r in caplog.records if r.levelno == logging.INFO] assert any("access denied" in m.message.lower() for m in info_messages) - # Should NOT produce an ERROR-level record error_messages = [r for r in caplog.records if r.levelno >= logging.ERROR] assert not error_messages @@ -385,7 +570,6 @@ def test_injects_after_heading(self): content = notifications["github_pr"][0]["content"] assert "3 finding(s) triaged" in content assert "Socket Dashboard" in content - # Summary line should appear after the # heading lines = content.split("\n") heading_idx = next(i for i, l in enumerate(lines) if l.strip().startswith("# ")) summary_idx = next(i for i, l in enumerate(lines) if "triaged" in l) From 0ded415a23ddd2fe2801a77782df11c28ee94167 Mon Sep 17 00:00:00 2001 From: Carl Bergenhem Date: Wed, 4 Mar 2026 12:57:02 -0500 Subject: [PATCH 13/13] Update how scan ID and HTML url are extracted The CreateFullScanREsponse object is used here and `id` and `html_url` are actually nested in the `data` field, not on the root of the object. `html_url` should also be `html_report_url` --- socket_basics/socket_basics.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/socket_basics/socket_basics.py b/socket_basics/socket_basics.py index c5683b2..5ca722e 100644 --- a/socket_basics/socket_basics.py +++ b/socket_basics/socket_basics.py @@ -356,9 +356,10 @@ def submit_socket_facts(self, socket_facts_path: Path, results: Dict[str, Any]) logger.error(f"Error creating full scan: {error_msg}") raise Exception(f"Error creating full scan: {error_msg}") - # Extract the scan ID and HTML URL from the response - scan_id = getattr(res, 'id', None) - html_url = getattr(res, 'html_url', None) + # SDK CreateFullScanResponse nests metadata under .data + data = getattr(res, 'data', None) or res + scan_id = getattr(data, 'id', None) + html_url = getattr(data, 'html_report_url', None) or getattr(data, 'html_url', None) logger.debug(f"Extracted from object: scan_id={scan_id}, html_url={html_url}") if scan_id: