From 62c49c9fe153b5c644bbedd9057469c60978675e Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Wed, 11 Mar 2026 04:47:08 +0000 Subject: [PATCH 01/23] Add cuDNN score_mod attention plumbing --- tests/pytorch/attention/test_attention.py | 97 +++++++++++++++++++ transformer_engine/common/CMakeLists.txt | 5 + .../common/fused_attn/fused_attn.cpp | 19 ++-- .../fused_attn_f16_arbitrary_seqlen.cu | 91 ++++++++++++----- .../fused_attn_f16_arbitrary_seqlen.h | 10 +- .../include/transformer_engine/fused_attn.h | 9 +- .../jax/csrc/extensions/attention.cpp | 14 +-- .../dot_product_attention/backends.py | 26 +++++ .../dot_product_attention.py | 8 ++ .../attention/dot_product_attention/utils.py | 30 ++++++ .../pytorch/attention/multi_head_attention.py | 8 ++ .../pytorch/cpp_extensions/fused_attn.py | 22 +++++ transformer_engine/pytorch/csrc/extensions.h | 8 +- .../pytorch/csrc/extensions/attention.cpp | 22 +++-- 14 files changed, 314 insertions(+), 55 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 60ade522e3..f6131d6d4b 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -99,6 +99,11 @@ def reset_global_fp8_state(): param_types.append(torch.bfloat16) param_types_lean = [torch.bfloat16] + +def _identity_score_mod(_graph, score_tensor): + return score_tensor + + model_configs_base = { # test: ModelConfig(b, sq, hq, dqk) "base_1_0": ModelConfig(8, 128, 16, 64), @@ -1416,6 +1421,98 @@ def test_transformer_layer( torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols) +@pytest.mark.skipif(get_cudnn_version() < (9, 0, 0), reason="cuDNN 9.0.0+ is required.") +def test_fused_attn_score_mod_smoke(): + pytest.importorskip("cudnn") + + batch_size, seqlen, num_heads, head_dim = 2, 64, 4, 64 + dtype = torch.float16 + device = "cuda" + cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, seqlen, device=device, dtype=torch.int32) + + q = torch.randn(batch_size, seqlen, num_heads, head_dim, device=device, dtype=dtype) + k = torch.randn_like(q) + v = torch.randn_like(q) + + out, aux_ctx_tensors = fused_attn_fwd( + True, + seqlen, + seqlen, + cu_seqlens, + cu_seqlens, + q, + k, + v, + dtype, + FusedAttnBackend["F16_arbitrary_seqlen"], + qkv_layout="bshd_bshd_bshd", + attn_mask_type="no_mask", + dropout=0.0, + score_mod=_identity_score_mod, + ) + + dq, dk, dv, *_ = fused_attn_bwd( + seqlen, + seqlen, + cu_seqlens, + cu_seqlens, + q, + k, + v, + out, + torch.randn_like(out), + dtype, + tex.DType.kFloat16, + aux_ctx_tensors, + FusedAttnBackend["F16_arbitrary_seqlen"], + qkv_layout="bshd_bshd_bshd", + attn_mask_type="no_mask", + dropout=0.0, + score_mod=_identity_score_mod, + score_mod_bprop=_identity_score_mod, + ) + + assert dq.shape == q.shape + assert dk.shape == k.shape + assert dv.shape == v.shape + + +@pytest.mark.skipif(get_cudnn_version() < (9, 0, 0), reason="cuDNN 9.0.0+ is required.") +def test_multihead_attention_score_mod_forces_fused_backend(): + pytest.importorskip("cudnn") + + _attention_backends["attention_params"] = None + _attention_backends["backend_selection_requires_update"] = True + + hidden_size = 256 + num_heads = 4 + seq_len = 64 + batch_size = 2 + mha = MultiheadAttention( + hidden_size=hidden_size, + num_attention_heads=num_heads, + attention_dropout=0.0, + params_dtype=torch.float16, + device="cuda", + qkv_format="sbhd", + ).cuda() + mha.train() + + hidden_states = torch.randn( + seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float16, requires_grad=True + ) + output = mha( + hidden_states, + attn_mask_type="causal", + score_mod=_identity_score_mod, + score_mod_bprop=_identity_score_mod, + ) + output.sum().backward() + + assert _attention_backends["use_fused_attention"] + assert not _attention_backends["use_flash_attention"] + + @pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_te_layer]) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index b9e2b907e0..36b3f39bae 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -97,6 +97,7 @@ set(CUTLASS_TOOLS_INCLUDE_DIR # Python find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) +find_package(pybind11 CONFIG REQUIRED) # Configure Transformer Engine library include_directories(${PROJECT_SOURCE_DIR}/..) @@ -259,6 +260,7 @@ target_link_libraries(transformer_engine PUBLIC CUDA::cublas CUDA::cudart CUDNN::cudnn_all) +target_link_libraries(transformer_engine PRIVATE Python::Module pybind11::headers) target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) @@ -268,6 +270,9 @@ target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_ target_include_directories(transformer_engine PRIVATE ${CUTLASS_INCLUDE_DIR} ${CUTLASS_TOOLS_INCLUDE_DIR}) +target_include_directories(transformer_engine PRIVATE + ${Python_INCLUDE_DIRS} + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/python/pygraph") # Compiling Userbuffers with native MPI bootstrapping requires linking against MPI option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 6a136c67e4..78f3918d01 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -547,7 +547,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { + bool bottom_right_diagonal, void *score_mod, NVTETensor workspace, + cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -635,8 +636,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, input_V, - input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + window_size_left, window_size_right, bottom_right_diagonal, score_mod, input_Q, input_K, + input_V, input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); #else @@ -670,7 +671,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, bool cuda_graph, - NVTETensor workspace, cudaStream_t stream) { + void *score_mod, void *score_mod_bprop, NVTETensor workspace, + cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -743,10 +745,11 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso fused_attn_arbitrary_seqlen_bwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, - bottom_right_diagonal, deterministic, input_Q, input_K, input_V, input_O, input_dO, - input_Bias, input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias, - output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, - input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); + bottom_right_diagonal, deterministic, score_mod, score_mod_bprop, input_Q, input_K, + input_V, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, output_dQ, + output_dK, output_dV, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, + input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, + input_rng_state, wkspace, stream, handle); #else const char *err_msg = "cuDNN 8.9.0 is required for BF16/FP16 fused attention " diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index eb2ebcff39..6508cc67ef 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -8,10 +8,13 @@ #include #include #include +#include +#include #include #include +#include "../../../3rdparty/cudnn-frontend/python/pygraph/pygraph.h" #include "../common.h" #include "../cudnn_utils.h" #include "../util/cuda_runtime.h" @@ -48,6 +51,30 @@ namespace transformer_engine { namespace fused_attn { +namespace py = pybind11; + +namespace { + +auto make_attention_score_modifier(void *callback) + -> cudnn_frontend::graph::SDPA_attributes::AttentionScoreModifier_t { + if (callback == nullptr) { + return nullptr; + } + + auto *py_callback = reinterpret_cast(callback); + return [py_callback](std::shared_ptr graph, + std::shared_ptr score_tensor) { + py::gil_scoped_acquire gil; + py::module_::import("cudnn"); + auto py_graph = std::make_shared(graph); + py::object result = py::reinterpret_borrow(py_callback)( + py::cast(py_graph), py::cast(score_tensor)); + return result.cast>(); + }; +} + +} // namespace + void fused_attn_arbitrary_seqlen_fwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t num_pages_k, int64_t num_pages_v, @@ -56,8 +83,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( bool is_training, bool return_max_logit, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias, - void *devPtrSoftmaxOffset, void *devPtrS1, void *devPtrS2, void *devPtrO, + bool bottom_right_diagonal, void *score_mod, void *devPtrQ, void *devPtrK, void *devPtrV, + void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1, void *devPtrS2, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, @@ -257,6 +284,9 @@ void fused_attn_arbitrary_seqlen_fwd_impl( .set_causal_mask(is_causal) .set_causal_mask_bottom_right(is_bottom_right) .set_attn_scale(attn_scale); + if (score_mod != nullptr) { + sdpa_options.set_score_mod(make_attention_score_modifier(score_mod)); + } fe::DiagonalAlignment_t const &diagonal_alignment = bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT @@ -555,13 +585,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl( int64_t bias_sq, int64_t bias_skv, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool deterministic, void *devPtrQ, void *devPtrKTranspose, - void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, - void *devPtrSoftmaxOffset, void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, - void *devPtrdBias, void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, - void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, - void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, - void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { + bool bottom_right_diagonal, bool deterministic, void *score_mod, void *score_mod_bprop, + void *devPtrQ, void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, + void *devPtrSoftmaxStats, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrdQ, + void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias, + void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, void *devPtrDropoutOffset, + void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, + void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, + size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -790,6 +821,13 @@ void fused_attn_arbitrary_seqlen_bwd_impl( .set_causal_mask(is_causal) .set_causal_mask_bottom_right(is_bottom_right) .set_attn_scale(attn_scale); + if (score_mod != nullptr) { + sdpa_backward_options.set_score_mod(make_attention_score_modifier(score_mod)); + } + if (score_mod_bprop != nullptr) { + sdpa_backward_options.set_score_mod_bprop( + make_attention_score_modifier(score_mod_bprop)); + } if (is_ragged_q && cudnn_runtime_version >= 90600) { sdpa_backward_options.set_max_total_seq_len_q(s_q); @@ -1069,8 +1107,9 @@ void fused_attn_arbitrary_seqlen_fwd( bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, - const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, + void *score_mod, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, + const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { @@ -1208,11 +1247,11 @@ void fused_attn_arbitrary_seqlen_fwd( max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, bias_skv, is_training, return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, - softmax_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, - devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, - devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, score_mod, devPtrQ, + devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, + devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, + devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, + get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1235,10 +1274,11 @@ void fused_attn_arbitrary_seqlen_bwd( size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, - const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, - const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, - Tensor *output_dV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset, + bool deterministic, void *score_mod, void *score_mod_bprop, const Tensor *input_Q, + const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, + const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S, + Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, + Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { @@ -1306,11 +1346,12 @@ void fused_attn_arbitrary_seqlen_bwd( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, bias_sq, bias_skv, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, - devPtrSoftmaxStats, devPtrBias, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, - devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, - devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), - workspace->data.dptr, &workspace_size, stream, handle); + window_size_right, bottom_right_diagonal, deterministic, score_mod, score_mod_bprop, devPtrQ, + devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrSoftmaxOffset, devPtrdQ, + devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed, + devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, + devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, + stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 4dd7f3d1da..9e1bbbbcef 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -26,6 +26,7 @@ void fused_attn_arbitrary_seqlen_fwd( bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + void *score_mod, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, @@ -38,10 +39,11 @@ void fused_attn_arbitrary_seqlen_bwd( size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, - const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, - const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, - Tensor *output_dV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset, + bool deterministic, void *score_mod, void *score_mod_bprop, const Tensor *input_Q, + const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, + const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S, + Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, + Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 8169bf22e2..a570bc1974 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -280,6 +280,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. + * \param[in] score_mod Opaque pointer to a score modifier callback. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ @@ -295,7 +296,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream); + bool bottom_right_diagonal, void *score_mod, NVTETensor workspace, + cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with separate Q, K and V. * @@ -355,6 +357,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] cuda_graph Whether cuda graph capture is enabled or not. + * \param[in] score_mod Opaque pointer to a score modifier callback. + * \param[in] score_mod_bprop Opaque pointer to a score modifier backward callback. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ @@ -370,7 +374,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, bool cuda_graph, - NVTETensor workspace, cudaStream_t stream); + void *score_mod, void *score_mod_bprop, NVTETensor workspace, + cudaStream_t stream); /*! \brief Update the RNG state with the seed and calculated offset. * diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 92e67ac191..3c7d48933b 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -192,8 +192,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, bottom_right_diagonal, query_workspace_tensor.data(), - nullptr); + window_size_left, window_size_right, bottom_right_diagonal, nullptr, + query_workspace_tensor.data(), nullptr); } nvte_tensor_pack_destroy(&aux_output_tensors); @@ -329,7 +329,8 @@ static void FusedAttnForwardImpl( k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, bottom_right_diagonal, workspace_tensor.data(), stream); + window_size_left, window_size_right, bottom_right_diagonal, nullptr, + workspace_tensor.data(), stream); nvte_tensor_pack_destroy(&aux_output_tensors); } @@ -479,8 +480,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, bottom_right_diagonal, deterministic, false, - query_workspace_tensor.data(), nullptr); + window_size_right, bottom_right_diagonal, deterministic, false, nullptr, + nullptr, query_workspace_tensor.data(), nullptr); } nvte_tensor_pack_destroy(&aux_input_tensors); @@ -606,7 +607,8 @@ static void FusedAttnBackwardImpl( q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, - bottom_right_diagonal, deterministic, false, workspace_tensor.data(), stream); + bottom_right_diagonal, deterministic, false, nullptr, nullptr, + workspace_tensor.data(), stream); nvte_tensor_pack_destroy(&aux_input_tensors); } diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index a6a8b0b26a..2322db90db 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1182,6 +1182,8 @@ def forward( fp8_output, layer_number, return_max_logit, + score_mod, + score_mod_bprop, ): # pylint: disable=missing-function-docstring @@ -1277,6 +1279,7 @@ def forward( bottom_right_diagonal, rng_gen, softmax_offset, + score_mod=score_mod, cuda_graph=is_graph_capturing(), ) @@ -1357,6 +1360,7 @@ def forward( softmax_offset, return_max_logit, is_graph_capturing(), + score_mod=score_mod, ) out = out_ out_ret = out_ @@ -1446,6 +1450,8 @@ def forward( ) ctx.use_FAv2_bwd = use_FAv2_bwd ctx.deterministic = deterministic + ctx.score_mod = score_mod + ctx.score_mod_bprop = score_mod_bprop if return_max_logit: return out_ret, *max_logit @@ -1594,6 +1600,8 @@ def backward(ctx, d_out, *_args): ctx.bottom_right_diagonal, ctx.deterministic, is_graph_capturing(), + score_mod=ctx.score_mod, + score_mod_bprop=ctx.score_mod_bprop, ) # dq, dk, dv: torch.Tensor; dtype = torch.float16 or torch.bfloat16 @@ -1660,6 +1668,8 @@ def backward(ctx, d_out, *_args): ctx.bottom_right_diagonal, ctx.deterministic, is_graph_capturing(), + score_mod=ctx.score_mod, + score_mod_bprop=ctx.score_mod_bprop, ) d_bias = None @@ -1702,6 +1712,8 @@ def backward(ctx, d_out, *_args): None, None, None, + None, + None, ) @@ -1810,6 +1822,8 @@ def forward( pad_between_seqs: bool = False, inference_params: Optional[InferenceParams] = None, softmax_offset: torch.Tensor = None, + score_mod: Optional[Callable] = None, + score_mod_bprop: Optional[Callable] = None, fp8_output: bool = False, ) -> torch.Tensor: """fused attention fprop""" @@ -1826,6 +1840,13 @@ def forward( assert ( qkv_layout in QKVLayouts ), f"FusedAttention does not support qkv_layout = {qkv_layout}!" + if score_mod is not None or score_mod_bprop is not None: + assert ( + fused_attention_backend + == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen + ), "score_mod and score_mod_bprop require the F16_arbitrary_seqlen fused backend." + if self.training and score_mod is not None: + assert score_mod_bprop is not None, "score_mod_bprop is required when training with score_mod." cp_size = 1 if isinstance(cp_group, dist_group_type): @@ -1934,6 +1955,9 @@ def forward( ) if context_parallel: + assert ( + score_mod is None and score_mod_bprop is None + ), "score_mod and score_mod_bprop are not supported with context parallelism." assert ( fp8 or fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen @@ -2015,6 +2039,8 @@ def forward( fp8_output, self.layer_number, self.return_max_logit, + score_mod, + score_mod_bprop, ) if self.return_max_logit: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 2dc42be18a..db2556c352 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -826,6 +826,8 @@ def forward( fast_zero_fill: bool = True, inference_params: Optional[InferenceParams] = None, pad_between_seqs: Optional[bool] = None, + score_mod: Optional[Callable] = None, + score_mod_bprop: Optional[Callable] = None, fp8_output: Optional[bool] = False, num_splits: Optional[int] = 1, ) -> torch.Tensor: @@ -1372,6 +1374,8 @@ def forward( inference_params=inference_params, softmax_type=self.softmax_type, return_max_logit=self.return_max_logit, + has_score_mod=score_mod is not None, + has_score_mod_bprop=score_mod_bprop is not None, cuda_graph=is_graph_capturing(), num_splits=num_splits, ) @@ -1520,6 +1524,8 @@ def forward( pad_between_seqs=pad_between_seqs, inference_params=inference_params, softmax_offset=softmax_offset, + score_mod=score_mod, + score_mod_bprop=score_mod_bprop, fp8_output=fp8_output, ) return self.fused_attention( @@ -1551,6 +1557,8 @@ def forward( pad_between_seqs=pad_between_seqs, inference_params=inference_params, softmax_offset=softmax_offset, + score_mod=score_mod, + score_mod_bprop=score_mod_bprop, fp8_output=fp8_output, ) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 567fd17c34..cccb4c25ae 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -234,6 +234,10 @@ class AttentionParams: The type of softmax operation. See DotProductAttention for details. return_max_logit : bool, default = False Whether to output max_logit. + has_score_mod : bool, default = False + Whether a cuDNN flexible-graph score modifier callback is present. + has_score_mod_bprop : bool, default = False + Whether a cuDNN flexible-graph score modifier backward callback is present. cuda_graph : bool, default = `False` Whether support for cuda graph capture is needed or not. num_splits : int, default = 1 @@ -268,6 +272,8 @@ class AttentionParams: inference_params: Optional[InferenceParams] = None softmax_type: str = "vanilla" return_max_logit: bool = False + has_score_mod: bool = False + has_score_mod_bprop: bool = False cuda_graph: bool = False num_splits: int = 1 @@ -345,6 +351,8 @@ def get_attention_backend( inference_params = attention_params.inference_params softmax_type = attention_params.softmax_type return_max_logit = attention_params.return_max_logit + has_score_mod = attention_params.has_score_mod + has_score_mod_bprop = attention_params.has_score_mod_bprop cuda_graph = attention_params.cuda_graph num_splits = attention_params.num_splits @@ -534,6 +542,17 @@ def get_attention_backend( logger.debug("Disabling UnfusedDotProductAttention for num_splits") use_unfused_attention = False + if has_score_mod or has_score_mod_bprop: + if use_flash_attention: + logger.debug("Disabling FlashAttention for score_mod callbacks") + use_flash_attention = False + if use_unfused_attention: + logger.debug("Disabling UnfusedDotProductAttention for score_mod callbacks") + use_unfused_attention = False + if fp8 and fp8_meta["recipe"].fp8_dpa: + logger.debug("Disabling FusedAttention for score_mod callbacks in FP8") + use_fused_attention = False + # Filter: Return max_logit if return_max_logit: if use_flash_attention: @@ -1009,6 +1028,17 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt logger.debug("Disabling FusedAttention as no backend supports the provided input") use_fused_attention = False fused_attention_backend = None + if ( + use_fused_attention + and (has_score_mod or has_score_mod_bprop) + and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"] + ): + logger.debug( + "Disabling FusedAttention as score_mod callbacks require the F16_arbitrary_seqlen" + " sub-backend" + ) + use_fused_attention = False + fused_attention_backend = None if ( use_fused_attention and window_size is not None diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index d95d327c78..0036916b25 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -647,6 +647,8 @@ def forward( rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, + score_mod: Optional[Callable] = None, + score_mod_bprop: Optional[Callable] = None, alibi_slopes: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_kv: Optional[torch.Tensor] = None, @@ -721,6 +723,10 @@ def forward( core_attention_bias: Optional[torch.Tensor], default = None Bias tensor for :math:`Q \cdot K^T`, shape ``[1, num_head, max_seqlen_q, max_seqlen_kv]``. It should be ``None`` for ``"no_bias"`` and ``"alibi"`` bias types. + score_mod: Optional[Callable], default = None + Optional cuDNN flexible-graph score modifier callback. + score_mod_bprop: Optional[Callable], default = None + Optional cuDNN flexible-graph score modifier backward callback. alibi_slopes: Optional[torch.Tensor], default = None ALiBi slopes in FP32 and shape ``[nheads]`` or ``[batch_size, nheads]``. It adds a bias of ``(-alibi_slope * (i + seqlen_k - seqlen_q - j))`` @@ -1041,6 +1047,8 @@ def forward( checkpoint_core_attention=checkpoint_core_attention, core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, + score_mod=score_mod, + score_mod_bprop=score_mod_bprop, alibi_slopes=alibi_slopes, fast_zero_fill=fast_zero_fill, inference_params=inference_params, diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 2de4576e05..deb90d101f 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -143,6 +143,7 @@ def fused_attn_fwd( softmax_offset: torch.Tensor = None, return_max_logit: bool = False, cuda_graph: bool = False, + score_mod=None, ) -> Tuple[Union[torch.Tensor, None], ...]: """Fused Attention FWD for separate QKV input. @@ -223,6 +224,8 @@ def fused_attn_fwd( softmax_offset : torch.Tensor, default = None softmax offset tensor of shape [1, h_q, 1, 1]. See softmax_type in DotProductAttention for details. + score_mod : Callable, default = None + Optional cuDNN flexible-graph score modifier callback. return_max_logit : bool, default = False whether to return the maximum attention score cuda_graph : bool, default = False @@ -289,6 +292,11 @@ def fused_attn_fwd( f" q.dtype={q.dtype}, backend={fused_attention_backend}." ) + if score_mod is not None: + assert ( + fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] + ), "score_mod is only supported by the cuDNN F16_arbitrary_seqlen backend." + # BF16/FP16 fused attention API from fmha_v1 apex if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: rng_elts_per_thread = ( @@ -345,6 +353,7 @@ def fused_attn_fwd( o_quantizer, attn_bias, softmax_offset, + score_mod, rng_gen, rng_elts_per_thread, return_max_logit, @@ -398,6 +407,8 @@ def fused_attn_bwd( bottom_right_diagonal: bool = None, deterministic: bool = False, cuda_graph: bool = False, + score_mod=None, + score_mod_bprop=None, ) -> Tuple[Union[torch.Tensor, None], ...]: """Fused Attention BWD for packed KV input. @@ -473,6 +484,10 @@ def fused_attn_bwd( bottom right (True) corner of the softmax matrix. deterministic : bool, default = False whether to execute the backward pass with deterministic behaviours. + score_mod : Callable, default = None + Optional cuDNN flexible-graph score modifier callback. + score_mod_bprop : Callable, default = None + Optional cuDNN flexible-graph score modifier backward callback. cuda_graph : bool, default = False whether or not cuda graph capture is enabled. @@ -509,6 +524,11 @@ def fused_attn_bwd( f" q.dtype={q.dtype}, backend={fused_attention_backend}." ) + if score_mod is not None or score_mod_bprop is not None: + assert ( + fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] + ), "score_mod and score_mod_bprop are only supported by the cuDNN F16_arbitrary_seqlen backend." + if fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]: if len(aux_ctx_tensors) < 1: raise ValueError( @@ -568,6 +588,8 @@ def fused_attn_bwd( s_quantizer, dp_quantizer, dqkv_quantizer, + score_mod, + score_mod_bprop, cuda_graph, ) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 1c5116a8da..2b1ef99c8a 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -93,8 +93,9 @@ std::vector fused_attn_fwd( const std::optional cu_seqlens_kv_padded, const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, - const std::optional SoftmaxOffset, const std::optional rng_gen, - size_t rng_elts_per_thread, bool return_max_logit, bool cuda_graph); + const std::optional SoftmaxOffset, py::handle score_mod, + const std::optional rng_gen, size_t rng_elts_per_thread, bool return_max_logit, + bool cuda_graph); std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, @@ -106,7 +107,8 @@ std::vector fused_attn_bwd( const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, - py::handle dp_quantizer, py::handle dqkv_quantizer, bool cuda_graph); + py::handle dp_quantizer, py::handle dqkv_quantizer, py::handle score_mod, + py::handle score_mod_bprop, bool cuda_graph); at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index bf62db8c33..ab4ee5e65b 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -107,8 +107,9 @@ std::vector fused_attn_fwd( const std::optional cu_seqlens_kv_padded, const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, - const std::optional SoftmaxOffset, const std::optional rng_gen, - size_t rng_elts_per_thread, bool return_max_logit, bool cuda_graph) { + const std::optional SoftmaxOffset, py::handle score_mod, + const std::optional rng_gen, size_t rng_elts_per_thread, bool return_max_logit, + bool cuda_graph) { // Ensure that cuDNN handle is created on the correct device, // overriding torch.cuda.set_device calls from user side. // Assumes all tensors passed are on the same device. @@ -236,7 +237,8 @@ std::vector fused_attn_fwd( te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), + softmax_type, window_size[0], window_size[1], bottom_right_diagonal, + score_mod.is_none() ? nullptr : score_mod.ptr(), workspace.data(), at::cuda::getCurrentCUDAStream()); }); @@ -296,7 +298,8 @@ std::vector fused_attn_fwd( te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), + softmax_type, window_size[0], window_size[1], bottom_right_diagonal, + score_mod.is_none() ? nullptr : score_mod.ptr(), workspace.data(), at::cuda::getCurrentCUDAStream()); }); @@ -318,7 +321,8 @@ std::vector fused_attn_bwd( const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, - py::handle dp_quantizer, py::handle dqkv_quantizer, bool cuda_graph) { + py::handle dp_quantizer, py::handle dqkv_quantizer, py::handle score_mod, + py::handle score_mod_bprop, bool cuda_graph) { auto none = py::none(); // create QKV, O, dO tensor wrappers @@ -539,7 +543,9 @@ std::vector fused_attn_bwd( te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], - window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(), + window_size[1], bottom_right_diagonal, deterministic, cuda_graph, + score_mod.is_none() ? nullptr : score_mod.ptr(), + score_mod_bprop.is_none() ? nullptr : score_mod_bprop.ptr(), workspace.data(), at::cuda::getCurrentCUDAStream()); }); @@ -556,7 +562,9 @@ std::vector fused_attn_bwd( te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], - window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(), + window_size[1], bottom_right_diagonal, deterministic, cuda_graph, + score_mod.is_none() ? nullptr : score_mod.ptr(), + score_mod_bprop.is_none() ? nullptr : score_mod_bprop.ptr(), workspace.data(), at::cuda::getCurrentCUDAStream()); }); From 2de2847d346a879f23402342e12c9aea17366f99 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Wed, 11 Mar 2026 06:00:09 +0000 Subject: [PATCH 02/23] Fix score_mod helper return type --- .../common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 6508cc67ef..e819e43c0c 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -56,7 +56,9 @@ namespace py = pybind11; namespace { auto make_attention_score_modifier(void *callback) - -> cudnn_frontend::graph::SDPA_attributes::AttentionScoreModifier_t { + -> std::function( + std::shared_ptr, + std::shared_ptr)> { if (callback == nullptr) { return nullptr; } From 6a02b84f0cca9f9430ffc1fad6aa3f37ef39af29 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Wed, 11 Mar 2026 06:17:45 +0000 Subject: [PATCH 03/23] Fix score_mod PyGraph callback lifetime --- .../common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index e819e43c0c..78efc4e821 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -68,9 +68,9 @@ auto make_attention_score_modifier(void *callback) std::shared_ptr score_tensor) { py::gil_scoped_acquire gil; py::module_::import("cudnn"); - auto py_graph = std::make_shared(graph); + cudnn_frontend::python_bindings::PyGraph py_graph(graph); py::object result = py::reinterpret_borrow(py_callback)( - py::cast(py_graph), py::cast(score_tensor)); + py::cast(&py_graph, py::return_value_policy::reference), py::cast(score_tensor)); return result.cast>(); }; } From a4620074e6ea5bf6e3f5f4177ecdab967a95ed1d Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Wed, 11 Mar 2026 21:32:36 +0000 Subject: [PATCH 04/23] Add test_dpa_score_mod and fix score_mod trampoline/caching - tests/pytorch/attention/test_attention.py: add test_dpa_score_mod testing DotProductAttention with an identity score_mod callable (requires cuDNN >= 9.7.0, F16_arbitrary_seqlen backend, no_mask). Pass score_mod_bprop alongside score_mod to satisfy the training assertion added in this branch. - utils.h: add has_score_mod/has_score_mod_bprop bool fields to FADescriptor_v1 and include them in operator< so graphs built with vs. without a score_mod callback get separate cache entries (fixes the score_mod callback never being invoked on cache hit). - fused_attn_f16_arbitrary_seqlen.cu: replace pybind11/PyGraph trampoline with PyCapsule-based trampoline (matching the simpler approach from cudnn-score-mod branch). Raw C++ pointers are wrapped as PyCapsule objects; GIL is acquired via PyGILState_Ensure. Removes the pybind11 and pygraph.h dependencies. - CMakeLists.txt: remove find_package(pybind11), the pybind11::headers link, and the pygraph include dir that were only needed for the old trampoline. Python_INCLUDE_DIRS (for ) is retained. Co-Authored-By: Claude Sonnet 4.6 --- tests/pytorch/attention/test_attention.py | 107 ++++++++++++++++++ transformer_engine/common/CMakeLists.txt | 7 +- .../fused_attn_f16_arbitrary_seqlen.cu | 46 +++++--- transformer_engine/common/fused_attn/utils.h | 7 +- 4 files changed, 142 insertions(+), 25 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index f6131d6d4b..c49b0914b8 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1513,6 +1513,113 @@ def test_multihead_attention_score_mod_forces_fused_backend(): assert not _attention_backends["use_flash_attention"] +# Configs for score_mod tests: batch=2, seqlen=512, heads=16, head_dim=64 +# Note: score_mod disables other cuDNN subgraphs (e.g. causal masking), so only no_mask is tested. +model_configs_score_mod = { + "score_mod_0": ModelConfig(2, 512, 16, 64, attn_mask_type="no_mask"), + "score_mod_1": ModelConfig(4, 256, 8, 64, attn_mask_type="no_mask"), +} + + +@pytest.mark.skipif(get_cudnn_version() < (9, 7, 0), reason="cuDNN 9.7.0+ is required for score_mod.") +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("model_configs", [model_configs_score_mod]) +@pytest.mark.parametrize("model", model_configs_score_mod.keys()) +def test_dpa_score_mod(dtype, model_configs, model): + """Test DotProductAttention with score_mod=None (plumbing test). + + Verifies that passing score_mod=None produces the same output as not passing it, + and that passing a Python callable as score_mod (identity: returns None) does not crash + and produces identical output when F16_arbitrary_seqlen backend is used. + """ + config = model_configs[model] + qkv_layout = "bshd_bshd_bshd" + qkv_format = "bshd" + + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "0" + _attention_backends["backend_selection_requires_update"] = True + + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=qkv_layout, + pad_between_seqs=False, + is_training=True, + deterministic=False, + ) + _, fused_attn_supported, _ = available_backends + + if not fused_attn_supported or FusedAttnBackend["F16_arbitrary_seqlen"] not in fused_attn_backends: + pytest.skip("F16_arbitrary_seqlen backend not available.") + + reset_rng_states() + + b, sq, h, d = config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim_qk + q = (torch.randn(b, sq, h, d, dtype=dtype, device="cuda") * 0.1).detach().requires_grad_(True) + k = (torch.randn(b, sq, h, d, dtype=dtype, device="cuda") * 0.1).detach().requires_grad_(True) + v = (torch.randn(b, sq, h, d, dtype=dtype, device="cuda") * 0.1).detach().requires_grad_(True) + cu_seqlens = torch.arange(0, (b + 1) * sq, sq, dtype=torch.int32, device="cuda") + out_grad = (torch.randn(b, sq, h * d, dtype=dtype, device="cuda") * 0.01).detach() + + _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() + _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) + + def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: + return _DUMMY_CUDA_RNG_STATE_TRACKER + + block = DotProductAttention( + h, d, + attention_dropout=0.0, + qkv_format=qkv_format, + attn_mask_type=config.attn_mask_type, + sequence_parallel=False, + tp_size=1, + get_rng_state_tracker=get_dummy_cuda_rng_tracker, + tp_group=None, + layer_number=1, + ).to(dtype=dtype, device="cuda") + + # Reference: run without score_mod (score_mod=None by default) + out_ref = block(q, k, v, qkv_format=qkv_format, cu_seqlens_q=cu_seqlens, cu_seqlens_kv=cu_seqlens, + max_seqlen_q=sq, max_seqlen_kv=sq, attn_mask_type=config.attn_mask_type) + out_ref.backward(out_grad) + dq_ref = q.grad.clone() + dk_ref = k.grad.clone() + dv_ref = v.grad.clone() + q.grad = None + k.grad = None + v.grad = None + + # score_mod callable that returns None (identity: cuDNN trampoline returns original score) + score_mod_called = [False] + + def identity_score_mod(graph, score): + """Identity score_mod: verify we get called, return None to keep original score.""" + score_mod_called[0] = True + return None + + # Run with score_mod identity callable + out_sm = block(q, k, v, qkv_format=qkv_format, cu_seqlens_q=cu_seqlens, cu_seqlens_kv=cu_seqlens, + max_seqlen_q=sq, max_seqlen_kv=sq, attn_mask_type=config.attn_mask_type, + score_mod=identity_score_mod, score_mod_bprop=identity_score_mod) + out_sm.backward(out_grad) + dq_sm = q.grad.clone() + dk_sm = k.grad.clone() + dv_sm = v.grad.clone() + + assert score_mod_called[0], "score_mod callable was not invoked by the cuDNN trampoline." + + tols = dict(atol=1e-3, rtol=1e-3) + if dtype == torch.bfloat16: + tols = dict(atol=1.5e-2, rtol=1.5e-2) + torch.testing.assert_close(out_sm, out_ref, **tols) + torch.testing.assert_close(dq_sm, dq_ref, **tols) + torch.testing.assert_close(dk_sm, dk_ref, **tols) + torch.testing.assert_close(dv_sm, dv_ref, **tols) + + @pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_te_layer]) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 36b3f39bae..12bb170331 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -97,7 +97,6 @@ set(CUTLASS_TOOLS_INCLUDE_DIR # Python find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) -find_package(pybind11 CONFIG REQUIRED) # Configure Transformer Engine library include_directories(${PROJECT_SOURCE_DIR}/..) @@ -260,19 +259,15 @@ target_link_libraries(transformer_engine PUBLIC CUDA::cublas CUDA::cudart CUDNN::cudnn_all) -target_link_libraries(transformer_engine PRIVATE Python::Module pybind11::headers) - target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_include_directories(transformer_engine SYSTEM PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl) target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") +target_include_directories(transformer_engine PRIVATE ${Python_INCLUDE_DIRS}) target_include_directories(transformer_engine PRIVATE ${CUTLASS_INCLUDE_DIR} ${CUTLASS_TOOLS_INCLUDE_DIR}) -target_include_directories(transformer_engine PRIVATE - ${Python_INCLUDE_DIRS} - "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/python/pygraph") # Compiling Userbuffers with native MPI bootstrapping requires linking against MPI option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 78efc4e821..e0d46a8d55 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -4,17 +4,15 @@ * See LICENSE for license information. ************************************************************************/ +#include #include #include #include #include -#include -#include #include #include -#include "../../../3rdparty/cudnn-frontend/python/pygraph/pygraph.h" #include "../common.h" #include "../cudnn_utils.h" #include "../util/cuda_runtime.h" @@ -51,27 +49,37 @@ namespace transformer_engine { namespace fused_attn { -namespace py = pybind11; - namespace { +using Tensor_t = std::shared_ptr; +using Graph_t = std::shared_ptr; + auto make_attention_score_modifier(void *callback) - -> std::function( - std::shared_ptr, - std::shared_ptr)> { + -> std::function { if (callback == nullptr) { return nullptr; } - auto *py_callback = reinterpret_cast(callback); - return [py_callback](std::shared_ptr graph, - std::shared_ptr score_tensor) { - py::gil_scoped_acquire gil; - py::module_::import("cudnn"); - cudnn_frontend::python_bindings::PyGraph py_graph(graph); - py::object result = py::reinterpret_borrow(py_callback)( - py::cast(&py_graph, py::return_value_policy::reference), py::cast(score_tensor)); - return result.cast>(); + PyObject *py_callback = static_cast(callback); + return [py_callback](Graph_t graph, Tensor_t score) -> Tensor_t { + PyGILState_STATE gstate = PyGILState_Ensure(); + PyObject *py_graph = PyCapsule_New(graph.get(), "fe::graph::Graph", nullptr); + PyObject *py_score = PyCapsule_New(score.get(), "fe::graph::Tensor_attributes", nullptr); + PyObject *result = PyObject_CallFunctionObjArgs(py_callback, py_graph, py_score, nullptr); + Py_DECREF(py_graph); + Py_DECREF(py_score); + Tensor_t result_tensor = score; // default: return input score unchanged + if (result != nullptr) { + void *ptr = PyCapsule_IsValid(result, "fe::graph::Tensor_attributes") + ? PyCapsule_GetPointer(result, "fe::graph::Tensor_attributes") + : nullptr; + if (ptr != nullptr) { + result_tensor = Tensor_t(score, static_cast(ptr)); + } + Py_DECREF(result); + } + PyGILState_Release(gstate); + return result_tensor; }; } @@ -169,6 +177,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( cudnn_frontend::DataType_t::NOT_SET, cudnn_frontend::DataType_t::NOT_SET, return_max_logit, + (score_mod != nullptr), + false, }; namespace fe = cudnn_frontend; @@ -677,6 +687,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( cudnn_frontend::DataType_t::NOT_SET, cudnn_frontend::DataType_t::NOT_SET, false, + (score_mod != nullptr), + (score_mod_bprop != nullptr), }; namespace fe = cudnn_frontend; diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 08a56cda6b..1a95be661c 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -119,6 +119,8 @@ struct FADescriptor_v1 { cudnn_frontend::DataType_t do_tensor_type; cudnn_frontend::DataType_t dqkv_tensor_type; bool generate_max_sum_exp; + bool has_score_mod; + bool has_score_mod_bprop; bool operator<(const FADescriptor_v1 &rhs) const { return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, @@ -126,7 +128,7 @@ struct FADescriptor_v1 { bias_skv, attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, bias_type, qkv_tensor_type, o_tensor_type, do_tensor_type, - dqkv_tensor_type, generate_max_sum_exp) < + dqkv_tensor_type, generate_max_sum_exp, has_score_mod, has_score_mod_bprop) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.bias_sq, rhs.bias_skv, @@ -134,7 +136,8 @@ struct FADescriptor_v1 { rhs.mask_type, rhs.softmax_type, rhs.window_size_left, rhs.window_size_right, rhs.bottom_right_diagonal, rhs.deterministic, rhs.bias_type, rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type, - rhs.dqkv_tensor_type, rhs.generate_max_sum_exp); + rhs.dqkv_tensor_type, rhs.generate_max_sum_exp, rhs.has_score_mod, + rhs.has_score_mod_bprop); } }; From b1e7adf230be3b776e45e30b07d16b3adc92d349 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Thu, 12 Mar 2026 05:32:54 +0000 Subject: [PATCH 05/23] Resore original trampoline Signed-off-by: Vladimir Cherepanov --- transformer_engine/common/CMakeLists.txt | 7 +++- .../fused_attn_f16_arbitrary_seqlen.cu | 42 ++++++++----------- 2 files changed, 23 insertions(+), 26 deletions(-) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 12bb170331..36b3f39bae 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -97,6 +97,7 @@ set(CUTLASS_TOOLS_INCLUDE_DIR # Python find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) +find_package(pybind11 CONFIG REQUIRED) # Configure Transformer Engine library include_directories(${PROJECT_SOURCE_DIR}/..) @@ -259,15 +260,19 @@ target_link_libraries(transformer_engine PUBLIC CUDA::cublas CUDA::cudart CUDNN::cudnn_all) +target_link_libraries(transformer_engine PRIVATE Python::Module pybind11::headers) + target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_include_directories(transformer_engine SYSTEM PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl) target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") -target_include_directories(transformer_engine PRIVATE ${Python_INCLUDE_DIRS}) target_include_directories(transformer_engine PRIVATE ${CUTLASS_INCLUDE_DIR} ${CUTLASS_TOOLS_INCLUDE_DIR}) +target_include_directories(transformer_engine PRIVATE + ${Python_INCLUDE_DIRS} + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/python/pygraph") # Compiling Userbuffers with native MPI bootstrapping requires linking against MPI option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index e0d46a8d55..3037bb6f4b 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -4,15 +4,17 @@ * See LICENSE for license information. ************************************************************************/ -#include #include #include #include #include +#include +#include #include #include +#include "../../../3rdparty/cudnn-frontend/python/pygraph/pygraph.h" #include "../common.h" #include "../cudnn_utils.h" #include "../util/cuda_runtime.h" @@ -49,37 +51,27 @@ namespace transformer_engine { namespace fused_attn { -namespace { +namespace py = pybind11; -using Tensor_t = std::shared_ptr; -using Graph_t = std::shared_ptr; +namespace { auto make_attention_score_modifier(void *callback) - -> std::function { + -> std::function( + std::shared_ptr, + std::shared_ptr)> { if (callback == nullptr) { return nullptr; } - PyObject *py_callback = static_cast(callback); - return [py_callback](Graph_t graph, Tensor_t score) -> Tensor_t { - PyGILState_STATE gstate = PyGILState_Ensure(); - PyObject *py_graph = PyCapsule_New(graph.get(), "fe::graph::Graph", nullptr); - PyObject *py_score = PyCapsule_New(score.get(), "fe::graph::Tensor_attributes", nullptr); - PyObject *result = PyObject_CallFunctionObjArgs(py_callback, py_graph, py_score, nullptr); - Py_DECREF(py_graph); - Py_DECREF(py_score); - Tensor_t result_tensor = score; // default: return input score unchanged - if (result != nullptr) { - void *ptr = PyCapsule_IsValid(result, "fe::graph::Tensor_attributes") - ? PyCapsule_GetPointer(result, "fe::graph::Tensor_attributes") - : nullptr; - if (ptr != nullptr) { - result_tensor = Tensor_t(score, static_cast(ptr)); - } - Py_DECREF(result); - } - PyGILState_Release(gstate); - return result_tensor; + auto *py_callback = reinterpret_cast(callback); + return [py_callback](std::shared_ptr graph, + std::shared_ptr score_tensor) { + py::gil_scoped_acquire gil; + py::module_::import("cudnn"); + cudnn_frontend::python_bindings::PyGraph py_graph(graph); + py::object result = py::reinterpret_borrow(py_callback)( + py::cast(&py_graph, py::return_value_policy::reference), py::cast(score_tensor)); + return result.cast>(); }; } From 8fed706bfcdcc38c0abd75d5324cc0c2c4379ecf Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Thu, 12 Mar 2026 07:07:36 +0000 Subject: [PATCH 06/23] Debug WIP... Signed-off-by: Vladimir Cherepanov --- tests/pytorch/attention/test_attention.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index c49b0914b8..2429b41bf0 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -10,6 +10,8 @@ import pytest import torch +import cudnn + from transformer_engine.pytorch.quantization import FP8GlobalStateManager, get_fp8_te_dtype from transformer_engine.common import recipe from transformer_engine.pytorch import ( @@ -100,9 +102,21 @@ def reset_global_fp8_state(): param_types_lean = [torch.bfloat16] -def _identity_score_mod(_graph, score_tensor): +def _identity_score_mod(sdpa_graph, score_tensor): return score_tensor +def foo_score_mod(sdpa_graph, q_kt_tensor): + row_index = sdpa_graph.gen_index(input=q_kt_tensor, axis=2) + row_index.set_data_type(cudnn.data_type.INT32) + + col_index = sdpa_graph.gen_index(input=q_kt_tensor, axis=3) + col_index.set_data_type(cudnn.data_type.INT32) + + d = sdpa_graph.sub(row_index, col_index) + ret = sdpa_graph.add(q_kt_tensor, d) + return ret + + model_configs_base = { # test: ModelConfig(b, sq, hq, dqk) @@ -1448,7 +1462,7 @@ def test_fused_attn_score_mod_smoke(): qkv_layout="bshd_bshd_bshd", attn_mask_type="no_mask", dropout=0.0, - score_mod=_identity_score_mod, + score_mod=foo_score_mod, ) dq, dk, dv, *_ = fused_attn_bwd( From cdb81fe50192e82f5a447a4a6518a1749fa80d55 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Thu, 12 Mar 2026 13:05:43 -0700 Subject: [PATCH 07/23] Fixes: lifetime issue in trampoline, callback identities instead of booleans Signed-off-by: Vladimir Cherepanov --- .../fused_attn_f16_arbitrary_seqlen.cu | 38 +++++++++++-------- transformer_engine/common/fused_attn/utils.h | 10 ++--- 2 files changed, 28 insertions(+), 20 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 3037bb6f4b..18972419dc 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -55,24 +55,32 @@ namespace py = pybind11; namespace { -auto make_attention_score_modifier(void *callback) +auto make_attention_score_modifier(void *callback_ptr) -> std::function( std::shared_ptr, std::shared_ptr)> { - if (callback == nullptr) { + if (callback_ptr == nullptr) { return nullptr; } - auto *py_callback = reinterpret_cast(callback); - return [py_callback](std::shared_ptr graph, - std::shared_ptr score_tensor) { + py::object callback; + { py::gil_scoped_acquire gil; - py::module_::import("cudnn"); - cudnn_frontend::python_bindings::PyGraph py_graph(graph); - py::object result = py::reinterpret_borrow(py_callback)( - py::cast(&py_graph, py::return_value_policy::reference), py::cast(score_tensor)); - return result.cast>(); - }; + callback = py::reinterpret_borrow(reinterpret_cast(callback_ptr)); + } + + return [callback = std::move(callback)]( + std::shared_ptr graph, + std::shared_ptr score_tensor) { + py::gil_scoped_acquire gil; + py::module_::import("cudnn"); + + auto py_graph = + std::make_shared(std::move(graph)); + + py::object result = callback(py::cast(py_graph), py::cast(std::move(score_tensor))); + return result.cast>(); + }; } } // namespace @@ -168,9 +176,9 @@ void fused_attn_arbitrary_seqlen_fwd_impl( cudnn_frontend::DataType_t::NOT_SET, cudnn_frontend::DataType_t::NOT_SET, cudnn_frontend::DataType_t::NOT_SET, + reinterpret_cast(score_mod), + 0, return_max_logit, - (score_mod != nullptr), - false, }; namespace fe = cudnn_frontend; @@ -678,9 +686,9 @@ void fused_attn_arbitrary_seqlen_bwd_impl( cudnn_frontend::DataType_t::NOT_SET, cudnn_frontend::DataType_t::NOT_SET, cudnn_frontend::DataType_t::NOT_SET, + reinterpret_cast(score_mod), + reinterpret_cast(score_mod_bprop), false, - (score_mod != nullptr), - (score_mod_bprop != nullptr), }; namespace fe = cudnn_frontend; diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 1a95be661c..e4455d1c15 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -118,9 +118,9 @@ struct FADescriptor_v1 { cudnn_frontend::DataType_t o_tensor_type; cudnn_frontend::DataType_t do_tensor_type; cudnn_frontend::DataType_t dqkv_tensor_type; + std::uintptr_t score_mod_id; + std::uintptr_t score_mod_bprop_id; bool generate_max_sum_exp; - bool has_score_mod; - bool has_score_mod_bprop; bool operator<(const FADescriptor_v1 &rhs) const { return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, @@ -128,7 +128,7 @@ struct FADescriptor_v1 { bias_skv, attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, bias_type, qkv_tensor_type, o_tensor_type, do_tensor_type, - dqkv_tensor_type, generate_max_sum_exp, has_score_mod, has_score_mod_bprop) < + dqkv_tensor_type, score_mod_id, score_mod_bprop_id, generate_max_sum_exp) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.bias_sq, rhs.bias_skv, @@ -136,8 +136,8 @@ struct FADescriptor_v1 { rhs.mask_type, rhs.softmax_type, rhs.window_size_left, rhs.window_size_right, rhs.bottom_right_diagonal, rhs.deterministic, rhs.bias_type, rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type, - rhs.dqkv_tensor_type, rhs.generate_max_sum_exp, rhs.has_score_mod, - rhs.has_score_mod_bprop); + rhs.dqkv_tensor_type, rhs.score_mod_id, rhs.score_mod_bprop_id, + rhs.generate_max_sum_exp); } }; From f7636ebd748fb35cbee90522f1be57a37ed7c2f3 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Thu, 12 Mar 2026 13:16:33 -0700 Subject: [PATCH 08/23] Another fix Signed-off-by: Vladimir Cherepanov --- .../common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 18972419dc..84b539067e 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -63,10 +63,10 @@ auto make_attention_score_modifier(void *callback_ptr) return nullptr; } - py::object callback; + py::function callback; { py::gil_scoped_acquire gil; - callback = py::reinterpret_borrow(reinterpret_cast(callback_ptr)); + callback = py::reinterpret_borrow(reinterpret_cast(callback_ptr)); } return [callback = std::move(callback)]( @@ -76,9 +76,9 @@ auto make_attention_score_modifier(void *callback_ptr) py::module_::import("cudnn"); auto py_graph = - std::make_shared(std::move(graph)); + std::make_shared(graph); - py::object result = callback(py::cast(py_graph), py::cast(std::move(score_tensor))); + py::object result = callback(*py_graph, score_tensor); return result.cast>(); }; } From 1aa3b84bc60fbe21b0dc2490d97bf4bc1aa2a9df Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Thu, 12 Mar 2026 13:27:59 -0700 Subject: [PATCH 09/23] Fix the case of callbacks returning None Signed-off-by: Vladimir Cherepanov --- .../common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu | 3 +++ 1 file changed, 3 insertions(+) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 84b539067e..6f940fdbb2 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -79,6 +79,9 @@ auto make_attention_score_modifier(void *callback_ptr) std::make_shared(graph); py::object result = callback(*py_graph, score_tensor); + if (result.is_none()) { + return score_tensor; + } return result.cast>(); }; } From 9fd5c81d9c33dddbd7ec8a0a8d1fc9324377fd2c Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Thu, 12 Mar 2026 13:49:12 -0700 Subject: [PATCH 10/23] Causal attn test Signed-off-by: Vladimir Cherepanov --- tests/pytorch/attention/test_attention.py | 128 ++++++++++++++++++++++ 1 file changed, 128 insertions(+) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 2429b41bf0..56f5ea8e76 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -105,6 +105,30 @@ def reset_global_fp8_state(): def _identity_score_mod(sdpa_graph, score_tensor): return score_tensor + +def _causal_score_mod(sdpa_graph, score_tensor): + row_index = sdpa_graph.gen_index(input=score_tensor, axis=2) + row_index.set_data_type(cudnn.data_type.INT32) + + col_index = sdpa_graph.gen_index(input=score_tensor, axis=3) + col_index.set_data_type(cudnn.data_type.INT32) + + causal_mask = sdpa_graph.cmp_ge( + input=row_index, + comparison=col_index, + compute_data_type=cudnn.data_type.BOOLEAN, + ) + causal_mask.set_data_type(cudnn.data_type.BOOLEAN) + + zero = sdpa_graph.sub(a=row_index, b=row_index, compute_data_type=cudnn.data_type.FLOAT) + zero.set_data_type(cudnn.data_type.FLOAT) + + neg_inf = sdpa_graph.log(input=zero, compute_data_type=cudnn.data_type.FLOAT) + neg_inf.set_data_type(cudnn.data_type.FLOAT) + + return sdpa_graph.binary_select(input0=score_tensor, input1=neg_inf, mask=causal_mask) + + def foo_score_mod(sdpa_graph, q_kt_tensor): row_index = sdpa_graph.gen_index(input=q_kt_tensor, axis=2) row_index.set_data_type(cudnn.data_type.INT32) @@ -1634,6 +1658,110 @@ def identity_score_mod(graph, score): torch.testing.assert_close(dv_sm, dv_ref, **tols) +@pytest.mark.skipif(get_cudnn_version() < (9, 7, 0), reason="cuDNN 9.7.0+ is required for score_mod.") +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("model_configs", [model_configs_score_mod]) +@pytest.mark.parametrize("model", model_configs_score_mod.keys()) +def test_dpa_score_mod_causal(dtype, model_configs, model): + """Test DotProductAttention causal masking implemented via score_mod.""" + + config = model_configs[model] + qkv_layout = "bshd_bshd_bshd" + qkv_format = "bshd" + + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "0" + _attention_backends["backend_selection_requires_update"] = True + + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=qkv_layout, + pad_between_seqs=False, + is_training=True, + deterministic=False, + ) + _, fused_attn_supported, _ = available_backends + + if not fused_attn_supported or FusedAttnBackend["F16_arbitrary_seqlen"] not in fused_attn_backends: + pytest.skip("F16_arbitrary_seqlen backend not available.") + + reset_rng_states() + + b, sq, h, d = config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim_qk + q = (torch.randn(b, sq, h, d, dtype=dtype, device="cuda") * 0.1).detach().requires_grad_(True) + k = (torch.randn(b, sq, h, d, dtype=dtype, device="cuda") * 0.1).detach().requires_grad_(True) + v = (torch.randn(b, sq, h, d, dtype=dtype, device="cuda") * 0.1).detach().requires_grad_(True) + cu_seqlens = torch.arange(0, (b + 1) * sq, sq, dtype=torch.int32, device="cuda") + out_grad = (torch.randn(b, sq, h * d, dtype=dtype, device="cuda") * 0.01).detach() + + _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() + _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) + + def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: + return _DUMMY_CUDA_RNG_STATE_TRACKER + + block = DotProductAttention( + h, + d, + attention_dropout=0.0, + qkv_format=qkv_format, + attn_mask_type="no_mask", + sequence_parallel=False, + tp_size=1, + get_rng_state_tracker=get_dummy_cuda_rng_tracker, + tp_group=None, + layer_number=1, + ).to(dtype=dtype, device="cuda") + + out_ref = block( + q, + k, + v, + qkv_format=qkv_format, + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + max_seqlen_q=sq, + max_seqlen_kv=sq, + attn_mask_type="causal", + ) + out_ref.backward(out_grad) + dq_ref = q.grad.clone() + dk_ref = k.grad.clone() + dv_ref = v.grad.clone() + q.grad = None + k.grad = None + v.grad = None + + out_sm = block( + q, + k, + v, + qkv_format=qkv_format, + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + max_seqlen_q=sq, + max_seqlen_kv=sq, + attn_mask_type="no_mask", + score_mod=_causal_score_mod, + score_mod_bprop=_causal_score_mod, + ) + out_sm.backward(out_grad) + dq_sm = q.grad.clone() + dk_sm = k.grad.clone() + dv_sm = v.grad.clone() + + tols = dict(atol=1e-3, rtol=1e-3) + if dtype == torch.bfloat16: + tols = dict(atol=1.5e-2, rtol=1.5e-2) + + torch.testing.assert_close(out_sm, out_ref, **tols) + torch.testing.assert_close(dq_sm, dq_ref, **tols) + torch.testing.assert_close(dk_sm, dk_ref, **tols) + torch.testing.assert_close(dv_sm, dv_ref, **tols) + + @pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_te_layer]) From f146c4a836973fec17c5a59b3e96700d6fddbe9a Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Thu, 12 Mar 2026 21:04:59 +0000 Subject: [PATCH 11/23] Don't require score_mod_bprop Signed-off-by: Vladimir Cherepanov --- tests/pytorch/attention/test_attention.py | 1 - .../pytorch/attention/dot_product_attention/backends.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 56f5ea8e76..105a245b6f 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1745,7 +1745,6 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: max_seqlen_kv=sq, attn_mask_type="no_mask", score_mod=_causal_score_mod, - score_mod_bprop=_causal_score_mod, ) out_sm.backward(out_grad) dq_sm = q.grad.clone() diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 2322db90db..d20b3380f9 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1845,8 +1845,6 @@ def forward( fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen ), "score_mod and score_mod_bprop require the F16_arbitrary_seqlen fused backend." - if self.training and score_mod is not None: - assert score_mod_bprop is not None, "score_mod_bprop is required when training with score_mod." cp_size = 1 if isinstance(cp_group, dist_group_type): From 279b7e33a37ec58d08b10537a4dbc32af19fd488 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Thu, 12 Mar 2026 23:25:28 -0700 Subject: [PATCH 12/23] Avoid owning py::function in cached score modifier Signed-off-by: Vladimir Cherepanov --- .../fused_attn/fused_attn_f16_arbitrary_seqlen.cu | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 6f940fdbb2..baf73de21c 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -63,22 +63,19 @@ auto make_attention_score_modifier(void *callback_ptr) return nullptr; } - py::function callback; - { - py::gil_scoped_acquire gil; - callback = py::reinterpret_borrow(reinterpret_cast(callback_ptr)); - } + auto *callback = reinterpret_cast(callback_ptr); - return [callback = std::move(callback)]( + return [callback]( std::shared_ptr graph, std::shared_ptr score_tensor) { py::gil_scoped_acquire gil; py::module_::import("cudnn"); + py::function callback_fn = py::reinterpret_borrow(callback); auto py_graph = std::make_shared(graph); - py::object result = callback(*py_graph, score_tensor); + py::object result = callback_fn(*py_graph, score_tensor); if (result.is_none()) { return score_tensor; } From 0763c02c310d31573e84e79e055d5939d74a05c4 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Fri, 13 Mar 2026 06:48:10 +0000 Subject: [PATCH 13/23] Tests cleanup Signed-off-by: Vladimir Cherepanov --- tests/pytorch/attention/test_attention.py | 75 +---------------------- 1 file changed, 2 insertions(+), 73 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 105a245b6f..7eb363122c 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -129,19 +129,6 @@ def _causal_score_mod(sdpa_graph, score_tensor): return sdpa_graph.binary_select(input0=score_tensor, input1=neg_inf, mask=causal_mask) -def foo_score_mod(sdpa_graph, q_kt_tensor): - row_index = sdpa_graph.gen_index(input=q_kt_tensor, axis=2) - row_index.set_data_type(cudnn.data_type.INT32) - - col_index = sdpa_graph.gen_index(input=q_kt_tensor, axis=3) - col_index.set_data_type(cudnn.data_type.INT32) - - d = sdpa_graph.sub(row_index, col_index) - ret = sdpa_graph.add(q_kt_tensor, d) - return ret - - - model_configs_base = { # test: ModelConfig(b, sq, hq, dqk) "base_1_0": ModelConfig(8, 128, 16, 64), @@ -1459,66 +1446,8 @@ def test_transformer_layer( torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols) -@pytest.mark.skipif(get_cudnn_version() < (9, 0, 0), reason="cuDNN 9.0.0+ is required.") -def test_fused_attn_score_mod_smoke(): - pytest.importorskip("cudnn") - - batch_size, seqlen, num_heads, head_dim = 2, 64, 4, 64 - dtype = torch.float16 - device = "cuda" - cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, seqlen, device=device, dtype=torch.int32) - - q = torch.randn(batch_size, seqlen, num_heads, head_dim, device=device, dtype=dtype) - k = torch.randn_like(q) - v = torch.randn_like(q) - - out, aux_ctx_tensors = fused_attn_fwd( - True, - seqlen, - seqlen, - cu_seqlens, - cu_seqlens, - q, - k, - v, - dtype, - FusedAttnBackend["F16_arbitrary_seqlen"], - qkv_layout="bshd_bshd_bshd", - attn_mask_type="no_mask", - dropout=0.0, - score_mod=foo_score_mod, - ) - - dq, dk, dv, *_ = fused_attn_bwd( - seqlen, - seqlen, - cu_seqlens, - cu_seqlens, - q, - k, - v, - out, - torch.randn_like(out), - dtype, - tex.DType.kFloat16, - aux_ctx_tensors, - FusedAttnBackend["F16_arbitrary_seqlen"], - qkv_layout="bshd_bshd_bshd", - attn_mask_type="no_mask", - dropout=0.0, - score_mod=_identity_score_mod, - score_mod_bprop=_identity_score_mod, - ) - - assert dq.shape == q.shape - assert dk.shape == k.shape - assert dv.shape == v.shape - - -@pytest.mark.skipif(get_cudnn_version() < (9, 0, 0), reason="cuDNN 9.0.0+ is required.") +@pytest.mark.skipif(get_cudnn_version() < (9, 7, 0), reason="cuDNN 9.7.0+ is required.") def test_multihead_attention_score_mod_forces_fused_backend(): - pytest.importorskip("cudnn") - _attention_backends["attention_params"] = None _attention_backends["backend_selection_requires_update"] = True @@ -1541,7 +1470,7 @@ def test_multihead_attention_score_mod_forces_fused_backend(): ) output = mha( hidden_states, - attn_mask_type="causal", + attn_mask_type="no_mask", score_mod=_identity_score_mod, score_mod_bprop=_identity_score_mod, ) From f34d825dde52a55f1fd23c013d64ed9a63a91e5b Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Fri, 13 Mar 2026 00:34:59 -0700 Subject: [PATCH 14/23] Support extra score_mod tensors in fused attention Signed-off-by: Vladimir Cherepanov --- tests/pytorch/attention/test_attention.py | 124 ++++++++++ .../common/fused_attn/fused_attn.cpp | 27 +- .../fused_attn_f16_arbitrary_seqlen.cu | 233 ++++++++++++++---- .../fused_attn_f16_arbitrary_seqlen.h | 5 +- transformer_engine/common/fused_attn/utils.h | 6 +- .../include/transformer_engine/fused_attn.h | 11 +- .../dot_product_attention/backends.py | 16 ++ .../dot_product_attention.py | 6 + .../pytorch/attention/multi_head_attention.py | 4 + .../pytorch/cpp_extensions/fused_attn.py | 18 ++ transformer_engine/pytorch/csrc/extensions.h | 4 +- .../pytorch/csrc/extensions/attention.cpp | 19 +- 12 files changed, 401 insertions(+), 72 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 7eb363122c..069644d997 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -129,6 +129,27 @@ def _causal_score_mod(sdpa_graph, score_tensor): return sdpa_graph.binary_select(input0=score_tensor, input1=neg_inf, mask=causal_mask) +def _causal_score_mod_external(sdpa_graph, score_tensor, score_mod_tensors): + row_index = sdpa_graph.gen_index(input=score_tensor, axis=2) + row_index.set_data_type(cudnn.data_type.INT32) + + col_index = sdpa_graph.gen_index(input=score_tensor, axis=3) + col_index.set_data_type(cudnn.data_type.INT32) + + causal_mask = sdpa_graph.cmp_ge( + input=row_index, + comparison=col_index, + compute_data_type=cudnn.data_type.BOOLEAN, + ) + causal_mask.set_data_type(cudnn.data_type.BOOLEAN) + + return sdpa_graph.binary_select( + input0=score_tensor, + input1=score_mod_tensors["neg_inf"], + mask=causal_mask, + ) + + model_configs_base = { # test: ModelConfig(b, sq, hq, dqk) "base_1_0": ModelConfig(8, 128, 16, 64), @@ -1690,6 +1711,109 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: torch.testing.assert_close(dv_sm, dv_ref, **tols) +@pytest.mark.skipif(get_cudnn_version() < (9, 7, 0), reason="cuDNN 9.7.0+ is required for score_mod.") +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("model_configs", [model_configs_score_mod]) +@pytest.mark.parametrize("model", model_configs_score_mod.keys()) +def test_dpa_score_mod_causal_external_neg_inf(dtype, model_configs, model): + """Test DotProductAttention causal masking via score_mod with external variant-pack tensor.""" + + config = model_configs[model] + qkv_layout = "bshd_bshd_bshd" + qkv_format = "bshd" + + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "0" + _attention_backends["backend_selection_requires_update"] = True + + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=qkv_layout, + pad_between_seqs=False, + is_training=True, + deterministic=False, + ) + _, fused_attn_supported, _ = available_backends + + if not fused_attn_supported or FusedAttnBackend["F16_arbitrary_seqlen"] not in fused_attn_backends: + pytest.skip("F16_arbitrary_seqlen backend not available.") + + reset_rng_states() + + b, sq, h, d = config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim_qk + q = (torch.randn(b, sq, h, d, dtype=dtype, device="cuda") * 0.1).detach().requires_grad_(True) + k = (torch.randn(b, sq, h, d, dtype=dtype, device="cuda") * 0.1).detach().requires_grad_(True) + v = (torch.randn(b, sq, h, d, dtype=dtype, device="cuda") * 0.1).detach().requires_grad_(True) + cu_seqlens = torch.arange(0, (b + 1) * sq, sq, dtype=torch.int32, device="cuda") + out_grad = (torch.randn(b, sq, h * d, dtype=dtype, device="cuda") * 0.01).detach() + neg_inf = torch.full((1, 1, 1, 1), float("-inf"), dtype=torch.float32) + + _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() + _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) + + def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: + return _DUMMY_CUDA_RNG_STATE_TRACKER + + block = DotProductAttention( + h, + d, + attention_dropout=0.0, + qkv_format=qkv_format, + attn_mask_type="no_mask", + sequence_parallel=False, + tp_size=1, + get_rng_state_tracker=get_dummy_cuda_rng_tracker, + tp_group=None, + layer_number=1, + ).to(dtype=dtype, device="cuda") + + out_ref = block( + q, + k, + v, + qkv_format=qkv_format, + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + max_seqlen_q=sq, + max_seqlen_kv=sq, + attn_mask_type="causal", + ) + out_ref.backward(out_grad) + dq_ref = q.grad.clone() + dk_ref = k.grad.clone() + dv_ref = v.grad.clone() + q.grad = None + k.grad = None + v.grad = None + + out_sm = block( + q, + k, + v, + qkv_format=qkv_format, + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + max_seqlen_q=sq, + max_seqlen_kv=sq, + attn_mask_type="no_mask", + score_mod=_causal_score_mod_external, + score_mod_tensors={"neg_inf": neg_inf}, + ) + out_sm.backward(out_grad) + dq_sm = q.grad.clone() + dk_sm = k.grad.clone() + dv_sm = v.grad.clone() + + tols = dict(atol=1.5e-2, rtol=1.5e-2) + + torch.testing.assert_close(out_sm, out_ref, **tols) + torch.testing.assert_close(dq_sm, dq_ref, **tols) + torch.testing.assert_close(dk_sm, dk_ref, **tols) + torch.testing.assert_close(dv_sm, dv_ref, **tols) + + @pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_te_layer]) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 78f3918d01..54b5bd80bf 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -547,8 +547,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, void *score_mod, NVTETensor workspace, - cudaStream_t stream) { + bool bottom_right_diagonal, void *score_mod, void *score_mod_tensors, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -636,10 +636,11 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, bottom_right_diagonal, score_mod, input_Q, input_K, - input_V, input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, - input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, - input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); + window_size_left, window_size_right, bottom_right_diagonal, score_mod, + score_mod_tensors, input_Q, input_K, input_V, input_Bias, input_SoftmaxOffset, output_O, + Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, + input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, + wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. " @@ -671,8 +672,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, bool cuda_graph, - void *score_mod, void *score_mod_bprop, NVTETensor workspace, - cudaStream_t stream) { + void *score_mod, void *score_mod_bprop, void *score_mod_tensors, + void *score_mod_bprop_tensors, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -745,11 +746,11 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso fused_attn_arbitrary_seqlen_bwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, - bottom_right_diagonal, deterministic, score_mod, score_mod_bprop, input_Q, input_K, - input_V, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, output_dQ, - output_dK, output_dV, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, - input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, - input_rng_state, wkspace, stream, handle); + bottom_right_diagonal, deterministic, score_mod, score_mod_bprop, score_mod_tensors, + score_mod_bprop_tensors, input_Q, input_K, input_V, input_O, input_dO, input_Bias, + input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias, + output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, + input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); #else const char *err_msg = "cuDNN 8.9.0 is required for BF16/FP16 fused attention " diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index baf73de21c..1307813025 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -8,10 +8,13 @@ #include #include #include +#include #include #include +#include #include +#include #include #include "../../../3rdparty/cudnn-frontend/python/pygraph/pygraph.h" @@ -55,7 +58,118 @@ namespace py = pybind11; namespace { -auto make_attention_score_modifier(void *callback_ptr) +using TensorAttr = std::shared_ptr; +using ExtraTensorList = std::vector>; + +constexpr char kDlpackCapsuleName[] = "dltensor"; + +std::vector get_sorted_extra_tensor_names(const py::dict &extra_tensors) { + std::vector names; + names.reserve(extra_tensors.size()); + for (auto item : extra_tensors) { + names.push_back(py::cast(item.first)); + } + std::sort(names.begin(), names.end()); + return names; +} + +DLManagedTensor *get_dlpack_tensor(py::handle tensor_obj) { + NVTE_CHECK(py::hasattr(tensor_obj, "__dlpack__"), + "score_mod_tensors entries must support __dlpack__()."); + py::capsule capsule = tensor_obj.attr("__dlpack__")(); + NVTE_CHECK(!capsule.is_none(), "Failed to retrieve DLPack capsule for score_mod_tensors entry."); + auto *managed = + static_cast(PyCapsule_GetPointer(capsule.ptr(), kDlpackCapsuleName)); + NVTE_CHECK(managed != nullptr, "Invalid DLPack capsule in score_mod_tensors entry."); + return managed; +} + +void *get_dlpack_data_pointer(py::handle tensor_obj) { + auto *managed = get_dlpack_tensor(tensor_obj); + return static_cast(managed->dl_tensor.data) + managed->dl_tensor.byte_offset; +} + +template +void hash_combine(std::uint64_t &seed, const T &value) { + seed ^= std::hash{}(value) + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2); +} + +std::uint64_t get_extra_tensor_signature(void *extra_tensors_ptr) { + if (extra_tensors_ptr == nullptr) { + return 0; + } + + py::gil_scoped_acquire gil; + py::dict extra_tensors = py::reinterpret_borrow(reinterpret_cast(extra_tensors_ptr)); + std::uint64_t signature = 0; + + for (const auto &name : get_sorted_extra_tensor_names(extra_tensors)) { + py::handle tensor_obj = extra_tensors[py::str(name)]; + auto *managed = get_dlpack_tensor(tensor_obj); + + hash_combine(signature, name); + hash_combine(signature, managed->dl_tensor.device.device_type); + hash_combine(signature, managed->dl_tensor.device.device_id); + hash_combine(signature, managed->dl_tensor.dtype.code); + hash_combine(signature, managed->dl_tensor.dtype.bits); + hash_combine(signature, managed->dl_tensor.dtype.lanes); + hash_combine(signature, managed->dl_tensor.ndim); + for (int i = 0; i < managed->dl_tensor.ndim; ++i) { + hash_combine(signature, managed->dl_tensor.shape[i]); + hash_combine(signature, managed->dl_tensor.strides ? managed->dl_tensor.strides[i] : -1); + } + } + + return signature; +} + +py::dict get_score_mod_tensor_attrs( + const std::shared_ptr &py_graph, void *extra_tensors_ptr, + ExtraTensorList *extra_tensor_attrs) { + py::dict callback_tensors; + if (extra_tensors_ptr == nullptr) { + return callback_tensors; + } + + if (extra_tensor_attrs != nullptr && !extra_tensor_attrs->empty()) { + for (const auto &[name, tensor_attr] : *extra_tensor_attrs) { + callback_tensors[py::str(name)] = py::cast(tensor_attr); + } + return callback_tensors; + } + + py::dict extra_tensors = py::reinterpret_borrow(reinterpret_cast(extra_tensors_ptr)); + for (const auto &name : get_sorted_extra_tensor_names(extra_tensors)) { + py::handle tensor_obj = extra_tensors[py::str(name)]; + auto tensor_attr = py_graph->tensor_like(py::reinterpret_borrow(tensor_obj)); + tensor_attr->set_name(name); + callback_tensors[py::str(name)] = py::cast(tensor_attr); + if (extra_tensor_attrs != nullptr) { + extra_tensor_attrs->emplace_back(name, std::move(tensor_attr)); + } + } + + return callback_tensors; +} + +void extend_variant_pack_with_extra_tensors( + void *extra_tensors_ptr, const ExtraTensorList &extra_tensor_attrs, + std::unordered_map &variant_pack) { + if (extra_tensors_ptr == nullptr || extra_tensor_attrs.empty()) { + return; + } + + py::gil_scoped_acquire gil; + py::dict extra_tensors = py::reinterpret_borrow(reinterpret_cast(extra_tensors_ptr)); + for (const auto &[name, tensor_attr] : extra_tensor_attrs) { + py::str key(name); + NVTE_CHECK(extra_tensors.contains(key), "Missing score_mod tensor entry: ", name); + variant_pack[tensor_attr] = get_dlpack_data_pointer(extra_tensors[key]); + } +} + +auto make_attention_score_modifier(void *callback_ptr, void *extra_tensors_ptr, + ExtraTensorList *extra_tensor_attrs) -> std::function( std::shared_ptr, std::shared_ptr)> { @@ -65,17 +179,21 @@ auto make_attention_score_modifier(void *callback_ptr) auto *callback = reinterpret_cast(callback_ptr); - return [callback]( + return [callback, extra_tensors_ptr, extra_tensor_attrs]( std::shared_ptr graph, - std::shared_ptr score_tensor) { + std::shared_ptr score_tensor) mutable { py::gil_scoped_acquire gil; py::module_::import("cudnn"); py::function callback_fn = py::reinterpret_borrow(callback); auto py_graph = std::make_shared(graph); + py::dict callback_tensors = + get_score_mod_tensor_attrs(py_graph, extra_tensors_ptr, extra_tensor_attrs); - py::object result = callback_fn(*py_graph, score_tensor); + py::object result = callback_tensors.empty() + ? callback_fn(*py_graph, score_tensor) + : callback_fn(*py_graph, score_tensor, callback_tensors); if (result.is_none()) { return score_tensor; } @@ -93,12 +211,13 @@ void fused_attn_arbitrary_seqlen_fwd_impl( bool is_training, bool return_max_logit, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, void *score_mod, void *devPtrQ, void *devPtrK, void *devPtrV, - void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1, void *devPtrS2, void *devPtrO, - void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, - void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, - void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, - void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { + bool bottom_right_diagonal, void *score_mod, void *score_mod_tensors, void *devPtrQ, + void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1, + void *devPtrS2, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, + void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, + void *devPtrPageTableV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, + cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, + cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -178,6 +297,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( cudnn_frontend::DataType_t::NOT_SET, reinterpret_cast(score_mod), 0, + get_extra_tensor_signature(score_mod_tensors), + 0, return_max_logit, }; @@ -203,7 +324,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::shared_ptr, // offset_o std::shared_ptr, // offset_stats std::shared_ptr, // dropout_seed - std::shared_ptr>; // dropout_offset + std::shared_ptr, // dropout_offset + ExtraTensorList>; // score_mod extra tensors using CacheType = std::map; static thread_local CacheType sdpa_f16_fprop_cache; @@ -229,6 +351,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::shared_ptr offset_q, offset_k, offset_v, offset_o, offset_stats; std::shared_ptr dropout_seed, dropout_offset; + ExtraTensorList score_mod_extra_tensors; std::vector q_stride(4); std::vector k_stride(4); @@ -297,7 +420,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( .set_causal_mask_bottom_right(is_bottom_right) .set_attn_scale(attn_scale); if (score_mod != nullptr) { - sdpa_options.set_score_mod(make_attention_score_modifier(score_mod)); + sdpa_options.set_score_mod( + make_attention_score_modifier(score_mod, score_mod_tensors, &score_mod_extra_tensors)); } fe::DiagonalAlignment_t const &diagonal_alignment = @@ -464,7 +588,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, softmax_offset_tuple, padding_tuple, page_table_tuple, offset_qo_tuple, - offset_kv_tuple, offset_s_tuple, dropout_tuple); + offset_kv_tuple, offset_s_tuple, dropout_tuple, + std::make_tuple(score_mod_extra_tensors)); cache.insert({descriptor, return_tuple}); return return_tuple; @@ -472,7 +597,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( auto [mha_graph, Q, K, V, attn_scale, O, S1, S2, bias, softmax_offset, seq_q, seq_kv, page_table_k, page_table_v, offset_q, offset_o, offset_k, offset_v, offset_stats, - dropout_seed, dropout_offset] = get_graph(sdpa_f16_fprop_cache, descriptor); + dropout_seed, dropout_offset, score_mod_extra_tensors] = + get_graph(sdpa_f16_fprop_cache, descriptor); // Exit to request upper level API to allocate memory if needed // n.b. Care should be taken to align each of the added worksapce tensors to their type. @@ -585,6 +711,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( variant_pack[softmax_offset] = devPtrSoftmaxOffset; } + extend_variant_pack_with_extra_tensors(score_mod_tensors, score_mod_extra_tensors, variant_pack); + NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); } catch (cudnn_frontend::cudnnException &e) { NVTE_ERROR(e.what()); @@ -598,13 +726,13 @@ void fused_attn_arbitrary_seqlen_bwd_impl( NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, void *score_mod, void *score_mod_bprop, - void *devPtrQ, void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, - void *devPtrSoftmaxStats, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrdQ, - void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias, - void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, void *devPtrDropoutOffset, - void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, - void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, - size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { + void *score_mod_tensors, void *score_mod_bprop_tensors, void *devPtrQ, + void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, + void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrdQ, void *devPtrdK, void *devPtrdV, + void *devPtrdO, void *devPtrdBias, void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, + void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, + void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, + void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -688,6 +816,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( cudnn_frontend::DataType_t::NOT_SET, reinterpret_cast(score_mod), reinterpret_cast(score_mod_bprop), + get_extra_tensor_signature(score_mod_tensors), + get_extra_tensor_signature(score_mod_bprop_tensors), false, }; @@ -716,7 +846,9 @@ void fused_attn_arbitrary_seqlen_bwd_impl( std::shared_ptr, // offset_o std::shared_ptr, // offset_stats std::shared_ptr, // dropout_seed - std::shared_ptr>; // dropout_offset + std::shared_ptr, // dropout_offset + ExtraTensorList, // score_mod extra tensors + ExtraTensorList>; // score_mod_bprop extra tensors using CacheType = std::map; static thread_local CacheType sdpa_f16_bprop_cache; @@ -742,6 +874,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( std::shared_ptr offset_q, offset_k, offset_v, offset_o, offset_stats; std::shared_ptr dropout_seed, dropout_offset; + ExtraTensorList score_mod_extra_tensors; + ExtraTensorList score_mod_bprop_extra_tensors; std::vector q_stride(4); std::vector k_stride(4); @@ -836,11 +970,12 @@ void fused_attn_arbitrary_seqlen_bwd_impl( .set_causal_mask_bottom_right(is_bottom_right) .set_attn_scale(attn_scale); if (score_mod != nullptr) { - sdpa_backward_options.set_score_mod(make_attention_score_modifier(score_mod)); + sdpa_backward_options.set_score_mod(make_attention_score_modifier( + score_mod, score_mod_tensors, &score_mod_extra_tensors)); } if (score_mod_bprop != nullptr) { - sdpa_backward_options.set_score_mod_bprop( - make_attention_score_modifier(score_mod_bprop)); + sdpa_backward_options.set_score_mod_bprop(make_attention_score_modifier( + score_mod_bprop, score_mod_bprop_tensors, &score_mod_bprop_extra_tensors)); } if (is_ragged_q && cudnn_runtime_version >= 90600) { @@ -980,7 +1115,9 @@ void fused_attn_arbitrary_seqlen_bwd_impl( auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, softmax_offset_tuple, padding_tuple, offset_qo_tuple, - offset_kv_tuple, offset_s_tuple, dropout_tuple); + offset_kv_tuple, offset_s_tuple, dropout_tuple, + std::make_tuple(score_mod_extra_tensors), + std::make_tuple(score_mod_bprop_extra_tensors)); cache.insert({descriptor, return_tuple}); return return_tuple; @@ -988,7 +1125,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV, bias, dBias, softmax_offset, d_softmax_offset, seq_q, seq_kv, offset_q, offset_o, offset_k, offset_v, offset_stats, - dropout_seed, dropout_offset] = get_graph(sdpa_f16_bprop_cache, descriptor); + dropout_seed, dropout_offset, score_mod_extra_tensors, + score_mod_bprop_extra_tensors] = get_graph(sdpa_f16_bprop_cache, descriptor); // Exit to request upper level API to allocate memory if needed // n.b. Care should be taken to align each of the added worksapce tensors to their type. @@ -1105,6 +1243,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl( variant_pack[d_softmax_offset] = devPtrdSoftmaxOffset; } + extend_variant_pack_with_extra_tensors(score_mod_tensors, score_mod_extra_tensors, variant_pack); + extend_variant_pack_with_extra_tensors(score_mod_bprop_tensors, score_mod_bprop_extra_tensors, + variant_pack); + NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); } catch (cudnn_frontend::cudnnException &e) { NVTE_ERROR(e.what()); @@ -1121,8 +1263,8 @@ void fused_attn_arbitrary_seqlen_fwd( bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - void *score_mod, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, - const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, + void *score_mod, void *score_mod_tensors, const Tensor *input_Q, const Tensor *input_K, + const Tensor *input_V, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, @@ -1261,11 +1403,12 @@ void fused_attn_arbitrary_seqlen_fwd( max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, bias_skv, is_training, return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, - softmax_type, window_size_left, window_size_right, bottom_right_diagonal, score_mod, devPtrQ, - devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, - devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, - devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, score_mod, + score_mod_tensors, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, + devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, + devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, + devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, + stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1288,11 +1431,11 @@ void fused_attn_arbitrary_seqlen_bwd( size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - bool deterministic, void *score_mod, void *score_mod_bprop, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S, - Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, - Tensor *output_dSoftmaxOffset, + bool deterministic, void *score_mod, void *score_mod_bprop, void *score_mod_tensors, + void *score_mod_bprop_tensors, const Tensor *input_Q, const Tensor *input_K, + const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, + const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, + Tensor *output_dV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { @@ -1360,12 +1503,12 @@ void fused_attn_arbitrary_seqlen_bwd( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, bias_sq, bias_skv, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, bottom_right_diagonal, deterministic, score_mod, score_mod_bprop, devPtrQ, - devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrSoftmaxOffset, devPtrdQ, - devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, - devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, - stream, handle); + window_size_right, bottom_right_diagonal, deterministic, score_mod, score_mod_bprop, + score_mod_tensors, score_mod_bprop_tensors, devPtrQ, devPtrK, devPtrV, devPtrO, + devPtrSoftmaxStats, devPtrBias, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, + devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, + devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), + workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 9e1bbbbcef..84ed29c99e 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -26,7 +26,7 @@ void fused_attn_arbitrary_seqlen_fwd( bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - void *score_mod, + void *score_mod, void *score_mod_tensors, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, @@ -39,7 +39,8 @@ void fused_attn_arbitrary_seqlen_bwd( size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - bool deterministic, void *score_mod, void *score_mod_bprop, const Tensor *input_Q, + bool deterministic, void *score_mod, void *score_mod_bprop, void *score_mod_tensors, + void *score_mod_bprop_tensors, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index e4455d1c15..73aca9879d 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -120,6 +120,8 @@ struct FADescriptor_v1 { cudnn_frontend::DataType_t dqkv_tensor_type; std::uintptr_t score_mod_id; std::uintptr_t score_mod_bprop_id; + std::uint64_t score_mod_tensors_id; + std::uint64_t score_mod_bprop_tensors_id; bool generate_max_sum_exp; bool operator<(const FADescriptor_v1 &rhs) const { @@ -128,7 +130,8 @@ struct FADescriptor_v1 { bias_skv, attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, bias_type, qkv_tensor_type, o_tensor_type, do_tensor_type, - dqkv_tensor_type, score_mod_id, score_mod_bprop_id, generate_max_sum_exp) < + dqkv_tensor_type, score_mod_id, score_mod_bprop_id, score_mod_tensors_id, + score_mod_bprop_tensors_id, generate_max_sum_exp) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.bias_sq, rhs.bias_skv, @@ -137,6 +140,7 @@ struct FADescriptor_v1 { rhs.bottom_right_diagonal, rhs.deterministic, rhs.bias_type, rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type, rhs.dqkv_tensor_type, rhs.score_mod_id, rhs.score_mod_bprop_id, + rhs.score_mod_tensors_id, rhs.score_mod_bprop_tensors_id, rhs.generate_max_sum_exp); } }; diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index a570bc1974..44b38918af 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -281,6 +281,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] window_size_right Sliding window size (the right half). * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] score_mod Opaque pointer to a score modifier callback. + * \param[in] score_mod_tensors Opaque pointer to a mapping of extra score modifier tensors. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ @@ -296,8 +297,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, void *score_mod, NVTETensor workspace, - cudaStream_t stream); + bool bottom_right_diagonal, void *score_mod, void *score_mod_tensors, + NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with separate Q, K and V. * @@ -359,6 +360,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] score_mod Opaque pointer to a score modifier callback. * \param[in] score_mod_bprop Opaque pointer to a score modifier backward callback. + * \param[in] score_mod_tensors Opaque pointer to a mapping of extra score modifier tensors. + * \param[in] score_mod_bprop_tensors Opaque pointer to a mapping of extra score modifier backward tensors. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ @@ -374,8 +377,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, bool cuda_graph, - void *score_mod, void *score_mod_bprop, NVTETensor workspace, - cudaStream_t stream); + void *score_mod, void *score_mod_bprop, void *score_mod_tensors, + void *score_mod_bprop_tensors, NVTETensor workspace, cudaStream_t stream); /*! \brief Update the RNG state with the seed and calculated offset. * diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index d20b3380f9..65bd9fd9b4 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1184,6 +1184,8 @@ def forward( return_max_logit, score_mod, score_mod_bprop, + score_mod_tensors, + score_mod_bprop_tensors, ): # pylint: disable=missing-function-docstring @@ -1280,6 +1282,7 @@ def forward( rng_gen, softmax_offset, score_mod=score_mod, + score_mod_tensors=score_mod_tensors, cuda_graph=is_graph_capturing(), ) @@ -1361,6 +1364,7 @@ def forward( return_max_logit, is_graph_capturing(), score_mod=score_mod, + score_mod_tensors=score_mod_tensors, ) out = out_ out_ret = out_ @@ -1452,6 +1456,8 @@ def forward( ctx.deterministic = deterministic ctx.score_mod = score_mod ctx.score_mod_bprop = score_mod_bprop + ctx.score_mod_tensors = score_mod_tensors + ctx.score_mod_bprop_tensors = score_mod_bprop_tensors if return_max_logit: return out_ret, *max_logit @@ -1602,6 +1608,8 @@ def backward(ctx, d_out, *_args): is_graph_capturing(), score_mod=ctx.score_mod, score_mod_bprop=ctx.score_mod_bprop, + score_mod_tensors=ctx.score_mod_tensors, + score_mod_bprop_tensors=ctx.score_mod_bprop_tensors, ) # dq, dk, dv: torch.Tensor; dtype = torch.float16 or torch.bfloat16 @@ -1670,6 +1678,8 @@ def backward(ctx, d_out, *_args): is_graph_capturing(), score_mod=ctx.score_mod, score_mod_bprop=ctx.score_mod_bprop, + score_mod_tensors=ctx.score_mod_tensors, + score_mod_bprop_tensors=ctx.score_mod_bprop_tensors, ) d_bias = None @@ -1714,6 +1724,8 @@ def backward(ctx, d_out, *_args): None, None, None, + None, + None, ) @@ -1824,6 +1836,8 @@ def forward( softmax_offset: torch.Tensor = None, score_mod: Optional[Callable] = None, score_mod_bprop: Optional[Callable] = None, + score_mod_tensors: Optional[Dict[str, torch.Tensor]] = None, + score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]] = None, fp8_output: bool = False, ) -> torch.Tensor: """fused attention fprop""" @@ -2039,6 +2053,8 @@ def forward( self.return_max_logit, score_mod, score_mod_bprop, + score_mod_tensors, + score_mod_bprop_tensors, ) if self.return_max_logit: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index db2556c352..98b61e6550 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -828,6 +828,8 @@ def forward( pad_between_seqs: Optional[bool] = None, score_mod: Optional[Callable] = None, score_mod_bprop: Optional[Callable] = None, + score_mod_tensors: Optional[Dict[str, torch.Tensor]] = None, + score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]] = None, fp8_output: Optional[bool] = False, num_splits: Optional[int] = 1, ) -> torch.Tensor: @@ -1526,6 +1528,8 @@ def forward( softmax_offset=softmax_offset, score_mod=score_mod, score_mod_bprop=score_mod_bprop, + score_mod_tensors=score_mod_tensors, + score_mod_bprop_tensors=score_mod_bprop_tensors, fp8_output=fp8_output, ) return self.fused_attention( @@ -1559,6 +1563,8 @@ def forward( softmax_offset=softmax_offset, score_mod=score_mod, score_mod_bprop=score_mod_bprop, + score_mod_tensors=score_mod_tensors, + score_mod_bprop_tensors=score_mod_bprop_tensors, fp8_output=fp8_output, ) diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index 0036916b25..696f3d43fc 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -649,6 +649,8 @@ def forward( core_attention_bias: Optional[torch.Tensor] = None, score_mod: Optional[Callable] = None, score_mod_bprop: Optional[Callable] = None, + score_mod_tensors: Optional[Dict[str, torch.Tensor]] = None, + score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]] = None, alibi_slopes: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_kv: Optional[torch.Tensor] = None, @@ -1049,6 +1051,8 @@ def forward( core_attention_bias=core_attention_bias, score_mod=score_mod, score_mod_bprop=score_mod_bprop, + score_mod_tensors=score_mod_tensors, + score_mod_bprop_tensors=score_mod_bprop_tensors, alibi_slopes=alibi_slopes, fast_zero_fill=fast_zero_fill, inference_params=inference_params, diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index deb90d101f..b784155107 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -144,6 +144,7 @@ def fused_attn_fwd( return_max_logit: bool = False, cuda_graph: bool = False, score_mod=None, + score_mod_tensors=None, ) -> Tuple[Union[torch.Tensor, None], ...]: """Fused Attention FWD for separate QKV input. @@ -226,6 +227,8 @@ def fused_attn_fwd( See softmax_type in DotProductAttention for details. score_mod : Callable, default = None Optional cuDNN flexible-graph score modifier callback. + score_mod_tensors : Dict[str, torch.Tensor], default = None + Extra tensors to expose to the score modifier callback via the variant pack. return_max_logit : bool, default = False whether to return the maximum attention score cuda_graph : bool, default = False @@ -292,6 +295,8 @@ def fused_attn_fwd( f" q.dtype={q.dtype}, backend={fused_attention_backend}." ) + if score_mod_tensors is not None: + assert score_mod is not None, "score_mod_tensors requires score_mod." if score_mod is not None: assert ( fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] @@ -354,6 +359,7 @@ def fused_attn_fwd( attn_bias, softmax_offset, score_mod, + score_mod_tensors, rng_gen, rng_elts_per_thread, return_max_logit, @@ -409,6 +415,8 @@ def fused_attn_bwd( cuda_graph: bool = False, score_mod=None, score_mod_bprop=None, + score_mod_tensors=None, + score_mod_bprop_tensors=None, ) -> Tuple[Union[torch.Tensor, None], ...]: """Fused Attention BWD for packed KV input. @@ -488,6 +496,10 @@ def fused_attn_bwd( Optional cuDNN flexible-graph score modifier callback. score_mod_bprop : Callable, default = None Optional cuDNN flexible-graph score modifier backward callback. + score_mod_tensors : Dict[str, torch.Tensor], default = None + Extra tensors to expose to the score modifier callback via the variant pack. + score_mod_bprop_tensors : Dict[str, torch.Tensor], default = None + Extra tensors to expose to the score modifier backward callback via the variant pack. cuda_graph : bool, default = False whether or not cuda graph capture is enabled. @@ -524,6 +536,10 @@ def fused_attn_bwd( f" q.dtype={q.dtype}, backend={fused_attention_backend}." ) + if score_mod_tensors is not None: + assert score_mod is not None, "score_mod_tensors requires score_mod." + if score_mod_bprop_tensors is not None: + assert score_mod_bprop is not None, "score_mod_bprop_tensors requires score_mod_bprop." if score_mod is not None or score_mod_bprop is not None: assert ( fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] @@ -590,6 +606,8 @@ def fused_attn_bwd( dqkv_quantizer, score_mod, score_mod_bprop, + score_mod_tensors, + score_mod_bprop_tensors, cuda_graph, ) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 2b1ef99c8a..059de63c6d 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -94,6 +94,7 @@ std::vector fused_attn_fwd( const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, const std::optional SoftmaxOffset, py::handle score_mod, + py::handle score_mod_tensors, const std::optional rng_gen, size_t rng_elts_per_thread, bool return_max_logit, bool cuda_graph); @@ -108,7 +109,8 @@ std::vector fused_attn_bwd( const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, py::handle dp_quantizer, py::handle dqkv_quantizer, py::handle score_mod, - py::handle score_mod_bprop, bool cuda_graph); + py::handle score_mod_bprop, py::handle score_mod_tensors, + py::handle score_mod_bprop_tensors, bool cuda_graph); at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index ab4ee5e65b..112900e55e 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -108,6 +108,7 @@ std::vector fused_attn_fwd( const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, const std::optional SoftmaxOffset, py::handle score_mod, + py::handle score_mod_tensors, const std::optional rng_gen, size_t rng_elts_per_thread, bool return_max_logit, bool cuda_graph) { // Ensure that cuDNN handle is created on the correct device, @@ -238,7 +239,8 @@ std::vector fused_attn_fwd( te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, - score_mod.is_none() ? nullptr : score_mod.ptr(), workspace.data(), + score_mod.is_none() ? nullptr : score_mod.ptr(), + score_mod_tensors.is_none() ? nullptr : score_mod_tensors.ptr(), workspace.data(), at::cuda::getCurrentCUDAStream()); }); @@ -322,7 +324,8 @@ std::vector fused_attn_bwd( const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, py::handle dp_quantizer, py::handle dqkv_quantizer, py::handle score_mod, - py::handle score_mod_bprop, bool cuda_graph) { + py::handle score_mod_bprop, py::handle score_mod_tensors, + py::handle score_mod_bprop_tensors, bool cuda_graph) { auto none = py::none(); // create QKV, O, dO tensor wrappers @@ -545,8 +548,10 @@ std::vector fused_attn_bwd( attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, deterministic, cuda_graph, score_mod.is_none() ? nullptr : score_mod.ptr(), - score_mod_bprop.is_none() ? nullptr : score_mod_bprop.ptr(), workspace.data(), - at::cuda::getCurrentCUDAStream()); + score_mod_bprop.is_none() ? nullptr : score_mod_bprop.ptr(), + score_mod_tensors.is_none() ? nullptr : score_mod_tensors.ptr(), + score_mod_bprop_tensors.is_none() ? nullptr : score_mod_bprop_tensors.ptr(), + workspace.data(), at::cuda::getCurrentCUDAStream()); }); // allocate memory for workspace @@ -564,8 +569,10 @@ std::vector fused_attn_bwd( attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, deterministic, cuda_graph, score_mod.is_none() ? nullptr : score_mod.ptr(), - score_mod_bprop.is_none() ? nullptr : score_mod_bprop.ptr(), workspace.data(), - at::cuda::getCurrentCUDAStream()); + score_mod_bprop.is_none() ? nullptr : score_mod_bprop.ptr(), + score_mod_tensors.is_none() ? nullptr : score_mod_tensors.ptr(), + score_mod_bprop_tensors.is_none() ? nullptr : score_mod_bprop_tensors.ptr(), + workspace.data(), at::cuda::getCurrentCUDAStream()); }); // destroy tensor wrappers From 59643ca0cbf0e881a9d5fb89248fa792b1ce4e50 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Fri, 13 Mar 2026 00:49:23 -0700 Subject: [PATCH 15/23] Use vendored DLPack header in fused attention Signed-off-by: Vladimir Cherepanov --- .../common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 1307813025..9db25bec53 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -8,7 +8,6 @@ #include #include #include -#include #include #include @@ -18,6 +17,7 @@ #include #include "../../../3rdparty/cudnn-frontend/python/pygraph/pygraph.h" +#include "../../../3rdparty/dlpack/include/dlpack/dlpack.h" #include "../common.h" #include "../cudnn_utils.h" #include "../util/cuda_runtime.h" From 6b632cfcfdf5d6d3e8585760c856382535755b2a Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Fri, 13 Mar 2026 01:03:17 -0700 Subject: [PATCH 16/23] Add DLPack dependency to common build Signed-off-by: Vladimir Cherepanov --- transformer_engine/common/CMakeLists.txt | 17 +++++++++++++++++ .../fused_attn_f16_arbitrary_seqlen.cu | 3 ++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 36b3f39bae..bbaf449076 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -4,6 +4,8 @@ cmake_minimum_required(VERSION 3.21) +include(FetchContent) + # Language options set(CMAKE_CXX_STANDARD 17) set(CMAKE_CUDA_STANDARD 17) @@ -99,6 +101,20 @@ set(CUTLASS_TOOLS_INCLUDE_DIR find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) find_package(pybind11 CONFIG REQUIRED) +option(NVTE_USE_SYSTEM_DLPACK "Whether to use a system-installed dlpack package" OFF) +if(NVTE_USE_SYSTEM_DLPACK) + find_package(dlpack REQUIRED) + set(NVTE_DLPACK_TARGET dlpack::dlpack) +else() + FetchContent_Declare( + dlpack + GIT_REPOSITORY https://github.com/dmlc/dlpack + GIT_TAG v1.1 + ) + FetchContent_MakeAvailable(dlpack) + set(NVTE_DLPACK_TARGET dlpack) +endif() + # Configure Transformer Engine library include_directories(${PROJECT_SOURCE_DIR}/..) set(transformer_engine_SOURCES) @@ -261,6 +277,7 @@ target_link_libraries(transformer_engine PUBLIC CUDA::cudart CUDNN::cudnn_all) target_link_libraries(transformer_engine PRIVATE Python::Module pybind11::headers) +target_link_libraries(transformer_engine PRIVATE ${NVTE_DLPACK_TARGET}) target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 9db25bec53..a429737cde 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -11,13 +11,14 @@ #include #include +#include + #include #include #include #include #include "../../../3rdparty/cudnn-frontend/python/pygraph/pygraph.h" -#include "../../../3rdparty/dlpack/include/dlpack/dlpack.h" #include "../common.h" #include "../cudnn_utils.h" #include "../util/cuda_runtime.h" From c3493bd98eeb7465ddbd3faf4b5a1076532be9b9 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Fri, 13 Mar 2026 01:16:24 -0700 Subject: [PATCH 17/23] Fix fused attention score_mod_tensors call site Signed-off-by: Vladimir Cherepanov --- transformer_engine/pytorch/csrc/extensions/attention.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 112900e55e..db4f8a67f3 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -301,7 +301,8 @@ std::vector fused_attn_fwd( te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, - score_mod.is_none() ? nullptr : score_mod.ptr(), workspace.data(), + score_mod.is_none() ? nullptr : score_mod.ptr(), + score_mod_tensors.is_none() ? nullptr : score_mod_tensors.ptr(), workspace.data(), at::cuda::getCurrentCUDAStream()); }); From 591cb06443a72fb6ff1054605a9ead0cccd942e2 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Fri, 13 Mar 2026 01:21:02 -0700 Subject: [PATCH 18/23] Remove pygraph tensor_like dependency from score_mod tensors Signed-off-by: Vladimir Cherepanov --- .../fused_attn_f16_arbitrary_seqlen.cu | 83 +++++++++++++++++-- 1 file changed, 77 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index a429737cde..1c66eec852 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -64,6 +64,47 @@ using ExtraTensorList = std::vector>; constexpr char kDlpackCapsuleName[] = "dltensor"; +cudnn_frontend::DataType_t convert_dlpack_dtype(const DLDataType &dtype) { + switch (dtype.code) { + case DLDataTypeCode::kDLUInt: + if (dtype.bits == 8) { + return cudnn_frontend::DataType_t::UINT8; + } + break; + case DLDataTypeCode::kDLInt: + switch (dtype.bits) { + case 8: + return cudnn_frontend::DataType_t::INT8; + case 32: + return cudnn_frontend::DataType_t::INT32; + case 64: + return cudnn_frontend::DataType_t::INT64; + } + break; + case DLDataTypeCode::kDLFloat: + switch (dtype.bits) { + case 16: + return cudnn_frontend::DataType_t::HALF; + case 32: + return cudnn_frontend::DataType_t::FLOAT; + case 64: + return cudnn_frontend::DataType_t::DOUBLE; + } + break; + case DLDataTypeCode::kDLBfloat: + if (dtype.bits == 16) { + return cudnn_frontend::DataType_t::BFLOAT16; + } + break; + case DLDataTypeCode::kDLBool: + if (dtype.bits == 8) { + return cudnn_frontend::DataType_t::BOOLEAN; + } + break; + } + return cudnn_frontend::DataType_t::NOT_SET; +} + std::vector get_sorted_extra_tensor_names(const py::dict &extra_tensors) { std::vector names; names.reserve(extra_tensors.size()); @@ -90,6 +131,38 @@ void *get_dlpack_data_pointer(py::handle tensor_obj) { return static_cast(managed->dl_tensor.data) + managed->dl_tensor.byte_offset; } +TensorAttr create_tensor_attr_from_dlpack(const std::shared_ptr &graph, + py::handle tensor_obj, const std::string &name) { + auto *managed = get_dlpack_tensor(tensor_obj); + const auto device_type = managed->dl_tensor.device.device_type; + NVTE_CHECK(device_type == kDLCPU || device_type == kDLCUDAHost || device_type == kDLCUDA || + device_type == kDLCUDAManaged, + "Invalid device type in score_mod_tensors entry."); + + const auto ndim = managed->dl_tensor.ndim; + std::vector dims(managed->dl_tensor.shape, managed->dl_tensor.shape + ndim); + const auto tensor_dtype = convert_dlpack_dtype(managed->dl_tensor.dtype); + NVTE_CHECK(tensor_dtype != cudnn_frontend::DataType_t::NOT_SET, + "Unsupported DLPack dtype in score_mod_tensors entry."); + + auto props = cudnn_frontend::graph::Tensor_attributes() + .set_name(name) + .set_data_type(tensor_dtype) + .set_is_virtual(false) + .set_is_pass_by_value(device_type == kDLCPU) + .set_dim(dims); + + if (managed->dl_tensor.strides == nullptr) { + auto stride_order = cudnn_frontend::detail::generate_row_major_stride_order(ndim); + props.set_stride(cudnn_frontend::detail::generate_stride(dims, stride_order)); + } else { + std::vector strides(managed->dl_tensor.strides, managed->dl_tensor.strides + ndim); + props.set_stride(strides); + } + + return graph->tensor(props); +} + template void hash_combine(std::uint64_t &seed, const T &value) { seed ^= std::hash{}(value) + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2); @@ -124,9 +197,8 @@ std::uint64_t get_extra_tensor_signature(void *extra_tensors_ptr) { return signature; } -py::dict get_score_mod_tensor_attrs( - const std::shared_ptr &py_graph, void *extra_tensors_ptr, - ExtraTensorList *extra_tensor_attrs) { +py::dict get_score_mod_tensor_attrs(const std::shared_ptr &graph, + void *extra_tensors_ptr, ExtraTensorList *extra_tensor_attrs) { py::dict callback_tensors; if (extra_tensors_ptr == nullptr) { return callback_tensors; @@ -142,8 +214,7 @@ py::dict get_score_mod_tensor_attrs( py::dict extra_tensors = py::reinterpret_borrow(reinterpret_cast(extra_tensors_ptr)); for (const auto &name : get_sorted_extra_tensor_names(extra_tensors)) { py::handle tensor_obj = extra_tensors[py::str(name)]; - auto tensor_attr = py_graph->tensor_like(py::reinterpret_borrow(tensor_obj)); - tensor_attr->set_name(name); + auto tensor_attr = create_tensor_attr_from_dlpack(graph, tensor_obj, name); callback_tensors[py::str(name)] = py::cast(tensor_attr); if (extra_tensor_attrs != nullptr) { extra_tensor_attrs->emplace_back(name, std::move(tensor_attr)); @@ -190,7 +261,7 @@ auto make_attention_score_modifier(void *callback_ptr, void *extra_tensors_ptr, auto py_graph = std::make_shared(graph); py::dict callback_tensors = - get_score_mod_tensor_attrs(py_graph, extra_tensors_ptr, extra_tensor_attrs); + get_score_mod_tensor_attrs(graph, extra_tensors_ptr, extra_tensor_attrs); py::object result = callback_tensors.empty() ? callback_fn(*py_graph, score_tensor) From c663752076243db4eb54bc8b058cb7f3cf9b4767 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Fri, 13 Mar 2026 01:26:22 -0700 Subject: [PATCH 19/23] Fix MultiheadAttention typing import for score mod tensors Signed-off-by: Vladimir Cherepanov --- transformer_engine/pytorch/attention/multi_head_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index 696f3d43fc..3b6ff13c85 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -5,7 +5,7 @@ """Multi-head Attention.""" import os import collections -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from transformer_engine.pytorch.quantization import FP8GlobalStateManager From 2035fe64242d3c47bf5b70e42b814f6b464f0f45 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Fri, 13 Mar 2026 01:33:25 -0700 Subject: [PATCH 20/23] Fix score mod extra tensor DLPack lifetime Signed-off-by: Vladimir Cherepanov --- tests/pytorch/attention/test_attention.py | 2 +- .../fused_attn_f16_arbitrary_seqlen.cu | 91 ++++++++++--------- 2 files changed, 49 insertions(+), 44 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 069644d997..12f8eae3c9 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1748,7 +1748,7 @@ def test_dpa_score_mod_causal_external_neg_inf(dtype, model_configs, model): v = (torch.randn(b, sq, h, d, dtype=dtype, device="cuda") * 0.1).detach().requires_grad_(True) cu_seqlens = torch.arange(0, (b + 1) * sq, sq, dtype=torch.int32, device="cuda") out_grad = (torch.randn(b, sq, h * d, dtype=dtype, device="cuda") * 0.01).detach() - neg_inf = torch.full((1, 1, 1, 1), float("-inf"), dtype=torch.float32) + neg_inf = torch.full((1, 1, 1, 1), float("-inf"), dtype=torch.float32, device="cuda") _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 1c66eec852..97c1f1dc77 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -115,7 +115,8 @@ std::vector get_sorted_extra_tensor_names(const py::dict &extra_ten return names; } -DLManagedTensor *get_dlpack_tensor(py::handle tensor_obj) { +template +decltype(auto) with_dlpack_tensor(py::handle tensor_obj, Fn &&fn) { NVTE_CHECK(py::hasattr(tensor_obj, "__dlpack__"), "score_mod_tensors entries must support __dlpack__()."); py::capsule capsule = tensor_obj.attr("__dlpack__")(); @@ -123,44 +124,48 @@ DLManagedTensor *get_dlpack_tensor(py::handle tensor_obj) { auto *managed = static_cast(PyCapsule_GetPointer(capsule.ptr(), kDlpackCapsuleName)); NVTE_CHECK(managed != nullptr, "Invalid DLPack capsule in score_mod_tensors entry."); - return managed; + return std::forward(fn)(*managed); } void *get_dlpack_data_pointer(py::handle tensor_obj) { - auto *managed = get_dlpack_tensor(tensor_obj); - return static_cast(managed->dl_tensor.data) + managed->dl_tensor.byte_offset; + return with_dlpack_tensor(tensor_obj, [](const DLManagedTensor &managed) -> void * { + return static_cast(managed.dl_tensor.data) + managed.dl_tensor.byte_offset; + }); } TensorAttr create_tensor_attr_from_dlpack(const std::shared_ptr &graph, py::handle tensor_obj, const std::string &name) { - auto *managed = get_dlpack_tensor(tensor_obj); - const auto device_type = managed->dl_tensor.device.device_type; - NVTE_CHECK(device_type == kDLCPU || device_type == kDLCUDAHost || device_type == kDLCUDA || - device_type == kDLCUDAManaged, - "Invalid device type in score_mod_tensors entry."); - - const auto ndim = managed->dl_tensor.ndim; - std::vector dims(managed->dl_tensor.shape, managed->dl_tensor.shape + ndim); - const auto tensor_dtype = convert_dlpack_dtype(managed->dl_tensor.dtype); - NVTE_CHECK(tensor_dtype != cudnn_frontend::DataType_t::NOT_SET, - "Unsupported DLPack dtype in score_mod_tensors entry."); - - auto props = cudnn_frontend::graph::Tensor_attributes() - .set_name(name) - .set_data_type(tensor_dtype) - .set_is_virtual(false) - .set_is_pass_by_value(device_type == kDLCPU) - .set_dim(dims); - - if (managed->dl_tensor.strides == nullptr) { - auto stride_order = cudnn_frontend::detail::generate_row_major_stride_order(ndim); - props.set_stride(cudnn_frontend::detail::generate_stride(dims, stride_order)); - } else { - std::vector strides(managed->dl_tensor.strides, managed->dl_tensor.strides + ndim); - props.set_stride(strides); - } + return with_dlpack_tensor( + tensor_obj, [&](const DLManagedTensor &managed) -> TensorAttr { + const auto device_type = managed.dl_tensor.device.device_type; + NVTE_CHECK(device_type == kDLCPU || device_type == kDLCUDAHost || device_type == kDLCUDA || + device_type == kDLCUDAManaged, + "Invalid device type in score_mod_tensors entry."); + + const auto ndim = managed.dl_tensor.ndim; + std::vector dims(managed.dl_tensor.shape, managed.dl_tensor.shape + ndim); + const auto tensor_dtype = convert_dlpack_dtype(managed.dl_tensor.dtype); + NVTE_CHECK(tensor_dtype != cudnn_frontend::DataType_t::NOT_SET, + "Unsupported DLPack dtype in score_mod_tensors entry."); + + auto props = cudnn_frontend::graph::Tensor_attributes() + .set_name(name) + .set_data_type(tensor_dtype) + .set_is_virtual(false) + .set_is_pass_by_value(device_type == kDLCPU) + .set_dim(dims); + + if (managed.dl_tensor.strides == nullptr) { + auto stride_order = cudnn_frontend::detail::generate_row_major_stride_order(ndim); + props.set_stride(cudnn_frontend::detail::generate_stride(dims, stride_order)); + } else { + std::vector strides(managed.dl_tensor.strides, + managed.dl_tensor.strides + ndim); + props.set_stride(strides); + } - return graph->tensor(props); + return graph->tensor(props); + }); } template @@ -179,19 +184,19 @@ std::uint64_t get_extra_tensor_signature(void *extra_tensors_ptr) { for (const auto &name : get_sorted_extra_tensor_names(extra_tensors)) { py::handle tensor_obj = extra_tensors[py::str(name)]; - auto *managed = get_dlpack_tensor(tensor_obj); - hash_combine(signature, name); - hash_combine(signature, managed->dl_tensor.device.device_type); - hash_combine(signature, managed->dl_tensor.device.device_id); - hash_combine(signature, managed->dl_tensor.dtype.code); - hash_combine(signature, managed->dl_tensor.dtype.bits); - hash_combine(signature, managed->dl_tensor.dtype.lanes); - hash_combine(signature, managed->dl_tensor.ndim); - for (int i = 0; i < managed->dl_tensor.ndim; ++i) { - hash_combine(signature, managed->dl_tensor.shape[i]); - hash_combine(signature, managed->dl_tensor.strides ? managed->dl_tensor.strides[i] : -1); - } + with_dlpack_tensor(tensor_obj, [&](const DLManagedTensor &managed) { + hash_combine(signature, managed.dl_tensor.device.device_type); + hash_combine(signature, managed.dl_tensor.device.device_id); + hash_combine(signature, managed.dl_tensor.dtype.code); + hash_combine(signature, managed.dl_tensor.dtype.bits); + hash_combine(signature, managed.dl_tensor.dtype.lanes); + hash_combine(signature, managed.dl_tensor.ndim); + for (int i = 0; i < managed.dl_tensor.ndim; ++i) { + hash_combine(signature, managed.dl_tensor.shape[i]); + hash_combine(signature, managed.dl_tensor.strides ? managed.dl_tensor.strides[i] : -1); + } + }); } return signature; From 9d0843ff3dd1f4e084076fd96c8b6216999fffbc Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Fri, 13 Mar 2026 01:46:08 -0700 Subject: [PATCH 21/23] Support host score mod tensors via retained DLPack capsules Signed-off-by: Vladimir Cherepanov --- tests/pytorch/attention/test_attention.py | 5 +- .../fused_attn_f16_arbitrary_seqlen.cu | 79 ++++++++++++++++--- 2 files changed, 71 insertions(+), 13 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 12f8eae3c9..4c1da9969a 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1715,7 +1715,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("model_configs", [model_configs_score_mod]) @pytest.mark.parametrize("model", model_configs_score_mod.keys()) -def test_dpa_score_mod_causal_external_neg_inf(dtype, model_configs, model): +@pytest.mark.parametrize("neg_inf_device", ["cuda", "cpu"], ids=["cuda_tensor", "cpu_by_value_tensor"]) +def test_dpa_score_mod_causal_external_neg_inf(dtype, model_configs, model, neg_inf_device): """Test DotProductAttention causal masking via score_mod with external variant-pack tensor.""" config = model_configs[model] @@ -1748,7 +1749,7 @@ def test_dpa_score_mod_causal_external_neg_inf(dtype, model_configs, model): v = (torch.randn(b, sq, h, d, dtype=dtype, device="cuda") * 0.1).detach().requires_grad_(True) cu_seqlens = torch.arange(0, (b + 1) * sq, sq, dtype=torch.int32, device="cuda") out_grad = (torch.randn(b, sq, h * d, dtype=dtype, device="cuda") * 0.01).detach() - neg_inf = torch.full((1, 1, 1, 1), float("-inf"), dtype=torch.float32, device="cuda") + neg_inf = torch.full((1, 1, 1, 1), float("-inf"), dtype=torch.float32, device=neg_inf_device) _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 97c1f1dc77..ac65950aea 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -64,6 +64,49 @@ using ExtraTensorList = std::vector>; constexpr char kDlpackCapsuleName[] = "dltensor"; +struct DlpackTensorView { + PyObject *capsule = nullptr; + DLManagedTensor *managed = nullptr; + + DlpackTensorView() = default; + DlpackTensorView(PyObject *capsule_obj, DLManagedTensor *managed_tensor) + : capsule(capsule_obj), managed(managed_tensor) {} + + DlpackTensorView(const DlpackTensorView &) = delete; + auto operator=(const DlpackTensorView &) -> DlpackTensorView & = delete; + + DlpackTensorView(DlpackTensorView &&other) noexcept + : capsule(other.capsule), managed(other.managed) { + other.capsule = nullptr; + other.managed = nullptr; + } + + auto operator=(DlpackTensorView &&other) noexcept -> DlpackTensorView & { + if (this != &other) { + reset(); + capsule = other.capsule; + managed = other.managed; + other.capsule = nullptr; + other.managed = nullptr; + } + return *this; + } + + ~DlpackTensorView() { reset(); } + + void reset() { + if (capsule == nullptr) { + return; + } + py::gil_scoped_acquire gil; + Py_DECREF(capsule); + capsule = nullptr; + managed = nullptr; + } +}; + +using DlpackTensorViews = std::vector; + cudnn_frontend::DataType_t convert_dlpack_dtype(const DLDataType &dtype) { switch (dtype.code) { case DLDataTypeCode::kDLUInt: @@ -127,10 +170,16 @@ decltype(auto) with_dlpack_tensor(py::handle tensor_obj, Fn &&fn) { return std::forward(fn)(*managed); } -void *get_dlpack_data_pointer(py::handle tensor_obj) { - return with_dlpack_tensor(tensor_obj, [](const DLManagedTensor &managed) -> void * { - return static_cast(managed.dl_tensor.data) + managed.dl_tensor.byte_offset; - }); +DlpackTensorView make_dlpack_tensor_view(py::handle tensor_obj) { + NVTE_CHECK(py::hasattr(tensor_obj, "__dlpack__"), + "score_mod_tensors entries must support __dlpack__()."); + py::capsule capsule = tensor_obj.attr("__dlpack__")(); + NVTE_CHECK(!capsule.is_none(), "Failed to retrieve DLPack capsule for score_mod_tensors entry."); + auto *managed = + static_cast(PyCapsule_GetPointer(capsule.ptr(), kDlpackCapsuleName)); + NVTE_CHECK(managed != nullptr, "Invalid DLPack capsule in score_mod_tensors entry."); + Py_INCREF(capsule.ptr()); + return DlpackTensorView(capsule.ptr(), managed); } TensorAttr create_tensor_attr_from_dlpack(const std::shared_ptr &graph, @@ -229,20 +278,26 @@ py::dict get_score_mod_tensor_attrs(const std::shared_ptr &variant_pack) { + DlpackTensorViews views; if (extra_tensors_ptr == nullptr || extra_tensor_attrs.empty()) { - return; + return views; } py::gil_scoped_acquire gil; py::dict extra_tensors = py::reinterpret_borrow(reinterpret_cast(extra_tensors_ptr)); + views.reserve(extra_tensor_attrs.size()); for (const auto &[name, tensor_attr] : extra_tensor_attrs) { py::str key(name); NVTE_CHECK(extra_tensors.contains(key), "Missing score_mod tensor entry: ", name); - variant_pack[tensor_attr] = get_dlpack_data_pointer(extra_tensors[key]); + auto view = make_dlpack_tensor_view(extra_tensors[key]); + variant_pack[tensor_attr] = + static_cast(view.managed->dl_tensor.data) + view.managed->dl_tensor.byte_offset; + views.emplace_back(std::move(view)); } + return views; } auto make_attention_score_modifier(void *callback_ptr, void *extra_tensors_ptr, @@ -788,7 +843,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( variant_pack[softmax_offset] = devPtrSoftmaxOffset; } - extend_variant_pack_with_extra_tensors(score_mod_tensors, score_mod_extra_tensors, variant_pack); + auto score_mod_tensor_views = + extend_variant_pack_with_extra_tensors(score_mod_tensors, score_mod_extra_tensors, variant_pack); NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); } catch (cudnn_frontend::cudnnException &e) { @@ -1320,9 +1376,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl( variant_pack[d_softmax_offset] = devPtrdSoftmaxOffset; } - extend_variant_pack_with_extra_tensors(score_mod_tensors, score_mod_extra_tensors, variant_pack); - extend_variant_pack_with_extra_tensors(score_mod_bprop_tensors, score_mod_bprop_extra_tensors, - variant_pack); + auto score_mod_tensor_views = + extend_variant_pack_with_extra_tensors(score_mod_tensors, score_mod_extra_tensors, variant_pack); + auto score_mod_bprop_tensor_views = extend_variant_pack_with_extra_tensors( + score_mod_bprop_tensors, score_mod_bprop_extra_tensors, variant_pack); NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); } catch (cudnn_frontend::cudnnException &e) { From 6a623996e867884b1db8a211e3ec8a973c5feeb4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 16 Mar 2026 20:25:28 +0000 Subject: [PATCH 22/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/test_attention.py | 63 ++++- .../common/fused_attn/fused_attn.cpp | 31 +-- .../fused_attn_f16_arbitrary_seqlen.cu | 260 +++++++++--------- .../fused_attn_f16_arbitrary_seqlen.h | 17 +- .../include/transformer_engine/fused_attn.h | 25 +- .../jax/csrc/extensions/attention.cpp | 41 ++- .../dot_product_attention/backends.py | 3 +- .../pytorch/cpp_extensions/fused_attn.py | 7 +- transformer_engine/pytorch/csrc/extensions.h | 9 +- .../pytorch/csrc/extensions/attention.cpp | 57 ++-- 10 files changed, 268 insertions(+), 245 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 4c1da9969a..ae7ab9e1e7 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1509,7 +1509,9 @@ def test_multihead_attention_score_mod_forces_fused_backend(): } -@pytest.mark.skipif(get_cudnn_version() < (9, 7, 0), reason="cuDNN 9.7.0+ is required for score_mod.") +@pytest.mark.skipif( + get_cudnn_version() < (9, 7, 0), reason="cuDNN 9.7.0+ is required for score_mod." +) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("model_configs", [model_configs_score_mod]) @pytest.mark.parametrize("model", model_configs_score_mod.keys()) @@ -1539,7 +1541,10 @@ def test_dpa_score_mod(dtype, model_configs, model): ) _, fused_attn_supported, _ = available_backends - if not fused_attn_supported or FusedAttnBackend["F16_arbitrary_seqlen"] not in fused_attn_backends: + if ( + not fused_attn_supported + or FusedAttnBackend["F16_arbitrary_seqlen"] not in fused_attn_backends + ): pytest.skip("F16_arbitrary_seqlen backend not available.") reset_rng_states() @@ -1558,7 +1563,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: return _DUMMY_CUDA_RNG_STATE_TRACKER block = DotProductAttention( - h, d, + h, + d, attention_dropout=0.0, qkv_format=qkv_format, attn_mask_type=config.attn_mask_type, @@ -1570,8 +1576,17 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: ).to(dtype=dtype, device="cuda") # Reference: run without score_mod (score_mod=None by default) - out_ref = block(q, k, v, qkv_format=qkv_format, cu_seqlens_q=cu_seqlens, cu_seqlens_kv=cu_seqlens, - max_seqlen_q=sq, max_seqlen_kv=sq, attn_mask_type=config.attn_mask_type) + out_ref = block( + q, + k, + v, + qkv_format=qkv_format, + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + max_seqlen_q=sq, + max_seqlen_kv=sq, + attn_mask_type=config.attn_mask_type, + ) out_ref.backward(out_grad) dq_ref = q.grad.clone() dk_ref = k.grad.clone() @@ -1589,9 +1604,19 @@ def identity_score_mod(graph, score): return None # Run with score_mod identity callable - out_sm = block(q, k, v, qkv_format=qkv_format, cu_seqlens_q=cu_seqlens, cu_seqlens_kv=cu_seqlens, - max_seqlen_q=sq, max_seqlen_kv=sq, attn_mask_type=config.attn_mask_type, - score_mod=identity_score_mod, score_mod_bprop=identity_score_mod) + out_sm = block( + q, + k, + v, + qkv_format=qkv_format, + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + max_seqlen_q=sq, + max_seqlen_kv=sq, + attn_mask_type=config.attn_mask_type, + score_mod=identity_score_mod, + score_mod_bprop=identity_score_mod, + ) out_sm.backward(out_grad) dq_sm = q.grad.clone() dk_sm = k.grad.clone() @@ -1608,7 +1633,9 @@ def identity_score_mod(graph, score): torch.testing.assert_close(dv_sm, dv_ref, **tols) -@pytest.mark.skipif(get_cudnn_version() < (9, 7, 0), reason="cuDNN 9.7.0+ is required for score_mod.") +@pytest.mark.skipif( + get_cudnn_version() < (9, 7, 0), reason="cuDNN 9.7.0+ is required for score_mod." +) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("model_configs", [model_configs_score_mod]) @pytest.mark.parametrize("model", model_configs_score_mod.keys()) @@ -1634,7 +1661,10 @@ def test_dpa_score_mod_causal(dtype, model_configs, model): ) _, fused_attn_supported, _ = available_backends - if not fused_attn_supported or FusedAttnBackend["F16_arbitrary_seqlen"] not in fused_attn_backends: + if ( + not fused_attn_supported + or FusedAttnBackend["F16_arbitrary_seqlen"] not in fused_attn_backends + ): pytest.skip("F16_arbitrary_seqlen backend not available.") reset_rng_states() @@ -1711,11 +1741,15 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: torch.testing.assert_close(dv_sm, dv_ref, **tols) -@pytest.mark.skipif(get_cudnn_version() < (9, 7, 0), reason="cuDNN 9.7.0+ is required for score_mod.") +@pytest.mark.skipif( + get_cudnn_version() < (9, 7, 0), reason="cuDNN 9.7.0+ is required for score_mod." +) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("model_configs", [model_configs_score_mod]) @pytest.mark.parametrize("model", model_configs_score_mod.keys()) -@pytest.mark.parametrize("neg_inf_device", ["cuda", "cpu"], ids=["cuda_tensor", "cpu_by_value_tensor"]) +@pytest.mark.parametrize( + "neg_inf_device", ["cuda", "cpu"], ids=["cuda_tensor", "cpu_by_value_tensor"] +) def test_dpa_score_mod_causal_external_neg_inf(dtype, model_configs, model, neg_inf_device): """Test DotProductAttention causal masking via score_mod with external variant-pack tensor.""" @@ -1738,7 +1772,10 @@ def test_dpa_score_mod_causal_external_neg_inf(dtype, model_configs, model, neg_ ) _, fused_attn_supported, _ = available_backends - if not fused_attn_supported or FusedAttnBackend["F16_arbitrary_seqlen"] not in fused_attn_backends: + if ( + not fused_attn_supported + or FusedAttnBackend["F16_arbitrary_seqlen"] not in fused_attn_backends + ): pytest.skip("F16_arbitrary_seqlen backend not available.") reset_rng_states() diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 54b5bd80bf..1828e67df2 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -535,20 +535,17 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( } // NVTE fused attention FWD with separate Q, K and V -void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, - const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, void *score_mod, void *score_mod_tensors, - NVTETensor workspace, cudaStream_t stream) { +void nvte_fused_attn_fwd( + const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, + bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + void *score_mod, void *score_mod_tensors, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -636,9 +633,9 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, bottom_right_diagonal, score_mod, - score_mod_tensors, input_Q, input_K, input_V, input_Bias, input_SoftmaxOffset, output_O, - Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, + window_size_left, window_size_right, bottom_right_diagonal, score_mod, score_mod_tensors, + input_Q, input_K, input_V, input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, + input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); #else diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index ac65950aea..654166cebe 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -8,11 +8,10 @@ #include #include #include +#include #include #include -#include - #include #include #include @@ -182,39 +181,38 @@ DlpackTensorView make_dlpack_tensor_view(py::handle tensor_obj) { return DlpackTensorView(capsule.ptr(), managed); } -TensorAttr create_tensor_attr_from_dlpack(const std::shared_ptr &graph, - py::handle tensor_obj, const std::string &name) { - return with_dlpack_tensor( - tensor_obj, [&](const DLManagedTensor &managed) -> TensorAttr { - const auto device_type = managed.dl_tensor.device.device_type; - NVTE_CHECK(device_type == kDLCPU || device_type == kDLCUDAHost || device_type == kDLCUDA || - device_type == kDLCUDAManaged, - "Invalid device type in score_mod_tensors entry."); - - const auto ndim = managed.dl_tensor.ndim; - std::vector dims(managed.dl_tensor.shape, managed.dl_tensor.shape + ndim); - const auto tensor_dtype = convert_dlpack_dtype(managed.dl_tensor.dtype); - NVTE_CHECK(tensor_dtype != cudnn_frontend::DataType_t::NOT_SET, - "Unsupported DLPack dtype in score_mod_tensors entry."); - - auto props = cudnn_frontend::graph::Tensor_attributes() - .set_name(name) - .set_data_type(tensor_dtype) - .set_is_virtual(false) - .set_is_pass_by_value(device_type == kDLCPU) - .set_dim(dims); - - if (managed.dl_tensor.strides == nullptr) { - auto stride_order = cudnn_frontend::detail::generate_row_major_stride_order(ndim); - props.set_stride(cudnn_frontend::detail::generate_stride(dims, stride_order)); - } else { - std::vector strides(managed.dl_tensor.strides, - managed.dl_tensor.strides + ndim); - props.set_stride(strides); - } +TensorAttr create_tensor_attr_from_dlpack( + const std::shared_ptr &graph, py::handle tensor_obj, + const std::string &name) { + return with_dlpack_tensor(tensor_obj, [&](const DLManagedTensor &managed) -> TensorAttr { + const auto device_type = managed.dl_tensor.device.device_type; + NVTE_CHECK(device_type == kDLCPU || device_type == kDLCUDAHost || device_type == kDLCUDA || + device_type == kDLCUDAManaged, + "Invalid device type in score_mod_tensors entry."); + + const auto ndim = managed.dl_tensor.ndim; + std::vector dims(managed.dl_tensor.shape, managed.dl_tensor.shape + ndim); + const auto tensor_dtype = convert_dlpack_dtype(managed.dl_tensor.dtype); + NVTE_CHECK(tensor_dtype != cudnn_frontend::DataType_t::NOT_SET, + "Unsupported DLPack dtype in score_mod_tensors entry."); + + auto props = cudnn_frontend::graph::Tensor_attributes() + .set_name(name) + .set_data_type(tensor_dtype) + .set_is_virtual(false) + .set_is_pass_by_value(device_type == kDLCPU) + .set_dim(dims); + + if (managed.dl_tensor.strides == nullptr) { + auto stride_order = cudnn_frontend::detail::generate_row_major_stride_order(ndim); + props.set_stride(cudnn_frontend::detail::generate_stride(dims, stride_order)); + } else { + std::vector strides(managed.dl_tensor.strides, managed.dl_tensor.strides + ndim); + props.set_stride(strides); + } - return graph->tensor(props); - }); + return graph->tensor(props); + }); } template @@ -228,7 +226,8 @@ std::uint64_t get_extra_tensor_signature(void *extra_tensors_ptr) { } py::gil_scoped_acquire gil; - py::dict extra_tensors = py::reinterpret_borrow(reinterpret_cast(extra_tensors_ptr)); + py::dict extra_tensors = + py::reinterpret_borrow(reinterpret_cast(extra_tensors_ptr)); std::uint64_t signature = 0; for (const auto &name : get_sorted_extra_tensor_names(extra_tensors)) { @@ -265,7 +264,8 @@ py::dict get_score_mod_tensor_attrs(const std::shared_ptr(reinterpret_cast(extra_tensors_ptr)); + py::dict extra_tensors = + py::reinterpret_borrow(reinterpret_cast(extra_tensors_ptr)); for (const auto &name : get_sorted_extra_tensor_names(extra_tensors)) { py::handle tensor_obj = extra_tensors[py::str(name)]; auto tensor_attr = create_tensor_attr_from_dlpack(graph, tensor_obj, name); @@ -287,7 +287,8 @@ DlpackTensorViews extend_variant_pack_with_extra_tensors( } py::gil_scoped_acquire gil; - py::dict extra_tensors = py::reinterpret_borrow(reinterpret_cast(extra_tensors_ptr)); + py::dict extra_tensors = + py::reinterpret_borrow(reinterpret_cast(extra_tensors_ptr)); views.reserve(extra_tensor_attrs.size()); for (const auto &[name, tensor_attr] : extra_tensor_attrs) { py::str key(name); @@ -314,23 +315,22 @@ auto make_attention_score_modifier(void *callback_ptr, void *extra_tensors_ptr, return [callback, extra_tensors_ptr, extra_tensor_attrs]( std::shared_ptr graph, std::shared_ptr score_tensor) mutable { - py::gil_scoped_acquire gil; - py::module_::import("cudnn"); - py::function callback_fn = py::reinterpret_borrow(callback); - - auto py_graph = - std::make_shared(graph); - py::dict callback_tensors = - get_score_mod_tensor_attrs(graph, extra_tensors_ptr, extra_tensor_attrs); - - py::object result = callback_tensors.empty() - ? callback_fn(*py_graph, score_tensor) - : callback_fn(*py_graph, score_tensor, callback_tensors); - if (result.is_none()) { - return score_tensor; - } - return result.cast>(); - }; + py::gil_scoped_acquire gil; + py::module_::import("cudnn"); + py::function callback_fn = py::reinterpret_borrow(callback); + + auto py_graph = std::make_shared(graph); + py::dict callback_tensors = + get_score_mod_tensor_attrs(graph, extra_tensors_ptr, extra_tensor_attrs); + + py::object result = callback_tensors.empty() + ? callback_fn(*py_graph, score_tensor) + : callback_fn(*py_graph, score_tensor, callback_tensors); + if (result.is_none()) { + return score_tensor; + } + return result.cast>(); + }; } } // namespace @@ -346,10 +346,9 @@ void fused_attn_arbitrary_seqlen_fwd_impl( bool bottom_right_diagonal, void *score_mod, void *score_mod_tensors, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1, void *devPtrS2, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, - void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, - void *devPtrPageTableV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, - cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, - cudaStream_t stream, cudnnHandle_t handle) { + void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, + void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, + void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -437,27 +436,27 @@ void fused_attn_arbitrary_seqlen_fwd_impl( namespace fe = cudnn_frontend; using graph_and_tensors = std::tuple, - std::shared_ptr, // Q - std::shared_ptr, // K - std::shared_ptr, // V - std::shared_ptr, // attn_scale - std::shared_ptr, // O - std::shared_ptr, // S1 - std::shared_ptr, // S2 - std::shared_ptr, // bias - std::shared_ptr, // softmax_offset - std::shared_ptr, // seq_q - std::shared_ptr, // seq_kv - std::shared_ptr, // page_table_k - std::shared_ptr, // page_table_v - std::shared_ptr, // offset_q - std::shared_ptr, // offset_k - std::shared_ptr, // offset_v - std::shared_ptr, // offset_o - std::shared_ptr, // offset_stats - std::shared_ptr, // dropout_seed - std::shared_ptr, // dropout_offset - ExtraTensorList>; // score_mod extra tensors + std::shared_ptr, // Q + std::shared_ptr, // K + std::shared_ptr, // V + std::shared_ptr, // attn_scale + std::shared_ptr, // O + std::shared_ptr, // S1 + std::shared_ptr, // S2 + std::shared_ptr, // bias + std::shared_ptr, // softmax_offset + std::shared_ptr, // seq_q + std::shared_ptr, // seq_kv + std::shared_ptr, // page_table_k + std::shared_ptr, // page_table_v + std::shared_ptr, // offset_q + std::shared_ptr, // offset_k + std::shared_ptr, // offset_v + std::shared_ptr, // offset_o + std::shared_ptr, // offset_stats + std::shared_ptr, // dropout_seed + std::shared_ptr, // dropout_offset + ExtraTensorList>; // score_mod extra tensors using CacheType = std::map; static thread_local CacheType sdpa_f16_fprop_cache; @@ -717,11 +716,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl( NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - auto return_tuple = - std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, - softmax_offset_tuple, padding_tuple, page_table_tuple, offset_qo_tuple, - offset_kv_tuple, offset_s_tuple, dropout_tuple, - std::make_tuple(score_mod_extra_tensors)); + auto return_tuple = std::tuple_cat( + std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, + softmax_offset_tuple, padding_tuple, page_table_tuple, offset_qo_tuple, offset_kv_tuple, + offset_s_tuple, dropout_tuple, std::make_tuple(score_mod_extra_tensors)); cache.insert({descriptor, return_tuple}); return return_tuple; @@ -843,8 +841,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( variant_pack[softmax_offset] = devPtrSoftmaxOffset; } - auto score_mod_tensor_views = - extend_variant_pack_with_extra_tensors(score_mod_tensors, score_mod_extra_tensors, variant_pack); + auto score_mod_tensor_views = extend_variant_pack_with_extra_tensors( + score_mod_tensors, score_mod_extra_tensors, variant_pack); NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); } catch (cudnn_frontend::cudnnException &e) { @@ -859,10 +857,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl( NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, void *score_mod, void *score_mod_bprop, - void *score_mod_tensors, void *score_mod_bprop_tensors, void *devPtrQ, - void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, - void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrdQ, void *devPtrdK, void *devPtrdV, - void *devPtrdO, void *devPtrdBias, void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, + void *score_mod_tensors, void *score_mod_bprop_tensors, void *devPtrQ, void *devPtrKTranspose, + void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, + void *devPtrSoftmaxOffset, void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, + void *devPtrdBias, void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { @@ -957,31 +955,31 @@ void fused_attn_arbitrary_seqlen_bwd_impl( namespace fe = cudnn_frontend; using graph_and_tensors = std::tuple, - std::shared_ptr, // q - std::shared_ptr, // k - std::shared_ptr, // v - std::shared_ptr, // o - std::shared_ptr, // dO - std::shared_ptr, // stats - std::shared_ptr, // attn_scale - std::shared_ptr, // dQ - std::shared_ptr, // dK - std::shared_ptr, // dV - std::shared_ptr, // bias - std::shared_ptr, // dBias - std::shared_ptr, // softmax_offset - std::shared_ptr, // d_softmax_offset - std::shared_ptr, // seq_q - std::shared_ptr, // seq_kv - std::shared_ptr, // offset_q - std::shared_ptr, // offset_k - std::shared_ptr, // offset_v - std::shared_ptr, // offset_o - std::shared_ptr, // offset_stats - std::shared_ptr, // dropout_seed - std::shared_ptr, // dropout_offset - ExtraTensorList, // score_mod extra tensors - ExtraTensorList>; // score_mod_bprop extra tensors + std::shared_ptr, // q + std::shared_ptr, // k + std::shared_ptr, // v + std::shared_ptr, // o + std::shared_ptr, // dO + std::shared_ptr, // stats + std::shared_ptr, // attn_scale + std::shared_ptr, // dQ + std::shared_ptr, // dK + std::shared_ptr, // dV + std::shared_ptr, // bias + std::shared_ptr, // dBias + std::shared_ptr, // softmax_offset + std::shared_ptr, // d_softmax_offset + std::shared_ptr, // seq_q + std::shared_ptr, // seq_kv + std::shared_ptr, // offset_q + std::shared_ptr, // offset_k + std::shared_ptr, // offset_v + std::shared_ptr, // offset_o + std::shared_ptr, // offset_stats + std::shared_ptr, // dropout_seed + std::shared_ptr, // dropout_offset + ExtraTensorList, // score_mod extra tensors + ExtraTensorList>; // score_mod_bprop extra tensors using CacheType = std::map; static thread_local CacheType sdpa_f16_bprop_cache; @@ -1103,8 +1101,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( .set_causal_mask_bottom_right(is_bottom_right) .set_attn_scale(attn_scale); if (score_mod != nullptr) { - sdpa_backward_options.set_score_mod(make_attention_score_modifier( - score_mod, score_mod_tensors, &score_mod_extra_tensors)); + sdpa_backward_options.set_score_mod( + make_attention_score_modifier(score_mod, score_mod_tensors, &score_mod_extra_tensors)); } if (score_mod_bprop != nullptr) { sdpa_backward_options.set_score_mod_bprop(make_attention_score_modifier( @@ -1246,11 +1244,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl( NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, - softmax_offset_tuple, padding_tuple, offset_qo_tuple, - offset_kv_tuple, offset_s_tuple, dropout_tuple, - std::make_tuple(score_mod_extra_tensors), - std::make_tuple(score_mod_bprop_extra_tensors)); + auto return_tuple = std::tuple_cat( + std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, softmax_offset_tuple, + padding_tuple, offset_qo_tuple, offset_kv_tuple, offset_s_tuple, dropout_tuple, + std::make_tuple(score_mod_extra_tensors), std::make_tuple(score_mod_bprop_extra_tensors)); cache.insert({descriptor, return_tuple}); return return_tuple; @@ -1258,8 +1255,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV, bias, dBias, softmax_offset, d_softmax_offset, seq_q, seq_kv, offset_q, offset_o, offset_k, offset_v, offset_stats, - dropout_seed, dropout_offset, score_mod_extra_tensors, - score_mod_bprop_extra_tensors] = get_graph(sdpa_f16_bprop_cache, descriptor); + dropout_seed, dropout_offset, score_mod_extra_tensors, score_mod_bprop_extra_tensors] = + get_graph(sdpa_f16_bprop_cache, descriptor); // Exit to request upper level API to allocate memory if needed // n.b. Care should be taken to align each of the added worksapce tensors to their type. @@ -1376,8 +1373,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( variant_pack[d_softmax_offset] = devPtrdSoftmaxOffset; } - auto score_mod_tensor_views = - extend_variant_pack_with_extra_tensors(score_mod_tensors, score_mod_extra_tensors, variant_pack); + auto score_mod_tensor_views = extend_variant_pack_with_extra_tensors( + score_mod_tensors, score_mod_extra_tensors, variant_pack); auto score_mod_bprop_tensor_views = extend_variant_pack_with_extra_tensors( score_mod_bprop_tensors, score_mod_bprop_extra_tensors, variant_pack); @@ -1398,9 +1395,9 @@ void fused_attn_arbitrary_seqlen_fwd( NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, void *score_mod, void *score_mod_tensors, const Tensor *input_Q, const Tensor *input_K, - const Tensor *input_V, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *input_V, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, + Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -1540,9 +1537,8 @@ void fused_attn_arbitrary_seqlen_fwd( softmax_type, window_size_left, window_size_right, bottom_right_diagonal, score_mod, score_mod_tensors, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, - devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, - devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, - stream, handle); + devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, + get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 84ed29c99e..a6e2bf25fa 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -26,10 +26,10 @@ void fused_attn_arbitrary_seqlen_fwd( bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - void *score_mod, void *score_mod_tensors, - const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, - const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + void *score_mod, void *score_mod_tensors, const Tensor *input_Q, const Tensor *input_K, + const Tensor *input_V, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, + Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); @@ -40,11 +40,10 @@ void fused_attn_arbitrary_seqlen_bwd( NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, void *score_mod, void *score_mod_bprop, void *score_mod_tensors, - void *score_mod_bprop_tensors, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S, - Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, - Tensor *output_dSoftmaxOffset, + void *score_mod_bprop_tensors, const Tensor *input_Q, const Tensor *input_K, + const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, + const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, + Tensor *output_dV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 44b38918af..a12a1dbb4f 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -285,20 +285,17 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, - const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, void *score_mod, void *score_mod_tensors, - NVTETensor workspace, cudaStream_t stream); +void nvte_fused_attn_fwd( + const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, + bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + void *score_mod, void *score_mod_tensors, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with separate Q, K and V. * diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 3c7d48933b..ad52d38db8 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -322,15 +322,15 @@ static void FusedAttnForwardImpl( auto k_tensor = TensorWrapper(k_ptr, k_shape, dtype); auto v_tensor = TensorWrapper(v_ptr, v_shape, dtype); - nvte_fused_attn_fwd( - q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), - softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), - k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), - rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false, - scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, bottom_right_diagonal, nullptr, - workspace_tensor.data(), stream); + nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), + softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), + &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), + dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), + rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, + false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, + nullptr, workspace_tensor.data(), stream); nvte_tensor_pack_destroy(&aux_output_tensors); } @@ -597,18 +597,17 @@ static void FusedAttnBackwardImpl( } } - nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), - doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), - dbias_tensor.data(), dsoftmax_offset_tensor.data(), - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, - kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, - mask_type, softmax_type, window_size_left, window_size_right, - bottom_right_diagonal, deterministic, false, nullptr, nullptr, - workspace_tensor.data(), stream); + nvte_fused_attn_bwd( + q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), + doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), dbias_tensor.data(), + dsoftmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, bottom_right_diagonal, deterministic, false, nullptr, + nullptr, workspace_tensor.data(), stream); nvte_tensor_pack_destroy(&aux_input_tensors); } diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 65bd9fd9b4..a5a1fa8c1b 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1856,8 +1856,7 @@ def forward( ), f"FusedAttention does not support qkv_layout = {qkv_layout}!" if score_mod is not None or score_mod_bprop is not None: assert ( - fused_attention_backend - == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen + fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen ), "score_mod and score_mod_bprop require the F16_arbitrary_seqlen fused backend." cp_size = 1 diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index b784155107..58176d035d 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -541,9 +541,10 @@ def fused_attn_bwd( if score_mod_bprop_tensors is not None: assert score_mod_bprop is not None, "score_mod_bprop_tensors requires score_mod_bprop." if score_mod is not None or score_mod_bprop is not None: - assert ( - fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] - ), "score_mod and score_mod_bprop are only supported by the cuDNN F16_arbitrary_seqlen backend." + assert fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"], ( + "score_mod and score_mod_bprop are only supported by the cuDNN F16_arbitrary_seqlen" + " backend." + ) if fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]: if len(aux_ctx_tensors) < 1: diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 059de63c6d..f9dfe464c7 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -94,9 +94,8 @@ std::vector fused_attn_fwd( const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, const std::optional SoftmaxOffset, py::handle score_mod, - py::handle score_mod_tensors, - const std::optional rng_gen, size_t rng_elts_per_thread, bool return_max_logit, - bool cuda_graph); + py::handle score_mod_tensors, const std::optional rng_gen, + size_t rng_elts_per_thread, bool return_max_logit, bool cuda_graph); std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, @@ -109,8 +108,8 @@ std::vector fused_attn_bwd( const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, py::handle dp_quantizer, py::handle dqkv_quantizer, py::handle score_mod, - py::handle score_mod_bprop, py::handle score_mod_tensors, - py::handle score_mod_bprop_tensors, bool cuda_graph); + py::handle score_mod_bprop, py::handle score_mod_tensors, py::handle score_mod_bprop_tensors, + bool cuda_graph); at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index db4f8a67f3..0f73641964 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -108,9 +108,8 @@ std::vector fused_attn_fwd( const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, const std::optional SoftmaxOffset, py::handle score_mod, - py::handle score_mod_tensors, - const std::optional rng_gen, size_t rng_elts_per_thread, bool return_max_logit, - bool cuda_graph) { + py::handle score_mod_tensors, const std::optional rng_gen, + size_t rng_elts_per_thread, bool return_max_logit, bool cuda_graph) { // Ensure that cuDNN handle is created on the correct device, // overriding torch.cuda.set_device calls from user side. // Assumes all tensors passed are on the same device. @@ -325,8 +324,8 @@ std::vector fused_attn_bwd( const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, py::handle dp_quantizer, py::handle dqkv_quantizer, py::handle score_mod, - py::handle score_mod_bprop, py::handle score_mod_tensors, - py::handle score_mod_bprop_tensors, bool cuda_graph) { + py::handle score_mod_bprop, py::handle score_mod_tensors, py::handle score_mod_bprop_tensors, + bool cuda_graph) { auto none = py::none(); // create QKV, O, dO tensor wrappers @@ -541,18 +540,18 @@ std::vector fused_attn_bwd( // populate tensors with appropriate shapes and dtypes NVTE_SCOPED_GIL_RELEASE({ - nvte_fused_attn_bwd( - te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), - &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), - te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), - te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], - window_size[1], bottom_right_diagonal, deterministic, cuda_graph, - score_mod.is_none() ? nullptr : score_mod.ptr(), - score_mod_bprop.is_none() ? nullptr : score_mod_bprop.ptr(), - score_mod_tensors.is_none() ? nullptr : score_mod_tensors.ptr(), - score_mod_bprop_tensors.is_none() ? nullptr : score_mod_bprop_tensors.ptr(), - workspace.data(), at::cuda::getCurrentCUDAStream()); + nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), + te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), + te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), + te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, + max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, + softmax_type, window_size[0], window_size[1], bottom_right_diagonal, + deterministic, cuda_graph, score_mod.is_none() ? nullptr : score_mod.ptr(), + score_mod_bprop.is_none() ? nullptr : score_mod_bprop.ptr(), + score_mod_tensors.is_none() ? nullptr : score_mod_tensors.ptr(), + score_mod_bprop_tensors.is_none() ? nullptr : score_mod_bprop_tensors.ptr(), + workspace.data(), at::cuda::getCurrentCUDAStream()); }); // allocate memory for workspace @@ -562,18 +561,18 @@ std::vector fused_attn_bwd( // execute kernel NVTE_SCOPED_GIL_RELEASE({ - nvte_fused_attn_bwd( - te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), - &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), - te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), - te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], - window_size[1], bottom_right_diagonal, deterministic, cuda_graph, - score_mod.is_none() ? nullptr : score_mod.ptr(), - score_mod_bprop.is_none() ? nullptr : score_mod_bprop.ptr(), - score_mod_tensors.is_none() ? nullptr : score_mod_tensors.ptr(), - score_mod_bprop_tensors.is_none() ? nullptr : score_mod_bprop_tensors.ptr(), - workspace.data(), at::cuda::getCurrentCUDAStream()); + nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), + te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), + te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), + te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, + max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, + softmax_type, window_size[0], window_size[1], bottom_right_diagonal, + deterministic, cuda_graph, score_mod.is_none() ? nullptr : score_mod.ptr(), + score_mod_bprop.is_none() ? nullptr : score_mod_bprop.ptr(), + score_mod_tensors.is_none() ? nullptr : score_mod_tensors.ptr(), + score_mod_bprop_tensors.is_none() ? nullptr : score_mod_bprop_tensors.ptr(), + workspace.data(), at::cuda::getCurrentCUDAStream()); }); // destroy tensor wrappers From ff8826e3dda9560bfef74509be6d5014f77265c9 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Mon, 16 Mar 2026 14:33:58 -0700 Subject: [PATCH 23/23] Fix JAX fused attention call signatures Signed-off-by: Vladimir Cherepanov --- transformer_engine/jax/csrc/extensions/attention.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index ad52d38db8..635b92b4d6 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -192,7 +192,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, bottom_right_diagonal, nullptr, + window_size_left, window_size_right, bottom_right_diagonal, nullptr, nullptr, query_workspace_tensor.data(), nullptr); } @@ -330,7 +330,7 @@ static void FusedAttnForwardImpl( rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, - nullptr, workspace_tensor.data(), stream); + nullptr, nullptr, workspace_tensor.data(), stream); nvte_tensor_pack_destroy(&aux_output_tensors); } @@ -481,7 +481,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, false, nullptr, - nullptr, query_workspace_tensor.data(), nullptr); + nullptr, nullptr, nullptr, query_workspace_tensor.data(), nullptr); } nvte_tensor_pack_destroy(&aux_input_tensors); @@ -607,7 +607,7 @@ static void FusedAttnBackwardImpl( q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, false, nullptr, - nullptr, workspace_tensor.data(), stream); + nullptr, nullptr, nullptr, workspace_tensor.data(), stream); nvte_tensor_pack_destroy(&aux_input_tensors); }