diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index bcd192c1f166..8f878e8f4450 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -556,7 +556,7 @@ def forward(self, latent): height, width = latent.shape[-2:] else: height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size - latent = self.proj(latent) + latent = self.proj(latent.to(self.proj.weight.dtype)) if self.flatten: latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC if self.layer_norm: diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 78a77ebcfea9..61eb34cd9d8c 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -677,7 +677,7 @@ def forward( `tuple` where the first element is the sample tensor. """ - hidden_states = self.x_embedder(hidden_states) + hidden_states = self.x_embedder(hidden_states.to(self.x_embedder.weight.dtype)) timestep = timestep.to(hidden_states.dtype) * 1000 if guidance is not None: