diff --git a/sgm/modules/diffusionmodules/guiders.py b/sgm/modules/diffusionmodules/guiders.py index db2cede0..411f0747 100644 --- a/sgm/modules/diffusionmodules/guiders.py +++ b/sgm/modules/diffusionmodules/guiders.py @@ -147,7 +147,7 @@ def __init__( self.scale = torch.cat( [ rise_steps, - torch.ones(num_frames - 2 * int(num_frames * edge_perc)), + torch.full((num_frames - 2 * int(num_frames * edge_perc),), max_scale), fall_steps, ] ).unsqueeze(0)