Skip to content

Conversation

@Marius-Graml
Copy link

Description

Integration of the Sage Attention algorithm into the Pruna framework. The current version applies the attention backend from Diffusers, choosing the Sage Attention kernel from the Kernel Hub. This is because the original sageattn function appears to be broken (its outputs were pure noise). Additionally, tests for the Sage Attention algorithm were implemented.

Related Issue

No issues were fixed.

Type of Change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

How Has This Been Tested?

Reuse of the tests for flashattn3 adapted to sage attention.

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Additional Notes

/

@Marius-Graml Marius-Graml changed the title feat/sage attn feat: Sage Attention Algorithm Dec 8, 2025
Copy link
Member

@johannaSommer johannaSommer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First PR and already almost flawless, big 👏🏻👏🏻👏🏻 coming your way soon!

runs_on: list[str] = ["cuda", "accelerate"]
dataset_required: bool = False
compatible_before: Iterable[str] = []
compatible_after: Iterable[str] = ["torch_compile"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compatible after would also be tags.CACHERS and compatible before probably also tags.QUANTIZERS

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then add this compatibility also in other algorithms

return False

return any(
hasattr(component, "set_attention_backend") and component.dtype in [torch.bfloat16, torch.float16]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i recall this dtype check for the components from flash attention (because attention needs to be computed in this precision for FA3 to work), did we double check that that is the case also here?

# We simply apply the sage attention backend from diffusers
# Furthermore, we use the sage attention kernel from the hub as the default sageattn function
# is broken (at least at the moment)
for component in model.components.values():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as discussed, let's add target modules also here :)

…antizers as compatible after and before, add sage_attn in corresponding cachers and quantizers algorithms as compatible, add dtype check as sage_attn only works for float/bfloat16 (double checked), add target modules (but not fully finished yet)
…ast attention block per attention component. Remove dtype gaurd as dtypes of q, k, and v per attn module is implicitly checked by sage attention kernel.
configuration system.
"""
return [
Boolean(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is actually not needed and we can remove it, as the user can specify this exactly through the target modules anyway (there is a smash config interface for this)

The wrapped model.
"""
target_modules = smash_config["target_modules"]
exclude_first_and_last_transformer_blocks = smash_config["exclude_first_and_last_transformer_blocks"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the target modules, let's please use the functionality we already have, otherwise we have a lot of duplicate code here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants