Skip to content

Commit 26030d0

Browse files
committed
Update inference.py
fix multiple loras loading and setting of adapters
1 parent 9c2426a commit 26030d0

File tree

1 file changed

+33
-10
lines changed

1 file changed

+33
-10
lines changed

optillm/inference.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class ModelConfig:
4040
# Advanced parameters
4141
use_memory_efficient_attention: bool = True
4242
enable_prompt_caching: bool = True
43-
dynamic_temperature: bool = True
43+
dynamic_temperature: bool = False
4444

4545

4646
@dataclass
@@ -490,6 +490,17 @@ def __init__(self, cache_manager: CacheManager):
490490
self.cache_manager = cache_manager
491491
self.loaded_adapters = {} # Maps model -> list of loaded adapter_ids
492492

493+
def _get_adapter_name(self, adapter_id: str) -> str:
494+
"""Create a valid adapter name from adapter_id by removing invalid characters"""
495+
# Replace invalid characters with underscore
496+
name = adapter_id.replace('.', '_').replace('-', '_')
497+
# Remove any other non-alphanumeric characters
498+
name = ''.join(c if c.isalnum() or c == '_' else '' for c in name)
499+
# Ensure it starts with a letter or underscore
500+
if name[0].isdigit():
501+
name = f"adapter_{name}"
502+
return name
503+
493504
def validate_adapter(self, adapter_id: str) -> bool:
494505
"""Validate if adapter exists and is compatible"""
495506
try:
@@ -516,12 +527,20 @@ def _load_adapter():
516527
raise ValueError(error_msg)
517528

518529
try:
530+
# Generate a consistent name for this adapter
531+
adapter_name = self._get_adapter_name(adapter_id)
532+
533+
config = PeftConfig.from_pretrained(
534+
adapter_id,
535+
trust_remote_code=True,
536+
use_auth_token=os.getenv("HF_TOKEN") # Support private repos
537+
)
519538

520539
# Load adapter into existing PeftModel
521540
model = base_model
522-
model.load_adapter(
523-
adapter_id,
524-
token=os.getenv("HF_TOKEN"),
541+
model.add_adapter(
542+
config,
543+
adapter_name = adapter_name,
525544
)
526545

527546
# Track loaded adapter
@@ -553,9 +572,6 @@ def _load_adapter():
553572

554573
def set_active_adapter(self, model: PeftModel, adapter_id: str = None) -> bool:
555574
"""Set a specific adapter as active with error handling"""
556-
if not isinstance(model, PeftModel):
557-
logger.warning("Model is not a PeftModel, cannot set adapter")
558-
return False
559575

560576
available_adapters = self.loaded_adapters.get(model, [])
561577

@@ -570,7 +586,8 @@ def set_active_adapter(self, model: PeftModel, adapter_id: str = None) -> bool:
570586
if adapter_id in available_adapters:
571587
try:
572588
model.enable_adapters()
573-
model.set_adapter(adapter_id)
589+
adapter_name = self._get_adapter_name(adapter_id)
590+
model.set_adapter(adapter_name)
574591
logger.info(f"Successfully set active adapter to: {adapter_id}")
575592
return True
576593
except Exception as e:
@@ -628,8 +645,7 @@ def __init__(self, model_config: ModelConfig, cache_manager: CacheManager,
628645
except Exception as e:
629646
logger.error(f"Failed to load adapter {adapter_id}: {e}")
630647

631-
if isinstance(self.current_model, PeftModel):
632-
self.lora_manager.set_active_adapter(self.current_model)
648+
self.lora_manager.set_active_adapter(self.current_model)
633649

634650
# Setup optimizations
635651
if model_config.use_memory_efficient_attention:
@@ -916,6 +932,7 @@ def __call__(self, input_ids, scores, **kwargs):
916932
input_length=input_length
917933
)
918934
])
935+
919936
def process_batch(
920937
self,
921938
system_prompts: List[str],
@@ -1203,13 +1220,19 @@ def create(
12031220
seed: Optional[int] = None,
12041221
logprobs: Optional[bool] = None,
12051222
top_logprobs: Optional[int] = None,
1223+
active_adapter: Optional[Dict[str, Any]] = None,
12061224
**kwargs
12071225
) -> ChatCompletion:
12081226
"""Create a chat completion with OpenAI-compatible parameters"""
12091227
if stream:
12101228
raise NotImplementedError("Streaming is not yet supported")
12111229

12121230
pipeline = self.client.get_pipeline(model)
1231+
1232+
# Set active adapter if specified in extra_body
1233+
if active_adapter is not None:
1234+
logger.info(f"Setting active adapter to: {active_adapter}")
1235+
pipeline.lora_manager.set_active_adapter(pipeline.current_model, active_adapter)
12131236

12141237
# Apply chat template to messages
12151238
prompt = pipeline.tokenizer.apply_chat_template(

0 commit comments

Comments
 (0)