diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 60ade522e3..ae7ab9e1e7 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -10,6 +10,8 @@ import pytest import torch +import cudnn + from transformer_engine.pytorch.quantization import FP8GlobalStateManager, get_fp8_te_dtype from transformer_engine.common import recipe from transformer_engine.pytorch import ( @@ -99,6 +101,55 @@ def reset_global_fp8_state(): param_types.append(torch.bfloat16) param_types_lean = [torch.bfloat16] + +def _identity_score_mod(sdpa_graph, score_tensor): + return score_tensor + + +def _causal_score_mod(sdpa_graph, score_tensor): + row_index = sdpa_graph.gen_index(input=score_tensor, axis=2) + row_index.set_data_type(cudnn.data_type.INT32) + + col_index = sdpa_graph.gen_index(input=score_tensor, axis=3) + col_index.set_data_type(cudnn.data_type.INT32) + + causal_mask = sdpa_graph.cmp_ge( + input=row_index, + comparison=col_index, + compute_data_type=cudnn.data_type.BOOLEAN, + ) + causal_mask.set_data_type(cudnn.data_type.BOOLEAN) + + zero = sdpa_graph.sub(a=row_index, b=row_index, compute_data_type=cudnn.data_type.FLOAT) + zero.set_data_type(cudnn.data_type.FLOAT) + + neg_inf = sdpa_graph.log(input=zero, compute_data_type=cudnn.data_type.FLOAT) + neg_inf.set_data_type(cudnn.data_type.FLOAT) + + return sdpa_graph.binary_select(input0=score_tensor, input1=neg_inf, mask=causal_mask) + + +def _causal_score_mod_external(sdpa_graph, score_tensor, score_mod_tensors): + row_index = sdpa_graph.gen_index(input=score_tensor, axis=2) + row_index.set_data_type(cudnn.data_type.INT32) + + col_index = sdpa_graph.gen_index(input=score_tensor, axis=3) + col_index.set_data_type(cudnn.data_type.INT32) + + causal_mask = sdpa_graph.cmp_ge( + input=row_index, + comparison=col_index, + compute_data_type=cudnn.data_type.BOOLEAN, + ) + causal_mask.set_data_type(cudnn.data_type.BOOLEAN) + + return sdpa_graph.binary_select( + input0=score_tensor, + input1=score_mod_tensors["neg_inf"], + mask=causal_mask, + ) + + model_configs_base = { # test: ModelConfig(b, sq, hq, dqk) "base_1_0": ModelConfig(8, 128, 16, 64), @@ -1416,6 +1467,391 @@ def test_transformer_layer( torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols) +@pytest.mark.skipif(get_cudnn_version() < (9, 7, 0), reason="cuDNN 9.7.0+ is required.") +def test_multihead_attention_score_mod_forces_fused_backend(): + _attention_backends["attention_params"] = None + _attention_backends["backend_selection_requires_update"] = True + + hidden_size = 256 + num_heads = 4 + seq_len = 64 + batch_size = 2 + mha = MultiheadAttention( + hidden_size=hidden_size, + num_attention_heads=num_heads, + attention_dropout=0.0, + params_dtype=torch.float16, + device="cuda", + qkv_format="sbhd", + ).cuda() + mha.train() + + hidden_states = torch.randn( + seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float16, requires_grad=True + ) + output = mha( + hidden_states, + attn_mask_type="no_mask", + score_mod=_identity_score_mod, + score_mod_bprop=_identity_score_mod, + ) + output.sum().backward() + + assert _attention_backends["use_fused_attention"] + assert not _attention_backends["use_flash_attention"] + + +# Configs for score_mod tests: batch=2, seqlen=512, heads=16, head_dim=64 +# Note: score_mod disables other cuDNN subgraphs (e.g. causal masking), so only no_mask is tested. +model_configs_score_mod = { + "score_mod_0": ModelConfig(2, 512, 16, 64, attn_mask_type="no_mask"), + "score_mod_1": ModelConfig(4, 256, 8, 64, attn_mask_type="no_mask"), +} + + +@pytest.mark.skipif( + get_cudnn_version() < (9, 7, 0), reason="cuDNN 9.7.0+ is required for score_mod." +) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("model_configs", [model_configs_score_mod]) +@pytest.mark.parametrize("model", model_configs_score_mod.keys()) +def test_dpa_score_mod(dtype, model_configs, model): + """Test DotProductAttention with score_mod=None (plumbing test). + + Verifies that passing score_mod=None produces the same output as not passing it, + and that passing a Python callable as score_mod (identity: returns None) does not crash + and produces identical output when F16_arbitrary_seqlen backend is used. + """ + config = model_configs[model] + qkv_layout = "bshd_bshd_bshd" + qkv_format = "bshd" + + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "0" + _attention_backends["backend_selection_requires_update"] = True + + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=qkv_layout, + pad_between_seqs=False, + is_training=True, + deterministic=False, + ) + _, fused_attn_supported, _ = available_backends + + if ( + not fused_attn_supported + or FusedAttnBackend["F16_arbitrary_seqlen"] not in fused_attn_backends + ): + pytest.skip("F16_arbitrary_seqlen backend not available.") + + reset_rng_states() + + b, sq, h, d = config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim_qk + q = (torch.randn(b, sq, h, d, dtype=dtype, device="cuda") * 0.1).detach().requires_grad_(True) + k = (torch.randn(b, sq, h, d, dtype=dtype, device="cuda") * 0.1).detach().requires_grad_(True) + v = (torch.randn(b, sq, h, d, dtype=dtype, device="cuda") * 0.1).detach().requires_grad_(True) + cu_seqlens = torch.arange(0, (b + 1) * sq, sq, dtype=torch.int32, device="cuda") + out_grad = (torch.randn(b, sq, h * d, dtype=dtype, device="cuda") * 0.01).detach() + + _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() + _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) + + def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: + return _DUMMY_CUDA_RNG_STATE_TRACKER + + block = DotProductAttention( + h, + d, + attention_dropout=0.0, + qkv_format=qkv_format, + attn_mask_type=config.attn_mask_type, + sequence_parallel=False, + tp_size=1, + get_rng_state_tracker=get_dummy_cuda_rng_tracker, + tp_group=None, + layer_number=1, + ).to(dtype=dtype, device="cuda") + + # Reference: run without score_mod (score_mod=None by default) + out_ref = block( + q, + k, + v, + qkv_format=qkv_format, + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + max_seqlen_q=sq, + max_seqlen_kv=sq, + attn_mask_type=config.attn_mask_type, + ) + out_ref.backward(out_grad) + dq_ref = q.grad.clone() + dk_ref = k.grad.clone() + dv_ref = v.grad.clone() + q.grad = None + k.grad = None + v.grad = None + + # score_mod callable that returns None (identity: cuDNN trampoline returns original score) + score_mod_called = [False] + + def identity_score_mod(graph, score): + """Identity score_mod: verify we get called, return None to keep original score.""" + score_mod_called[0] = True + return None + + # Run with score_mod identity callable + out_sm = block( + q, + k, + v, + qkv_format=qkv_format, + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + max_seqlen_q=sq, + max_seqlen_kv=sq, + attn_mask_type=config.attn_mask_type, + score_mod=identity_score_mod, + score_mod_bprop=identity_score_mod, + ) + out_sm.backward(out_grad) + dq_sm = q.grad.clone() + dk_sm = k.grad.clone() + dv_sm = v.grad.clone() + + assert score_mod_called[0], "score_mod callable was not invoked by the cuDNN trampoline." + + tols = dict(atol=1e-3, rtol=1e-3) + if dtype == torch.bfloat16: + tols = dict(atol=1.5e-2, rtol=1.5e-2) + torch.testing.assert_close(out_sm, out_ref, **tols) + torch.testing.assert_close(dq_sm, dq_ref, **tols) + torch.testing.assert_close(dk_sm, dk_ref, **tols) + torch.testing.assert_close(dv_sm, dv_ref, **tols) + + +@pytest.mark.skipif( + get_cudnn_version() < (9, 7, 0), reason="cuDNN 9.7.0+ is required for score_mod." +) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("model_configs", [model_configs_score_mod]) +@pytest.mark.parametrize("model", model_configs_score_mod.keys()) +def test_dpa_score_mod_causal(dtype, model_configs, model): + """Test DotProductAttention causal masking implemented via score_mod.""" + + config = model_configs[model] + qkv_layout = "bshd_bshd_bshd" + qkv_format = "bshd" + + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "0" + _attention_backends["backend_selection_requires_update"] = True + + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=qkv_layout, + pad_between_seqs=False, + is_training=True, + deterministic=False, + ) + _, fused_attn_supported, _ = available_backends + + if ( + not fused_attn_supported + or FusedAttnBackend["F16_arbitrary_seqlen"] not in fused_attn_backends + ): + pytest.skip("F16_arbitrary_seqlen backend not available.") + + reset_rng_states() + + b, sq, h, d = config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim_qk + q = (torch.randn(b, sq, h, d, dtype=dtype, device="cuda") * 0.1).detach().requires_grad_(True) + k = (torch.randn(b, sq, h, d, dtype=dtype, device="cuda") * 0.1).detach().requires_grad_(True) + v = (torch.randn(b, sq, h, d, dtype=dtype, device="cuda") * 0.1).detach().requires_grad_(True) + cu_seqlens = torch.arange(0, (b + 1) * sq, sq, dtype=torch.int32, device="cuda") + out_grad = (torch.randn(b, sq, h * d, dtype=dtype, device="cuda") * 0.01).detach() + + _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() + _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) + + def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: + return _DUMMY_CUDA_RNG_STATE_TRACKER + + block = DotProductAttention( + h, + d, + attention_dropout=0.0, + qkv_format=qkv_format, + attn_mask_type="no_mask", + sequence_parallel=False, + tp_size=1, + get_rng_state_tracker=get_dummy_cuda_rng_tracker, + tp_group=None, + layer_number=1, + ).to(dtype=dtype, device="cuda") + + out_ref = block( + q, + k, + v, + qkv_format=qkv_format, + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + max_seqlen_q=sq, + max_seqlen_kv=sq, + attn_mask_type="causal", + ) + out_ref.backward(out_grad) + dq_ref = q.grad.clone() + dk_ref = k.grad.clone() + dv_ref = v.grad.clone() + q.grad = None + k.grad = None + v.grad = None + + out_sm = block( + q, + k, + v, + qkv_format=qkv_format, + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + max_seqlen_q=sq, + max_seqlen_kv=sq, + attn_mask_type="no_mask", + score_mod=_causal_score_mod, + ) + out_sm.backward(out_grad) + dq_sm = q.grad.clone() + dk_sm = k.grad.clone() + dv_sm = v.grad.clone() + + tols = dict(atol=1e-3, rtol=1e-3) + if dtype == torch.bfloat16: + tols = dict(atol=1.5e-2, rtol=1.5e-2) + + torch.testing.assert_close(out_sm, out_ref, **tols) + torch.testing.assert_close(dq_sm, dq_ref, **tols) + torch.testing.assert_close(dk_sm, dk_ref, **tols) + torch.testing.assert_close(dv_sm, dv_ref, **tols) + + +@pytest.mark.skipif( + get_cudnn_version() < (9, 7, 0), reason="cuDNN 9.7.0+ is required for score_mod." +) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("model_configs", [model_configs_score_mod]) +@pytest.mark.parametrize("model", model_configs_score_mod.keys()) +@pytest.mark.parametrize( + "neg_inf_device", ["cuda", "cpu"], ids=["cuda_tensor", "cpu_by_value_tensor"] +) +def test_dpa_score_mod_causal_external_neg_inf(dtype, model_configs, model, neg_inf_device): + """Test DotProductAttention causal masking via score_mod with external variant-pack tensor.""" + + config = model_configs[model] + qkv_layout = "bshd_bshd_bshd" + qkv_format = "bshd" + + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "0" + _attention_backends["backend_selection_requires_update"] = True + + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=qkv_layout, + pad_between_seqs=False, + is_training=True, + deterministic=False, + ) + _, fused_attn_supported, _ = available_backends + + if ( + not fused_attn_supported + or FusedAttnBackend["F16_arbitrary_seqlen"] not in fused_attn_backends + ): + pytest.skip("F16_arbitrary_seqlen backend not available.") + + reset_rng_states() + + b, sq, h, d = config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim_qk + q = (torch.randn(b, sq, h, d, dtype=dtype, device="cuda") * 0.1).detach().requires_grad_(True) + k = (torch.randn(b, sq, h, d, dtype=dtype, device="cuda") * 0.1).detach().requires_grad_(True) + v = (torch.randn(b, sq, h, d, dtype=dtype, device="cuda") * 0.1).detach().requires_grad_(True) + cu_seqlens = torch.arange(0, (b + 1) * sq, sq, dtype=torch.int32, device="cuda") + out_grad = (torch.randn(b, sq, h * d, dtype=dtype, device="cuda") * 0.01).detach() + neg_inf = torch.full((1, 1, 1, 1), float("-inf"), dtype=torch.float32, device=neg_inf_device) + + _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() + _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) + + def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: + return _DUMMY_CUDA_RNG_STATE_TRACKER + + block = DotProductAttention( + h, + d, + attention_dropout=0.0, + qkv_format=qkv_format, + attn_mask_type="no_mask", + sequence_parallel=False, + tp_size=1, + get_rng_state_tracker=get_dummy_cuda_rng_tracker, + tp_group=None, + layer_number=1, + ).to(dtype=dtype, device="cuda") + + out_ref = block( + q, + k, + v, + qkv_format=qkv_format, + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + max_seqlen_q=sq, + max_seqlen_kv=sq, + attn_mask_type="causal", + ) + out_ref.backward(out_grad) + dq_ref = q.grad.clone() + dk_ref = k.grad.clone() + dv_ref = v.grad.clone() + q.grad = None + k.grad = None + v.grad = None + + out_sm = block( + q, + k, + v, + qkv_format=qkv_format, + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + max_seqlen_q=sq, + max_seqlen_kv=sq, + attn_mask_type="no_mask", + score_mod=_causal_score_mod_external, + score_mod_tensors={"neg_inf": neg_inf}, + ) + out_sm.backward(out_grad) + dq_sm = q.grad.clone() + dk_sm = k.grad.clone() + dv_sm = v.grad.clone() + + tols = dict(atol=1.5e-2, rtol=1.5e-2) + + torch.testing.assert_close(out_sm, out_ref, **tols) + torch.testing.assert_close(dq_sm, dq_ref, **tols) + torch.testing.assert_close(dk_sm, dk_ref, **tols) + torch.testing.assert_close(dv_sm, dv_ref, **tols) + + @pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_te_layer]) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index b9e2b907e0..bbaf449076 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -4,6 +4,8 @@ cmake_minimum_required(VERSION 3.21) +include(FetchContent) + # Language options set(CMAKE_CXX_STANDARD 17) set(CMAKE_CUDA_STANDARD 17) @@ -97,6 +99,21 @@ set(CUTLASS_TOOLS_INCLUDE_DIR # Python find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) +find_package(pybind11 CONFIG REQUIRED) + +option(NVTE_USE_SYSTEM_DLPACK "Whether to use a system-installed dlpack package" OFF) +if(NVTE_USE_SYSTEM_DLPACK) + find_package(dlpack REQUIRED) + set(NVTE_DLPACK_TARGET dlpack::dlpack) +else() + FetchContent_Declare( + dlpack + GIT_REPOSITORY https://github.com/dmlc/dlpack + GIT_TAG v1.1 + ) + FetchContent_MakeAvailable(dlpack) + set(NVTE_DLPACK_TARGET dlpack) +endif() # Configure Transformer Engine library include_directories(${PROJECT_SOURCE_DIR}/..) @@ -259,6 +276,8 @@ target_link_libraries(transformer_engine PUBLIC CUDA::cublas CUDA::cudart CUDNN::cudnn_all) +target_link_libraries(transformer_engine PRIVATE Python::Module pybind11::headers) +target_link_libraries(transformer_engine PRIVATE ${NVTE_DLPACK_TARGET}) target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) @@ -268,6 +287,9 @@ target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_ target_include_directories(transformer_engine PRIVATE ${CUTLASS_INCLUDE_DIR} ${CUTLASS_TOOLS_INCLUDE_DIR}) +target_include_directories(transformer_engine PRIVATE + ${Python_INCLUDE_DIRS} + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/python/pygraph") # Compiling Userbuffers with native MPI bootstrapping requires linking against MPI option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 6a136c67e4..1828e67df2 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -535,19 +535,17 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( } // NVTE fused attention FWD with separate Q, K and V -void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, - const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { +void nvte_fused_attn_fwd( + const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, + bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + void *score_mod, void *score_mod_tensors, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -635,10 +633,11 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, input_V, - input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, - input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, - input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); + window_size_left, window_size_right, bottom_right_diagonal, score_mod, score_mod_tensors, + input_Q, input_K, input_V, input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, + input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, + input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, + wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. " @@ -670,7 +669,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, bool cuda_graph, - NVTETensor workspace, cudaStream_t stream) { + void *score_mod, void *score_mod_bprop, void *score_mod_tensors, + void *score_mod_bprop_tensors, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -743,8 +743,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso fused_attn_arbitrary_seqlen_bwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, - bottom_right_diagonal, deterministic, input_Q, input_K, input_V, input_O, input_dO, - input_Bias, input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias, + bottom_right_diagonal, deterministic, score_mod, score_mod_bprop, score_mod_tensors, + score_mod_bprop_tensors, input_Q, input_K, input_V, input_O, input_dO, input_Bias, + input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); #else diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index eb2ebcff39..654166cebe 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -8,10 +8,16 @@ #include #include #include +#include +#include +#include +#include #include +#include #include +#include "../../../3rdparty/cudnn-frontend/python/pygraph/pygraph.h" #include "../common.h" #include "../cudnn_utils.h" #include "../util/cuda_runtime.h" @@ -48,6 +54,287 @@ namespace transformer_engine { namespace fused_attn { +namespace py = pybind11; + +namespace { + +using TensorAttr = std::shared_ptr; +using ExtraTensorList = std::vector>; + +constexpr char kDlpackCapsuleName[] = "dltensor"; + +struct DlpackTensorView { + PyObject *capsule = nullptr; + DLManagedTensor *managed = nullptr; + + DlpackTensorView() = default; + DlpackTensorView(PyObject *capsule_obj, DLManagedTensor *managed_tensor) + : capsule(capsule_obj), managed(managed_tensor) {} + + DlpackTensorView(const DlpackTensorView &) = delete; + auto operator=(const DlpackTensorView &) -> DlpackTensorView & = delete; + + DlpackTensorView(DlpackTensorView &&other) noexcept + : capsule(other.capsule), managed(other.managed) { + other.capsule = nullptr; + other.managed = nullptr; + } + + auto operator=(DlpackTensorView &&other) noexcept -> DlpackTensorView & { + if (this != &other) { + reset(); + capsule = other.capsule; + managed = other.managed; + other.capsule = nullptr; + other.managed = nullptr; + } + return *this; + } + + ~DlpackTensorView() { reset(); } + + void reset() { + if (capsule == nullptr) { + return; + } + py::gil_scoped_acquire gil; + Py_DECREF(capsule); + capsule = nullptr; + managed = nullptr; + } +}; + +using DlpackTensorViews = std::vector; + +cudnn_frontend::DataType_t convert_dlpack_dtype(const DLDataType &dtype) { + switch (dtype.code) { + case DLDataTypeCode::kDLUInt: + if (dtype.bits == 8) { + return cudnn_frontend::DataType_t::UINT8; + } + break; + case DLDataTypeCode::kDLInt: + switch (dtype.bits) { + case 8: + return cudnn_frontend::DataType_t::INT8; + case 32: + return cudnn_frontend::DataType_t::INT32; + case 64: + return cudnn_frontend::DataType_t::INT64; + } + break; + case DLDataTypeCode::kDLFloat: + switch (dtype.bits) { + case 16: + return cudnn_frontend::DataType_t::HALF; + case 32: + return cudnn_frontend::DataType_t::FLOAT; + case 64: + return cudnn_frontend::DataType_t::DOUBLE; + } + break; + case DLDataTypeCode::kDLBfloat: + if (dtype.bits == 16) { + return cudnn_frontend::DataType_t::BFLOAT16; + } + break; + case DLDataTypeCode::kDLBool: + if (dtype.bits == 8) { + return cudnn_frontend::DataType_t::BOOLEAN; + } + break; + } + return cudnn_frontend::DataType_t::NOT_SET; +} + +std::vector get_sorted_extra_tensor_names(const py::dict &extra_tensors) { + std::vector names; + names.reserve(extra_tensors.size()); + for (auto item : extra_tensors) { + names.push_back(py::cast(item.first)); + } + std::sort(names.begin(), names.end()); + return names; +} + +template +decltype(auto) with_dlpack_tensor(py::handle tensor_obj, Fn &&fn) { + NVTE_CHECK(py::hasattr(tensor_obj, "__dlpack__"), + "score_mod_tensors entries must support __dlpack__()."); + py::capsule capsule = tensor_obj.attr("__dlpack__")(); + NVTE_CHECK(!capsule.is_none(), "Failed to retrieve DLPack capsule for score_mod_tensors entry."); + auto *managed = + static_cast(PyCapsule_GetPointer(capsule.ptr(), kDlpackCapsuleName)); + NVTE_CHECK(managed != nullptr, "Invalid DLPack capsule in score_mod_tensors entry."); + return std::forward(fn)(*managed); +} + +DlpackTensorView make_dlpack_tensor_view(py::handle tensor_obj) { + NVTE_CHECK(py::hasattr(tensor_obj, "__dlpack__"), + "score_mod_tensors entries must support __dlpack__()."); + py::capsule capsule = tensor_obj.attr("__dlpack__")(); + NVTE_CHECK(!capsule.is_none(), "Failed to retrieve DLPack capsule for score_mod_tensors entry."); + auto *managed = + static_cast(PyCapsule_GetPointer(capsule.ptr(), kDlpackCapsuleName)); + NVTE_CHECK(managed != nullptr, "Invalid DLPack capsule in score_mod_tensors entry."); + Py_INCREF(capsule.ptr()); + return DlpackTensorView(capsule.ptr(), managed); +} + +TensorAttr create_tensor_attr_from_dlpack( + const std::shared_ptr &graph, py::handle tensor_obj, + const std::string &name) { + return with_dlpack_tensor(tensor_obj, [&](const DLManagedTensor &managed) -> TensorAttr { + const auto device_type = managed.dl_tensor.device.device_type; + NVTE_CHECK(device_type == kDLCPU || device_type == kDLCUDAHost || device_type == kDLCUDA || + device_type == kDLCUDAManaged, + "Invalid device type in score_mod_tensors entry."); + + const auto ndim = managed.dl_tensor.ndim; + std::vector dims(managed.dl_tensor.shape, managed.dl_tensor.shape + ndim); + const auto tensor_dtype = convert_dlpack_dtype(managed.dl_tensor.dtype); + NVTE_CHECK(tensor_dtype != cudnn_frontend::DataType_t::NOT_SET, + "Unsupported DLPack dtype in score_mod_tensors entry."); + + auto props = cudnn_frontend::graph::Tensor_attributes() + .set_name(name) + .set_data_type(tensor_dtype) + .set_is_virtual(false) + .set_is_pass_by_value(device_type == kDLCPU) + .set_dim(dims); + + if (managed.dl_tensor.strides == nullptr) { + auto stride_order = cudnn_frontend::detail::generate_row_major_stride_order(ndim); + props.set_stride(cudnn_frontend::detail::generate_stride(dims, stride_order)); + } else { + std::vector strides(managed.dl_tensor.strides, managed.dl_tensor.strides + ndim); + props.set_stride(strides); + } + + return graph->tensor(props); + }); +} + +template +void hash_combine(std::uint64_t &seed, const T &value) { + seed ^= std::hash{}(value) + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2); +} + +std::uint64_t get_extra_tensor_signature(void *extra_tensors_ptr) { + if (extra_tensors_ptr == nullptr) { + return 0; + } + + py::gil_scoped_acquire gil; + py::dict extra_tensors = + py::reinterpret_borrow(reinterpret_cast(extra_tensors_ptr)); + std::uint64_t signature = 0; + + for (const auto &name : get_sorted_extra_tensor_names(extra_tensors)) { + py::handle tensor_obj = extra_tensors[py::str(name)]; + hash_combine(signature, name); + with_dlpack_tensor(tensor_obj, [&](const DLManagedTensor &managed) { + hash_combine(signature, managed.dl_tensor.device.device_type); + hash_combine(signature, managed.dl_tensor.device.device_id); + hash_combine(signature, managed.dl_tensor.dtype.code); + hash_combine(signature, managed.dl_tensor.dtype.bits); + hash_combine(signature, managed.dl_tensor.dtype.lanes); + hash_combine(signature, managed.dl_tensor.ndim); + for (int i = 0; i < managed.dl_tensor.ndim; ++i) { + hash_combine(signature, managed.dl_tensor.shape[i]); + hash_combine(signature, managed.dl_tensor.strides ? managed.dl_tensor.strides[i] : -1); + } + }); + } + + return signature; +} + +py::dict get_score_mod_tensor_attrs(const std::shared_ptr &graph, + void *extra_tensors_ptr, ExtraTensorList *extra_tensor_attrs) { + py::dict callback_tensors; + if (extra_tensors_ptr == nullptr) { + return callback_tensors; + } + + if (extra_tensor_attrs != nullptr && !extra_tensor_attrs->empty()) { + for (const auto &[name, tensor_attr] : *extra_tensor_attrs) { + callback_tensors[py::str(name)] = py::cast(tensor_attr); + } + return callback_tensors; + } + + py::dict extra_tensors = + py::reinterpret_borrow(reinterpret_cast(extra_tensors_ptr)); + for (const auto &name : get_sorted_extra_tensor_names(extra_tensors)) { + py::handle tensor_obj = extra_tensors[py::str(name)]; + auto tensor_attr = create_tensor_attr_from_dlpack(graph, tensor_obj, name); + callback_tensors[py::str(name)] = py::cast(tensor_attr); + if (extra_tensor_attrs != nullptr) { + extra_tensor_attrs->emplace_back(name, std::move(tensor_attr)); + } + } + + return callback_tensors; +} + +DlpackTensorViews extend_variant_pack_with_extra_tensors( + void *extra_tensors_ptr, const ExtraTensorList &extra_tensor_attrs, + std::unordered_map &variant_pack) { + DlpackTensorViews views; + if (extra_tensors_ptr == nullptr || extra_tensor_attrs.empty()) { + return views; + } + + py::gil_scoped_acquire gil; + py::dict extra_tensors = + py::reinterpret_borrow(reinterpret_cast(extra_tensors_ptr)); + views.reserve(extra_tensor_attrs.size()); + for (const auto &[name, tensor_attr] : extra_tensor_attrs) { + py::str key(name); + NVTE_CHECK(extra_tensors.contains(key), "Missing score_mod tensor entry: ", name); + auto view = make_dlpack_tensor_view(extra_tensors[key]); + variant_pack[tensor_attr] = + static_cast(view.managed->dl_tensor.data) + view.managed->dl_tensor.byte_offset; + views.emplace_back(std::move(view)); + } + return views; +} + +auto make_attention_score_modifier(void *callback_ptr, void *extra_tensors_ptr, + ExtraTensorList *extra_tensor_attrs) + -> std::function( + std::shared_ptr, + std::shared_ptr)> { + if (callback_ptr == nullptr) { + return nullptr; + } + + auto *callback = reinterpret_cast(callback_ptr); + + return [callback, extra_tensors_ptr, extra_tensor_attrs]( + std::shared_ptr graph, + std::shared_ptr score_tensor) mutable { + py::gil_scoped_acquire gil; + py::module_::import("cudnn"); + py::function callback_fn = py::reinterpret_borrow(callback); + + auto py_graph = std::make_shared(graph); + py::dict callback_tensors = + get_score_mod_tensor_attrs(graph, extra_tensors_ptr, extra_tensor_attrs); + + py::object result = callback_tensors.empty() + ? callback_fn(*py_graph, score_tensor) + : callback_fn(*py_graph, score_tensor, callback_tensors); + if (result.is_none()) { + return score_tensor; + } + return result.cast>(); + }; +} + +} // namespace + void fused_attn_arbitrary_seqlen_fwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t num_pages_k, int64_t num_pages_v, @@ -56,10 +343,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl( bool is_training, bool return_max_logit, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias, - void *devPtrSoftmaxOffset, void *devPtrS1, void *devPtrS2, void *devPtrO, - void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, - void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, + bool bottom_right_diagonal, void *score_mod, void *score_mod_tensors, void *devPtrQ, + void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1, + void *devPtrS2, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, + void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -139,32 +426,37 @@ void fused_attn_arbitrary_seqlen_fwd_impl( cudnn_frontend::DataType_t::NOT_SET, cudnn_frontend::DataType_t::NOT_SET, cudnn_frontend::DataType_t::NOT_SET, + reinterpret_cast(score_mod), + 0, + get_extra_tensor_signature(score_mod_tensors), + 0, return_max_logit, }; namespace fe = cudnn_frontend; using graph_and_tensors = std::tuple, - std::shared_ptr, // Q - std::shared_ptr, // K - std::shared_ptr, // V - std::shared_ptr, // attn_scale - std::shared_ptr, // O - std::shared_ptr, // S1 - std::shared_ptr, // S2 - std::shared_ptr, // bias - std::shared_ptr, // softmax_offset - std::shared_ptr, // seq_q - std::shared_ptr, // seq_kv - std::shared_ptr, // page_table_k - std::shared_ptr, // page_table_v - std::shared_ptr, // offset_q - std::shared_ptr, // offset_k - std::shared_ptr, // offset_v - std::shared_ptr, // offset_o - std::shared_ptr, // offset_stats - std::shared_ptr, // dropout_seed - std::shared_ptr>; // dropout_offset + std::shared_ptr, // Q + std::shared_ptr, // K + std::shared_ptr, // V + std::shared_ptr, // attn_scale + std::shared_ptr, // O + std::shared_ptr, // S1 + std::shared_ptr, // S2 + std::shared_ptr, // bias + std::shared_ptr, // softmax_offset + std::shared_ptr, // seq_q + std::shared_ptr, // seq_kv + std::shared_ptr, // page_table_k + std::shared_ptr, // page_table_v + std::shared_ptr, // offset_q + std::shared_ptr, // offset_k + std::shared_ptr, // offset_v + std::shared_ptr, // offset_o + std::shared_ptr, // offset_stats + std::shared_ptr, // dropout_seed + std::shared_ptr, // dropout_offset + ExtraTensorList>; // score_mod extra tensors using CacheType = std::map; static thread_local CacheType sdpa_f16_fprop_cache; @@ -190,6 +482,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::shared_ptr offset_q, offset_k, offset_v, offset_o, offset_stats; std::shared_ptr dropout_seed, dropout_offset; + ExtraTensorList score_mod_extra_tensors; std::vector q_stride(4); std::vector k_stride(4); @@ -257,6 +550,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl( .set_causal_mask(is_causal) .set_causal_mask_bottom_right(is_bottom_right) .set_attn_scale(attn_scale); + if (score_mod != nullptr) { + sdpa_options.set_score_mod( + make_attention_score_modifier(score_mod, score_mod_tensors, &score_mod_extra_tensors)); + } fe::DiagonalAlignment_t const &diagonal_alignment = bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT @@ -419,10 +716,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl( NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - auto return_tuple = - std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, - softmax_offset_tuple, padding_tuple, page_table_tuple, offset_qo_tuple, - offset_kv_tuple, offset_s_tuple, dropout_tuple); + auto return_tuple = std::tuple_cat( + std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, + softmax_offset_tuple, padding_tuple, page_table_tuple, offset_qo_tuple, offset_kv_tuple, + offset_s_tuple, dropout_tuple, std::make_tuple(score_mod_extra_tensors)); cache.insert({descriptor, return_tuple}); return return_tuple; @@ -430,7 +727,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( auto [mha_graph, Q, K, V, attn_scale, O, S1, S2, bias, softmax_offset, seq_q, seq_kv, page_table_k, page_table_v, offset_q, offset_o, offset_k, offset_v, offset_stats, - dropout_seed, dropout_offset] = get_graph(sdpa_f16_fprop_cache, descriptor); + dropout_seed, dropout_offset, score_mod_extra_tensors] = + get_graph(sdpa_f16_fprop_cache, descriptor); // Exit to request upper level API to allocate memory if needed // n.b. Care should be taken to align each of the added worksapce tensors to their type. @@ -543,6 +841,9 @@ void fused_attn_arbitrary_seqlen_fwd_impl( variant_pack[softmax_offset] = devPtrSoftmaxOffset; } + auto score_mod_tensor_views = extend_variant_pack_with_extra_tensors( + score_mod_tensors, score_mod_extra_tensors, variant_pack); + NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); } catch (cudnn_frontend::cudnnException &e) { NVTE_ERROR(e.what()); @@ -555,7 +856,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( int64_t bias_sq, int64_t bias_skv, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool deterministic, void *devPtrQ, void *devPtrKTranspose, + bool bottom_right_diagonal, bool deterministic, void *score_mod, void *score_mod_bprop, + void *score_mod_tensors, void *score_mod_bprop_tensors, void *devPtrQ, void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias, void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, @@ -643,35 +945,41 @@ void fused_attn_arbitrary_seqlen_bwd_impl( cudnn_frontend::DataType_t::NOT_SET, cudnn_frontend::DataType_t::NOT_SET, cudnn_frontend::DataType_t::NOT_SET, + reinterpret_cast(score_mod), + reinterpret_cast(score_mod_bprop), + get_extra_tensor_signature(score_mod_tensors), + get_extra_tensor_signature(score_mod_bprop_tensors), false, }; namespace fe = cudnn_frontend; using graph_and_tensors = std::tuple, - std::shared_ptr, // q - std::shared_ptr, // k - std::shared_ptr, // v - std::shared_ptr, // o - std::shared_ptr, // dO - std::shared_ptr, // stats - std::shared_ptr, // attn_scale - std::shared_ptr, // dQ - std::shared_ptr, // dK - std::shared_ptr, // dV - std::shared_ptr, // bias - std::shared_ptr, // dBias - std::shared_ptr, // softmax_offset - std::shared_ptr, // d_softmax_offset - std::shared_ptr, // seq_q - std::shared_ptr, // seq_kv - std::shared_ptr, // offset_q - std::shared_ptr, // offset_k - std::shared_ptr, // offset_v - std::shared_ptr, // offset_o - std::shared_ptr, // offset_stats - std::shared_ptr, // dropout_seed - std::shared_ptr>; // dropout_offset + std::shared_ptr, // q + std::shared_ptr, // k + std::shared_ptr, // v + std::shared_ptr, // o + std::shared_ptr, // dO + std::shared_ptr, // stats + std::shared_ptr, // attn_scale + std::shared_ptr, // dQ + std::shared_ptr, // dK + std::shared_ptr, // dV + std::shared_ptr, // bias + std::shared_ptr, // dBias + std::shared_ptr, // softmax_offset + std::shared_ptr, // d_softmax_offset + std::shared_ptr, // seq_q + std::shared_ptr, // seq_kv + std::shared_ptr, // offset_q + std::shared_ptr, // offset_k + std::shared_ptr, // offset_v + std::shared_ptr, // offset_o + std::shared_ptr, // offset_stats + std::shared_ptr, // dropout_seed + std::shared_ptr, // dropout_offset + ExtraTensorList, // score_mod extra tensors + ExtraTensorList>; // score_mod_bprop extra tensors using CacheType = std::map; static thread_local CacheType sdpa_f16_bprop_cache; @@ -697,6 +1005,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( std::shared_ptr offset_q, offset_k, offset_v, offset_o, offset_stats; std::shared_ptr dropout_seed, dropout_offset; + ExtraTensorList score_mod_extra_tensors; + ExtraTensorList score_mod_bprop_extra_tensors; std::vector q_stride(4); std::vector k_stride(4); @@ -790,6 +1100,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl( .set_causal_mask(is_causal) .set_causal_mask_bottom_right(is_bottom_right) .set_attn_scale(attn_scale); + if (score_mod != nullptr) { + sdpa_backward_options.set_score_mod( + make_attention_score_modifier(score_mod, score_mod_tensors, &score_mod_extra_tensors)); + } + if (score_mod_bprop != nullptr) { + sdpa_backward_options.set_score_mod_bprop(make_attention_score_modifier( + score_mod_bprop, score_mod_bprop_tensors, &score_mod_bprop_extra_tensors)); + } if (is_ragged_q && cudnn_runtime_version >= 90600) { sdpa_backward_options.set_max_total_seq_len_q(s_q); @@ -926,9 +1244,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl( NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, - softmax_offset_tuple, padding_tuple, offset_qo_tuple, - offset_kv_tuple, offset_s_tuple, dropout_tuple); + auto return_tuple = std::tuple_cat( + std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, softmax_offset_tuple, + padding_tuple, offset_qo_tuple, offset_kv_tuple, offset_s_tuple, dropout_tuple, + std::make_tuple(score_mod_extra_tensors), std::make_tuple(score_mod_bprop_extra_tensors)); cache.insert({descriptor, return_tuple}); return return_tuple; @@ -936,7 +1255,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV, bias, dBias, softmax_offset, d_softmax_offset, seq_q, seq_kv, offset_q, offset_o, offset_k, offset_v, offset_stats, - dropout_seed, dropout_offset] = get_graph(sdpa_f16_bprop_cache, descriptor); + dropout_seed, dropout_offset, score_mod_extra_tensors, score_mod_bprop_extra_tensors] = + get_graph(sdpa_f16_bprop_cache, descriptor); // Exit to request upper level API to allocate memory if needed // n.b. Care should be taken to align each of the added worksapce tensors to their type. @@ -1053,6 +1373,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl( variant_pack[d_softmax_offset] = devPtrdSoftmaxOffset; } + auto score_mod_tensor_views = extend_variant_pack_with_extra_tensors( + score_mod_tensors, score_mod_extra_tensors, variant_pack); + auto score_mod_bprop_tensor_views = extend_variant_pack_with_extra_tensors( + score_mod_bprop_tensors, score_mod_bprop_extra_tensors, variant_pack); + NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); } catch (cudnn_frontend::cudnnException &e) { NVTE_ERROR(e.what()); @@ -1069,9 +1394,10 @@ void fused_attn_arbitrary_seqlen_fwd( bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, - const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + void *score_mod, void *score_mod_tensors, const Tensor *input_Q, const Tensor *input_K, + const Tensor *input_V, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, + Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -1208,11 +1534,11 @@ void fused_attn_arbitrary_seqlen_fwd( max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, bias_skv, is_training, return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, - softmax_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, - devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, - devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, score_mod, + score_mod_tensors, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, + devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, + devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, + get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1235,8 +1561,9 @@ void fused_attn_arbitrary_seqlen_bwd( size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, - const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, + bool deterministic, void *score_mod, void *score_mod_bprop, void *score_mod_tensors, + void *score_mod_bprop_tensors, const Tensor *input_Q, const Tensor *input_K, + const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, @@ -1306,7 +1633,8 @@ void fused_attn_arbitrary_seqlen_bwd( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, bias_sq, bias_skv, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, + window_size_right, bottom_right_diagonal, deterministic, score_mod, score_mod_bprop, + score_mod_tensors, score_mod_bprop_tensors, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 4dd7f3d1da..a6e2bf25fa 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -26,9 +26,10 @@ void fused_attn_arbitrary_seqlen_fwd( bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, - const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + void *score_mod, void *score_mod_tensors, const Tensor *input_Q, const Tensor *input_K, + const Tensor *input_V, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, + Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); @@ -38,8 +39,9 @@ void fused_attn_arbitrary_seqlen_bwd( size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, - const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, + bool deterministic, void *score_mod, void *score_mod_bprop, void *score_mod_tensors, + void *score_mod_bprop_tensors, const Tensor *input_Q, const Tensor *input_K, + const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 08a56cda6b..73aca9879d 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -118,6 +118,10 @@ struct FADescriptor_v1 { cudnn_frontend::DataType_t o_tensor_type; cudnn_frontend::DataType_t do_tensor_type; cudnn_frontend::DataType_t dqkv_tensor_type; + std::uintptr_t score_mod_id; + std::uintptr_t score_mod_bprop_id; + std::uint64_t score_mod_tensors_id; + std::uint64_t score_mod_bprop_tensors_id; bool generate_max_sum_exp; bool operator<(const FADescriptor_v1 &rhs) const { @@ -126,7 +130,8 @@ struct FADescriptor_v1 { bias_skv, attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, bias_type, qkv_tensor_type, o_tensor_type, do_tensor_type, - dqkv_tensor_type, generate_max_sum_exp) < + dqkv_tensor_type, score_mod_id, score_mod_bprop_id, score_mod_tensors_id, + score_mod_bprop_tensors_id, generate_max_sum_exp) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.bias_sq, rhs.bias_skv, @@ -134,7 +139,9 @@ struct FADescriptor_v1 { rhs.mask_type, rhs.softmax_type, rhs.window_size_left, rhs.window_size_right, rhs.bottom_right_diagonal, rhs.deterministic, rhs.bias_type, rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type, - rhs.dqkv_tensor_type, rhs.generate_max_sum_exp); + rhs.dqkv_tensor_type, rhs.score_mod_id, rhs.score_mod_bprop_id, + rhs.score_mod_tensors_id, rhs.score_mod_bprop_tensors_id, + rhs.generate_max_sum_exp); } }; diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 8169bf22e2..a12a1dbb4f 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -280,22 +280,22 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. + * \param[in] score_mod Opaque pointer to a score modifier callback. + * \param[in] score_mod_tensors Opaque pointer to a mapping of extra score modifier tensors. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, - const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream); +void nvte_fused_attn_fwd( + const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, + bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + void *score_mod, void *score_mod_tensors, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with separate Q, K and V. * @@ -355,6 +355,10 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] cuda_graph Whether cuda graph capture is enabled or not. + * \param[in] score_mod Opaque pointer to a score modifier callback. + * \param[in] score_mod_bprop Opaque pointer to a score modifier backward callback. + * \param[in] score_mod_tensors Opaque pointer to a mapping of extra score modifier tensors. + * \param[in] score_mod_bprop_tensors Opaque pointer to a mapping of extra score modifier backward tensors. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ @@ -370,7 +374,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, bool cuda_graph, - NVTETensor workspace, cudaStream_t stream); + void *score_mod, void *score_mod_bprop, void *score_mod_tensors, + void *score_mod_bprop_tensors, NVTETensor workspace, cudaStream_t stream); /*! \brief Update the RNG state with the seed and calculated offset. * diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 92e67ac191..635b92b4d6 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -192,8 +192,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, bottom_right_diagonal, query_workspace_tensor.data(), - nullptr); + window_size_left, window_size_right, bottom_right_diagonal, nullptr, nullptr, + query_workspace_tensor.data(), nullptr); } nvte_tensor_pack_destroy(&aux_output_tensors); @@ -322,14 +322,15 @@ static void FusedAttnForwardImpl( auto k_tensor = TensorWrapper(k_ptr, k_shape, dtype); auto v_tensor = TensorWrapper(v_ptr, v_shape, dtype); - nvte_fused_attn_fwd( - q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), - softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), - k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), - rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false, - scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, bottom_right_diagonal, workspace_tensor.data(), stream); + nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), + softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), + &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), + dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), + rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, + false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, + nullptr, nullptr, workspace_tensor.data(), stream); nvte_tensor_pack_destroy(&aux_output_tensors); } @@ -479,8 +480,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, bottom_right_diagonal, deterministic, false, - query_workspace_tensor.data(), nullptr); + window_size_right, bottom_right_diagonal, deterministic, false, nullptr, + nullptr, nullptr, nullptr, query_workspace_tensor.data(), nullptr); } nvte_tensor_pack_destroy(&aux_input_tensors); @@ -596,17 +597,17 @@ static void FusedAttnBackwardImpl( } } - nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), - doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), - dbias_tensor.data(), dsoftmax_offset_tensor.data(), - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, - kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, - mask_type, softmax_type, window_size_left, window_size_right, - bottom_right_diagonal, deterministic, false, workspace_tensor.data(), stream); + nvte_fused_attn_bwd( + q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), + doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), dbias_tensor.data(), + dsoftmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, bottom_right_diagonal, deterministic, false, nullptr, + nullptr, nullptr, nullptr, workspace_tensor.data(), stream); nvte_tensor_pack_destroy(&aux_input_tensors); } diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index a6a8b0b26a..a5a1fa8c1b 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1182,6 +1182,10 @@ def forward( fp8_output, layer_number, return_max_logit, + score_mod, + score_mod_bprop, + score_mod_tensors, + score_mod_bprop_tensors, ): # pylint: disable=missing-function-docstring @@ -1277,6 +1281,8 @@ def forward( bottom_right_diagonal, rng_gen, softmax_offset, + score_mod=score_mod, + score_mod_tensors=score_mod_tensors, cuda_graph=is_graph_capturing(), ) @@ -1357,6 +1363,8 @@ def forward( softmax_offset, return_max_logit, is_graph_capturing(), + score_mod=score_mod, + score_mod_tensors=score_mod_tensors, ) out = out_ out_ret = out_ @@ -1446,6 +1454,10 @@ def forward( ) ctx.use_FAv2_bwd = use_FAv2_bwd ctx.deterministic = deterministic + ctx.score_mod = score_mod + ctx.score_mod_bprop = score_mod_bprop + ctx.score_mod_tensors = score_mod_tensors + ctx.score_mod_bprop_tensors = score_mod_bprop_tensors if return_max_logit: return out_ret, *max_logit @@ -1594,6 +1606,10 @@ def backward(ctx, d_out, *_args): ctx.bottom_right_diagonal, ctx.deterministic, is_graph_capturing(), + score_mod=ctx.score_mod, + score_mod_bprop=ctx.score_mod_bprop, + score_mod_tensors=ctx.score_mod_tensors, + score_mod_bprop_tensors=ctx.score_mod_bprop_tensors, ) # dq, dk, dv: torch.Tensor; dtype = torch.float16 or torch.bfloat16 @@ -1660,6 +1676,10 @@ def backward(ctx, d_out, *_args): ctx.bottom_right_diagonal, ctx.deterministic, is_graph_capturing(), + score_mod=ctx.score_mod, + score_mod_bprop=ctx.score_mod_bprop, + score_mod_tensors=ctx.score_mod_tensors, + score_mod_bprop_tensors=ctx.score_mod_bprop_tensors, ) d_bias = None @@ -1702,6 +1722,10 @@ def backward(ctx, d_out, *_args): None, None, None, + None, + None, + None, + None, ) @@ -1810,6 +1834,10 @@ def forward( pad_between_seqs: bool = False, inference_params: Optional[InferenceParams] = None, softmax_offset: torch.Tensor = None, + score_mod: Optional[Callable] = None, + score_mod_bprop: Optional[Callable] = None, + score_mod_tensors: Optional[Dict[str, torch.Tensor]] = None, + score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]] = None, fp8_output: bool = False, ) -> torch.Tensor: """fused attention fprop""" @@ -1826,6 +1854,10 @@ def forward( assert ( qkv_layout in QKVLayouts ), f"FusedAttention does not support qkv_layout = {qkv_layout}!" + if score_mod is not None or score_mod_bprop is not None: + assert ( + fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen + ), "score_mod and score_mod_bprop require the F16_arbitrary_seqlen fused backend." cp_size = 1 if isinstance(cp_group, dist_group_type): @@ -1934,6 +1966,9 @@ def forward( ) if context_parallel: + assert ( + score_mod is None and score_mod_bprop is None + ), "score_mod and score_mod_bprop are not supported with context parallelism." assert ( fp8 or fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen @@ -2015,6 +2050,10 @@ def forward( fp8_output, self.layer_number, self.return_max_logit, + score_mod, + score_mod_bprop, + score_mod_tensors, + score_mod_bprop_tensors, ) if self.return_max_logit: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 2dc42be18a..98b61e6550 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -826,6 +826,10 @@ def forward( fast_zero_fill: bool = True, inference_params: Optional[InferenceParams] = None, pad_between_seqs: Optional[bool] = None, + score_mod: Optional[Callable] = None, + score_mod_bprop: Optional[Callable] = None, + score_mod_tensors: Optional[Dict[str, torch.Tensor]] = None, + score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]] = None, fp8_output: Optional[bool] = False, num_splits: Optional[int] = 1, ) -> torch.Tensor: @@ -1372,6 +1376,8 @@ def forward( inference_params=inference_params, softmax_type=self.softmax_type, return_max_logit=self.return_max_logit, + has_score_mod=score_mod is not None, + has_score_mod_bprop=score_mod_bprop is not None, cuda_graph=is_graph_capturing(), num_splits=num_splits, ) @@ -1520,6 +1526,10 @@ def forward( pad_between_seqs=pad_between_seqs, inference_params=inference_params, softmax_offset=softmax_offset, + score_mod=score_mod, + score_mod_bprop=score_mod_bprop, + score_mod_tensors=score_mod_tensors, + score_mod_bprop_tensors=score_mod_bprop_tensors, fp8_output=fp8_output, ) return self.fused_attention( @@ -1551,6 +1561,10 @@ def forward( pad_between_seqs=pad_between_seqs, inference_params=inference_params, softmax_offset=softmax_offset, + score_mod=score_mod, + score_mod_bprop=score_mod_bprop, + score_mod_tensors=score_mod_tensors, + score_mod_bprop_tensors=score_mod_bprop_tensors, fp8_output=fp8_output, ) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 567fd17c34..cccb4c25ae 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -234,6 +234,10 @@ class AttentionParams: The type of softmax operation. See DotProductAttention for details. return_max_logit : bool, default = False Whether to output max_logit. + has_score_mod : bool, default = False + Whether a cuDNN flexible-graph score modifier callback is present. + has_score_mod_bprop : bool, default = False + Whether a cuDNN flexible-graph score modifier backward callback is present. cuda_graph : bool, default = `False` Whether support for cuda graph capture is needed or not. num_splits : int, default = 1 @@ -268,6 +272,8 @@ class AttentionParams: inference_params: Optional[InferenceParams] = None softmax_type: str = "vanilla" return_max_logit: bool = False + has_score_mod: bool = False + has_score_mod_bprop: bool = False cuda_graph: bool = False num_splits: int = 1 @@ -345,6 +351,8 @@ def get_attention_backend( inference_params = attention_params.inference_params softmax_type = attention_params.softmax_type return_max_logit = attention_params.return_max_logit + has_score_mod = attention_params.has_score_mod + has_score_mod_bprop = attention_params.has_score_mod_bprop cuda_graph = attention_params.cuda_graph num_splits = attention_params.num_splits @@ -534,6 +542,17 @@ def get_attention_backend( logger.debug("Disabling UnfusedDotProductAttention for num_splits") use_unfused_attention = False + if has_score_mod or has_score_mod_bprop: + if use_flash_attention: + logger.debug("Disabling FlashAttention for score_mod callbacks") + use_flash_attention = False + if use_unfused_attention: + logger.debug("Disabling UnfusedDotProductAttention for score_mod callbacks") + use_unfused_attention = False + if fp8 and fp8_meta["recipe"].fp8_dpa: + logger.debug("Disabling FusedAttention for score_mod callbacks in FP8") + use_fused_attention = False + # Filter: Return max_logit if return_max_logit: if use_flash_attention: @@ -1009,6 +1028,17 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt logger.debug("Disabling FusedAttention as no backend supports the provided input") use_fused_attention = False fused_attention_backend = None + if ( + use_fused_attention + and (has_score_mod or has_score_mod_bprop) + and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"] + ): + logger.debug( + "Disabling FusedAttention as score_mod callbacks require the F16_arbitrary_seqlen" + " sub-backend" + ) + use_fused_attention = False + fused_attention_backend = None if ( use_fused_attention and window_size is not None diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index d95d327c78..3b6ff13c85 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -5,7 +5,7 @@ """Multi-head Attention.""" import os import collections -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from transformer_engine.pytorch.quantization import FP8GlobalStateManager @@ -647,6 +647,10 @@ def forward( rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, + score_mod: Optional[Callable] = None, + score_mod_bprop: Optional[Callable] = None, + score_mod_tensors: Optional[Dict[str, torch.Tensor]] = None, + score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]] = None, alibi_slopes: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_kv: Optional[torch.Tensor] = None, @@ -721,6 +725,10 @@ def forward( core_attention_bias: Optional[torch.Tensor], default = None Bias tensor for :math:`Q \cdot K^T`, shape ``[1, num_head, max_seqlen_q, max_seqlen_kv]``. It should be ``None`` for ``"no_bias"`` and ``"alibi"`` bias types. + score_mod: Optional[Callable], default = None + Optional cuDNN flexible-graph score modifier callback. + score_mod_bprop: Optional[Callable], default = None + Optional cuDNN flexible-graph score modifier backward callback. alibi_slopes: Optional[torch.Tensor], default = None ALiBi slopes in FP32 and shape ``[nheads]`` or ``[batch_size, nheads]``. It adds a bias of ``(-alibi_slope * (i + seqlen_k - seqlen_q - j))`` @@ -1041,6 +1049,10 @@ def forward( checkpoint_core_attention=checkpoint_core_attention, core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, + score_mod=score_mod, + score_mod_bprop=score_mod_bprop, + score_mod_tensors=score_mod_tensors, + score_mod_bprop_tensors=score_mod_bprop_tensors, alibi_slopes=alibi_slopes, fast_zero_fill=fast_zero_fill, inference_params=inference_params, diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 2de4576e05..58176d035d 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -143,6 +143,8 @@ def fused_attn_fwd( softmax_offset: torch.Tensor = None, return_max_logit: bool = False, cuda_graph: bool = False, + score_mod=None, + score_mod_tensors=None, ) -> Tuple[Union[torch.Tensor, None], ...]: """Fused Attention FWD for separate QKV input. @@ -223,6 +225,10 @@ def fused_attn_fwd( softmax_offset : torch.Tensor, default = None softmax offset tensor of shape [1, h_q, 1, 1]. See softmax_type in DotProductAttention for details. + score_mod : Callable, default = None + Optional cuDNN flexible-graph score modifier callback. + score_mod_tensors : Dict[str, torch.Tensor], default = None + Extra tensors to expose to the score modifier callback via the variant pack. return_max_logit : bool, default = False whether to return the maximum attention score cuda_graph : bool, default = False @@ -289,6 +295,13 @@ def fused_attn_fwd( f" q.dtype={q.dtype}, backend={fused_attention_backend}." ) + if score_mod_tensors is not None: + assert score_mod is not None, "score_mod_tensors requires score_mod." + if score_mod is not None: + assert ( + fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] + ), "score_mod is only supported by the cuDNN F16_arbitrary_seqlen backend." + # BF16/FP16 fused attention API from fmha_v1 apex if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: rng_elts_per_thread = ( @@ -345,6 +358,8 @@ def fused_attn_fwd( o_quantizer, attn_bias, softmax_offset, + score_mod, + score_mod_tensors, rng_gen, rng_elts_per_thread, return_max_logit, @@ -398,6 +413,10 @@ def fused_attn_bwd( bottom_right_diagonal: bool = None, deterministic: bool = False, cuda_graph: bool = False, + score_mod=None, + score_mod_bprop=None, + score_mod_tensors=None, + score_mod_bprop_tensors=None, ) -> Tuple[Union[torch.Tensor, None], ...]: """Fused Attention BWD for packed KV input. @@ -473,6 +492,14 @@ def fused_attn_bwd( bottom right (True) corner of the softmax matrix. deterministic : bool, default = False whether to execute the backward pass with deterministic behaviours. + score_mod : Callable, default = None + Optional cuDNN flexible-graph score modifier callback. + score_mod_bprop : Callable, default = None + Optional cuDNN flexible-graph score modifier backward callback. + score_mod_tensors : Dict[str, torch.Tensor], default = None + Extra tensors to expose to the score modifier callback via the variant pack. + score_mod_bprop_tensors : Dict[str, torch.Tensor], default = None + Extra tensors to expose to the score modifier backward callback via the variant pack. cuda_graph : bool, default = False whether or not cuda graph capture is enabled. @@ -509,6 +536,16 @@ def fused_attn_bwd( f" q.dtype={q.dtype}, backend={fused_attention_backend}." ) + if score_mod_tensors is not None: + assert score_mod is not None, "score_mod_tensors requires score_mod." + if score_mod_bprop_tensors is not None: + assert score_mod_bprop is not None, "score_mod_bprop_tensors requires score_mod_bprop." + if score_mod is not None or score_mod_bprop is not None: + assert fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"], ( + "score_mod and score_mod_bprop are only supported by the cuDNN F16_arbitrary_seqlen" + " backend." + ) + if fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]: if len(aux_ctx_tensors) < 1: raise ValueError( @@ -568,6 +605,10 @@ def fused_attn_bwd( s_quantizer, dp_quantizer, dqkv_quantizer, + score_mod, + score_mod_bprop, + score_mod_tensors, + score_mod_bprop_tensors, cuda_graph, ) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 1c5116a8da..f9dfe464c7 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -93,7 +93,8 @@ std::vector fused_attn_fwd( const std::optional cu_seqlens_kv_padded, const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, - const std::optional SoftmaxOffset, const std::optional rng_gen, + const std::optional SoftmaxOffset, py::handle score_mod, + py::handle score_mod_tensors, const std::optional rng_gen, size_t rng_elts_per_thread, bool return_max_logit, bool cuda_graph); std::vector fused_attn_bwd( @@ -106,7 +107,9 @@ std::vector fused_attn_bwd( const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, - py::handle dp_quantizer, py::handle dqkv_quantizer, bool cuda_graph); + py::handle dp_quantizer, py::handle dqkv_quantizer, py::handle score_mod, + py::handle score_mod_bprop, py::handle score_mod_tensors, py::handle score_mod_bprop_tensors, + bool cuda_graph); at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index bf62db8c33..0f73641964 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -107,7 +107,8 @@ std::vector fused_attn_fwd( const std::optional cu_seqlens_kv_padded, const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, - const std::optional SoftmaxOffset, const std::optional rng_gen, + const std::optional SoftmaxOffset, py::handle score_mod, + py::handle score_mod_tensors, const std::optional rng_gen, size_t rng_elts_per_thread, bool return_max_logit, bool cuda_graph) { // Ensure that cuDNN handle is created on the correct device, // overriding torch.cuda.set_device calls from user side. @@ -236,7 +237,9 @@ std::vector fused_attn_fwd( te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), + softmax_type, window_size[0], window_size[1], bottom_right_diagonal, + score_mod.is_none() ? nullptr : score_mod.ptr(), + score_mod_tensors.is_none() ? nullptr : score_mod_tensors.ptr(), workspace.data(), at::cuda::getCurrentCUDAStream()); }); @@ -296,7 +299,9 @@ std::vector fused_attn_fwd( te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), + softmax_type, window_size[0], window_size[1], bottom_right_diagonal, + score_mod.is_none() ? nullptr : score_mod.ptr(), + score_mod_tensors.is_none() ? nullptr : score_mod_tensors.ptr(), workspace.data(), at::cuda::getCurrentCUDAStream()); }); @@ -318,7 +323,9 @@ std::vector fused_attn_bwd( const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, - py::handle dp_quantizer, py::handle dqkv_quantizer, bool cuda_graph) { + py::handle dp_quantizer, py::handle dqkv_quantizer, py::handle score_mod, + py::handle score_mod_bprop, py::handle score_mod_tensors, py::handle score_mod_bprop_tensors, + bool cuda_graph) { auto none = py::none(); // create QKV, O, dO tensor wrappers @@ -533,14 +540,18 @@ std::vector fused_attn_bwd( // populate tensors with appropriate shapes and dtypes NVTE_SCOPED_GIL_RELEASE({ - nvte_fused_attn_bwd( - te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), - &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), - te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), - te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], - window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(), - at::cuda::getCurrentCUDAStream()); + nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), + te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), + te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), + te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, + max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, + softmax_type, window_size[0], window_size[1], bottom_right_diagonal, + deterministic, cuda_graph, score_mod.is_none() ? nullptr : score_mod.ptr(), + score_mod_bprop.is_none() ? nullptr : score_mod_bprop.ptr(), + score_mod_tensors.is_none() ? nullptr : score_mod_tensors.ptr(), + score_mod_bprop_tensors.is_none() ? nullptr : score_mod_bprop_tensors.ptr(), + workspace.data(), at::cuda::getCurrentCUDAStream()); }); // allocate memory for workspace @@ -550,14 +561,18 @@ std::vector fused_attn_bwd( // execute kernel NVTE_SCOPED_GIL_RELEASE({ - nvte_fused_attn_bwd( - te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), - &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), - te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), - te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], - window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(), - at::cuda::getCurrentCUDAStream()); + nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), + te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), + te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), + te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, + max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, + softmax_type, window_size[0], window_size[1], bottom_right_diagonal, + deterministic, cuda_graph, score_mod.is_none() ? nullptr : score_mod.ptr(), + score_mod_bprop.is_none() ? nullptr : score_mod_bprop.ptr(), + score_mod_tensors.is_none() ? nullptr : score_mod_tensors.ptr(), + score_mod_bprop_tensors.is_none() ? nullptr : score_mod_bprop_tensors.ptr(), + workspace.data(), at::cuda::getCurrentCUDAStream()); }); // destroy tensor wrappers