Skip to content

Commit f72224e

Browse files
Add a flag to print intermediate outputs in CePO
1 parent e7dbd1a commit f72224e

File tree

2 files changed

+20
-13
lines changed

2 files changed

+20
-13
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ optillm supports various command-line arguments and environment variables for co
331331
| `--cepo_planning_max_tokens_step2` | Maximum number of tokens in step 2 of planning stage | 4096 |
332332
| `--cepo_planning_max_tokens_step3` | Maximum number of tokens in step 3 of planning stage | 4096 |
333333
| `--cepo_planning_max_tokens_step4` | Maximum number of tokens in step 4 of planning stage | 4096 |
334+
| `--cepo_print_output` | Whether to print the output of each stage | False |
334335
| `--cepo_config_file` | Path to CePO configuration file | None |
335336

336337
When using Docker, these can be set as environment variables prefixed with `OPTILLM_`.

optillm/cepo.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
import re
22
import yaml
3-
import logging
43

54
from dataclasses import dataclass
65
from typing import Literal, Any
76
from cerebras.cloud.sdk import BadRequestError as CerebrasBadRequestError
87
from openai import BadRequestError as OpenAIBadRequestError
98

109

11-
logger = logging.getLogger(__name__)
12-
13-
1410
@dataclass
1511
class CepoConfig:
1612
bestofn_n: int # number of responses to be generated in best of n stage
@@ -27,6 +23,7 @@ class CepoConfig:
2723
planning_max_tokens_step2: int # maximum number of tokens in step 2 of planning stage
2824
planning_max_tokens_step3: int # maximum number of tokens in step 3 of planning stage
2925
planning_max_tokens_step4: int # maximum number of tokens in step 4 of planning stage
26+
print_output: bool = False # whether to print the output of each stage
3027

3128

3229
# given command line arguments which includes a yaml file path, initialize a CePO configuration
@@ -113,13 +110,15 @@ def generate_completion(system_prompt: str, task: str, client: Any, model: str,
113110
if response.choices[0].finish_reason == "length":
114111
messages.append({"role": "assistant", "content": response.choices[0].message.content})
115112
cb_log[f"messages_planning_{i}_rejected_due_to_length"] = messages
116-
logger.debug(f"Plan proposal rejected due to length. Attempt {i + 1} out of {cepo_config.planning_m}.\nMessages: {messages}")
113+
if cepo_config.print_output:
114+
print(f"\nCePO: Plan proposal rejected due to length. Attempt {i + 1} out of {cepo_config.planning_m}.\nMessages: {messages}")
117115
continue
118116

119117
plans.append(response.choices[0].message.content)
120118
messages.append({"role": "assistant", "content": response.choices[0].message.content})
121119
cb_log[f"messages_planning_{i}"] = messages
122-
logger.debug(f"Plan proposal generated. Attempt {i + 1} out of {cepo_config.planning_m}.\nMessages: {messages}")
120+
if cepo_config.print_output:
121+
print(f"\nCePO: Plan proposal generated. Attempt {i + 1} out of {cepo_config.planning_m}.\nMessages: {messages}")
123122

124123
if len(plans) == cepo_config.planning_n:
125124
break
@@ -129,7 +128,8 @@ def generate_completion(system_prompt: str, task: str, client: Any, model: str,
129128
plans.append(response.choices[0].message.content)
130129
messages.append({"role": "assistant", "content": response.choices[0].message.content})
131130
cb_log[f"messages_planning_{i}_no_plans_so_taking_the_last_one"] = messages
132-
logger.debug(f"No plans generated successfully. Taking the last one from rejected due to length.\nMessages: {messages}")
131+
if cepo_config.print_output:
132+
print(f"\nCePO: No plans generated successfully. Taking the last one from rejected due to length.\nMessages: {messages}")
133133

134134
# Step 3 - Review and address inconsistencies
135135
try:
@@ -169,7 +169,8 @@ def generate_completion(system_prompt: str, task: str, client: Any, model: str,
169169
completion_tokens += response.usage.completion_tokens
170170

171171
cb_log["messages"] = messages
172-
logger.debug(f"Answer generated.\nMessages: {messages}")
172+
if cepo_config.print_output:
173+
print(f"\nCePO: Answer generated.\nMessages: {messages}")
173174
return response.choices[0].message.content, completion_tokens, cb_log
174175

175176

@@ -192,7 +193,8 @@ def generate_n_completions(system_prompt: str, initial_query: str, client: Any,
192193
completions = []
193194

194195
for i in range(cepo_config.bestofn_n):
195-
logger.debug(f"Generating completion {i + 1} out of {cepo_config.bestofn_n}")
196+
if cepo_config.print_output:
197+
print(f"\nCePO: Generating completion {i + 1} out of {cepo_config.bestofn_n}")
196198
response_i, completion_tokens_i, cb_log_i = generate_completion(system_prompt, initial_query, client, model, cepo_config)
197199
completions.append(response_i)
198200
completion_tokens += completion_tokens_i
@@ -264,7 +266,8 @@ def rate_completions_absolute(system_prompt: str, initial_query: str, client: An
264266

265267
rating_response = rating_response.choices[0].message.content.strip()
266268
cb_log[f"rating_response_{i}"] = rating_response
267-
logger.debug(f"Rating response for completion {i}: {rating_response}")
269+
if cepo_config.print_output:
270+
print(f"\nCePO: Rating response for completion {i}: {rating_response}")
268271

269272
pattern = r"Rating: \[\[(\d+)\]\]"
270273
match = re.search(pattern, rating_response)
@@ -280,7 +283,8 @@ def rate_completions_absolute(system_prompt: str, initial_query: str, client: An
280283
best_index = ratings.index(max(ratings))
281284
cb_log["ratings"] = ratings
282285
cb_log["best_index"] = best_index
283-
logger.debug(f"Finished rating completions. Ratings: {ratings}, best completion index: {best_index}")
286+
if cepo_config.print_output:
287+
print(f"\nCePO: Finished rating completions. Ratings: {ratings}, best completion index: {best_index}")
284288
return completions[best_index], completion_tokens, cb_log
285289

286290

@@ -340,7 +344,8 @@ def rate_completions_pairwise(system_prompt: str, initial_query: str, client: An
340344

341345
rating_response = rating_response.choices[0].message.content.strip()
342346
cb_log[f"rating_response_for_pair_{pair[0]}_{pair[1]}"] = rating_response
343-
logger.debug(f"Rating response for pair {pair}: {rating_response}")
347+
if cepo_config.print_output:
348+
print(f"\nCePO: Rating response for pair {pair}: {rating_response}")
344349

345350
pattern = r"Better Response: \[\[(\d+)\]\]"
346351
match = re.search(pattern, rating_response)
@@ -359,7 +364,8 @@ def rate_completions_pairwise(system_prompt: str, initial_query: str, client: An
359364
best_index = ratings.index(max(ratings))
360365
cb_log["ratings"] = ratings
361366
cb_log["best_index"] = best_index
362-
logger.debug(f"Finished rating completions. Ratings: {ratings}, best completion index: {best_index}")
367+
if cepo_config.print_output:
368+
print(f"\nCePO: Finished rating completions. Ratings: {ratings}, best completion index: {best_index}")
363369
return completions[best_index], completion_tokens, cb_log
364370

365371

0 commit comments

Comments
 (0)