Conversation
|
I think it fixes #373, feel free to rectify if needed. |
Yes you're right. Thanks for pointing that out :) |
johannaSommer
left a comment
There was a problem hiding this comment.
Really cool PR already, I just have a few comments :)
src/pruna/algorithms/flash_attn3.py
Outdated
| # as kernel expects per tensor quantization | ||
| descale_shape = (q.shape[0], k.shape[2]) # (B, H) as input format = (B, N, H, D) | ||
| # Quantize sequentially and delete originals to reduce peak memory (otherwise risk of OOM) | ||
| q_fp8, descale_q = _quantize_fp8(q, descale_shape) |
There was a problem hiding this comment.
did we check whether we can just assign the quantized tensor to q, k, v directly? Also not sure about this tbh
There was a problem hiding this comment.
Yes, that's a very good point. Changed.
src/pruna/algorithms/flash_attn3.py
Outdated
| register_pruna_flash_attn_op(kernel, use_fp8=False) | ||
|
|
||
| # Build the version-specific apply function for non-FP8 | ||
| use_new_backend = Version(diffusers_version) >= Version("0.35.0.dev0") |
There was a problem hiding this comment.
sorry to add this to your PR, it's a bit unrelated but now would be a great time. As we discussed the older diffusers version is no longer needed, as we pin this. However, we sometimes use this functionality for non-diffusers-pipeline models. Could you please, if you have time, add another hyperaparameter to the algorithm called "custom model" and replace the diffusers version check with this, s.t. if this parameter is True, we use _apply_via_forward_wrap? Thank you!
There was a problem hiding this comment.
-
Replaced the diffusers version check with a custom_model hyperparameter. When custom_model=True, _apply_via_forward_wrap is used instead of the backend API.
-
For model_check_fn, I kept the existing DiffusionPipeline check (which validates components, set_attention_backend, and dtype) and added a generic check for custom models that accepts any nn.Module or any object containing at least one nn.Module attribute. A more specific check, e.g., verifying that the model actually calls scaled_dot_product_attention, is difficult without running a forward pass or parsing source code, which would probably create too much overhead.
-
Since the generic check for custom models cannot guarantee that the model contains attention modules, I considered adding a warning in _apply_via_forward_wrap if no sdpa call is intercepted. However, the wrapper is installed during smash() but only executed during inference, so you would only know after the first forward pass whether sdpa was actually intercepted, which makes an automatic check overly complex. Thus I would say, when using a custom model, it is the user's responsibility to ensure that the targeted submodules actually use scaled_dot_product_attention. If they don't, the algorithm runs silently without effect.
-
Note that without the nn.Module check, expand_dict_of_roots_and_subpaths would also fail downstream, but catching incompatible models early in model_check_fn is cleaner, I believe.
-
Also note, that I removed the pinned fa3 kernel version as also the torch 2.10 version now works for pruna. In fact, when installing pruna the pinned version would actually throw an error as it does not work for torch 2.10. As the flashattn3 interface in huggingface slightly differs for different kernel versions in terms of return parameters (sometimes tuple is returned, sometimes single parameter) you also have to handle this (see line 546 - 548)
-
Also added an ignore in the pyproject.toml. It seems that the ty checks cannot handle the torch.ops dispatches
If you have a better idea for the model_check_fn for custom models, feel free to let me know.
|
@BugBot run |
Bugbot couldn't runBugbot is not enabled for your user on this team. Ask your team administrator to increase your team's hard limit for Bugbot seats or add you to the allowlist in the Cursor dashboard. |
|
@BugBot run |
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
Bugbot Autofix prepared a fix for the issue found in the latest run.
- ✅ Fixed: Unused
kernelparameter threaded through new code paths- Removed the unused
kernelargument from the forward-wrap path, context object, and_flash_attention3signature so only effective parameters are propagated.
- Removed the unused
Or push these changes by commenting:
@cursor push 9762604c9f
Preview (9762604c9f)
diff --git a/src/pruna/algorithms/flash_attn3.py b/src/pruna/algorithms/flash_attn3.py
--- a/src/pruna/algorithms/flash_attn3.py
+++ b/src/pruna/algorithms/flash_attn3.py
@@ -118,7 +118,7 @@
backend_name = register_custom_backend(imported_packages, use_fp8=False)
apply_fn = functools.partial(_apply_via_backend, backend=backend_name)
else:
- apply_fn = functools.partial(_apply_via_forward_wrap, kernel=kernel, use_fp8=False)
+ apply_fn = functools.partial(_apply_via_forward_wrap, use_fp8=False)
if use_fp8:
# FA3 fp16 on ALL compatible modules
@@ -130,7 +130,7 @@
backend_name_fp8 = register_custom_backend(imported_packages, use_fp8=True)
apply_fn_fp8 = functools.partial(_apply_via_backend, backend=backend_name_fp8)
else:
- apply_fn_fp8 = functools.partial(_apply_via_forward_wrap, kernel=kernel, use_fp8=True)
+ apply_fn_fp8 = functools.partial(_apply_via_forward_wrap, use_fp8=True)
model = map_targeted_nn_roots(apply_fn_fp8, model, target_modules)
else:
# FA3 fp16 only on targeted modules
@@ -290,15 +290,12 @@
Parameters
----------
- kernel : Any
- The kernel to use for the flash attention 3.
use_fp8 : bool
Whether to quantize Q, K, V to FP8 before the attention computation.
"""
- def __init__(self, kernel: Any, use_fp8: bool = False):
+ def __init__(self, use_fp8: bool = False):
super().__init__()
- self.kernel = kernel
self.use_fp8 = use_fp8
def __torch_function__(self, func, types, args=(), kwargs=None): # noqa: D105
@@ -332,14 +329,14 @@
kwargs.pop("dropout_p", None)
kwargs.pop("enable_gqa", None)
kwargs["softmax_scale"] = kwargs.pop("scale", None)
- return _flash_attention3(*args, **kwargs, kernel=self.kernel, use_fp8=self.use_fp8)
+ return _flash_attention3(*args, **kwargs, use_fp8=self.use_fp8)
else:
return func(*args, **kwargs)
else:
return func(*args, **kwargs)
-def _flash_attention3(query, key, value, *, is_causal=False, softmax_scale=None, kernel=None, use_fp8=False):
+def _flash_attention3(query, key, value, *, is_causal=False, softmax_scale=None, use_fp8=False):
# convert (B, H, S, D) → (B, S, H, D)
q, k, v = [x.transpose(1, 2).contiguous() for x in (query, key, value)]
_ops = torch.ops.flash_attn_pruna
@@ -392,7 +389,6 @@
root_name: str | None,
root_nn_module: torch.nn.Module,
relative_target_paths: List[str],
- kernel: Any,
use_fp8: bool,
) -> torch.nn.Module:
"""
@@ -410,8 +406,6 @@
The root nn.Module.
relative_target_paths : List[str]
Relative paths to targeted submodules within the root.
- kernel : Any
- The flash attention 3 kernel module.
use_fp8 : bool
Whether to quantize Q, K, V to FP8 before the attention computation.
@@ -435,8 +429,8 @@
original_forward = original_forward.__wrapped__
@functools.wraps(original_forward)
- def new_forward(*args, _orig=original_forward, _kernel=kernel, _fp8=use_fp8, **kwargs):
- with FlashAttention3Context(kernel=_kernel, use_fp8=_fp8):
+ def new_forward(*args, _orig=original_forward, _fp8=use_fp8, **kwargs):
+ with FlashAttention3Context(use_fp8=_fp8):
return _orig(*args, **kwargs)
sub_module.forward = new_forwardComment @cursor review or bugbot run to trigger another review on this PR
src/pruna/algorithms/flash_attn3.py
Outdated
| out, _ = torch.ops.flash_attn_pruna._flash_attn_forward(q, k, v, causal=is_causal, softmax_scale=softmax_scale) | ||
| _ops = torch.ops.flash_attn_pruna | ||
| op_fn = _ops._flash_attn_forward_fp8 if use_fp8 else _ops._flash_attn_forward | ||
| out, _ = op_fn(q, k, v, causal=is_causal, softmax_scale=softmax_scale) |
There was a problem hiding this comment.
Unused kernel parameter threaded through new code paths
Low Severity
The kernel parameter in _flash_attention3 is accepted but never read in the function body — the function uses torch.ops.flash_attn_pruna directly instead. The new _apply_via_forward_wrap function propagates kernel through to FlashAttention3Context and then to _flash_attention3, creating a misleading chain of unused plumbing that suggests the kernel object is required for the computation when it is not.
Additional Locations (1)
…heir quantized versions to their fp16 ones. Instead of checking for diffusers pipeline version, check for diffusers pipeline (any version) and custom model (additional boolean hyperparameter). For diffusers pipelines, the code set_attention_backend method is applied. For custom models, the wrapper is applied. Note, that for the wrapper no straightforward, bulletproof model check is possible and therefore the algorithm might run silently without actually affecting the model.
11f9164 to
f7055a1
Compare
…orch version 2.10 so no need to pin here. As different fa3 kernels from huggingface return either tuple or single parameters, which needs to be handled as well.
…ch.ops dispatches
gsprochette
left a comment
There was a problem hiding this comment.
Great PR, can't wait to have this merged! Thanks a lot and sorry for the review delay. I left some comments, 3 topics seem important:
- with this PR FA3 no longer supports diffusers<=0.34 because of the broken import, even though I'm 90% sure the logic for wrapping the forward would still support it. Can you protect the import and test on diffusers==0.34?
- I left a lot of context about target modules. Everything you did target modules related works, I just pointed out some patterns that already exist at other places in the code. Maybe this is linked with the introduction of a "custom_model" hyperparameter that seems superfluous?
- please please please do not add the global ignore in the pyproject.toml it's disabling typing for the whole project.
| possibly-missing-import = "ignore" | ||
| possibly-missing-attribute = "ignore" | ||
| missing-argument = "ignore" | ||
| invalid-argument-type = "ignore" # torch.ops dynamic dispatch is not understood by ty |
There was a problem hiding this comment.
I understand that torch.ops creates typing errors, but please use local ignore comments in the places that are using the torch.ops in question. Invalid argument type corresponds to the following type of errors:
def func(x: int) -> None: ...
func("foo") # called with argument of incorrect type
By adding the ignore here in the pyproject.toml, you're allowing such type error anywhere in pruna, making the use of typing essentially obsolete. We managed to remove this ignore only 2 weeks ago...
| if Version(diffusers_version) >= Version("0.35.0.dev0"): | ||
| if not isinstance(model, DiffusionPipeline) or not hasattr(model, "components"): | ||
| # Standard diffusers pipeline path | ||
| if isinstance(model, DiffusionPipeline): |
There was a problem hiding this comment.
There used to be a version check to apply FA3 to model.transformer by default for diffusers versions before 0.35.0.dev0 because the components was not always there. Are DiffusionPipeline just not compatible with FA3 anymore if the version is 0.34-? Or would it detect that the transformer is a nn.Module withing the pipeline and return True from the last line?
| _check_shape, | ||
| _native_attention, | ||
| ) | ||
| from diffusers.models.attention_dispatch import ( |
There was a problem hiding this comment.
this import breaks for diffusers==0.34 and below
| Boolean( | ||
| "custom_model", | ||
| default=False, | ||
| meta=dict(desc="Apply FlashAttention3 to a custom (=non-diffusers pipeline) model."), |
There was a problem hiding this comment.
Why does this need to be a hyperparameter? It seems you could determine this from the model passed as input in _apply with the conditions of the first if block in the model_check_fn.
| return [ | ||
| # We do not set specific default target modules as FA3 is lossless if not used with FP8 quantization | ||
| # and therefore can be applied to any attn module without any performance degradation. | ||
| TargetModules(name="target_modules", default_value={"include": ["*"], "exclude": []}), |
There was a problem hiding this comment.
I'll let you decide whether this is something you want to use or not: TargetModules accepts a None value meaning that this has not been set by the user. The pattern in other algorithm is that when TargetModules is None, you set it based on the model architecture in get_model_dependent_hyperparameter_defaults.
A simple and important example of this is in quanto, where the function target_backbone automatically targets the unet or transformer attributes of a DiffuserPipeline. The flow is therefore:
- if the user provides a target_modules, use it without question
- otherwise (None), deduce what to target based on the model, looking for unets or transformers.
| The (modified) root module. | ||
| """ | ||
| if root_nn_module.dtype not in (torch.bfloat16, torch.float16): | ||
| return root_nn_module |
There was a problem hiding this comment.
I could see a debug or info log here. Also know that the target modules allow you to filter the modules that are listed using filter_targeted_modules, so you can do target_modules_16 = filter_target_modules(lambda module: module.dtype in (torch.bloat16, torch.float16), model, target_modules) before you map _apply_via_backend. Then you can be sure that root_nn_module.get_submodule(rel_path) always has the correct dtype (although you're not actually checking for the root's dtype).
| pruna_logger.warning( | ||
| "FlashAttention3 forward wrap was not applied to module '%s' " | ||
| "because its dtype is not bfloat16 or float16.", | ||
| root_name, | ||
| ) | ||
| return root_nn_module |
There was a problem hiding this comment.
| pruna_logger.warning( | |
| "FlashAttention3 forward wrap was not applied to module '%s' " | |
| "because its dtype is not bfloat16 or float16.", | |
| root_name, | |
| ) | |
| return root_nn_module | |
| pruna_logger.warning( | |
| f"FlashAttention3 forward wrap skipped {root_name}: expected dtype bfloat16 or float16, got {module_dtype}" | |
| ) | |
| return root_nn_module |
| # If already wrapped by a previous pass, unwrap to the true original | ||
| # to avoid nested TorchFunctionMode contexts (inner would always win). | ||
| while hasattr(original_forward, "__wrapped__"): | ||
| original_forward = original_forward.__wrapped__ |
There was a problem hiding this comment.
Is there any risk to this, like undoing some caching monkey patch for example?
- There should probably be a debug log if we're removing a wrap and not applying it again.
- The other option is to make the
sub_module.forward = new_forwardwrap more complex by replacingsub_module.forward.__wrapped__instead to preserve the wrap but still have the fa3 wrao applied as the innermost one



Description
Add fp8 quantization to flashattn3 algorithm as optional boolean hyperparameter. Further, add optional target modules for fa3 as well as fp8 quantization. The logic is as follows:
Following speed ups / results are for Wan-AI/Wan2.2-TI2V-5B-Diffusers.
Settings
1) Original
Inference speed: 483s
original.mp4
2) FA3-FP16
Results identical to original as FA3 is lossless:
>= 0.35.0.dev0
Inference speed: 362s
fa3.mp4
< 0.35.0.dev0
Inference speed: 361s
fa3_wrapped.mp4
3) FA3-FP8
>= 0.35.0.dev0
Inference speed: 321s
fa3_fp8.mp4
< 0.35.0.dev0
Inference speed: 326s
fa3_fp8_wrapped.mp4
4) FA3-FP8 (first and last transformer block excluded from quantization)
>= 0.35.0.dev0
Inference speed: 324s
fa3_fp8_excluded.mp4
< 0.35.0.dev0
Inference speed: 329s
fa3_fp8_excluded_wrapped.mp4
5) FA3-FP8 (first and last two transformer blocks excluded from quantization)
>= 0.35.0.dev0
Inference speed: 327s
fa3_fp8_excluded2.mp4
< 0.35.0.dev0
Inference speed: 331s
fa3_fp8_excluded2_wrapped.mp4
Related Issue
/
Type of Change
How Has This Been Tested?
Manual test runs, see above
Checklist
Additional Notes
/