Skip to content

Commit bef08cf

Browse files
Make cepo_config.yaml define the default values instead of the dataclass for single source of truth
1 parent 05ff108 commit bef08cf

File tree

2 files changed

+23
-41
lines changed

2 files changed

+23
-41
lines changed

optillm.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from concurrent.futures import ThreadPoolExecutor
1515
from typing import Tuple, Optional, Union, Dict, Any, List
1616
from importlib.metadata import version
17+
from dataclasses import fields
1718

1819
# Import approach modules
1920
from optillm.mcts import chat_with_mcts
@@ -133,7 +134,7 @@ def none_approach(
133134
model: Model identifier
134135
original_messages: Original messages from the request
135136
**kwargs: Additional parameters to pass through
136-
137+
137138
Returns:
138139
Dict[str, Any]: Full OpenAI API response
139140
"""
@@ -702,12 +703,11 @@ def parse_args():
702703
parser.add_argument("--base-url", "--base_url", dest="base_url", type=str, default=base_url_default,
703704
help="Base url for OpenAI compatible endpoint")
704705

705-
# Special handling of all the Cepo Configurations
706-
for key, value in CepoConfig.__dict__.items():
707-
if not key.startswith('__'):
708-
parser.add_argument(f"--cepo_{key}", dest=f"cepo_{key}", type=type(value), default=None, help=f"CePO configuration for {key}")
706+
# Special handling of all the CePO Configurations
707+
for field in fields(CepoConfig):
708+
parser.add_argument(f"--cepo_{field.name}", dest=f"cepo_{field.name}", type=field.type, default=None, help=f"CePO configuration for {field.name}")
709709

710-
parser.add_argument(f"--cepo_config_file", dest=f"cepo_config_file", type=str, default=None, help="Path to CePO configuration file")
710+
parser.add_argument(f"--cepo_config_file", dest=f"cepo_config_file", type=str, default="./configs/cepo_config.yaml", help="Path to CePO configuration file")
711711

712712
args = parser.parse_args()
713713

optillm/cepo.py

Lines changed: 17 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,20 @@
99

1010
@dataclass
1111
class CepoConfig:
12-
bestofn_n: int = 3
13-
bestofn_temperature: float = 0.1
14-
bestofn_max_tokens: int = 4096
15-
bestofn_rating_type: Literal["absolute", "pairwise"] = "absolute"
16-
planning_n: int = 3
17-
planning_m: int = 6
18-
planning_temperature_step1: float = 0.55
19-
planning_temperature_step2: float = 0.25
20-
planning_temperature_step3: float = 0.1
21-
planning_temperature_step4: float = 0
22-
planning_max_tokens_step1: int = 4096
23-
planning_max_tokens_step2: int = 4096
24-
planning_max_tokens_step3: int = 4096
25-
planning_max_tokens_step4: int = 4096
12+
bestofn_n: int
13+
bestofn_temperature: float
14+
bestofn_max_tokens: int
15+
bestofn_rating_type: Literal["absolute", "pairwise"]
16+
planning_n: int
17+
planning_m: int
18+
planning_temperature_step1: float
19+
planning_temperature_step2: float
20+
planning_temperature_step3: float
21+
planning_temperature_step4: float
22+
planning_max_tokens_step1: int
23+
planning_max_tokens_step2: int
24+
planning_max_tokens_step3: int
25+
planning_max_tokens_step4: int
2626

2727

2828
# given command line arguments which includes a yaml file path, initialize a CePO configuration
@@ -40,27 +40,9 @@ def init_cepo_config(cmd_line_args: dict) -> CepoConfig:
4040
with open(cmd_line_args["cepo_config_file"], "r") as yaml_file:
4141
cepo_config_yaml = yaml.safe_load(yaml_file)
4242

43-
# check if any of the keys overlap, and if they do, error out
44-
for key in cepo_config_yaml.keys():
45-
if key in cepo_args.keys():
46-
raise RuntimeError(f"Key {key} is found in both yaml file and command line arguments")
47-
48-
# if not, then we take both of them and add them to the cepo config
49-
cepo_config = CepoConfig()
50-
cepo_attrs = [key for key, _ in cepo_config.__dict__.items() if not key.startswith('__')]
51-
52-
# add command line arguments
53-
for key, value in cepo_args.items():
54-
# this assert should not be raised as the cli parser should catch this
55-
assert key in cepo_attrs, f"Command line argument {key} is not found in CepoConfig"
56-
setattr(cepo_config, key, value)
57-
58-
# add yaml arguments
59-
for key, value in cepo_config_yaml.items():
60-
assert key in cepo_attrs, f"Yaml argument {key} is not found in CepoConfig"
61-
setattr(cepo_config, key, value)
62-
63-
return cepo_config
43+
# merge cepo args from command line and yaml file, args from command line will overwrite the ones from yaml file
44+
cepo_args = {**cepo_config_yaml, **cepo_args}
45+
return CepoConfig(**cepo_args)
6446

6547
def extract_question_only(task: str) -> str:
6648
"""We noticed that sometimes if the task includes specific formatting instructions, they may interfere with the reasoning flow. This

0 commit comments

Comments
 (0)