33import openai
44
55from dataclasses import dataclass
6- from typing import Optional , Literal
6+ from typing import Optional , Literal , Any
7+ from cerebras .cloud .sdk import BadRequestError as CerebrasBadRequestError
8+ from openai import BadRequestError as OpenAIBadRequestError
79
810import yaml
911
12+
1013@dataclass
1114class CepoConfig :
1215 bestofn_n : int # number of responses to be generated in best of n stage
@@ -24,6 +27,7 @@ class CepoConfig:
2427 planning_max_tokens_step3 : int # maximum number of tokens in step 3 of planning stage
2528 planning_max_tokens_step4 : int # maximum number of tokens in step 4 of planning stage
2629
30+
2731# given command line arguments which includes a yaml file path, initialize a CePO configuration
2832def init_cepo_config (cmd_line_args : dict ) -> CepoConfig :
2933 # get the command line arguments
@@ -35,14 +39,15 @@ def init_cepo_config(cmd_line_args: dict) -> CepoConfig:
3539
3640 # get the yaml file arguments
3741 cepo_config_yaml = {}
38- if cmd_line_args [ "cepo_config_file" ] :
42+ if cmd_line_args . get ( "cepo_config_file" , None ) :
3943 with open (cmd_line_args ["cepo_config_file" ], "r" ) as yaml_file :
4044 cepo_config_yaml = yaml .safe_load (yaml_file )
4145
4246 # merge cepo args from command line and yaml file, args from command line will overwrite the ones from yaml file
4347 cepo_args = {** cepo_config_yaml , ** cepo_args }
4448 return CepoConfig (** cepo_args )
4549
50+
4651def extract_question_only (task : str ) -> str :
4752 """We noticed that sometimes if the task includes specific formatting instructions, they may interfere with the reasoning flow. This
4853 is a temporary workaround to extract the question only from the task. Work in progress.
@@ -52,7 +57,20 @@ def extract_question_only(task: str) -> str:
5257 return question_only
5358
5459
55- def generate_completion (system_prompt : str , task : str , client , model : str , cepo_config : CepoConfig ) -> str :
60+ def generate_completion (system_prompt : str , task : str , client : Any , model : str , cepo_config : CepoConfig ) -> str :
61+ """
62+ Generates a completion based on the provided system prompt and task.
63+
64+ Parameters:
65+ system_prompt (str): The system prompt to guide the model.
66+ task (str): The task or question to be addressed.
67+ client (Any): The client instance for interacting with the AI model.
68+ model (str): The model name to be used for generating completions.
69+ cepo_config (CepoConfig): Configuration parameters for CePO flow.
70+
71+ Returns:
72+ Tuple[str, int, dict]: The generated completion, number of tokens used, and a log dictionary.
73+ """
5674 completion_tokens = 0
5775 question_only = extract_question_only (task )
5876 cb_log = {}
@@ -128,7 +146,7 @@ def generate_completion(system_prompt: str, task: str, client, model: str, cepo_
128146 )
129147 final_solution = response .choices [0 ].message .content
130148 completion_tokens += response .usage .completion_tokens
131- except (cerebras . cloud . sdk . BadRequestError , openai . BadRequestError ) as e :
149+ except (CerebrasBadRequestError , OpenAIBadRequestError ) as e :
132150 # In case of an error, take the first plan as the final solution
133151 final_solution = plans [0 ]
134152 messages = []
@@ -150,7 +168,20 @@ def generate_completion(system_prompt: str, task: str, client, model: str, cepo_
150168 return response .choices [0 ].message .content , completion_tokens , cb_log
151169
152170
153- def generate_n_completions (system_prompt : str , initial_query : str , client , model : str , cepo_config : CepoConfig ) -> tuple [list [str ], int , dict ]:
171+ def generate_n_completions (system_prompt : str , initial_query : str , client : Any , model : str , cepo_config : CepoConfig ) -> tuple [list [str ], int , dict ]:
172+ """
173+ Generates n completions for the Best of N step of CePO.
174+
175+ Parameters:
176+ system_prompt (str): The system prompt to guide the model.
177+ initial_query (str): The task or question to be addressed.
178+ client (Any): The client instance for interacting with the AI model.
179+ model (str): The model name to be used for generating completions.
180+ cepo_config (CepoConfig): Configuration parameters for CePO flow.
181+
182+ Returns:
183+ Tuple[str, int, dict]: The generated completion, number of tokens used, and a log dictionary.
184+ """
154185 completion_tokens = 0
155186 cb_log = {}
156187 completions = []
@@ -166,7 +197,21 @@ def generate_n_completions(system_prompt: str, initial_query: str, client, model
166197 return completions , completion_tokens , cb_log
167198
168199
169- def rate_completions_absolute (system_prompt : str , initial_query : str , client , model : str , completions : list [str ], cepo_config : CepoConfig , cb_log : dict ) -> tuple [str , int , dict ]:
200+ def rate_completions_absolute (system_prompt : str , initial_query : str , client : Any , model : str , completions : list [str ], cepo_config : CepoConfig , cb_log : dict ) -> tuple [str , int , dict ]:
201+ """
202+ Rates completions for the Best of N step of CePO. Each completion is rated on a scale of 1 to 10 individually.
203+
204+ Parameters:
205+ system_prompt (str): The system prompt to guide the model.
206+ initial_query (str): The task or question to be addressed.
207+ client (Any): The client instance for interacting with the AI model.
208+ model (str): The model name to be used for generating completions.
209+ completions (list[str]): List of completions to be rated.
210+ cepo_config (CepoConfig): Configuration parameters for CePO flow.
211+
212+ Returns:
213+ Tuple[str, int, dict]: The generated completion, number of tokens used, and a log dictionary.
214+ """
170215 completion_tokens = 0
171216 rating_messages = [{"role" : "system" , "content" : system_prompt },
172217 {"role" : "user" , "content" : initial_query }]
@@ -231,7 +276,21 @@ def rate_completions_absolute(system_prompt: str, initial_query: str, client, mo
231276 return completions [best_index ], completion_tokens , cb_log
232277
233278
234- def rate_completions_pairwise (system_prompt : str , initial_query : str , client , model : str , completions : list [str ], cepo_config : CepoConfig , cb_log : dict ) -> tuple [str , int , dict ]:
279+ def rate_completions_pairwise (system_prompt : str , initial_query : str , client : Any , model : str , completions : list [str ], cepo_config : CepoConfig , cb_log : dict ) -> tuple [str , int , dict ]:
280+ """
281+ Rates completions for the Best of N step of CePO. Completions are rated pairwise against each other in both orders (A vs B and B vs A).
282+
283+ Parameters:
284+ system_prompt (str): The system prompt to guide the model.
285+ initial_query (str): The task or question to be addressed.
286+ client (Any): The client instance for interacting with the AI model.
287+ model (str): The model name to be used for generating completions.
288+ completions (list[str]): List of completions to be rated.
289+ cepo_config (CepoConfig): Configuration parameters for CePO flow.
290+
291+ Returns:
292+ Tuple[str, int, dict]: The generated completion, number of tokens used, and a log dictionary.
293+ """
235294 completion_tokens = 0
236295 rating_messages = [{"role" : "system" , "content" : system_prompt },
237296 {"role" : "user" , "content" : initial_query }]
@@ -294,9 +353,30 @@ def rate_completions_pairwise(system_prompt: str, initial_query: str, client, mo
294353 return completions [best_index ], completion_tokens , cb_log
295354
296355
297- def cepo (system_prompt : str , initial_query : str , client , model : str , cepo_config : Optional [CepoConfig ] = None ) -> tuple [str , int , dict ]:
298- if cepo_config is None :
299- cepo_config = CepoConfig ()
356+ def cepo (system_prompt : str , initial_query : str , client : Any , model : str , cepo_config : Optional [CepoConfig ]) -> tuple [str , int ]:
357+ """
358+ Applies CePO reasoning flow for the given task. First, it generates multiple completions, and then rates them to select the best one.
359+ Each completion is generated as follows:
360+
361+ Generate `planning_n` solution proposals:
362+ Step 1: Plan Generation - The model generates a detailed, step-by-step plan to solve the problem, along with its confidence level for
363+ each step.
364+ Step 2: Initial Solution - Using the plan from Step 1, the model produces an initial solution.
365+
366+ Step 3: Plan Refinement - The model reviews all generated solution proposals and their associated plans, identifying inconsistencies.
367+ Based on this analysis, a refined, final step-by-step plan is constructed.
368+ Step 4: Final Solution - The model uses the refined plan from Step 3 to produce the final answer.
369+
370+ Parameters:
371+ system_prompt (str): The system prompt to guide the model.
372+ initial_query (str): The task or question to be addressed.
373+ client (Any): The client instance for interacting with the AI model.
374+ model (str): The model name to be used for generating completions.
375+ cepo_config (CepoConfig): Configuration parameters for CePO flow.
376+
377+ Returns:
378+ Tuple[str, int, dict]: The generated completion, number of tokens used
379+ """
300380
301381 # Generate completions
302382 completions , completion_tokens_planning , cb_log = generate_n_completions (system_prompt , initial_query , client , model , cepo_config ) # cb_log is a dictionary for debugging purposes
0 commit comments