@@ -34,6 +34,7 @@ class StableDiffusionSafetyChecker(PreTrainedModel):
3434 main_input_name = "clip_input"
3535
3636 _no_split_modules = ["CLIPEncoderLayer" ]
37+
3738
3839 def __init__ (self , config : CLIPConfig ):
3940 super ().__init__ (config )
@@ -46,8 +47,14 @@ def __init__(self, config: CLIPConfig):
4647
4748 self .concept_embeds_weights = nn .Parameter (torch .ones (17 ), requires_grad = False )
4849 self .special_care_embeds_weights = nn .Parameter (torch .ones (3 ), requires_grad = False )
50+
51+ self .adjustment = 0.0
4952
5053 def update_safety_checker_Level (self , Level ):
54+ """
55+ Args:
56+ Level (`int` or `float` or one of the following [`WEAK`], [`MEDIUM`], [`NOMAL`], [`STRONG`], [`MAX`])
57+ """
5158 Level_dict = {
5259 "WEAK" : - 1.0 ,
5360 "MEDIUM" : - 0.5 ,
@@ -56,11 +63,21 @@ def update_safety_checker_Level(self, Level):
5663 "MAX" : 1.0 ,
5764 }
5865 if Level in Level_dict :
59- Level = Level_dict [Level ]
66+ Level = Level_dict [Level ]
6067 if isinstance (Level , (float , int )):
6168 setattr (self ,"adjustment" ,Level )
6269 else :
6370 raise ValueError ("`int` or `float` or one of the following ['WEAK'], ['MEDIUM'], ['NOMAL'], ['STRONG'], ['MAX']" )
71+
72+ if self .adjustment < 0 :
73+ logger .warning (
74+ f"You have disabled the safety checker for { self .__class__ } by passing `safety_checker=None`. Ensure"
75+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
76+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
77+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
78+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
79+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
80+ )
6481
6582 @torch .no_grad ()
6683 def forward (self , clip_input , images ):
@@ -78,20 +95,20 @@ def forward(self, clip_input, images):
7895
7996 # increase this value to create a stronger `nfsw` filter
8097 # at the cost of increasing the possibility of filtering benign images
81- adjustment = 0.0
98+ # adjustment = 0.0
8299
83100 for concept_idx in range (len (special_cos_dist [0 ])):
84101 concept_cos = special_cos_dist [i ][concept_idx ]
85102 concept_threshold = self .special_care_embeds_weights [concept_idx ].item ()
86- result_img ["special_scores" ][concept_idx ] = round (concept_cos - concept_threshold + adjustment , 3 )
103+ result_img ["special_scores" ][concept_idx ] = round (concept_cos - concept_threshold + self . adjustment , 3 )
87104 if result_img ["special_scores" ][concept_idx ] > 0 :
88105 result_img ["special_care" ].append ({concept_idx , result_img ["special_scores" ][concept_idx ]})
89- adjustment = 0.01
106+ self . adjustment = 0.01
90107
91108 for concept_idx in range (len (cos_dist [0 ])):
92109 concept_cos = cos_dist [i ][concept_idx ]
93110 concept_threshold = self .concept_embeds_weights [concept_idx ].item ()
94- result_img ["concept_scores" ][concept_idx ] = round (concept_cos - concept_threshold + adjustment , 3 )
111+ result_img ["concept_scores" ][concept_idx ] = round (concept_cos - concept_threshold + self . adjustment , 3 )
95112 if result_img ["concept_scores" ][concept_idx ] > 0 :
96113 result_img ["bad_concepts" ].append (concept_idx )
97114
@@ -124,9 +141,9 @@ def forward_onnx(self, clip_input: torch.Tensor, images: torch.Tensor):
124141
125142 # increase this value to create a stronger `nsfw` filter
126143 # at the cost of increasing the possibility of filtering benign images
127- adjustment = 0.0
144+ # adjustment = 0.0
128145
129- special_scores = special_cos_dist - self .special_care_embeds_weights + adjustment
146+ special_scores = special_cos_dist - self .special_care_embeds_weights + self . adjustment
130147 # special_scores = special_scores.round(decimals=3)
131148 special_care = torch .any (special_scores > 0 , dim = 1 )
132149 special_adjustment = special_care * 0.01
0 commit comments