99this_file_dir = os .path .dirname (__file__ )
1010vocab_dir = this_file_dir
1111
12+ SD1 = 0
13+ SD2 = 1
14+
1215ggml_ftype_str_to_int = {
1316 "f32" : 0 ,
1417 "f16" : 1 ,
@@ -155,19 +158,17 @@ def get_alpha_comprod(linear_start=0.00085, linear_end=0.0120, timesteps=1000):
155158 "posterior_mean_coef1" ,
156159 "posterior_mean_coef2" ,
157160 "cond_stage_model.transformer.text_model.embeddings.position_ids" ,
161+ "cond_stage_model.model.logit_scale" ,
162+ "cond_stage_model.model.text_projection" ,
158163 "model_ema.decay" ,
159164 "model_ema.num_updates" ,
160165 "control_model" ,
161166 "lora_te_text_model" ,
162167 "embedding_manager"
163168]
164169
165- def convert (model_path , out_type = None , out_file = None ):
166- # load model
167- with open (os .path .join (vocab_dir , "vocab.json" ), encoding = "utf-8" ) as f :
168- clip_vocab = json .load (f )
169-
170- state_dict = load_model_from_file (model_path )
170+
171+ def preprocess (state_dict ):
171172 alphas_cumprod = state_dict .get ("alphas_cumprod" )
172173 if alphas_cumprod != None :
173174 # print((np.abs(get_alpha_comprod().numpy() - alphas_cumprod.numpy()) < 0.000001).all())
@@ -176,11 +177,100 @@ def convert(model_path, out_type = None, out_file=None):
176177 print ("no alphas_cumprod in file, generate new one" )
177178 alphas_cumprod = get_alpha_comprod ()
178179 state_dict ["alphas_cumprod" ] = alphas_cumprod
180+
181+ new_state_dict = {}
182+ for name in state_dict .keys ():
183+ # ignore unused tensors
184+ if not isinstance (state_dict [name ], torch .Tensor ):
185+ continue
186+ skip = False
187+ for unused_tensor in unused_tensors :
188+ if name .startswith (unused_tensor ):
189+ skip = True
190+ break
191+ if skip :
192+ continue
193+
194+ # convert open_clip to hf CLIPTextModel (for SD2.x)
195+ open_clip_to_hf_clip_model = {
196+ "cond_stage_model.model.ln_final.bias" : "cond_stage_model.transformer.text_model.final_layer_norm.bias" ,
197+ "cond_stage_model.model.ln_final.weight" : "cond_stage_model.transformer.text_model.final_layer_norm.weight" ,
198+ "cond_stage_model.model.positional_embedding" : "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight" ,
199+ "cond_stage_model.model.token_embedding.weight" : "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" ,
200+ }
201+ open_clip_to_hk_clip_resblock = {
202+ "attn.out_proj.bias" : "self_attn.out_proj.bias" ,
203+ "attn.out_proj.weight" : "self_attn.out_proj.weight" ,
204+ "ln_1.bias" : "layer_norm1.bias" ,
205+ "ln_1.weight" : "layer_norm1.weight" ,
206+ "ln_2.bias" : "layer_norm2.bias" ,
207+ "ln_2.weight" : "layer_norm2.weight" ,
208+ "mlp.c_fc.bias" : "mlp.fc1.bias" ,
209+ "mlp.c_fc.weight" : "mlp.fc1.weight" ,
210+ "mlp.c_proj.bias" : "mlp.fc2.bias" ,
211+ "mlp.c_proj.weight" : "mlp.fc2.weight" ,
212+ }
213+ open_clip_resblock_prefix = "cond_stage_model.model.transformer.resblocks."
214+ hf_clip_resblock_prefix = "cond_stage_model.transformer.text_model.encoder.layers."
215+ if name in open_clip_to_hf_clip_model :
216+ new_name = open_clip_to_hf_clip_model [name ]
217+ new_state_dict [new_name ] = state_dict [name ]
218+ print (f"preprocess { name } => { new_name } " )
219+ continue
220+ if name .startswith (open_clip_resblock_prefix ):
221+ remain = name [len (open_clip_resblock_prefix ):]
222+ idx = remain .split ("." )[0 ]
223+ suffix = remain [len (idx )+ 1 :]
224+ if suffix == "attn.in_proj_weight" :
225+ w = state_dict [name ]
226+ w_q , w_k , w_v = w .chunk (3 )
227+ for new_suffix , new_w in zip (["self_attn.q_proj.weight" , "self_attn.k_proj.weight" , "self_attn.v_proj.weight" ], [w_q , w_k , w_v ]):
228+ new_name = hf_clip_resblock_prefix + idx + "." + new_suffix
229+ new_state_dict [new_name ] = new_w
230+ print (f"preprocess { name } { w .size ()} => { new_name } { new_w .size ()} " )
231+ elif suffix == "attn.in_proj_bias" :
232+ w = state_dict [name ]
233+ w_q , w_k , w_v = w .chunk (3 )
234+ for new_suffix , new_w in zip (["self_attn.q_proj.bias" , "self_attn.k_proj.bias" , "self_attn.v_proj.bias" ], [w_q , w_k , w_v ]):
235+ new_name = hf_clip_resblock_prefix + idx + "." + new_suffix
236+ new_state_dict [new_name ] = new_w
237+ print (f"preprocess { name } { w .size ()} => { new_name } { new_w .size ()} " )
238+ else :
239+ new_suffix = open_clip_to_hk_clip_resblock [suffix ]
240+ new_name = hf_clip_resblock_prefix + idx + "." + new_suffix
241+ new_state_dict [new_name ] = state_dict [name ]
242+ print (f"preprocess { name } => { new_name } " )
243+ continue
244+
245+ # convert unet transformer linear to conv2d 1x1
246+ if name .startswith ("model.diffusion_model." ) and (name .endswith ("proj_in.weight" ) or name .endswith ("proj_out.weight" )):
247+ w = state_dict [name ]
248+ if len (state_dict [name ].shape ) == 2 :
249+ new_w = w .unsqueeze (2 ).unsqueeze (3 )
250+ new_state_dict [name ] = new_w
251+ print (f"preprocess { name } { w .size ()} => { name } { new_w .size ()} " )
252+ continue
253+
254+ new_state_dict [name ] = state_dict [name ]
255+ return new_state_dict
179256
257+ def convert (model_path , out_type = None , out_file = None ):
258+ # load model
259+ with open (os .path .join (vocab_dir , "vocab.json" ), encoding = "utf-8" ) as f :
260+ clip_vocab = json .load (f )
261+
262+ state_dict = load_model_from_file (model_path )
263+ model_type = SD1
264+ if "cond_stage_model.model.token_embedding.weight" in state_dict .keys ():
265+ model_type = SD2
266+ print ("Stable diffuison 2.x" )
267+ else :
268+ print ("Stable diffuison 1.x" )
269+ state_dict = preprocess (state_dict )
180270
181271 # output option
182272 if out_type == None :
183- weight = state_dict ["cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj .weight" ].numpy ()
273+ weight = state_dict ["model.diffusion_model.input_blocks.0.0 .weight" ].numpy ()
184274 if weight .dtype == np .float32 :
185275 out_type = "f32"
186276 elif weight .dtype == np .float16 :
@@ -198,8 +288,9 @@ def convert(model_path, out_type = None, out_file=None):
198288 with open (out_file , "wb" ) as file :
199289 # magic: ggml in hex
200290 file .write (struct .pack ("i" , 0x67676D6C ))
201- # out type
202- file .write (struct .pack ("i" , ggml_ftype_str_to_int [out_type ]))
291+ # model & file type
292+ ftype = (model_type << 16 ) | ggml_ftype_str_to_int [out_type ]
293+ file .write (struct .pack ("i" , ftype ))
203294
204295 # vocab
205296 byte_encoder = bytes_to_unicode ()
0 commit comments