11import re
22import yaml
3- import logging
43
54from dataclasses import dataclass
65from typing import Literal , Any
76from cerebras .cloud .sdk import BadRequestError as CerebrasBadRequestError
87from openai import BadRequestError as OpenAIBadRequestError
98
109
11- logger = logging .getLogger (__name__ )
12-
13-
1410@dataclass
1511class CepoConfig :
1612 bestofn_n : int # number of responses to be generated in best of n stage
@@ -27,6 +23,7 @@ class CepoConfig:
2723 planning_max_tokens_step2 : int # maximum number of tokens in step 2 of planning stage
2824 planning_max_tokens_step3 : int # maximum number of tokens in step 3 of planning stage
2925 planning_max_tokens_step4 : int # maximum number of tokens in step 4 of planning stage
26+ print_output : bool = False # whether to print the output of each stage
3027
3128
3229# given command line arguments which includes a yaml file path, initialize a CePO configuration
@@ -113,13 +110,15 @@ def generate_completion(system_prompt: str, task: str, client: Any, model: str,
113110 if response .choices [0 ].finish_reason == "length" :
114111 messages .append ({"role" : "assistant" , "content" : response .choices [0 ].message .content })
115112 cb_log [f"messages_planning_{ i } _rejected_due_to_length" ] = messages
116- logger .debug (f"Plan proposal rejected due to length. Attempt { i + 1 } out of { cepo_config .planning_m } .\n Messages: { messages } " )
113+ if cepo_config .print_output :
114+ print (f"\n CePO: Plan proposal rejected due to length. Attempt { i + 1 } out of { cepo_config .planning_m } .\n Messages: { messages } " )
117115 continue
118116
119117 plans .append (response .choices [0 ].message .content )
120118 messages .append ({"role" : "assistant" , "content" : response .choices [0 ].message .content })
121119 cb_log [f"messages_planning_{ i } " ] = messages
122- logger .debug (f"Plan proposal generated. Attempt { i + 1 } out of { cepo_config .planning_m } .\n Messages: { messages } " )
120+ if cepo_config .print_output :
121+ print (f"\n CePO: Plan proposal generated. Attempt { i + 1 } out of { cepo_config .planning_m } .\n Messages: { messages } " )
123122
124123 if len (plans ) == cepo_config .planning_n :
125124 break
@@ -129,7 +128,8 @@ def generate_completion(system_prompt: str, task: str, client: Any, model: str,
129128 plans .append (response .choices [0 ].message .content )
130129 messages .append ({"role" : "assistant" , "content" : response .choices [0 ].message .content })
131130 cb_log [f"messages_planning_{ i } _no_plans_so_taking_the_last_one" ] = messages
132- logger .debug (f"No plans generated successfully. Taking the last one from rejected due to length.\n Messages: { messages } " )
131+ if cepo_config .print_output :
132+ print (f"\n CePO: No plans generated successfully. Taking the last one from rejected due to length.\n Messages: { messages } " )
133133
134134 # Step 3 - Review and address inconsistencies
135135 try :
@@ -169,7 +169,8 @@ def generate_completion(system_prompt: str, task: str, client: Any, model: str,
169169 completion_tokens += response .usage .completion_tokens
170170
171171 cb_log ["messages" ] = messages
172- logger .debug (f"Answer generated.\n Messages: { messages } " )
172+ if cepo_config .print_output :
173+ print (f"\n CePO: Answer generated.\n Messages: { messages } " )
173174 return response .choices [0 ].message .content , completion_tokens , cb_log
174175
175176
@@ -192,7 +193,8 @@ def generate_n_completions(system_prompt: str, initial_query: str, client: Any,
192193 completions = []
193194
194195 for i in range (cepo_config .bestofn_n ):
195- logger .debug (f"Generating completion { i + 1 } out of { cepo_config .bestofn_n } " )
196+ if cepo_config .print_output :
197+ print (f"\n CePO: Generating completion { i + 1 } out of { cepo_config .bestofn_n } " )
196198 response_i , completion_tokens_i , cb_log_i = generate_completion (system_prompt , initial_query , client , model , cepo_config )
197199 completions .append (response_i )
198200 completion_tokens += completion_tokens_i
@@ -264,7 +266,8 @@ def rate_completions_absolute(system_prompt: str, initial_query: str, client: An
264266
265267 rating_response = rating_response .choices [0 ].message .content .strip ()
266268 cb_log [f"rating_response_{ i } " ] = rating_response
267- logger .debug (f"Rating response for completion { i } : { rating_response } " )
269+ if cepo_config .print_output :
270+ print (f"\n CePO: Rating response for completion { i } : { rating_response } " )
268271
269272 pattern = r"Rating: \[\[(\d+)\]\]"
270273 match = re .search (pattern , rating_response )
@@ -280,7 +283,8 @@ def rate_completions_absolute(system_prompt: str, initial_query: str, client: An
280283 best_index = ratings .index (max (ratings ))
281284 cb_log ["ratings" ] = ratings
282285 cb_log ["best_index" ] = best_index
283- logger .debug (f"Finished rating completions. Ratings: { ratings } , best completion index: { best_index } " )
286+ if cepo_config .print_output :
287+ print (f"\n CePO: Finished rating completions. Ratings: { ratings } , best completion index: { best_index } " )
284288 return completions [best_index ], completion_tokens , cb_log
285289
286290
@@ -340,7 +344,8 @@ def rate_completions_pairwise(system_prompt: str, initial_query: str, client: An
340344
341345 rating_response = rating_response .choices [0 ].message .content .strip ()
342346 cb_log [f"rating_response_for_pair_{ pair [0 ]} _{ pair [1 ]} " ] = rating_response
343- logger .debug (f"Rating response for pair { pair } : { rating_response } " )
347+ if cepo_config .print_output :
348+ print (f"\n CePO: Rating response for pair { pair } : { rating_response } " )
344349
345350 pattern = r"Better Response: \[\[(\d+)\]\]"
346351 match = re .search (pattern , rating_response )
@@ -359,7 +364,8 @@ def rate_completions_pairwise(system_prompt: str, initial_query: str, client: An
359364 best_index = ratings .index (max (ratings ))
360365 cb_log ["ratings" ] = ratings
361366 cb_log ["best_index" ] = best_index
362- logger .debug (f"Finished rating completions. Ratings: { ratings } , best completion index: { best_index } " )
367+ if cepo_config .print_output :
368+ print (f"\n CePO: Finished rating completions. Ratings: { ratings } , best completion index: { best_index } " )
363369 return completions [best_index ], completion_tokens , cb_log
364370
365371
0 commit comments