-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpybind.cpp
More file actions
34 lines (33 loc) · 1.51 KB
/
pybind.cpp
File metadata and controls
34 lines (33 loc) · 1.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#include "grouped_gemm/src/grouped_gemm.h"
#include "grouped_soft_gemv/src/grouped_soft_gemv.h"
#ifdef ENABLE_GROUPED_QUERY_ATTENTION
#include "ascend-closed/grouped_query_attention/src/grouped_query_attention.h"
#endif
#ifdef ENABLE_INCRE_FLASH_ATTENTION
#include "ascend-closed/incre_flash_attention/src/incre_flash_attention.h"
#endif
#include <pybind11/pybind11.h>
namespace cinfer_ascendc {
using namespace pybind11::literals;
PYBIND11_MODULE(cinfer_ascendc, m) {
m.def("grouped_soft_gemv", &grouped_soft_gemv::GroupedSoftGemv, "x"_a,
"weight"_a, "scale"_a, "groupList"_a, "computeType"_a, "output"_a,
"GROUP GEMV.");
m.def("grouped_gemm", &grouped_gemm::GroupedGemm, "x"_a, "weight"_a,
"antiquantScaleOptional"_a, "antiquantOffsetOptional"_a,
"groupListOptional"_a, "computeType"_a, "output"_a, "GROUP GEMM.");
#ifdef ENABLE_GROUPED_QUERY_ATTENTION
m.def("grouped_query_attention",
&grouped_query_attention::GroupedQueryAttention, "query"_a, "key"_a,
"value"_a, "seqLen"_a, "attentionOut"_a, "batchSize"_a, "layout"_a,
"scale"_a, "GROUP QUERY ATTENTION.");
#endif
#ifdef ENABLE_INCRE_FLASH_ATTENTION
m.def("incre_flash_attention", &incre_flash_attention::IncreFlashAttention,
"query"_a, "key"_a, "value"_a, "seqLen"_a, "maxSeqLen"_a,
"startIdxEachCore"_a, "attenMask"_a, "attentionOut"_a, "batch"_a,
"numHeads"_a, "scale"_a, "layout"_a, "kvNumHeads"_a,
"INCRE FLASH ATTENTION.");
#endif
}
} // namespace cinfer_ascendc