add blackwell support filter for 9.7<=cudnn<9.18.1#2775
add blackwell support filter for 9.7<=cudnn<9.18.1#2775sudhakarsingh27 wants to merge 8 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Greptile SummaryThis PR adds a targeted determinism filter to Key observations:
Confidence Score: 4/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["get_attention_backend()"] --> B{"use_fused_attention\nAND deterministic?"}
B -- No --> Z["Continue to flash-attn selection"]
B -- Yes --> C{"softmax_type\n≠ 'vanilla'?"}
C -- Yes --> D["Disable FusedAttention\n(non-deterministic softmax)"]
C -- No --> E{"backend == FP8\nAND is_training\nAND (sm < 9.0 OR cuDNN < 9.19.0)?"}
E -- Yes --> F["Disable FusedAttention\n(FP8 determinism)"]
E -- No --> G{"backend == F16_arbitrary_seqlen\nAND is_training\nAND (sm < 9.0 OR bias_grad OR cuDNN < 8.9.5)?"}
G -- Yes --> H["Disable FusedAttention\n(F16 legacy determinism)"]
G -- No --> I{"backend == F16_arbitrary_seqlen\nAND is_training\nAND 9.7.0 ≤ cuDNN < 9.18.1\nAND sm ≥ 10.0 (Blackwell)?"}
I -- Yes --> J["Disable FusedAttention\n🆕 Blackwell + cuDNN bug filter"]
I -- No --> Z
D --> Z
F --> Z
H --> Z
J --> Z
Last reviewed commit: "update the error log" |
transformer_engine/pytorch/attention/dot_product_attention/utils.py
Outdated
Show resolved
Hide resolved
| if ( | ||
| fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] | ||
| and deterministic | ||
| and (cudnn_version >= (9, 7) and cudnn_version < (9, 18, 1)) | ||
| and device_compute_capability >= (10, 0) | ||
| ): | ||
| logger.debug( | ||
| "Determinism not supported on Blackwell for BF16 with 9.7 <= cuDNN < 9.18.1" | ||
| ) | ||
| use_fused_attention = False | ||
| fused_attention_backend = None |
There was a problem hiding this comment.
Missing
is_training guard — may incorrectly disable FusedAttention for inference
Every other determinism filter in this same if use_fused_attention and deterministic: block guards against is_training (see lines 1070–1080 for FP8 and 1081–1092 for F16_arbitrary_seqlen), conveying that those non-determinism issues are backward-pass–specific. The new Blackwell / cuDNN-version filter does not include and is_training, so it will also disable FusedAttention during deterministic inference on Blackwell GPUs with cuDNN 9.7–9.18.1.
If the cuDNN bug only manifests during training (backward pass), the filter is overly broad and will unnecessarily fall back to a slower backend during inference. If it truly affects the forward pass as well, a comment explaining that would help reviewers and future maintainers understand the deviation from the existing pattern.
Consider either:
if (
fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
and is_training
and (cudnn_version >= (9, 7) and cudnn_version < (9, 18, 1))
and device_compute_capability >= (10, 0)
):or, if inference is also affected, add a comment explaining why is_training is deliberately omitted.
There was a problem hiding this comment.
updated, check again
transformer_engine/pytorch/attention/dot_product_attention/utils.py
Outdated
Show resolved
Hide resolved
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
transformer_engine/pytorch/attention/dot_product_attention/utils.py
Outdated
Show resolved
Hide resolved
…p8_determinism_check
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: