Skip to content

Commit 8b4617e

Browse files
erich-cerebraspawelf-cerebras
authored andcommitted
Add modification of CePO configs through yaml and cli arguments
This will allow users to do the following: 1. Cli where if they can pass in anything that's "cepo_<name-of-attribute>" 2. Yaml file where if they pass it in as "<name-of-attribute>" 3. If none of them have a specific attribute, we use the default setting 4. If both of them have the specific attribute, we error out
1 parent 0617e0c commit 8b4617e

File tree

3 files changed

+76
-9
lines changed

3 files changed

+76
-9
lines changed

configs/cepo_config.yaml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
bestofn_n: 4
2+
bestofn_temperature: 0.2324214
3+
bestofn_max_tokens: 4096
4+
bestofn_rating_type: "absolute" # or "pairwise"
5+
planning_n: 2
6+
planning_m: 2
7+
planning_temperature_step1: 0.00055
8+
planning_temperature_step2: 0.25
9+
planning_temperature_step3: 0.999
10+
planning_temperature_step4: 1.2
11+
planning_max_tokens_step1: 96
12+
planning_max_tokens_step2: 2
13+
planning_max_tokens_step3: 8000
14+
planning_max_tokens_step4: 0

optillm.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from optillm.plansearch import plansearch
2929
from optillm.leap import leap
3030
from optillm.reread import re2_approach
31-
from optillm.cepo import cepo
31+
from optillm.cepo import cepo, CepoConfig, init_cepo_config
3232

3333
# Setup logging
3434
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -292,7 +292,9 @@ def execute_single_approach(approach, system_prompt, initial_query, client, mode
292292
elif approach == 're2':
293293
return re2_approach(system_prompt, initial_query, client, model, n=server_config['n'])
294294
elif approach == 'cepo':
295-
return cepo(system_prompt, initial_query, client, model)
295+
# build the cepo config based on the cmd line arguments and the
296+
logger.debug(f"Calling with {cepo_config}")
297+
return cepo(system_prompt, initial_query, client, model, cepo_config)
296298
elif approach in plugin_approaches:
297299
return plugin_approaches[approach](system_prompt, initial_query, client, model)
298300
else:
@@ -701,6 +703,13 @@ def parse_args():
701703
parser.add_argument("--base-url", "--base_url", dest="base_url", type=str, default=base_url_default,
702704
help="Base url for OpenAI compatible endpoint")
703705

706+
# Special handling of all the Cepo Configurations
707+
for key, value in CepoConfig.__dict__.items():
708+
if not key.startswith('__'):
709+
parser.add_argument(f"--cepo_{key}", dest=f"cepo_{key}", type=type(value), default=None, help=f"CePO configuration for {key}")
710+
711+
parser.add_argument(f"--cepo_config_file", dest=f"cepo_config_file", type=str, default=None, help="Path to CePO configuration file")
712+
704713
args = parser.parse_args()
705714

706715
# Convert argument names to match server_config keys
@@ -714,6 +723,7 @@ def parse_args():
714723

715724
def main():
716725
global server_config
726+
global cepo_config
717727
# Call this function at the start of main()
718728
args = parse_args()
719729
# Update server_config with all argument values
@@ -728,6 +738,11 @@ def main():
728738
if logging_level in logging_levels.keys():
729739
logger.setLevel(logging_levels[logging_level])
730740

741+
# set and log the cepo configs
742+
cepo_config = init_cepo_config(server_config)
743+
if args.approach == 'cepo':
744+
logger.info(f"CePO Config: {cepo_config}")
745+
731746
logger.info(f"Starting server with approach: {server_config['approach']}")
732747
server_config_clean = server_config.copy()
733748
if server_config_clean['optillm_api_key']:

optillm/cepo.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,65 @@
33
import openai
44

55
from dataclasses import dataclass
6-
from typing import Optional
6+
from typing import Optional, Literal
77

8+
import yaml
89

910
@dataclass
1011
class CepoConfig:
1112
bestofn_n: int = 3
12-
bestofn_temperature: int = 0.1
13+
bestofn_temperature: float = 0.1
1314
bestofn_max_tokens: int = 4096
14-
bestofn_rating_type: str = "absolute" # or "pairwise"
15+
bestofn_rating_type: Literal["absolute", "pairwise"] = "absolute"
1516
planning_n: int = 3
1617
planning_m: int = 6
17-
planning_temperature_step1: int = 0.55
18-
planning_temperature_step2: int = 0.25
19-
planning_temperature_step3: int = 0.1
20-
planning_temperature_step4: int = 0
18+
planning_temperature_step1: float = 0.55
19+
planning_temperature_step2: float = 0.25
20+
planning_temperature_step3: float = 0.1
21+
planning_temperature_step4: float = 0
2122
planning_max_tokens_step1: int = 4096
2223
planning_max_tokens_step2: int = 4096
2324
planning_max_tokens_step3: int = 4096
2425
planning_max_tokens_step4: int = 4096
2526

2627

28+
# given command line arguments which includes a yaml file path, initialize a CePO configuration
29+
def init_cepo_config(cmd_line_args: dict) -> CepoConfig:
30+
# get the command line arguments
31+
cepo_args = {
32+
key.split("cepo_")[1]: value
33+
for key, value in cmd_line_args.items()
34+
if "cepo" in key and "cepo_config_file" != key and value is not None
35+
}
36+
37+
# get the yaml file arguments
38+
cepo_config_yaml = {}
39+
if "cepo_config_file" in cmd_line_args.keys():
40+
with open(cmd_line_args["cepo_config_file"], "r") as yaml_file:
41+
cepo_config_yaml = yaml.safe_load(yaml_file)
42+
43+
# check if any of the keys overlap, and if they do, error out
44+
for key in cepo_config_yaml.keys():
45+
if key in cepo_args.keys():
46+
raise RuntimeError(f"Key {key} is found in both yaml file and command line arguments")
47+
48+
# if not, then we take both of them and add them to the cepo config
49+
cepo_config = CepoConfig()
50+
cepo_attrs = [key for key, _ in cepo_config.__dict__.items() if not key.startswith('__')]
51+
52+
# add command line arguments
53+
for key, value in cepo_args.items():
54+
# this assert should not be raised as the cli parser should catch this
55+
assert key in cepo_attrs, f"Command line argument {key} is not found in CepoConfig"
56+
setattr(cepo_config, key, value)
57+
58+
# add yaml arguments
59+
for key, value in cepo_config_yaml.items():
60+
assert key in cepo_attrs, f"Yaml argument {key} is not found in CepoConfig"
61+
setattr(cepo_config, key, value)
62+
63+
return cepo_config
64+
2765
def extract_question_only(task: str) -> str:
2866
"""We noticed that sometimes if the task includes specific formatting instructions, they may interfere with the reasoning flow. This
2967
is a temporary workaround to extract the question only from the task. Work in progress.

0 commit comments

Comments
 (0)