Skip to content

Commit 672e38a

Browse files
adds jais2 model support
1 parent 85ced0f commit 672e38a

File tree

10 files changed

+1421
-0
lines changed

10 files changed

+1421
-0
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,8 @@
547547
title: HunYuanMoEV1
548548
- local: model_doc/ibert
549549
title: I-BERT
550+
- local: model_doc/jais2
551+
title: Jais2
550552
- local: model_doc/jamba
551553
title: Jamba
552554
- local: model_doc/jetmoe

docs/source/en/model_doc/jais2.md

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
*This model was released on {release_date} and added to Hugging Face Transformers on 2025-12-09.*
17+
18+
# Jais2
19+
20+
## Overview
21+
22+
Jais2 is a large language model developed by MBZUAI, Inception and Cerebras Systems. It is based on the transformer architecture with several modifications including:
23+
24+
- LayerNorm instead of RMSNorm
25+
- ReLU² activation function
26+
- Rotary Position Embeddings (RoPE)
27+
28+
## Jais2Config
29+
30+
[[autodoc]] Jais2Config
31+
32+
## Jais2Model
33+
34+
[[autodoc]] Jais2Model
35+
- forward
36+
37+
## Jais2ForCausalLM
38+
39+
[[autodoc]] Jais2ForCausalLM
40+
- forward
41+
42+
## Jais2ForSequenceClassification
43+
44+
[[autodoc]] Jais2ForSequenceClassification
45+
- forward
46+
47+
## Jais2ForTokenClassification
48+
49+
[[autodoc]] Jais2ForTokenClassification
50+
- forward
51+
52+
## Jais2ForQuestionAnswering
53+
54+
[[autodoc]] Jais2ForQuestionAnswering
55+
- forward

src/transformers/models/auto/configuration_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@
215215
("instructblipvideo", "InstructBlipVideoConfig"),
216216
("internvl", "InternVLConfig"),
217217
("internvl_vision", "InternVLVisionConfig"),
218+
("jais2", "Jais2Config"),
218219
("jamba", "JambaConfig"),
219220
("janus", "JanusConfig"),
220221
("jetmoe", "JetMoeConfig"),
@@ -658,6 +659,7 @@
658659
("instructblipvideo", "InstructBlipVideo"),
659660
("internvl", "InternVL"),
660661
("internvl_vision", "InternVLVision"),
662+
("jais2", "Jais2"),
661663
("jamba", "Jamba"),
662664
("janus", "Janus"),
663665
("jetmoe", "JetMoe"),

