-
Notifications
You must be signed in to change notification settings - Fork 75
feat: Sage Attention Algorithm #455
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
e2e6e9f to
69c9679
Compare
johannaSommer
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First PR and already almost flawless, big 👏🏻👏🏻👏🏻 coming your way soon!
src/pruna/algorithms/sage_attn.py
Outdated
| runs_on: list[str] = ["cuda", "accelerate"] | ||
| dataset_required: bool = False | ||
| compatible_before: Iterable[str] = [] | ||
| compatible_after: Iterable[str] = ["torch_compile"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
compatible after would also be tags.CACHERS and compatible before probably also tags.QUANTIZERS
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
then add this compatibility also in other algorithms
src/pruna/algorithms/sage_attn.py
Outdated
| return False | ||
|
|
||
| return any( | ||
| hasattr(component, "set_attention_backend") and component.dtype in [torch.bfloat16, torch.float16] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i recall this dtype check for the components from flash attention (because attention needs to be computed in this precision for FA3 to work), did we double check that that is the case also here?
src/pruna/algorithms/sage_attn.py
Outdated
| # We simply apply the sage attention backend from diffusers | ||
| # Furthermore, we use the sage attention kernel from the hub as the default sageattn function | ||
| # is broken (at least at the moment) | ||
| for component in model.components.values(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as discussed, let's add target modules also here :)
…antizers as compatible after and before, add sage_attn in corresponding cachers and quantizers algorithms as compatible, add dtype check as sage_attn only works for float/bfloat16 (double checked), add target modules (but not fully finished yet)
…ast attention block per attention component. Remove dtype gaurd as dtypes of q, k, and v per attn module is implicitly checked by sage attention kernel.
src/pruna/algorithms/sage_attn.py
Outdated
| configuration system. | ||
| """ | ||
| return [ | ||
| Boolean( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is actually not needed and we can remove it, as the user can specify this exactly through the target modules anyway (there is a smash config interface for this)
src/pruna/algorithms/sage_attn.py
Outdated
| The wrapped model. | ||
| """ | ||
| target_modules = smash_config["target_modules"] | ||
| exclude_first_and_last_transformer_blocks = smash_config["exclude_first_and_last_transformer_blocks"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for the target modules, let's please use the functionality we already have, otherwise we have a lot of duplicate code here
Description
Integration of the Sage Attention algorithm into the Pruna framework. The current version applies the attention backend from Diffusers, choosing the Sage Attention kernel from the Kernel Hub. This is because the original sageattn function appears to be broken (its outputs were pure noise). Additionally, tests for the Sage Attention algorithm were implemented.
Related Issue
No issues were fixed.
Type of Change
How Has This Been Tested?
Reuse of the tests for flashattn3 adapted to sage attention.
Checklist
Additional Notes
/