Skip to content

Commit e396f79

Browse files
committed
支持 z_image
1 parent 5f2518f commit e396f79

File tree

7 files changed

+140
-19
lines changed

7 files changed

+140
-19
lines changed

gpt_server/model_worker/qwen_image.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,20 +68,19 @@ def __init__(
6868
async def get_image_output(self, params):
6969
self.call_ct += 1
7070
prompt = params["prompt"]
71-
if contains_chinese(prompt):
72-
prompt += positive_magic["zh"]
73-
else:
74-
prompt += positive_magic["en"]
7571
response_format = params.get("response_format", "b64_json")
7672
inputs = {
7773
"prompt": prompt,
7874
"negative_prompt": " ",
79-
"height": height,
80-
"width": width,
8175
"num_inference_steps": 50,
8276
"true_cfg_scale": 4.0,
8377
"generator": torch.Generator(self.device).manual_seed(0),
8478
}
79+
size = params.get("size", None)
80+
if size:
81+
size_split = size.split("x")
82+
width, height = int(size_split[0]), int(size_split[1])
83+
inputs.update({"width": width, "height": height})
8584
output = await asyncio.to_thread(self.pipe, **inputs)
8685
image = output.images[0]
8786
result = {}

gpt_server/model_worker/z_image.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import asyncio
2+
import os
3+
from typing import List
4+
import uuid
5+
from loguru import logger
6+
import shortuuid
7+
from gpt_server.model_worker.base.model_worker_base import ModelWorkerBase
8+
from gpt_server.model_worker.utils import pil_to_base64
9+
import torch
10+
from diffusers import ZImagePipeline
11+
from gpt_server.utils import STATIC_DIR
12+
13+
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
14+
15+
aspect_ratios = {
16+
"1:1": (1328, 1328),
17+
"16:9": (1664, 928),
18+
"9:16": (928, 1664),
19+
"4:3": (1472, 1140),
20+
"3:4": (1140, 1472),
21+
"3:2": (1584, 1056),
22+
"2:3": (1056, 1584),
23+
}
24+
25+
width, height = aspect_ratios["16:9"]
26+
import re
27+
28+
29+
def contains_chinese(text):
30+
pattern = re.compile(r"[\u4e00-\u9fff]")
31+
return bool(pattern.search(text))
32+
33+
34+
class ZImageWorker(ModelWorkerBase):
35+
def __init__(
36+
self,
37+
controller_addr: str,
38+
worker_addr: str,
39+
worker_id: str,
40+
model_path: str,
41+
model_names: List[str],
42+
limit_worker_concurrency: int,
43+
conv_template: str = None, # type: ignore
44+
):
45+
super().__init__(
46+
controller_addr,
47+
worker_addr,
48+
worker_id,
49+
model_path,
50+
model_names,
51+
limit_worker_concurrency,
52+
conv_template,
53+
model_type="image",
54+
)
55+
backend = os.environ["backend"]
56+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
57+
self.pipe = ZImagePipeline.from_pretrained(
58+
model_path, torch_dtype=torch.bfloat16
59+
).to(self.device)
60+
61+
logger.warning(f"模型:{model_names[0]}")
62+
63+
async def get_image_output(self, params):
64+
self.call_ct += 1
65+
prompt = params["prompt"]
66+
response_format = params.get("response_format", "b64_json")
67+
inputs = {
68+
"prompt": prompt,
69+
"negative_prompt": " ",
70+
"num_inference_steps": 8,
71+
"guidance_scale": 0.0,
72+
"generator": torch.Generator(self.device).manual_seed(42),
73+
}
74+
size = params.get("size", None)
75+
if size:
76+
size_split = size.split("x")
77+
width, height = int(size_split[0]), int(size_split[1])
78+
inputs.update({"width": width, "height": height})
79+
output = await asyncio.to_thread(self.pipe, **inputs)
80+
image = output.images[0]
81+
result = {}
82+
if response_format == "b64_json":
83+
# Convert PIL image to base64
84+
base64 = pil_to_base64(pil_img=image)
85+
result = {
86+
"created": shortuuid.random(),
87+
"data": [{"b64_json": base64}],
88+
"usage": {
89+
"total_tokens": 0,
90+
"input_tokens": 0,
91+
"output_tokens": 0,
92+
"input_tokens_details": {"text_tokens": 0, "image_tokens": 0},
93+
},
94+
}
95+
return result
96+
elif response_format == "url":
97+
# 生成唯一文件名(避免冲突)
98+
file_name = str(uuid.uuid4()) + ".png"
99+
save_path = STATIC_DIR / file_name
100+
image.save(save_path, format="PNG")
101+
WORKER_PORT = os.environ["WORKER_PORT"]
102+
WORKER_HOST = os.environ["WORKER_HOST"]
103+
url = f"http://{WORKER_HOST}:{WORKER_PORT}/static/{file_name}"
104+
result = {
105+
"created": shortuuid.random(),
106+
"data": [{"url": url}],
107+
"usage": {
108+
"total_tokens": 0,
109+
"input_tokens": 0,
110+
"output_tokens": 0,
111+
"input_tokens_details": {"text_tokens": 0, "image_tokens": 0},
112+
},
113+
}
114+
return result
115+
116+
117+
if __name__ == "__main__":
118+
ZImageWorker.run()

gpt_server/openai_api_protocol/custom_api_protocol.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ class ImagesGenRequest(BaseModel):
274274
default="url",
275275
description="生成图像时返回的格式。必须为“ur”或“b64_json”之一。URL仅在图像生成后60分钟内有效。",
276276
)
277+
size: str | None = None
277278

