From e208c49d679f1af9811a770efb71e573ddcba872 Mon Sep 17 00:00:00 2001 From: Arsen Ohanyan Date: Mon, 16 Feb 2026 16:43:07 -0800 Subject: [PATCH 1/5] fix: skip Path.resolve() for cloud storage db_uri in vector store config --- .../graphrag/graphrag/config/models/graph_rag_config.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/packages/graphrag/graphrag/config/models/graph_rag_config.py b/packages/graphrag/graphrag/config/models/graph_rag_config.py index dc28da97ca..5e0d42b97a 100644 --- a/packages/graphrag/graphrag/config/models/graph_rag_config.py +++ b/packages/graphrag/graphrag/config/models/graph_rag_config.py @@ -268,9 +268,11 @@ def _validate_vector_store_db_uri(self) -> None: """Validate the vector store configuration.""" store = self.vector_store if store.type == VectorStoreType.LanceDB: - if not store.db_uri or store.db_uri.strip == "": + if not store.db_uri or store.db_uri.strip() == "": store.db_uri = graphrag_config_defaults.vector_store.db_uri - store.db_uri = str(Path(store.db_uri).resolve()) + # Don't resolve cloud storage URIs as local paths + if not store.db_uri.startswith(("gs://", "s3://", "az://", "abfs://")): + store.db_uri = str(Path(store.db_uri).resolve()) def get_completion_model_config(self, model_id: str) -> ModelConfig: """Get a completion model configuration by ID. From c51ceaaadc6f37ffded16d50510713bf7ea8a42a Mon Sep 17 00:00:00 2001 From: Arsen Ohanyan Date: Tue, 17 Feb 2026 17:21:28 -0800 Subject: [PATCH 2/5] feat: add LLM-based entity resolution workflow step --- packages/graphrag/graphrag/config/defaults.py | 14 ++ .../graphrag/graphrag/config/init_content.py | 5 + .../config/models/entity_resolution_config.py | 52 +++++ .../config/models/graph_rag_config.py | 13 +- .../operations/resolve_entities/__init__.py | 4 + .../resolve_entities/resolve_entities.py | 157 +++++++++++++ .../graphrag/index/workflows/__init__.py | 4 + .../graphrag/index/workflows/factory.py | 1 + .../index/workflows/resolve_entities.py | 74 ++++++ .../prompts/index/entity_resolution.py | 29 +++ .../resolve_entities/test_resolve_entities.py | 217 ++++++++++++++++++ 11 files changed, 566 insertions(+), 4 deletions(-) create mode 100644 packages/graphrag/graphrag/config/models/entity_resolution_config.py create mode 100644 packages/graphrag/graphrag/index/operations/resolve_entities/__init__.py create mode 100644 packages/graphrag/graphrag/index/operations/resolve_entities/resolve_entities.py create mode 100644 packages/graphrag/graphrag/index/workflows/resolve_entities.py create mode 100644 packages/graphrag/graphrag/prompts/index/entity_resolution.py create mode 100644 tests/unit/indexing/operations/resolve_entities/test_resolve_entities.py diff --git a/packages/graphrag/graphrag/config/defaults.py b/packages/graphrag/graphrag/config/defaults.py index 640933581a..c1b880e1fe 100644 --- a/packages/graphrag/graphrag/config/defaults.py +++ b/packages/graphrag/graphrag/config/defaults.py @@ -311,6 +311,17 @@ class SnapshotsDefaults: raw_graph: bool = False +@dataclass +class EntityResolutionDefaults: + """Default values for entity resolution.""" + + enabled: bool = False + prompt: None = None + batch_size: int = 200 + completion_model_id: str = DEFAULT_COMPLETION_MODEL_ID + model_instance_name: str = "entity_resolution" + + @dataclass class SummarizeDescriptionsDefaults: """Default values for summarizing descriptions.""" @@ -359,6 +370,9 @@ class GraphRagConfigDefaults: chunking: ChunkingDefaults = field(default_factory=ChunkingDefaults) snapshots: SnapshotsDefaults = field(default_factory=SnapshotsDefaults) extract_graph: ExtractGraphDefaults = field(default_factory=ExtractGraphDefaults) + entity_resolution: EntityResolutionDefaults = field( + default_factory=EntityResolutionDefaults + ) extract_graph_nlp: ExtractGraphNLPDefaults = field( default_factory=ExtractGraphNLPDefaults ) diff --git a/packages/graphrag/graphrag/config/init_content.py b/packages/graphrag/graphrag/config/init_content.py index 9973d1920f..b651425f4b 100644 --- a/packages/graphrag/graphrag/config/init_content.py +++ b/packages/graphrag/graphrag/config/init_content.py @@ -82,6 +82,11 @@ entity_types: [{",".join(graphrag_config_defaults.extract_graph.entity_types)}] max_gleanings: {graphrag_config_defaults.extract_graph.max_gleanings} +entity_resolution: + enabled: {graphrag_config_defaults.entity_resolution.enabled} + completion_model_id: {graphrag_config_defaults.entity_resolution.completion_model_id} + batch_size: {graphrag_config_defaults.entity_resolution.batch_size} + summarize_descriptions: completion_model_id: {graphrag_config_defaults.summarize_descriptions.completion_model_id} prompt: "prompts/summarize_descriptions.txt" diff --git a/packages/graphrag/graphrag/config/models/entity_resolution_config.py b/packages/graphrag/graphrag/config/models/entity_resolution_config.py new file mode 100644 index 0000000000..9150169f27 --- /dev/null +++ b/packages/graphrag/graphrag/config/models/entity_resolution_config.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for entity resolution.""" + +from dataclasses import dataclass +from pathlib import Path + +from pydantic import BaseModel, Field + +from graphrag.config.defaults import graphrag_config_defaults +from graphrag.prompts.index.entity_resolution import ENTITY_RESOLUTION_PROMPT + + +@dataclass +class EntityResolutionPrompts: + """Entity resolution prompt templates.""" + + resolution_prompt: str + + +class EntityResolutionConfig(BaseModel): + """Configuration section for entity resolution.""" + + enabled: bool = Field( + description="Whether to enable LLM-based entity resolution.", + default=graphrag_config_defaults.entity_resolution.enabled, + ) + completion_model_id: str = Field( + description="The model ID to use for entity resolution.", + default=graphrag_config_defaults.entity_resolution.completion_model_id, + ) + model_instance_name: str = Field( + description="The model singleton instance name. This primarily affects the cache storage partitioning.", + default=graphrag_config_defaults.entity_resolution.model_instance_name, + ) + prompt: str | None = Field( + description="The entity resolution prompt to use.", + default=graphrag_config_defaults.entity_resolution.prompt, + ) + batch_size: int = Field( + description="Maximum number of entity names to send to the LLM in each batch.", + default=graphrag_config_defaults.entity_resolution.batch_size, + ) + + def resolved_prompts(self) -> EntityResolutionPrompts: + """Get the resolved entity resolution prompts.""" + return EntityResolutionPrompts( + resolution_prompt=Path(self.prompt).read_text(encoding="utf-8") + if self.prompt + else ENTITY_RESOLUTION_PROMPT, + ) diff --git a/packages/graphrag/graphrag/config/models/graph_rag_config.py b/packages/graphrag/graphrag/config/models/graph_rag_config.py index 5e0d42b97a..27e692671f 100644 --- a/packages/graphrag/graphrag/config/models/graph_rag_config.py +++ b/packages/graphrag/graphrag/config/models/graph_rag_config.py @@ -24,6 +24,7 @@ from graphrag.config.models.community_reports_config import CommunityReportsConfig from graphrag.config.models.drift_search_config import DRIFTSearchConfig from graphrag.config.models.embed_text_config import EmbedTextConfig +from graphrag.config.models.entity_resolution_config import EntityResolutionConfig from graphrag.config.models.extract_claims_config import ExtractClaimsConfig from graphrag.config.models.extract_graph_config import ExtractGraphConfig from graphrag.config.models.extract_graph_nlp_config import ExtractGraphNLPConfig @@ -186,6 +187,12 @@ def _validate_reporting_base_dir(self) -> None: ) """The entity extraction configuration to use.""" + entity_resolution: EntityResolutionConfig = Field( + description="The entity resolution configuration to use.", + default=EntityResolutionConfig(), + ) + """The entity resolution configuration to use.""" + summarize_descriptions: SummarizeDescriptionsConfig = Field( description="The description summarization configuration to use.", default=SummarizeDescriptionsConfig(), @@ -268,11 +275,9 @@ def _validate_vector_store_db_uri(self) -> None: """Validate the vector store configuration.""" store = self.vector_store if store.type == VectorStoreType.LanceDB: - if not store.db_uri or store.db_uri.strip() == "": + if not store.db_uri or store.db_uri.strip == "": store.db_uri = graphrag_config_defaults.vector_store.db_uri - # Don't resolve cloud storage URIs as local paths - if not store.db_uri.startswith(("gs://", "s3://", "az://", "abfs://")): - store.db_uri = str(Path(store.db_uri).resolve()) + store.db_uri = str(Path(store.db_uri).resolve()) def get_completion_model_config(self, model_id: str) -> ModelConfig: """Get a completion model configuration by ID. diff --git a/packages/graphrag/graphrag/index/operations/resolve_entities/__init__.py b/packages/graphrag/graphrag/index/operations/resolve_entities/__init__.py new file mode 100644 index 0000000000..ae85b818dd --- /dev/null +++ b/packages/graphrag/graphrag/index/operations/resolve_entities/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Entity resolution operation package.""" diff --git a/packages/graphrag/graphrag/index/operations/resolve_entities/resolve_entities.py b/packages/graphrag/graphrag/index/operations/resolve_entities/resolve_entities.py new file mode 100644 index 0000000000..1966aaba2a --- /dev/null +++ b/packages/graphrag/graphrag/index/operations/resolve_entities/resolve_entities.py @@ -0,0 +1,157 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""LLM-based entity resolution operation. + +Identifies entities with different surface forms that refer to the same +real-world entity (e.g. "Ahab" and "Captain Ahab") and unifies their titles. +""" + +import asyncio +import logging +from typing import TYPE_CHECKING + +import pandas as pd + +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks +from graphrag.logger.progress import progress_ticker + +if TYPE_CHECKING: + from graphrag_llm.completion import LLMCompletion + +logger = logging.getLogger(__name__) + + +async def resolve_entities( + entities: pd.DataFrame, + relationships: pd.DataFrame, + callbacks: WorkflowCallbacks, + model: "LLMCompletion", + prompt: str, + batch_size: int, + num_threads: int, +) -> tuple[pd.DataFrame, pd.DataFrame]: + """Identify and merge duplicate entities with different surface forms. + + Sends entity names in batches to the LLM, parses the response to build + a rename mapping, then applies it to both entity titles and relationship + source/target columns. + + Parameters + ---------- + entities : pd.DataFrame + Entity DataFrame with at least a "title" column. + relationships : pd.DataFrame + Relationship DataFrame with "source" and "target" columns. + callbacks : WorkflowCallbacks + Progress callbacks. + model : LLMCompletion + The LLM completion model to use. + prompt : str + The entity resolution prompt template (must contain {entity_list}). + batch_size : int + Maximum number of entity names per LLM batch. + num_threads : int + Concurrency limit for LLM calls. + + Returns + ------- + tuple[pd.DataFrame, pd.DataFrame] + Updated (entities, relationships) with unified titles. + """ + if "title" not in entities.columns: + return entities, relationships + + titles = entities["title"].dropna().unique().tolist() + if len(titles) < 2: + return entities, relationships + + logger.info( + "Running LLM entity resolution on %d unique entity names...", len(titles) + ) + + # Build batches + batches = [ + titles[i : i + batch_size] for i in range(0, len(titles), batch_size) + ] + + ticker = progress_ticker( + callbacks.progress, + len(batches), + description="Entity resolution batch progress: ", + ) + + semaphore = asyncio.Semaphore(num_threads) + rename_map: dict[str, str] = {} # alias → canonical + + async def process_batch(batch: list[str], batch_idx: int) -> None: + entity_list = "\n".join(f"{i+1}. {name}" for i, name in enumerate(batch)) + formatted_prompt = prompt.format(entity_list=entity_list) + + async with semaphore: + try: + response = await model(formatted_prompt) + raw = (response or "").strip() + except Exception: + logger.warning( + "Entity resolution LLM call failed for batch %d, skipping", + batch_idx + 1, + ) + ticker(1) + return + + if "NO_DUPLICATES" in raw: + logger.info(" Batch %d: no duplicates found", batch_idx + 1) + ticker(1) + return + + # Parse response lines like "3, 17" or "5, 12, 28" + for line in raw.splitlines(): + line = line.strip() + if not line or line.startswith("#") or line.startswith("Where"): + continue + parts = [p.strip() for p in line.split(",")] + indices: list[int] = [] + for p in parts: + digits = "".join(c for c in p if c.isdigit()) + if digits: + idx = int(digits) - 1 # 1-indexed → 0-indexed + if 0 <= idx < len(batch): + indices.append(idx) + if len(indices) >= 2: + canonical = batch[indices[0]] + for alias_idx in indices[1:]: + alias = batch[alias_idx] + rename_map[alias] = canonical + logger.info( + " Entity resolution: '%s' → '%s'", alias, canonical + ) + + ticker(1) + + futures = [process_batch(batch, i) for i, batch in enumerate(batches)] + await asyncio.gather(*futures) + + if not rename_map: + logger.info("Entity resolution complete: no duplicates found") + return entities, relationships + + logger.info("Entity resolution: merging %d duplicate names", len(rename_map)) + + # Apply renames to entity titles + entities = entities.copy() + entities["title"] = entities["title"].map(lambda t: rename_map.get(t, t)) + + # Apply renames to relationship source/target + if not relationships.empty: + relationships = relationships.copy() + if "source" in relationships.columns: + relationships["source"] = relationships["source"].map( + lambda s: rename_map.get(s, s) + ) + if "target" in relationships.columns: + relationships["target"] = relationships["target"].map( + lambda t: rename_map.get(t, t) + ) + + return entities, relationships diff --git a/packages/graphrag/graphrag/index/workflows/__init__.py b/packages/graphrag/graphrag/index/workflows/__init__.py index 6dee90c097..766b1c979f 100644 --- a/packages/graphrag/graphrag/index/workflows/__init__.py +++ b/packages/graphrag/graphrag/index/workflows/__init__.py @@ -36,6 +36,9 @@ from .finalize_graph import ( run_workflow as run_finalize_graph, ) +from .resolve_entities import ( + run_workflow as run_resolve_entities, +) from .generate_text_embeddings import ( run_workflow as run_generate_text_embeddings, ) @@ -86,6 +89,7 @@ "create_final_text_units": run_create_final_text_units, "extract_graph_nlp": run_extract_graph_nlp, "extract_graph": run_extract_graph, + "resolve_entities": run_resolve_entities, "finalize_graph": run_finalize_graph, "generate_text_embeddings": run_generate_text_embeddings, "prune_graph": run_prune_graph, diff --git a/packages/graphrag/graphrag/index/workflows/factory.py b/packages/graphrag/graphrag/index/workflows/factory.py index 585ecfa8a1..1f621079de 100644 --- a/packages/graphrag/graphrag/index/workflows/factory.py +++ b/packages/graphrag/graphrag/index/workflows/factory.py @@ -53,6 +53,7 @@ def create_pipeline( "create_base_text_units", "create_final_documents", "extract_graph", + "resolve_entities", "finalize_graph", "extract_covariates", "create_communities", diff --git a/packages/graphrag/graphrag/index/workflows/resolve_entities.py b/packages/graphrag/graphrag/index/workflows/resolve_entities.py new file mode 100644 index 0000000000..3ac034fd7a --- /dev/null +++ b/packages/graphrag/graphrag/index/workflows/resolve_entities.py @@ -0,0 +1,74 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing the resolve_entities workflow definition.""" + +import logging +from typing import TYPE_CHECKING + +from graphrag_llm.completion import create_completion + +from graphrag.cache.cache_key_creator import cache_key_creator +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.data_model.data_reader import DataReader +from graphrag.index.operations.resolve_entities.resolve_entities import ( + resolve_entities, +) +from graphrag.index.typing.context import PipelineRunContext +from graphrag.index.typing.workflow import WorkflowFunctionOutput + +if TYPE_CHECKING: + from graphrag_llm.completion import LLMCompletion + +logger = logging.getLogger(__name__) + + +async def run_workflow( + config: GraphRagConfig, + context: PipelineRunContext, +) -> WorkflowFunctionOutput: + """Resolve duplicate entities with different surface forms.""" + logger.info("Workflow started: resolve_entities") + + if not config.entity_resolution.enabled: + logger.info( + "Entity resolution is disabled (entity_resolution.enabled=false), skipping" + ) + return WorkflowFunctionOutput(result={}) + + reader = DataReader(context.output_table_provider) + entities = await reader.entities() + relationships = await reader.relationships() + + resolution_model_config = config.get_completion_model_config( + config.entity_resolution.completion_model_id + ) + resolution_prompts = config.entity_resolution.resolved_prompts() + resolution_model = create_completion( + resolution_model_config, + cache=context.cache.child(config.entity_resolution.model_instance_name), + cache_key_creator=cache_key_creator, + ) + + resolved_entities, resolved_relationships = await resolve_entities( + entities=entities, + relationships=relationships, + callbacks=context.callbacks, + model=resolution_model, + prompt=resolution_prompts.resolution_prompt, + batch_size=config.entity_resolution.batch_size, + num_threads=config.concurrent_requests, + ) + + await context.output_table_provider.write_dataframe("entities", resolved_entities) + await context.output_table_provider.write_dataframe( + "relationships", resolved_relationships + ) + + logger.info("Workflow completed: resolve_entities") + return WorkflowFunctionOutput( + result={ + "entities": resolved_entities, + "relationships": resolved_relationships, + } + ) diff --git a/packages/graphrag/graphrag/prompts/index/entity_resolution.py b/packages/graphrag/graphrag/prompts/index/entity_resolution.py new file mode 100644 index 0000000000..071b6674bc --- /dev/null +++ b/packages/graphrag/graphrag/prompts/index/entity_resolution.py @@ -0,0 +1,29 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A file containing prompts definition.""" + +ENTITY_RESOLUTION_PROMPT = """ +You are an entity resolution expert. Below is a numbered list of entity names +extracted from a knowledge graph. Identify which names refer to the SAME +real-world entity and choose the best canonical name for each group of duplicates. + +Rules: +- Only merge names that clearly refer to the same entity (e.g., "Ahab" and +"Captain Ahab", "USA" and "United States of America") +- Do NOT merge entities that are merely related (e.g., "Ahab" and "Moby Dick") +- Choose the most complete and commonly used name as the canonical form +- Reference entities by their number + +Output format — one group per line, canonical number first, then duplicate numbers: +3, 17 +5, 12, 28 + +Where each line means: all listed numbers refer to the same entity, and the +first number's name is the canonical form. + +If no duplicates are found, respond with exactly: NO_DUPLICATES + +Entity list: +{entity_list} +""" diff --git a/tests/unit/indexing/operations/resolve_entities/test_resolve_entities.py b/tests/unit/indexing/operations/resolve_entities/test_resolve_entities.py new file mode 100644 index 0000000000..79a11cdd22 --- /dev/null +++ b/tests/unit/indexing/operations/resolve_entities/test_resolve_entities.py @@ -0,0 +1,217 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Unit tests for the resolve_entities operation.""" + +from unittest.mock import AsyncMock, MagicMock + +import pandas as pd +import pytest + +from graphrag.index.operations.resolve_entities.resolve_entities import ( + resolve_entities, +) + + +@pytest.fixture +def sample_entities(): + """Create sample entity DataFrame with known duplicates.""" + return pd.DataFrame({ + "title": [ + "Captain Ahab", + "Moby Dick", + "Ahab", + "The Pequod", + "Ishmael", + "Pequod", + ], + "description": [ + "Captain of the Pequod", + "The great white whale", + "The obsessed captain", + "The whaling ship", + "The narrator", + "A whaling vessel", + ], + }) + + +@pytest.fixture +def sample_relationships(): + """Create sample relationship DataFrame.""" + return pd.DataFrame({ + "source": ["Ahab", "Captain Ahab", "Ishmael", "Pequod"], + "target": ["Moby Dick", "The Pequod", "Pequod", "Ahab"], + "description": [ + "hunts", + "commands", + "boards", + "carries", + ], + }) + + +@pytest.fixture +def mock_callbacks(): + """Create mock workflow callbacks.""" + callbacks = MagicMock() + callbacks.progress = MagicMock() + return callbacks + + +def _make_mock_model(response_text: str) -> AsyncMock: + """Create a mock LLM model that returns the given text.""" + model = AsyncMock() + model.return_value = response_text + return model + + +@pytest.mark.asyncio +async def test_no_duplicates(sample_entities, sample_relationships, mock_callbacks): + """When LLM finds no duplicates, entities remain unchanged.""" + model = _make_mock_model("NO_DUPLICATES") + + result_entities, result_relationships = await resolve_entities( + entities=sample_entities.copy(), + relationships=sample_relationships.copy(), + callbacks=mock_callbacks, + model=model, + prompt="{entity_list}", + batch_size=200, + num_threads=1, + ) + + # Titles should be unchanged + assert list(result_entities["title"]) == list(sample_entities["title"]) + assert list(result_relationships["source"]) == list( + sample_relationships["source"] + ) + + +@pytest.mark.asyncio +async def test_simple_duplicates( + sample_entities, sample_relationships, mock_callbacks +): + """Ahab → Captain Ahab, Pequod → The Pequod.""" + # LLM response: entity 1 (Captain Ahab) and 3 (Ahab) are the same; + # entity 4 (The Pequod) and 6 (Pequod) are the same. + model = _make_mock_model("1, 3\n4, 6") + + result_entities, result_relationships = await resolve_entities( + entities=sample_entities.copy(), + relationships=sample_relationships.copy(), + callbacks=mock_callbacks, + model=model, + prompt="{entity_list}", + batch_size=200, + num_threads=1, + ) + + # "Ahab" should become "Captain Ahab" + titles = list(result_entities["title"]) + assert "Ahab" not in titles + assert titles.count("Captain Ahab") == 2 # both rows unified + + # "Pequod" should become "The Pequod" + assert "Pequod" not in titles + assert titles.count("The Pequod") == 2 + + # Relationships should also be renamed + sources = list(result_relationships["source"]) + targets = list(result_relationships["target"]) + assert "Ahab" not in sources + assert "Ahab" not in targets + assert "Pequod" not in sources + assert "Pequod" not in targets + + +@pytest.mark.asyncio +async def test_llm_failure_graceful( + sample_entities, sample_relationships, mock_callbacks +): + """If LLM call fails, entities are returned unchanged.""" + model = AsyncMock(side_effect=Exception("LLM unavailable")) + + result_entities, result_relationships = await resolve_entities( + entities=sample_entities.copy(), + relationships=sample_relationships.copy(), + callbacks=mock_callbacks, + model=model, + prompt="{entity_list}", + batch_size=200, + num_threads=1, + ) + + # Should fall back to no changes + assert list(result_entities["title"]) == list(sample_entities["title"]) + + +@pytest.mark.asyncio +async def test_single_entity_skips(): + """With fewer than 2 entities, resolution is skipped entirely.""" + entities = pd.DataFrame({"title": ["Only One"]}) + relationships = pd.DataFrame({"source": [], "target": []}) + callbacks = MagicMock() + callbacks.progress = MagicMock() + model = _make_mock_model("should not be called") + + result_entities, _ = await resolve_entities( + entities=entities, + relationships=relationships, + callbacks=callbacks, + model=model, + prompt="{entity_list}", + batch_size=200, + num_threads=1, + ) + + # Model should not have been called + model.assert_not_called() + assert list(result_entities["title"]) == ["Only One"] + + +@pytest.mark.asyncio +async def test_batch_splitting(mock_callbacks): + """Entities are split into batches of the configured size.""" + # 5 entities, batch_size=2 → 3 batches + entities = pd.DataFrame({ + "title": ["A", "B", "C", "D", "E"], + "description": [""] * 5, + }) + relationships = pd.DataFrame({"source": ["A"], "target": ["B"]}) + + model = _make_mock_model("NO_DUPLICATES") + + await resolve_entities( + entities=entities.copy(), + relationships=relationships.copy(), + callbacks=mock_callbacks, + model=model, + prompt="{entity_list}", + batch_size=2, + num_threads=1, + ) + + # Model should have been called 3 times (ceil(5/2)) + assert model.call_count == 3 + + +@pytest.mark.asyncio +async def test_missing_title_column(mock_callbacks): + """If there's no title column, skip resolution.""" + entities = pd.DataFrame({"name": ["A", "B"]}) + relationships = pd.DataFrame({"source": ["A"], "target": ["B"]}) + model = _make_mock_model("should not be called") + + result_entities, _ = await resolve_entities( + entities=entities, + relationships=relationships, + callbacks=mock_callbacks, + model=model, + prompt="{entity_list}", + batch_size=200, + num_threads=1, + ) + + model.assert_not_called() + assert list(result_entities.columns) == ["name"] From 5c0949ba928531576e394732d17d0fd7a8d87cbd Mon Sep 17 00:00:00 2001 From: Arsen Ohanyan Date: Thu, 19 Feb 2026 20:13:51 -0800 Subject: [PATCH 3/5] Early resolve and drop batches --- packages/graphrag/graphrag/config/defaults.py | 1 - .../graphrag/graphrag/config/init_content.py | 1 - .../config/models/entity_resolution_config.py | 4 - .../graphrag/graphrag/data_model/schemas.py | 2 + .../index/operations/finalize_entities.py | 3 + .../resolve_entities/resolve_entities.py | 123 ++++++++---------- .../graphrag/index/update/entities.py | 8 ++ .../graphrag/index/workflows/__init__.py | 4 - .../graphrag/index/workflows/extract_graph.py | 42 +++++- .../graphrag/index/workflows/factory.py | 1 - .../index/workflows/resolve_entities.py | 74 ----------- 11 files changed, 106 insertions(+), 157 deletions(-) delete mode 100644 packages/graphrag/graphrag/index/workflows/resolve_entities.py diff --git a/packages/graphrag/graphrag/config/defaults.py b/packages/graphrag/graphrag/config/defaults.py index c1b880e1fe..0fe4e3831a 100644 --- a/packages/graphrag/graphrag/config/defaults.py +++ b/packages/graphrag/graphrag/config/defaults.py @@ -317,7 +317,6 @@ class EntityResolutionDefaults: enabled: bool = False prompt: None = None - batch_size: int = 200 completion_model_id: str = DEFAULT_COMPLETION_MODEL_ID model_instance_name: str = "entity_resolution" diff --git a/packages/graphrag/graphrag/config/init_content.py b/packages/graphrag/graphrag/config/init_content.py index b651425f4b..3dc15d7e3e 100644 --- a/packages/graphrag/graphrag/config/init_content.py +++ b/packages/graphrag/graphrag/config/init_content.py @@ -85,7 +85,6 @@ entity_resolution: enabled: {graphrag_config_defaults.entity_resolution.enabled} completion_model_id: {graphrag_config_defaults.entity_resolution.completion_model_id} - batch_size: {graphrag_config_defaults.entity_resolution.batch_size} summarize_descriptions: completion_model_id: {graphrag_config_defaults.summarize_descriptions.completion_model_id} diff --git a/packages/graphrag/graphrag/config/models/entity_resolution_config.py b/packages/graphrag/graphrag/config/models/entity_resolution_config.py index 9150169f27..5b3ef03be6 100644 --- a/packages/graphrag/graphrag/config/models/entity_resolution_config.py +++ b/packages/graphrag/graphrag/config/models/entity_resolution_config.py @@ -38,10 +38,6 @@ class EntityResolutionConfig(BaseModel): description="The entity resolution prompt to use.", default=graphrag_config_defaults.entity_resolution.prompt, ) - batch_size: int = Field( - description="Maximum number of entity names to send to the LLM in each batch.", - default=graphrag_config_defaults.entity_resolution.batch_size, - ) def resolved_prompts(self) -> EntityResolutionPrompts: """Get the resolved entity resolution prompts.""" diff --git a/packages/graphrag/graphrag/data_model/schemas.py b/packages/graphrag/graphrag/data_model/schemas.py index c0926b9bb7..644649e122 100644 --- a/packages/graphrag/graphrag/data_model/schemas.py +++ b/packages/graphrag/graphrag/data_model/schemas.py @@ -13,6 +13,7 @@ NODE_DEGREE = "degree" NODE_FREQUENCY = "frequency" NODE_DETAILS = "node_details" +ALTERNATIVE_NAMES = "alternative_names" # POST-PREP EDGE TABLE SCHEMA EDGE_SOURCE = "source" @@ -73,6 +74,7 @@ TITLE, TYPE, DESCRIPTION, + ALTERNATIVE_NAMES, TEXT_UNIT_IDS, NODE_FREQUENCY, NODE_DEGREE, diff --git a/packages/graphrag/graphrag/index/operations/finalize_entities.py b/packages/graphrag/graphrag/index/operations/finalize_entities.py index 71d6acc536..b75134cb4e 100644 --- a/packages/graphrag/graphrag/index/operations/finalize_entities.py +++ b/packages/graphrag/graphrag/index/operations/finalize_entities.py @@ -28,6 +28,9 @@ def finalize_entities( final_entities["id"] = final_entities["human_readable_id"].apply( lambda _x: str(uuid4()) ) + # Ensure alternative_names column exists (empty when resolution is disabled) + if "alternative_names" not in final_entities.columns: + final_entities["alternative_names"] = [[] for _ in range(len(final_entities))] return final_entities.loc[ :, ENTITIES_FINAL_COLUMNS, diff --git a/packages/graphrag/graphrag/index/operations/resolve_entities/resolve_entities.py b/packages/graphrag/graphrag/index/operations/resolve_entities/resolve_entities.py index 1966aaba2a..19e8c61c9b 100644 --- a/packages/graphrag/graphrag/index/operations/resolve_entities/resolve_entities.py +++ b/packages/graphrag/graphrag/index/operations/resolve_entities/resolve_entities.py @@ -7,14 +7,12 @@ real-world entity (e.g. "Ahab" and "Captain Ahab") and unifies their titles. """ -import asyncio import logging from typing import TYPE_CHECKING import pandas as pd from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks -from graphrag.logger.progress import progress_ticker if TYPE_CHECKING: from graphrag_llm.completion import LLMCompletion @@ -28,36 +26,35 @@ async def resolve_entities( callbacks: WorkflowCallbacks, model: "LLMCompletion", prompt: str, - batch_size: int, num_threads: int, ) -> tuple[pd.DataFrame, pd.DataFrame]: """Identify and merge duplicate entities with different surface forms. - Sends entity names in batches to the LLM, parses the response to build - a rename mapping, then applies it to both entity titles and relationship - source/target columns. + Sends all unique entity titles to the LLM in a single call, parses the + response to build a rename mapping, then applies it to entity titles and + relationship source/target columns. Each canonical entity receives an + ``alternative_names`` column listing all of its aliases. Parameters ---------- entities : pd.DataFrame - Entity DataFrame with at least a "title" column. + Entity DataFrame with at least a ``title`` column. relationships : pd.DataFrame - Relationship DataFrame with "source" and "target" columns. + Relationship DataFrame with ``source`` and ``target`` columns. callbacks : WorkflowCallbacks Progress callbacks. model : LLMCompletion The LLM completion model to use. prompt : str - The entity resolution prompt template (must contain {entity_list}). - batch_size : int - Maximum number of entity names per LLM batch. + The entity resolution prompt template (must contain ``{entity_list}``). num_threads : int - Concurrency limit for LLM calls. + Concurrency limit for LLM calls (reserved for future use). Returns ------- tuple[pd.DataFrame, pd.DataFrame] - Updated (entities, relationships) with unified titles. + Updated ``(entities, relationships)`` with unified titles and an + ``alternative_names`` column on entities. """ if "title" not in entities.columns: return entities, relationships @@ -70,67 +67,46 @@ async def resolve_entities( "Running LLM entity resolution on %d unique entity names...", len(titles) ) - # Build batches - batches = [ - titles[i : i + batch_size] for i in range(0, len(titles), batch_size) - ] + # Build numbered entity list for the prompt + entity_list = "\n".join(f"{i+1}. {name}" for i, name in enumerate(titles)) + formatted_prompt = prompt.format(entity_list=entity_list) - ticker = progress_ticker( - callbacks.progress, - len(batches), - description="Entity resolution batch progress: ", - ) + try: + response = await model(formatted_prompt) + raw = (response or "").strip() + except Exception: + logger.warning("Entity resolution LLM call failed, skipping resolution") + return entities, relationships - semaphore = asyncio.Semaphore(num_threads) - rename_map: dict[str, str] = {} # alias → canonical + if "NO_DUPLICATES" in raw: + logger.info("Entity resolution: no duplicates found") + return entities, relationships - async def process_batch(batch: list[str], batch_idx: int) -> None: - entity_list = "\n".join(f"{i+1}. {name}" for i, name in enumerate(batch)) - formatted_prompt = prompt.format(entity_list=entity_list) - - async with semaphore: - try: - response = await model(formatted_prompt) - raw = (response or "").strip() - except Exception: - logger.warning( - "Entity resolution LLM call failed for batch %d, skipping", - batch_idx + 1, - ) - ticker(1) - return - - if "NO_DUPLICATES" in raw: - logger.info(" Batch %d: no duplicates found", batch_idx + 1) - ticker(1) - return - - # Parse response lines like "3, 17" or "5, 12, 28" - for line in raw.splitlines(): - line = line.strip() - if not line or line.startswith("#") or line.startswith("Where"): - continue - parts = [p.strip() for p in line.split(",")] - indices: list[int] = [] - for p in parts: - digits = "".join(c for c in p if c.isdigit()) - if digits: - idx = int(digits) - 1 # 1-indexed → 0-indexed - if 0 <= idx < len(batch): - indices.append(idx) - if len(indices) >= 2: - canonical = batch[indices[0]] - for alias_idx in indices[1:]: - alias = batch[alias_idx] - rename_map[alias] = canonical - logger.info( - " Entity resolution: '%s' → '%s'", alias, canonical - ) - - ticker(1) - - futures = [process_batch(batch, i) for i, batch in enumerate(batches)] - await asyncio.gather(*futures) + # Parse response and build rename mapping + rename_map: dict[str, str] = {} # alias → canonical + alternatives: dict[str, set[str]] = {} # canonical → {aliases} + + for line in raw.splitlines(): + line = line.strip() + if not line or line.startswith("#") or line.startswith("Where"): + continue + parts = [p.strip() for p in line.split(",")] + indices: list[int] = [] + for p in parts: + digits = "".join(c for c in p if c.isdigit()) + if digits: + idx = int(digits) - 1 # 1-indexed → 0-indexed + if 0 <= idx < len(titles): + indices.append(idx) + if len(indices) >= 2: + canonical = titles[indices[0]] + if canonical not in alternatives: + alternatives[canonical] = set() + for alias_idx in indices[1:]: + alias = titles[alias_idx] + rename_map[alias] = canonical + alternatives[canonical].add(alias) + logger.info(" Entity resolution: '%s' → '%s'", alias, canonical) if not rename_map: logger.info("Entity resolution complete: no duplicates found") @@ -142,6 +118,11 @@ async def process_batch(batch: list[str], batch_idx: int) -> None: entities = entities.copy() entities["title"] = entities["title"].map(lambda t: rename_map.get(t, t)) + # Add alternative_names column + entities["alternative_names"] = entities["title"].map( + lambda t: sorted(alternatives.get(t, set())) + ) + # Apply renames to relationship source/target if not relationships.empty: relationships = relationships.copy() diff --git a/packages/graphrag/graphrag/index/update/entities.py b/packages/graphrag/graphrag/index/update/entities.py index fe9bb2347b..09badda23e 100644 --- a/packages/graphrag/graphrag/index/update/entities.py +++ b/packages/graphrag/graphrag/index/update/entities.py @@ -44,6 +44,11 @@ def _group_and_resolve_entities( delta_entities_df["human_readable_id"] = np.arange( initial_id, initial_id + len(delta_entities_df) ) + # Ensure alternative_names column exists (may be absent in older indexes) + for df in [old_entities_df, delta_entities_df]: + if "alternative_names" not in df.columns: + df["alternative_names"] = [[] for _ in range(len(df))] + # Concat A and B combined = pd.concat( [old_entities_df, delta_entities_df], ignore_index=True, copy=False @@ -60,6 +65,9 @@ def _group_and_resolve_entities( "description": lambda x: list(x.astype(str)), # Ensure str # Concatenate nd.array into a single list "text_unit_ids": lambda x: list(itertools.chain(*x.tolist())), + "alternative_names": lambda x: sorted( + set(itertools.chain(*x.tolist())) + ), "degree": "first", # todo: we could probably re-compute this with the entire new graph }) .reset_index() diff --git a/packages/graphrag/graphrag/index/workflows/__init__.py b/packages/graphrag/graphrag/index/workflows/__init__.py index 766b1c979f..6dee90c097 100644 --- a/packages/graphrag/graphrag/index/workflows/__init__.py +++ b/packages/graphrag/graphrag/index/workflows/__init__.py @@ -36,9 +36,6 @@ from .finalize_graph import ( run_workflow as run_finalize_graph, ) -from .resolve_entities import ( - run_workflow as run_resolve_entities, -) from .generate_text_embeddings import ( run_workflow as run_generate_text_embeddings, ) @@ -89,7 +86,6 @@ "create_final_text_units": run_create_final_text_units, "extract_graph_nlp": run_extract_graph_nlp, "extract_graph": run_extract_graph, - "resolve_entities": run_resolve_entities, "finalize_graph": run_finalize_graph, "generate_text_embeddings": run_generate_text_embeddings, "prune_graph": run_prune_graph, diff --git a/packages/graphrag/graphrag/index/workflows/extract_graph.py b/packages/graphrag/graphrag/index/workflows/extract_graph.py index dc86b180fc..2d902fabc2 100644 --- a/packages/graphrag/graphrag/index/workflows/extract_graph.py +++ b/packages/graphrag/graphrag/index/workflows/extract_graph.py @@ -17,6 +17,9 @@ from graphrag.index.operations.extract_graph.extract_graph import ( extract_graph as extractor, ) +from graphrag.index.operations.resolve_entities.resolve_entities import ( + resolve_entities, +) from graphrag.index.operations.summarize_descriptions.summarize_descriptions import ( summarize_descriptions, ) @@ -58,6 +61,24 @@ async def run_workflow( cache_key_creator=cache_key_creator, ) + # Entity resolution model (optional) + resolution_enabled = config.entity_resolution.enabled + resolution_model = None + resolution_prompt = "" + if resolution_enabled: + resolution_model_config = config.get_completion_model_config( + config.entity_resolution.completion_model_id + ) + resolution_prompts = config.entity_resolution.resolved_prompts() + resolution_prompt = resolution_prompts.resolution_prompt + resolution_model = create_completion( + resolution_model_config, + cache=context.cache.child( + config.entity_resolution.model_instance_name + ), + cache_key_creator=cache_key_creator, + ) + entities, relationships, raw_entities, raw_relationships = await extract_graph( text_units=text_units, callbacks=context.callbacks, @@ -72,6 +93,10 @@ async def run_workflow( max_input_tokens=config.summarize_descriptions.max_input_tokens, summarization_prompt=summarization_prompts.summarize_prompt, summarization_num_threads=config.concurrent_requests, + resolution_enabled=resolution_enabled, + resolution_model=resolution_model, + resolution_prompt=resolution_prompt, + resolution_num_threads=config.concurrent_requests, ) await context.output_table_provider.write_dataframe("entities", entities) @@ -108,6 +133,10 @@ async def extract_graph( max_input_tokens: int, summarization_prompt: str, summarization_num_threads: int, + resolution_enabled: bool = False, + resolution_model: "LLMCompletion | None" = None, + resolution_prompt: str = "", + resolution_num_threads: int = 1, ) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]: """All the steps to create the base entity graph.""" # this returns a graph for each text unit, to be merged later @@ -136,10 +165,21 @@ async def extract_graph( logger.error(error_msg) raise ValueError(error_msg) - # copy these as is before any summarization + # copy these as is before any resolution or summarization raw_entities = extracted_entities.copy() raw_relationships = extracted_relationships.copy() + # Resolve duplicate entity names before grouping by title + if resolution_enabled and resolution_model is not None: + extracted_entities, extracted_relationships = await resolve_entities( + entities=extracted_entities, + relationships=extracted_relationships, + callbacks=callbacks, + model=resolution_model, + prompt=resolution_prompt, + num_threads=resolution_num_threads, + ) + entities, relationships = await get_summarized_entities_relationships( extracted_entities=extracted_entities, extracted_relationships=extracted_relationships, diff --git a/packages/graphrag/graphrag/index/workflows/factory.py b/packages/graphrag/graphrag/index/workflows/factory.py index 1f621079de..585ecfa8a1 100644 --- a/packages/graphrag/graphrag/index/workflows/factory.py +++ b/packages/graphrag/graphrag/index/workflows/factory.py @@ -53,7 +53,6 @@ def create_pipeline( "create_base_text_units", "create_final_documents", "extract_graph", - "resolve_entities", "finalize_graph", "extract_covariates", "create_communities", diff --git a/packages/graphrag/graphrag/index/workflows/resolve_entities.py b/packages/graphrag/graphrag/index/workflows/resolve_entities.py deleted file mode 100644 index 3ac034fd7a..0000000000 --- a/packages/graphrag/graphrag/index/workflows/resolve_entities.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing the resolve_entities workflow definition.""" - -import logging -from typing import TYPE_CHECKING - -from graphrag_llm.completion import create_completion - -from graphrag.cache.cache_key_creator import cache_key_creator -from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.data_model.data_reader import DataReader -from graphrag.index.operations.resolve_entities.resolve_entities import ( - resolve_entities, -) -from graphrag.index.typing.context import PipelineRunContext -from graphrag.index.typing.workflow import WorkflowFunctionOutput - -if TYPE_CHECKING: - from graphrag_llm.completion import LLMCompletion - -logger = logging.getLogger(__name__) - - -async def run_workflow( - config: GraphRagConfig, - context: PipelineRunContext, -) -> WorkflowFunctionOutput: - """Resolve duplicate entities with different surface forms.""" - logger.info("Workflow started: resolve_entities") - - if not config.entity_resolution.enabled: - logger.info( - "Entity resolution is disabled (entity_resolution.enabled=false), skipping" - ) - return WorkflowFunctionOutput(result={}) - - reader = DataReader(context.output_table_provider) - entities = await reader.entities() - relationships = await reader.relationships() - - resolution_model_config = config.get_completion_model_config( - config.entity_resolution.completion_model_id - ) - resolution_prompts = config.entity_resolution.resolved_prompts() - resolution_model = create_completion( - resolution_model_config, - cache=context.cache.child(config.entity_resolution.model_instance_name), - cache_key_creator=cache_key_creator, - ) - - resolved_entities, resolved_relationships = await resolve_entities( - entities=entities, - relationships=relationships, - callbacks=context.callbacks, - model=resolution_model, - prompt=resolution_prompts.resolution_prompt, - batch_size=config.entity_resolution.batch_size, - num_threads=config.concurrent_requests, - ) - - await context.output_table_provider.write_dataframe("entities", resolved_entities) - await context.output_table_provider.write_dataframe( - "relationships", resolved_relationships - ) - - logger.info("Workflow completed: resolve_entities") - return WorkflowFunctionOutput( - result={ - "entities": resolved_entities, - "relationships": resolved_relationships, - } - ) From 877b4a90c86c87b674860d3938495db6d2e61755 Mon Sep 17 00:00:00 2001 From: Arsen Ohanyan Date: Thu, 19 Feb 2026 21:09:52 -0800 Subject: [PATCH 4/5] Better logging --- .../index/operations/resolve_entities/resolve_entities.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/graphrag/graphrag/index/operations/resolve_entities/resolve_entities.py b/packages/graphrag/graphrag/index/operations/resolve_entities/resolve_entities.py index 19e8c61c9b..8f482becac 100644 --- a/packages/graphrag/graphrag/index/operations/resolve_entities/resolve_entities.py +++ b/packages/graphrag/graphrag/index/operations/resolve_entities/resolve_entities.py @@ -74,8 +74,8 @@ async def resolve_entities( try: response = await model(formatted_prompt) raw = (response or "").strip() - except Exception: - logger.warning("Entity resolution LLM call failed, skipping resolution") + except Exception as e: + logger.warning("Entity resolution LLM call failed, skipping resolution: %s", e, exc_info=True) return entities, relationships if "NO_DUPLICATES" in raw: From dad1cea33b1d2475241d7435e76acf6d65f9e153 Mon Sep 17 00:00:00 2001 From: Arsen Ohanyan Date: Thu, 19 Feb 2026 21:31:59 -0800 Subject: [PATCH 5/5] Fix --- .../index/operations/resolve_entities/resolve_entities.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/graphrag/graphrag/index/operations/resolve_entities/resolve_entities.py b/packages/graphrag/graphrag/index/operations/resolve_entities/resolve_entities.py index 8f482becac..458f939b2f 100644 --- a/packages/graphrag/graphrag/index/operations/resolve_entities/resolve_entities.py +++ b/packages/graphrag/graphrag/index/operations/resolve_entities/resolve_entities.py @@ -72,8 +72,8 @@ async def resolve_entities( formatted_prompt = prompt.format(entity_list=entity_list) try: - response = await model(formatted_prompt) - raw = (response or "").strip() + response = await model.completion_async(messages=formatted_prompt) + raw = (response.content or "").strip() except Exception as e: logger.warning("Entity resolution LLM call failed, skipping resolution: %s", e, exc_info=True) return entities, relationships