@@ -106,6 +106,8 @@ def __init__(
106106 names_which_can_be_saved : list = [],
107107 names_which_can_be_offloaded : list = [],
108108 attention_kernel : str = "flash" ,
109+ a2v_attention_kernel : str = "flash" ,
110+ v2a_attention_kernel : str = "dot_product" ,
109111 flash_block_sizes : BlockSizes = None ,
110112 flash_min_seq_length : int = 4096 ,
111113 ):
@@ -243,7 +245,7 @@ def __init__(
243245 eps = norm_eps ,
244246 dtype = dtype ,
245247 mesh = mesh ,
246- attention_kernel = "flash" ,
248+ attention_kernel = a2v_attention_kernel ,
247249 rope_type = rope_type ,
248250 flash_block_sizes = flash_block_sizes ,
249251 flash_min_seq_length = 0 ,
@@ -270,7 +272,7 @@ def __init__(
270272 eps = norm_eps ,
271273 dtype = dtype ,
272274 mesh = mesh ,
273- attention_kernel = self . attention_kernel ,
275+ attention_kernel = v2a_attention_kernel ,
274276 rope_type = rope_type ,
275277 flash_block_sizes = flash_block_sizes ,
276278 flash_min_seq_length = flash_min_seq_length ,
@@ -571,6 +573,8 @@ def __init__(
571573 names_which_can_be_offloaded : list = [],
572574 scan_layers : bool = True ,
573575 attention_kernel : str = "flash" ,
576+ a2v_attention_kernel : str = "flash" ,
577+ v2a_attention_kernel : str = "dot_product" ,
574578 qk_norm : str = "rms_norm_across_heads" ,
575579 flash_block_sizes : BlockSizes = None ,
576580 flash_min_seq_length : int = 4096 ,
@@ -620,6 +624,8 @@ def __init__(
620624 self .names_which_can_be_offloaded = names_which_can_be_offloaded
621625 self .scan_layers = scan_layers
622626 self .attention_kernel = attention_kernel
627+ self .a2v_attention_kernel = a2v_attention_kernel
628+ self .v2a_attention_kernel = v2a_attention_kernel
623629 self .flash_min_seq_length = flash_min_seq_length
624630
625631 _out_channels = self .out_channels or self .in_channels
@@ -813,6 +819,8 @@ def init_block(rngs):
813819 names_which_can_be_saved = self .names_which_can_be_saved ,
814820 names_which_can_be_offloaded = self .names_which_can_be_offloaded ,
815821 attention_kernel = self .attention_kernel ,
822+ a2v_attention_kernel = self .a2v_attention_kernel ,
823+ v2a_attention_kernel = self .v2a_attention_kernel ,
816824 flash_block_sizes = flash_block_sizes ,
817825 flash_min_seq_length = self .flash_min_seq_length ,
818826 )
@@ -846,6 +854,8 @@ def init_block(rngs):
846854 names_which_can_be_saved = self .names_which_can_be_saved ,
847855 names_which_can_be_offloaded = self .names_which_can_be_offloaded ,
848856 attention_kernel = self .attention_kernel ,
857+ a2v_attention_kernel = self .a2v_attention_kernel ,
858+ v2a_attention_kernel = self .v2a_attention_kernel ,
849859 flash_block_sizes = flash_block_sizes ,
850860 flash_min_seq_length = self .flash_min_seq_length ,
851861 )
0 commit comments