feat: token merging for image classification#537
feat: token merging for image classification#537rensortino wants to merge 9 commits intoPrunaAI:mainfrom
Conversation
llcnt
left a comment
There was a problem hiding this comment.
Thanks a lot for the nice contribution :):)
Could you provide a working example of this new integration please (eg. a script or a notebook)? I tried to run it with a ViTForImageClassification but it fails (see comment below). I tried to run it with a pipeline from transformers, now the smashing works, but the inference fails: the base pipeline can accept str (url of images) or raw images. But the smashed pipeline can not. It would be nice to fix this so that the base model and the smashed one behave similarly. I tried to follow what you did in the new test by preprocessing the image before feeding it into the smashed pipeline, but I still get the error: TypeError: ViTAttention.forward() got an unexpected keyword argument 'output_attentions'.
Also could you fix the cursot[bot] comments (some of them are quite relevant ;) ) ?
Thx in advance!
|
This PR has been inactive for 10 days and is now marked as stale. |
|
Hi, thanks a lot for the feedback! I am working on the issues raised by Bugbot and will provide you shortly a notebook with a basic example on how to test the algorithm |
a6a79ec to
be3a6ed
Compare
|
Here you can find a notebook to test the algorithm on HF models and pipelines. |
|
This PR has been inactive for 10 days and is now marked as stale. |
|
bugbot run |
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Comment @cursor review or bugbot run to trigger another review on this PR
|
|
||
| def forward(self, *args: Any, **kwargs: Any) -> Any: | ||
| """Initialise ToMe state and forward through the wrapped model.""" | ||
| self._tome_info["r"] = self.parsed_r |
There was a problem hiding this comment.
Forward pass mutates parsed_r, breaking subsequent inferences
High Severity
ToMeModelWrapper.forward assigns self.parsed_r directly to self._tome_info["r"], making them the same list object. Each ToMeViTLayer.forward then calls .pop(0) on that list, which destructively mutates self.parsed_r. After the first forward pass, self.parsed_r is empty. Any subsequent forward call assigns the empty list, and the first layer's .pop(0) raises an IndexError. The assignment needs to be a copy, e.g. list(self.parsed_r).
Additional Locations (1)
llcnt
left a comment
There was a problem hiding this comment.
Sorry for the delay, I was stuck with other tasks :(
Thanks for the updates, and for the notebook!
I was not able to run all the notebook though, because:
is_vitis not implemented. I guess you have added it lately, but did not re-run the notebook;TypeError: ViTAttention.forward() got an unexpected keyword argument 'output_attentions'pops out when I run inference of the smashed model. I guess because you are using an old version oftransformers. Can you print it ? And make sure it is compatible with newer version (4.56.0 is a good starting point I would say) ;)
Thank you again for your contribution:)


Description
This PR introduces the Token Merging (ToMe) algorithm for HuggingFace Vision Transformer models. Token Merging progressively merges similar tokens between the attention and MLP stages of each transformer block, significantly reducing the number of tokens and speeding up inference with minimal quality loss.
Using model
google/vit-base-patch16-224, speedup is over 2x withr=8.Key Changes:
Token Merging Algorithm:
ToMeViTLayer,ToMeViTSelfAttention) that extend HuggingFace transformersTesting Infrastructure:
Related Issue
Fixes #399
Type of Change
How Has This Been Tested?
google/vit-base-patch16-224model familyImplementation Details
Token Merging Core Features:
racross all layersHyperparameters:
r(int, 0-128): Number of tokens to merge per layer (default: 16)trace_source(bool): Track merge provenance for visualizationprop_attn(bool): Enable proportional attention weighting (default: True)Checklist
Additional Notes
Design Decisions:
Module-level class definitions:
ToMeViTLayerand related classes are defined at module level (not inside methods) to ensure they are picklable for distributed training and model serialization.Eager attention enforcement: The
ToMeViTSelfAttentionclass uses eager attention computation to inject the proportional attention bias between QK matmul and softmax operations.Shared mutable state: All ToMe modules share a single
tome_infodict for efficient state management across layers.Future Enhancements:
References: