@@ -235,17 +235,14 @@ def rate_completions_absolute(system_prompt: str, initial_query: str, client, mo
235235
236236 pattern = r"Rating: \[\[(\d+)\]\]"
237237 match = re .search (pattern , rating_response )
238- if match :
239- rating_response = match .group (1 )
240- else :
241- rating_response = "0"
238+ rating_response = match .group (1 ) if match else "0"
242239
243240 try :
244241 ratings .append (float (rating_response ))
245242 except ValueError :
246243 ratings .append (0 )
247244
248- rating_messages = rating_messages [:- 2 ] # remove the last two messages
245+ rating_messages = rating_messages [:- 2 ] # clear the last two messages to start over in the next iteration
249246
250247 best_index = ratings .index (max (ratings ))
251248 cb_log ["ratings" ] = ratings
@@ -325,10 +322,12 @@ def cepo(system_prompt: str, initial_query: str, client, model: str, cepo_config
325322
326323 # Rate the completions
327324 if cepo_config .bestofn_rating_type == "absolute" :
328- best_completion , completion_tokens_rating , cb_log = rate_completions_absolute ( system_prompt , initial_query , client , model , completions , cepo_config , cb_log )
325+ rate_completions_fn = rate_completions_absolute
329326 elif cepo_config .bestofn_rating_type == "pairwise" :
330- best_completion , completion_tokens_rating , cb_log = rate_completions_pairwise ( system_prompt , initial_query , client , model , completions , cepo_config , cb_log )
327+ rate_completions_fn = rate_completions_pairwise
331328 else :
332329 raise ValueError ("Invalid rating type in cepo_config" )
330+
331+ best_completion , completion_tokens_rating , cb_log = rate_completions_fn (system_prompt , initial_query , client , model , completions , cepo_config , cb_log )
333332
334333 return best_completion , completion_tokens_planning + completion_tokens_rating
0 commit comments