From c9daeba3759ea1273655f70f3cc6f1e9cd5a13d2 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Tue, 17 Mar 2026 16:05:09 -0700 Subject: [PATCH 1/6] add blackwell support filter for 9.7<=cudnn<9.18.1 Signed-off-by: Sudhakar Singh --- .../attention/dot_product_attention/utils.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 567fd17c34..e6e75271f3 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -1090,6 +1090,19 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt logger.debug("Disabling FusedAttention for determinism reasons with post_scale_bias") use_fused_attention = False fused_attention_backend = None + if ( + fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] + and deterministic + and ( + cudnn_version >= (9, 7) + and cudnn_version < (9, 18, 1) + ) + and device_compute_capability >= (10, 0) + ): + logger.debug("Determinism not supported on Blackwell for BF16 with 9.7 <= cuDNN < 9.18.1") + use_fused_attention = False + fused_attention_backend = None + # use_flash_attention may have been set above use_flash_attention_2 = use_flash_attention and use_flash_attention_2 From 9c30493fc7c315a9e6cbb53e29270830ae8a38ed Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Mar 2026 23:09:52 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/attention/dot_product_attention/utils.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index e6e75271f3..329e803c20 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -1093,17 +1093,15 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if ( fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] and deterministic - and ( - cudnn_version >= (9, 7) - and cudnn_version < (9, 18, 1) - ) + and (cudnn_version >= (9, 7) and cudnn_version < (9, 18, 1)) and device_compute_capability >= (10, 0) ): - logger.debug("Determinism not supported on Blackwell for BF16 with 9.7 <= cuDNN < 9.18.1") + logger.debug( + "Determinism not supported on Blackwell for BF16 with 9.7 <= cuDNN < 9.18.1" + ) use_fused_attention = False fused_attention_backend = None - # use_flash_attention may have been set above use_flash_attention_2 = use_flash_attention and use_flash_attention_2 use_flash_attention_3 = use_flash_attention and use_flash_attention_3 From 37800b95cf1f83524e7699f1077748b3dd37fd98 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Tue, 17 Mar 2026 16:13:09 -0700 Subject: [PATCH 3/6] simplify conditionals Signed-off-by: Sudhakar Singh --- .../pytorch/attention/dot_product_attention/utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index e6e75271f3..45068d990b 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -1093,10 +1093,8 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if ( fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] and deterministic - and ( - cudnn_version >= (9, 7) - and cudnn_version < (9, 18, 1) - ) + and cudnn_version >= (9, 7) + and cudnn_version < (9, 18, 1) and device_compute_capability >= (10, 0) ): logger.debug("Determinism not supported on Blackwell for BF16 with 9.7 <= cuDNN < 9.18.1") From 342c68ad78e60dfa96a39190799c99a430fd5e5f Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Tue, 17 Mar 2026 16:15:51 -0700 Subject: [PATCH 4/6] fix conditionals again Signed-off-by: Sudhakar Singh --- .../pytorch/attention/dot_product_attention/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 5a547f7a78..0bf69aefc6 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -1092,13 +1092,13 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt fused_attention_backend = None if ( fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] - and deterministic + and is_training and cudnn_version >= (9, 7) and cudnn_version < (9, 18, 1) and device_compute_capability >= (10, 0) ): logger.debug( - "Determinism not supported on Blackwell for BF16 with 9.7 <= cuDNN < 9.18.1" + "Determinism not supported on Blackwell for FP16/BF16 with 9.7 <= cuDNN < 9.18.1" ) use_fused_attention = False fused_attention_backend = None From 73880c533bc544e19b7c80f245f64c940cecc320 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Tue, 17 Mar 2026 16:16:48 -0700 Subject: [PATCH 5/6] fix conditionals again Signed-off-by: Sudhakar Singh --- .../pytorch/attention/dot_product_attention/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 0bf69aefc6..1db8aac1fb 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -1093,7 +1093,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if ( fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] and is_training - and cudnn_version >= (9, 7) + and cudnn_version >= (9, 7, 0) and cudnn_version < (9, 18, 1) and device_compute_capability >= (10, 0) ): From 9e4f358849dae969da32757c04a1df86180ba645 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Thu, 19 Mar 2026 11:56:59 -0700 Subject: [PATCH 6/6] update the error log Signed-off-by: Sudhakar Singh --- .../pytorch/attention/dot_product_attention/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 1db8aac1fb..4548ff3932 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -1098,7 +1098,8 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt and device_compute_capability >= (10, 0) ): logger.debug( - "Determinism not supported on Blackwell for FP16/BF16 with 9.7 <= cuDNN < 9.18.1" + "Disabling FusedAtttention because determinism is not supported on Blackwell for " + "FP16/BF16 with 9.7 <= cuDNN < 9.18.1" ) use_fused_attention = False fused_attention_backend = None