Skip to content

Commit e5ae670

Browse files
authored
Update ace15.py to allow min_p sampling (Comfy-Org#12373)
1 parent 3fe61ce commit e5ae670

2 files changed

Lines changed: 15 additions & 5 deletions

File tree

comfy/text_encoders/ace15.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def sample_manual_loop_no_classes(
1616
temperature: float = 0.85,
1717
top_p: float = 0.9,
1818
top_k: int = None,
19+
min_p: float = 0.000,
1920
seed: int = 1,
2021
min_tokens: int = 1,
2122
max_new_tokens: int = 2048,
@@ -80,6 +81,12 @@ def sample_manual_loop_no_classes(
8081
min_val = top_k_vals[..., -1, None]
8182
cfg_logits[cfg_logits < min_val] = remove_logit_value
8283

84+
if min_p is not None and min_p > 0:
85+
probs = torch.softmax(cfg_logits, dim=-1)
86+
p_max = probs.max(dim=-1, keepdim=True).values
87+
indices_to_remove = probs < (min_p * p_max)
88+
cfg_logits[indices_to_remove] = remove_logit_value
89+
8390
if top_p is not None and top_p < 1.0:
8491
sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True)
8592
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
@@ -110,7 +117,7 @@ def sample_manual_loop_no_classes(
110117
return output_audio_codes
111118

112119

113-
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0):
120+
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0, min_p=0.000):
114121
positive = [[token for token, _ in inner_list] for inner_list in positive]
115122
positive = positive[0]
116123

@@ -134,7 +141,7 @@ def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=102
134141
paddings = []
135142
ids = [positive]
136143

137-
return sample_manual_loop_no_classes(model, ids, paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
144+
return sample_manual_loop_no_classes(model, ids, paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
138145

139146

140147
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
@@ -192,6 +199,7 @@ def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
192199
temperature = kwargs.get("temperature", 0.85)
193200
top_p = kwargs.get("top_p", 0.9)
194201
top_k = kwargs.get("top_k", 0.0)
202+
min_p = kwargs.get("min_p", 0.000)
195203

196204
duration = math.ceil(duration)
197205
kwargs["duration"] = duration
@@ -239,6 +247,7 @@ def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
239247
"temperature": temperature,
240248
"top_p": top_p,
241249
"top_k": top_k,
250+
"min_p": min_p,
242251
}
243252
return out
244253

@@ -299,7 +308,7 @@ def encode_token_weights(self, token_weight_pairs):
299308

300309
lm_metadata = token_weight_pairs["lm_metadata"]
301310
if lm_metadata["generate_audio_codes"]:
302-
audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["max_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"])
311+
audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"], min_p=lm_metadata["min_p"])
303312
out["audio_codes"] = [audio_codes]
304313

305314
return base_out, None, out

comfy_extras/nodes_ace.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,14 @@ def define_schema(cls):
4949
io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
5050
io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
5151
io.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
52+
io.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True),
5253
],
5354
outputs=[io.Conditioning.Output()],
5455
)
5556

5657
@classmethod
57-
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k) -> io.NodeOutput:
58-
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k)
58+
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> io.NodeOutput:
59+
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p)
5960
conditioning = clip.encode_from_tokens_scheduled(tokens)
6061
return io.NodeOutput(conditioning)
6162

0 commit comments

Comments
 (0)