Skip to content

Commit 41e365b

Browse files
committed
fixes
1 parent 87caa1c commit 41e365b

File tree

2 files changed

+66
-24
lines changed

2 files changed

+66
-24
lines changed

examples/function_minimization/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Configuration for function minimization example
2-
max_iterations: 10
2+
max_iterations: 100
33
checkpoint_interval: 10
44
log_level: "INFO"
55

examples/function_minimization/evaluator.py

Lines changed: 65 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import time
77
import concurrent.futures
88
import threading
9+
import traceback
10+
import sys
911

1012
def run_with_timeout(func, args=(), kwargs={}, timeout_seconds=5):
1113
"""
@@ -27,6 +29,14 @@ def run_with_timeout(func, args=(), kwargs={}, timeout_seconds=5):
2729
except concurrent.futures.TimeoutError:
2830
raise TimeoutError(f"Function {func.__name__} timed out after {timeout_seconds} seconds")
2931

32+
def safe_float(value):
33+
"""Convert a value to float safely"""
34+
try:
35+
return float(value)
36+
except (TypeError, ValueError):
37+
print(f"Warning: Could not convert {value} of type {type(value)} to float")
38+
return 0.0
39+
3040
def evaluate(program_path):
3141
"""
3242
Evaluate the program by running it multiple times and checking how close
@@ -72,33 +82,45 @@ def evaluate(program_path):
7282
start_time = time.time()
7383

7484
# Run with timeout
75-
x, y, value = run_with_timeout(program.run_search, timeout_seconds=5)
85+
result = run_with_timeout(program.run_search, timeout_seconds=5)
86+
87+
# Check if we got a tuple of 3 values
88+
if not isinstance(result, tuple) or len(result) != 3:
89+
print(f"Trial {trial}: Invalid result format, expected tuple of 3 values but got {type(result)}")
90+
continue
91+
92+
x, y, value = result
7693

7794
end_time = time.time()
7895

96+
# Ensure all values are float
97+
x = safe_float(x)
98+
y = safe_float(y)
99+
value = safe_float(value)
100+
79101
# Check if the result is valid (not NaN or infinite)
80102
if (np.isnan(x) or np.isnan(y) or np.isnan(value) or
81103
np.isinf(x) or np.isinf(y) or np.isinf(value)):
82104
print(f"Trial {trial}: Invalid result, got x={x}, y={y}, value={value}")
83105
continue
84106

85-
# Ensure all values are float
86-
x, y, value = float(x), float(y), float(value)
87-
88107
# Calculate metrics
89-
distance_to_global = np.sqrt((x - GLOBAL_MIN_X)**2 + (y - GLOBAL_MIN_Y)**2)
108+
x_diff = safe_float(x) - GLOBAL_MIN_X
109+
y_diff = safe_float(y) - GLOBAL_MIN_Y
110+
distance_to_global = np.sqrt(x_diff**2 + y_diff**2)
90111
value_difference = abs(value - GLOBAL_MIN_VALUE)
91112

92-
values.append(value)
93-
distances.append(distance_to_global)
94-
times.append(end_time - start_time)
113+
values.append(float(value))
114+
distances.append(float(distance_to_global))
115+
times.append(float(end_time - start_time))
95116
success_count += 1
96117

97118
except TimeoutError as e:
98119
print(f"Trial {trial}: {str(e)}")
99120
continue
100121
except Exception as e:
101122
print(f"Trial {trial}: Error - {str(e)}")
123+
print(traceback.format_exc())
102124
continue
103125

104126
# If all trials failed, return zero scores
@@ -112,31 +134,35 @@ def evaluate(program_path):
112134
}
113135

114136
# Calculate metrics
115-
avg_value = np.mean(values)
116-
avg_distance = np.mean(distances)
117-
avg_time = np.mean(times)
137+
avg_value = float(np.mean(values))
138+
avg_distance = float(np.mean(distances))
139+
avg_time = float(np.mean(times)) if times else 1.0
118140

119141
# Convert to scores (higher is better)
120-
value_score = 1.0 / (1.0 + abs(avg_value - GLOBAL_MIN_VALUE)) # Normalize and invert
121-
distance_score = 1.0 / (1.0 + avg_distance)
122-
speed_score = 1.0 / avg_time if avg_time > 0 else 0.0
142+
value_score = float(1.0 / (1.0 + abs(avg_value - GLOBAL_MIN_VALUE))) # Normalize and invert
143+
distance_score = float(1.0 / (1.0 + avg_distance))
144+
speed_score = float(1.0 / avg_time) if avg_time > 0 else 0.0
123145

