-
Notifications
You must be signed in to change notification settings - Fork 661
Fix GCG OOM on long runs by detaching gradients & explicit cleanup (#961) #1324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,8 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT license. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import gc | ||
| import json | ||
| import logging | ||
|
|
@@ -814,6 +816,12 @@ def control_weight_fn(_: int) -> float: | |
|
|
||
| self.control_str = last_control | ||
|
|
||
| # Clean up memory after test_all() which creates temporary PromptManagers | ||
| del model_tests | ||
| gc.collect() | ||
| if torch.cuda.is_available(): | ||
| torch.cuda.empty_cache() | ||
|
|
||
| return self.control_str, loss, steps | ||
|
|
||
| def test( | ||
|
|
@@ -1644,7 +1652,12 @@ def run(model: Any, tasks: mp.JoinableQueue[Any], results: mp.JoinableQueue[Any] | |
| ob, fn, args, kwargs = task | ||
| if fn == "grad": | ||
| with torch.enable_grad(): # type: ignore[no-untyped-call, unused-ignore] | ||
| results.put(ob.grad(*args, **kwargs)) | ||
| result = ob.grad(*args, **kwargs) | ||
| results.put(result) | ||
| del result | ||
| # Clear CUDA cache after gradient computation to prevent memory accumulation | ||
| if torch.cuda.is_available(): | ||
| torch.cuda.empty_cache() | ||
| else: | ||
| with torch.no_grad(): | ||
| if fn == "logits": | ||
|
|
@@ -1657,6 +1670,9 @@ def run(model: Any, tasks: mp.JoinableQueue[Any], results: mp.JoinableQueue[Any] | |
| results.put(ob.test_loss(*args, **kwargs)) | ||
| else: | ||
| results.put(fn(*args, **kwargs)) | ||
| # Clean up the task object to free memory | ||
| del ob | ||
| gc.collect() | ||
| tasks.task_done() | ||
|
Comment on lines
+1674
to
1676
|
||
|
|
||
| def start(self) -> "ModelWorker": | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -68,7 +68,17 @@ def token_gradients( | |||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| loss.backward() | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| return one_hot.grad.clone() | ||||||||||||||||||||||||||||||||||
| # Clone and detach the gradient to break the computation graph | ||||||||||||||||||||||||||||||||||
| grad = one_hot.grad.clone().detach() | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # Explicitly clear references to free memory | ||||||||||||||||||||||||||||||||||
| del one_hot, input_embeds, embeds, full_embeds, logits, targets, loss | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # Clear CUDA cache to release GPU memory | ||||||||||||||||||||||||||||||||||
| if torch.cuda.is_available(): | ||||||||||||||||||||||||||||||||||
| torch.cuda.empty_cache() | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
Comment on lines
+77
to
+80
|
||||||||||||||||||||||||||||||||||
| # Clear CUDA cache to release GPU memory | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Conditionally clear CUDA cache to mitigate memory pressure without | |
| # incurring synchronization overhead on every gradient computation. | |
| if torch.cuda.is_available(): | |
| device = getattr(model, "device", torch.device("cuda")) | |
| try: | |
| reserved_memory: int = torch.cuda.memory_reserved(device) | |
| total_memory: int = torch.cuda.get_device_properties(device).total_memory | |
| except Exception: | |
| reserved_memory = 0 | |
| total_memory = 1 | |
| if total_memory > 0 and reserved_memory / total_memory > 0.9: | |
| torch.cuda.empty_cache() |
Copilot
AI
Feb 6, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment says this is a periodic CUDA cache clear, but the code clears the cache unconditionally for every cand iteration. Either update the comment to match reality or add an actual periodic condition (e.g., every N candidates/steps) to avoid unnecessary cache thrash.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
del obhere doesn’t immediately free the task payload because the originaltasktuple still holds a reference toobuntil the next loop iteration. If the intent is to drop references beforegc.collect(), alsodel task(and potentiallyargs/kwargs) before callinggc.collect().