Skip to content

Commit 820ed76

Browse files
authored
[infer] support lora adapter for SGLang backend (hiyouga#8067)
1 parent 66f719d commit 820ed76

3 files changed

Lines changed: 22 additions & 1 deletion

File tree

src/llamafactory/chat/sglang_engine.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ def __init__(
7979
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
8080
self.template.mm_plugin.expand_mm_tokens = False # for sglang generate
8181
self.generating_args = generating_args.to_dict()
82+
if model_args.adapter_name_or_path is not None:
83+
self.lora_request = True
84+
else:
85+
self.lora_request = False
8286

8387
launch_cmd = [
8488
"python3 -m sglang.launch_server",
@@ -90,6 +94,15 @@ def __init__(
9094
f"--download-dir {model_args.cache_dir}",
9195
"--log-level error",
9296
]
97+
if self.lora_request:
98+
launch_cmd.extend(
99+
[
100+
"--max-loras-per-batch 1",
101+
f"--lora-backend {model_args.sglang_lora_backend}",
102+
f"--lora-paths lora0={model_args.adapter_name_or_path[0]}",
103+
"--disable-radix-cache",
104+
]
105+
)
93106
launch_cmd = " ".join(launch_cmd)
94107
logger.info_rank0(f"Starting SGLang server with command: {launch_cmd}")
95108
try:
@@ -202,6 +215,8 @@ def stream_request():
202215
"sampling_params": sampling_params,
203216
"stream": True,
204217
}
218+
if self.lora_request:
219+
json_data["lora_request"] = ["lora0"]
205220
response = requests.post(f"{self.base_url}/generate", json=json_data, stream=True)
206221
if response.status_code != 200:
207222
raise RuntimeError(f"SGLang server error: {response.status_code}, {response.text}")

src/llamafactory/hparams/model_args.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,12 @@ class SGLangArguments:
364364
default=None,
365365
metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."},
366366
)
367+
sglang_lora_backend: Literal["triton", "flashinfer"] = field(
368+
default="triton",
369+
metadata={
370+
"help": "The backend of running GEMM kernels for Lora modules. Recommend using the Triton LoRA backend for better performance and stability."
371+
},
372+
)
367373

368374
def __post_init__(self):
369375
if isinstance(self.sglang_config, str) and self.sglang_config.startswith("{"):

tests/e2e/test_sglang.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from llamafactory.extras.packages import is_sglang_available
2121

2222

23-
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
23+
MODEL_NAME = "Qwen/Qwen2.5-0.5B"
2424

2525

2626
INFER_ARGS = {

0 commit comments

Comments
 (0)