124146
# Normalize speed score (so it doesn't dominate)
125-
speed_score = min(speed_score, 10.0) / 10.0
147+
speed_score = float(min(speed_score, 10.0) / 10.0)
126148

127149
# Add reliability score based on success rate
128-
reliability_score = success_count / num_trials
150+
reliability_score = float(success_count / num_trials)
151+
152+
# Calculate combined score
153+
combined_score = float(0.5 * value_score + 0.2 * distance_score + 0.1 * speed_score + 0.2 * reliability_score)
129154

130155
return {
131156
"value_score": value_score,
132157
"distance_score": distance_score,
133158
"speed_score": speed_score,
134159
"reliability_score": reliability_score,
135-
"combined_score": 0.5 * value_score + 0.2 * distance_score + 0.1 * speed_score + 0.2 * reliability_score,
160+
"combined_score": combined_score,
136161
"success_rate": reliability_score
137162
}
138163
except Exception as e:
139164
print(f"Evaluation failed completely: {str(e)}")
165+
print(traceback.format_exc())
140166
return {
141167
"value_score": 0.0,
142168
"distance_score": 0.0,
@@ -149,9 +175,9 @@ def evaluate(program_path):
149175
def evaluate_stage1(program_path):
150176
"""First stage evaluation with fewer trials"""
151177
# Known global minimum (approximate)
152-
GLOBAL_MIN_X = -1.76
153-
GLOBAL_MIN_Y = -1.03
154-
GLOBAL_MIN_VALUE = -2.104
178+
GLOBAL_MIN_X = float(-1.76)
179+
GLOBAL_MIN_Y = float(-1.03)
180+
GLOBAL_MIN_VALUE = float(-2.104)
155181

156182
# Quick check to see if the program runs without errors
157183
try:
@@ -167,31 +193,47 @@ def evaluate_stage1(program_path):
167193

168194
try:
169195
# Run a single trial with timeout
170-
x, y, value = run_with_timeout(program.run_search, timeout_seconds=5)
196+
result = run_with_timeout(program.run_search, timeout_seconds=5)
197+
198+
# Check if we got a tuple of 3 values
199+
if not isinstance(result, tuple) or len(result) != 3:
200+
print(f"Stage 1: Invalid result format, expected tuple of 3 values but got {type(result)}")
201+
return {"runs_successfully": 0.0, "error": "Invalid result format"}
202+
203+
x, y, value = result
171204

172205
# Ensure all values are float
173-
x, y, value = float(x), float(y), float(value)
206+
x = safe_float(x)
207+
y = safe_float(y)
208+
value = safe_float(value)
174209

175210
# Check if the result is valid
176211
if np.isnan(x) or np.isnan(y) or np.isnan(value) or np.isinf(x) or np.isinf(y) or np.isinf(value):
177212
print(f"Stage 1 validation: Invalid result, got x={x}, y={y}, value={value}")
178213
return {"runs_successfully": 0.5, "error": "Invalid result values"}
179214

215+
# Calculate distance safely
216+
x_diff = float(x) - GLOBAL_MIN_X
217+
y_diff = float(y) - GLOBAL_MIN_Y
218+
distance = float(np.sqrt(x_diff**2 + y_diff**2))
219+
180220
# Basic metrics
181221
return {
182222
"runs_successfully": 1.0,
183223
"value": float(value),
184-
"distance": float(np.sqrt((x - GLOBAL_MIN_X)**2 + (y - GLOBAL_MIN_Y)**2)) # Distance to known minimum
224+
"distance": distance
185225
}
186226
except TimeoutError as e:
187227
print(f"Stage 1 evaluation timed out: {e}")
188228
return {"runs_successfully": 0.0, "error": "Timeout"}
189229
except Exception as e:
190230
print(f"Stage 1 evaluation failed: {e}")
231+
print(traceback.format_exc())
191232
return {"runs_successfully": 0.0, "error": str(e)}
192233

193234
except Exception as e:
194235
print(f"Stage 1 evaluation failed: {e}")
236+
print(traceback.format_exc())
195237
return {"runs_successfully": 0.0, "error": str(e)}
196238

197239
def evaluate_stage2(program_path):

0 commit comments

Comments
 (0)