Skip to content

Commit 4e61ff1

Browse files
wasamtcwasamtc
andauthored
feat(backend): add sglang lora params of gpu (#272)
Co-authored-by: wasamtc <wasam@qq.com>
1 parent 97d82ed commit 4e61ff1

File tree

5 files changed

+172
-0
lines changed

5 files changed

+172
-0
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ dependencies = [
3333
"dijkstar==2.6.0",
3434
"lattica==1.0.13",
3535
"orjson",
36+
"transformers==4.55.2",
3637
]
3738

3839
[project.scripts]

src/parallax/server/executor.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,15 @@ def __init__(
102102
# GPU/SGLang Specialized Configs
103103
attention_backend: Optional[str] = "flashinfer",
104104
moe_runner_backend: Optional[str] = "auto",
105+
enable_lora: Optional[bool] = False,
106+
max_lora_rank: Optional[int] = None,
107+
lora_target_modules: Optional[List[str]] = None,
108+
lora_paths: Optional[List[str]] = None,
109+
max_loras_per_batch: Optional[int] = None,
110+
max_loaded_loras: Optional[int] = None,
111+
lora_eviction_policy: Optional[str] = "lru",
112+
lora_backend: Optional[str] = "triton",
113+
max_lora_chunk_size: Optional[int] = 128,
105114
# Tensor Parallel Configs
106115
tp_rank: Optional[int] = 0,
107116
tp_size: Optional[int] = 1,
@@ -155,6 +164,15 @@ def __init__(
155164
"tp_size": tp_size,
156165
"nccl_port": nccl_port,
157166
"using_hfcache": use_hfcache,
167+
"enable_lora": enable_lora,
168+
"max_lora_rank": max_lora_rank,
169+
"lora_target_modules": lora_target_modules,
170+
"lora_paths": lora_paths,
171+
"max_loras_per_batch": max_loras_per_batch,
172+
"max_loaded_loras": max_loaded_loras,
173+
"lora_eviction_policy": lora_eviction_policy,
174+
"lora_backend": lora_backend,
175+
"max_lora_chunk_size": max_lora_chunk_size,
158176
}
159177

160178
self.model_runner, self.config, self.tokenizer = initialize_cuda_model_runner(
@@ -1615,5 +1633,14 @@ def create_executor_config(args: argparse.Namespace, gradient_server=None):
16151633
"nccl_port": args.nccl_port,
16161634
"gradient_server": gradient_server,
16171635
"use_hfcache": args.use_hfcache,
1636+
"enable_lora": args.enable_lora,
1637+
"max_lora_rank": args.max_lora_rank,
1638+
"lora_target_modules": args.lora_target_modules,
1639+
"lora_paths": args.lora_paths,
1640+
"max_loras_per_batch": args.max_loras_per_batch,
1641+
"max_loaded_loras": args.max_loaded_loras,
1642+
"lora_eviction_policy": args.lora_eviction_policy,
1643+
"lora_backend": args.lora_backend,
1644+
"max_lora_chunk_size": args.max_lora_chunk_size,
16181645
}
16191646
return config

src/parallax/server/server_args.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,69 @@ def parse_args() -> argparse.Namespace:
171171
help="Choose the GPU moe kernels",
172172
)
173173

174+
parser.add_argument(
175+
"--enable-lora", action="store_true", help="Enable LoRA adapter support for SGLang backend"
176+
)
177+
178+
parser.add_argument(
179+
"--max-lora-rank",
180+
type=int,
181+
default=None,
182+
help="The maximum rank of LoRA adapters. If not specified, it will be automatically inferred from the adapters provided in --lora-paths.",
183+
)
184+
185+
parser.add_argument(
186+
"--lora-target-modules",
187+
nargs="*",
188+
type=str,
189+
default=None,
190+
help="The union set of all target modules where LoRA should be applied. If not specified, it will be automatically inferred from the adapters provided in --lora-paths. If 'all' is specified, all supported modules will be targeted.",
191+
)
192+
193+
parser.add_argument(
194+
"--lora-paths",
195+
nargs="*",
196+
type=str,
197+
default=None,
198+
help="The list of LoRA adapters to load. Each adapter must be specified in one of the following formats: <PATH> | <NAME>=<PATH> | JSON with schema {'lora_name':str,'lora_path':str,'pinned':bool}.",
199+
)
200+
201+
parser.add_argument(
202+
"--max-loras-per-batch",
203+
type=int,
204+
default=8,
205+
help="Maximum number of adapters for a running batch, include base-only request.",
206+
)
207+
208+
parser.add_argument(
209+
"--max-loaded-loras",
210+
type=int,
211+
default=None,
212+
help="If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to --max-loras-per-batch.",
213+
)
214+
215+
parser.add_argument(
216+
"--lora-eviction-policy",
217+
choices=["lru", "fifo"],
218+
default="lru",
219+
help="LoRA adapter eviction policy when memory pool is full. 'lru': Least Recently Used (default, better cache efficiency). 'fifo': First-In-First-Out.",
220+
)
221+
222+
parser.add_argument(
223+
"--lora-backend",
224+
choices=["triton", "csgmv"],
225+
default="triton",
226+
help="Choose the kernel backend for multi-LoRA serving.",
227+
)
228+
229+
parser.add_argument(
230+
"--max-lora-chunk-size",
231+
choices=[16, 32, 64, 128],
232+
type=int,
233+
default=16,
234+
help="Maximum chunk size for the ChunkedSGMV LoRA backend. Only used when --lora-backend is 'csgmv'. Choosing a larger value might improve performance.",
235+
)
236+
174237
# Tensor parallel configuration
175238
parser.add_argument("--tp-size", type=int, default=1, help="Tensor parallel size")
176239

src/parallax/sglang/model_runner.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import logging
88
import os
99
import random
10+
from typing import List, Optional
1011

1112
import sglang
1213
import sglang.srt.distributed.parallel_state
@@ -207,6 +208,15 @@ def form_sgl_server_args(
207208
attention_backend: str = "flashinfer",
208209
kv_block_size: int = 64,
209210
moe_runner_backend="auto",
211+
enable_lora: Optional[bool] = False,
212+
max_lora_rank: Optional[int] = None,
213+
lora_target_modules: Optional[List[str]] = None,
214+
lora_paths: Optional[List[str]] = None,
215+
max_loras_per_batch: Optional[int] = None,
216+
max_loaded_loras: Optional[int] = None,
217+
lora_eviction_policy: Optional[str] = "lru",
218+
lora_backend: Optional[str] = "triton",
219+
max_lora_chunk_size: Optional[int] = 128,
210220
):
211221
"""Creates a SGL ServerArgs object"""
212222
sgl_server_args = ServerArgs(
@@ -218,6 +228,15 @@ def form_sgl_server_args(
218228
moe_runner_backend=moe_runner_backend,
219229
tp_size=tp_size,
220230
trust_remote_code=True,
231+
enable_lora=enable_lora,
232+
max_lora_rank=max_lora_rank,
233+
lora_target_modules=lora_target_modules,
234+
lora_paths=lora_paths,
235+
max_loras_per_batch=max_loras_per_batch,
236+
max_loaded_loras=max_loaded_loras,
237+
lora_eviction_policy=lora_eviction_policy,
238+
lora_backend=lora_backend,
239+
max_lora_chunk_size=max_lora_chunk_size,
221240
)
222241
return sgl_server_args
223242

@@ -231,6 +250,15 @@ def initialize_sgl_model_runner(
231250
kv_block_size: int,
232251
moe_runner_backend: str,
233252
max_num_tokens_per_batch: int = 1024,
253+
enable_lora: Optional[bool] = False,
254+
max_lora_rank: Optional[int] = None,
255+
lora_target_modules: Optional[List[str]] = None,
256+
lora_paths: Optional[List[str]] = None,
257+
max_loras_per_batch: Optional[int] = None,
258+
max_loaded_loras: Optional[int] = None,
259+
lora_eviction_policy: Optional[str] = "lru",
260+
lora_backend: Optional[str] = "triton",
261+
max_lora_chunk_size: Optional[int] = 128,
234262
**kwargs,
235263
):
236264
"""
@@ -285,6 +313,15 @@ def initialize_sgl_model_runner(
285313
attention_backend,
286314
kv_block_size,
287315
moe_runner_backend,
316+
enable_lora,
317+
max_lora_rank,
318+
lora_target_modules,
319+
lora_paths,
320+
max_loras_per_batch,
321+
max_loaded_loras,
322+
lora_eviction_policy,
323+
lora_backend,
324+
max_lora_chunk_size,
288325
)
289326
initialize_moe_config(server_args)
290327
quant_method = None

tests/test_server_args.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,41 @@ class TestCreateExecutorConfig:
7171

7272
def test_create_config(self):
7373
"""Test creating executor configuration."""
74+
args = argparse.Namespace(
75+
model_path="mlx-community/Qwen3-0.6B-bf16",
76+
start_layer=0,
77+
end_layer=10,
78+
dtype="float16",
79+
gpu_backend="sglang",
80+
max_sequence_length=2048,
81+
max_batch_size=8,
82+
kv_block_size=64,
83+
kv_cache_memory_fraction=0.8,
84+
enable_prefix_cache=False,
85+
max_num_tokens_per_batch=1024,
86+
prefill_priority=0,
87+
micro_batch_ratio=2,
88+
scheduler_wait_ms=500,
89+
send_to_peer_addr=None,
90+
recv_from_peer_addr=None,
91+
executor_input_ipc="ipc://test_input",
92+
executor_output_ipc="ipc://test_output",
93+
attention_backend="flashinfer",
94+
moe_runner_backend="auto",
95+
tp_rank=0,
96+
tp_size=1,
97+
nccl_port=4000,
98+
use_hfcache=False,
99+
enable_lora=False,
100+
max_lora_rank=None,
101+
lora_target_modules=None,
102+
lora_paths=None,
103+
max_loras_per_batch=1,
104+
max_loaded_loras=8,
105+
lora_eviction_policy="lru",
106+
lora_backend="triton",
107+
max_lora_chunk_size=128,
108+
)
74109
args = argparse.Namespace(
75110
model_path="mlx-community/Qwen3-0.6B-bf16",
76111
start_layer=0,
@@ -92,6 +127,15 @@ def test_create_config(self):
92127
tp_size=1,
93128
nccl_port=4001,
94129
use_hfcache=False,
130+
enable_lora=False,
131+
max_lora_rank=None,
132+
lora_target_modules=None,
133+
lora_paths=None,
134+
max_loras_per_batch=1,
135+
max_loaded_loras=8,
136+
lora_eviction_policy="lru",
137+
lora_backend="triton",
138+
max_lora_chunk_size=128,
95139
)
96140

97141
config = create_executor_config(args)

0 commit comments

Comments
 (0)