@@ -341,38 +341,6 @@ void extTokenSoftmaxReduceVInference(const at::Tensor& logics,
341341 b_start_loc, b_seq_len, max_input_len, other_kv_index);
342342}
343343
344- // void extTokenDecodeAttentionInference(const at::Tensor& q, const at::Tensor& k,
345- // const at::Tensor& v, at::Tensor& out,
346- // const at::Tensor& b_loc,
347- // const at::Tensor& b_start_loc,
348- // const at::Tensor& b_seq_len,
349- // int max_input_len, int other_kv_index) {
350- // callDiopi(diopiTokenDecodeAttentionInference, out, q, k, v, b_loc, b_start_loc,
351- // b_seq_len, max_input_len, other_kv_index);
352- // }
353-
354- // void extTokenDecodeAttentionInferenceBatchOne(const at::Tensor& q, const at::Tensor& k,
355- // const at::Tensor& v, at::Tensor& out,
356- // const at::Tensor& b_loc,
357- // const at::Tensor& b_start_loc,
358- // const at::Tensor& b_seq_len,
359- // int max_input_len, int other_kv_index) {
360- // callDiopi(diopiTokenDecodeAttentionInferenceBatchOne, out, q, k, v, b_loc, b_start_loc,
361- // b_seq_len, max_input_len, other_kv_index);
362- // }
363-
364- // void extIncreFlashAttention(const at::Tensor& q, const at::Tensor& k,
365- // const at::Tensor& v, at::Tensor& out,
366- // const int head, const char* layout,
367- // const c10::optional<at::Tensor>& padding_mask = {},
368- // const c10::optional<at::Tensor>& atten_mask = {},
369- // const OptionalIntArray& actual_seq_lengths = {},
370- // int64_t num_heads = 1, double scale_value = 1.0,
371- // const std::string& input_layout = "BSH", int64_t num_key_value_heads = 0) {
372- // callDiopi(diopiIncreFlashAttention, out, q, k, v, padding_mask, atten_mask,
373- // actual_seq_lengths, num_heads, scale_value, input_layout.c_str(), num_key_value_heads);
374- // }
375-
376344void extPromptFlashAttention (at::Tensor& out, const at::Tensor& q,
377345 const at::Tensor& k, const at::Tensor& v,
378346 const at::Tensor& atten_mask,
@@ -412,11 +380,11 @@ void extApplyPenaltyV2(at::Tensor& logits, const at::Tensor& presence_penalty,
412380}
413381
414382void extPagedAttention (at::Tensor& out, const at::Tensor& q, const at::Tensor& k, const at::Tensor& v,
415- const at::IntArrayRef& actual_seq_lengths ,
416- int64_t numHeads, int64_t numKeyValueHeads, int64_t dim ,
417- const at::Tensor& block_table ,
418- int64_t block_size) {
419- callDiopi (diopiPagedAttention, out, q, k, v, actual_seq_lengths,
383+ const c10::optional< at::Tensor>& atten_mask = {} ,
384+ const at::IntArrayRef& actual_seq_lengths = {} ,
385+ int64_t numHeads = 1 , int64_t numKeyValueHeads = 1 , int64_t dim = 1 ,
386+ const c10::optional<at::Tensor>& block_table = {}, int64_t block_size = 1 ) {
387+ callDiopi (diopiPagedAttention, out, q, k, v, atten_mask, actual_seq_lengths,
420388 numHeads, numKeyValueHeads, dim,
421389 block_table, block_size);
422390}
@@ -501,18 +469,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
501469 m.def (" token_softmax_reducev_inference" , &extTokenSoftmaxReduceVInference,
502470 " deeplink ext_token_softmax_reducev_inference" );
503471 }
504- // if (&diopiTokenDecodeAttentionInference != nullptr) {
505- // m.def("token_decode_attention_inference", &extTokenDecodeAttentionInference,
506- // "deeplink token_decode_attention_inference");
507- // }
508- // if (&diopiTokenDecodeAttentionInferenceBatchOne != nullptr) {
509- // m.def("token_decode_attention_inference_batch_one", &extTokenDecodeAttentionInferenceBatchOne,
510- // "deeplink token_decode_attention_inference");
511- // }
512- // if (&diopiIncreFlashAttention != nullptr) {
513- // m.def("incre_flash_attention", &extIncreFlashAttention,
514- // "deeplink incre_flash_attention");
515- // }
516472 if (&diopiPromptFlashAttention != nullptr ) {
517473 m.def (" prompt_flash_attention" , &extPromptFlashAttention,
518474 " deeplink ext_prompt_flash_attention" );
0 commit comments