3131
3232QK4_0 = 32
3333def quantize_q4_0 (x ):
34- assert x .shape [- 1 ] % QK4_0 == 0
34+ assert x .shape [- 1 ] % QK4_0 == 0 and x . shape [ - 1 ] > QK4_0
3535 x = x .reshape (- 1 , QK4_0 )
3636 max = np .take_along_axis (x , np .argmax (np .abs (x ), axis = - 1 )[:, np .newaxis ], axis = - 1 )
3737 d = max / - 8
@@ -44,7 +44,7 @@ def quantize_q4_0(x):
4444
4545QK4_1 = 32
4646def quantize_q4_1 (x ):
47- assert x .shape [- 1 ] % QK4_1 == 0
47+ assert x .shape [- 1 ] % QK4_1 == 0 and x . shape [ - 1 ] > QK4_1
4848 x = x .reshape (- 1 , QK4_1 )
4949 min = np .min (x , axis = - 1 , keepdims = True )
5050 max = np .max (x , axis = - 1 , keepdims = True )
@@ -59,7 +59,7 @@ def quantize_q4_1(x):
5959
6060QK5_0 = 32
6161def quantize_q5_0 (x ):
62- assert x .shape [1 ] % QK5_0 == 0
62+ assert x .shape [- 1 ] % QK5_0 == 0 and x . shape [ - 1 ] > QK5_0
6363 x = x .reshape (- 1 , QK5_0 )
6464 max = np .take_along_axis (x , np .argmax (np .abs (x ), axis = - 1 )[:, np .newaxis ], axis = - 1 )
6565 d = max / - 16
@@ -76,7 +76,7 @@ def quantize_q5_0(x):
7676
7777QK5_1 = 32
7878def quantize_q5_1 (x ):
79- assert x .shape [- 1 ] % QK5_1 == 0
79+ assert x .shape [- 1 ] % QK5_1 == 0 and x . shape [ - 1 ] > QK5_1
8080 x = x .reshape (- 1 , QK5_1 )
8181 min = np .min (x , axis = - 1 , keepdims = True )
8282 max = np .max (x , axis = - 1 , keepdims = True )
@@ -95,7 +95,7 @@ def quantize_q5_1(x):
9595
9696QK8_0 = 32
9797def quantize_q8_0 (x ):
98- assert x .shape [- 1 ] % QK8_0 == 0
98+ assert x .shape [- 1 ] % QK8_0 == 0 and x . shape [ - 1 ] > QK8_0
9999 x = x .reshape (- 1 , QK8_0 )
100100 amax = np .max (np .abs (x ), axis = - 1 , keepdims = True )
101101 d = amax / ((1 << 7 ) - 1 )
@@ -156,7 +156,10 @@ def get_alpha_comprod(linear_start=0.00085, linear_end=0.0120, timesteps=1000):
156156 "posterior_mean_coef2" ,
157157 "cond_stage_model.transformer.text_model.embeddings.position_ids" ,
158158 "model_ema.decay" ,
159- "model_ema.num_updates"
159+ "model_ema.num_updates" ,
160+ "control_model" ,
161+ "lora_te_text_model" ,
162+ "embedding_manager"
160163]
161164
162165def convert (model_path , out_type = None , out_file = None ):
@@ -182,6 +185,10 @@ def convert(model_path, out_type = None, out_file=None):
182185 out_type = "f32"
183186 elif weight .dtype == np .float16 :
184187 out_type = "f16"
188+ elif weight .dtype == np .float64 :
189+ out_type = "f32"
190+ else :
191+ raise Exception ("unsupported weight type %s" % weight .dtype )
185192 if out_file == None :
186193 out_file = os .path .splitext (os .path .basename (model_path ))[0 ] + f"-ggml-model-{ out_type } .bin"
187194 out_file = os .path .join (os .getcwd (), out_file )
@@ -207,6 +214,13 @@ def convert(model_path, out_type = None, out_file=None):
207214 for name in state_dict .keys ():
208215 if not isinstance (state_dict [name ], torch .Tensor ):
209216 continue
217+ skip = False
218+ for unused_tensor in unused_tensors :
219+ if name .startswith (unused_tensor ):
220+ skip = True
221+ break
222+ if skip :
223+ continue
210224 if name in unused_tensors :
211225 continue
212226 data = state_dict [name ].numpy ()
0 commit comments