-
Notifications
You must be signed in to change notification settings - Fork 667
add blackwell support filter for 9.7<=cudnn<9.18.1 #2775
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
c9daeba
9c30493
37800b9
25da675
342c68a
73880c5
481d712
9e4f358
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1090,6 +1090,19 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt | |
| logger.debug("Disabling FusedAttention for determinism reasons with post_scale_bias") | ||
| use_fused_attention = False | ||
| fused_attention_backend = None | ||
| if ( | ||
| fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] | ||
| and is_training | ||
| and cudnn_version >= (9, 7, 0) | ||
| and cudnn_version < (9, 18, 1) | ||
| and device_compute_capability >= (10, 0) | ||
| ): | ||
| logger.debug( | ||
| "Disabling FusedAtttention because determinism is not supported on Blackwell for " | ||
| "FP16/BF16 with 9.7 <= cuDNN < 9.18.1" | ||
| ) | ||
| use_fused_attention = False | ||
| fused_attention_backend = None | ||
|
Comment on lines
+1093
to
+1105
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Every other determinism filter in this same If the cuDNN bug only manifests during training (backward pass), the filter is overly broad and will unnecessarily fall back to a slower backend during inference. If it truly affects the forward pass as well, a comment explaining that would help reviewers and future maintainers understand the deviation from the existing pattern. Consider either: if (
fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
and is_training
and (cudnn_version >= (9, 7) and cudnn_version < (9, 18, 1))
and device_compute_capability >= (10, 0)
):or, if inference is also affected, add a comment explaining why
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated, check again |
||
|
|
||
| # use_flash_attention may have been set above | ||
| use_flash_attention_2 = use_flash_attention and use_flash_attention_2 | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.