src/transformers/models/auto/modeling_auto.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
216216
("instructblipvideo", "InstructBlipVideoModel"),
217217
("internvl", "InternVLModel"),
218218
("internvl_vision", "InternVLVisionModel"),
219+
("jais2", "Jais2Model"),
219220
("jamba", "JambaModel"),
220221
("janus", "JanusModel"),
221222
("jetmoe", "JetMoeModel"),
@@ -689,6 +690,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
689690
("helium", "HeliumForCausalLM"),
690691
("hunyuan_v1_dense", "HunYuanDenseV1ForCausalLM"),
691692
("hunyuan_v1_moe", "HunYuanMoEV1ForCausalLM"),
693+
("jais2", "Jais2ForCausalLM"),
692694
("jamba", "JambaForCausalLM"),
693695
("jetmoe", "JetMoeForCausalLM"),
694696
("lfm2", "Lfm2ForCausalLM"),
@@ -1245,6 +1247,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
12451247
("hunyuan_v1_dense", "HunYuanDenseV1ForSequenceClassification"),
12461248
("hunyuan_v1_moe", "HunYuanMoEV1ForSequenceClassification"),
12471249
("ibert", "IBertForSequenceClassification"),
1250+
("jais2", "Jais2ForSequenceClassification"),
12481251
("jamba", "JambaForSequenceClassification"),
12491252
("jetmoe", "JetMoeForSequenceClassification"),
12501253
("layoutlm", "LayoutLMForSequenceClassification"),
@@ -1343,6 +1346,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
13431346
("gpt_neox", "GPTNeoXForQuestionAnswering"),
13441347
("gptj", "GPTJForQuestionAnswering"),
13451348
("ibert", "IBertForQuestionAnswering"),
1349+
("jais2", "Jais2ForQuestionAnswering"),
13461350
("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
13471351
("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
13481352
("led", "LEDForQuestionAnswering"),
@@ -1458,6 +1462,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
14581462
("gpt_oss", "GptOssForTokenClassification"),
14591463
("helium", "HeliumForTokenClassification"),
14601464
("ibert", "IBertForTokenClassification"),
1465+
("jais2", "Jais2ForTokenClassification"),
14611466
("layoutlm", "LayoutLMForTokenClassification"),
14621467
("layoutlmv2", "LayoutLMv2ForTokenClassification"),
14631468
("layoutlmv3", "LayoutLMv3ForTokenClassification"),
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from typing import TYPE_CHECKING
2+
3+
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
4+
5+
6+
_import_structure = {
7+
"configuration_jais2": ["Jais2Config"],
8+
}
9+
10+
try:
11+
if not is_torch_available():
12+
raise OptionalDependencyNotAvailable()
13+
except OptionalDependencyNotAvailable:
14+
pass
15+
else:
16+
_import_structure["modeling_jais2"] = [
17+
"Jais2ForCausalLM",
18+
"Jais2ForQuestionAnswering",
19+
"Jais2ForSequenceClassification",
20+
"Jais2ForTokenClassification",
21+
"Jais2Model",
22+
"Jais2PreTrainedModel",
23+
]
24+
25+
26+
if TYPE_CHECKING:
27+
from .configuration_jais2 import Jais2Config
28+
29+
try:
30+
if not is_torch_available():
31+
raise OptionalDependencyNotAvailable()
32+
except OptionalDependencyNotAvailable:
33+
pass
34+
else:
35+
from .modeling_jais2 import (
36+
Jais2ForCausalLM,
37+
Jais2ForQuestionAnswering,
38+
Jais2ForSequenceClassification,
39+
Jais2ForTokenClassification,
40+
Jais2Model,
41+
Jais2PreTrainedModel,
42+
)
43+
44+
else:
45+
import sys
46+
47+
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2+
# This file was automatically generated from src/transformers/models/jais2/modular_jais2.py.
3+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
4+
# the file from the modular. If any change should be done, please apply the change to the
5+
# modular_jais2.py file directly. One of our CI enforces this.
6+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7+
# coding=utf-8
8+
# Copyright 2025 the HuggingFace Team. All rights reserved.
9+
#
10+
# Licensed under the Apache License, Version 2.0 (the "License");
11+
# you may not use this file except in compliance with the License.
12+
# You may obtain a copy of the License at
13+
#
14+
# http://www.apache.org/licenses/LICENSE-2.0
15+
#
16+
# Unless required by applicable law or agreed to in writing, software
17+
# distributed under the License is distributed on an "AS IS" BASIS,
18+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19+
# See the License for the specific language governing permissions and
20+
# limitations under the License.
21+
22+
from typing import Optional
23+
24+
from ...configuration_utils import PreTrainedConfig
25+
from ...modeling_rope_utils import RopeParameters
26+
27+
28+
class Jais2Config(PreTrainedConfig):
29+
r"""
30+
This is the configuration class to store the configuration of a [`Jais2Model`].
31+
It inherits from [`LlamaConfig`] and can be used to control the model outputs.
32+
33+
Read more from the [inceptionai/Jais-2-8B-Chat](https://huggingface.co/inceptionai/Jais-2-8B-Chat).
34+
35+
Args:
36+
vocab_size (`int`, *optional*, defaults to 150272):
37+
Vocabulary size of the Jais2 model.
38+
hidden_size (`int`, *optional*, defaults to 3328):
39+
Dimension of the hidden representations.
40+
intermediate_size (`int`, *optional*, defaults to 26624):
41+
Dimension of the MLP representations.
42+
num_hidden_layers (`int`, *optional*, defaults to 32):
43+
Number of hidden layers in the Transformer decoder.
44+
num_attention_heads (`int`, *optional*, defaults to 26):
45+
Number of attention heads for each attention layer.
46+
num_key_value_heads (`int`, *optional*):
47+
Number of key_value heads for Grouped Query Attention.
48+
hidden_act (`str`, *optional*, defaults to `"relu2"`):
49+
The non-linear activation function in the decoder.
50+
max_position_embeddings (`int`, *optional*, defaults to 8192):
51+
The maximum sequence length.
52+
initializer_range (`float`, *optional*, defaults to 0.02):
53+
The standard deviation of the truncated_normal_initializer.
54+
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
55+
The epsilon used by the normalization layers.
56+
use_cache (`bool`, *optional*, defaults to `True`):
57+
Whether to return last key/values attentions.
58+
pad_token_id (`int`, *optional*):
59+
Padding token id.
60+
bos_token_id (`int`, *optional*, defaults to 0):
61+
Beginning of stream token id.
62+
eos_token_id (`int`, *optional*, defaults to 150024):
63+
End of stream token id.
64+
pretraining_tp (`int`, *optional*, defaults to 1):
65+
Tensor parallelism rank used during pretraining.
66+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
67+
Whether to tie weight embeddings.
68+
attention_bias (`bool`, *optional*, defaults to `True`):
69+
Whether to use a bias in the query, key, value and output projection layers.
70+
attention_dropout (`float`, *optional*, defaults to 0.0):
71+
The dropout ratio for the attention probabilities.
72+
mlp_bias (`bool`, *optional*, defaults to `True`):
73+
Whether to use a bias in up_proj, down_proj and gate_proj layers.
74+
head_dim (`int`, *optional*):
75+
The attention head dimension.
76+
rope_theta (`float`, *optional*, defaults to 500000.0):
77+
The base period of the RoPE embeddings.
78+
rope_parameters (`dict`, *optional*):
79+
The RoPE parameters.
80+
"""
81+
82+
model_type = "jais2"
83+
keys_to_ignore_at_inference = ["past_key_values"]
84+
85+
base_model_tp_plan = {
86+
"layers.*.self_attn.q_proj": "colwise",
87+
"layers.*.self_attn.k_proj": "colwise",
88+
"layers.*.self_attn.v_proj": "colwise",
89+
"layers.*.self_attn.o_proj": "rowwise",
90+
"layers.*.mlp.up_proj": "colwise",
91+
"layers.*.mlp.down_proj": "rowwise",
92+
}
93+
base_model_pp_plan = {
94+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
95+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
96+
"norm": (["hidden_states"], ["hidden_states"]),
97+
}
98+
99+
def __init__(
100+
self,
101+
vocab_size: Optional[int] = 150272,
102+
hidden_size: Optional[int] = 3328,
103+
intermediate_size: Optional[int] = 26624,
104+
num_hidden_layers: Optional[int] = 32,
105+
num_attention_heads: Optional[int] = 26,
106+
num_key_value_heads: Optional[int] = None,
107+
hidden_act: Optional[str] = "relu2",
108+
max_position_embeddings: Optional[int] = 8192,
109+
initializer_range: Optional[float] = 0.02,
110+
layer_norm_eps: Optional[float] = 1e-5,
111+
use_cache: Optional[bool] = True,
112+
pad_token_id: Optional[int] = None,
113+
bos_token_id: Optional[int] = 0,
114+
eos_token_id: Optional[int] = 150024,
115+
pretraining_tp: Optional[int] = 1,
116+
tie_word_embeddings: Optional[bool] = False,
117+
attention_bias: Optional[bool] = True,
118+
attention_dropout: Optional[float] = 0.0,
119+
mlp_bias: Optional[bool] = True,
120+
head_dim: Optional[int] = None,
121+
rope_theta: Optional[float] = 500000.0,
122+
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
123+
**kwargs,
124+
):
125+
# If rope_parameters not provided, create default with rope_theta
126+
if rope_parameters is None:
127+
rope_parameters = RopeParameters(rope_theta=rope_theta)
128+
129+
# Define rms_norm_eps for the parent init to use
130+
rms_norm_eps = layer_norm_eps
131+
self.vocab_size = vocab_size
132+
self.max_position_embeddings = max_position_embeddings
133+
self.hidden_size = hidden_size
134+
self.intermediate_size = intermediate_size
135+
self.num_hidden_layers = num_hidden_layers
136+
self.num_attention_heads = num_attention_heads
137+
138+
# for backward compatibility
139+
if num_key_value_heads is None:
140+
num_key_value_heads = num_attention_heads
141+
142+
self.num_key_value_heads = num_key_value_heads
143+
self.hidden_act = hidden_act
144+
self.initializer_range = initializer_range
145+
self.rms_norm_eps = rms_norm_eps
146+
self.pretraining_tp = pretraining_tp
147+
self.use_cache = use_cache
148+
self.attention_bias = attention_bias
149+
self.attention_dropout = attention_dropout
150+
self.mlp_bias = mlp_bias
151+
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
152+
self.rope_parameters = rope_parameters
153+
154+
super().__init__(
155+
pad_token_id=pad_token_id,
156+
bos_token_id=bos_token_id,
157+
eos_token_id=eos_token_id,
158+
tie_word_embeddings=tie_word_embeddings,
159+
**kwargs,
160+
)
161+
# Rename the attribute from rms_norm_eps to layer_norm_eps
162+
self.layer_norm_eps = self.rms_norm_eps
163+
164+
# Validate and standardize RoPE parameters
165+
self.standardize_rope_params()
166+
self.validate_rope()
167+
168+
169+
__all__ = ["Jais2Config"]

0 commit comments

Comments
 (0)