Skip to content

Commit c542a77

Browse files
committed
fix: correct the handling of weight loading
1 parent 1b5a868 commit c542a77

File tree

2 files changed

+23
-9
lines changed

2 files changed

+23
-9
lines changed

models/convert.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
QK4_0 = 32
3333
def quantize_q4_0(x):
34-
assert x.shape[-1] % QK4_0 == 0
34+
assert x.shape[-1] % QK4_0 == 0 and x.shape[-1] > QK4_0
3535
x = x.reshape(-1, QK4_0)
3636
max = np.take_along_axis(x, np.argmax(np.abs(x), axis=-1)[:, np.newaxis], axis=-1)
3737
d = max / -8
@@ -44,7 +44,7 @@ def quantize_q4_0(x):
4444

4545
QK4_1 = 32
4646
def quantize_q4_1(x):
47-
assert x.shape[-1] % QK4_1 == 0
47+
assert x.shape[-1] % QK4_1 == 0 and x.shape[-1] > QK4_1
4848
x = x.reshape(-1, QK4_1)
4949
min = np.min(x, axis=-1, keepdims=True)
5050
max = np.max(x, axis=-1, keepdims=True)
@@ -59,7 +59,7 @@ def quantize_q4_1(x):
5959

6060
QK5_0 = 32
6161
def quantize_q5_0(x):
62-
assert x.shape[1] % QK5_0 == 0
62+
assert x.shape[-1] % QK5_0 == 0 and x.shape[-1] > QK5_0
6363
x = x.reshape(-1, QK5_0)
6464
max = np.take_along_axis(x, np.argmax(np.abs(x), axis=-1)[:, np.newaxis], axis=-1)
6565
d = max / -16
@@ -76,7 +76,7 @@ def quantize_q5_0(x):
7676

7777
QK5_1 = 32
7878
def quantize_q5_1(x):
79-
assert x.shape[-1] % QK5_1 == 0
79+
assert x.shape[-1] % QK5_1 == 0 and x.shape[-1] > QK5_1
8080
x = x.reshape(-1, QK5_1)
8181
min = np.min(x, axis=-1, keepdims=True)
8282
max = np.max(x, axis=-1, keepdims=True)
@@ -95,7 +95,7 @@ def quantize_q5_1(x):
9595

9696
QK8_0 = 32
9797
def quantize_q8_0(x):
98-
assert x.shape[-1] % QK8_0 == 0
98+
assert x.shape[-1] % QK8_0 == 0 and x.shape[-1] > QK8_0
9999
x = x.reshape(-1, QK8_0)
100100
amax = np.max(np.abs(x), axis=-1, keepdims=True)
101101
d = amax / ((1 << 7) - 1)
@@ -156,7 +156,10 @@ def get_alpha_comprod(linear_start=0.00085, linear_end=0.0120, timesteps=1000):
156156
"posterior_mean_coef2",
157157
"cond_stage_model.transformer.text_model.embeddings.position_ids",
158158
"model_ema.decay",
159-
"model_ema.num_updates"
159+
"model_ema.num_updates",
160+
"control_model",
161+
"lora_te_text_model",
162+
"embedding_manager"
160163
]
161164

162165
def convert(model_path, out_type = None, out_file=None):
@@ -182,6 +185,10 @@ def convert(model_path, out_type = None, out_file=None):
182185
out_type = "f32"
183186
elif weight.dtype == np.float16:
184187
out_type = "f16"
188+
elif weight.dtype == np.float64:
189+
out_type = "f32"
190+
else:
191+
raise Exception("unsupported weight type %s" % weight.dtype)
185192
if out_file == None:
186193
out_file = os.path.splitext(os.path.basename(model_path))[0] + f"-ggml-model-{out_type}.bin"
187194
out_file = os.path.join(os.getcwd(), out_file)
@@ -207,6 +214,13 @@ def convert(model_path, out_type = None, out_file=None):
207214
for name in state_dict.keys():
208215
if not isinstance(state_dict[name], torch.Tensor):
209216
continue
217+
skip = False
218+
for unused_tensor in unused_tensors:
219+
if name.startswith(unused_tensor):
220+
skip = True
221+
break
222+
if skip:
223+
continue
210224
if name in unused_tensors:
211225
continue
212226
data = state_dict[name].numpy()

stable-diffusion.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2864,6 +2864,8 @@ class StableDiffusionGGML {
28642864
nelements *= ne[i];
28652865
}
28662866

2867+
const size_t num_bytes = nelements / ggml_blck_size(ggml_type(ttype)) * ggml_type_size(ggml_type(ttype));
2868+
28672869
std::string name(length, 0);
28682870
file.read(&name[0], length);
28692871

@@ -2891,7 +2893,7 @@ class StableDiffusionGGML {
28912893
return false;
28922894
}
28932895
}
2894-
file.ignore(nelements * ggml_type_size((ggml_type)ttype));
2896+
file.ignore(num_bytes);
28952897
continue;
28962898
}
28972899

@@ -2919,8 +2921,6 @@ class StableDiffusionGGML {
29192921
return false;
29202922
}
29212923

2922-
const size_t num_bytes = nelements / ggml_blck_size(ggml_type(ttype)) * ggml_type_size(ggml_type(ttype));
2923-
29242924
file.read(reinterpret_cast<char*>(tensor->data), num_bytes);
29252925

29262926
total_size += ggml_nbytes(tensor);

0 commit comments

Comments
 (0)