1-
2-
3- from io import BytesIO
41import json
52import os
3+ import pickle as pkl
64import random
7-
85from glob import glob
6+ from io import BytesIO
97from pathlib import Path
10- import pickle as pkl
118from typing import Callable
129
10+ import h5py as h5
1311import numpy as np
14-
1512import tensorflow as tf
16- from tensorflow import keras
17-
18- from tqdm .auto import tqdm
19- from matplotlib import pyplot as plt
20-
2113import zstd
22- import h5py as h5
23-
24- from keras .src .saving .legacy import hdf5_format
25- from keras .src .layers .convolutional .base_conv import Conv
26- from keras .layers import Dense
27-
2814from HGQ .bops import trace_minmax
15+ from keras .layers import Dense
16+ from keras .src .layers .convolutional .base_conv import Conv
17+ from keras .src .saving .legacy import hdf5_format
18+ from matplotlib import pyplot as plt
19+ from tensorflow import keras
20+ from tqdm .auto import tqdm
2921
3022
3123class NumpyFloatValuesEncoder (json .JSONEncoder ):
@@ -36,14 +28,15 @@ def default(self, obj):
3628
3729
3830class SaveTopN (keras .callbacks .Callback ):
39- def __init__ (self ,
40- metric_fn : Callable [[dict ], float ],
41- n : int ,
42- path : str | Path ,
43- side : str = 'max' ,
44- fname_format = 'epoch={epoch}-metric={metric:.4e}.h5' ,
45- cond_fn : Callable [[dict ], bool ] = lambda x : True ,
46- ):
31+ def __init__ (
32+ self ,
33+ metric_fn : Callable [[dict ], float ],
34+ n : int ,
35+ path : str | Path ,
36+ side : str = 'max' ,
37+ fname_format = 'epoch={epoch}-metric={metric:.4e}.h5' ,
38+ cond_fn : Callable [[dict ], bool ] = lambda x : True ,
39+ ):
4740 self .n = n
4841 self .metric_fn = metric_fn
4942 self .path = Path (path )
@@ -188,9 +181,11 @@ def absorb_batchNorm(model_target, model_original):
188181 if layer .__class__ .__name__ == 'Functional' :
189182 absorb_batchNorm (layer , model_original .get_layer (layer .name ))
190183 continue
191- if (isinstance (layer , Dense ) or isinstance (layer , Conv )) and \
192- len (nodes := model_original .get_layer (layer .name )._outbound_nodes ) > 0 and \
193- isinstance (nodes [0 ].outbound_layer , keras .layers .BatchNormalization ):
184+ if (
185+ (isinstance (layer , Dense ) or isinstance (layer , Conv ))
186+ and len (nodes := model_original .get_layer (layer .name )._outbound_nodes ) > 0
187+ and isinstance (nodes [0 ].outbound_layer , keras .layers .BatchNormalization )
188+ ):
194189 _gamma , _beta , _mu , _var = model_original .get_layer (layer .name )._outbound_nodes [0 ].outbound_layer .get_weights ()
195190 _ratio = _gamma / np .sqrt (0.001 + _var )
196191 _bias = - _gamma * _mu / np .sqrt (0.001 + _var ) + _beta
@@ -213,7 +208,7 @@ def absorb_batchNorm(model_target, model_original):
213208 weights = layer .get_weights ()
214209 new_weights = model_original .get_layer (layer .name ).get_weights ()
215210 l = len (new_weights )
216- layer .set_weights ([* new_weights , * weights [l :]][:len (weights )])
211+ layer .set_weights ([* new_weights , * weights [l :]][: len (weights )])
217212
218213
219214def set_seed (seed ):
@@ -225,9 +220,10 @@ def set_seed(seed):
225220 tf .config .experimental .enable_op_determinism ()
226221
227222
228- import h5py as h5
229223import json
230224
225+ import h5py as h5
226+
231227
232228def get_best_ckpt (save_path : Path , take_min = False ):
233229 ckpts = list (save_path .glob ('*.h5' ))
@@ -245,13 +241,14 @@ def rank(ckpt: Path):
245241
246242
247243class PeratoFront (keras .callbacks .Callback ):
248- def __init__ (self ,
249- path : str | Path ,
250- fname_format : str ,
251- metrics_names : list [str ],
252- sides : list [int ],
253- cond_fn : Callable [[dict ], bool ] = lambda x : True ,
254- ):
244+ def __init__ (
245+ self ,
246+ path : str | Path ,
247+ fname_format : str ,
248+ metrics_names : list [str ],
249+ sides : list [int ],
250+ cond_fn : Callable [[dict ], bool ] = lambda x : True ,
251+ ):
255252 self .path = Path (path )
256253 self .fname_format = fname_format
257254 os .makedirs (path , exist_ok = True )
0 commit comments