Skip to content

Commit 2337abb

Browse files
Minor fixes and add docstrings
1 parent 3807724 commit 2337abb

File tree

1 file changed

+90
-10
lines changed

1 file changed

+90
-10
lines changed

optillm/cepo.py

Lines changed: 90 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@
33
import openai
44

55
from dataclasses import dataclass
6-
from typing import Optional, Literal
6+
from typing import Optional, Literal, Any
7+
from cerebras.cloud.sdk import BadRequestError as CerebrasBadRequestError
8+
from openai import BadRequestError as OpenAIBadRequestError
79

810
import yaml
911

12+
1013
@dataclass
1114
class CepoConfig:
1215
bestofn_n: int # number of responses to be generated in best of n stage
@@ -24,6 +27,7 @@ class CepoConfig:
2427
planning_max_tokens_step3: int # maximum number of tokens in step 3 of planning stage
2528
planning_max_tokens_step4: int # maximum number of tokens in step 4 of planning stage
2629

30+
2731
# given command line arguments which includes a yaml file path, initialize a CePO configuration
2832
def init_cepo_config(cmd_line_args: dict) -> CepoConfig:
2933
# get the command line arguments
@@ -35,14 +39,15 @@ def init_cepo_config(cmd_line_args: dict) -> CepoConfig:
3539

3640
# get the yaml file arguments
3741
cepo_config_yaml = {}
38-
if cmd_line_args["cepo_config_file"]:
42+
if cmd_line_args.get("cepo_config_file", None):
3943
with open(cmd_line_args["cepo_config_file"], "r") as yaml_file:
4044
cepo_config_yaml = yaml.safe_load(yaml_file)
4145

4246
# merge cepo args from command line and yaml file, args from command line will overwrite the ones from yaml file
4347
cepo_args = {**cepo_config_yaml, **cepo_args}
4448
return CepoConfig(**cepo_args)
4549

