Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
62c49c9
Add cuDNN score_mod attention plumbing
vcherepanov-nv Mar 11, 2026
2de2847
Fix score_mod helper return type
vcherepanov-nv Mar 11, 2026
6a02b84
Fix score_mod PyGraph callback lifetime
vcherepanov-nv Mar 11, 2026
a462007
Add test_dpa_score_mod and fix score_mod trampoline/caching
vcherepanov-nv Mar 11, 2026
b1e7adf
Resore original trampoline
vcherepanov-nv Mar 12, 2026
8fed706
Debug WIP...
vcherepanov-nv Mar 12, 2026
cdb81fe
Fixes: lifetime issue in trampoline, callback identities instead of b…
vcherepanov-nv Mar 12, 2026
f7636eb
Another fix
vcherepanov-nv Mar 12, 2026
1aa3b84
Fix the case of callbacks returning None
vcherepanov-nv Mar 12, 2026
9fd5c81
Causal attn test
vcherepanov-nv Mar 12, 2026
f146c4a
Don't require score_mod_bprop
vcherepanov-nv Mar 12, 2026
279b7e3
Avoid owning py::function in cached score modifier
vcherepanov-nv Mar 13, 2026
0763c02
Tests cleanup
vcherepanov-nv Mar 13, 2026
f34d825
Support extra score_mod tensors in fused attention
vcherepanov-nv Mar 13, 2026
59643ca
Use vendored DLPack header in fused attention
vcherepanov-nv Mar 13, 2026
6b632cf
Add DLPack dependency to common build
vcherepanov-nv Mar 13, 2026
c3493bd
Fix fused attention score_mod_tensors call site
vcherepanov-nv Mar 13, 2026
591cb06
Remove pygraph tensor_like dependency from score_mod tensors
vcherepanov-nv Mar 13, 2026
c663752
Fix MultiheadAttention typing import for score mod tensors
vcherepanov-nv Mar 13, 2026
2035fe6
Fix score mod extra tensor DLPack lifetime
vcherepanov-nv Mar 13, 2026
9d0843f
Support host score mod tensors via retained DLPack capsules
vcherepanov-nv Mar 13, 2026
6a62399
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 16, 2026
ff8826e
Fix JAX fused attention call signatures
vcherepanov-nv Mar 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
436 changes: 436 additions & 0 deletions tests/pytorch/attention/test_attention.py

Large diffs are not rendered by default.

22 changes: 22 additions & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

cmake_minimum_required(VERSION 3.21)

include(FetchContent)

# Language options
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
Expand Down Expand Up @@ -97,6 +99,21 @@ set(CUTLASS_TOOLS_INCLUDE_DIR

# Python
find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
find_package(pybind11 CONFIG REQUIRED)

option(NVTE_USE_SYSTEM_DLPACK "Whether to use a system-installed dlpack package" OFF)
if(NVTE_USE_SYSTEM_DLPACK)
find_package(dlpack REQUIRED)
set(NVTE_DLPACK_TARGET dlpack::dlpack)
else()
FetchContent_Declare(
dlpack
GIT_REPOSITORY https://github.com/dmlc/dlpack
GIT_TAG v1.1
)
FetchContent_MakeAvailable(dlpack)
set(NVTE_DLPACK_TARGET dlpack)
endif()

# Configure Transformer Engine library
include_directories(${PROJECT_SOURCE_DIR}/..)
Expand Down Expand Up @@ -259,6 +276,8 @@ target_link_libraries(transformer_engine PUBLIC
CUDA::cublas
CUDA::cudart
CUDNN::cudnn_all)
target_link_libraries(transformer_engine PRIVATE Python::Module pybind11::headers)
target_link_libraries(transformer_engine PRIVATE ${NVTE_DLPACK_TARGET})

target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
Expand All @@ -268,6 +287,9 @@ target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_
target_include_directories(transformer_engine PRIVATE
${CUTLASS_INCLUDE_DIR}
${CUTLASS_TOOLS_INCLUDE_DIR})
target_include_directories(transformer_engine PRIVATE
${Python_INCLUDE_DIRS}
"${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/python/pygraph")

# Compiling Userbuffers with native MPI bootstrapping requires linking against MPI
option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF)
Expand Down
41 changes: 21 additions & 20 deletions transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -535,19 +535,17 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
}

// NVTE fused attention FWD with separate Q, K and V
void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
bool return_max_logit, bool cuda_graph, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right,
bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) {
void nvte_fused_attn_fwd(
const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias,
const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit,
bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,
void *score_mod, void *score_mod_tensors, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
Expand Down Expand Up @@ -635,10 +633,11 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v,
page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training,
return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type,
window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, input_V,
input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q,
input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded,
input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle);
window_size_left, window_size_right, bottom_right_diagonal, score_mod, score_mod_tensors,
input_Q, input_K, input_V, input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. "
Expand Down Expand Up @@ -670,7 +669,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right,
bool bottom_right_diagonal, bool deterministic, bool cuda_graph,
NVTETensor workspace, cudaStream_t stream) {
void *score_mod, void *score_mod_bprop, void *score_mod_tensors,
void *score_mod_bprop_tensors, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
Expand Down Expand Up @@ -743,8 +743,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
fused_attn_arbitrary_seqlen_bwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right,
bottom_right_diagonal, deterministic, input_Q, input_K, input_V, input_O, input_dO,
input_Bias, input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias,
bottom_right_diagonal, deterministic, score_mod, score_mod_bprop, score_mod_tensors,
score_mod_bprop_tensors, input_Q, input_K, input_V, input_O, input_dO, input_Bias,
input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias,
output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
#else
Expand Down
Loading
Loading