|
| 1 | +from ..quantizers.quantizers_utils import should_convert_module |
1 | 2 | from ..utils import is_accelerate_available, is_torch_available, logging |
2 | 3 |
|
3 | 4 |
|
@@ -314,113 +315,57 @@ def forward(self, input): |
314 | 315 | return output |
315 | 316 |
|
316 | 317 |
|
317 | | -def _replace_with_bitnet_linear( |
318 | | - model, |
319 | | - modules_to_not_convert=None, |
320 | | - current_key_name=None, |
321 | | - quantization_config=None, |
322 | | - has_been_replaced=False, |
323 | | - pre_quantized=False, |
324 | | -): |
| 318 | +def replace_with_bitnet_linear(model, modules_to_not_convert: list[str] | None = None, quantization_config=None): |
325 | 319 | """ |
326 | | - Private method that wraps the recursion for module replacement. |
| 320 | + Public method that replaces the linear layers of the given model with bitnet quantized layers. |
327 | 321 |
|
328 | | - Returns the converted model and a boolean that indicates if the conversion has been successful or not. |
329 | | - """ |
330 | | - |
331 | | - if current_key_name is None: |
332 | | - current_key_name = [] |
333 | | - |
334 | | - for name, module in model.named_children(): |
335 | | - if current_key_name is None: |
336 | | - current_key_name = [] |
337 | | - current_key_name.append(name) |
338 | | - |
339 | | - # Check if the current key is not in the `modules_to_not_convert` |
340 | | - if not any(key in ".".join(current_key_name) for key in modules_to_not_convert): |
341 | | - with init_empty_weights(): |
342 | | - if isinstance(module, nn.Linear) and name not in modules_to_not_convert: |
343 | | - in_features = module.in_features |
344 | | - out_features = module.out_features |
345 | | - if quantization_config and quantization_config.linear_class == "autobitlinear": |
346 | | - model._modules[name] = AutoBitLinear( |
347 | | - in_features=in_features, |
348 | | - out_features=out_features, |
349 | | - bias=module.bias is not None, |
350 | | - device=module.weight.device, |
351 | | - dtype=module.weight.dtype, |
352 | | - online_quant=(quantization_config.quantization_mode == "online"), |
353 | | - use_rms_norm=quantization_config.use_rms_norm, |
354 | | - rms_norm_eps=quantization_config.rms_norm_eps, |
355 | | - ) |
356 | | - if quantization_config.quantization_mode == "offline": |
357 | | - model._modules[name].requires_grad_(False) |
358 | | - else: |
359 | | - model._modules[name] = BitLinear( |
360 | | - in_features=in_features, |
361 | | - out_features=out_features, |
362 | | - bias=module.bias is not None, |
363 | | - device=module.weight.device, |
364 | | - dtype=module.weight.dtype, |
365 | | - use_rms_norm=quantization_config.use_rms_norm if quantization_config else False, |
366 | | - rms_norm_eps=quantization_config.rms_norm_eps if quantization_config else 1e-6, |
367 | | - ) |
368 | | - model._modules[name].requires_grad_(False) |
369 | | - has_been_replaced = True |
370 | | - |
371 | | - if len(list(module.children())) > 0: |
372 | | - _, has_been_replaced = _replace_with_bitnet_linear( |
373 | | - module, |
374 | | - modules_to_not_convert=modules_to_not_convert, |
375 | | - current_key_name=current_key_name, |
376 | | - quantization_config=quantization_config, |
377 | | - has_been_replaced=has_been_replaced, |
378 | | - ) |
379 | | - # Remove the last key for recursion |
380 | | - current_key_name.pop(-1) |
381 | | - return model, has_been_replaced |
382 | | - |
383 | | - |
384 | | -def replace_with_bitnet_linear( |
385 | | - model, |
386 | | - modules_to_not_convert=None, |
387 | | - current_key_name=None, |
388 | | - quantization_config=None, |
389 | | - pre_quantized=False, |
390 | | -): |
391 | | - """ |
392 | | - A helper function to replace all `torch.nn.Linear` modules by `BitLinear158` modules`. |
393 | | -
|
394 | | - The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should |
395 | | - be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no |
396 | | - CPU/GPU memory is required to run this function. Each weight will be quantized along the channel. |
397 | | -
|
398 | | - Parameters: |
| 322 | + Args: |
399 | 323 | model (`torch.nn.Module`): |
400 | | - Input model or `torch.nn.Module` as the function is run recursively. |
401 | | - modules_to_not_convert (`list[`str`]`, *optional*, defaults to `["lm_head"]`): |
402 | | - Names of the modules to not convert in `BitLinear`. In practice we keep the `lm_head` in full precision |
403 | | - for numerical stability reasons. |
404 | | - current_key_name (`list[`str`]`, *optional*): |
405 | | - An array to track the current key of the recursion. This is used to check whether the current key (part of |
406 | | - it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or |
407 | | - `disk`). |
| 324 | + The model to convert, can be any `torch.nn.Module` instance. |
| 325 | + modules_to_not_convert (`list[str]`, *optional*, defaults to `None`): |
| 326 | + A list of nn.Linear weights to not convert. If a parameter path is in the list (e.g. `lm_head.weight`), the corresponding module will not be |
| 327 | + converted. |
| 328 | + quantization_config (`BitNetConfig`): |
| 329 | + The quantization config object that contains the quantization parameters. |
408 | 330 | """ |
409 | | - modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert |
410 | | - if quantization_config and quantization_config.modules_to_not_convert is not None: |
411 | | - modules_to_not_convert.extend(quantization_config.modules_to_not_convert) |
412 | | - modules_to_not_convert = list(set(modules_to_not_convert)) |
413 | | - model, has_been_replaced = _replace_with_bitnet_linear( |
414 | | - model, |
415 | | - modules_to_not_convert, |
416 | | - current_key_name, |
417 | | - quantization_config, |
418 | | - pre_quantized=pre_quantized, |
419 | | - ) |
| 331 | + |
| 332 | + has_been_replaced = False |
| 333 | + # we need this to correctly materialize the weights during quantization |
| 334 | + for module_name, module in model.named_modules(): |
| 335 | + if not should_convert_module(module_name, modules_to_not_convert): |
| 336 | + continue |
| 337 | + with init_empty_weights(): |
| 338 | + if isinstance(module, nn.Linear): |
| 339 | + if quantization_config and quantization_config.linear_class == "autobitlinear": |
| 340 | + new_module = AutoBitLinear( |
| 341 | + in_features=module.in_features, |
| 342 | + out_features=module.out_features, |
| 343 | + bias=module.bias is not None, |
| 344 | + device=module.weight.device, |
| 345 | + dtype=module.weight.dtype, |
| 346 | + online_quant=(quantization_config.quantization_mode == "online"), |
| 347 | + use_rms_norm=quantization_config.use_rms_norm, |
| 348 | + rms_norm_eps=quantization_config.rms_norm_eps, |
| 349 | + ) |
| 350 | + if quantization_config.quantization_mode == "offline": |
| 351 | + new_module.requires_grad_(False) |
| 352 | + else: |
| 353 | + new_module = BitLinear( |
| 354 | + in_features=module.in_features, |
| 355 | + out_features=module.out_features, |
| 356 | + bias=module.bias is not None, |
| 357 | + device=module.weight.device, |
| 358 | + dtype=module.weight.dtype, |
| 359 | + use_rms_norm=quantization_config.use_rms_norm if quantization_config else False, |
| 360 | + rms_norm_eps=quantization_config.rms_norm_eps if quantization_config else 1e-6, |
| 361 | + ) |
| 362 | + new_module.requires_grad_(False) |
| 363 | + model.set_submodule(module_name, new_module) |
| 364 | + has_been_replaced = True |
420 | 365 |
|
421 | 366 | if not has_been_replaced: |
422 | 367 | logger.warning( |
423 | | - "You are loading your model using bitnet but no linear modules were found in your model." |
| 368 | + "You are loading your model using eetq but no linear modules were found in your model." |
424 | 369 | " Please double check your model architecture, or submit an issue on github if you think this is" |
425 | 370 | " a bug." |
426 | 371 | ) |
|
0 commit comments