-
Notifications
You must be signed in to change notification settings - Fork 31.4k
adds jais2 model support #42684
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
adds jais2 model support #42684
Conversation
|
Hi @sarathc-cerebras, thank you for the PR! The main thing missing is a conversion to modular format. You can look at the modular files for other models to see how it works, but it reduces the size of the PR a lot by importing duplicated code from other models. |
377e2b8 to
ab785fc
Compare
|
@Rocketknight1 thanks for bringing this up, i have updated it to use the modular format |
2ae7204 to
672e38a
Compare
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
9e0839b to
7dfa45e
Compare
Rocketknight1
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this looks good! I made a few comments but they're small.
a363e45 to
e363470
Compare
vasqu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some comments, I think we can still simplify a bit and update a few things to be up to date with our current standards. Overall, looking really good already tho
2f9713c to
5090c18
Compare
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM good review @vasqu small nits but let's go!
| generated_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) | ||
| print(f"Static cache generated text: {generated_text}") | ||
|
|
||
| self.assertGreater(generated_ids.shape[1], input_ids.shape[1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be better to have explicit expected outputs here!
efed368 to
f4a67f3
Compare
vasqu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please check out the comments from the last review, mostly nits otherwise and let's make the tests more explicit (I've linked an example in one of the review comments)
Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
10a5980 to
bf97684
Compare
|
[For maintainers] Suggested jobs to run (before merge) run-slow: auto, jais2 |
|
View the CircleCI Test Summary for this PR: https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=42684&sha=9398dd |
vasqu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Last comments from my side (I hope), small fixes and finishing touches
| End of stream token id. | ||
| pretraining_tp (`int`, *optional*, defaults to 1): | ||
| Tensor parallelism rank used during pretraining. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| End of stream token id. | |
| pretraining_tp (`int`, *optional*, defaults to 1): | |
| Tensor parallelism rank used during pretraining. | |
| End of stream token id. |
TP is no longer handled that way
| The attention head dimension. | ||
| rope_theta (`float`, *optional*, defaults to 500000.0): | ||
| The base period of the RoPE embeddings. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| The attention head dimension. | |
| rope_theta (`float`, *optional*, defaults to 500000.0): | |
| The base period of the RoPE embeddings. | |
| The attention head dimension. |
Let's move this to default_theta:
| default_theta = 12000000.0 |
| pad_token_id: Optional[int] = None, | ||
| bos_token_id: Optional[int] = 0, | ||
| eos_token_id: Optional[int] = 150024, | ||
| pretraining_tp: Optional[int] = 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| pretraining_tp: Optional[int] = 1, |
| # If rope_parameters not provided, create default with rope_theta | ||
| if rope_parameters is None: | ||
| rope_parameters = RopeParameters(rope_theta=rope_theta) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not need this, we have a mixin in the config that should handle this for us
| The RoPE parameters. | ||
| """ | ||
|
|
||
| model_type = "jais2" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry seems like I was wrong about the TP plan, I didn't notice that we have a different MLP. Can you readd the correct version
| model = Jais2ForCausalLM.from_pretrained( | ||
| "inceptionai/Jais-2-8B-Chat", torch_dtype=torch.float16, device_map="auto" | ||
| ) | ||
| input_text = "The capital of France is" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we find something that generates more tokens, e.g. 32 tokens? This is a bit few tokens so let's make the test a bit more sensible to changes
What does this PR do?
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.