diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 9fc682d6d0..097e616db1 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -50,6 +50,9 @@ logger = logging.getLogger(__name__) +# ref: https://www.sqlite.org/limits.html +# Lowest default maximum is 999, intentionally setting it to half +_SQLITE_MAX_BIND_VARS = 500 Model = TypeVar("Model") @@ -202,6 +205,73 @@ def _query_entries( List of model instances representing the rows fetched from the table. """ + def _execute_batched_query( + self, + model_class: type[Model], + *, + batch_column: InstrumentedAttribute[Any], + batch_values: Sequence[Any], + other_conditions: Optional[list[Any]] = None, + distinct: bool = False, + join_scores: bool = False, + ) -> MutableSequence[Model]: + """ + Execute queries in batches to avoid exceeding database bind variable limits. + + SQLite and other databases have per-statement parameter limits. This method + executes separate queries for each batch of values and merges the results. + + Args: + model_class: The SQLAlchemy model class to query. + batch_column: The column to batch the IN condition on. + batch_values: The values to filter by (will be batched). + other_conditions: Additional SQLAlchemy conditions to include in each query. + distinct: Whether to return distinct rows only. + join_scores: Whether to join the scores table. + + Returns: + MutableSequence[Model]: Merged and deduplicated results from all batched queries. + """ + if other_conditions is None: + other_conditions = [] + + # If values fit in one batch, execute a single query + if len(batch_values) <= _SQLITE_MAX_BIND_VARS: + conditions = other_conditions + [batch_column.in_(batch_values)] + return self._query_entries( + model_class, + conditions=and_(*conditions) if conditions else None, + distinct=distinct, + join_scores=join_scores, + ) + + # Execute multiple separate queries and merge results + all_results: MutableSequence[Model] = [] + seen_ids: set[str] = set() + + for i in range(0, len(batch_values), _SQLITE_MAX_BIND_VARS): + batch = batch_values[i : i + _SQLITE_MAX_BIND_VARS] + conditions = other_conditions + [batch_column.in_(batch)] + + results = self._query_entries( + model_class, + conditions=and_(*conditions) if conditions else None, + distinct=distinct, + join_scores=join_scores, + ) + + # Deduplicate by primary key (id) + for result in results: + result_id = str(getattr(result, "id", None)) + if result_id and result_id not in seen_ids: + seen_ids.add(result_id) + all_results.append(result) + elif not result_id: + # If no id attribute, just append + all_results.append(result) + + return all_results + @abc.abstractmethod def _insert_entry(self, entry: Base) -> None: """ @@ -363,8 +433,6 @@ def get_scores( """ conditions: list[Any] = [] - if score_ids: - conditions.append(ScoreEntry.id.in_(score_ids)) if score_type: conditions.append(ScoreEntry.score_type == score_type) if score_category: @@ -374,11 +442,22 @@ def get_scores( if sent_before: conditions.append(ScoreEntry.timestamp <= sent_before) + # Handle score_ids with batched queries if needed + if score_ids: + entries = self._execute_batched_query( + ScoreEntry, + batch_column=ScoreEntry.id, + batch_values=list(score_ids), + other_conditions=conditions, + ) + return [entry.get_score() for entry in entries] + + # No score_ids specified - use regular query if not conditions: return [] - entries: Sequence[ScoreEntry] = self._query_entries(ScoreEntry, conditions=and_(*conditions)) - return [entry.get_score() for entry in entries] + score_entries: Sequence[ScoreEntry] = self._query_entries(ScoreEntry, conditions=and_(*conditions)) + return [entry.get_score() for entry in score_entries] def get_prompt_scores( self, @@ -532,39 +611,86 @@ def get_message_pieces( Exception: If there is an error retrieving the prompts, an exception is logged and an empty list is returned. """ - conditions = [] - if attack_id: - conditions.append(self._get_message_pieces_attack_conditions(attack_id=str(attack_id))) - if role: - conditions.append(PromptMemoryEntry.role == role) - if conversation_id: - conditions.append(PromptMemoryEntry.conversation_id == str(conversation_id)) - if prompt_ids: - prompt_ids = [str(pi) for pi in prompt_ids] - conditions.append(PromptMemoryEntry.id.in_(prompt_ids)) - if labels: - conditions.extend(self._get_message_pieces_memory_label_conditions(memory_labels=labels)) - if prompt_metadata: - conditions.extend(self._get_message_pieces_prompt_metadata_conditions(prompt_metadata=prompt_metadata)) - if sent_after: - conditions.append(PromptMemoryEntry.timestamp >= sent_after) - if sent_before: - conditions.append(PromptMemoryEntry.timestamp <= sent_before) - if original_values: - conditions.append(PromptMemoryEntry.original_value.in_(original_values)) - if converted_values: - conditions.append(PromptMemoryEntry.converted_value.in_(converted_values)) - if data_type: - conditions.append(PromptMemoryEntry.converted_value_data_type == data_type) - if not_data_type: - conditions.append(PromptMemoryEntry.converted_value_data_type != not_data_type) - if converted_value_sha256: - conditions.append(PromptMemoryEntry.converted_value_sha256.in_(converted_value_sha256)) - try: - memory_entries: Sequence[PromptMemoryEntry] = self._query_entries( - PromptMemoryEntry, conditions=and_(*conditions) if conditions else None, join_scores=True + conditions: list[Any] = [] + if attack_id: + conditions.append(self._get_message_pieces_attack_conditions(attack_id=str(attack_id))) + if role: + conditions.append(PromptMemoryEntry.role == role) + if conversation_id: + conditions.append(PromptMemoryEntry.conversation_id == str(conversation_id)) + if labels: + conditions.extend(self._get_message_pieces_memory_label_conditions(memory_labels=labels)) + if prompt_metadata: + conditions.extend(self._get_message_pieces_prompt_metadata_conditions(prompt_metadata=prompt_metadata)) + if sent_after: + conditions.append(PromptMemoryEntry.timestamp >= sent_after) + if sent_before: + conditions.append(PromptMemoryEntry.timestamp <= sent_before) + if data_type: + conditions.append(PromptMemoryEntry.converted_value_data_type == data_type) + if not_data_type: + conditions.append(PromptMemoryEntry.converted_value_data_type != not_data_type) + + # Identify list parameters and whether they need batching + list_params: list[tuple[InstrumentedAttribute[Any], Sequence[Any], str]] = [] + if prompt_ids: + list_params.append((PromptMemoryEntry.id, [str(pi) for pi in prompt_ids], "id")) + if original_values: + list_params.append((PromptMemoryEntry.original_value, list(original_values), "original_value")) + if converted_values: + list_params.append((PromptMemoryEntry.converted_value, list(converted_values), "converted_value")) + if converted_value_sha256: + list_params.append( + (PromptMemoryEntry.converted_value_sha256, list(converted_value_sha256), "converted_value_sha256") + ) + + # If no list params, execute simple query + if not list_params: + memory_entries: Sequence[PromptMemoryEntry] = self._query_entries( + PromptMemoryEntry, + conditions=and_(*conditions) if conditions else None, + join_scores=True, + ) + message_pieces = [memory_entry.get_message_piece() for memory_entry in memory_entries] + return sort_message_pieces(message_pieces=message_pieces) + + # Find which list params need batching (exceed limit) + large_params = [(col, vals, name) for col, vals, name in list_params if len(vals) > _SQLITE_MAX_BIND_VARS] + small_params = [(col, vals, name) for col, vals, name in list_params if len(vals) <= _SQLITE_MAX_BIND_VARS] + + # Add small list params to base conditions + for col, vals, _ in small_params: + conditions.append(col.in_(vals)) + + # If no large params, execute simple query + if not large_params: + memory_entries = self._query_entries( + PromptMemoryEntry, + conditions=and_(*conditions) if conditions else None, + join_scores=True, + ) + message_pieces = [memory_entry.get_message_piece() for memory_entry in memory_entries] + return sort_message_pieces(message_pieces=message_pieces) + + # Batch on the first large parameter + batch_col, batch_vals, _ = large_params[0] + other_large_params = large_params[1:] + + # Execute batched query + memory_entries = self._execute_batched_query( + PromptMemoryEntry, + batch_column=batch_col, + batch_values=batch_vals, + other_conditions=conditions, + join_scores=True, ) + + # If there are additional large params, filter results in Python + for col, vals, attr_name in other_large_params: + vals_set = set(vals) + memory_entries = [e for e in memory_entries if getattr(e, attr_name, None) in vals_set] + message_pieces = [memory_entry.get_message_piece() for memory_entry in memory_entries] return sort_message_pieces(message_pieces=message_pieces) except Exception as e: @@ -1238,37 +1364,73 @@ def get_attack_results( Returns: Sequence[AttackResult]: A list of AttackResult objects that match the specified filters. """ - conditions: list[ColumnElement[bool]] = [] + # Handle empty list cases + if attack_result_ids is not None and len(attack_result_ids) == 0: + return [] + if objective_sha256 is not None and len(objective_sha256) == 0: + return [] - if attack_result_ids is not None: - if len(attack_result_ids) == 0: - # Empty list means no results - return [] - conditions.append(AttackResultEntry.id.in_(attack_result_ids)) + # Build non-list conditions + conditions: list[ColumnElement[bool]] = [] if conversation_id: conditions.append(AttackResultEntry.conversation_id == conversation_id) if objective: conditions.append(AttackResultEntry.objective.contains(objective)) - - if objective_sha256: - conditions.append(AttackResultEntry.objective_sha256.in_(objective_sha256)) if outcome: conditions.append(AttackResultEntry.outcome == outcome) if targeted_harm_categories: - # Use database-specific JSON query method conditions.append( self._get_attack_result_harm_category_condition(targeted_harm_categories=targeted_harm_categories) ) - if labels: - # Use database-specific JSON query method conditions.append(self._get_attack_result_label_condition(labels=labels)) try: - entries: Sequence[AttackResultEntry] = self._query_entries( - AttackResultEntry, conditions=and_(*conditions) if conditions else None + # Identify list parameters and whether they need batching + list_params: list[tuple[InstrumentedAttribute[Any], Sequence[Any], str]] = [] + if attack_result_ids: + list_params.append((AttackResultEntry.id, list(attack_result_ids), "id")) + if objective_sha256: + list_params.append((AttackResultEntry.objective_sha256, list(objective_sha256), "objective_sha256")) + + # If no list params, execute simple query + if not list_params: + entries: Sequence[AttackResultEntry] = self._query_entries( + AttackResultEntry, conditions=and_(*conditions) if conditions else None + ) + return [entry.get_attack_result() for entry in entries] + + # Find which list params need batching + large_params = [(col, vals, name) for col, vals, name in list_params if len(vals) > _SQLITE_MAX_BIND_VARS] + small_params = [(col, vals, name) for col, vals, name in list_params if len(vals) <= _SQLITE_MAX_BIND_VARS] + + # Add small list params to conditions + for col, vals, _ in small_params: + conditions.append(col.in_(vals)) + + # If no large params, execute simple query + if not large_params: + entries = self._query_entries(AttackResultEntry, conditions=and_(*conditions) if conditions else None) + return [entry.get_attack_result() for entry in entries] + + # Batch on the first large parameter + batch_col, batch_vals, _ = large_params[0] + other_large_params = large_params[1:] + + # Execute batched query + entries = self._execute_batched_query( + AttackResultEntry, + batch_column=batch_col, + batch_values=batch_vals, + other_conditions=conditions, ) + + # If there are additional large params, filter results in Python + for col, vals, attr_name in other_large_params: + vals_set = set(vals) + entries = [e for e in entries if getattr(e, attr_name, None) in vals_set] + return [entry.get_attack_result() for entry in entries] except Exception as e: logger.exception(f"Failed to retrieve attack results with error {e}") @@ -1426,18 +1588,12 @@ def get_scenario_results( Returns: Sequence[ScenarioResult]: A list of ScenarioResult objects that match the specified filters. """ - conditions: list[ColumnElement[bool]] = [] + if scenario_result_ids is not None and len(scenario_result_ids) == 0: + return [] - if scenario_result_ids is not None: - if len(scenario_result_ids) == 0: - # Empty list means no results - return [] - conditions.append(ScenarioResultEntry.id.in_(scenario_result_ids)) + conditions: list[ColumnElement[bool]] = [] if scenario_name: - # Normalize CLI snake_case names (e.g., "foundry" or "content_harms") - # to class names (e.g., "Foundry" or "ContentHarms") - # This allows users to query with either format normalized_name = ScenarioResult.normalize_scenario_name(scenario_name) conditions.append(ScenarioResultEntry.scenario_name.contains(normalized_name)) @@ -1454,21 +1610,28 @@ def get_scenario_results( conditions.append(ScenarioResultEntry.completion_time <= added_before) if labels: - # Use database-specific JSON query method conditions.append(self._get_scenario_result_label_condition(labels=labels)) if objective_target_endpoint: - # Use database-specific JSON query method conditions.append(self._get_scenario_result_target_endpoint_condition(endpoint=objective_target_endpoint)) if objective_target_model_name: - # Use database-specific JSON query method conditions.append(self._get_scenario_result_target_model_condition(model_name=objective_target_model_name)) try: - entries: Sequence[ScenarioResultEntry] = self._query_entries( - ScenarioResultEntry, conditions=and_(*conditions) if conditions else None - ) + # Handle scenario_result_ids with batched queries if needed + if scenario_result_ids and len(scenario_result_ids) > _SQLITE_MAX_BIND_VARS: + entries = self._execute_batched_query( + ScenarioResultEntry, + batch_column=ScenarioResultEntry.id, + batch_values=list(scenario_result_ids), + other_conditions=conditions, + ) + elif scenario_result_ids: + conditions = conditions + [ScenarioResultEntry.id.in_(list(scenario_result_ids))] + entries = self._query_entries(ScenarioResultEntry, conditions=and_(*conditions) if conditions else None) + else: + entries = self._query_entries(ScenarioResultEntry, conditions=and_(*conditions) if conditions else None) # Convert entries to ScenarioResults and populate attack_results efficiently scenario_results = [] @@ -1483,13 +1646,19 @@ def get_scenario_results( for conv_ids in conversation_ids_by_attack.values(): all_conversation_ids.extend(conv_ids) - # Query all AttackResults in a single batch if there are any + # Query all AttackResults using batched queries if needed if all_conversation_ids: - # Build condition to query multiple conversation IDs at once - attack_conditions = [AttackResultEntry.conversation_id.in_(all_conversation_ids)] - attack_entries: Sequence[AttackResultEntry] = self._query_entries( - AttackResultEntry, conditions=and_(*attack_conditions) - ) + if len(all_conversation_ids) > _SQLITE_MAX_BIND_VARS: + attack_entries = self._execute_batched_query( + AttackResultEntry, + batch_column=AttackResultEntry.conversation_id, + batch_values=all_conversation_ids, + ) + else: + attack_entries = self._query_entries( + AttackResultEntry, + conditions=AttackResultEntry.conversation_id.in_(all_conversation_ids), + ) # Build a dict for quick lookup attack_results_dict = {entry.conversation_id: entry.get_attack_result() for entry in attack_entries} diff --git a/tests/unit/memory/memory_interface/test_batching_scale.py b/tests/unit/memory/memory_interface/test_batching_scale.py new file mode 100644 index 0000000000..13c7aa24c5 --- /dev/null +++ b/tests/unit/memory/memory_interface/test_batching_scale.py @@ -0,0 +1,449 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for batching functionality to handle large numbers of IDs. +This addresses the scaling bug where methods like get_scores_by_prompt_ids +fail when querying with many IDs due to SQLite bind variable limits. +""" + +import hashlib +import uuid +from unittest.mock import patch + +from pyrit.memory import MemoryInterface +from pyrit.memory.memory_interface import _SQLITE_MAX_BIND_VARS +from pyrit.models import MessagePiece, Score + + +def _create_message_piece(conversation_id: str = None, original_value: str = "test message") -> MessagePiece: + """Create a sample message piece for testing.""" + converted_value = original_value + # Compute SHA256 for converted_value so filtering by sha256 works + sha256 = hashlib.sha256(converted_value.encode("utf-8")).hexdigest() + return MessagePiece( + id=str(uuid.uuid4()), + role="user", + original_value=original_value, + converted_value=converted_value, + converted_value_sha256=sha256, + sequence=0, + conversation_id=conversation_id or str(uuid.uuid4()), + labels={"test": "label"}, + attack_identifier={"id": str(uuid.uuid4())}, + ) + + +def _create_score(message_piece_id: str) -> Score: + """Create a sample score for testing.""" + return Score( + score_value="0.5", + score_value_description="test score", + score_type="float_scale", + score_category=["test"], + score_rationale="test rationale", + score_metadata={}, + scorer_class_identifier={"__type__": "TestScorer"}, + message_piece_id=message_piece_id, + ) + + +class TestBatchingScale: + """Tests for batching when querying with many IDs.""" + + def test_get_message_pieces_with_many_prompt_ids(self, sqlite_instance: MemoryInterface): + """Test that get_message_pieces works with more IDs than the batch limit.""" + # Create more message pieces than the batch limit + num_pieces = _SQLITE_MAX_BIND_VARS + 100 + pieces = [_create_message_piece() for _ in range(num_pieces)] + + # Add to memory + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + # Query with all IDs - this should work with batching + all_ids = [piece.id for piece in pieces] + results = sqlite_instance.get_message_pieces(prompt_ids=all_ids) + + assert len(results) == num_pieces, f"Expected {num_pieces} results, got {len(results)}" + + def test_get_message_pieces_with_exact_batch_size(self, sqlite_instance: MemoryInterface): + """Test that get_message_pieces works with exactly the batch limit.""" + num_pieces = _SQLITE_MAX_BIND_VARS + pieces = [_create_message_piece() for _ in range(num_pieces)] + + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + all_ids = [piece.id for piece in pieces] + results = sqlite_instance.get_message_pieces(prompt_ids=all_ids) + + assert len(results) == num_pieces + + def test_get_message_pieces_with_double_batch_size(self, sqlite_instance: MemoryInterface): + """Test that get_message_pieces works with double the batch limit.""" + num_pieces = _SQLITE_MAX_BIND_VARS * 2 + pieces = [_create_message_piece() for _ in range(num_pieces)] + + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + all_ids = [piece.id for piece in pieces] + results = sqlite_instance.get_message_pieces(prompt_ids=all_ids) + + assert len(results) == num_pieces + + def test_get_scores_with_many_score_ids(self, sqlite_instance: MemoryInterface): + """Test that get_scores works with more IDs than the batch limit.""" + # Create message pieces first (scores need to reference them) + num_scores = _SQLITE_MAX_BIND_VARS + 100 + pieces = [_create_message_piece() for _ in range(num_scores)] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + # Create and add scores + scores = [_create_score(str(piece.id)) for piece in pieces] + sqlite_instance.add_scores_to_memory(scores=scores) + + # Query with all score IDs - this should work with batching + all_score_ids = [str(score.id) for score in scores] + results = sqlite_instance.get_scores(score_ids=all_score_ids) + + assert len(results) == num_scores, f"Expected {num_scores} results, got {len(results)}" + + def test_get_prompt_scores_with_many_prompt_ids(self, sqlite_instance: MemoryInterface): + """Test that get_prompt_scores works with more prompt IDs than the batch limit.""" + # Create message pieces + num_pieces = _SQLITE_MAX_BIND_VARS + 50 + pieces = [_create_message_piece() for _ in range(num_pieces)] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + # Create and add scores for half of them + num_scores = num_pieces // 2 + scores = [_create_score(str(pieces[i].id)) for i in range(num_scores)] + sqlite_instance.add_scores_to_memory(scores=scores) + + # Query with all prompt IDs - should return scores for pieces that have them + all_prompt_ids = [piece.id for piece in pieces] + results = sqlite_instance.get_prompt_scores(prompt_ids=all_prompt_ids) + + assert len(results) == num_scores, f"Expected {num_scores} results, got {len(results)}" + + def test_get_message_pieces_batching_preserves_other_filters(self, sqlite_instance: MemoryInterface): + """Test that batching still applies other filter conditions correctly.""" + # Create pieces with different roles + num_pieces = _SQLITE_MAX_BIND_VARS + 50 + user_pieces = [_create_message_piece() for _ in range(num_pieces)] + for piece in user_pieces: + piece.role = "user" + + assistant_pieces = [_create_message_piece() for _ in range(50)] + for piece in assistant_pieces: + piece.role = "assistant" + + all_pieces = user_pieces + assistant_pieces + sqlite_instance.add_message_pieces_to_memory(message_pieces=all_pieces) + + # Query with all IDs but filter by role + all_ids = [piece.id for piece in all_pieces] + results = sqlite_instance.get_message_pieces(prompt_ids=all_ids, role="user") + + assert len(results) == num_pieces, f"Expected {num_pieces} user pieces, got {len(results)}" + + def test_get_message_pieces_small_list_still_works(self, sqlite_instance: MemoryInterface): + """Test that small ID lists (under batch limit) still work correctly.""" + num_pieces = 10 + pieces = [_create_message_piece() for _ in range(num_pieces)] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + all_ids = [piece.id for piece in pieces] + results = sqlite_instance.get_message_pieces(prompt_ids=all_ids) + + assert len(results) == num_pieces + + def test_get_message_pieces_with_many_original_values(self, sqlite_instance: MemoryInterface): + """Test that get_message_pieces works with many original_values exceeding batch limit.""" + num_pieces = _SQLITE_MAX_BIND_VARS + 100 + # Create pieces with unique original values + pieces = [_create_message_piece(original_value=f"unique_value_{i}") for i in range(num_pieces)] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + # Query with all original values + all_values = [piece.original_value for piece in pieces] + results = sqlite_instance.get_message_pieces(original_values=all_values) + + assert len(results) == num_pieces, f"Expected {num_pieces} results, got {len(results)}" + + def test_get_message_pieces_with_many_converted_value_sha256(self, sqlite_instance: MemoryInterface): + """Test that get_message_pieces works with many converted_value_sha256 exceeding batch limit.""" + num_pieces = _SQLITE_MAX_BIND_VARS + 100 + pieces = [_create_message_piece(original_value=f"unique_value_{i}") for i in range(num_pieces)] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + # Get SHA256 hashes from stored pieces + stored_pieces = sqlite_instance.get_message_pieces() + all_hashes = [piece.converted_value_sha256 for piece in stored_pieces if piece.converted_value_sha256] + + if len(all_hashes) > _SQLITE_MAX_BIND_VARS: + results = sqlite_instance.get_message_pieces(converted_value_sha256=all_hashes) + assert len(results) == len(all_hashes) + + def test_get_message_pieces_combines_filters_correctly(self, sqlite_instance: MemoryInterface): + """Test that multiple filters can be combined (e.g., prompt_ids AND role).""" + # Create message pieces with different roles + num_pieces = 50 + user_pieces = [_create_message_piece() for _ in range(num_pieces)] + for piece in user_pieces: + piece.role = "user" + + assistant_pieces = [_create_message_piece() for _ in range(num_pieces)] + for piece in assistant_pieces: + piece.role = "assistant" + + all_pieces = user_pieces + assistant_pieces + sqlite_instance.add_message_pieces_to_memory(message_pieces=all_pieces) + + # Query with both prompt_ids AND role filter + user_ids = [piece.id for piece in user_pieces] + results = sqlite_instance.get_message_pieces(prompt_ids=user_ids, role="user") + + # Should return only user pieces (intersection of both filters) + assert len(results) == num_pieces + assert all(r.role == "user" for r in results) + + # Query with role filter and a subset of IDs + subset_ids = user_ids[:10] + results = sqlite_instance.get_message_pieces(prompt_ids=subset_ids, role="user") + assert len(results) == 10 + + def test_get_message_pieces_multiple_large_params_simultaneously(self, sqlite_instance: MemoryInterface): + """Test batching with multiple parameters exceeding batch limit simultaneously.""" + # Create enough pieces to exceed batch limit with unique values + num_pieces = _SQLITE_MAX_BIND_VARS + 200 + pieces = [_create_message_piece(original_value=f"original_value_{i}") for i in range(num_pieces)] + + # Add to memory + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + # Get all stored pieces to extract their IDs and SHA256 hashes + stored_pieces = sqlite_instance.get_message_pieces() + assert len(stored_pieces) >= num_pieces + + # Extract multiple large parameter lists + all_ids = [piece.id for piece in stored_pieces[:num_pieces]] + all_original_values = [piece.original_value for piece in stored_pieces[:num_pieces]] + all_sha256 = [piece.converted_value_sha256 for piece in stored_pieces[:num_pieces]] + + # Query with multiple large parameters simultaneously + # This tests that ALL parameters are batched correctly, not just one + results = sqlite_instance.get_message_pieces( + prompt_ids=all_ids, + original_values=all_original_values, + converted_value_sha256=all_sha256, + ) + + # Should return all pieces that match ALL conditions (intersection) + assert len(results) == num_pieces, ( + f"Expected {num_pieces} results when filtering with multiple large parameters, got {len(results)}" + ) + + # Verify all returned pieces match all filter criteria + result_ids = {r.id for r in results} + result_original_values = {r.original_value for r in results} + result_sha256 = {r.converted_value_sha256 for r in results} + + assert result_ids == set(all_ids), "Returned IDs don't match filter" + assert result_original_values == set(all_original_values), "Returned original_values don't match filter" + assert result_sha256 == set(all_sha256), "Returned SHA256 hashes don't match filter" + + def test_get_message_pieces_multiple_batched_params_with_query_spy(self, sqlite_instance: MemoryInterface): + """Test that batching executes multiple separate queries and merges results correctly.""" + # Create pieces exceeding batch limit + num_pieces = _SQLITE_MAX_BIND_VARS + 100 + pieces = [_create_message_piece(original_value=f"value_{i}") for i in range(num_pieces)] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + # Get stored pieces + stored_pieces = sqlite_instance.get_message_pieces() + all_ids = [piece.id for piece in stored_pieces[:num_pieces]] + all_original_values = [piece.original_value for piece in stored_pieces[:num_pieces]] + + # Mock _query_entries to track how it's called + original_query = sqlite_instance._query_entries + call_count = 0 + + def spy_query(*args, **kwargs): + nonlocal call_count + call_count += 1 + return original_query(*args, **kwargs) + + with patch.object(sqlite_instance, "_query_entries", side_effect=spy_query): + results = sqlite_instance.get_message_pieces(prompt_ids=all_ids, original_values=all_original_values) + + # Should get all results despite batching + assert len(results) == num_pieces + + # With the new batching approach, multiple separate queries should be executed + # when the primary batch parameter exceeds _SQLITE_MAX_BIND_VARS + # Expected: ceil(num_pieces / _SQLITE_MAX_BIND_VARS) = 2 queries + expected_min_calls = (num_pieces + _SQLITE_MAX_BIND_VARS - 1) // _SQLITE_MAX_BIND_VARS + assert call_count >= expected_min_calls, ( + f"Expected at least {expected_min_calls} separate queries for {num_pieces} items, " + f"but only got {call_count} calls" + ) + + def test_get_message_pieces_triple_large_params_preserves_intersection(self, sqlite_instance: MemoryInterface): + """Test that filtering with 3 large parameter lists returns correct intersection.""" + # Create a large set of pieces + total_pieces = _SQLITE_MAX_BIND_VARS + 150 + pieces = [ + _create_message_piece(conversation_id=str(uuid.uuid4()), original_value=f"content_{i}") + for i in range(total_pieces) + ] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + # Get stored pieces + stored_pieces = sqlite_instance.get_message_pieces() + + # Create three overlapping large filter lists + # List 1: All IDs + filter_ids = [p.id for p in stored_pieces[:total_pieces]] + + # List 2: All original values + filter_original_values = [p.original_value for p in stored_pieces[:total_pieces]] + + # List 3: Subset of SHA256 hashes (to test intersection) + subset_size = _SQLITE_MAX_BIND_VARS + 50 + filter_sha256 = [p.converted_value_sha256 for p in stored_pieces[:subset_size]] + + # Query with all three large parameters + results = sqlite_instance.get_message_pieces( + prompt_ids=filter_ids, + original_values=filter_original_values, + converted_value_sha256=filter_sha256, + ) + + # Should return only the intersection (subset_size items) + assert len(results) == subset_size, f"Expected {subset_size} results from intersection, got {len(results)}" + + # Verify all results have SHA256 in the filter list + result_sha256 = {r.converted_value_sha256 for r in results} + assert result_sha256.issubset(set(filter_sha256)), "Results contain unexpected SHA256 values" + + +class TestExecuteBatchedQuery: + """Tests for the _execute_batched_query helper method.""" + + def test_execute_batched_query_small_list_single_query(self, sqlite_instance: MemoryInterface): + """Test that small lists execute a single query.""" + # Create a small number of pieces (under batch limit) + num_pieces = 10 + pieces = [_create_message_piece() for _ in range(num_pieces)] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + # Track query calls + original_query = sqlite_instance._query_entries + call_count = 0 + + def spy_query(*args, **kwargs): + nonlocal call_count + call_count += 1 + return original_query(*args, **kwargs) + + with patch.object(sqlite_instance, "_query_entries", side_effect=spy_query): + all_ids = [piece.id for piece in pieces] + results = sqlite_instance.get_message_pieces(prompt_ids=all_ids) + + # Should be a single query for small lists + assert call_count == 1 + assert len(results) == num_pieces + + def test_execute_batched_query_large_list_multiple_queries(self, sqlite_instance: MemoryInterface): + """Test that large lists execute multiple separate queries.""" + # Create pieces exceeding batch limit + num_pieces = _SQLITE_MAX_BIND_VARS * 3 # 3 batches needed + pieces = [_create_message_piece() for _ in range(num_pieces)] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + # Track query calls + original_query = sqlite_instance._query_entries + call_count = 0 + + def spy_query(*args, **kwargs): + nonlocal call_count + call_count += 1 + return original_query(*args, **kwargs) + + with patch.object(sqlite_instance, "_query_entries", side_effect=spy_query): + all_ids = [piece.id for piece in pieces] + results = sqlite_instance.get_message_pieces(prompt_ids=all_ids) + + # Should execute 3 separate queries (one per batch) + assert call_count == 3, f"Expected 3 queries for 3 batches, got {call_count}" + assert len(results) == num_pieces + + def test_execute_batched_query_deduplicates_results(self, sqlite_instance: MemoryInterface): + """Test that batched queries properly deduplicate results.""" + # Create pieces + num_pieces = 50 + pieces = [_create_message_piece() for _ in range(num_pieces)] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + # Query with the same IDs repeated (should still return unique results) + all_ids = [piece.id for piece in pieces] + # Query twice with same IDs - results should still be unique + results = sqlite_instance.get_message_pieces(prompt_ids=all_ids) + + assert len(results) == num_pieces + # Verify no duplicates + result_ids = [r.id for r in results] + assert len(result_ids) == len(set(result_ids)), "Results contain duplicate entries" + + def test_execute_batched_query_exact_batch_boundary(self, sqlite_instance: MemoryInterface): + """Test querying with exactly the batch limit (edge case).""" + num_pieces = _SQLITE_MAX_BIND_VARS + pieces = [_create_message_piece() for _ in range(num_pieces)] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + # Track query calls + original_query = sqlite_instance._query_entries + call_count = 0 + + def spy_query(*args, **kwargs): + nonlocal call_count + call_count += 1 + return original_query(*args, **kwargs) + + with patch.object(sqlite_instance, "_query_entries", side_effect=spy_query): + all_ids = [piece.id for piece in pieces] + results = sqlite_instance.get_message_pieces(prompt_ids=all_ids) + + # Exactly at the limit should still be a single query + assert call_count == 1, f"Expected 1 query at exact batch limit, got {call_count}" + assert len(results) == num_pieces + + def test_batching_with_scores_exceeds_limit(self, sqlite_instance: MemoryInterface): + """Test that get_scores handles large numbers of score IDs correctly.""" + # Create message pieces and scores exceeding the limit + num_items = _SQLITE_MAX_BIND_VARS * 2 + 50 + pieces = [_create_message_piece() for _ in range(num_items)] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + scores = [_create_score(str(piece.id)) for piece in pieces] + sqlite_instance.add_scores_to_memory(scores=scores) + + # Query with all score IDs + all_score_ids = [str(score.id) for score in scores] + + # Track query calls + original_query = sqlite_instance._query_entries + call_count = 0 + + def spy_query(*args, **kwargs): + nonlocal call_count + call_count += 1 + return original_query(*args, **kwargs) + + with patch.object(sqlite_instance, "_query_entries", side_effect=spy_query): + results = sqlite_instance.get_scores(score_ids=all_score_ids) + + # Should execute multiple queries + expected_calls = (num_items + _SQLITE_MAX_BIND_VARS - 1) // _SQLITE_MAX_BIND_VARS + assert call_count == expected_calls, f"Expected {expected_calls} queries, got {call_count}" + assert len(results) == num_items