@@ -1144,20 +1144,24 @@ def test_cpu_offload_forward_pass_twice(self, expected_max_diff=2e-4):
11441144 self .assertLess (
11451145 max_diff , expected_max_diff , "running CPU offloading 2nd time should not affect the inference results"
11461146 )
1147- offloaded_modules = [
1148- v
1147+ offloaded_modules = {
1148+ k : v
11491149 for k , v in pipe .components .items ()
11501150 if isinstance (v , torch .nn .Module ) and k not in pipe ._exclude_from_cpu_offload
1151- ]
1152- (
1153- self . assertTrue ( all (v .device .type == "cpu" for v in offloaded_modules )),
1154- f"Not offloaded: { [v for v in offloaded_modules if v .device .type != 'cpu' ]} " ,
1151+ }
1152+ self . assertTrue (
1153+ all (v .device .type == "cpu" for v in offloaded_modules . values ( )),
1154+ f"Not offloaded: { [k for k , v in offloaded_modules . items () if v .device .type != 'cpu' ]} " ,
11551155 )
11561156
1157- offloaded_modules_with_hooks = [v for v in offloaded_modules if hasattr (v , "_hf_hook" )]
1158- (
1159- self .assertTrue (all (isinstance (v , accelerate .hooks .CpuOffload ) for v in offloaded_modules_with_hooks )),
1160- f"Not installed correct hook: { [v for v in offloaded_modules_with_hooks if not isinstance (v , accelerate .hooks .CpuOffload )]} " ,
1157+ offloaded_modules_with_incorrect_hooks = {}
1158+ for k , v in offloaded_modules .items ():
1159+ if hasattr (v , "_hf_hook" ) and not isinstance (v ._hf_hook , accelerate .hooks .CpuOffload ):
1160+ offloaded_modules_with_incorrect_hooks [k ] = type (v ._hf_hook )
1161+
1162+ self .assertTrue (
1163+ len (offloaded_modules_with_incorrect_hooks ) == 0 ,
1164+ f"Not installed correct hook: { offloaded_modules_with_incorrect_hooks } " ,
11611165 )
11621166
11631167 @unittest .skipIf (
@@ -1189,22 +1193,23 @@ def test_sequential_offload_forward_pass_twice(self, expected_max_diff=2e-4):
11891193 self .assertLess (
11901194 max_diff , expected_max_diff , "running sequential offloading second time should have the inference results"
11911195 )
1192- offloaded_modules = [
1193- v
1196+ offloaded_modules = {
1197+ k : v
11941198 for k , v in pipe .components .items ()
11951199 if isinstance (v , torch .nn .Module ) and k not in pipe ._exclude_from_cpu_offload
1196- ]
1197- (
1198- self . assertTrue ( all (v .device .type == "meta" for v in offloaded_modules )),
1199- f"Not offloaded: { [v for v in offloaded_modules if v .device .type != 'meta' ]} " ,
1200+ }
1201+ self . assertTrue (
1202+ all (v .device .type == "meta" for v in offloaded_modules . values ( )),
1203+ f"Not offloaded: { [k for k , v in offloaded_modules . items () if v .device .type != 'meta' ]} " ,
12001204 )
1205+ offloaded_modules_with_incorrect_hooks = {}
1206+ for k , v in offloaded_modules .items ():
1207+ if hasattr (v , "_hf_hook" ) and not isinstance (v ._hf_hook , accelerate .hooks .AlignDevicesHook ):
1208+ offloaded_modules_with_incorrect_hooks [k ] = type (v ._hf_hook )
12011209
1202- offloaded_modules_with_hooks = [v for v in offloaded_modules if hasattr (v , "_hf_hook" )]
1203- (
1204- self .assertTrue (
1205- all (isinstance (v , accelerate .hooks .AlignDevicesHook ) for v in offloaded_modules_with_hooks )
1206- ),
1207- f"Not installed correct hook: { [v for v in offloaded_modules_with_hooks if not isinstance (v , accelerate .hooks .AlignDevicesHook )]} " ,
1210+ self .assertTrue (
1211+ len (offloaded_modules_with_incorrect_hooks ) == 0 ,
1212+ f"Not installed correct hook: { offloaded_modules_with_incorrect_hooks } " ,
12081213 )
12091214
12101215 @unittest .skipIf (
0 commit comments