Skip to content

Commit 85ced0f

Browse files
SunMarcMekkCyber
andauthored
Update replace_with_ for quants methods to not use recursion (#42711)
* Fix replace * fix bnb * fix * style * fix * fix * styke * fix * style * Apply suggestions from code review Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> --------- Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
1 parent ec37fc8 commit 85ced0f

File tree

13 files changed

+326
-604
lines changed

13 files changed

+326
-604
lines changed

src/transformers/integrations/aqlm.py

Lines changed: 38 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -13,88 +13,60 @@
1313
# limitations under the License.
1414
"AQLM (Additive Quantization of Language Model) integration file"
1515

16-
from ..utils import ACCELERATE_MIN_VERSION, is_accelerate_available, is_aqlm_available, is_torch_available
16+
from ..quantizers.quantizers_utils import should_convert_module
17+
from ..utils import is_accelerate_available, is_torch_available, logging
1718

1819

20+
if is_accelerate_available():
21+
from accelerate import init_empty_weights
22+
1923
if is_torch_available():
2024
import torch.nn as nn
2125

26+
logger = logging.get_logger(__name__)
2227

23-
def replace_with_aqlm_linear(
24-
model,
25-
quantization_config=None,
26-
linear_weights_not_to_quantize=None,
27-
current_key_name=None,
28-
has_been_replaced=False,
29-
):
28+
29+
def replace_with_aqlm_linear(model, modules_to_not_convert: list[str] | None = None, quantization_config=None):
3030
"""
3131
Public method that recursively replaces the Linear layers of the given model with AQLM quantized layers.
32-
`accelerate` is needed to use this method. Returns the converted model and a boolean that indicates if the
33-
conversion has been successful or not.
3432
3533
Args:
3634
model (`torch.nn.Module`):
3735
The model to convert, can be any `torch.nn.Module` instance.
38-
quantization_config (`AqlmConfig`):
39-
The quantization config object that contains the quantization parameters.
40-
linear_weights_not_to_quantize (`list[str]`, *optional*):
36+
modules_to_not_convert (`list[str]`, *optional*, defaults to `None`):
4137
A list of nn.Linear weights to not convert. If a parameter path is in the list (e.g. `lm_head.weight`), the corresponding module will not be
4238
converted.
43-
current_key_name (`list`, *optional*):
44-
A list that contains the current key name. This is used for recursion and should not be passed by the user.
45-
has_been_replaced (`bool`, *optional*):
46-
A boolean that indicates if the conversion has been successful or not. This is used for recursion and
47-
should not be passed by the user.
39+
quantization_config (`AqlmConfig`):
40+
The quantization config object that contains the quantization parameters.
4841
"""
49-
if not is_aqlm_available():
50-
raise ValueError("AQLM is not available. Please install it with `pip install aqlm[cpu,gpu]`")
51-
52-
if not is_accelerate_available():
53-
raise ValueError(
54-
f"AQLM requires Accelerate to be installed: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
55-
)
56-
57-
if linear_weights_not_to_quantize is None:
58-
linear_weights_not_to_quantize = []
59-
60-
from accelerate import init_empty_weights
6142
from aqlm import QuantizedLinear
6243

63-
for name, module in model.named_children():
64-
if current_key_name is None:
65-
current_key_name = []
66-
current_key_name.append(name)
67-
68-
if isinstance(module, nn.Linear):
69-
# Check if the current key is not in the `linear_weights_not_to_quantize`
70-
if ".".join(current_key_name) + ".weight" not in linear_weights_not_to_quantize:
71-
with init_empty_weights():
72-
in_features = module.in_features
73-
out_features = module.out_features
44+
has_been_replaced = False
45+
# we need this to correctly materialize the weights during quantization
46+
for module_name, module in model.named_modules():
47+
if not should_convert_module(module_name, modules_to_not_convert):
48+
continue
49+
with init_empty_weights():
50+
if isinstance(module, nn.Linear):
51+
new_module = QuantizedLinear(
52+
module.in_features,
53+
module.out_features,
54+
bias=module.bias is not None,
55+
in_group_size=quantization_config.in_group_size,
56+
out_group_size=quantization_config.out_group_size,
57+
num_codebooks=quantization_config.num_codebooks,
58+
nbits_per_codebook=quantization_config.nbits_per_codebook,
59+
)
60+
new_module.source_cls = type(module)
61+
new_module.requires_grad_(False)
62+
model.set_submodule(module_name, new_module)
63+
has_been_replaced = True
7464

75-
model._modules[name] = QuantizedLinear(
76-
in_features,
77-
out_features,
78-
bias=module.bias is not None,
79-
in_group_size=quantization_config.in_group_size,
80-
out_group_size=quantization_config.out_group_size,
81-
num_codebooks=quantization_config.num_codebooks,
82-
nbits_per_codebook=quantization_config.nbits_per_codebook,
83-
)
84-
has_been_replaced = True
65+
if not has_been_replaced:
66+
logger.warning(
67+
"You are loading your model using eetq but no linear modules were found in your model."
68+
" Please double check your model architecture, or submit an issue on github if you think this is"
69+
" a bug."
70+
)
8571

86-
# Store the module class in case we need to transpose the weight later
87-
model._modules[name].source_cls = type(module)
88-
# Force requires grad to False to avoid unexpected errors
89-
model._modules[name].requires_grad_(False)
90-
if len(list(module.children())) > 0:
91-
_, has_been_replaced = replace_with_aqlm_linear(
92-
module,
93-
quantization_config=quantization_config,
94-
linear_weights_not_to_quantize=linear_weights_not_to_quantize,
95-
current_key_name=current_key_name,
96-
has_been_replaced=has_been_replaced,
97-
)
98-
# Remove the last key for recursion
99-
current_key_name.pop(-1)
100-
return model, has_been_replaced
72+
return model

src/transformers/integrations/bitnet.py

Lines changed: 45 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from ..quantizers.quantizers_utils import should_convert_module
12
from ..utils import is_accelerate_available, is_torch_available, logging
23

34

@@ -314,113 +315,57 @@ def forward(self, input):
314315
return output
315316

316317

317-
def _replace_with_bitnet_linear(
318-
model,
319-
modules_to_not_convert=None,
320-
current_key_name=None,
321-
quantization_config=None,
322-
has_been_replaced=False,
323-
pre_quantized=False,
324-
):
318+
def replace_with_bitnet_linear(model, modules_to_not_convert: list[str] | None = None, quantization_config=None):
325319
"""
326-
Private method that wraps the recursion for module replacement.
320+
Public method that replaces the linear layers of the given model with bitnet quantized layers.
327321
328-
Returns the converted model and a boolean that indicates if the conversion has been successful or not.
329-
"""
330-
331-
if current_key_name is None:
332-
current_key_name = []
333-
334-
for name, module in model.named_children():
335-
if current_key_name is None:
336-
current_key_name = []
337-
current_key_name.append(name)
338-
339-
# Check if the current key is not in the `modules_to_not_convert`
340-
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
341-
with init_empty_weights():
342-
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
343-
in_features = module.in_features
344-
out_features = module.out_features
345-
if quantization_config and quantization_config.linear_class == "autobitlinear":
346-
model._modules[name] = AutoBitLinear(
347-
in_features=in_features,
348-
out_features=out_features,
349-
bias=module.bias is not None,
350-
device=module.weight.device,
351-
dtype=module.weight.dtype,
352-
online_quant=(quantization_config.quantization_mode == "online"),
353-
use_rms_norm=quantization_config.use_rms_norm,
354-
rms_norm_eps=quantization_config.rms_norm_eps,
355-
)
356-
if quantization_config.quantization_mode == "offline":
357-
model._modules[name].requires_grad_(False)
358-
else:
359-
model._modules[name] = BitLinear(
360-
in_features=in_features,
361-
out_features=out_features,
362-
bias=module.bias is not None,
363-
device=module.weight.device,
364-
dtype=module.weight.dtype,
365-
use_rms_norm=quantization_config.use_rms_norm if quantization_config else False,
366-
rms_norm_eps=quantization_config.rms_norm_eps if quantization_config else 1e-6,
367-
)
368-
model._modules[name].requires_grad_(False)
369-
has_been_replaced = True
370-
371-
if len(list(module.children())) > 0:
372-
_, has_been_replaced = _replace_with_bitnet_linear(
373-
module,
374-
modules_to_not_convert=modules_to_not_convert,
375-
current_key_name=current_key_name,
376-
quantization_config=quantization_config,
377-
has_been_replaced=has_been_replaced,
378-
)
379-
# Remove the last key for recursion
380-
current_key_name.pop(-1)
381-
return model, has_been_replaced
382-
383-
384-
def replace_with_bitnet_linear(
385-
model,
386-
modules_to_not_convert=None,
387-
current_key_name=None,
388-
quantization_config=None,
389-
pre_quantized=False,
390-
):
391-
"""
392-
A helper function to replace all `torch.nn.Linear` modules by `BitLinear158` modules`.
393-
394-
The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should
395-
be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no
396-
CPU/GPU memory is required to run this function. Each weight will be quantized along the channel.
397-
398-
Parameters:
322+
Args:
399323
model (`torch.nn.Module`):
400-
Input model or `torch.nn.Module` as the function is run recursively.
401-
modules_to_not_convert (`list[`str`]`, *optional*, defaults to `["lm_head"]`):
402-
Names of the modules to not convert in `BitLinear`. In practice we keep the `lm_head` in full precision
403-
for numerical stability reasons.
404-
current_key_name (`list[`str`]`, *optional*):
405-
An array to track the current key of the recursion. This is used to check whether the current key (part of
406-
it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
407-
`disk`).
324+
The model to convert, can be any `torch.nn.Module` instance.
325+
modules_to_not_convert (`list[str]`, *optional*, defaults to `None`):
326+
A list of nn.Linear weights to not convert. If a parameter path is in the list (e.g. `lm_head.weight`), the corresponding module will not be
327+
converted.
328+
quantization_config (`BitNetConfig`):
329+
The quantization config object that contains the quantization parameters.
408330
"""
409-
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
410-
if quantization_config and quantization_config.modules_to_not_convert is not None:
411-
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
412-
modules_to_not_convert = list(set(modules_to_not_convert))
413-
model, has_been_replaced = _replace_with_bitnet_linear(
414-
model,
415-
modules_to_not_convert,
416-
current_key_name,
417-
quantization_config,
418-
pre_quantized=pre_quantized,
419-
)
331+
332+
has_been_replaced = False
333+
# we need this to correctly materialize the weights during quantization
334+
for module_name, module in model.named_modules():
335+
if not should_convert_module(module_name, modules_to_not_convert):
336+
continue
337+
with init_empty_weights():
338+
if isinstance(module, nn.Linear):
339+
if quantization_config and quantization_config.linear_class == "autobitlinear":
340+
new_module = AutoBitLinear(
341+
in_features=module.in_features,
342+
out_features=module.out_features,
343+
bias=module.bias is not None,
344+
device=module.weight.device,
345+
dtype=module.weight.dtype,
346+
online_quant=(quantization_config.quantization_mode == "online"),
347+
use_rms_norm=quantization_config.use_rms_norm,
348+
rms_norm_eps=quantization_config.rms_norm_eps,
349+
)
350+
if quantization_config.quantization_mode == "offline":
351+
new_module.requires_grad_(False)
352+
else:
353+
new_module = BitLinear(
354+
in_features=module.in_features,
355+
out_features=module.out_features,
356+
bias=module.bias is not None,
357+
device=module.weight.device,
358+
dtype=module.weight.dtype,
359+
use_rms_norm=quantization_config.use_rms_norm if quantization_config else False,
360+
rms_norm_eps=quantization_config.rms_norm_eps if quantization_config else 1e-6,
361+
)
362+
new_module.requires_grad_(False)
363+
model.set_submodule(module_name, new_module)
364+
has_been_replaced = True
420365

421366
if not has_been_replaced:
422367
logger.warning(
423-
"You are loading your model using bitnet but no linear modules were found in your model."
368+
"You are loading your model using eetq but no linear modules were found in your model."
424369
" Please double check your model architecture, or submit an issue on github if you think this is"
425370
" a bug."
426371
)

0 commit comments

Comments
 (0)