Skip to content

Commit 31e77e1

Browse files
authored
feat: add SD2.x support (#40)
1 parent c542a77 commit 31e77e1

File tree

3 files changed

+498
-131
lines changed

3 files changed

+498
-131
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
1414
- Accelerated memory-efficient CPU inference
1515
- Only requires ~2.3GB when using txt2img with fp16 precision to generate a 512x512 image
1616
- AVX, AVX2 and AVX512 support for x86 architectures
17+
- SD1.x and SD2.x support
1718
- Original `txt2img` and `img2img` mode
1819
- Negative prompt
1920
- [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) style tokenizer (not all the features, only token weighting for now)
@@ -60,10 +61,12 @@ git submodule update
6061
- download original weights(.ckpt or .safetensors). For example
6162
- Stable Diffusion v1.4 from https://huggingface.co/CompVis/stable-diffusion-v-1-4-original
6263
- Stable Diffusion v1.5 from https://huggingface.co/runwayml/stable-diffusion-v1-5
64+
- Stable Diffuison v2.1 from https://huggingface.co/stabilityai/stable-diffusion-2-1
6365

6466
```shell
6567
curl -L -O https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt
6668
# curl -L -O https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors
69+
# curl -L -o https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-nonema-pruned.safetensors
6770
```
6871

6972
- convert weights to ggml model format
@@ -182,5 +185,6 @@ docker run -v /path/to/models:/models -v /path/to/output/:/output sd [args...]
182185

183186
- [ggml](https://github.com/ggerganov/ggml)
184187
- [stable-diffusion](https://github.com/CompVis/stable-diffusion)
188+
- [stable-diffusion-stability-ai](https://github.com/Stability-AI/stablediffusion)
185189
- [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
186190
- [k-diffusion](https://github.com/crowsonkb/k-diffusion)

models/convert.py

Lines changed: 100 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
this_file_dir = os.path.dirname(__file__)
1010
vocab_dir = this_file_dir
1111

12+
SD1 = 0
13+
SD2 = 1
14+
1215
ggml_ftype_str_to_int = {
1316
"f32": 0,
1417
"f16": 1,
@@ -155,19 +158,17 @@ def get_alpha_comprod(linear_start=0.00085, linear_end=0.0120, timesteps=1000):
155158
"posterior_mean_coef1",
156159
"posterior_mean_coef2",
157160
"cond_stage_model.transformer.text_model.embeddings.position_ids",
161+
"cond_stage_model.model.logit_scale",
162+
"cond_stage_model.model.text_projection",
158163
"model_ema.decay",
159164
"model_ema.num_updates",
160165
"control_model",
161166
"lora_te_text_model",
162167
"embedding_manager"
163168
]
164169

165-
def convert(model_path, out_type = None, out_file=None):
166-
# load model
167-
with open(os.path.join(vocab_dir, "vocab.json"), encoding="utf-8") as f:
168-
clip_vocab = json.load(f)
169-
170-
state_dict = load_model_from_file(model_path)
170+
171+
def preprocess(state_dict):
171172
alphas_cumprod = state_dict.get("alphas_cumprod")
172173
if alphas_cumprod != None:
173174
# print((np.abs(get_alpha_comprod().numpy() - alphas_cumprod.numpy()) < 0.000001).all())
@@ -176,11 +177,100 @@ def convert(model_path, out_type = None, out_file=None):
176177
print("no alphas_cumprod in file, generate new one")
177178
alphas_cumprod = get_alpha_comprod()
178179
state_dict["alphas_cumprod"] = alphas_cumprod
180+
181+
new_state_dict = {}
182+
for name in state_dict.keys():
183+
# ignore unused tensors
184+
if not isinstance(state_dict[name], torch.Tensor):
185+
continue
186+
skip = False
187+
for unused_tensor in unused_tensors:
188+
if name.startswith(unused_tensor):
189+
skip = True
190+
break
191+
if skip:
192+
continue
193+
194+
# convert open_clip to hf CLIPTextModel (for SD2.x)
195+
open_clip_to_hf_clip_model = {
196+
"cond_stage_model.model.ln_final.bias": "cond_stage_model.transformer.text_model.final_layer_norm.bias",
197+
"cond_stage_model.model.ln_final.weight": "cond_stage_model.transformer.text_model.final_layer_norm.weight",
198+
"cond_stage_model.model.positional_embedding": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
199+
"cond_stage_model.model.token_embedding.weight": "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight",
200+
}
201+
open_clip_to_hk_clip_resblock = {
202+
"attn.out_proj.bias": "self_attn.out_proj.bias",
203+
"attn.out_proj.weight": "self_attn.out_proj.weight",
204+
"ln_1.bias": "layer_norm1.bias",
205+
"ln_1.weight": "layer_norm1.weight",
206+
"ln_2.bias": "layer_norm2.bias",
207+
"ln_2.weight": "layer_norm2.weight",
208+
"mlp.c_fc.bias": "mlp.fc1.bias",
209+
"mlp.c_fc.weight": "mlp.fc1.weight",
210+
"mlp.c_proj.bias": "mlp.fc2.bias",
211+
"mlp.c_proj.weight": "mlp.fc2.weight",
212+
}
213+
open_clip_resblock_prefix = "cond_stage_model.model.transformer.resblocks."
214+
hf_clip_resblock_prefix = "cond_stage_model.transformer.text_model.encoder.layers."
215+
if name in open_clip_to_hf_clip_model:
216+
new_name = open_clip_to_hf_clip_model[name]
217+
new_state_dict[new_name] = state_dict[name]
218+
print(f"preprocess {name} => {new_name}")
219+
continue
220+
if name.startswith(open_clip_resblock_prefix):
221+
remain = name[len(open_clip_resblock_prefix):]
222+
idx = remain.split(".")[0]
223+
suffix = remain[len(idx)+1:]
224+
if suffix == "attn.in_proj_weight":
225+
w = state_dict[name]
226+
w_q, w_k, w_v = w.chunk(3)
227+
for new_suffix, new_w in zip(["self_attn.q_proj.weight", "self_attn.k_proj.weight", "self_attn.v_proj.weight"], [w_q, w_k, w_v]):
228+
new_name = hf_clip_resblock_prefix + idx + "." + new_suffix
229+
new_state_dict[new_name] = new_w
230+
print(f"preprocess {name}{w.size()} => {new_name}{new_w.size()}")
231+
elif suffix == "attn.in_proj_bias":
232+
w = state_dict[name]
233+
w_q, w_k, w_v = w.chunk(3)
234+
for new_suffix, new_w in zip(["self_attn.q_proj.bias", "self_attn.k_proj.bias", "self_attn.v_proj.bias"], [w_q, w_k, w_v]):
235+
new_name = hf_clip_resblock_prefix + idx + "." + new_suffix
236+
new_state_dict[new_name] = new_w
237+
print(f"preprocess {name}{w.size()} => {new_name}{new_w.size()}")
238+
else:
239+
new_suffix = open_clip_to_hk_clip_resblock[suffix]
240+
new_name = hf_clip_resblock_prefix + idx + "." + new_suffix
241+
new_state_dict[new_name] = state_dict[name]
242+
print(f"preprocess {name} => {new_name}")
243+
continue
244+
245+
# convert unet transformer linear to conv2d 1x1
246+
if name.startswith("model.diffusion_model.") and (name.endswith("proj_in.weight") or name.endswith("proj_out.weight")):
247+
w = state_dict[name]
248+
if len(state_dict[name].shape) == 2:
249+
new_w = w.unsqueeze(2).unsqueeze(3)
250+
new_state_dict[name] = new_w
251+
print(f"preprocess {name} {w.size()} => {name} {new_w.size()}")
252+
continue
253+
254+
new_state_dict[name] = state_dict[name]
255+
return new_state_dict
179256

257+
def convert(model_path, out_type = None, out_file=None):
258+
# load model
259+
with open(os.path.join(vocab_dir, "vocab.json"), encoding="utf-8") as f:
260+
clip_vocab = json.load(f)
261+
262+
state_dict = load_model_from_file(model_path)
263+
model_type = SD1
264+
if "cond_stage_model.model.token_embedding.weight" in state_dict.keys():
265+
model_type = SD2
266+
print("Stable diffuison 2.x")
267+
else:
268+
print("Stable diffuison 1.x")
269+
state_dict = preprocess(state_dict)
180270

181271
# output option
182272
if out_type == None:
183-
weight = state_dict["cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight"].numpy()
273+
weight = state_dict["model.diffusion_model.input_blocks.0.0.weight"].numpy()
184274
if weight.dtype == np.float32:
185275
out_type = "f32"
186276
elif weight.dtype == np.float16:
@@ -198,8 +288,9 @@ def convert(model_path, out_type = None, out_file=None):
198288
with open(out_file, "wb") as file:
199289
# magic: ggml in hex
200290
file.write(struct.pack("i", 0x67676D6C))
201-
# out type
202-
file.write(struct.pack("i", ggml_ftype_str_to_int[out_type]))
291+
# model & file type
292+
ftype = (model_type << 16) | ggml_ftype_str_to_int[out_type]
293+
file.write(struct.pack("i", ftype))
203294

204295
# vocab
205296
byte_encoder = bytes_to_unicode()

0 commit comments

Comments
 (0)