Skip to content

feat: FA3-FP8-extension#552

Open
Marius-Graml wants to merge 9 commits intomainfrom
feat/fa3-fp8-extension
Open

feat: FA3-FP8-extension#552
Marius-Graml wants to merge 9 commits intomainfrom
feat/fa3-fp8-extension

Conversation

@Marius-Graml
Copy link
Contributor

@Marius-Graml Marius-Graml commented Feb 24, 2026

Description

Add fp8 quantization to flashattn3 algorithm as optional boolean hyperparameter. Further, add optional target modules for fa3 as well as fp8 quantization. The logic is as follows:

  1. Apply target modules + no fp8: FA3 is only applied to target modules
  2. Apply target modules + fp8: FP8 quantization is applied to target modules only. FA3 is applied to all modules.

Following speed ups / results are for Wan-AI/Wan2.2-TI2V-5B-Diffusers.

Settings

  • num_frames = 241
  • height = 1280
  • width = 704
  • guidance_scale = 6.0
  • num_inference_steps = 30
  • seed: 20

1) Original

Inference speed: 483s

original.mp4

2) FA3-FP16

Results identical to original as FA3 is lossless:

>= 0.35.0.dev0

Inference speed: 362s

fa3.mp4

< 0.35.0.dev0

Inference speed: 361s

fa3_wrapped.mp4

3) FA3-FP8

>= 0.35.0.dev0

Inference speed: 321s

fa3_fp8.mp4

< 0.35.0.dev0

Inference speed: 326s

fa3_fp8_wrapped.mp4

4) FA3-FP8 (first and last transformer block excluded from quantization)

>= 0.35.0.dev0

Inference speed: 324s

fa3_fp8_excluded.mp4

< 0.35.0.dev0

Inference speed: 329s

fa3_fp8_excluded_wrapped.mp4

5) FA3-FP8 (first and last two transformer blocks excluded from quantization)

>= 0.35.0.dev0

Inference speed: 327s

fa3_fp8_excluded2.mp4

< 0.35.0.dev0

Inference speed: 331s

fa3_fp8_excluded2_wrapped.mp4

Related Issue

/

Type of Change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

How Has This Been Tested?

Manual test runs, see above

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Additional Notes

/

@Marius-Graml Marius-Graml changed the title Feat: FA3-FP8-extension feat: FA3-FP8-extension Feb 24, 2026
@ParagEkbote
Copy link
Contributor

ParagEkbote commented Feb 24, 2026

I think it fixes #373, feel free to rectify if needed.

@Marius-Graml
Copy link
Contributor Author

I think it fixes #373, feel free to rectify if needed.

Yes you're right. Thanks for pointing that out :)

@Marius-Graml Marius-Graml requested review from gsprochette and removed request for gsprochette and johannaSommer February 25, 2026 09:18
Copy link
Member

@johannaSommer johannaSommer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really cool PR already, I just have a few comments :)

# as kernel expects per tensor quantization
descale_shape = (q.shape[0], k.shape[2]) # (B, H) as input format = (B, N, H, D)
# Quantize sequentially and delete originals to reduce peak memory (otherwise risk of OOM)
q_fp8, descale_q = _quantize_fp8(q, descale_shape)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did we check whether we can just assign the quantized tensor to q, k, v directly? Also not sure about this tbh

Copy link
Contributor Author

@Marius-Graml Marius-Graml Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's a very good point. Changed.

register_pruna_flash_attn_op(kernel, use_fp8=False)

