@@ -16,6 +16,7 @@ def sample_manual_loop_no_classes(
1616 temperature : float = 0.85 ,
1717 top_p : float = 0.9 ,
1818 top_k : int = None ,
19+ min_p : float = 0.000 ,
1920 seed : int = 1 ,
2021 min_tokens : int = 1 ,
2122 max_new_tokens : int = 2048 ,
@@ -80,6 +81,12 @@ def sample_manual_loop_no_classes(
8081 min_val = top_k_vals [..., - 1 , None ]
8182 cfg_logits [cfg_logits < min_val ] = remove_logit_value
8283
84+ if min_p is not None and min_p > 0 :
85+ probs = torch .softmax (cfg_logits , dim = - 1 )
86+ p_max = probs .max (dim = - 1 , keepdim = True ).values
87+ indices_to_remove = probs < (min_p * p_max )
88+ cfg_logits [indices_to_remove ] = remove_logit_value
89+
8390 if top_p is not None and top_p < 1.0 :
8491 sorted_logits , sorted_indices = torch .sort (cfg_logits , descending = True )
8592 cumulative_probs = torch .cumsum (torch .softmax (sorted_logits , dim = - 1 ), dim = - 1 )
@@ -110,7 +117,7 @@ def sample_manual_loop_no_classes(
110117 return output_audio_codes
111118
112119
113- def generate_audio_codes (model , positive , negative , min_tokens = 1 , max_tokens = 1024 , seed = 0 , cfg_scale = 2.0 , temperature = 0.85 , top_p = 0.9 , top_k = 0 ):
120+ def generate_audio_codes (model , positive , negative , min_tokens = 1 , max_tokens = 1024 , seed = 0 , cfg_scale = 2.0 , temperature = 0.85 , top_p = 0.9 , top_k = 0 , min_p = 0.000 ):
114121 positive = [[token for token , _ in inner_list ] for inner_list in positive ]
115122 positive = positive [0 ]
116123
@@ -134,7 +141,7 @@ def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=102
134141 paddings = []
135142 ids = [positive ]
136143
137- return sample_manual_loop_no_classes (model , ids , paddings , cfg_scale = cfg_scale , temperature = temperature , top_p = top_p , top_k = top_k , seed = seed , min_tokens = min_tokens , max_new_tokens = max_tokens )
144+ return sample_manual_loop_no_classes (model , ids , paddings , cfg_scale = cfg_scale , temperature = temperature , top_p = top_p , top_k = top_k , min_p = min_p , seed = seed , min_tokens = min_tokens , max_new_tokens = max_tokens )
138145
139146
140147class ACE15Tokenizer (sd1_clip .SD1Tokenizer ):
@@ -192,6 +199,7 @@ def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
192199 temperature = kwargs .get ("temperature" , 0.85 )
193200 top_p = kwargs .get ("top_p" , 0.9 )
194201 top_k = kwargs .get ("top_k" , 0.0 )
202+ min_p = kwargs .get ("min_p" , 0.000 )
195203
196204 duration = math .ceil (duration )
197205 kwargs ["duration" ] = duration
@@ -239,6 +247,7 @@ def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
239247 "temperature" : temperature ,
240248 "top_p" : top_p ,
241249 "top_k" : top_k ,
250+ "min_p" : min_p ,
242251 }
243252 return out
244253
@@ -299,7 +308,7 @@ def encode_token_weights(self, token_weight_pairs):
299308
300309 lm_metadata = token_weight_pairs ["lm_metadata" ]
301310 if lm_metadata ["generate_audio_codes" ]:
302- audio_codes = generate_audio_codes (getattr (self , self .lm_model , self .qwen3_06b ), token_weight_pairs ["lm_prompt" ], token_weight_pairs ["lm_prompt_negative" ], min_tokens = lm_metadata ["min_tokens" ], max_tokens = lm_metadata ["max_tokens " ], seed = lm_metadata ["seed" ], cfg_scale = lm_metadata ["cfg_scale" ], temperature = lm_metadata ["temperature" ], top_p = lm_metadata ["top_p" ], top_k = lm_metadata ["top_k" ])
311+ audio_codes = generate_audio_codes (getattr (self , self .lm_model , self .qwen3_06b ), token_weight_pairs ["lm_prompt" ], token_weight_pairs ["lm_prompt_negative" ], min_tokens = lm_metadata ["min_tokens" ], max_tokens = lm_metadata ["min_tokens " ], seed = lm_metadata ["seed" ], cfg_scale = lm_metadata ["cfg_scale" ], temperature = lm_metadata ["temperature" ], top_p = lm_metadata ["top_p" ], top_k = lm_metadata ["top_k" ], min_p = lm_metadata [ "min_p " ])
303312 out ["audio_codes" ] = [audio_codes ]
304313
305314 return base_out , None , out
0 commit comments