PSv2: Use connection pooling and retries for NATS#1130
PSv2: Use connection pooling and retries for NATS#1130carlosgjs wants to merge 20 commits intoRolnickLab:mainfrom
Conversation
✅ Deploy Preview for antenna-ssec canceled.
|
✅ Deploy Preview for antenna-preview canceled.
|
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughIntroduces a per-event-loop NATS ConnectionPool and retry decorator, refactors TaskQueueManager to use pool-provided (nc, js) per operation, and updates publish/reserve/ack/delete/cleanup flows and tests to use lifecycle-managed connections with retry/backoff. Changes
Sequence DiagramsequenceDiagram
participant App as Application
participant TQM as TaskQueueManager
participant Retry as RetryDecorator
participant Pool as ConnectionPool
participant NATS as NATS Client
participant JS as JetStream
App->>TQM: publish_task(job_id, task)
activate TQM
TQM->>Retry: wrapped call
activate Retry
Retry->>TQM: attempt
TQM->>Pool: _get_connection()
activate Pool
Pool->>NATS: ensure/connect (nc)
Pool->>JS: provide JetStreamContext (js)
Pool-->>TQM: (nc, js)
deactivate Pool
TQM->>JS: publish(stream, data)
JS-->>TQM: ack / error
alt Connection error
TQM-->>Retry: raise connection error
Retry->>Pool: reset_connection()
Retry->>Retry: backoff wait
Retry->>TQM: retry attempt (up to max)
end
Retry-->>App: result
deactivate Retry
deactivate TQM
App->>TQM: reserve_task(job_id, timeout)
activate TQM
TQM->>Retry: wrapped call
activate Retry
Retry->>TQM: attempt
TQM->>Pool: _get_connection()
Pool-->>TQM: (nc, js)
TQM->>JS: pull_subscribe + fetch(timeout)
alt Message received
JS-->>TQM: message -> PipelineProcessingTask
Retry-->>App: task
else Timeout
JS-->>TQM: TimeoutError
Retry-->>App: None
end
deactivate Retry
deactivate TQM
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Poem
🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Pull request overview
This PR refactors NATS JetStream task-queue interactions to reuse a shared (process-local) connection and to add retry/backoff behavior for improved reliability and reduced connection churn.
Changes:
- Introduces a new
ConnectionPoolmodule to lazily create and reuse a NATS + JetStream connection. - Updates
TaskQueueManagerto always obtain connections from the pool, removes async context-manager lifecycle, and adds a retry decorator for connection-related failures. - Updates call sites and unit tests to match the new non-context-manager usage.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| ami/ml/orchestration/tests/test_nats_queue.py | Refactors tests to mock the new connection pool and remove context-manager assumptions. |
| ami/ml/orchestration/nats_queue.py | Adds retry decorator, removes per-operation connection creation, and routes all operations through the shared pool. |
| ami/ml/orchestration/nats_connection_pool.py | New module implementing a process-local cached NATS connection + JetStream context. |
| ami/ml/orchestration/jobs.py | Removes async with TaskQueueManager() usage in orchestration job helpers. |
| ami/jobs/views.py | Removes async with TaskQueueManager() usage in the tasks endpoint. |
| ami/jobs/tasks.py | Removes async with TaskQueueManager() usage when ACKing tasks from Celery. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In `@ami/ml/orchestration/nats_connection_pool.py`:
- Around line 104-112: get_pool() can race when multiple threads call it
concurrently; make initialization of the module-global _connection_pool
thread-safe by introducing a module-level threading.Lock (e.g.,
_connection_pool_lock) and use a double-checked locking pattern inside
get_pool(): check _connection_pool, acquire _connection_pool_lock, check again,
and only then instantiate ConnectionPool and assign to _connection_pool; ensure
you import threading and keep the global _connection_pool and lock declarations
at module scope so get_pool() uses them.
- Around line 88-97: The reset() method currently nulls out self._nc and
self._js and leaks an active NATS connection; modify reset (or add an async
reset_async) to attempt to close/drain the existing connection before clearing
references: if self._nc exists, call its close/drain routine (await if reset
becomes async, or schedule with asyncio.create_task(self._nc.close()) if keeping
reset synchronous) wrapped in try/except to swallow errors, then set self._nc =
None and self._js = None; update any callers (including the retry decorator) to
call the async version or rely on the scheduled background close so the old TCP
socket is not leaked.
In `@ami/ml/orchestration/nats_queue.py`:
- Line 318: The log message contains a missing space between "job" and the job
id; update the logger.info call that logs stream deletion (the line using
logger.info with f"Deleted stream {stream_name} for job'{job_id}'") to insert a
space so it reads f"Deleted stream {stream_name} for job '{job_id}'", leaving
the surrounding code unchanged.
🧹 Nitpick comments (6)
ami/ml/orchestration/jobs.py (1)
107-109: Uselogger.exceptionto preserve the stack trace.When a publish fails, the traceback is valuable for diagnosing whether it's a connection issue that exhausted retries, a serialization bug, etc.
Proposed fix
except Exception as e: - logger.error(f"Failed to queue image {image_pk} to stream for job '{job.pk}': {e}") + logger.exception(f"Failed to queue image {image_pk} to stream for job '{job.pk}': {e}") success = Falseami/ml/orchestration/tests/test_nats_queue.py (2)
88-100: Consider adding a test for theTimeoutErrorpath inreserve_task.This test mocks
fetchto return[], which exercises theif msgs:falsy branch. However, in practice, NATS'sfetch()raisesnats_errors.TimeoutErrorwhen no messages are available (handled at Line 243 ofnats_queue.py). Consider adding a test case wherefetchraisesTimeoutErrorto cover that code path as well.Additional test case
async def test_reserve_task_timeout(self): """Test reserve_task when fetch raises TimeoutError (no messages).""" from nats import errors as nats_errors with self._mock_nats_setup() as (_, js, _): mock_psub = MagicMock() mock_psub.fetch = AsyncMock(side_effect=nats_errors.TimeoutError) mock_psub.unsubscribe = AsyncMock() js.pull_subscribe = AsyncMock(return_value=mock_psub) manager = TaskQueueManager() task = await manager.reserve_task(123) self.assertIsNone(task) mock_psub.unsubscribe.assert_called_once()
49-131: No tests for the retry/backoff behavior.The
retry_on_connection_errordecorator is a core part of this PR. Consider adding at least one test that verifies a retried operation succeeds on a subsequent attempt after a connection error, and one that verifies the error is raised after exhausting retries. This would validate the most important new behavior introduced.ami/ml/orchestration/nats_connection_pool.py (1)
50-58: Fast-path check outside the lock is fine for asyncio but worth a comment.The health check at line 51 and stale-connection cleanup at lines 55-58 happen outside the lock, relying on asyncio's cooperative scheduling (no preemption between awaits). This is correct but non-obvious. A brief inline comment would help future readers.
Suggested comment
# Fast path: connection exists, is open, and is connected + # Safe without lock: no await/yield between check and return (cooperative scheduling) if self._nc is not None and not self._nc.is_closed and self._nc.is_connected: return self._nc, self._js # type: ignoreami/ml/orchestration/nats_queue.py (2)
169-197: Redundant_get_connection()call —jsfetched at line 182 is also re-fetched inside_ensure_streamand_ensure_consumer.
publish_taskfetchesjsat line 182, but_ensure_stream(line 124) and_ensure_consumer(line 144) each fetch their ownjsfrom the pool internally. Thejsfrom line 182 is only used at line 194. This works because the pool returns the same object, but the pattern is slightly misleading — it looks like the connection from line 182 is used throughout when it's not.Consider either (a) passing
jsinto_ensure_stream/_ensure_consumer, or (b) moving the_get_connection()call to after the ensure calls since that's where it's first needed. Same applies toreserve_task.
199-249:reserve_task: implicitNonereturn whenmsgsis empty.If
psub.fetch()returns an empty list (rather than raisingTimeoutError), the function falls through theif msgs:block andexceptclause, runsfinally, and returnsNoneimplicitly. This is technically correct per the return type, but an explicitreturn Noneafter theif msgs:block would make the intent clearer.Suggested improvement
if msgs: msg = msgs[0] task_data = json.loads(msg.data.decode()) metadata = msg.metadata # Parse the task data into PipelineProcessingTask task = PipelineProcessingTask(**task_data) # Set the reply_subject for acknowledgment task.reply_subject = msg.reply logger.debug(f"Reserved task from stream for job '{job_id}', sequence {metadata.sequence.stream}") return task + return None except nats_errors.TimeoutError:
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
|
Hi Carlos, here is a Claude-written comment after I spent some time with it reviewing these changes. Ha. The retry decorator is a great addition — the nats.py client doesn't buffer operations during reconnection (unlike the Go client), so application-level retry is genuinely needed. I think we can simplify the connection pooling, it may be unnecessary for our scale after all. We run ~3 Celery workers on separate VMs. Since More importantly, the event-loop-keyed
Suggested simplificationKeep the retry decorator (real value) and drop the pool. The connection scoping is already natural:
Concretely, this would mean:
This keeps the retry logic (the real fix) while avoiding the event-loop-keyed pool complexity. What do you think? |
…rop pool Replace the event-loop-keyed WeakKeyDictionary connection pool with a straightforward async context manager on TaskQueueManager. Each async_to_sync() call now scopes one connection for its block of operations (e.g. queue_all_images reuses one connection for all publishes, _ack_task_via_nats gets one for the single ACK). The retry decorator is preserved — on connection error it closes the stale connection so the next _get_connection() creates a fresh one. Also adds reconnected_cb/disconnected_cb logging callbacks to nats.connect() and narrows bare except clauses to NotFoundError. Co-Authored-By: Claude <noreply@anthropic.com>
ami/ml/orchestration/nats_queue.py
Outdated
| await manager.publish_task(job_id, task) | ||
|
|
||
| The connection is created on entry and closed on exit. Within the block, the retry | ||
| decorator handles transient connection errors by clearing and recreating the connection. |
There was a problem hiding this comment.
@mihow - The downside of this approach is that there will still be a large number or connections opened and closed, which was one of the original problems we were trying to solve. The TaskQueueManager is used in 3 scenarios (assume a 1000 image job):
- Queuing up all tasks for a job: Here we do ok since we can queue all 1000 images in one async context, so 1 connection is used.
- Fetching
/tasks. The workers will call this at least 1000/batch size, so we'll likely have 250-500 connections. - Acknowledging results: Each result is saved/acked on a separate celery task, so this will be 1000 connections.
There was a problem hiding this comment.
Thanks for breaking that down @carlosgjs. I am fine to go with your implementation, but I got a little nervous. It seems like maybe you were fighting Django's async_to_sync() stuff, and I was just thinking about the first scenario (Queuing up all tasks for a job).
There was a problem hiding this comment.
@carlosgjs I reverted back to your implementation, but i got some help to add more comments & docstrings for me to know what's going on.
There was a problem hiding this comment.
I just read your/claude's comment. Makes sense, I agree it's fine to keep the simplified approach and only optimize the connections if needed later.
Mark connection handling as done (PR RolnickLab#1130), add worktree/remote mapping and docker testing notes for future sessions. Co-Authored-By: Claude <noreply@anthropic.com>
… churn Reverts c384199 which replaced the event-loop-keyed connection pool with a plain async context manager. The context manager approach opened and closed a TCP connection per async block, causing ~1500 connections per 1000-image job (250-500 for task fetches, 1000 for ACKs). The connection pool keeps one persistent connection per event loop and reuses it across all TaskQueueManager operations. Co-Authored-By: Claude <noreply@anthropic.com>
Extract connection pool to a pluggable design with two strategies: - "pool" (default): persistent connection reuse per event loop - "per_operation": fresh TCP connection each time, for debugging Controlled by NATS_CONNECTION_STRATEGY Django setting. Both strategies implement the same interface (get_connection, reset, close) so TaskQueueManager is agnostic to which one is active. Changes: - Rename nats_connection_pool.py to nats_connection.py - Rename get_pool() to get_provider() - Use settings.NATS_URL directly instead of getattr with divergent defaults - Narrow except clauses in _ensure_stream/_ensure_consumer to NotFoundError - Add _js guard to fast path, add strategy logging - Enhanced module and class docstrings Co-Authored-By: Claude <noreply@anthropic.com>
Remove the switchable strategy pattern (Protocol + factory + Django setting) and expose the connection pool directly via module-level get_connection() and reset_connection() functions. The PerOperationConnection is archived as ContextManagerConnection for reference. Remove NATS_CONNECTION_STRATEGY setting. Co-Authored-By: Claude <noreply@anthropic.com>
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Fix all issues with AI agents
In `@ami/ml/orchestration/nats_connection.py`:
- Around line 186-192: The except block that catches RuntimeError from
asyncio.get_running_loop() should preserve the original exception chain: capture
the caught exception (e.g., "except RuntimeError as err:") and re-raise the
RuntimeError with your existing message using "raise ... from err" so the
original traceback is linked; update the block inside get_connection()/where
get_running_loop() is called to use "except RuntimeError as err" and "raise
RuntimeError(... ) from err".
- Around line 1-27: The pool keyed by event loop (ConnectionPool /
WeakKeyDictionary) doesn't guarantee reuse when callers use async_to_sync(), so
either confirm your runtime actually reuses the same event loop across
async_to_sync() or switch to a per-call connection strategy: remove/stop using
the loop-keyed pool lookup in ConnectionPool (and associated WeakKeyDictionary),
instead create and close a fresh nats.Client per operation (or default to
ContextManagerConnection) and rely on the existing retry_on_connection_error /
reset_connection logic for resilience; update any helper functions that call
ConnectionPool.create/get to use the per-call connect/close flow and keep
ContextManagerConnection as the simple fallback.
In `@ami/ml/orchestration/nats_queue.py`:
- Around line 253-276: The finally block currently awaits psub.unsubscribe()
which can raise and mask earlier exceptions and also makes the empty-msgs path
implicit; change the flow so that after calling psub.fetch(1, timeout=timeout)
you explicitly return None if msgs is falsy, and move the unsubscribe call into
its own try/except that catches and logs/unwarns errors without suppressing the
original exception (i.e., if an exception is active, call psub.unsubscribe()
inside a nested try/except and re-raise the original exception; if no exception,
safely await unsubscribe and log any unsubscribe errors). Update the block that
constructs PipelineProcessingTask and the handling around
psub.fetch/psub.unsubscribe accordingly so psub.unsubscribe failures do not hide
earlier exceptions.
- Line 217: Replace the deprecated Pydantic v1 call: in the line that builds
task_data (task_data = json.dumps(data.dict())), call data.model_dump() instead
so it uses Pydantic v2; update the expression that creates task_data in
nats_queue.py (the variable task_data and the object data) to use model_dump()
to avoid deprecation warnings and preserve the same serialized structure.
🧹 Nitpick comments (4)
ami/ml/orchestration/nats_connection.py (3)
59-68: Race betweenreset()and_ensure_lock()whenself._lockis set toNone.
reset()(line 137) setsself._lock = None. If a concurrent coroutine has already passed theself._lock is Nonecheck at line 66 but hasn't yet assigned the new lock,reset()could clear the newly-created lock. In practice, within a single-threaded asyncio event loop, cooperative scheduling means there's no true interleaving between lines 66–67. However,_ensure_lockshould still be tightened to avoid fragility ifreset()ever runs concurrently (e.g., from a callback):Suggested improvement
async def reset(self): ... self._nc = None self._js = None - self._lock = None # Clear lock so new one is created for fresh connection + self._lock = None # Will be lazily recreated on next get_connection()Consider documenting that
reset()must only be called from the same event loop that owns this pool (which is already implied by the design).
80-107: Fast-path state mutation outside the lock may confuse future maintainers.Lines 86–89 set
self._nc = Noneandself._js = Nonebefore acquiring the lock. While safe in a cooperative single-threaded asyncio context (the double-check at line 95 handles the reconnection), this pattern is unusual for a lock-protected resource and could lead to bugs if the code is ever adapted for truly concurrent access. Consider moving the clearing inside the lock:Suggested refactor
- # Connection is stale or doesn't exist — clear references before reconnecting - if self._nc is not None: - logger.warning("NATS connection is closed or disconnected, will reconnect") - self._nc = None - self._js = None - # Slow path: acquire lock to prevent concurrent reconnection attempts lock = self._ensure_lock() async with lock: # Double-check after acquiring lock (another coroutine may have reconnected) if self._nc is not None and self._js is not None and not self._nc.is_closed and self._nc.is_connected: return self._nc, self._js + # Connection is stale or doesn't exist — clear references before reconnecting + if self._nc is not None: + logger.warning("NATS connection is closed or disconnected, will reconnect") + self._nc = None + self._js = None + nats_url = settings.NATS_URL
140-175:ContextManagerConnectionis dead code — consider removing it.This class is described as "archived" and "kept as a drop-in fallback," but it is never referenced anywhere in the codebase. Keeping dead code increases the maintenance surface. If it's valuable as documentation, a comment or ADR would serve better.
#!/bin/bash # Verify ContextManagerConnection is not used anywhere rg -n "ContextManagerConnection" --type=pyami/ml/orchestration/nats_queue.py (1)
196-224:_ensure_streamand_ensure_consumerare called on everypublish_taskinvocation.For a 1000-image job, this results in ~2000 extra
stream_info/consumer_inforound trips to NATS, even though the stream and consumer only need to be created once. Consider caching the "ensured" status perjob_idfor the lifetime of the event loop (e.g., a set on the manager or pool), or moving the ensure calls to a separate setup method invoked once before the publish loop.Sketch
+ _ensured_jobs: set[int] = set() # class-level or instance-level cache + + async def _ensure_resources(self, job_id: int): + if job_id in self._ensured_jobs: + return + await self._ensure_stream(job_id) + await self._ensure_consumer(job_id) + self._ensured_jobs.add(job_id) + `@retry_on_connection_error`(max_retries=2, backoff_seconds=0.5) async def publish_task(self, job_id: int, data: PipelineProcessingTask) -> bool: _, js = await self._get_connection() - await self._ensure_stream(job_id) - await self._ensure_consumer(job_id) + await self._ensure_resources(job_id) ...
| try: | ||
| loop = asyncio.get_running_loop() | ||
| except RuntimeError: | ||
| raise RuntimeError( | ||
| "get_connection() must be called from an async context with a running event loop. " | ||
| "If calling from sync code, use async_to_sync() to wrap the async function." | ||
| ) |
There was a problem hiding this comment.
Use raise ... from err in the except clause (B904).
The static analysis hint is valid. This ensures the exception chain is preserved.
Fix
try:
loop = asyncio.get_running_loop()
except RuntimeError:
- raise RuntimeError(
+ raise RuntimeError(
"get_connection() must be called from an async context with a running event loop. "
"If calling from sync code, use async_to_sync() to wrap the async function."
- )
+ ) from None📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| try: | |
| loop = asyncio.get_running_loop() | |
| except RuntimeError: | |
| raise RuntimeError( | |
| "get_connection() must be called from an async context with a running event loop. " | |
| "If calling from sync code, use async_to_sync() to wrap the async function." | |
| ) | |
| try: | |
| loop = asyncio.get_running_loop() | |
| except RuntimeError: | |
| raise RuntimeError( | |
| "get_connection() must be called from an async context with a running event loop. " | |
| "If calling from sync code, use async_to_sync() to wrap the async function." | |
| ) from None |
🧰 Tools
🪛 Ruff (0.15.0)
[warning] 189-192: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling
(B904)
[warning] 189-192: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In `@ami/ml/orchestration/nats_connection.py` around lines 186 - 192, The except
block that catches RuntimeError from asyncio.get_running_loop() should preserve
the original exception chain: capture the caught exception (e.g., "except
RuntimeError as err:") and re-raise the RuntimeError with your existing message
using "raise ... from err" so the original traceback is linked; update the block
inside get_connection()/where get_running_loop() is called to use "except
RuntimeError as err" and "raise RuntimeError(... ) from err".
There was a problem hiding this comment.
In this case, we are catching the exception and returning our own. Wouldn't the exception chain just add noise?
There was a problem hiding this comment.
@mihow You're absolutely right that the chain would add noise here! Using from None (as suggested) explicitly suppresses the exception chain, which is exactly what you want.
Here's why from None is the right choice:
-
The original exception doesn't add value —
asyncio.get_running_loop()just raises "no running event loop", which your custom message already explains more clearly with actionable advice. -
Signals intent —
from Nonetells readers "I'm intentionally replacing this exception, not accidentally swallowing context." -
Satisfies the linter — B904 requires explicit chaining syntax (
from errorfrom None) to distinguish intentional replacement from accidental context loss.
Without from None, the linter can't tell if you deliberately chose not to chain or simply forgot. The suggested fix makes your intent explicit while keeping the output clean.
✏️ Learnings added
Learnt from: mihow
Repo: RolnickLab/antenna PR: 1130
File: ami/ml/orchestration/nats_connection.py:186-192
Timestamp: 2026-02-13T23:12:21.720Z
Learning: In Python exception handling, use `raise ... from None` when intentionally replacing an exception with a more informative one where the original exception would add noise. Use `raise ... from err` when the original exception provides valuable debugging context. The B904 linting rule requires explicit chaining syntax to distinguish intentional replacement from accidental context loss.
Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.
| try: | ||
| # Fetch a single message | ||
| msgs = await psub.fetch(1, timeout=timeout) | ||
|
|
||
| # Parse the task data into PipelineProcessingTask | ||
| task = PipelineProcessingTask(**task_data) | ||
| # Set the reply_subject for acknowledgment | ||
| task.reply_subject = msg.reply | ||
| if msgs: | ||
| msg = msgs[0] | ||
| task_data = json.loads(msg.data.decode()) | ||
| metadata = msg.metadata | ||
|
|
||
| logger.debug(f"Reserved task from stream for job '{job_id}', sequence {metadata.sequence.stream}") | ||
| return task | ||
| # Parse the task data into PipelineProcessingTask | ||
| task = PipelineProcessingTask(**task_data) | ||
| # Set the reply_subject for acknowledgment | ||
| task.reply_subject = msg.reply | ||
|
|
||
| except nats.errors.TimeoutError: | ||
| # No messages available | ||
| logger.debug(f"No tasks available in stream for job '{job_id}'") | ||
| return None | ||
| finally: | ||
| # Always unsubscribe | ||
| await psub.unsubscribe() | ||
| logger.debug(f"Reserved task from stream for job '{job_id}', sequence {metadata.sequence.stream}") | ||
| return task | ||
|
|
||
| except Exception as e: | ||
| logger.error(f"Failed to reserve task from stream for job '{job_id}': {e}") | ||
| except nats_errors.TimeoutError: | ||
| # No messages available (expected behavior) | ||
| logger.debug(f"No tasks available in stream for job '{job_id}'") | ||
| return None | ||
| finally: | ||
| # Always unsubscribe | ||
| await psub.unsubscribe() |
There was a problem hiding this comment.
psub.unsubscribe() failure in finally can mask the original exception.
If unsubscribe() raises (e.g., on a broken connection), it will mask whatever error was being propagated. Also note: when fetch returns an empty list (no messages and no timeout), the function falls through the if msgs: block and try without returning, reaching finally and then implicitly returning None. This works but is subtle.
Suggested improvement
finally:
# Always unsubscribe
- await psub.unsubscribe()
+ try:
+ await psub.unsubscribe()
+ except Exception:
+ logger.debug("Failed to unsubscribe pull subscription", exc_info=True)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| try: | |
| # Fetch a single message | |
| msgs = await psub.fetch(1, timeout=timeout) | |
| # Parse the task data into PipelineProcessingTask | |
| task = PipelineProcessingTask(**task_data) | |
| # Set the reply_subject for acknowledgment | |
| task.reply_subject = msg.reply | |
| if msgs: | |
| msg = msgs[0] | |
| task_data = json.loads(msg.data.decode()) | |
| metadata = msg.metadata | |
| logger.debug(f"Reserved task from stream for job '{job_id}', sequence {metadata.sequence.stream}") | |
| return task | |
| # Parse the task data into PipelineProcessingTask | |
| task = PipelineProcessingTask(**task_data) | |
| # Set the reply_subject for acknowledgment | |
| task.reply_subject = msg.reply | |
| except nats.errors.TimeoutError: | |
| # No messages available | |
| logger.debug(f"No tasks available in stream for job '{job_id}'") | |
| return None | |
| finally: | |
| # Always unsubscribe | |
| await psub.unsubscribe() | |
| logger.debug(f"Reserved task from stream for job '{job_id}', sequence {metadata.sequence.stream}") | |
| return task | |
| except Exception as e: | |
| logger.error(f"Failed to reserve task from stream for job '{job_id}': {e}") | |
| except nats_errors.TimeoutError: | |
| # No messages available (expected behavior) | |
| logger.debug(f"No tasks available in stream for job '{job_id}'") | |
| return None | |
| finally: | |
| # Always unsubscribe | |
| await psub.unsubscribe() | |
| try: | |
| # Fetch a single message | |
| msgs = await psub.fetch(1, timeout=timeout) | |
| if msgs: | |
| msg = msgs[0] | |
| task_data = json.loads(msg.data.decode()) | |
| metadata = msg.metadata | |
| # Parse the task data into PipelineProcessingTask | |
| task = PipelineProcessingTask(**task_data) | |
| # Set the reply_subject for acknowledgment | |
| task.reply_subject = msg.reply | |
| logger.debug(f"Reserved task from stream for job '{job_id}', sequence {metadata.sequence.stream}") | |
| return task | |
| except nats_errors.TimeoutError: | |
| # No messages available (expected behavior) | |
| logger.debug(f"No tasks available in stream for job '{job_id}'") | |
| return None | |
| finally: | |
| # Always unsubscribe | |
| try: | |
| await psub.unsubscribe() | |
| except Exception: | |
| logger.debug("Failed to unsubscribe pull subscription", exc_info=True) |
🤖 Prompt for AI Agents
In `@ami/ml/orchestration/nats_queue.py` around lines 253 - 276, The finally block
currently awaits psub.unsubscribe() which can raise and mask earlier exceptions
and also makes the empty-msgs path implicit; change the flow so that after
calling psub.fetch(1, timeout=timeout) you explicitly return None if msgs is
falsy, and move the unsubscribe call into its own try/except that catches and
logs/unwarns errors without suppressing the original exception (i.e., if an
exception is active, call psub.unsubscribe() inside a nested try/except and
re-raise the original exception; if no exception, safely await unsubscribe and
log any unsubscribe errors). Update the block that constructs
PipelineProcessingTask and the handling around psub.fetch/psub.unsubscribe
accordingly so psub.unsubscribe failures do not hide earlier exceptions.
There was a problem hiding this comment.
The @retry_on_connection_error decorator wrapping this method already catches connection-level errors (the likely
failure mode for unsubscribe()) and handles them with reset + backoff. If we suppress unsubscribe errors here, we'd
actually hide a signal that the retry decorator needs to trigger reconnection. The current finally pattern lets
cleanup errors propagate to the right layer.
There was a problem hiding this comment.
Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!
The docstring previously implied the pool reused connections across all async_to_sync() calls. In practice, each async_to_sync() creates a new event loop, so reuse only happens within a single boundary. Updated to be explicit about where the pool helps (bulk publishes, batch reserves) and where it doesn't (single-operation calls like ACKs). Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Claude <noreply@anthropic.com>
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@ami/ml/orchestration/nats_connection.py`:
- Around line 122-142: The reset() method currently sets self._lock = None which
allows another coroutine to create a new lock while reset is awaiting
self._nc.close(), enabling two coroutines to enter the critical "slow path" in
get_connection(); fix this by removing the line that clears the lock and only
reset connection state (keep self._lock intact), i.e., in reset() clear/close
self._nc and self._js but do not assign self._lock = None so _ensure_lock()
continues to use the same event-loop-bound lock instance.
🧹 Nitpick comments (1)
ami/ml/orchestration/nats_connection.py (1)
85-112: Addis_reconnectingcheck to prevent discarding clients during automatic reconnection.The fast-path check on line 87 rejects clients in the
RECONNECTINGstate (is_connected=False), causing them to be cleared and forcibly reconnected instead of allowing nats.py's automatic reconnection logic to complete. This creates redundant TCP connections during temporary outages.The suggested fix is correct and prevents this regression:
Suggested change
- if self._nc is not None and self._js is not None and not self._nc.is_closed and self._nc.is_connected: + if self._nc is not None and self._js is not None and not self._nc.is_closed and (self._nc.is_connected or self._nc.is_reconnecting): return self._nc, self._jsNote: The same pattern exists in the double-check at line 100 and should also be updated for consistency.
| async def reset(self): | ||
| """ | ||
| Close the current connection and clear all state so the next call to | ||
| get_connection() creates a fresh one. | ||
|
|
||
| Called by retry_on_connection_error when an operation hits a connection | ||
| error (e.g. network blip, NATS restart). The lock is also cleared so it | ||
| gets recreated bound to the current event loop. | ||
| """ | ||
| logger.warning("Resetting NATS connection pool due to connection error") | ||
| if self._nc is not None: | ||
| try: | ||
| if not self._nc.is_closed: | ||
| await self._nc.close() | ||
| logger.debug("Successfully closed existing NATS connection during reset") | ||
| except Exception as e: | ||
| # Swallow errors - connection may already be broken | ||
| logger.debug(f"Error closing connection during reset (expected): {e}") | ||
| self._nc = None | ||
| self._js = None | ||
| self._lock = None # Clear lock so new one is created for fresh connection |
There was a problem hiding this comment.
self._lock = None on line 142 can cause concurrent coroutines to bypass mutual exclusion.
When reset() awaits self._nc.close() (line 135), it yields control. If another coroutine calls get_connection() during that yield, _ensure_lock() sees self._lock is None (set on line 142 — or even before, if ordering shifts) and creates a new lock. The original lock held by the first coroutine is now a different object, so two coroutines can enter the "slow path" critical section simultaneously and race to create connections.
Instead of clearing the lock, keep it and only reset the connection state:
Suggested fix
async def reset(self):
logger.warning("Resetting NATS connection pool due to connection error")
- if self._nc is not None:
- try:
- if not self._nc.is_closed:
- await self._nc.close()
- logger.debug("Successfully closed existing NATS connection during reset")
- except Exception as e:
- # Swallow errors - connection may already be broken
- logger.debug(f"Error closing connection during reset (expected): {e}")
- self._nc = None
- self._js = None
- self._lock = None # Clear lock so new one is created for fresh connection
+ lock = self._ensure_lock()
+ async with lock:
+ if self._nc is not None:
+ try:
+ if not self._nc.is_closed:
+ await self._nc.close()
+ logger.debug("Successfully closed existing NATS connection during reset")
+ except Exception as e:
+ logger.debug(f"Error closing connection during reset (expected): {e}")
+ self._nc = None
+ self._js = NoneThe lock is bound to the event loop at creation time. Since reset() runs on the same event loop as get_connection(), there's no cross-loop issue — the same lock instance remains valid throughout the pool's lifetime on that loop.
🧰 Tools
🪛 Ruff (0.15.0)
[warning] 137-137: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
In `@ami/ml/orchestration/nats_connection.py` around lines 122 - 142, The reset()
method currently sets self._lock = None which allows another coroutine to create
a new lock while reset is awaiting self._nc.close(), enabling two coroutines to
enter the critical "slow path" in get_connection(); fix this by removing the
line that clears the lock and only reset connection state (keep self._lock
intact), i.e., in reset() clear/close self._nc and self._js but do not assign
self._lock = None so _ensure_lock() continues to use the same event-loop-bound
lock instance.
The _setup_mock_nats helper was configuring TaskQueueManager as an async context manager (__aenter__/__aexit__), but _ack_task_via_nats uses plain instantiation. The await on a non-awaitable MagicMock failed silently in the except clause, causing acknowledge_task assertions to always fail. Co-Authored-By: Claude <noreply@anthropic.com>
Summary
This pull request introduces improvements of the NATS JetStream task queue management to improve connection reliability, efficiency, and error handling. The main change is the introduction of a process-local NATS connection pool, which replaces the previous pattern of creating and closing connections for every operation. The code now uses retry logic with exponential backoff for all NATS operations. The context manager pattern for
TaskQueueManageris removed, and all methods are updated to use the shared connection pool. Several methods are now decorated to automatically retry on connection errors.Testing
Tested locally with multiple runs of 100 images, verifying tasks are acknowledges and NATS resources cleaned up.

Checklist
Summary by CodeRabbit
New Features
Refactor
Bug Fixes
Tests
Documentation