# Build the version-specific apply function for non-FP8
use_new_backend = Version(diffusers_version) >= Version("0.35.0.dev0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry to add this to your PR, it's a bit unrelated but now would be a great time. As we discussed the older diffusers version is no longer needed, as we pin this. However, we sometimes use this functionality for non-diffusers-pipeline models. Could you please, if you have time, add another hyperaparameter to the algorithm called "custom model" and replace the diffusers version check with this, s.t. if this parameter is True, we use _apply_via_forward_wrap? Thank you!

Copy link
Contributor Author

@Marius-Graml Marius-Graml Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Replaced the diffusers version check with a custom_model hyperparameter. When custom_model=True, _apply_via_forward_wrap is used instead of the backend API.

  • For model_check_fn, I kept the existing DiffusionPipeline check (which validates components, set_attention_backend, and dtype) and added a generic check for custom models that accepts any nn.Module or any object containing at least one nn.Module attribute. A more specific check, e.g., verifying that the model actually calls scaled_dot_product_attention, is difficult without running a forward pass or parsing source code, which would probably create too much overhead.

  • Since the generic check for custom models cannot guarantee that the model contains attention modules, I considered adding a warning in _apply_via_forward_wrap if no sdpa call is intercepted. However, the wrapper is installed during smash() but only executed during inference, so you would only know after the first forward pass whether sdpa was actually intercepted, which makes an automatic check overly complex. Thus I would say, when using a custom model, it is the user's responsibility to ensure that the targeted submodules actually use scaled_dot_product_attention. If they don't, the algorithm runs silently without effect.

  • Note that without the nn.Module check, expand_dict_of_roots_and_subpaths would also fail downstream, but catching incompatible models early in model_check_fn is cleaner, I believe.

  • Also note, that I removed the pinned fa3 kernel version as also the torch 2.10 version now works for pruna. In fact, when installing pruna the pinned version would actually throw an error as it does not work for torch 2.10. As the flashattn3 interface in huggingface slightly differs for different kernel versions in terms of return parameters (sometimes tuple is returned, sometimes single parameter) you also have to handle this (see line 546 - 548)

  • Also added an ignore in the pyproject.toml. It seems that the ty checks cannot handle the torch.ops dispatches

If you have a better idea for the model_check_fn for custom models, feel free to let me know.

@johannaSommer
Copy link
Member

@BugBot run

@cursor
Copy link

cursor bot commented Mar 5, 2026

Bugbot couldn't run

Bugbot is not enabled for your user on this team.

Ask your team administrator to increase your team's hard limit for Bugbot seats or add you to the allowlist in the Cursor dashboard.

@johannaSommer
Copy link
Member

@BugBot run

Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Bugbot Autofix prepared a fix for the issue found in the latest run.

  • ✅ Fixed: Unused kernel parameter threaded through new code paths
    • Removed the unused kernel argument from the forward-wrap path, context object, and _flash_attention3 signature so only effective parameters are propagated.

Create PR

Or push these changes by commenting:

@cursor push 9762604c9f
Preview (9762604c9f)
diff --git a/src/pruna/algorithms/flash_attn3.py b/src/pruna/algorithms/flash_attn3.py
--- a/src/pruna/algorithms/flash_attn3.py
+++ b/src/pruna/algorithms/flash_attn3.py
@@ -118,7 +118,7 @@
             backend_name = register_custom_backend(imported_packages, use_fp8=False)
             apply_fn = functools.partial(_apply_via_backend, backend=backend_name)
         else:
-            apply_fn = functools.partial(_apply_via_forward_wrap, kernel=kernel, use_fp8=False)
+            apply_fn = functools.partial(_apply_via_forward_wrap, use_fp8=False)
 
         if use_fp8:
             # FA3 fp16 on ALL compatible modules
@@ -130,7 +130,7 @@
                 backend_name_fp8 = register_custom_backend(imported_packages, use_fp8=True)
                 apply_fn_fp8 = functools.partial(_apply_via_backend, backend=backend_name_fp8)
             else:
-                apply_fn_fp8 = functools.partial(_apply_via_forward_wrap, kernel=kernel, use_fp8=True)
+                apply_fn_fp8 = functools.partial(_apply_via_forward_wrap, use_fp8=True)
             model = map_targeted_nn_roots(apply_fn_fp8, model, target_modules)
         else:
             # FA3 fp16 only on targeted modules
@@ -290,15 +290,12 @@
 
     Parameters
     ----------
-    kernel : Any
-        The kernel to use for the flash attention 3.
     use_fp8 : bool
         Whether to quantize Q, K, V to FP8 before the attention computation.
     """
 
-    def __init__(self, kernel: Any, use_fp8: bool = False):
+    def __init__(self, use_fp8: bool = False):
         super().__init__()
-        self.kernel = kernel
         self.use_fp8 = use_fp8
 
     def __torch_function__(self, func, types, args=(), kwargs=None):  # noqa: D105
@@ -332,14 +329,14 @@
                 kwargs.pop("dropout_p", None)
                 kwargs.pop("enable_gqa", None)
                 kwargs["softmax_scale"] = kwargs.pop("scale", None)
-                return _flash_attention3(*args, **kwargs, kernel=self.kernel, use_fp8=self.use_fp8)
+                return _flash_attention3(*args, **kwargs, use_fp8=self.use_fp8)
             else:
                 return func(*args, **kwargs)
         else:
             return func(*args, **kwargs)
 
 
-def _flash_attention3(query, key, value, *, is_causal=False, softmax_scale=None, kernel=None, use_fp8=False):
+def _flash_attention3(query, key, value, *, is_causal=False, softmax_scale=None, use_fp8=False):
     # convert (B, H, S, D) → (B, S, H, D)
     q, k, v = [x.transpose(1, 2).contiguous() for x in (query, key, value)]
     _ops = torch.ops.flash_attn_pruna
@@ -392,7 +389,6 @@
     root_name: str | None,
     root_nn_module: torch.nn.Module,
     relative_target_paths: List[str],
-    kernel: Any,
     use_fp8: bool,
 ) -> torch.nn.Module:
     """
@@ -410,8 +406,6 @@
         The root nn.Module.
     relative_target_paths : List[str]
         Relative paths to targeted submodules within the root.
-    kernel : Any
-        The flash attention 3 kernel module.
     use_fp8 : bool
         Whether to quantize Q, K, V to FP8 before the attention computation.
 
@@ -435,8 +429,8 @@
             original_forward = original_forward.__wrapped__
 
         @functools.wraps(original_forward)
-        def new_forward(*args, _orig=original_forward, _kernel=kernel, _fp8=use_fp8, **kwargs):
-            with FlashAttention3Context(kernel=_kernel, use_fp8=_fp8):
+        def new_forward(*args, _orig=original_forward, _fp8=use_fp8, **kwargs):
+            with FlashAttention3Context(use_fp8=_fp8):
                 return _orig(*args, **kwargs)
 
         sub_module.forward = new_forward
This Bugbot Autofix run was free. To enable autofix for future PRs, go to the Cursor dashboard.

Comment @cursor review or bugbot run to trigger another review on this PR

out, _ = torch.ops.flash_attn_pruna._flash_attn_forward(q, k, v, causal=is_causal, softmax_scale=softmax_scale)
_ops = torch.ops.flash_attn_pruna
op_fn = _ops._flash_attn_forward_fp8 if use_fp8 else _ops._flash_attn_forward
out, _ = op_fn(q, k, v, causal=is_causal, softmax_scale=softmax_scale)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused kernel parameter threaded through new code paths

Low Severity

The kernel parameter in _flash_attention3 is accepted but never read in the function body — the function uses torch.ops.flash_attn_pruna directly instead. The new _apply_via_forward_wrap function propagates kernel through to FlashAttention3Context and then to _flash_attention3, creating a misleading chain of unused plumbing that suggests the kernel object is required for the computation when it is not.

Additional Locations (1)

Fix in Cursor Fix in Web

…heir quantized versions to their fp16 ones. Instead of checking for diffusers pipeline version, check for diffusers pipeline (any version) and custom model (additional boolean hyperparameter). For diffusers pipelines, the code set_attention_backend method is applied. For custom models, the wrapper is applied. Note, that for the wrapper no straightforward, bulletproof model check is possible and therefore the algorithm might run silently without actually affecting the model.
@Marius-Graml Marius-Graml force-pushed the feat/fa3-fp8-extension branch from 11f9164 to f7055a1 Compare March 9, 2026 14:48
…orch version 2.10 so no need to pin here. As different fa3 kernels from huggingface return either tuple or single parameters, which needs to be handled as well.
Copy link
Collaborator

@gsprochette gsprochette left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great PR, can't wait to have this merged! Thanks a lot and sorry for the review delay. I left some comments, 3 topics seem important:

  • with this PR FA3 no longer supports diffusers<=0.34 because of the broken import, even though I'm 90% sure the logic for wrapping the forward would still support it. Can you protect the import and test on diffusers==0.34?
  • I left a lot of context about target modules. Everything you did target modules related works, I just pointed out some patterns that already exist at other places in the code. Maybe this is linked with the introduction of a "custom_model" hyperparameter that seems superfluous?
  • please please please do not add the global ignore in the pyproject.toml it's disabling typing for the whole project.

possibly-missing-import = "ignore"
possibly-missing-attribute = "ignore"
missing-argument = "ignore"
invalid-argument-type = "ignore" # torch.ops dynamic dispatch is not understood by ty
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that torch.ops creates typing errors, but please use local ignore comments in the places that are using the torch.ops in question. Invalid argument type corresponds to the following type of errors:

def func(x: int) -> None: ...
func("foo")  # called with argument of incorrect type

By adding the ignore here in the pyproject.toml, you're allowing such type error anywhere in pruna, making the use of typing essentially obsolete. We managed to remove this ignore only 2 weeks ago...

if Version(diffusers_version) >= Version("0.35.0.dev0"):
if not isinstance(model, DiffusionPipeline) or not hasattr(model, "components"):
# Standard diffusers pipeline path
if isinstance(model, DiffusionPipeline):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There used to be a version check to apply FA3 to model.transformer by default for diffusers versions before 0.35.0.dev0 because the components was not always there. Are DiffusionPipeline just not compatible with FA3 anymore if the version is 0.34-? Or would it detect that the transformer is a nn.Module withing the pipeline and return True from the last line?

_check_shape,
_native_attention,
)
from diffusers.models.attention_dispatch import (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this import breaks for diffusers==0.34 and below

Boolean(
"custom_model",
default=False,
meta=dict(desc="Apply FlashAttention3 to a custom (=non-diffusers pipeline) model."),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this need to be a hyperparameter? It seems you could determine this from the model passed as input in _apply with the conditions of the first if block in the model_check_fn.

return [
# We do not set specific default target modules as FA3 is lossless if not used with FP8 quantization
# and therefore can be applied to any attn module without any performance degradation.
TargetModules(name="target_modules", default_value={"include": ["*"], "exclude": []}),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll let you decide whether this is something you want to use or not: TargetModules accepts a None value meaning that this has not been set by the user. The pattern in other algorithm is that when TargetModules is None, you set it based on the model architecture in get_model_dependent_hyperparameter_defaults.

A simple and important example of this is in quanto, where the function target_backbone automatically targets the unet or transformer attributes of a DiffuserPipeline. The flow is therefore:

  • if the user provides a target_modules, use it without question
  • otherwise (None), deduce what to target based on the model, looking for unets or transformers.

The (modified) root module.
"""
if root_nn_module.dtype not in (torch.bfloat16, torch.float16):
return root_nn_module
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could see a debug or info log here. Also know that the target modules allow you to filter the modules that are listed using filter_targeted_modules, so you can do target_modules_16 = filter_target_modules(lambda module: module.dtype in (torch.bloat16, torch.float16), model, target_modules) before you map _apply_via_backend. Then you can be sure that root_nn_module.get_submodule(rel_path) always has the correct dtype (although you're not actually checking for the root's dtype).

Comment on lines +452 to +457
pruna_logger.warning(
"FlashAttention3 forward wrap was not applied to module '%s' "
"because its dtype is not bfloat16 or float16.",
root_name,
)
return root_nn_module
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pruna_logger.warning(
"FlashAttention3 forward wrap was not applied to module '%s' "
"because its dtype is not bfloat16 or float16.",
root_name,
)
return root_nn_module
pruna_logger.warning(
f"FlashAttention3 forward wrap skipped {root_name}: expected dtype bfloat16 or float16, got {module_dtype}"
)
return root_nn_module

# If already wrapped by a previous pass, unwrap to the true original
# to avoid nested TorchFunctionMode contexts (inner would always win).
while hasattr(original_forward, "__wrapped__"):
original_forward = original_forward.__wrapped__
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any risk to this, like undoing some caching monkey patch for example?

  • There should probably be a debug log if we're removing a wrap and not applying it again.
  • The other option is to make the sub_module.forward = new_forward wrap more complex by replacing sub_module.forward.__wrapped__ instead to preserve the wrap but still have the fa3 wrao applied as the innermost one

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants