Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from __future__ import annotations

import gc
import json
import logging
Expand Down Expand Up @@ -814,6 +816,12 @@ def control_weight_fn(_: int) -> float:

self.control_str = last_control

# Clean up memory after test_all() which creates temporary PromptManagers
del model_tests
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()

return self.control_str, loss, steps

def test(
Expand Down Expand Up @@ -1644,7 +1652,12 @@ def run(model: Any, tasks: mp.JoinableQueue[Any], results: mp.JoinableQueue[Any]
ob, fn, args, kwargs = task
if fn == "grad":
with torch.enable_grad(): # type: ignore[no-untyped-call, unused-ignore]
results.put(ob.grad(*args, **kwargs))
result = ob.grad(*args, **kwargs)
results.put(result)
del result
# Clear CUDA cache after gradient computation to prevent memory accumulation
if torch.cuda.is_available():
torch.cuda.empty_cache()
else:
with torch.no_grad():
if fn == "logits":
Expand All @@ -1657,6 +1670,9 @@ def run(model: Any, tasks: mp.JoinableQueue[Any], results: mp.JoinableQueue[Any]
results.put(ob.test_loss(*args, **kwargs))
else:
results.put(fn(*args, **kwargs))
# Clean up the task object to free memory
del ob
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

del ob here doesn’t immediately free the task payload because the original task tuple still holds a reference to ob until the next loop iteration. If the intent is to drop references before gc.collect(), also del task (and potentially args/kwargs) before calling gc.collect().

Suggested change
del ob
del ob
del task
del args
del kwargs

Copilot uses AI. Check for mistakes.
gc.collect()
tasks.task_done()
Comment on lines +1674 to 1676
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling gc.collect() on every task processed can be a major CPU-side bottleneck, especially in long GCG runs where the worker loop is hot. Consider collecting less frequently (e.g., every N tasks or only after known large allocations like grad) or making it configurable.

Copilot uses AI. Check for mistakes.

def start(self) -> "ModelWorker":
Expand Down
20 changes: 19 additions & 1 deletion pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,17 @@ def token_gradients(

loss.backward()

return one_hot.grad.clone()
# Clone and detach the gradient to break the computation graph
grad = one_hot.grad.clone().detach()

# Explicitly clear references to free memory
del one_hot, input_embeds, embeds, full_embeds, logits, targets, loss

# Clear CUDA cache to release GPU memory
if torch.cuda.is_available():
torch.cuda.empty_cache()

Comment on lines +77 to +80
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.cuda.empty_cache() inside token_gradients() will run on every gradient call (hot path) and can introduce significant synchronization/throughput overhead. Consider making cache eviction conditional (e.g., behind a flag, every N iterations, or based on torch.cuda.memory_reserved()/max_memory_allocated() thresholds) rather than unconditionally emptying the cache each call.

Suggested change
# Clear CUDA cache to release GPU memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Conditionally clear CUDA cache to mitigate memory pressure without
# incurring synchronization overhead on every gradient computation.
if torch.cuda.is_available():
device = getattr(model, "device", torch.device("cuda"))
try:
reserved_memory: int = torch.cuda.memory_reserved(device)
total_memory: int = torch.cuda.get_device_properties(device).total_memory
except Exception:
reserved_memory = 0
total_memory = 1
if total_memory > 0 and reserved_memory / total_memory > 0.9:
torch.cuda.empty_cache()

Copilot uses AI. Check for mistakes.
return grad


class GCGAttackPrompt(AttackPrompt):
Expand Down Expand Up @@ -144,9 +154,11 @@ def step(
j - 1, control_cand, filter_cand=filter_cand, curr_control=self.control_str
)
)
del grad # Explicitly delete old grad before reassignment
grad = new_grad
else:
grad += new_grad
del new_grad # Clean up new_grad after use

with torch.no_grad():
control_cand = self.prompts[j].sample_control(grad, batch_size, topk, temp, allow_non_ascii)
Expand All @@ -155,6 +167,8 @@ def step(
)
del grad, control_cand
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()

# Search
loss = torch.zeros(len(control_cands) * batch_size).to(main_device)
Expand Down Expand Up @@ -192,6 +206,10 @@ def step(
f"loss={loss[j * batch_size : (j + 1) * batch_size].min().item() / (i + 1):.4f}" # type: ignore[operator]
)

# Periodically clear CUDA cache during search to prevent memory buildup
if torch.cuda.is_available():
torch.cuda.empty_cache()
Comment on lines +209 to +211
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment says this is a periodic CUDA cache clear, but the code clears the cache unconditionally for every cand iteration. Either update the comment to match reality or add an actual periodic condition (e.g., every N candidates/steps) to avoid unnecessary cache thrash.

Copilot uses AI. Check for mistakes.

min_idx = loss.argmin()
model_idx = min_idx // batch_size
batch_idx = min_idx % batch_size
Expand Down
Loading