Skip to content

feat: token merging for image classification#537

Open
rensortino wants to merge 9 commits intoPrunaAI:mainfrom
rensortino:feat/token-merging
Open

feat: token merging for image classification#537
rensortino wants to merge 9 commits intoPrunaAI:mainfrom
rensortino:feat/token-merging

Conversation

@rensortino
Copy link
Contributor

Description

This PR introduces the Token Merging (ToMe) algorithm for HuggingFace Vision Transformer models. Token Merging progressively merges similar tokens between the attention and MLP stages of each transformer block, significantly reducing the number of tokens and speeding up inference with minimal quality loss.

Using model google/vit-base-patch16-224, speedup is over 2x with r=8.

Key Changes:

Token Merging Algorithm:

  • Implements the ToMe algorithm adapted from facebook/ToMe paper
  • Custom ViT module classes (ToMeViTLayer, ToMeViTSelfAttention) that extend HuggingFace transformers
  • Supports proportional attention weighting based on merged token sizes
  • Bipartite soft matching for intelligent token pair selection
  • Configurable token reduction schedule with per-layer control
  • Model wrapper for state management across forward passes

Testing Infrastructure:

  • Added ViT model fixtures for comprehensive testing
  • Token Merging test class with validation scenarios

Related Issue

Fixes #399

Type of Change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

How Has This Been Tested?

  • Token Merging algorithm tested with HuggingFace ViT models
  • Test fixtures added for google/vit-base-patch16-224 model family
  • Integration tests verify proper token reduction and attention output handling
  • Validated compatibility with existing Pruna pipeline

Implementation Details

Token Merging Core Features:

  1. Bipartite Soft Matching: Intelligently selects which token pairs to merge based on key similarity
  2. Proportional Attention: Adjusts attention weights by the log of merged token sizes
  3. Configurable Reduction Schedule:
    • Constant r across all layers
    • Per-layer list specification
    • Inflection-based schedules (increasing/decreasing/constant)
  4. Class Swapping Pattern: Dynamically replaces HF module classes at runtime to inject ToMe behavior
  5. Metric Storage: Uses key layer mean as similarity metric for matching

Hyperparameters:

  • r (int, 0-128): Number of tokens to merge per layer (default: 16)
  • trace_source (bool): Track merge provenance for visualization
  • prop_attn (bool): Enable proportional attention weighting (default: True)

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Additional Notes

Design Decisions:

  1. Module-level class definitions: ToMeViTLayer and related classes are defined at module level (not inside methods) to ensure they are picklable for distributed training and model serialization.

  2. Eager attention enforcement: The ToMeViTSelfAttention class uses eager attention computation to inject the proportional attention bias between QK matmul and softmax operations.

  3. Shared mutable state: All ToMe modules share a single tome_info dict for efficient state management across layers.

Future Enhancements:

  • Extension to other transformer architectures (Flux, SAM, etc.)
  • Support for custom attention mechanisms

References:

@sdiazlor sdiazlor requested a review from llcnt February 17, 2026 14:38
Copy link
Collaborator

@llcnt llcnt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the nice contribution :):)
Could you provide a working example of this new integration please (eg. a script or a notebook)? I tried to run it with a ViTForImageClassification but it fails (see comment below). I tried to run it with a pipeline from transformers, now the smashing works, but the inference fails: the base pipeline can accept str (url of images) or raw images. But the smashed pipeline can not. It would be nice to fix this so that the base model and the smashed one behave similarly. I tried to follow what you did in the new test by preprocessing the image before feeding it into the smashed pipeline, but I still get the error: TypeError: ViTAttention.forward() got an unexpected keyword argument 'output_attentions'.
Also could you fix the cursot[bot] comments (some of them are quite relevant ;) ) ?
Thx in advance!

@github-actions
Copy link

github-actions bot commented Mar 6, 2026

This PR has been inactive for 10 days and is now marked as stale.

@github-actions github-actions bot added the stale label Mar 6, 2026
@rensortino
Copy link
Contributor Author

Hi, thanks a lot for the feedback! I am working on the issues raised by Bugbot and will provide you shortly a notebook with a basic example on how to test the algorithm

@rensortino rensortino force-pushed the feat/token-merging branch from a6a79ec to be3a6ed Compare March 6, 2026 17:23
@rensortino
Copy link
Contributor Author

Here you can find a notebook to test the algorithm on HF models and pipelines.

@github-actions github-actions bot removed the stale label Mar 7, 2026
@github-actions
Copy link

This PR has been inactive for 10 days and is now marked as stale.

@github-actions github-actions bot added the stale label Mar 17, 2026
@llcnt
Copy link
Collaborator

llcnt commented Mar 18, 2026

bugbot run

@llcnt llcnt removed the stale label Mar 18, 2026
Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Fix All in Cursor

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Comment @cursor review or bugbot run to trigger another review on this PR


def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Initialise ToMe state and forward through the wrapped model."""
self._tome_info["r"] = self.parsed_r
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forward pass mutates parsed_r, breaking subsequent inferences

High Severity

ToMeModelWrapper.forward assigns self.parsed_r directly to self._tome_info["r"], making them the same list object. Each ToMeViTLayer.forward then calls .pop(0) on that list, which destructively mutates self.parsed_r. After the first forward pass, self.parsed_r is empty. Any subsequent forward call assigns the empty list, and the first layer's .pop(0) raises an IndexError. The assignment needs to be a copy, e.g. list(self.parsed_r).

Additional Locations (1)
Fix in Cursor Fix in Web

Copy link
Collaborator

@llcnt llcnt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay, I was stuck with other tasks :(
Thanks for the updates, and for the notebook!
I was not able to run all the notebook though, because:

  • is_vit is not implemented. I guess you have added it lately, but did not re-run the notebook;
  • TypeError: ViTAttention.forward() got an unexpected keyword argument 'output_attentions' pops out when I run inference of the smashed model. I guess because you are using an old version of transformers. Can you print it ? And make sure it is compatible with newer version (4.56.0 is a good starting point I would say) ;)

Thank you again for your contribution:)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[FEATURE] Implement Token Merging

2 participants