@@ -205,3 +205,85 @@ def test_simple_inference_with_transformer_lora_and_scale(self):
205205 np .allclose (output_no_lora , output_lora_0_scale , atol = 1e-3 , rtol = 1e-3 ),
206206 "Lora + 0 scale should lead to same result as no LoRA" ,
207207 )
208+
209+ def test_simple_inference_with_transformer_fused (self ):
210+ components = self .get_dummy_components ()
211+ transformer_lora_config = self .get_lora_config_for_transformer ()
212+ pipe = self .pipeline_class (** components )
213+ pipe = pipe .to (torch_device )
214+ pipe .set_progress_bar_config (disable = None )
215+
216+ inputs = self .get_dummy_inputs (torch_device )
217+ output_no_lora = pipe (** inputs ).images
218+
219+ pipe .transformer .add_adapter (transformer_lora_config )
220+ self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in transformer" )
221+
222+ pipe .fuse_lora ()
223+ # Fusing should still keep the LoRA layers
224+ self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in transformer" )
225+
226+ inputs = self .get_dummy_inputs (torch_device )
227+ ouput_fused = pipe (** inputs ).images
228+ self .assertFalse (
229+ np .allclose (ouput_fused , output_no_lora , atol = 1e-3 , rtol = 1e-3 ), "Fused lora should change the output"
230+ )
231+
232+ def test_simple_inference_with_transformer_fused_with_no_fusion (self ):
233+ components = self .get_dummy_components ()
234+ transformer_lora_config = self .get_lora_config_for_transformer ()
235+ pipe = self .pipeline_class (** components )
236+ pipe = pipe .to (torch_device )
237+ pipe .set_progress_bar_config (disable = None )
238+
239+ inputs = self .get_dummy_inputs (torch_device )
240+ output_no_lora = pipe (** inputs ).images
241+
242+ pipe .transformer .add_adapter (transformer_lora_config )
243+ self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in transformer" )
244+ inputs = self .get_dummy_inputs (torch_device )
245+ ouput_lora = pipe (** inputs ).images
246+
247+ pipe .fuse_lora ()
248+ # Fusing should still keep the LoRA layers
249+ self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in transformer" )
250+
251+ inputs = self .get_dummy_inputs (torch_device )
252+ ouput_fused = pipe (** inputs ).images
253+ self .assertFalse (
254+ np .allclose (ouput_fused , output_no_lora , atol = 1e-3 , rtol = 1e-3 ), "Fused lora should change the output"
255+ )
256+ self .assertTrue (
257+ np .allclose (ouput_fused , ouput_lora , atol = 1e-3 , rtol = 1e-3 ),
258+ "Fused lora output should be changed when LoRA isn't fused but still effective." ,
259+ )
260+
261+ def test_simple_inference_with_transformer_fuse_unfuse (self ):
262+ components = self .get_dummy_components ()
263+ transformer_lora_config = self .get_lora_config_for_transformer ()
264+ pipe = self .pipeline_class (** components )
265+ pipe = pipe .to (torch_device )
266+ pipe .set_progress_bar_config (disable = None )
267+
268+ inputs = self .get_dummy_inputs (torch_device )
269+ output_no_lora = pipe (** inputs ).images
270+
271+ pipe .transformer .add_adapter (transformer_lora_config )
272+ self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in transformer" )
273+
274+ pipe .fuse_lora ()
275+ # Fusing should still keep the LoRA layers
276+ self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in transformer" )
277+ inputs = self .get_dummy_inputs (torch_device )
278+ ouput_fused = pipe (** inputs ).images
279+ self .assertFalse (
280+ np .allclose (ouput_fused , output_no_lora , atol = 1e-3 , rtol = 1e-3 ), "Fused lora should change the output"
281+ )
282+
283+ pipe .unfuse_lora ()
284+ self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in transformer" )
285+ inputs = self .get_dummy_inputs (torch_device )
286+ output_unfused_lora = pipe (** inputs ).images
287+ self .assertTrue (
288+ np .allclose (ouput_fused , output_unfused_lora , atol = 1e-3 , rtol = 1e-3 ), "Fused lora should change the output"
289+ )
0 commit comments