From 62c5201c5f9723ecf0a62ff777868622da83bb31 Mon Sep 17 00:00:00 2001 From: Jason Date: Sun, 22 Mar 2026 00:54:50 +0100 Subject: [PATCH] fix: cast input dtype in PatchEmbed and FluxTransformer to prevent dtype mismatch When running SD3/FLUX models on CPU with torch_dtype=float16, the projection layers (Conv2d in PatchEmbed, Linear in FluxTransformer2DModel) may have float32 weights/bias while receiving half-precision inputs, causing 'Input type (c10::Half) and bias type (float) should be the same'. Cast inputs to match layer weight dtype before calling the projection. Fixes #13300 --- src/diffusers/models/embeddings.py | 2 +- src/diffusers/models/transformers/transformer_flux.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index bcd192c1f166..8f878e8f4450 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -556,7 +556,7 @@ def forward(self, latent): height, width = latent.shape[-2:] else: height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size - latent = self.proj(latent) + latent = self.proj(latent.to(self.proj.weight.dtype)) if self.flatten: latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC if self.layer_norm: diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 78a77ebcfea9..61eb34cd9d8c 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -677,7 +677,7 @@ def forward( `tuple` where the first element is the sample tensor. """ - hidden_states = self.x_embedder(hidden_states) + hidden_states = self.x_embedder(hidden_states.to(self.x_embedder.weight.dtype)) timestep = timestep.to(hidden_states.dtype) * 1000 if guidance is not None: