From 7e5310b908fb70d70fd60a0ceff65b6e8d1cceb7 Mon Sep 17 00:00:00 2001 From: Krzysztof Rymski Date: Mon, 9 Feb 2026 08:28:46 -0800 Subject: [PATCH] Internal changes PiperOrigin-RevId: 867617121 --- gemma/flash_attention.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index b8c105aa..488d4253 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -446,7 +446,7 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max, template , class DF4 = hn::CappedTag, class VF4 = hn::Vec, class VF = hn::Vec, typename F> -static VF4 HWY_INLINE Reduce4(DF df, VF x_0, VF x_1, VF x_2, VF x_3, +static HWY_INLINE VF4 Reduce4(DF df, VF x_0, VF x_1, VF x_2, VF x_3, F reducer) { const DF4 df4; constexpr size_t kMaxLanes = hn::MaxLanes(df); @@ -469,7 +469,7 @@ static VF4 HWY_INLINE Reduce4(DF df, VF x_0, VF x_1, VF x_2, VF x_3, // Handles Up to 4 Q rows by NF*2 timesteps of flash attention. template > -static void HWY_INLINE FlashAttentionTileStepAndApplySoftCap( +static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap( DF df, float att_cap, float one_over_att_cap, VF& x_0_p0, VF& x_0_p1, VF& x_1_p0, VF& x_1_p1, VF& x_2_p0, VF& x_2_p1, VF& x_3_p0, VF& x_3_p1, float* HWY_RESTRICT old_max, float* HWY_RESTRICT old_d,