diff --git a/nitin_docs/index_migrator/99_tickets.md b/nitin_docs/index_migrator/99_tickets.md new file mode 100644 index 00000000..db9c8629 --- /dev/null +++ b/nitin_docs/index_migrator/99_tickets.md @@ -0,0 +1,370 @@ +# Index Migrator Tickets + +--- + +## Milestones + +| Milestone | Theme | Stories | +|-----------|-------|---------| +| M1 | Plan and Execute Single-Index Schema Migrations | IM-01, IM-05 | +| M2 | Interactive Migration Wizard | IM-02 | +| M3 | Rename Indexes, Prefixes, and Fields | IM-06 | +| M4 | Async Execution and Batch Operations | IM-03, IM-04 | +| M5 | Validation Fixes, Integration Tests, and Documentation | IM-07, IM-08, IM-09, IM-10 | + +--- + +## Completed + +### IM-01: Plan, Execute, and Validate Document-Preserving Index Schema Migrations + +**Status:** Done | **Commit:** `a3d534b` | **Milestone:** M1 + +**Story:** As a developer with an existing Redis index, I want to generate a reviewable migration plan, execute a safe drop-and-recreate, and validate the result, so that I can add/remove fields, change vector algorithms (FLAT/HNSW/SVS-VAMANA), change distance metrics (cosine/L2/IP), quantize vectors (float32 to float16/bfloat16/int8/uint8), and tune HNSW parameters (m, ef_construction, ef_runtime, epsilon) — all without losing documents. + +**What This Delivers:** +- **Discovery**: `rvl migrate list` shows all indexes, `rvl migrate helper` explains capabilities +- **Planning**: MigrationPlanner generates a plan from a schema patch or target schema. Captures source snapshot, target schema, classifies changes as supported or blocked. Incompatible changes (dimension, storage type) are rejected at plan time. +- **Execution**: MigrationExecutor drops the index definition (not documents), re-encodes vectors if quantization is needed, and recreates the index with the merged schema. +- **Validation**: MigrationValidator confirms schema match, doc count parity, key sample existence, and functional query correctness post-migration. +- **Reporting**: Structured `migration_report.yaml` with per-phase timings, counts, benchmark summary, and warnings. + +**Key Files:** `redisvl/migration/planner.py`, `executor.py`, `validation.py`, `models.py` + +--- + +### IM-02: Build Migration Plans Interactively via Guided Wizard + +**Status:** Done | **Commit:** `b06e949` | **Milestone:** M2 + +**Story:** As a developer unfamiliar with YAML schema syntax, I want a menu-driven wizard that walks me through adding, removing, updating, and renaming fields with real-time validation, so that I can build a correct migration plan without reading documentation. + +**What This Delivers:** +- `rvl migrate wizard --index ` launches an interactive session +- Menus for: add field (text/tag/numeric/geo), remove field (any type, with vector warning), rename field, update field attributes (sortable, weight, no_stem, phonetic_matcher, separator, case_sensitive, index_missing, index_empty), update vector settings (algorithm, datatype, distance metric, all HNSW and SVS-VAMANA params), rename index, change prefix +- Shows current schema and previews changes before generating plan +- Outputs both `schema_patch.yaml` and `migration_plan.yaml` +- Validates choices against what's actually supported + +**Key Files:** `redisvl/migration/wizard.py` + +--- + +### IM-03: Execute Migrations Asynchronously for Large Indexes + +**Status:** Done | **Commit:** `b559215` | **Milestone:** M4 + +**Story:** As a developer with a large index (1M+ vectors) in an async codebase, I want async migration planning, execution, and validation so that my application remains responsive and I don't block the event loop during long-running migrations. + +**What This Delivers:** +- `AsyncMigrationPlanner`, `AsyncMigrationExecutor`, `AsyncMigrationValidator` classes with full feature parity +- `rvl migrate apply --async` CLI flag +- Same `MigrationPlan` model works for both sync and async +- Same plan format works for both sync and async + +**Key Files:** `redisvl/migration/async_planner.py`, `async_executor.py`, `async_validation.py` + +--- + +### IM-04: Migrate Multiple Indexes in a Single Batch with Failure Isolation and Resume + +**Status:** Done | **Commit:** `61c6e80` | **Milestone:** M4 + +**Story:** As a platform operator with many indexes, I want to apply a shared schema patch to multiple indexes in one operation, choose whether to stop or continue on failure, and resume interrupted batches from a checkpoint, so that I can coordinate migrations during maintenance windows. + +**What This Delivers:** +- `BatchMigrationPlanner` generates per-index plans from a shared patch +- `BatchMigrationExecutor` runs migrations sequentially with state persistence +- Failure policies: `fail_fast` (stop on first error), `continue_on_error` (skip and continue) +- CLI: `batch-plan`, `batch-apply`, `batch-resume`, `batch-status` +- `batch_state.yaml` checkpoint file for resume capability +- `BatchReport` with per-index status and aggregate summary + +**Key Files:** `redisvl/migration/batch_planner.py`, `batch_executor.py` + +--- + +### IM-05: Optimize Document Enumeration Using FT.AGGREGATE Cursors + +**Status:** Done | **Commit:** `9561094` | **Milestone:** M1 + +**Story:** As a developer migrating a large index over a sparse keyspace, I want document enumeration to use the search index directly instead of SCAN, so that migration runs faster and only touches indexed keys. +``` +FT.AGGREGATE idx "*" + LOAD 1 __key # Get document key + WITHCURSOR COUNT 500 # Cursor-based pagination +``` + +**What This Delivers:** +- Executor uses `FT.AGGREGATE ... WITHCURSOR COUNT LOAD 0` for key enumeration +- Falls back to SCAN only when `hash_indexing_failures > 0` (those docs wouldn't appear in aggregate) +- Pre-enumerates all keys before dropping index for reliable re-indexing +- CLI simplified: removed `--allow-downtime` flag (plan review is the safety mechanism) + +**Key Files:** `redisvl/migration/executor.py`, `async_executor.py` + +--- + +### IM-06: Rename Indexes, Change Key Prefixes, and Rename Fields Across Documents + +**Status:** Done | **Commit:** pending | **Milestone:** M3 + +**Story:** As a developer, I want to rename my index, change its key prefix, or rename fields in my schema, so that I can refactor naming conventions without rebuilding from scratch. + +**What This Delivers:** +- Index rename: drop old index, create new with same prefix (no document changes) +- Prefix change: `RENAME` command on every key (single-prefix indexes only, multi-prefix blocked) +- Field rename: `HSET`/`HDEL` for hash, `JSON.SET`/`JSON.DEL` for JSON, on every document +- Execution order: field renames, then key renames, then drop, then recreate +- `RenameOperations` model in migration plan +- Timing fields: `field_rename_duration_seconds`, `key_rename_duration_seconds` +- Warnings issued for expensive operations + +**Key Files:** `redisvl/migration/models.py`, `planner.py`, `executor.py`, `async_executor.py` + +**Spec:** `nitin_docs/index_migrator/30_rename_operations_spec.md` + +--- + +### IM-07: Fix HNSW Parameter Parsing, Weight Normalization, and Algorithm Case Sensitivity + +**Status:** Done | **Commit:** `ab8a017` | **Milestone:** M5 + +**Story:** As a developer, I want post-migration validation to correctly handle HNSW-specific parameters, weight normalization, and algorithm case sensitivity, so that validation doesn't produce false failures. + +**What This Fixes:** +- HNSW-specific parameters (m, ef_construction) were not being parsed from `FT.INFO`, causing validation failures +- Weight int/float normalization mismatch (schema defines `1`, Redis returns `1.0`) +- Algorithm case sensitivity in wizard (schema stores `'hnsw'`, wizard compared to `'HNSW'`) + +**Key Files:** `redisvl/redis/connection.py`, `redisvl/migration/utils.py`, `redisvl/migration/wizard.py` + +--- + +### IM-08: Add Integration Tests for All Supported Migration Routes + +**Status:** Done | **Commit:** `b3d88a0` | **Milestone:** M5 + +**Story:** As a maintainer, I want integration tests covering algorithm changes, quantization, distance metrics, HNSW tuning, and combined migrations, so that regressions are caught before release. + +**What This Delivers:** +- 22 integration tests running full apply+validate against a live Redis instance +- Covers: 9 datatype routes, 4 distance metric routes, 5 HNSW tuning routes, 2 algorithm routes, 2 combined routes +- Tests require Redis 8.0+ for INT8/UINT8 datatypes +- Located in `tests/integration/test_migration_routes.py` + +--- + +### IM-09: Update Migration Documentation to Reflect Rename, Batch, and Redis 8.0 Support + +**Status:** Done | **Commit:** `d452eab` | **Milestone:** M5 + +**Story:** As a user, I want documentation that accurately reflects all supported migration operations, so that I can self-serve without guessing at capabilities. + +**What This Delivers:** +- Updated `docs/concepts/index-migrations.md` to reflect prefix/field rename support +- Updated `docs/user_guide/how_to_guides/migrate-indexes.md` with Redis 8.0 requirements +- Added batch migration commands to CLI reference in `docs/user_guide/cli.ipynb` +- Removed prefix/field rename from "blocked" lists + +--- + +### IM-10: Address PR Review Feedback for Correctness and Consistency + +**Status:** Done | **Commit:** pending | **Milestone:** M5 + +**Story:** As a maintainer, I want code review issues addressed so that the migration engine is correct, consistent, and production-ready. + +**What This Fixes:** +- `merge_patch()` now applies `rename_fields` to merged schema +- `BatchState.success_count` uses correct status string (`"succeeded"`) +- CLI helper text updated to show prefix/rename as supported +- Planner docstring updated to reflect current capabilities +- `batch_plan_path` stored in state for proper resume support +- Fixed `--output` to `--plan-out` in batch migration docs +- Fixed `--indexes` docs to use comma-separated format +- Added validation to block multi-prefix migrations +- Updated migration plan YAML example to match actual model +- Added `skipped_count` property and `[SKIP]` status display + +**Key Files:** `redisvl/migration/planner.py`, `models.py`, `batch_executor.py`, `redisvl/cli/migrate.py`, `docs/user_guide/how_to_guides/migrate-indexes.md` + +--- + +## Pending / Future + +### IM-R1: Add Crash-Safe Quantization with Checkpoint Resume and Pre-Migration Snapshot + +**Status:** Done | **Commit:** `30cc6c1` | **Priority:** High + +**Story:** As a developer running vector quantization on a production index, I want the migration to be resumable if it crashes mid-quantization, so that I don't end up with a partially quantized index and no rollback path. + +**Problem:** +The current quantization flow is: enumerate keys, drop index, quantize vectors in-place, recreate index, validate. If the process crashes during quantization, you're left with no index, a mix of float32 and float16 vectors, and no way to recover. + +**What This Delivers:** +A four-layer reliability model. A pre-migration `BGSAVE` (run sequentially, waited to completion) provides full disaster recovery by restoring the RDB to pre-migration state. A checkpoint file on disk tracks which keys have been quantized, enabling resume from the exact failure point on retry. Each key conversion detects the vector dtype before converting, making it idempotent so already-converted keys are safely skipped on resume. A bounded undo buffer stores originals for only the current in-flight batch, allowing rollback of the batch that was in progress at crash time. + +**Acceptance Criteria:** +1. Pre-migration `BGSAVE` is triggered and completes before any mutations begin +2. A checkpoint file records progress as each batch of keys is quantized +3. `rvl migrate apply --resume` picks up from the last checkpoint and completes the migration +4. Each key conversion is idempotent -- running the migration twice on the same key produces the correct result +5. If a batch fails mid-write, only that batch's vectors are rolled back using the bounded undo buffer +6. A disk space estimator function calculates projected RDB snapshot size, AOF growth, and total new disk required based on doc count, vector dimensions, source/target dtype, and AOF status. The estimator runs before any mutations and prints a human-readable summary. If available disk is below 80% of the estimate, the CLI prompts for confirmation. The estimator also supports a standalone dry-run mode via `rvl migrate estimate --plan plan.yaml`. See `nitin_docs/index_migrator/40_reliability_brainstorm.md` section "Pre-Migration Disk Space Estimator" for the full specification including inputs, outputs (DiskSpaceEstimate dataclass), calculation logic, CLI output format, integration points, and edge cases. + +**Alternatives Considered:** Undo log (WAL-style), new-field-then-swap (side-write), shadow index (blue-green), streaming with bounded undo buffer. See `nitin_docs/index_migrator/40_reliability_brainstorm.md` for full analysis. + +--- + +### IM-B1: Benchmark Float32 vs Float16 Quantization: Search Quality and Migration Performance at Scale + +**Status:** Planned | **Priority:** High + +**Story:** As a developer considering vector quantization to reduce memory, I want benchmarks measuring search quality degradation (precision, recall, F1) and migration performance (throughput, latency, memory savings) across realistic dataset sizes, so that I can make an informed decision about whether the memory-accuracy tradeoff is acceptable for my use case. + +**Problem:** +We tell users they can quantize float32 vectors to float16 to cut memory in half, but we don't have published data showing what they actually lose in search quality or what they can expect in migration performance at different scales. + +**What This Delivers:** +A benchmark script and published results using a real dataset (AG News with sentence-transformers embeddings) that measures two things across multiple dataset sizes (1K, 10K, 100K). For search quality: precision@K, recall@K, and F1@K comparing float32 (ground truth) vs float16 (post-migration) top-K nearest neighbor results. For migration performance: end-to-end duration, quantization throughput (vectors/second), index downtime, pre/post memory footprint, and query latency before and after (p50, p95, p99). + +**Acceptance Criteria:** +1. Benchmark runs end-to-end against a local Redis instance with a single command +2. Uses a real public dataset with real embeddings (not synthetic random vectors) +3. Reports precision@K, recall@K, and F1@K for float32 vs float16 search results +4. Reports per-query statistics (mean, p50, p95, min, max) not just aggregates +5. Runs at multiple dataset sizes (at minimum 1K, 10K, 100K) to show how quality and performance scale +6. Reports memory savings (index size delta in MB) and migration throughput (docs/second) +7. Reports query latency before and after migration +8. Outputs a structured JSON report that can be compared across runs + +**Note:** Benchmark script scaffolded at `tests/benchmarks/index_migrator_real_benchmark.py`. + +--- + +### IM-11: Run Old and New Indexes in Parallel for Incompatible Changes with Operator-Controlled Cutover + +**Status:** Future | **Priority:** Medium + +**Story:** As a developer changing vector dimensions or storage type, I want to run old and new indexes in parallel until I'm confident in the new one, so that I can migrate without downtime and rollback if needed. + +**Context:** +Some migrations cannot use `drop_recreate` because the stored data is incompatible (dimension changes, storage type changes, complex payload restructuring). Shadow migration creates a new index alongside the old one, copies/transforms documents, validates, then hands off cutover to the operator. + +**What This Requires:** +- Capacity estimation (can Redis hold both indexes?) +- Shadow index creation +- Document copy with optional transform +- Progress tracking with resume +- Validation gate before cutover +- Operator handoff for cutover decision +- Cleanup of old index/keys after cutover + +**Spec:** `nitin_docs/index_migrator/20_v2_iterative_shadow_spec.md` + +--- + +### IM-12: Pipeline Vector Reads During Quantization to Reduce Round Trips on Large Datasets + +**Status:** Backlog | **Priority:** Low + +**Story:** As a developer migrating large datasets, I want quantization reads to be pipelined so that migration completes faster. + +**Context:** +Current quantization implementation does O(N) round trips for reads (one `HGET` per key/field) while only pipelining writes. For large datasets this is slow. + +**What This Requires:** +- Pipeline all reads in a batch before processing +- Use `transaction=False` for read pipeline +- Add JSON storage support (`JSON.GET`/`JSON.SET`) for JSON indexes + +--- + +### IM-13: Wire ValidationPolicy Enforcement into Validators or Remove the Unused Model + +**Status:** Backlog | **Priority:** Low + +**Story:** As a developer, I want to skip certain validation checks (e.g., doc count) when I know they'll fail due to expected conditions. + +**Context:** +`MigrationPlan.validation` (ValidationPolicy) exists in the model but is not enforced by validators. Schema/doc-count mismatches always produce errors. + +**What This Requires:** +- Wire `ValidationPolicy.require_doc_count_match` into validators +- Add CLI flag to set policy during plan creation +- Or remove unused ValidationPolicy model + +--- + +### IM-14: Clean Up Unused Imports and Linting Across the Codebase + +**Status:** Backlog | **Priority:** Low + +**Story:** As a maintainer, I want clean linting so that CI is reliable and code quality is consistent. + +**Context:** +During development, pyflakes identified unused imports across the codebase. These were fixed in migration files but not committed for non-migration files to keep the PR focused. + +**What This Requires:** +- Fix remaining unused imports (see `nitin_docs/issues/unused_imports_cleanup.md`) +- Update `.pylintrc` to remove deprecated Python 2/3 compat options +- Consider adding `check-lint` to the main `lint` target after cleanup + +--- + +### IM-15: Use RENAMENX for Prefix Migrations to Fail Fast on Key Collisions + +**Status:** Backlog | **Priority:** Low + +**Story:** As a developer changing key prefixes, I want the migration to fail fast if target keys already exist, so I don't end up with a partially renamed keyspace. + +**Context:** +Current implementation uses `RENAME` without checking if destination key exists. If a target key exists, RENAME will error and the pipeline may abort, leaving a partially-renamed keyspace. + +**What This Requires:** +- Preflight check for key collisions or use `RENAMENX` +- Surface hard error rather than warning +- Consider rollback strategy + +--- + +### IM-16: Auto-Detect AOF Status for Disk Space Estimation + +**Status:** Backlog | **Priority:** Low + +**Story:** As an operator running `rvl migrate estimate`, I want the disk space estimate to automatically detect whether AOF is enabled on the target Redis instance, so that AOF growth is included in the estimate without me needing to know or pass a flag. + +**Context:** +The disk space estimator (`estimate_disk_space`) is a pure calculation that accepts `aof_enabled` as a parameter (default `False`). In CLI usage, this means AOF growth is never estimated unless the caller explicitly passes `aof_enabled=True`. The summary currently prints "not estimated (pass aof_enabled=True if AOF is on)" which is accurate but requires the operator to know their Redis config. + +**What This Requires:** +- Add `--aof-enabled` flag to `rvl migrate estimate` CLI for offline/pure-calculation use +- During `rvl migrate apply`, read `CONFIG GET appendonly` from the live Redis connection and pass the result to `estimate_disk_space` +- Handle `CONFIG GET` failures gracefully (e.g. ACL restrictions) by falling back to the current "not estimated" behavior + +--- + +## Summary + +| Ticket | Title | Status | +|--------|-------|--------| +| IM-01 | Plan, Execute, and Validate Document-Preserving Index Schema Migrations | Done | +| IM-02 | Build Migration Plans Interactively via Guided Wizard | Done | +| IM-03 | Execute Migrations Asynchronously for Large Indexes | Done | +| IM-04 | Migrate Multiple Indexes in a Single Batch with Failure Isolation and Resume | Done | +| IM-05 | Optimize Document Enumeration Using FT.AGGREGATE Cursors | Done | +| IM-06 | Rename Indexes, Change Key Prefixes, and Rename Fields Across Documents | Done | +| IM-07 | Fix HNSW Parameter Parsing, Weight Normalization, and Algorithm Case Sensitivity | Done | +| IM-08 | Add Integration Tests for All Supported Migration Routes | Done | +| IM-09 | Update Migration Documentation to Reflect Rename, Batch, and Redis 8.0 Support | Done | +| IM-10 | Address PR Review Feedback for Correctness and Consistency | Done | +| IM-R1 | Add Crash-Safe Quantization with Checkpoint Resume and Pre-Migration Snapshot | Done | +| IM-B1 | Benchmark Float32 vs Float16 Quantization: Search Quality and Migration Performance at Scale | Planned | +| IM-11 | Run Old and New Indexes in Parallel for Incompatible Changes with Operator-Controlled Cutover | Future | +| IM-12 | Pipeline Vector Reads During Quantization to Reduce Round Trips on Large Datasets | Backlog | +| IM-13 | Wire ValidationPolicy Enforcement into Validators or Remove the Unused Model | Backlog | +| IM-14 | Clean Up Unused Imports and Linting Across the Codebase | Backlog | +| IM-15 | Use RENAMENX for Prefix Migrations to Fail Fast on Key Collisions | Backlog | +| IM-16 | Auto-Detect AOF Status for Disk Space Estimation | Backlog | + diff --git a/redisvl/cli/migrate.py b/redisvl/cli/migrate.py index fdcd8a2f..fa56dd71 100644 --- a/redisvl/cli/migrate.py +++ b/redisvl/cli/migrate.py @@ -14,6 +14,8 @@ MigrationValidator, ) from redisvl.migration.utils import ( + detect_aof_enabled, + estimate_disk_space, list_indexes, load_migration_plan, load_yaml, @@ -21,6 +23,7 @@ write_migration_report, ) from redisvl.migration.wizard import MigrationWizard +from redisvl.redis.connection import RedisConnectionFactory from redisvl.utils.log import get_logger logger = get_logger("[RedisVL]") @@ -36,6 +39,7 @@ class Migrate: "\tplan Generate a migration plan for a document-preserving drop/recreate migration", "\twizard Interactively build a migration plan and schema patch", "\tapply Execute a reviewed drop/recreate migration plan (use --async for large migrations)", + "\testimate Estimate disk space required for a migration plan (dry-run, no mutations)", "\tvalidate Validate a completed migration plan against the live index", "", "Batch Commands:", @@ -91,7 +95,7 @@ def helper(self): - Changing field options (sortable, separator, weight) - Changing vector algorithm (FLAT, HNSW, SVS_VAMANA) - Changing distance metric (COSINE, L2, IP) - - Tuning algorithm parameters (M, EF_CONSTRUCTION) + - Tuning algorithm parameters (M, EF_CONSTRUCTION, EF_RUNTIME, EPSILON) - Quantizing vectors (float32 to float16/bfloat16/int8/uint8) - Changing key prefix (renames all keys) - Renaming fields (updates all documents) @@ -212,7 +216,8 @@ def apply(self): parser = argparse.ArgumentParser( usage=( "rvl migrate apply --plan " - "[--async] [--report-out ]" + "[--async] [--resume ] " + "[--report-out ]" ) ) parser.add_argument("--plan", help="Path to migration_plan.yaml", required=True) @@ -222,6 +227,12 @@ def apply(self): help="Use async executor (recommended for large migrations with quantization)", action="store_true", ) + parser.add_argument( + "--resume", + dest="checkpoint_path", + help="Path to quantization checkpoint file for crash-safe resume", + default=None, + ) parser.add_argument( "--report-out", help="Path to write migration_report.yaml", @@ -243,19 +254,62 @@ def apply(self): redis_url = create_redis_url(args) plan = load_migration_plan(args.plan) + # Print disk space estimate for quantization migrations + aof_enabled = False + try: + client = RedisConnectionFactory.get_redis_connection(redis_url=redis_url) + try: + aof_enabled = detect_aof_enabled(client) + finally: + client.close() + except Exception as exc: + logger.debug("Could not detect AOF for CLI preflight estimate: %s", exc) + + disk_estimate = estimate_disk_space(plan, aof_enabled=aof_enabled) + if disk_estimate.has_quantization: + print(f"\n{disk_estimate.summary()}\n") + + checkpoint_path = args.checkpoint_path if args.use_async: report = asyncio.run( - self._apply_async(plan, redis_url, args.query_check_file) + self._apply_async( + plan, redis_url, args.query_check_file, checkpoint_path + ) ) else: - report = self._apply_sync(plan, redis_url, args.query_check_file) + report = self._apply_sync( + plan, redis_url, args.query_check_file, checkpoint_path + ) write_migration_report(report, args.report_out) if args.benchmark_out: write_benchmark_report(report, args.benchmark_out) self._print_report_summary(args.report_out, report, args.benchmark_out) - def _apply_sync(self, plan, redis_url: str, query_check_file: Optional[str]): + def estimate(self): + """Estimate disk space required for a migration plan (dry-run).""" + parser = argparse.ArgumentParser( + usage="rvl migrate estimate --plan " + ) + parser.add_argument("--plan", help="Path to migration_plan.yaml", required=True) + parser.add_argument( + "--aof-enabled", + action="store_true", + help="Include AOF growth in the disk space estimate", + ) + args = parser.parse_args(sys.argv[3:]) + + plan = load_migration_plan(args.plan) + disk_estimate = estimate_disk_space(plan, aof_enabled=args.aof_enabled) + print(disk_estimate.summary()) + + def _apply_sync( + self, + plan, + redis_url: str, + query_check_file: Optional[str], + checkpoint_path: Optional[str] = None, + ): """Execute migration synchronously.""" executor = MigrationExecutor() @@ -263,11 +317,13 @@ def _apply_sync(self, plan, redis_url: str, query_check_file: Optional[str]): def progress_callback(step: str, detail: Optional[str]) -> None: step_labels = { - "drop": "[1/5] Drop index", - "quantize": "[2/5] Quantize vectors", - "create": "[3/5] Create index", - "index": "[4/5] Re-indexing", - "validate": "[5/5] Validate", + "enumerate": "[1/6] Enumerate keys", + "bgsave": "[2/6] BGSAVE snapshot", + "drop": "[3/6] Drop index", + "quantize": "[4/6] Quantize vectors", + "create": "[5/6] Create index", + "index": "[6/6] Re-indexing", + "validate": "Validate", } label = step_labels.get(step, step) if detail and not detail.startswith("done"): @@ -280,12 +336,19 @@ def progress_callback(step: str, detail: Optional[str]) -> None: redis_url=redis_url, query_check_file=query_check_file, progress_callback=progress_callback, + checkpoint_path=checkpoint_path, ) self._print_apply_result(report) return report - async def _apply_async(self, plan, redis_url: str, query_check_file: Optional[str]): + async def _apply_async( + self, + plan, + redis_url: str, + query_check_file: Optional[str], + checkpoint_path: Optional[str] = None, + ): """Execute migration asynchronously (non-blocking for large quantization jobs).""" executor = AsyncMigrationExecutor() @@ -293,11 +356,13 @@ async def _apply_async(self, plan, redis_url: str, query_check_file: Optional[st def progress_callback(step: str, detail: Optional[str]) -> None: step_labels = { - "drop": "[1/5] Drop index", - "quantize": "[2/5] Quantize vectors", - "create": "[3/5] Create index", - "index": "[4/5] Re-indexing", - "validate": "[5/5] Validate", + "enumerate": "[1/6] Enumerate keys", + "bgsave": "[2/6] BGSAVE snapshot", + "drop": "[3/6] Drop index", + "quantize": "[4/6] Quantize vectors", + "create": "[5/6] Create index", + "index": "[6/6] Re-indexing", + "validate": "Validate", } label = step_labels.get(step, step) if detail and not detail.startswith("done"): @@ -310,6 +375,7 @@ def progress_callback(step: str, detail: Optional[str]) -> None: redis_url=redis_url, query_check_file=query_check_file, progress_callback=progress_callback, + checkpoint_path=checkpoint_path, ) self._print_apply_result(report) @@ -487,7 +553,11 @@ def batch_plan(self): args = parser.parse_args(sys.argv[3:]) redis_url = create_redis_url(args) - indexes = args.indexes.split(",") if args.indexes else None + indexes = ( + [idx.strip() for idx in args.indexes.split(",") if idx.strip()] + if args.indexes + else None + ) planner = BatchMigrationPlanner() batch_plan = planner.create_batch_plan( @@ -545,7 +615,7 @@ def batch_apply(self): If you need to preserve original vectors, backup your data first: redis-cli BGSAVE""" ) - return + exit(1) redis_url = create_redis_url(args) executor = BatchMigrationExecutor() diff --git a/redisvl/migration/__init__.py b/redisvl/migration/__init__.py index b6e3b86c..443b7961 100644 --- a/redisvl/migration/__init__.py +++ b/redisvl/migration/__init__.py @@ -13,6 +13,7 @@ BatchPlan, BatchReport, BatchState, + DiskSpaceEstimate, FieldRename, MigrationPlan, MigrationReport, @@ -25,6 +26,7 @@ __all__ = [ # Sync + "DiskSpaceEstimate", "MigrationExecutor", "MigrationPlan", "MigrationPlanner", diff --git a/redisvl/migration/async_executor.py b/redisvl/migration/async_executor.py index 13945180..e3880bd4 100644 --- a/redisvl/migration/async_executor.py +++ b/redisvl/migration/async_executor.py @@ -17,7 +17,20 @@ MigrationTimings, MigrationValidation, ) -from redisvl.migration.utils import timestamp_utc +from redisvl.migration.reliability import ( + BatchUndoBuffer, + QuantizationCheckpoint, + async_trigger_bgsave_and_wait, + is_already_quantized, + is_same_width_dtype_conversion, +) +from redisvl.migration.utils import ( + build_scan_match_patterns, + estimate_disk_space, + get_schema_field_path, + normalize_keys, + timestamp_utc, +) from redisvl.redis.utils import array_to_buffer, buffer_to_array from redisvl.types import AsyncRedisClient @@ -35,11 +48,32 @@ class AsyncMigrationExecutor: def __init__(self, validator: Optional[AsyncMigrationValidator] = None): self.validator = validator or AsyncMigrationValidator() + async def _detect_aof_enabled(self, client: Any) -> bool: + """Best-effort detection of whether AOF is enabled on the live Redis.""" + try: + info = await client.info("persistence") + if isinstance(info, dict) and "aof_enabled" in info: + return bool(int(info["aof_enabled"])) + except Exception: + logger.debug("Could not read Redis INFO persistence for AOF detection.") + + try: + config = await client.config_get("appendonly") + if isinstance(config, dict): + value = config.get("appendonly") + if value is not None: + return str(value).lower() in {"yes", "1", "true", "on"} + except Exception: + logger.debug("Could not read Redis CONFIG GET appendonly.") + + return False + async def _enumerate_indexed_keys( self, client: AsyncRedisClient, index_name: str, batch_size: int = 1000, + key_separator: str = ":", ) -> AsyncGenerator[str, None]: """Async version: Enumerate document keys using FT.AGGREGATE with SCAN fallback. @@ -58,13 +92,15 @@ async def _enumerate_indexed_keys( "Using SCAN for complete enumeration." ) async for key in self._enumerate_with_scan( - client, index_name, batch_size + client, index_name, batch_size, key_separator ): yield key return except Exception as e: logger.warning(f"Failed to check index info: {e}. Using SCAN fallback.") - async for key in self._enumerate_with_scan(client, index_name, batch_size): + async for key in self._enumerate_with_scan( + client, index_name, batch_size, key_separator + ): yield key return @@ -78,7 +114,9 @@ async def _enumerate_indexed_keys( logger.warning( f"FT.AGGREGATE failed: {e}. Falling back to SCAN enumeration." ) - async for key in self._enumerate_with_scan(client, index_name, batch_size): + async for key in self._enumerate_with_scan( + client, index_name, batch_size, key_separator + ): yield key async def _enumerate_with_aggregate( @@ -138,6 +176,7 @@ async def _enumerate_with_scan( client: AsyncRedisClient, index_name: str, batch_size: int = 1000, + key_separator: str = ":", ) -> AsyncGenerator[str, None]: """Async version: Enumerate keys using SCAN with prefix matching.""" # Get prefix from index info @@ -157,25 +196,32 @@ async def _enumerate_with_scan( if d in (b"prefixes", "prefixes") and j + 1 < len(defn): prefixes = defn[j + 1] break - prefix = prefixes[0] if prefixes else "" - if isinstance(prefix, bytes): - prefix = prefix.decode() + normalized_prefixes = [ + p.decode() if isinstance(p, bytes) else str(p) for p in prefixes + ] except Exception as e: logger.warning(f"Failed to get prefix from index info: {e}") - prefix = "" - - cursor: int = 0 - while True: - cursor, keys = await client.scan( - cursor=cursor, - match=f"{prefix}*" if prefix else "*", - count=batch_size, - ) - for key in keys: - yield key.decode() if isinstance(key, bytes) else str(key) + normalized_prefixes = [] - if cursor == 0: - break + seen_keys: set[str] = set() + for match_pattern in build_scan_match_patterns( + normalized_prefixes, key_separator + ): + cursor: int = 0 + while True: + cursor, keys = await client.scan( + cursor=cursor, + match=match_pattern, + count=batch_size, + ) + for key in keys: + key_str = key.decode() if isinstance(key, bytes) else str(key) + if key_str not in seen_keys: + seen_keys.add(key_str) + yield key_str + + if cursor == 0: + break async def _rename_keys( self, @@ -299,6 +345,7 @@ async def apply( redis_client: Optional[AsyncRedisClient] = None, query_check_file: Optional[str] = None, progress_callback: Optional[Callable[[str, Optional[str]], None]] = None, + checkpoint_path: Optional[str] = None, ) -> MigrationReport: """Apply a migration plan asynchronously. @@ -308,6 +355,8 @@ async def apply( redis_client: Optional existing async Redis client. query_check_file: Optional file with query checks. progress_callback: Optional callback(step, detail) for progress updates. + checkpoint_path: Optional path for quantization checkpoint file. + When provided, enables crash-safe resume for vector re-encoding. """ started_at = timestamp_utc() started = time.perf_counter() @@ -329,26 +378,63 @@ async def apply( report.finished_at = timestamp_utc() return report - if not await self._async_current_source_matches_snapshot( - plan.source.index_name, - plan.source.schema_snapshot, - redis_url=redis_url, - redis_client=redis_client, - ): - report.validation.errors.append( - "The current live source schema no longer matches the saved source snapshot." + # Check if we are resuming from a checkpoint (post-drop crash). + # If so, the source index may no longer exist in Redis, so we + # skip live schema validation and construct from the plan snapshot. + resuming_from_checkpoint = False + if checkpoint_path: + existing_checkpoint = QuantizationCheckpoint.load(checkpoint_path) + if existing_checkpoint is not None: + # Validate checkpoint belongs to this migration and is incomplete + if existing_checkpoint.index_name != plan.source.index_name: + logger.warning( + "Checkpoint index '%s' does not match plan index '%s', ignoring", + existing_checkpoint.index_name, + plan.source.index_name, + ) + elif existing_checkpoint.status == "completed": + logger.info( + "Checkpoint at %s is already completed, ignoring", + checkpoint_path, + ) + else: + resuming_from_checkpoint = True + logger.info( + "Checkpoint found at %s, skipping source index validation " + "(index may have been dropped before crash)", + checkpoint_path, + ) + + if not resuming_from_checkpoint: + if not await self._async_current_source_matches_snapshot( + plan.source.index_name, + plan.source.schema_snapshot, + redis_url=redis_url, + redis_client=redis_client, + ): + report.validation.errors.append( + "The current live source schema no longer matches the saved source snapshot." + ) + report.manual_actions.append( + "Re-run `rvl migrate plan` to refresh the migration plan before applying." + ) + report.finished_at = timestamp_utc() + return report + + source_index = await AsyncSearchIndex.from_existing( + plan.source.index_name, + redis_url=redis_url, + redis_client=redis_client, ) - report.manual_actions.append( - "Re-run `rvl migrate plan` to refresh the migration plan before applying." + else: + # Source index was dropped before crash; reconstruct from snapshot + # to get a valid AsyncSearchIndex with a Redis client attached. + source_index = AsyncSearchIndex.from_dict( + plan.source.schema_snapshot, + redis_url=redis_url, + redis_client=redis_client, ) - report.finished_at = timestamp_utc() - return report - source_index = await AsyncSearchIndex.from_existing( - plan.source.index_name, - redis_url=redis_url, - redis_client=redis_client, - ) target_index = AsyncSearchIndex.from_dict( plan.merged_target_schema, redis_url=redis_url, @@ -365,6 +451,7 @@ async def apply( target_info: Dict[str, Any] = {} docs_quantized = 0 keys_to_process: List[str] = [] + storage_type = plan.source.keyspace.storage_type datatype_changes = AsyncMigrationPlanner.get_vector_datatype_changes( plan.source.schema_snapshot, plan.merged_target_schema @@ -374,7 +461,24 @@ async def apply( rename_ops = plan.rename_operations has_prefix_change = bool(rename_ops.change_prefix) has_field_renames = bool(rename_ops.rename_fields) - needs_enumeration = datatype_changes or has_prefix_change or has_field_renames + needs_quantization = bool(datatype_changes) and storage_type != "json" + needs_enumeration = needs_quantization or has_prefix_change or has_field_renames + has_same_width_quantization = any( + is_same_width_dtype_conversion(change["source"], change["target"]) + for change in datatype_changes.values() + ) + + if checkpoint_path and has_same_width_quantization: + report.validation.errors.append( + "Crash-safe resume is not supported for same-width datatype " + "changes (float16<->bfloat16 or int8<->uint8)." + ) + report.manual_actions.append( + "Re-run without --resume for same-width vector conversions, or " + "split the migration to avoid same-width datatype changes." + ) + report.finished_at = timestamp_utc() + return report def _notify(step: str, detail: Optional[str] = None) -> None: if progress_callback: @@ -384,67 +488,140 @@ def _notify(step: str, detail: Optional[str] = None) -> None: client = source_index._redis_client if client is None: raise ValueError("Failed to get Redis client from source index") - storage_type = plan.source.keyspace.storage_type - - # STEP 1: Enumerate keys BEFORE any modifications - if needs_enumeration: - _notify("enumerate", "Enumerating indexed documents...") - enumerate_started = time.perf_counter() - keys_to_process = [ - key - async for key in self._enumerate_indexed_keys( - client, plan.source.index_name, batch_size=1000 - ) - ] - enumerate_duration = round(time.perf_counter() - enumerate_started, 3) - _notify( - "enumerate", - f"found {len(keys_to_process):,} documents ({enumerate_duration}s)", + aof_enabled = await self._detect_aof_enabled(client) + disk_estimate = estimate_disk_space(plan, aof_enabled=aof_enabled) + if disk_estimate.has_quantization: + logger.info( + "Disk space estimate: RDB ~%d bytes, AOF ~%d bytes, total ~%d bytes", + disk_estimate.rdb_snapshot_disk_bytes, + disk_estimate.aof_growth_bytes, + disk_estimate.total_new_disk_bytes, ) + report.disk_space_estimate = disk_estimate + + if resuming_from_checkpoint: + # On resume after a post-drop crash, the index no longer + # exists. Enumerate keys via SCAN using the plan prefix, + # and skip BGSAVE / field renames / drop (already done). + if needs_enumeration: + _notify("enumerate", "Enumerating documents via SCAN (resume)...") + enumerate_started = time.perf_counter() + prefixes = list(plan.source.keyspace.prefixes) + if has_prefix_change and rename_ops.change_prefix: + prefixes = [rename_ops.change_prefix] + seen_keys: set[str] = set() + for match_pattern in build_scan_match_patterns( + prefixes, plan.source.keyspace.key_separator + ): + cursor: int = 0 + while True: + cursor, scanned = await client.scan( # type: ignore[misc] + cursor=cursor, + match=match_pattern, + count=1000, + ) + for k in scanned: + key = k.decode() if isinstance(k, bytes) else str(k) + if key not in seen_keys: + seen_keys.add(key) + keys_to_process.append(key) + if cursor == 0: + break + keys_to_process = normalize_keys(keys_to_process) + enumerate_duration = round( + time.perf_counter() - enumerate_started, 3 + ) + _notify( + "enumerate", + f"found {len(keys_to_process):,} documents ({enumerate_duration}s)", + ) - # STEP 2: Field renames (before dropping index) - if has_field_renames and keys_to_process: - _notify("field_rename", "Renaming fields in documents...") - field_rename_started = time.perf_counter() - for field_rename in rename_ops.rename_fields: - if storage_type == "json": - old_path = f"$.{field_rename.old_name}" - new_path = f"$.{field_rename.new_name}" - await self._rename_field_in_json( - client, - keys_to_process, - old_path, - new_path, - progress_callback=lambda done, total: _notify( - "field_rename", - f"{field_rename.old_name} -> {field_rename.new_name}: {done:,}/{total:,}", - ), - ) - else: - await self._rename_field_in_hash( + _notify("bgsave", "skipped (resume)") + _notify("drop", "skipped (already dropped)") + else: + # Normal (non-resume) path + # STEP 1: Enumerate keys BEFORE any modifications + if needs_enumeration: + _notify("enumerate", "Enumerating indexed documents...") + enumerate_started = time.perf_counter() + keys_to_process = [ + key + async for key in self._enumerate_indexed_keys( client, - keys_to_process, - field_rename.old_name, - field_rename.new_name, - progress_callback=lambda done, total: _notify( - "field_rename", - f"{field_rename.old_name} -> {field_rename.new_name}: {done:,}/{total:,}", - ), + plan.source.index_name, + batch_size=1000, + key_separator=plan.source.keyspace.key_separator, ) - field_rename_duration = round( - time.perf_counter() - field_rename_started, 3 - ) - _notify("field_rename", f"done ({field_rename_duration}s)") + ] + keys_to_process = normalize_keys(keys_to_process) + enumerate_duration = round( + time.perf_counter() - enumerate_started, 3 + ) + _notify( + "enumerate", + f"found {len(keys_to_process):,} documents ({enumerate_duration}s)", + ) + + # BGSAVE safety net: snapshot data before mutations begin + if needs_enumeration and keys_to_process: + _notify("bgsave", "Triggering BGSAVE safety snapshot...") + try: + await async_trigger_bgsave_and_wait(client) + _notify("bgsave", "done") + except Exception as e: + logger.warning("BGSAVE safety snapshot failed: %s", e) + _notify("bgsave", f"skipped ({e})") + + # STEP 2: Field renames (before dropping index) + if has_field_renames and keys_to_process: + _notify("field_rename", "Renaming fields in documents...") + field_rename_started = time.perf_counter() + for field_rename in rename_ops.rename_fields: + if storage_type == "json": + old_path = get_schema_field_path( + plan.source.schema_snapshot, field_rename.old_name + ) + new_path = get_schema_field_path( + plan.merged_target_schema, field_rename.new_name + ) + if not old_path or not new_path or old_path == new_path: + continue + await self._rename_field_in_json( + client, + keys_to_process, + old_path, + new_path, + progress_callback=lambda done, total: _notify( + "field_rename", + f"{field_rename.old_name} -> {field_rename.new_name}: {done:,}/{total:,}", + ), + ) + else: + await self._rename_field_in_hash( + client, + keys_to_process, + field_rename.old_name, + field_rename.new_name, + progress_callback=lambda done, total: _notify( + "field_rename", + f"{field_rename.old_name} -> {field_rename.new_name}: {done:,}/{total:,}", + ), + ) + field_rename_duration = round( + time.perf_counter() - field_rename_started, 3 + ) + _notify("field_rename", f"done ({field_rename_duration}s)") - # STEP 3: Drop the index - _notify("drop", "Dropping index definition...") - drop_started = time.perf_counter() - await source_index.delete(drop=False) - drop_duration = round(time.perf_counter() - drop_started, 3) - _notify("drop", f"done ({drop_duration}s)") + # STEP 3: Drop the index + _notify("drop", "Dropping index definition...") + drop_started = time.perf_counter() + await source_index.delete(drop=False) + drop_duration = round(time.perf_counter() - drop_started, 3) + _notify("drop", f"done ({drop_duration}s)") # STEP 4: Key renames (after drop, before recreate) - if has_prefix_change and keys_to_process: + # On resume, key renames were already done before the crash. + if has_prefix_change and keys_to_process and not resuming_from_checkpoint: _notify("key_rename", "Renaming keys...") key_rename_started = time.perf_counter() old_prefix = plan.source.keyspace.prefixes[0] @@ -466,11 +643,15 @@ def _notify(step: str, detail: Optional[str] = None) -> None: ) # STEP 5: Re-encode vectors using pre-enumerated keys - if datatype_changes and keys_to_process: + if needs_quantization and keys_to_process: _notify("quantize", "Re-encoding vectors...") quantize_started = time.perf_counter() - # If we renamed keys, update keys_to_process to new names - if has_prefix_change and rename_ops.change_prefix: + # If we renamed keys (non-resume), update keys_to_process + if ( + has_prefix_change + and rename_ops.change_prefix + and not resuming_from_checkpoint + ): old_prefix = plan.source.keyspace.prefixes[0] new_prefix = rename_ops.change_prefix keys_to_process = [ @@ -481,6 +662,7 @@ def _notify(step: str, detail: Optional[str] = None) -> None: ) for k in keys_to_process ] + keys_to_process = normalize_keys(keys_to_process) docs_quantized = await self._async_quantize_vectors( source_index, datatype_changes, @@ -488,6 +670,7 @@ def _notify(step: str, detail: Optional[str] = None) -> None: progress_callback=lambda done, total: _notify( "quantize", f"{done:,}/{total:,} docs" ), + checkpoint_path=checkpoint_path, ) quantize_duration = round(time.perf_counter() - quantize_started, 3) _notify( @@ -498,6 +681,15 @@ def _notify(step: str, detail: Optional[str] = None) -> None: f"Re-encoded {docs_quantized} documents for vector quantization: " f"{datatype_changes}" ) + elif datatype_changes and storage_type == "json": + if checkpoint_path and not resuming_from_checkpoint: + checkpoint = QuantizationCheckpoint( + index_name=source_index.name, + total_keys=len(keys_to_process), + checkpoint_path=checkpoint_path, + ) + checkpoint.save() + _notify("quantize", "skipped (JSON vectors are re-indexed on recreate)") _notify("create", "Creating index with new schema...") recreate_started = time.perf_counter() @@ -608,20 +800,24 @@ def _index_progress(indexed: int, total: int, pct: float) -> None: async def _async_quantize_vectors( self, source_index: AsyncSearchIndex, - datatype_changes: Dict[str, Dict[str, str]], + datatype_changes: Dict[str, Dict[str, Any]], keys: List[str], progress_callback: Optional[Callable[[int, int], None]] = None, + checkpoint_path: Optional[str] = None, ) -> int: """Re-encode vectors in documents for datatype changes (quantization). Uses pre-enumerated keys (from _enumerate_indexed_keys) to process only the documents that were in the index, avoiding full keyspace scan. + Includes idempotent skip (already-quantized vectors), bounded undo + buffer for per-batch rollback, and optional checkpointing for resume. Args: source_index: The source AsyncSearchIndex (already dropped but client available) - datatype_changes: Dict mapping field_name -> {"source": dtype, "target": dtype} + datatype_changes: Dict mapping field_name -> {"source", "target", "dims"} keys: Pre-enumerated list of document keys to process progress_callback: Optional callback(docs_done, total_docs) + checkpoint_path: Optional path for checkpoint file (enables resume) Returns: Number of documents processed @@ -632,32 +828,145 @@ async def _async_quantize_vectors( total_keys = len(keys) docs_processed = 0 + docs_quantized = 0 + skipped = 0 batch_size = 500 - for i in range(0, total_keys, batch_size): + # Load or create checkpoint for resume support + checkpoint: Optional[QuantizationCheckpoint] = None + if checkpoint_path: + checkpoint = QuantizationCheckpoint.load(checkpoint_path) + if checkpoint: + # Skip if checkpoint shows a completed migration + if checkpoint.status == "completed": + logger.info( + "Checkpoint already marked as completed for index '%s'. " + "Skipping quantization. Remove the checkpoint file to force re-run.", + checkpoint.index_name, + ) + return 0 + # Validate checkpoint matches current migration + if checkpoint.index_name != source_index.name: + raise ValueError( + f"Checkpoint index '{checkpoint.index_name}' does not match " + f"source index '{source_index.name}'. " + f"Use the correct checkpoint file or remove it to start fresh." + ) + if checkpoint.total_keys != total_keys: + if checkpoint.processed_keys: + current_keys = set(keys) + missing_processed = [ + key + for key in checkpoint.processed_keys + if key not in current_keys + ] + if missing_processed or total_keys < checkpoint.total_keys: + raise ValueError( + f"Checkpoint total_keys={checkpoint.total_keys} does not match " + f"the current key set ({total_keys}). " + "Use the correct checkpoint file or remove it to start fresh." + ) + logger.warning( + "Checkpoint total_keys=%d differs from current key set size=%d. " + "Proceeding because all legacy processed keys are present.", + checkpoint.total_keys, + total_keys, + ) + else: + raise ValueError( + f"Checkpoint total_keys={checkpoint.total_keys} does not match " + f"the current key set ({total_keys}). " + "Use the correct checkpoint file or remove it to start fresh." + ) + remaining = checkpoint.get_remaining_keys(keys) + logger.info( + "Resuming from checkpoint: %d/%d keys already processed", + total_keys - len(remaining), + total_keys, + ) + docs_processed = total_keys - len(remaining) + keys = remaining + total_keys_for_progress = total_keys + else: + checkpoint = QuantizationCheckpoint( + index_name=source_index.name, + total_keys=total_keys, + checkpoint_path=checkpoint_path, + ) + checkpoint.save() + total_keys_for_progress = total_keys + else: + total_keys_for_progress = total_keys + + remaining_keys = len(keys) + + for i in range(0, remaining_keys, batch_size): batch = keys[i : i + batch_size] pipe = client.pipeline() + undo = BatchUndoBuffer() keys_updated_in_batch: set[str] = set() - for key in batch: - # Read all vector fields that need conversion - for field_name, change in datatype_changes.items(): - field_data: bytes | None = await client.hget(key, field_name) # type: ignore[misc,assignment] - if field_data: - # Convert: source dtype -> array -> target dtype -> bytes + try: + for key in batch: + for field_name, change in datatype_changes.items(): + field_data: bytes | None = await client.hget(key, field_name) # type: ignore[misc,assignment] + if not field_data: + continue + + # Idempotent: skip if already converted to target dtype + dims = change.get("dims", 0) + if dims and is_already_quantized( + field_data, dims, change["source"], change["target"] + ): + skipped += 1 + continue + + undo.store(key, field_name, field_data) array = buffer_to_array(field_data, change["source"]) new_bytes = array_to_buffer(array, change["target"]) pipe.hset(key, field_name, new_bytes) # type: ignore[arg-type] keys_updated_in_batch.add(key) - if keys_updated_in_batch: - await pipe.execute() - docs_processed += len(keys_updated_in_batch) - if progress_callback: - progress_callback(docs_processed, total_keys) + if keys_updated_in_batch: + await pipe.execute() + except Exception: + logger.warning( + "Batch %d failed, rolling back %d entries", + i // batch_size, + undo.size, + ) + rollback_pipe = client.pipeline() + await undo.async_rollback(rollback_pipe) + if checkpoint: + checkpoint.save() + raise + finally: + undo.clear() + + docs_quantized += len(keys_updated_in_batch) + docs_processed += len(batch) + + if checkpoint: + # Record all keys in batch (including skipped) so they + # are not re-scanned on resume + checkpoint.record_batch(batch) + checkpoint.save() - logger.info(f"Quantized {docs_processed} documents: {datatype_changes}") - return docs_processed + if progress_callback: + progress_callback(docs_processed, total_keys_for_progress) + + if checkpoint: + checkpoint.mark_complete() + checkpoint.save() + + if skipped: + logger.info("Skipped %d already-quantized vector fields", skipped) + logger.info( + "Quantized %d documents across %d fields", + docs_quantized, + len(datatype_changes), + ) + return docs_quantized async def _async_wait_for_index_ready( self, diff --git a/redisvl/migration/async_validation.py b/redisvl/migration/async_validation.py index 7242784f..7b9691c4 100644 --- a/redisvl/migration/async_validation.py +++ b/redisvl/migration/async_validation.py @@ -121,7 +121,7 @@ async def _run_query_checks( passed=fetched is not None, details=( "Document fetched successfully" - if fetched + if fetched is not None else "Document not found" ), ) diff --git a/redisvl/migration/batch_executor.py b/redisvl/migration/batch_executor.py index 60dff9bd..68ba6362 100644 --- a/redisvl/migration/batch_executor.py +++ b/redisvl/migration/batch_executor.py @@ -170,6 +170,11 @@ def resume( """ state = self._load_state(state_path) plan_path = batch_plan_path or state.plan_path + if not plan_path or not plan_path.strip(): + raise ValueError( + "No batch plan path available. Provide batch_plan_path explicitly, " + "or ensure the checkpoint state contains a valid plan_path." + ) batch_plan = self._load_batch_plan(plan_path) # Optionally retry failed indexes diff --git a/redisvl/migration/batch_planner.py b/redisvl/migration/batch_planner.py index 00a5d9c1..33c265c4 100644 --- a/redisvl/migration/batch_planner.py +++ b/redisvl/migration/batch_planner.py @@ -149,7 +149,9 @@ def _load_indexes_from_file(self, file_path: str) -> List[str]: lines = f.readlines() return [ - line.strip() for line in lines if line.strip() and not line.startswith("#") + stripped + for line in lines + if (stripped := line.strip()) and not stripped.startswith("#") ] def _check_index_applicability( diff --git a/redisvl/migration/executor.py b/redisvl/migration/executor.py index 523129a1..327c4922 100644 --- a/redisvl/migration/executor.py +++ b/redisvl/migration/executor.py @@ -15,8 +15,20 @@ MigrationValidation, ) from redisvl.migration.planner import MigrationPlanner +from redisvl.migration.reliability import ( + BatchUndoBuffer, + QuantizationCheckpoint, + is_already_quantized, + is_same_width_dtype_conversion, + trigger_bgsave_and_wait, +) from redisvl.migration.utils import ( + build_scan_match_patterns, current_source_matches_snapshot, + detect_aof_enabled, + estimate_disk_space, + get_schema_field_path, + normalize_keys, timestamp_utc, wait_for_index_ready, ) @@ -36,6 +48,7 @@ def _enumerate_indexed_keys( client: SyncRedisClient, index_name: str, batch_size: int = 1000, + key_separator: str = ":", ) -> Generator[str, None, None]: """Enumerate document keys using FT.AGGREGATE with SCAN fallback. @@ -48,6 +61,7 @@ def _enumerate_indexed_keys( client: Redis client index_name: Name of the index to enumerate batch_size: Number of keys per batch + key_separator: Separator between prefix and key ID Yields: Document keys as strings @@ -61,11 +75,15 @@ def _enumerate_indexed_keys( f"Index '{index_name}' has {failures} indexing failures. " "Using SCAN for complete enumeration." ) - yield from self._enumerate_with_scan(client, index_name, batch_size) + yield from self._enumerate_with_scan( + client, index_name, batch_size, key_separator + ) return except Exception as e: logger.warning(f"Failed to check index info: {e}. Using SCAN fallback.") - yield from self._enumerate_with_scan(client, index_name, batch_size) + yield from self._enumerate_with_scan( + client, index_name, batch_size, key_separator + ) return # Try FT.AGGREGATE enumeration @@ -75,7 +93,9 @@ def _enumerate_indexed_keys( logger.warning( f"FT.AGGREGATE failed: {e}. Falling back to SCAN enumeration." ) - yield from self._enumerate_with_scan(client, index_name, batch_size) + yield from self._enumerate_with_scan( + client, index_name, batch_size, key_separator + ) def _enumerate_with_aggregate( self, @@ -141,6 +161,7 @@ def _enumerate_with_scan( client: SyncRedisClient, index_name: str, batch_size: int = 1000, + key_separator: str = ":", ) -> Generator[str, None, None]: """Enumerate keys using SCAN with prefix matching. @@ -166,28 +187,32 @@ def _enumerate_with_scan( if d in (b"prefixes", "prefixes") and j + 1 < len(defn): prefixes = defn[j + 1] break - prefix = prefixes[0] if prefixes else "" - if isinstance(prefix, bytes): - prefix = prefix.decode() + normalized_prefixes = [ + p.decode() if isinstance(p, bytes) else str(p) for p in prefixes + ] except Exception as e: logger.warning(f"Failed to get prefix from index info: {e}") - prefix = "" - - if not prefix: - logger.warning("No prefix found for index, SCAN may return unexpected keys") + normalized_prefixes = [] - cursor = 0 - while True: - cursor, keys = client.scan( # type: ignore[misc] - cursor=cursor, - match=f"{prefix}*" if prefix else "*", - count=batch_size, - ) - for key in keys: - yield key.decode() if isinstance(key, bytes) else str(key) + seen_keys: set[str] = set() + for match_pattern in build_scan_match_patterns( + normalized_prefixes, key_separator + ): + cursor = 0 + while True: + cursor, keys = client.scan( # type: ignore[misc] + cursor=cursor, + match=match_pattern, + count=batch_size, + ) + for key in keys: + key_str = key.decode() if isinstance(key, bytes) else str(key) + if key_str not in seen_keys: + seen_keys.add(key_str) + yield key_str - if cursor == 0: - break + if cursor == 0: + break def _rename_keys( self, @@ -342,6 +367,7 @@ def apply( redis_client: Optional[Any] = None, query_check_file: Optional[str] = None, progress_callback: Optional[Callable[[str, Optional[str]], None]] = None, + checkpoint_path: Optional[str] = None, ) -> MigrationReport: """Apply a migration plan. @@ -353,6 +379,8 @@ def apply( progress_callback: Optional callback(step, detail) for progress updates. step: Current step name (e.g., "drop", "quantize", "create", "index", "validate") detail: Optional detail string (e.g., "1000/5000 docs (20%)") + checkpoint_path: Optional path for quantization checkpoint file. + When provided, enables crash-safe resume for vector re-encoding. """ started_at = timestamp_utc() started = time.perf_counter() @@ -374,26 +402,63 @@ def apply( report.finished_at = timestamp_utc() return report - if not current_source_matches_snapshot( - plan.source.index_name, - plan.source.schema_snapshot, - redis_url=redis_url, - redis_client=redis_client, - ): - report.validation.errors.append( - "The current live source schema no longer matches the saved source snapshot." + # Check if we are resuming from a checkpoint (post-drop crash). + # If so, the source index may no longer exist in Redis, so we + # skip live schema validation and construct from the plan snapshot. + resuming_from_checkpoint = False + if checkpoint_path: + existing_checkpoint = QuantizationCheckpoint.load(checkpoint_path) + if existing_checkpoint is not None: + # Validate checkpoint belongs to this migration and is incomplete + if existing_checkpoint.index_name != plan.source.index_name: + logger.warning( + "Checkpoint index '%s' does not match plan index '%s', ignoring", + existing_checkpoint.index_name, + plan.source.index_name, + ) + elif existing_checkpoint.status == "completed": + logger.info( + "Checkpoint at %s is already completed, ignoring", + checkpoint_path, + ) + else: + resuming_from_checkpoint = True + logger.info( + "Checkpoint found at %s, skipping source index validation " + "(index may have been dropped before crash)", + checkpoint_path, + ) + + if not resuming_from_checkpoint: + if not current_source_matches_snapshot( + plan.source.index_name, + plan.source.schema_snapshot, + redis_url=redis_url, + redis_client=redis_client, + ): + report.validation.errors.append( + "The current live source schema no longer matches the saved source snapshot." + ) + report.manual_actions.append( + "Re-run `rvl migrate plan` to refresh the migration plan before applying." + ) + report.finished_at = timestamp_utc() + return report + + source_index = SearchIndex.from_existing( + plan.source.index_name, + redis_url=redis_url, + redis_client=redis_client, ) - report.manual_actions.append( - "Re-run `rvl migrate plan` to refresh the migration plan before applying." + else: + # Source index was dropped before crash; reconstruct from snapshot + # to get a valid SearchIndex with a Redis client attached. + source_index = SearchIndex.from_dict( + plan.source.schema_snapshot, + redis_url=redis_url, + redis_client=redis_client, ) - report.finished_at = timestamp_utc() - return report - source_index = SearchIndex.from_existing( - plan.source.index_name, - redis_url=redis_url, - redis_client=redis_client, - ) target_index = SearchIndex.from_dict( plan.merged_target_schema, redis_url=redis_url, @@ -410,6 +475,7 @@ def apply( target_info: Dict[str, Any] = {} docs_quantized = 0 keys_to_process: List[str] = [] + storage_type = plan.source.keyspace.storage_type # Check if we need to re-encode vectors for datatype changes datatype_changes = MigrationPlanner.get_vector_datatype_changes( @@ -420,7 +486,24 @@ def apply( rename_ops = plan.rename_operations has_prefix_change = bool(rename_ops.change_prefix) has_field_renames = bool(rename_ops.rename_fields) - needs_enumeration = datatype_changes or has_prefix_change or has_field_renames + needs_quantization = bool(datatype_changes) and storage_type != "json" + needs_enumeration = needs_quantization or has_prefix_change or has_field_renames + has_same_width_quantization = any( + is_same_width_dtype_conversion(change["source"], change["target"]) + for change in datatype_changes.values() + ) + + if checkpoint_path and has_same_width_quantization: + report.validation.errors.append( + "Crash-safe resume is not supported for same-width datatype " + "changes (float16<->bfloat16 or int8<->uint8)." + ) + report.manual_actions.append( + "Re-run without --resume for same-width vector conversions, or " + "split the migration to avoid same-width datatype changes." + ) + report.finished_at = timestamp_utc() + return report def _notify(step: str, detail: Optional[str] = None) -> None: if progress_callback: @@ -428,69 +511,143 @@ def _notify(step: str, detail: Optional[str] = None) -> None: try: client = source_index._redis_client - storage_type = plan.source.keyspace.storage_type - - # STEP 1: Enumerate keys BEFORE any modifications - # Needed for: quantization, prefix change, or field renames - if needs_enumeration: - _notify("enumerate", "Enumerating indexed documents...") - enumerate_started = time.perf_counter() - keys_to_process = list( - self._enumerate_indexed_keys( - client, plan.source.index_name, batch_size=1000 - ) - ) - enumerate_duration = round(time.perf_counter() - enumerate_started, 3) - _notify( - "enumerate", - f"found {len(keys_to_process):,} documents ({enumerate_duration}s)", + aof_enabled = detect_aof_enabled(client) + disk_estimate = estimate_disk_space(plan, aof_enabled=aof_enabled) + if disk_estimate.has_quantization: + logger.info( + "Disk space estimate: RDB ~%d bytes, AOF ~%d bytes, total ~%d bytes", + disk_estimate.rdb_snapshot_disk_bytes, + disk_estimate.aof_growth_bytes, + disk_estimate.total_new_disk_bytes, ) + report.disk_space_estimate = disk_estimate + + if resuming_from_checkpoint: + # On resume after a post-drop crash, the index no longer + # exists. Enumerate keys via SCAN using the plan prefix, + # and skip BGSAVE / field renames / drop (already done). + if needs_enumeration: + _notify("enumerate", "Enumerating documents via SCAN (resume)...") + enumerate_started = time.perf_counter() + prefixes = list(plan.source.keyspace.prefixes) + # If a prefix change was part of the migration, keys + # were already renamed before the crash, so scan with + # the new prefix instead. + if has_prefix_change and rename_ops.change_prefix: + prefixes = [rename_ops.change_prefix] + seen_keys: set[str] = set() + for match_pattern in build_scan_match_patterns( + prefixes, plan.source.keyspace.key_separator + ): + cursor: int = 0 + while True: + cursor, scanned = client.scan( # type: ignore[misc] + cursor=cursor, + match=match_pattern, + count=1000, + ) + for k in scanned: + key = k.decode() if isinstance(k, bytes) else str(k) + if key not in seen_keys: + seen_keys.add(key) + keys_to_process.append(key) + if cursor == 0: + break + keys_to_process = normalize_keys(keys_to_process) + enumerate_duration = round( + time.perf_counter() - enumerate_started, 3 + ) + _notify( + "enumerate", + f"found {len(keys_to_process):,} documents ({enumerate_duration}s)", + ) - # STEP 2: Field renames (before dropping index, while docs are still indexed) - if has_field_renames and keys_to_process: - _notify("field_rename", "Renaming fields in documents...") - field_rename_started = time.perf_counter() - for field_rename in rename_ops.rename_fields: - if storage_type == "json": - # For JSON, use JSON paths - old_path = f"$.{field_rename.old_name}" - new_path = f"$.{field_rename.new_name}" - self._rename_field_in_json( - client, - keys_to_process, - old_path, - new_path, - progress_callback=lambda done, total: _notify( - "field_rename", - f"{field_rename.old_name} -> {field_rename.new_name}: {done:,}/{total:,}", - ), - ) - else: - # For HASH, use field names directly - self._rename_field_in_hash( + _notify("bgsave", "skipped (resume)") + _notify("drop", "skipped (already dropped)") + else: + # Normal (non-resume) path + # STEP 1: Enumerate keys BEFORE any modifications + # Needed for: quantization, prefix change, or field renames + if needs_enumeration: + _notify("enumerate", "Enumerating indexed documents...") + enumerate_started = time.perf_counter() + keys_to_process = list( + self._enumerate_indexed_keys( client, - keys_to_process, - field_rename.old_name, - field_rename.new_name, - progress_callback=lambda done, total: _notify( - "field_rename", - f"{field_rename.old_name} -> {field_rename.new_name}: {done:,}/{total:,}", - ), + plan.source.index_name, + batch_size=1000, + key_separator=plan.source.keyspace.key_separator, ) - field_rename_duration = round( - time.perf_counter() - field_rename_started, 3 - ) - _notify("field_rename", f"done ({field_rename_duration}s)") + ) + keys_to_process = normalize_keys(keys_to_process) + enumerate_duration = round( + time.perf_counter() - enumerate_started, 3 + ) + _notify( + "enumerate", + f"found {len(keys_to_process):,} documents ({enumerate_duration}s)", + ) + + # BGSAVE safety net: snapshot data before mutations begin + if needs_enumeration and keys_to_process: + _notify("bgsave", "Triggering BGSAVE safety snapshot...") + try: + trigger_bgsave_and_wait(client) + _notify("bgsave", "done") + except Exception as e: + logger.warning("BGSAVE safety snapshot failed: %s", e) + _notify("bgsave", f"skipped ({e})") + + # STEP 2: Field renames (before dropping index) + if has_field_renames and keys_to_process: + _notify("field_rename", "Renaming fields in documents...") + field_rename_started = time.perf_counter() + for field_rename in rename_ops.rename_fields: + if storage_type == "json": + old_path = get_schema_field_path( + plan.source.schema_snapshot, field_rename.old_name + ) + new_path = get_schema_field_path( + plan.merged_target_schema, field_rename.new_name + ) + if not old_path or not new_path or old_path == new_path: + continue + self._rename_field_in_json( + client, + keys_to_process, + old_path, + new_path, + progress_callback=lambda done, total: _notify( + "field_rename", + f"{field_rename.old_name} -> {field_rename.new_name}: {done:,}/{total:,}", + ), + ) + else: + self._rename_field_in_hash( + client, + keys_to_process, + field_rename.old_name, + field_rename.new_name, + progress_callback=lambda done, total: _notify( + "field_rename", + f"{field_rename.old_name} -> {field_rename.new_name}: {done:,}/{total:,}", + ), + ) + field_rename_duration = round( + time.perf_counter() - field_rename_started, 3 + ) + _notify("field_rename", f"done ({field_rename_duration}s)") - # STEP 3: Drop the index - _notify("drop", "Dropping index definition...") - drop_started = time.perf_counter() - source_index.delete(drop=False) - drop_duration = round(time.perf_counter() - drop_started, 3) - _notify("drop", f"done ({drop_duration}s)") + # STEP 3: Drop the index + _notify("drop", "Dropping index definition...") + drop_started = time.perf_counter() + source_index.delete(drop=False) + drop_duration = round(time.perf_counter() - drop_started, 3) + _notify("drop", f"done ({drop_duration}s)") # STEP 4: Key renames (after drop, before recreate) - if has_prefix_change and keys_to_process: + # On resume, key renames were already done before the crash. + if has_prefix_change and keys_to_process and not resuming_from_checkpoint: _notify("key_rename", "Renaming keys...") key_rename_started = time.perf_counter() old_prefix = plan.source.keyspace.prefixes[0] @@ -512,11 +669,15 @@ def _notify(step: str, detail: Optional[str] = None) -> None: ) # STEP 5: Re-encode vectors using pre-enumerated keys - if datatype_changes and keys_to_process: + if needs_quantization and keys_to_process: _notify("quantize", "Re-encoding vectors...") quantize_started = time.perf_counter() - # If we renamed keys, update keys_to_process to new names - if has_prefix_change and rename_ops.change_prefix: + # If we renamed keys (non-resume), update keys_to_process + if ( + has_prefix_change + and rename_ops.change_prefix + and not resuming_from_checkpoint + ): old_prefix = plan.source.keyspace.prefixes[0] new_prefix = rename_ops.change_prefix keys_to_process = [ @@ -527,6 +688,7 @@ def _notify(step: str, detail: Optional[str] = None) -> None: ) for k in keys_to_process ] + keys_to_process = normalize_keys(keys_to_process) docs_quantized = self._quantize_vectors( source_index, datatype_changes, @@ -534,6 +696,7 @@ def _notify(step: str, detail: Optional[str] = None) -> None: progress_callback=lambda done, total: _notify( "quantize", f"{done:,}/{total:,} docs" ), + checkpoint_path=checkpoint_path, ) quantize_duration = round(time.perf_counter() - quantize_started, 3) _notify( @@ -544,6 +707,15 @@ def _notify(step: str, detail: Optional[str] = None) -> None: f"Re-encoded {docs_quantized} documents for vector quantization: " f"{datatype_changes}" ) + elif datatype_changes and storage_type == "json": + if checkpoint_path and not resuming_from_checkpoint: + checkpoint = QuantizationCheckpoint( + index_name=source_index.name, + total_keys=len(keys_to_process), + checkpoint_path=checkpoint_path, + ) + checkpoint.save() + _notify("quantize", "skipped (JSON vectors are re-indexed on recreate)") _notify("create", "Creating index with new schema...") recreate_started = time.perf_counter() @@ -652,20 +824,24 @@ def _index_progress(indexed: int, total: int, pct: float) -> None: def _quantize_vectors( self, source_index: SearchIndex, - datatype_changes: Dict[str, Dict[str, str]], + datatype_changes: Dict[str, Dict[str, Any]], keys: List[str], progress_callback: Optional[Callable[[int, int], None]] = None, + checkpoint_path: Optional[str] = None, ) -> int: """Re-encode vectors in documents for datatype changes (quantization). Uses pre-enumerated keys (from _enumerate_indexed_keys) to process only the documents that were in the index, avoiding full keyspace scan. + Includes idempotent skip (already-quantized vectors), bounded undo + buffer for per-batch rollback, and optional checkpointing for resume. Args: source_index: The source SearchIndex (already dropped but client available) - datatype_changes: Dict mapping field_name -> {"source": dtype, "target": dtype} + datatype_changes: Dict mapping field_name -> {"source", "target", "dims"} keys: Pre-enumerated list of document keys to process progress_callback: Optional callback(docs_done, total_docs) + checkpoint_path: Optional path for checkpoint file (enables resume) Returns: Number of documents processed @@ -673,32 +849,145 @@ def _quantize_vectors( client = source_index._redis_client total_keys = len(keys) docs_processed = 0 + docs_quantized = 0 + skipped = 0 batch_size = 500 - for i in range(0, total_keys, batch_size): + # Load or create checkpoint for resume support + checkpoint: Optional[QuantizationCheckpoint] = None + if checkpoint_path: + checkpoint = QuantizationCheckpoint.load(checkpoint_path) + if checkpoint: + # Skip if checkpoint shows a completed migration + if checkpoint.status == "completed": + logger.info( + "Checkpoint already marked as completed for index '%s'. " + "Skipping quantization. Remove the checkpoint file to force re-run.", + checkpoint.index_name, + ) + return 0 + # Validate checkpoint matches current migration + if checkpoint.index_name != source_index.name: + raise ValueError( + f"Checkpoint index '{checkpoint.index_name}' does not match " + f"source index '{source_index.name}'. " + f"Use the correct checkpoint file or remove it to start fresh." + ) + if checkpoint.total_keys != total_keys: + if checkpoint.processed_keys: + current_keys = set(keys) + missing_processed = [ + key + for key in checkpoint.processed_keys + if key not in current_keys + ] + if missing_processed or total_keys < checkpoint.total_keys: + raise ValueError( + f"Checkpoint total_keys={checkpoint.total_keys} does not match " + f"the current key set ({total_keys}). " + "Use the correct checkpoint file or remove it to start fresh." + ) + logger.warning( + "Checkpoint total_keys=%d differs from current key set size=%d. " + "Proceeding because all legacy processed keys are present.", + checkpoint.total_keys, + total_keys, + ) + else: + raise ValueError( + f"Checkpoint total_keys={checkpoint.total_keys} does not match " + f"the current key set ({total_keys}). " + "Use the correct checkpoint file or remove it to start fresh." + ) + remaining = checkpoint.get_remaining_keys(keys) + logger.info( + "Resuming from checkpoint: %d/%d keys already processed", + total_keys - len(remaining), + total_keys, + ) + docs_processed = total_keys - len(remaining) + keys = remaining + total_keys_for_progress = total_keys + else: + checkpoint = QuantizationCheckpoint( + index_name=source_index.name, + total_keys=total_keys, + checkpoint_path=checkpoint_path, + ) + checkpoint.save() + total_keys_for_progress = total_keys + else: + total_keys_for_progress = total_keys + + remaining_keys = len(keys) + + for i in range(0, remaining_keys, batch_size): batch = keys[i : i + batch_size] pipe = client.pipeline() - keys_updated_in_batch = set() + undo = BatchUndoBuffer() + keys_updated_in_batch: set[str] = set() - for key in batch: - # Read all vector fields that need conversion - for field_name, change in datatype_changes.items(): - field_data: bytes | None = client.hget(key, field_name) # type: ignore[misc,assignment] - if field_data: - # Convert: source dtype -> array -> target dtype -> bytes + try: + for key in batch: + for field_name, change in datatype_changes.items(): + field_data: bytes | None = client.hget(key, field_name) # type: ignore[misc,assignment] + if not field_data: + continue + + # Idempotent: skip if already converted to target dtype + dims = change.get("dims", 0) + if dims and is_already_quantized( + field_data, dims, change["source"], change["target"] + ): + skipped += 1 + continue + + undo.store(key, field_name, field_data) array = buffer_to_array(field_data, change["source"]) new_bytes = array_to_buffer(array, change["target"]) pipe.hset(key, field_name, new_bytes) # type: ignore[arg-type] keys_updated_in_batch.add(key) - if keys_updated_in_batch: - pipe.execute() - docs_processed += len(keys_updated_in_batch) - if progress_callback: - progress_callback(docs_processed, total_keys) + if keys_updated_in_batch: + pipe.execute() + except Exception: + logger.warning( + "Batch %d failed, rolling back %d entries", + i // batch_size, + undo.size, + ) + rollback_pipe = client.pipeline() + undo.rollback(rollback_pipe) + if checkpoint: + checkpoint.save() + raise + finally: + undo.clear() + + docs_quantized += len(keys_updated_in_batch) + docs_processed += len(batch) + + if checkpoint: + # Record all keys in batch (including skipped) so they + # are not re-scanned on resume + checkpoint.record_batch(batch) + checkpoint.save() - logger.info(f"Quantized {docs_processed} documents: {datatype_changes}") - return docs_processed + if progress_callback: + progress_callback(docs_processed, total_keys_for_progress) + + if checkpoint: + checkpoint.mark_complete() + checkpoint.save() + + if skipped: + logger.info("Skipped %d already-quantized vector fields", skipped) + logger.info( + "Quantized %d documents across %d fields", + docs_quantized, + len(datatype_changes), + ) + return docs_quantized def _build_benchmark_summary( self, diff --git a/redisvl/migration/models.py b/redisvl/migration/models.py index 9d84044c..b03f0398 100644 --- a/redisvl/migration/models.py +++ b/redisvl/migration/models.py @@ -139,10 +139,135 @@ class MigrationReport(BaseModel): benchmark_summary: MigrationBenchmarkSummary = Field( default_factory=MigrationBenchmarkSummary ) + disk_space_estimate: Optional["DiskSpaceEstimate"] = None warnings: List[str] = Field(default_factory=list) manual_actions: List[str] = Field(default_factory=list) +# ----------------------------------------------------------------------------- +# Disk Space Estimation +# ----------------------------------------------------------------------------- + +# Bytes per element for each vector datatype +DTYPE_BYTES: Dict[str, int] = { + "float64": 8, + "float32": 4, + "float16": 2, + "bfloat16": 2, + "int8": 1, + "uint8": 1, +} + +# AOF protocol overhead per HSET command (RESP framing) +AOF_HSET_OVERHEAD_BYTES = 114 +# JSON.SET has slightly larger RESP framing +AOF_JSON_SET_OVERHEAD_BYTES = 140 +# RDB compression ratio for pseudo-random vector data (compresses poorly) +RDB_COMPRESSION_RATIO = 0.95 + + +class VectorFieldEstimate(BaseModel): + """Per-field disk space breakdown for a single vector field.""" + + field_name: str + dims: int + source_dtype: str + target_dtype: str + source_bytes_per_doc: int + target_bytes_per_doc: int + + +class DiskSpaceEstimate(BaseModel): + """Pre-migration estimate of disk and memory costs. + + Produced by estimate_disk_space() as a pure calculation from the migration + plan. No Redis mutations are performed. + """ + + # Index metadata + index_name: str + doc_count: int + storage_type: str = "hash" + + # Per-field breakdowns + vector_fields: List[VectorFieldEstimate] = Field(default_factory=list) + + # Aggregate vector data sizes + total_source_vector_bytes: int = 0 + total_target_vector_bytes: int = 0 + + # RDB snapshot cost (BGSAVE before migration) + rdb_snapshot_disk_bytes: int = 0 + rdb_cow_memory_if_concurrent_bytes: int = 0 + + # AOF growth cost (only if aof_enabled is True) + aof_enabled: bool = False + aof_growth_bytes: int = 0 + + # Totals + total_new_disk_bytes: int = 0 + memory_savings_after_bytes: int = 0 + + @property + def has_quantization(self) -> bool: + return len(self.vector_fields) > 0 + + def summary(self) -> str: + """Human-readable summary for CLI output.""" + if not self.has_quantization: + return "No vector quantization in this migration. No additional disk space required." + + lines = [ + "Pre-migration disk space estimate:", + f" Index: {self.index_name} ({self.doc_count:,} documents)", + ] + for vf in self.vector_fields: + lines.append( + f" Vector field '{vf.field_name}': {vf.dims} dims, " + f"{vf.source_dtype} -> {vf.target_dtype}" + ) + + lines.append("") + lines.append( + f" RDB snapshot (BGSAVE): ~{_format_bytes(self.rdb_snapshot_disk_bytes)}" + ) + if self.aof_enabled: + lines.append( + f" AOF growth (appendonly=yes): ~{_format_bytes(self.aof_growth_bytes)}" + ) + else: + lines.append( + " AOF growth: not estimated (pass aof_enabled=True if AOF is on)" + ) + lines.append( + f" Total new disk required: ~{_format_bytes(self.total_new_disk_bytes)}" + ) + lines.append("") + lines.append( + f" Post-migration memory savings: ~{_format_bytes(self.memory_savings_after_bytes)} " + f"({self._savings_pct()}% reduction)" + ) + return "\n".join(lines) + + def _savings_pct(self) -> int: + if self.total_source_vector_bytes == 0: + return 0 + return round( + 100 * self.memory_savings_after_bytes / self.total_source_vector_bytes + ) + + +def _format_bytes(n: int) -> str: + """Format byte count as human-readable string.""" + if n >= 1_073_741_824: + return f"{n / 1_073_741_824:.2f} GB" + if n >= 1_048_576: + return f"{n / 1_048_576:.1f} MB" + if n >= 1024: + return f"{n / 1024:.1f} KB" + return f"{n} bytes" + + # ----------------------------------------------------------------------------- # Batch Migration Models # ----------------------------------------------------------------------------- diff --git a/redisvl/migration/planner.py b/redisvl/migration/planner.py index 85bfe511..bdd300a7 100644 --- a/redisvl/migration/planner.py +++ b/redisvl/migration/planner.py @@ -425,11 +425,6 @@ def classify_diff( has_field_renames = ( rename_operations and len(rename_operations.rename_fields) > 0 ) - renamed_field_names = set() - if has_field_renames and rename_operations: - renamed_field_names = { - fr.old_name for fr in rename_operations.rename_fields - } for index_key, target_value in changes.index.items(): source_value = source_dict["index"].get(index_key) @@ -602,13 +597,17 @@ def _classify_vector_field_change( @staticmethod def get_vector_datatype_changes( source_schema: Dict[str, Any], target_schema: Dict[str, Any] - ) -> Dict[str, Dict[str, str]]: + ) -> Dict[str, Dict[str, Any]]: """Identify vector fields that need datatype conversion (quantization). Returns: - Dict mapping field_name -> {"source": source_dtype, "target": target_dtype} + Dict mapping field_name -> { + "source": source_dtype, + "target": target_dtype, + "dims": int # vector dimensions for idempotent detection + } """ - changes: Dict[str, Dict[str, str]] = {} + changes: Dict[str, Dict[str, Any]] = {} source_fields = {f["name"]: f for f in source_schema.get("fields", [])} target_fields = {f["name"]: f for f in target_schema.get("fields", [])} @@ -621,9 +620,14 @@ def get_vector_datatype_changes( source_dtype = source_field.get("attrs", {}).get("datatype", "float32") target_dtype = target_field.get("attrs", {}).get("datatype", "float32") + dims = source_field.get("attrs", {}).get("dims", 0) if source_dtype != target_dtype: - changes[name] = {"source": source_dtype, "target": target_dtype} + changes[name] = { + "source": source_dtype, + "target": target_dtype, + "dims": dims, + } return changes diff --git a/redisvl/migration/reliability.py b/redisvl/migration/reliability.py new file mode 100644 index 00000000..1037c9f6 --- /dev/null +++ b/redisvl/migration/reliability.py @@ -0,0 +1,328 @@ +"""Crash-safe quantization utilities for index migration. + +Provides idempotent dtype detection, checkpointing, BGSAVE safety, +and bounded undo buffering for reliable vector re-encoding. +""" + +import asyncio +import logging +import os +import tempfile +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import yaml +from pydantic import BaseModel, Field + +from redisvl.migration.models import DTYPE_BYTES + +logger = logging.getLogger(__name__) + +# Dtypes that share byte widths and are functionally interchangeable +# for idempotent detection purposes (same byte length per element). +_DTYPE_FAMILY: Dict[str, str] = { + "float64": "8byte", + "float32": "4byte", + "float16": "2byte", + "bfloat16": "2byte", + "int8": "1byte", + "uint8": "1byte", +} + + +def is_same_width_dtype_conversion(source_dtype: str, target_dtype: str) -> bool: + """Return True when two dtypes share byte width but differ in encoding.""" + if source_dtype == target_dtype: + return False + return _DTYPE_FAMILY.get(source_dtype) == _DTYPE_FAMILY.get(target_dtype) + + +# --------------------------------------------------------------------------- +# Idempotent Dtype Detection +# --------------------------------------------------------------------------- + + +def detect_vector_dtype(data: bytes, expected_dims: int) -> Optional[str]: + """Inspect raw vector bytes and infer the storage dtype. + + Uses byte length and expected dimensions to determine which dtype + the vector is currently stored as. Returns the canonical representative + for each byte-width family (float16 for 2-byte, int8 for 1-byte), + since dtypes within a family cannot be distinguished by length alone. + + Args: + data: Raw vector bytes from Redis. + expected_dims: Number of dimensions expected for this vector field. + + Returns: + Detected dtype string (e.g. "float32", "float16", "int8") or None + if the size does not match any known dtype. + """ + if not data or expected_dims <= 0: + return None + + nbytes = len(data) + + # Check each dtype in decreasing element size to avoid ambiguity. + # Only canonical representatives are checked (float16 covers bfloat16, + # int8 covers uint8) since they share byte widths. + for dtype in ("float64", "float32", "float16", "int8"): + if nbytes == expected_dims * DTYPE_BYTES[dtype]: + return dtype + + return None + + +def is_already_quantized( + data: bytes, + expected_dims: int, + source_dtype: str, + target_dtype: str, +) -> bool: + """Check whether a vector has already been converted to the target dtype. + + Uses byte-width families to handle ambiguous dtypes. For example, + if source is float32 and target is float16, a vector detected as + 2-bytes-per-element is considered already quantized (the byte width + shrank from 4 to 2, so conversion already happened). + + However, same-width conversions (e.g. float16 -> bfloat16 or + int8 -> uint8) are NOT skipped because the encoding semantics + differ even though the byte length is identical. We cannot + distinguish these by length, so we must always re-encode. + + Args: + data: Raw vector bytes. + expected_dims: Number of dimensions. + source_dtype: The dtype the vector was originally stored as. + target_dtype: The dtype we want to convert to. + + Returns: + True if the vector already matches the target dtype (skip conversion). + """ + detected = detect_vector_dtype(data, expected_dims) + if detected is None: + return False + + detected_family = _DTYPE_FAMILY.get(detected) + target_family = _DTYPE_FAMILY.get(target_dtype) + source_family = _DTYPE_FAMILY.get(source_dtype) + + # If detected byte-width matches target family, the vector looks converted. + # But if source and target share the same byte-width family (e.g. + # float16 -> bfloat16), we cannot tell whether conversion happened, + # so we must NOT skip -- always re-encode for same-width migrations. + if source_family == target_family: + return False + + return detected_family == target_family + + +# --------------------------------------------------------------------------- +# Quantization Checkpoint +# --------------------------------------------------------------------------- + + +class QuantizationCheckpoint(BaseModel): + """Tracks migration progress for crash-safe resume.""" + + index_name: str + total_keys: int + completed_keys: int = 0 + completed_batches: int = 0 + last_batch_keys: List[str] = Field(default_factory=list) + # Retained for backward compatibility with older checkpoint files. + # New checkpoints rely on completed_keys with deterministic key ordering + # instead of rewriting an ever-growing processed key list on every batch. + processed_keys: List[str] = Field(default_factory=list) + status: str = "in_progress" + checkpoint_path: str = "" + + def record_batch(self, keys: List[str]) -> None: + """Record a successfully processed batch. + + Does not auto-save to disk. Call save() after record_batch() + to persist the checkpoint for crash recovery. + """ + self.completed_keys += len(keys) + self.completed_batches += 1 + self.last_batch_keys = list(keys) + if self.processed_keys: + self.processed_keys.extend(keys) + + def mark_complete(self) -> None: + """Mark the migration as completed.""" + self.status = "completed" + + def save(self) -> None: + """Persist checkpoint to disk atomically. + + Writes to a temporary file first, then renames. This ensures a + crash mid-write does not corrupt the checkpoint file. + """ + path = Path(self.checkpoint_path) + path.parent.mkdir(parents=True, exist_ok=True) + fd, tmp_path = tempfile.mkstemp( + dir=path.parent, suffix=".tmp", prefix=".checkpoint_" + ) + try: + exclude = set() + if not self.processed_keys: + exclude.add("processed_keys") + with os.fdopen(fd, "w") as f: + yaml.safe_dump( + self.model_dump(exclude=exclude), + f, + sort_keys=False, + ) + os.replace(tmp_path, str(path)) + except BaseException: + # Clean up temp file on any failure + try: + os.unlink(tmp_path) + except OSError: + pass + raise + + @classmethod + def load(cls, path: str) -> Optional["QuantizationCheckpoint"]: + """Load a checkpoint from disk. Returns None if file does not exist. + + Always sets checkpoint_path to the path used to load, not the + value stored in the file. This ensures subsequent save() calls + write to the correct location even if the file was moved. + """ + p = Path(path) + if not p.exists(): + return None + with open(p, "r") as f: + data = yaml.safe_load(f) + if not data: + return None + checkpoint = cls.model_validate(data) + if checkpoint.processed_keys and checkpoint.completed_keys < len( + checkpoint.processed_keys + ): + checkpoint.completed_keys = len(checkpoint.processed_keys) + checkpoint.checkpoint_path = str(p) + return checkpoint + + def get_remaining_keys(self, all_keys: List[str]) -> List[str]: + """Return keys that have not yet been processed.""" + if self.processed_keys: + done = set(self.processed_keys) + return [k for k in all_keys if k not in done] + + if self.completed_keys <= 0: + return list(all_keys) + + return all_keys[self.completed_keys :] + + +# --------------------------------------------------------------------------- +# BGSAVE Safety Net +# --------------------------------------------------------------------------- + + +def trigger_bgsave_and_wait( + client: Any, + *, + timeout_seconds: int = 300, + poll_interval: float = 1.0, +) -> bool: + """Trigger a Redis BGSAVE and wait for it to complete. + + If a BGSAVE is already in progress, waits for it instead. + + Args: + client: Sync Redis client. + timeout_seconds: Max seconds to wait for BGSAVE to finish. + poll_interval: Seconds between status polls. + + Returns: + True if BGSAVE completed successfully. + """ + try: + client.bgsave() + except Exception as exc: + if "already in progress" not in str(exc).lower(): + raise + logger.info("BGSAVE already in progress, waiting for it to finish.") + + deadline = time.monotonic() + timeout_seconds + while time.monotonic() < deadline: + info = client.info() + if isinstance(info, dict) and not info.get("rdb_bgsave_in_progress", 0): + return True + time.sleep(poll_interval) + + raise TimeoutError(f"BGSAVE did not complete within {timeout_seconds}s") + + +async def async_trigger_bgsave_and_wait( + client: Any, + *, + timeout_seconds: int = 300, + poll_interval: float = 1.0, +) -> bool: + """Async version of trigger_bgsave_and_wait.""" + try: + await client.bgsave() + except Exception as exc: + if "already in progress" not in str(exc).lower(): + raise + logger.info("BGSAVE already in progress, waiting for it to finish.") + + deadline = time.monotonic() + timeout_seconds + while time.monotonic() < deadline: + info = await client.info() + if isinstance(info, dict) and not info.get("rdb_bgsave_in_progress", 0): + return True + await asyncio.sleep(poll_interval) + + raise TimeoutError(f"BGSAVE did not complete within {timeout_seconds}s") + + +# --------------------------------------------------------------------------- +# Bounded Undo Buffer +# --------------------------------------------------------------------------- + + +class BatchUndoBuffer: + """Stores original vector values for the current batch to allow rollback. + + Memory-bounded: only holds data for one batch at a time. Call clear() + after each successful batch commit. + """ + + def __init__(self) -> None: + self._entries: List[Tuple[str, str, bytes]] = [] + + @property + def size(self) -> int: + return len(self._entries) + + def store(self, key: str, field: str, original_value: bytes) -> None: + """Record the original value of a field before mutation.""" + self._entries.append((key, field, original_value)) + + def rollback(self, pipe: Any) -> None: + """Restore all stored originals via the given pipeline (sync).""" + if not self._entries: + return + for key, field, value in self._entries: + pipe.hset(key, field, value) + pipe.execute() + + async def async_rollback(self, pipe: Any) -> None: + """Restore all stored originals via the given pipeline (async).""" + if not self._entries: + return + for key, field, value in self._entries: + pipe.hset(key, field, value) + await pipe.execute() + + def clear(self) -> None: + """Discard all stored entries.""" + self._entries.clear() diff --git a/redisvl/migration/utils.py b/redisvl/migration/utils.py index ac377b76..18beb3c2 100644 --- a/redisvl/migration/utils.py +++ b/redisvl/migration/utils.py @@ -3,12 +3,21 @@ import json import time from pathlib import Path -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import yaml from redisvl.index import SearchIndex -from redisvl.migration.models import MigrationPlan, MigrationReport +from redisvl.migration.models import ( + AOF_HSET_OVERHEAD_BYTES, + AOF_JSON_SET_OVERHEAD_BYTES, + DTYPE_BYTES, + RDB_COMPRESSION_RATIO, + DiskSpaceEstimate, + MigrationPlan, + MigrationReport, + VectorFieldEstimate, +) from redisvl.redis.connection import RedisConnectionFactory from redisvl.schema.schema import IndexSchema @@ -66,6 +75,60 @@ def write_benchmark_report(report: MigrationReport, path: str) -> None: write_yaml(benchmark_report, path) +def normalize_keys(keys: List[str]) -> List[str]: + """Deduplicate and sort keys for deterministic resume behavior.""" + return sorted(set(keys)) + + +def build_scan_match_patterns(prefixes: List[str], key_separator: str) -> List[str]: + """Build SCAN patterns for all configured prefixes.""" + if not prefixes: + return ["*"] + + patterns = set() + for prefix in prefixes: + if not prefix: + return ["*"] + if key_separator and not prefix.endswith(key_separator): + patterns.add(f"{prefix}{key_separator}*") + else: + patterns.add(f"{prefix}*") + return sorted(patterns) + + +def detect_aof_enabled(client: Any) -> bool: + """Best-effort detection of whether AOF is enabled on the live Redis.""" + try: + info = client.info("persistence") + if isinstance(info, dict) and "aof_enabled" in info: + return bool(int(info["aof_enabled"])) + except Exception: + pass + + try: + config = client.config_get("appendonly") + if isinstance(config, dict): + value = config.get("appendonly") + if value is not None: + return str(value).lower() in {"yes", "1", "true", "on"} + except Exception: + pass + + return False + + +def get_schema_field_path(schema: Dict[str, Any], field_name: str) -> Optional[str]: + """Return the JSON path configured for a field, if present.""" + for field in schema.get("fields", []): + if field.get("name") != field_name: + continue + path = field.get("path") + if path is None: + path = field.get("attrs", {}).get("path") + return str(path) if path is not None else None + return None + + # Attributes excluded from schema validation comparison. # These are query-time or creation-hint parameters that FT.INFO does not return # and are not relevant for index structure validation (confirmed by RediSearch team). @@ -148,7 +211,7 @@ def canonicalize_schema( schema["index"]["prefix"] = sorted(prefixes) stopwords = schema["index"].get("stopwords") if isinstance(stopwords, list): - schema["index"]["stopwords"] = list(stopwords) + schema["index"]["stopwords"] = sorted(stopwords) return schema @@ -196,7 +259,7 @@ def wait_for_index_ready( deadline = start + timeout_seconds latest_info = index.info() - stable_ready_checks = 0 + stable_ready_checks: Optional[int] = None while time.perf_counter() < deadline: latest_info = index.info() indexing = latest_info.get("indexing") @@ -214,7 +277,7 @@ def wait_for_index_ready( if current_docs is None: ready = True else: - if stable_ready_checks == 0: + if stable_ready_checks is None: stable_ready_checks = int(current_docs) time.sleep(poll_interval_seconds) continue @@ -247,3 +310,119 @@ def current_source_matches_snapshot( def timestamp_utc() -> str: return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + + +def estimate_disk_space( + plan: MigrationPlan, + *, + aof_enabled: bool = False, +) -> DiskSpaceEstimate: + """Estimate disk space required for a migration with quantization. + + This is a pure calculation based on the migration plan. No Redis + operations are performed. + + Args: + plan: The migration plan containing source/target schemas. + aof_enabled: Whether AOF persistence is active on the Redis instance. + + Returns: + DiskSpaceEstimate with projected costs. + """ + doc_count = int(plan.source.stats_snapshot.get("num_docs", 0) or 0) + storage_type = plan.source.keyspace.storage_type + index_name = plan.source.index_name + + # Find vector fields with datatype changes + source_fields = { + f["name"]: f for f in plan.source.schema_snapshot.get("fields", []) + } + target_fields = {f["name"]: f for f in plan.merged_target_schema.get("fields", [])} + + vector_field_estimates: list[VectorFieldEstimate] = [] + total_source_bytes = 0 + total_target_bytes = 0 + total_aof_growth = 0 + + aof_overhead = ( + AOF_JSON_SET_OVERHEAD_BYTES + if storage_type == "json" + else AOF_HSET_OVERHEAD_BYTES + ) + + for name, source_field in source_fields.items(): + if source_field.get("type") != "vector": + continue + target_field = target_fields.get(name) + if not target_field or target_field.get("type") != "vector": + continue + + source_attrs = source_field.get("attrs", {}) + target_attrs = target_field.get("attrs", {}) + source_dtype = source_attrs.get("datatype", "float32").lower() + target_dtype = target_attrs.get("datatype", "float32").lower() + + if source_dtype == target_dtype: + continue + + if source_dtype not in DTYPE_BYTES: + raise ValueError( + f"Unknown source vector datatype '{source_dtype}' for field '{name}'. " + f"Supported datatypes: {', '.join(sorted(DTYPE_BYTES.keys()))}" + ) + if target_dtype not in DTYPE_BYTES: + raise ValueError( + f"Unknown target vector datatype '{target_dtype}' for field '{name}'. " + f"Supported datatypes: {', '.join(sorted(DTYPE_BYTES.keys()))}" + ) + + if storage_type == "json": + # JSON-backed migrations do not rewrite per-document vector payloads + # during apply(); they rely on recreate + re-index instead. + continue + + dims = int(source_attrs.get("dims", 0)) + source_bpe = DTYPE_BYTES[source_dtype] + target_bpe = DTYPE_BYTES[target_dtype] + + source_vec_size = dims * source_bpe + target_vec_size = dims * target_bpe + + vector_field_estimates.append( + VectorFieldEstimate( + field_name=name, + dims=dims, + source_dtype=source_dtype, + target_dtype=target_dtype, + source_bytes_per_doc=source_vec_size, + target_bytes_per_doc=target_vec_size, + ) + ) + + field_source_total = doc_count * source_vec_size + field_target_total = doc_count * target_vec_size + total_source_bytes += field_source_total + total_target_bytes += field_target_total + + if aof_enabled: + total_aof_growth += doc_count * (target_vec_size + aof_overhead) + + rdb_snapshot_disk = int(total_source_bytes * RDB_COMPRESSION_RATIO) + rdb_cow_memory = total_source_bytes + total_new_disk = rdb_snapshot_disk + total_aof_growth + memory_savings = total_source_bytes - total_target_bytes + + return DiskSpaceEstimate( + index_name=index_name, + doc_count=doc_count, + storage_type=storage_type, + vector_fields=vector_field_estimates, + total_source_vector_bytes=total_source_bytes, + total_target_vector_bytes=total_target_bytes, + rdb_snapshot_disk_bytes=rdb_snapshot_disk, + rdb_cow_memory_if_concurrent_bytes=rdb_cow_memory, + aof_enabled=aof_enabled, + aof_growth_bytes=total_aof_growth, + total_new_disk_bytes=total_new_disk, + memory_savings_after_bytes=memory_savings, + ) diff --git a/redisvl/migration/validation.py b/redisvl/migration/validation.py index efa46381..25b39c33 100644 --- a/redisvl/migration/validation.py +++ b/redisvl/migration/validation.py @@ -108,7 +108,7 @@ def _run_query_checks( passed=fetched is not None, details=( "Document fetched successfully" - if fetched + if fetched is not None else "Document not found" ), ) diff --git a/tests/integration/test_migration_comprehensive.py b/tests/integration/test_migration_comprehensive.py index 2678abac..0b335357 100644 --- a/tests/integration/test_migration_comprehensive.py +++ b/tests/integration/test_migration_comprehensive.py @@ -1278,28 +1278,22 @@ def json_schema(self, unique_ids): @pytest.fixture def json_sample_docs(self): - """Sample JSON documents.""" - import json - + """Sample JSON documents (as dicts for RedisJSON).""" return [ - json.dumps( - { - "doc_id": "1", - "title": "Alpha Product", - "category": "electronics", - "price": 99.99, - "embedding": [0.1, 0.2, 0.3, 0.4], - } - ), - json.dumps( - { - "doc_id": "2", - "title": "Beta Service", - "category": "software", - "price": 149.99, - "embedding": [0.2, 0.3, 0.4, 0.5], - } - ), + { + "doc_id": "1", + "title": "Alpha Product", + "category": "electronics", + "price": 99.99, + "embedding": [0.1, 0.2, 0.3, 0.4], + }, + { + "doc_id": "2", + "title": "Beta Service", + "category": "software", + "price": 149.99, + "embedding": [0.2, 0.3, 0.4, 0.5], + }, ] def test_json_add_field( @@ -1348,11 +1342,9 @@ def test_json_rename_field( index.create(overwrite=True) # Load JSON docs - import json as json_module - for i, doc in enumerate(json_sample_docs): key = f"{unique_ids['prefix']}:{i+1}" - client.json().set(key, "$", json_module.loads(doc)) + client.json().set(key, "$", doc) try: result = run_migration( diff --git a/tests/integration/test_migration_routes.py b/tests/integration/test_migration_routes.py index c666d7c7..5d897d01 100644 --- a/tests/integration/test_migration_routes.py +++ b/tests/integration/test_migration_routes.py @@ -8,10 +8,12 @@ import uuid import pytest +from redis import Redis from redisvl.index import SearchIndex from redisvl.migration import MigrationExecutor, MigrationPlanner from redisvl.migration.models import FieldUpdate, SchemaPatch +from tests.conftest import skip_if_redis_version_below def create_source_index(redis_url, worker_id, source_attrs): @@ -136,6 +138,8 @@ def test_flat_datatype_change( @pytest.mark.parametrize("target_dtype", ["int8", "uint8"]) def test_flat_quantized_datatype(self, redis_url, worker_id, target_dtype): """Test INT8/UINT8 datatypes (requires Redis 8.0+).""" + client = Redis.from_url(redis_url) + skip_if_redis_version_below(client, "8.0.0", "INT8/UINT8 requires Redis 8.0+") index, index_name = create_source_index( redis_url, worker_id, {"algorithm": "flat"} ) @@ -169,6 +173,8 @@ def test_hnsw_datatype_change( @pytest.mark.parametrize("target_dtype", ["int8", "uint8"]) def test_hnsw_quantized_datatype(self, redis_url, worker_id, target_dtype): """Test INT8/UINT8 datatypes with HNSW (requires Redis 8.0+).""" + client = Redis.from_url(redis_url) + skip_if_redis_version_below(client, "8.0.0", "INT8/UINT8 requires Redis 8.0+") index, index_name = create_source_index( redis_url, worker_id, {"algorithm": "hnsw"} ) @@ -308,6 +314,8 @@ def test_flat_to_hnsw_with_datatype_and_metric(self, redis_url, worker_id): def test_flat_to_hnsw_with_int8(self, redis_url, worker_id): """Combined algorithm + quantized datatype (requires Redis 8.0+).""" + client = Redis.from_url(redis_url) + skip_if_redis_version_below(client, "8.0.0", "INT8 requires Redis 8.0+") index, index_name = create_source_index( redis_url, worker_id, {"algorithm": "flat"} ) diff --git a/tests/unit/test_async_migration_executor.py b/tests/unit/test_async_migration_executor.py index da43ba2f..ac65c384 100644 --- a/tests/unit/test_async_migration_executor.py +++ b/tests/unit/test_async_migration_executor.py @@ -1,17 +1,27 @@ -"""Unit tests for AsyncMigrationExecutor. +"""Unit tests for migration executors and disk space estimator. These tests mirror the sync MigrationExecutor patterns but use async/await. +Also includes pure-calculation tests for estimate_disk_space(). """ +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock + import pytest -from redisvl.migration import AsyncMigrationExecutor +from redisvl.migration import AsyncMigrationExecutor, MigrationExecutor from redisvl.migration.models import ( DiffClassification, KeyspaceSnapshot, MigrationPlan, SourceSnapshot, ValidationPolicy, + _format_bytes, +) +from redisvl.migration.utils import ( + build_scan_match_patterns, + estimate_disk_space, + normalize_keys, ) @@ -123,3 +133,960 @@ async def test_async_executor_validates_redis_url(): # For a proper test, we'd need to mock AsyncSearchIndex.from_existing # For now, we just verify the executor is created assert executor is not None + + +# ============================================================================= +# Disk Space Estimator Tests +# ============================================================================= + + +def _make_quantize_plan( + source_dtype="float32", + target_dtype="float16", + dims=3072, + doc_count=100_000, + storage_type="hash", +): + """Helper to create a migration plan with a vector datatype change.""" + return MigrationPlan( + mode="drop_recreate", + source=SourceSnapshot( + index_name="test_index", + keyspace=KeyspaceSnapshot( + storage_type=storage_type, + prefixes=["test"], + key_separator=":", + ), + schema_snapshot={ + "index": { + "name": "test_index", + "prefix": "test", + "storage_type": storage_type, + }, + "fields": [ + {"name": "title", "type": "text"}, + { + "name": "embedding", + "type": "vector", + "attrs": { + "algorithm": "hnsw", + "dims": dims, + "distance_metric": "cosine", + "datatype": source_dtype, + }, + }, + ], + }, + stats_snapshot={"num_docs": doc_count}, + ), + requested_changes={}, + merged_target_schema={ + "index": { + "name": "test_index", + "prefix": "test", + "storage_type": storage_type, + }, + "fields": [ + {"name": "title", "type": "text"}, + { + "name": "embedding", + "type": "vector", + "attrs": { + "algorithm": "hnsw", + "dims": dims, + "distance_metric": "cosine", + "datatype": target_dtype, + }, + }, + ], + }, + diff_classification=DiffClassification(supported=True, blocked_reasons=[]), + validation=ValidationPolicy(require_doc_count_match=True), + ) + + +def test_estimate_fp32_to_fp16(): + """FP32->FP16 with 3072 dims, 100K docs should produce expected byte counts.""" + plan = _make_quantize_plan("float32", "float16", dims=3072, doc_count=100_000) + est = estimate_disk_space(plan) + + assert est.has_quantization is True + assert len(est.vector_fields) == 1 + vf = est.vector_fields[0] + assert vf.source_bytes_per_doc == 3072 * 4 # 12288 + assert vf.target_bytes_per_doc == 3072 * 2 # 6144 + + assert est.total_source_vector_bytes == 100_000 * 12288 + assert est.total_target_vector_bytes == 100_000 * 6144 + assert est.memory_savings_after_bytes == 100_000 * (12288 - 6144) + + # RDB = source * 0.95 + assert est.rdb_snapshot_disk_bytes == int(100_000 * 12288 * 0.95) + # COW = full source + assert est.rdb_cow_memory_if_concurrent_bytes == 100_000 * 12288 + # AOF disabled by default + assert est.aof_enabled is False + assert est.aof_growth_bytes == 0 + assert est.total_new_disk_bytes == est.rdb_snapshot_disk_bytes + + +def test_estimate_with_aof_enabled(): + """AOF growth should include RESP overhead per HSET.""" + plan = _make_quantize_plan("float32", "float16", dims=3072, doc_count=100_000) + est = estimate_disk_space(plan, aof_enabled=True) + + assert est.aof_enabled is True + target_vec_size = 3072 * 2 + expected_aof = 100_000 * (target_vec_size + 114) # 114 = HSET overhead + assert est.aof_growth_bytes == expected_aof + assert est.total_new_disk_bytes == est.rdb_snapshot_disk_bytes + expected_aof + + +def test_estimate_json_storage_aof(): + """JSON storage quantization should not report in-place rewrite costs.""" + plan = _make_quantize_plan( + "float32", "float16", dims=128, doc_count=1000, storage_type="json" + ) + est = estimate_disk_space(plan, aof_enabled=True) + + assert est.has_quantization is False + assert est.aof_growth_bytes == 0 + assert est.total_new_disk_bytes == 0 + + +def test_estimate_no_quantization(): + """Same dtype source and target should produce empty estimate.""" + plan = _make_quantize_plan("float32", "float32", dims=128, doc_count=1000) + est = estimate_disk_space(plan) + + assert est.has_quantization is False + assert len(est.vector_fields) == 0 + assert est.total_new_disk_bytes == 0 + assert est.memory_savings_after_bytes == 0 + + +def test_estimate_fp32_to_int8(): + """FP32->INT8 should use 1 byte per element.""" + plan = _make_quantize_plan("float32", "int8", dims=768, doc_count=50_000) + est = estimate_disk_space(plan) + + assert est.vector_fields[0].source_bytes_per_doc == 768 * 4 + assert est.vector_fields[0].target_bytes_per_doc == 768 * 1 + assert est.memory_savings_after_bytes == 50_000 * 768 * 3 + + +def test_estimate_summary_with_quantization(): + """Summary string should contain key information.""" + plan = _make_quantize_plan("float32", "float16", dims=128, doc_count=1000) + est = estimate_disk_space(plan) + summary = est.summary() + + assert "Pre-migration disk space estimate" in summary + assert "test_index" in summary + assert "1,000 documents" in summary + assert "float32 -> float16" in summary + assert "RDB snapshot" in summary + assert "memory savings" in summary + + +def test_estimate_summary_no_quantization(): + """Summary for non-quantization migration should say no disk needed.""" + plan = _make_quantize_plan("float32", "float32", dims=128, doc_count=1000) + est = estimate_disk_space(plan) + summary = est.summary() + + assert "No vector quantization" in summary + + +def test_format_bytes_gb(): + assert _format_bytes(1_073_741_824) == "1.00 GB" + assert _format_bytes(2_147_483_648) == "2.00 GB" + + +def test_format_bytes_mb(): + assert _format_bytes(1_048_576) == "1.0 MB" + assert _format_bytes(10_485_760) == "10.0 MB" + + +def test_format_bytes_kb(): + assert _format_bytes(1024) == "1.0 KB" + assert _format_bytes(2048) == "2.0 KB" + + +def test_format_bytes_bytes(): + assert _format_bytes(500) == "500 bytes" + assert _format_bytes(0) == "0 bytes" + + +def test_savings_pct(): + """Verify savings percentage calculation.""" + plan = _make_quantize_plan("float32", "float16", dims=128, doc_count=100) + est = estimate_disk_space(plan) + # FP32->FP16 = 50% savings + assert est._savings_pct() == 50 + + +# ============================================================================= +# TDD RED Phase: Idempotent Dtype Detection Tests +# ============================================================================= +# These test detect_vector_dtype() and is_already_quantized() which inspect +# raw vector bytes to determine whether a key needs conversion or can be skipped. + + +def test_detect_dtype_float32_by_size(): + """A 3072-dim vector stored as FP32 should be 12288 bytes.""" + import numpy as np + + from redisvl.migration.reliability import detect_vector_dtype + + vec = np.random.randn(3072).astype(np.float32).tobytes() + detected = detect_vector_dtype(vec, expected_dims=3072) + assert detected == "float32" + + +def test_detect_dtype_float16_by_size(): + """A 3072-dim vector stored as FP16 should be 6144 bytes.""" + import numpy as np + + from redisvl.migration.reliability import detect_vector_dtype + + vec = np.random.randn(3072).astype(np.float16).tobytes() + detected = detect_vector_dtype(vec, expected_dims=3072) + assert detected == "float16" + + +def test_detect_dtype_int8_by_size(): + """A 768-dim vector stored as INT8 should be 768 bytes.""" + import numpy as np + + from redisvl.migration.reliability import detect_vector_dtype + + vec = np.zeros(768, dtype=np.int8).tobytes() + detected = detect_vector_dtype(vec, expected_dims=768) + assert detected == "int8" + + +def test_detect_dtype_bfloat16_by_size(): + """A 768-dim bfloat16 vector should be 1536 bytes (same as float16).""" + import numpy as np + + from redisvl.migration.reliability import detect_vector_dtype + + # bfloat16 and float16 are both 2 bytes per element + vec = np.random.randn(768).astype(np.float16).tobytes() + detected = detect_vector_dtype(vec, expected_dims=768) + # Cannot distinguish float16 from bfloat16 by size alone; returns "float16" + assert detected in ("float16", "bfloat16") + + +def test_detect_dtype_empty_returns_none(): + """Empty bytes should return None.""" + from redisvl.migration.reliability import detect_vector_dtype + + assert detect_vector_dtype(b"", expected_dims=128) is None + + +def test_detect_dtype_unknown_size(): + """Bytes that don't match any known dtype should return None.""" + from redisvl.migration.reliability import detect_vector_dtype + + # 7 bytes doesn't match any dtype for 3 dims + assert detect_vector_dtype(b"\x00" * 7, expected_dims=3) is None + + +def test_is_already_quantized_skip(): + """If source is float32 and vector is already float16, should return True.""" + import numpy as np + + from redisvl.migration.reliability import is_already_quantized + + vec = np.random.randn(128).astype(np.float16).tobytes() + result = is_already_quantized( + vec, expected_dims=128, source_dtype="float32", target_dtype="float16" + ) + assert result is True + + +def test_is_already_quantized_needs_conversion(): + """If source is float32 and vector IS float32, should return False.""" + import numpy as np + + from redisvl.migration.reliability import is_already_quantized + + vec = np.random.randn(128).astype(np.float32).tobytes() + result = is_already_quantized( + vec, expected_dims=128, source_dtype="float32", target_dtype="float16" + ) + assert result is False + + +def test_is_already_quantized_bfloat16_target(): + """If target is bfloat16 and vector is 2-bytes-per-element, should return True. + + bfloat16 and float16 share the same byte width (2 bytes per element) + and are treated as the same dtype family for idempotent detection. + """ + import numpy as np + + from redisvl.migration.reliability import is_already_quantized + + vec = np.random.randn(128).astype(np.float16).tobytes() + result = is_already_quantized( + vec, expected_dims=128, source_dtype="float32", target_dtype="bfloat16" + ) + assert result is True + + +def test_is_already_quantized_uint8_target(): + """If target is uint8 and vector is 1-byte-per-element, should return True. + + uint8 and int8 share the same byte width (1 byte per element) + and are treated as the same dtype family for idempotent detection. + """ + import numpy as np + + from redisvl.migration.reliability import is_already_quantized + + vec = np.random.randn(128).astype(np.int8).tobytes() + result = is_already_quantized( + vec, expected_dims=128, source_dtype="float32", target_dtype="uint8" + ) + assert result is True + + +def test_is_already_quantized_same_width_float16_to_bfloat16(): + """float16 -> bfloat16 should NOT be skipped (same byte width, different encoding).""" + import numpy as np + + from redisvl.migration.reliability import is_already_quantized + + vec = np.random.randn(128).astype(np.float16).tobytes() + result = is_already_quantized( + vec, expected_dims=128, source_dtype="float16", target_dtype="bfloat16" + ) + assert result is False + + +def test_is_already_quantized_same_width_int8_to_uint8(): + """int8 -> uint8 should NOT be skipped (same byte width, different encoding).""" + import numpy as np + + from redisvl.migration.reliability import is_already_quantized + + vec = np.random.randn(128).astype(np.int8).tobytes() + result = is_already_quantized( + vec, expected_dims=128, source_dtype="int8", target_dtype="uint8" + ) + assert result is False + + +# ============================================================================= +# TDD RED Phase: Checkpoint File Tests +# ============================================================================= + + +def test_checkpoint_create_new(tmp_path): + """Creating a new checkpoint should initialize with zero progress.""" + from redisvl.migration.reliability import QuantizationCheckpoint + + cp = QuantizationCheckpoint( + index_name="test_index", + total_keys=10000, + checkpoint_path=str(tmp_path / "checkpoint.yaml"), + ) + assert cp.index_name == "test_index" + assert cp.total_keys == 10000 + assert cp.completed_keys == 0 + assert cp.completed_batches == 0 + assert cp.last_batch_keys == [] + assert cp.status == "in_progress" + + +def test_checkpoint_save_and_load(tmp_path): + """Checkpoint should persist to disk and reload with same state.""" + from redisvl.migration.reliability import QuantizationCheckpoint + + path = str(tmp_path / "checkpoint.yaml") + cp = QuantizationCheckpoint( + index_name="test_index", + total_keys=5000, + checkpoint_path=path, + ) + cp.record_batch(["key:1", "key:2", "key:3"]) + cp.save() + + loaded = QuantizationCheckpoint.load(path) + assert loaded.index_name == "test_index" + assert loaded.total_keys == 5000 + assert loaded.completed_keys == 3 + assert loaded.completed_batches == 1 + assert loaded.last_batch_keys == ["key:1", "key:2", "key:3"] + + +def test_checkpoint_record_multiple_batches(tmp_path): + """Recording multiple batches should accumulate counts.""" + from redisvl.migration.reliability import QuantizationCheckpoint + + cp = QuantizationCheckpoint( + index_name="idx", + total_keys=100, + checkpoint_path=str(tmp_path / "cp.yaml"), + ) + cp.record_batch(["k1", "k2"]) + cp.record_batch(["k3", "k4", "k5"]) + + assert cp.completed_keys == 5 + assert cp.completed_batches == 2 + assert cp.last_batch_keys == ["k3", "k4", "k5"] + + +def test_checkpoint_mark_complete(tmp_path): + """Marking complete should set status to 'completed'.""" + from redisvl.migration.reliability import QuantizationCheckpoint + + cp = QuantizationCheckpoint( + index_name="idx", + total_keys=2, + checkpoint_path=str(tmp_path / "cp.yaml"), + ) + cp.record_batch(["k1", "k2"]) + cp.mark_complete() + + assert cp.status == "completed" + + +def test_checkpoint_get_remaining_keys(tmp_path): + """get_remaining_keys should return only keys not yet processed.""" + from redisvl.migration.reliability import QuantizationCheckpoint + + cp = QuantizationCheckpoint( + index_name="idx", + total_keys=5, + checkpoint_path=str(tmp_path / "cp.yaml"), + ) + all_keys = ["k1", "k2", "k3", "k4", "k5"] + cp.record_batch(["k1", "k2"]) + + remaining = cp.get_remaining_keys(all_keys) + assert remaining == ["k3", "k4", "k5"] + + +def test_checkpoint_get_remaining_keys_uses_completed_offset_when_compact(tmp_path): + """Compact checkpoints should resume via completed_keys ordering.""" + from redisvl.migration.reliability import QuantizationCheckpoint + + cp = QuantizationCheckpoint( + index_name="idx", + total_keys=5, + checkpoint_path=str(tmp_path / "cp.yaml"), + ) + cp.record_batch(["k1", "k2"]) + + remaining = cp.get_remaining_keys(["k1", "k2", "k3", "k4", "k5"]) + assert remaining == ["k3", "k4", "k5"] + + +def test_checkpoint_save_excludes_processed_keys(tmp_path): + """New checkpoints should persist compact state without processed_keys.""" + from redisvl.migration.reliability import QuantizationCheckpoint + + path = tmp_path / "checkpoint.yaml" + cp = QuantizationCheckpoint( + index_name="idx", + total_keys=3, + checkpoint_path=str(path), + ) + cp.save() + + raw = path.read_text() + assert "processed_keys" not in raw + + +def test_checkpoint_load_nonexistent_returns_none(tmp_path): + """Loading a nonexistent checkpoint file should return None.""" + from redisvl.migration.reliability import QuantizationCheckpoint + + result = QuantizationCheckpoint.load( + str(tmp_path / "nonexistent_checkpoint_xyz.yaml") + ) + assert result is None + + +def test_checkpoint_load_forces_path(tmp_path): + """load() should set checkpoint_path to the file used to load, not the stored value.""" + from redisvl.migration.reliability import QuantizationCheckpoint + + original_path = str(tmp_path / "original.yaml") + cp = QuantizationCheckpoint( + index_name="idx", + total_keys=10, + checkpoint_path=original_path, + ) + cp.record_batch(["k1"]) + cp.save() + + # Move the file to a new location + new_path = str(tmp_path / "moved.yaml") + import shutil + + shutil.copy(original_path, new_path) + + loaded = QuantizationCheckpoint.load(new_path) + assert loaded.checkpoint_path == new_path # should use load path, not stored + + +def test_checkpoint_save_preserves_legacy_processed_keys(tmp_path): + """Legacy checkpoints should keep processed_keys across subsequent saves.""" + from redisvl.migration.reliability import QuantizationCheckpoint + + path = tmp_path / "legacy.yaml" + path.write_text( + "index_name: idx\n" + "total_keys: 4\n" + "processed_keys:\n" + " - k1\n" + " - k2\n" + "status: in_progress\n" + ) + + checkpoint = QuantizationCheckpoint.load(str(path)) + checkpoint.record_batch(["k3"]) + checkpoint.save() + + reloaded = QuantizationCheckpoint.load(str(path)) + assert reloaded.processed_keys == ["k1", "k2", "k3"] + assert reloaded.completed_keys == 3 + + +def test_quantize_vectors_saves_checkpoint_before_processing(monkeypatch, tmp_path): + """Checkpoint save should happen before the first HGET in a fresh run.""" + import numpy as np + + executor = MigrationExecutor() + checkpoint_path = str(tmp_path / "checkpoint.yaml") + field_bytes = np.array([1.0, 2.0], dtype=np.float32).tobytes() + events: list[str] = [] + + original_save = executor._quantize_vectors.__globals__[ + "QuantizationCheckpoint" + ].save + + def tracking_save(self): + events.append("save") + return original_save(self) + + monkeypatch.setattr( + executor._quantize_vectors.__globals__["QuantizationCheckpoint"], + "save", + tracking_save, + ) + + client = MagicMock() + client.hget.side_effect = lambda key, field: (events.append("hget") or field_bytes) + pipe = MagicMock() + client.pipeline.return_value = pipe + source_index = MagicMock() + source_index._redis_client = client + source_index.name = "idx" + + result = executor._quantize_vectors( + source_index, + {"embedding": {"source": "float32", "target": "float16", "dims": 2}}, + ["doc:1"], + checkpoint_path=checkpoint_path, + ) + + assert result == 1 + assert events[0] == "save" + assert Path(checkpoint_path).exists() + + +def test_quantize_vectors_returns_reencoded_docs_not_scanned_docs(): + """Quantize count should reflect converted docs, not skipped docs.""" + import numpy as np + + executor = MigrationExecutor() + already_quantized = np.array([1.0, 2.0], dtype=np.float16).tobytes() + needs_quantization = np.array([1.0, 2.0], dtype=np.float32).tobytes() + + client = MagicMock() + client.hget.side_effect = lambda key, field: { + "doc:1": already_quantized, + "doc:2": needs_quantization, + }[key] + pipe = MagicMock() + client.pipeline.return_value = pipe + source_index = MagicMock() + source_index._redis_client = client + source_index.name = "idx" + + progress: list[tuple[int, int]] = [] + result = executor._quantize_vectors( + source_index, + {"embedding": {"source": "float32", "target": "float16", "dims": 2}}, + ["doc:1", "doc:2"], + progress_callback=lambda done, total: progress.append((done, total)), + ) + + assert result == 1 + assert progress[-1] == (2, 2) + + +def test_build_scan_match_patterns_uses_separator(): + assert build_scan_match_patterns(["test"], ":") == ["test:*"] + assert build_scan_match_patterns(["test:"], ":") == ["test:*"] + assert build_scan_match_patterns([], ":") == ["*"] + assert build_scan_match_patterns(["b", "a"], ":") == ["a:*", "b:*"] + + +def test_normalize_keys_dedupes_and_sorts(): + assert normalize_keys(["b", "a", "b"]) == ["a", "b"] + + +def test_detect_aof_enabled_from_info(): + from redisvl.migration.utils import detect_aof_enabled + + client = MagicMock() + client.info.return_value = {"aof_enabled": 1} + assert detect_aof_enabled(client) is True + + +@pytest.mark.asyncio +async def test_async_detect_aof_enabled_from_info(): + executor = AsyncMigrationExecutor() + client = MagicMock() + client.info = AsyncMock(return_value={"aof_enabled": 1}) + assert await executor._detect_aof_enabled(client) is True + + +def test_estimate_json_quantization_is_noop(): + """JSON datatype changes should not report in-place rewrite costs.""" + plan = _make_quantize_plan( + "float32", "float16", dims=128, doc_count=1000, storage_type="json" + ) + est = estimate_disk_space(plan, aof_enabled=True) + + assert est.has_quantization is False + assert est.total_new_disk_bytes == 0 + assert est.aof_growth_bytes == 0 + + +def test_estimate_unknown_dtype_raises(): + plan = _make_quantize_plan("madeup32", "float16", dims=128, doc_count=10) + + with pytest.raises(ValueError, match="Unknown source vector datatype"): + estimate_disk_space(plan) + + +def test_enumerate_with_scan_uses_all_prefixes(): + executor = MigrationExecutor() + client = MagicMock() + client.ft.return_value.info.return_value = { + "index_definition": {"prefixes": ["alpha", "beta"]} + } + client.scan.side_effect = [ + (0, [b"alpha:1", b"shared:1"]), + (0, [b"beta:2", b"shared:1"]), + ] + + keys = list(executor._enumerate_with_scan(client, "idx", batch_size=1000)) + + assert keys == ["alpha:1", "shared:1", "beta:2"] + + +@pytest.mark.asyncio +async def test_async_enumerate_with_scan_uses_all_prefixes(): + executor = AsyncMigrationExecutor() + client = MagicMock() + client.ft.return_value.info = AsyncMock( + return_value={"index_definition": {"prefixes": ["alpha", "beta"]}} + ) + client.scan = AsyncMock( + side_effect=[ + (0, [b"alpha:1", b"shared:1"]), + (0, [b"beta:2", b"shared:1"]), + ] + ) + + keys = [ + key + async for key in executor._enumerate_with_scan(client, "idx", batch_size=1000) + ] + + assert keys == ["alpha:1", "shared:1", "beta:2"] + + +def test_apply_rejects_same_width_resume(monkeypatch): + plan = _make_quantize_plan("float16", "bfloat16", dims=2, doc_count=1) + executor = MigrationExecutor() + + def _make_index(*args, **kwargs): + index = MagicMock() + index._redis_client = MagicMock() + index.name = "test_index" + return index + + monkeypatch.setattr( + "redisvl.migration.executor.current_source_matches_snapshot", + lambda *args, **kwargs: True, + ) + monkeypatch.setattr( + "redisvl.migration.executor.SearchIndex.from_existing", + _make_index, + ) + monkeypatch.setattr( + "redisvl.migration.executor.SearchIndex.from_dict", + _make_index, + ) + + report = executor.apply( + plan, + redis_client=MagicMock(), + checkpoint_path="resume.yaml", + ) + + assert report.result == "failed" + assert "same-width datatype changes" in report.validation.errors[0] + + +@pytest.mark.asyncio +async def test_async_quantize_vectors_saves_checkpoint_before_processing( + monkeypatch, tmp_path +): + """Async checkpoint save should happen before the first HGET in a fresh run.""" + import numpy as np + + executor = AsyncMigrationExecutor() + checkpoint_path = str(tmp_path / "checkpoint.yaml") + field_bytes = np.array([1.0, 2.0], dtype=np.float32).tobytes() + events: list[str] = [] + + original_save = executor._async_quantize_vectors.__globals__[ + "QuantizationCheckpoint" + ].save + + def tracking_save(self): + events.append("save") + return original_save(self) + + monkeypatch.setattr( + executor._async_quantize_vectors.__globals__["QuantizationCheckpoint"], + "save", + tracking_save, + ) + + client = MagicMock() + client.hget = AsyncMock( + side_effect=lambda key, field: (events.append("hget") or field_bytes) + ) + pipe = MagicMock() + pipe.execute = AsyncMock(return_value=[]) + client.pipeline.return_value = pipe + source_index = MagicMock() + source_index._redis_client = client + source_index.name = "idx" + + result = await executor._async_quantize_vectors( + source_index, + {"embedding": {"source": "float32", "target": "float16", "dims": 2}}, + ["doc:1"], + checkpoint_path=checkpoint_path, + ) + + assert result == 1 + assert events[0] == "save" + assert Path(checkpoint_path).exists() + + +@pytest.mark.asyncio +async def test_async_quantize_vectors_returns_reencoded_docs_not_scanned_docs(): + """Async quantize count should reflect converted docs, not skipped docs.""" + import numpy as np + + executor = AsyncMigrationExecutor() + already_quantized = np.array([1.0, 2.0], dtype=np.float16).tobytes() + needs_quantization = np.array([1.0, 2.0], dtype=np.float32).tobytes() + + client = MagicMock() + client.hget = AsyncMock( + side_effect=lambda key, field: { + "doc:1": already_quantized, + "doc:2": needs_quantization, + }[key] + ) + pipe = MagicMock() + pipe.execute = AsyncMock(return_value=[]) + client.pipeline.return_value = pipe + source_index = MagicMock() + source_index._redis_client = client + source_index.name = "idx" + + progress: list[tuple[int, int]] = [] + result = await executor._async_quantize_vectors( + source_index, + {"embedding": {"source": "float32", "target": "float16", "dims": 2}}, + ["doc:1", "doc:2"], + progress_callback=lambda done, total: progress.append((done, total)), + ) + + assert result == 1 + assert progress[-1] == (2, 2) + + +# ============================================================================= +# TDD RED Phase: BGSAVE Safety Net Tests +# ============================================================================= + + +def test_trigger_bgsave_success(): + """BGSAVE should be triggered and waited on; returns True on success.""" + from unittest.mock import MagicMock + + from redisvl.migration.reliability import trigger_bgsave_and_wait + + mock_client = MagicMock() + mock_client.bgsave.return_value = True + mock_client.info.return_value = {"rdb_bgsave_in_progress": 0} + + result = trigger_bgsave_and_wait(mock_client, timeout_seconds=5) + assert result is True + mock_client.bgsave.assert_called_once() + + +def test_trigger_bgsave_already_in_progress(): + """If BGSAVE is already running, wait for it instead of starting a new one.""" + from unittest.mock import MagicMock, call + + from redisvl.migration.reliability import trigger_bgsave_and_wait + + mock_client = MagicMock() + # First bgsave raises because one is already in progress + mock_client.bgsave.side_effect = Exception("Background save already in progress") + # First check: still running; second check: done + mock_client.info.side_effect = [ + {"rdb_bgsave_in_progress": 1}, + {"rdb_bgsave_in_progress": 0}, + ] + + result = trigger_bgsave_and_wait(mock_client, timeout_seconds=5, poll_interval=0.01) + assert result is True + + +@pytest.mark.asyncio +async def test_async_trigger_bgsave_success(): + """Async BGSAVE should work the same as sync.""" + from unittest.mock import AsyncMock + + from redisvl.migration.reliability import async_trigger_bgsave_and_wait + + mock_client = AsyncMock() + mock_client.bgsave.return_value = True + mock_client.info.return_value = {"rdb_bgsave_in_progress": 0} + + result = await async_trigger_bgsave_and_wait(mock_client, timeout_seconds=5) + assert result is True + mock_client.bgsave.assert_called_once() + + +# ============================================================================= +# TDD RED Phase: Bounded Undo Buffer Tests +# ============================================================================= + + +def test_undo_buffer_store_and_rollback(): + """Undo buffer should store original values and rollback via pipeline.""" + from unittest.mock import MagicMock + + from redisvl.migration.reliability import BatchUndoBuffer + + buf = BatchUndoBuffer() + buf.store("key:1", "embedding", b"\x00\x01\x02\x03") + buf.store("key:2", "embedding", b"\x04\x05\x06\x07") + + assert buf.size == 2 + + mock_pipe = MagicMock() + buf.rollback(mock_pipe) + + # Should have called hset twice to restore originals + assert mock_pipe.hset.call_count == 2 + mock_pipe.execute.assert_called_once() + + +def test_undo_buffer_clear(): + """After clear, buffer should be empty.""" + from redisvl.migration.reliability import BatchUndoBuffer + + buf = BatchUndoBuffer() + buf.store("key:1", "field", b"\x00") + assert buf.size == 1 + + buf.clear() + assert buf.size == 0 + + +def test_undo_buffer_empty_rollback(): + """Rolling back an empty buffer should be a no-op.""" + from unittest.mock import MagicMock + + from redisvl.migration.reliability import BatchUndoBuffer + + buf = BatchUndoBuffer() + mock_pipe = MagicMock() + buf.rollback(mock_pipe) + + # No hset calls, no execute + mock_pipe.hset.assert_not_called() + mock_pipe.execute.assert_not_called() + + +def test_undo_buffer_multiple_fields_same_key(): + """Should handle multiple fields for the same key.""" + from unittest.mock import MagicMock + + from redisvl.migration.reliability import BatchUndoBuffer + + buf = BatchUndoBuffer() + buf.store("key:1", "embedding", b"\x00\x01") + buf.store("key:1", "embedding2", b"\x02\x03") + + assert buf.size == 2 + + mock_pipe = MagicMock() + buf.rollback(mock_pipe) + assert mock_pipe.hset.call_count == 2 + + +@pytest.mark.asyncio +async def test_undo_buffer_async_rollback(): + """async_rollback should await pipe.execute() for async Redis pipelines.""" + from unittest.mock import AsyncMock, MagicMock + + from redisvl.migration.reliability import BatchUndoBuffer + + buf = BatchUndoBuffer() + buf.store("key:1", "embedding", b"\x00\x01") + buf.store("key:2", "embedding", b"\x02\x03") + + mock_pipe = MagicMock() + mock_pipe.execute = AsyncMock() + + await buf.async_rollback(mock_pipe) + assert mock_pipe.hset.call_count == 2 + mock_pipe.execute.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_undo_buffer_async_rollback_empty(): + """async_rollback on empty buffer should be a no-op.""" + from unittest.mock import AsyncMock, MagicMock + + from redisvl.migration.reliability import BatchUndoBuffer + + buf = BatchUndoBuffer() + mock_pipe = MagicMock() + mock_pipe.execute = AsyncMock() + + await buf.async_rollback(mock_pipe) + mock_pipe.hset.assert_not_called() + mock_pipe.execute.assert_not_awaited()