diff --git a/comfy/ldm/lightricks/embeddings_connector.py b/comfy/ldm/lightricks/embeddings_connector.py index 06f5ada89bf3..862bc844ab70 100644 --- a/comfy/ldm/lightricks/embeddings_connector.py +++ b/comfy/ldm/lightricks/embeddings_connector.py @@ -234,7 +234,7 @@ def generate_freq_grid(self, spacing, dtype, device): return indices - def precompute_freqs_cis(self, indices_grid, spacing="exp"): + def precompute_freqs_cis(self, indices_grid, spacing="exp", out_dtype=None): dim = self.inner_dim n_elem = 2 # 2 because of cos and sin freqs = self.precompute_freqs(indices_grid, spacing) @@ -247,7 +247,7 @@ def precompute_freqs_cis(self, indices_grid, spacing="exp"): ) else: cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem) - return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope + return cos_freq.to(dtype=out_dtype), sin_freq.to(dtype=out_dtype), self.split_rope def forward( self, @@ -288,7 +288,7 @@ def forward( hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device ) indices_grid = indices_grid[None, None, :] - freqs_cis = self.precompute_freqs_cis(indices_grid) + freqs_cis = self.precompute_freqs_cis(indices_grid, out_dtype=hidden_states.dtype) # 2. Blocks for block_idx, block in enumerate(self.transformer_1d_blocks):