50+
4651
def extract_question_only(task: str) -> str:
4752
"""We noticed that sometimes if the task includes specific formatting instructions, they may interfere with the reasoning flow. This
4853
is a temporary workaround to extract the question only from the task. Work in progress.
@@ -52,7 +57,20 @@ def extract_question_only(task: str) -> str:
5257
return question_only
5358

5459

55-
def generate_completion(system_prompt: str, task: str, client, model: str, cepo_config: CepoConfig) -> str:
60+
def generate_completion(system_prompt: str, task: str, client: Any, model: str, cepo_config: CepoConfig) -> str:
61+
"""
62+
Generates a completion based on the provided system prompt and task.
63+
64+
Parameters:
65+
system_prompt (str): The system prompt to guide the model.
66+
task (str): The task or question to be addressed.
67+
client (Any): The client instance for interacting with the AI model.
68+
model (str): The model name to be used for generating completions.
69+
cepo_config (CepoConfig): Configuration parameters for CePO flow.
70+
71+
Returns:
72+
Tuple[str, int, dict]: The generated completion, number of tokens used, and a log dictionary.
73+
"""
5674
completion_tokens = 0
5775
question_only = extract_question_only(task)
5876
cb_log = {}
@@ -128,7 +146,7 @@ def generate_completion(system_prompt: str, task: str, client, model: str, cepo_
128146
)
129147
final_solution = response.choices[0].message.content
130148
completion_tokens += response.usage.completion_tokens
131-
except (cerebras.cloud.sdk.BadRequestError, openai.BadRequestError) as e:
149+
except (CerebrasBadRequestError, OpenAIBadRequestError) as e:
132150
# In case of an error, take the first plan as the final solution
133151
final_solution = plans[0]
134152
messages = []
@@ -150,7 +168,20 @@ def generate_completion(system_prompt: str, task: str, client, model: str, cepo_
150168
return response.choices[0].message.content, completion_tokens, cb_log
151169

152170

153-
def generate_n_completions(system_prompt: str, initial_query: str, client, model: str, cepo_config: CepoConfig) -> tuple[list[str], int, dict]:
171+
def generate_n_completions(system_prompt: str, initial_query: str, client: Any, model: str, cepo_config: CepoConfig) -> tuple[list[str], int, dict]:
172+
"""
173+
Generates n completions for the Best of N step of CePO.
174+
175+
Parameters:
176+
system_prompt (str): The system prompt to guide the model.
177+
initial_query (str): The task or question to be addressed.
178+
client (Any): The client instance for interacting with the AI model.
179+
model (str): The model name to be used for generating completions.
180+
cepo_config (CepoConfig): Configuration parameters for CePO flow.
181+
182+
Returns:
183+
Tuple[str, int, dict]: The generated completion, number of tokens used, and a log dictionary.
184+
"""
154185
completion_tokens = 0
155186
cb_log = {}
156187
completions = []
@@ -166,7 +197,21 @@ def generate_n_completions(system_prompt: str, initial_query: str, client, model
166197
return completions, completion_tokens, cb_log
167198

168199

169-
def rate_completions_absolute(system_prompt: str, initial_query: str, client, model: str, completions: list[str], cepo_config: CepoConfig, cb_log: dict) -> tuple[str, int, dict]:
200+
def rate_completions_absolute(system_prompt: str, initial_query: str, client: Any, model: str, completions: list[str], cepo_config: CepoConfig, cb_log: dict) -> tuple[str, int, dict]:
201+
"""
202+
Rates completions for the Best of N step of CePO. Each completion is rated on a scale of 1 to 10 individually.
203+
204+
Parameters:
205+
system_prompt (str): The system prompt to guide the model.
206+
initial_query (str): The task or question to be addressed.
207+
client (Any): The client instance for interacting with the AI model.
208+
model (str): The model name to be used for generating completions.
209+
completions (list[str]): List of completions to be rated.
210+
cepo_config (CepoConfig): Configuration parameters for CePO flow.
211+
212+
Returns:
213+
Tuple[str, int, dict]: The generated completion, number of tokens used, and a log dictionary.
214+
"""
170215
completion_tokens = 0
171216
rating_messages = [{"role": "system", "content": system_prompt},
172217
{"role": "user", "content": initial_query}]
@@ -231,7 +276,21 @@ def rate_completions_absolute(system_prompt: str, initial_query: str, client, mo
231276
return completions[best_index], completion_tokens, cb_log
232277

233278

234-
def rate_completions_pairwise(system_prompt: str, initial_query: str, client, model: str, completions: list[str], cepo_config: CepoConfig, cb_log: dict) -> tuple[str, int, dict]:
279+
def rate_completions_pairwise(system_prompt: str, initial_query: str, client: Any, model: str, completions: list[str], cepo_config: CepoConfig, cb_log: dict) -> tuple[str, int, dict]:
280+
"""
281+
Rates completions for the Best of N step of CePO. Completions are rated pairwise against each other in both orders (A vs B and B vs A).
282+
283+
Parameters:
284+
system_prompt (str): The system prompt to guide the model.
285+
initial_query (str): The task or question to be addressed.
286+
client (Any): The client instance for interacting with the AI model.
287+
model (str): The model name to be used for generating completions.
288+
completions (list[str]): List of completions to be rated.
289+
cepo_config (CepoConfig): Configuration parameters for CePO flow.
290+
291+
Returns:
292+
Tuple[str, int, dict]: The generated completion, number of tokens used, and a log dictionary.
293+
"""
235294
completion_tokens = 0
236295
rating_messages = [{"role": "system", "content": system_prompt},
237296
{"role": "user", "content": initial_query}]
@@ -294,9 +353,30 @@ def rate_completions_pairwise(system_prompt: str, initial_query: str, client, mo
294353
return completions[best_index], completion_tokens, cb_log
295354

296355

297-
def cepo(system_prompt: str, initial_query: str, client, model: str, cepo_config: Optional[CepoConfig] = None) -> tuple[str, int, dict]:
298-
if cepo_config is None:
299-
cepo_config = CepoConfig()
356+
def cepo(system_prompt: str, initial_query: str, client: Any, model: str, cepo_config: Optional[CepoConfig]) -> tuple[str, int]:
357+
"""
358+
Applies CePO reasoning flow for the given task. First, it generates multiple completions, and then rates them to select the best one.
359+
Each completion is generated as follows:
360+
361+
Generate `planning_n` solution proposals:
362+
Step 1: Plan Generation - The model generates a detailed, step-by-step plan to solve the problem, along with its confidence level for
363+
each step.
364+
Step 2: Initial Solution - Using the plan from Step 1, the model produces an initial solution.
365+
366+
Step 3: Plan Refinement - The model reviews all generated solution proposals and their associated plans, identifying inconsistencies.
367+
Based on this analysis, a refined, final step-by-step plan is constructed.
368+
Step 4: Final Solution - The model uses the refined plan from Step 3 to produce the final answer.
369+
370+
Parameters:
371+
system_prompt (str): The system prompt to guide the model.
372+
initial_query (str): The task or question to be addressed.
373+
client (Any): The client instance for interacting with the AI model.
374+
model (str): The model name to be used for generating completions.
375+
cepo_config (CepoConfig): Configuration parameters for CePO flow.
376+
377+
Returns:
378+
Tuple[str, int, dict]: The generated completion, number of tokens used
379+
"""
300380

301381
# Generate completions
302382
completions, completion_tokens_planning, cb_log = generate_n_completions(system_prompt, initial_query, client, model, cepo_config) # cb_log is a dictionary for debugging purposes

0 commit comments

Comments
 (0)