diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 567fd17c34..4548ff3932 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -1090,6 +1090,19 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt logger.debug("Disabling FusedAttention for determinism reasons with post_scale_bias") use_fused_attention = False fused_attention_backend = None + if ( + fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] + and is_training + and cudnn_version >= (9, 7, 0) + and cudnn_version < (9, 18, 1) + and device_compute_capability >= (10, 0) + ): + logger.debug( + "Disabling FusedAtttention because determinism is not supported on Blackwell for " + "FP16/BF16 with 9.7 <= cuDNN < 9.18.1" + ) + use_fused_attention = False + fused_attention_backend = None # use_flash_attention may have been set above use_flash_attention_2 = use_flash_attention and use_flash_attention_2