278279

279280
# copy from https://github.com/remsky/Kokoro-FastAPI/blob/master/api/src/routers/openai_compatible.py

gpt_server/serving/openai_api_server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1274,6 +1274,7 @@ async def images_generations(request: ImagesGenRequest):
12741274
"prompt": request.prompt,
12751275
"output_format": request.output_format,
12761276
"response_format": request.response_format,
1277+
"size": request.size,
12771278
}
12781279
result = await get_images_gen(payload=payload)
12791280
return result

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "gpt_server"
3-
version = "0.6.7"
3+
version = "0.6.8"
44
description = "gpt_server是一个用于生产级部署LLMs、Embedding、Reranker、ASR和TTS的开源框架。"
55
readme = "README.md"
66
license = { text = "Apache 2.0" }
@@ -55,6 +55,9 @@ gpt_server = "gpt_server.cli:main"
5555
url = "https://pypi.tuna.tsinghua.edu.cn/simple"
5656
default = true
5757

58+
[tool.uv.sources]
59+
diffusers = { git = "https://gitee.com/liuyu_1997/diffusers.git" }
60+
5861
# [[tool.uv.index]]
5962
# name = "vllm-custom"
6063
# url = "https://wheels.vllm.ai/006e7a34aeb3e905ca4131a3251fe079f0511e2f"

tests/test_image_gen.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
client = OpenAI(api_key="EMPTY", base_url="http://localhost:8082/v1")
55
# 两种响应方式
66
## response_format = "url" 默认为 url
7-
img = client.images.generate(model="flux", prompt="A red pig", response_format="url")
8-
print(img.data[0])
9-
## response_format = "b64_json"
7+
prompt = "身着粉色汉服、精致刺绣的中国年轻女子。无可挑剔的妆容,额头上的红色花卉图案。精致的高髻,金凤头饰,红花,珠子。持有圆形折扇,上面有女士、树木、鸟。霓虹灯闪电灯(⚡️),明亮的黄色光芒,位于伸出的左手掌上方。室外夜景柔和,剪影的西安大雁塔,远处的七彩灯光模糊。"
8+
model = "z_image"
109
img = client.images.generate(
11-
model="flux", prompt="A red pig", response_format="b64_json"
10+
model=model, prompt=prompt, response_format="url", size="1664x928"
1211
)
12+
print(img.data[0])
13+
# response_format = "b64_json"
14+
img = client.images.generate(model=model, prompt=prompt, response_format="b64_json")
1315
image_bytes = base64.b64decode(img.data[0].b64_json)
1416
with open("output.png", "wb") as f:
1517
f.write(image_bytes)

uv.lock

Lines changed: 5 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)