Skip to content

Commit b6e827c

Browse files
committed
Update train_optillm_classifier.py
bert large for classification
1 parent 1b62fc4 commit b6e827c

File tree

1 file changed

+54
-52
lines changed

1 file changed

+54
-52
lines changed

scripts/train_optillm_classifier.py

Lines changed: 54 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,17 @@
2424
MAX_LENGTH = 512
2525

2626
class OptILMDataset(Dataset):
27-
def __init__(self, prompts, approaches, tokenizer):
27+
def __init__(self, prompts, best_approaches, tokenizer):
2828
self.prompts = prompts
29-
self.approaches = approaches
29+
self.best_approaches = best_approaches
3030
self.tokenizer = tokenizer
3131

3232
def __len__(self):
3333
return len(self.prompts)
3434

3535
def __getitem__(self, idx):
3636
prompt = self.prompts[idx]
37-
approach = self.approaches[idx]
37+
best_approach = self.best_approaches[idx]
3838

3939
encoding = self.tokenizer.encode_plus(
4040
prompt,
@@ -49,69 +49,60 @@ def __getitem__(self, idx):
4949
return {
5050
'input_ids': encoding['input_ids'].flatten(),
5151
'attention_mask': encoding['attention_mask'].flatten(),
52-
'labels': torch.tensor(APPROACHES.index(approach), dtype=torch.long)
52+
'labels': torch.tensor(APPROACHES.index(best_approach), dtype=torch.long)
5353
}
5454

5555
def load_and_preprocess_data(tokenizer):
56-
dataset = load_dataset('json', data_files='optillm_dataset_1.jsonl')
56+
dataset = load_dataset('json', data_files='optillm_dataset.jsonl')
5757

5858
data_items = []
5959

6060
for item in dataset['train']:
6161
prompt = item['prompt']
6262
results = item['results']
6363

64-
valid_results = [r for r in results if 'approach' in r]
65-
if not valid_results:
64+
if not results:
6665
continue
66+
# Filter the list to exclude items where rank is None
67+
filtered_data = [item for item in results if item['rank'] is not None]
68+
# Find the best approach (lowest rank)
69+
best_result = min(filtered_data, key=lambda x: x['rank'])
70+
best_approach = best_result['approach']
6771

68-
for result in valid_results:
69-
data_items.append({
70-
'prompt': prompt,
71-
'approach': result['approach']
72-
})
72+
data_items.append({
73+
'prompt': prompt,
74+
'best_approach': best_approach
75+
})
7376

7477
# Print some statistics
7578
print(f"Total data points: {len(data_items)}")
7679
print(f"Unique prompts: {len(set(item['prompt'] for item in data_items))}")
77-
approach_counts = Counter(item['approach'] for item in data_items)
78-
print("Approach distribution:")
80+
approach_counts = Counter(item['best_approach'] for item in data_items)
81+
print("Best Approach distribution:")
7982
for approach, count in approach_counts.items():
8083
print(f" {approach}: {count}")
8184

82-
# Calculate class weights for balanced sampling
83-
class_weights = {approach: len(data_items) / count for approach, count in approach_counts.items()}
84-
sample_weights = [class_weights[item['approach']] for item in data_items]
85-
8685
# Split the data
8786
train_data, val_data = train_test_split(data_items, test_size=0.2, random_state=42)
8887

8988
train_dataset = OptILMDataset(
9089
[item['prompt'] for item in train_data],
91-
[item['approach'] for item in train_data],
90+
[item['best_approach'] for item in train_data],
9291
tokenizer
9392
)
9493
val_dataset = OptILMDataset(
9594
[item['prompt'] for item in val_data],
96-
[item['approach'] for item in val_data],
95+
[item['best_approach'] for item in val_data],
9796
tokenizer
9897
)
9998

100-
# Create a weighted sampler for the training data
101-
train_sampler = WeightedRandomSampler(
102-
weights=[class_weights[item['approach']] for item in train_data],
103-
num_samples=len(train_data),
104-
replacement=True
105-
)
106-
107-
return train_dataset, val_dataset, train_sampler
99+
return train_dataset, val_dataset
108100

109-
def calculate_accuracy(logits, labels):
110-
predictions = torch.argmax(logits, dim=-1)
101+
def calculate_accuracy(predictions, labels):
111102
return (predictions == labels).float().mean()
112103

113104
def train(model, train_dataloader, val_dataloader, optimizer, scheduler, num_epochs):
114-
best_val_loss = float('inf')
105+
best_val_accuracy = 0.0
115106

116107
for epoch in range(num_epochs):
117108
model.train()
@@ -134,14 +125,14 @@ def train(model, train_dataloader, val_dataloader, optimizer, scheduler, num_epo
134125
optimizer.zero_grad()
135126

136127
total_loss += loss.item()
137-
total_accuracy += calculate_accuracy(logits, labels)
128+
predictions = torch.argmax(logits, dim=-1)
129+
total_accuracy += calculate_accuracy(predictions, labels)
138130

139131
avg_train_loss = total_loss / len(train_dataloader)
140132
avg_train_accuracy = total_accuracy / len(train_dataloader)
141133

142134
# Validation
143135
model.eval()
144-
total_val_loss = 0
145136
total_val_accuracy = 0
146137

147138
with torch.no_grad():
@@ -150,20 +141,17 @@ def train(model, train_dataloader, val_dataloader, optimizer, scheduler, num_epo
150141
attention_mask = batch['attention_mask'].to(device)
151142
labels = batch['labels'].to(device)
152143

153-
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
154-
val_loss = outputs.loss
144+
outputs = model(input_ids, attention_mask=attention_mask)
155145
logits = outputs.logits
146+
predictions = torch.argmax(logits, dim=-1)
147+
total_val_accuracy += calculate_accuracy(predictions, labels)
156148

157-
total_val_loss += val_loss.item()
158-
total_val_accuracy += calculate_accuracy(logits, labels)
159-
160-
avg_val_loss = total_val_loss / len(val_dataloader)
161149
avg_val_accuracy = total_val_accuracy / len(val_dataloader)
162150

163-
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Train Accuracy: {avg_train_accuracy:.4f}, Val Loss: {avg_val_loss:.4f}, Val Accuracy: {avg_val_accuracy:.4f}")
151+
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Train Accuracy: {avg_train_accuracy:.4f}, Val Accuracy: {avg_val_accuracy:.4f}")
164152

165-
if avg_val_loss < best_val_loss:
166-
best_val_loss = avg_val_loss
153+
if avg_val_accuracy > best_val_accuracy:
154+
best_val_accuracy = avg_val_accuracy
167155
# Save the best model
168156
save_model(model, "best_model.safetensors")
169157

@@ -188,10 +176,10 @@ def main(args):
188176
model.to(device)
189177

190178
# Load and preprocess data
191-
train_dataset, val_dataset, train_sampler = load_and_preprocess_data(tokenizer)
179+
train_dataset, val_dataset = load_and_preprocess_data(tokenizer)
192180

193181
# Create data loaders
194-
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler)
182+
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size)
195183
val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size)
196184

197185
# Optimizer and scheduler
@@ -209,17 +197,31 @@ def main(args):
209197
model.push_to_hub(args.hub_model_id)
210198
tokenizer.push_to_hub(args.hub_model_id)
211199

212-
# Example inference
213-
test_prompt = "Maximize x + y subject to: x + 2y <= 10, x >= 0, y >= 0"
214-
predicted_approach, confidence = inference(model, tokenizer, test_prompt)
215-
print(f"Test Prompt: {test_prompt}")
216-
print(f"Predicted Approach: {predicted_approach}")
217-
print(f"Confidence: {confidence:.4f}")
200+
# Example inferences
201+
test_prompts = [
202+
"Maximize x + y subject to: x + 2y <= 10, x >= 0, y >= 0",
203+
"Find the shortest path between nodes A and B in the given graph",
204+
"Solve the Tower of Hanoi problem with 4 disks",
205+
"Determine if the given number is prime",
206+
"Find all possible combinations of coins that sum up to $1",
207+
"Implement a binary search algorithm",
208+
"Design an algorithm to find the longest palindromic substring",
209+
"Solve the 8-queens problem",
210+
"Implement a depth-first search algorithm for a graph",
211+
"Find the maximum subarray sum in a given array of integers"
212+
]
213+
214+
print("\nInference Examples:")
215+
for prompt in test_prompts:
216+
predicted_approach, confidence = inference(model, tokenizer, prompt)
217+
print(f"\nTest Prompt: {prompt}")
218+
print(f"Predicted Approach: {predicted_approach}")
219+
print(f"Confidence: {confidence:.4f}")
218220

219221
if __name__ == "__main__":
220222
parser = argparse.ArgumentParser(description="Train OptILM classifier")
221-
parser.add_argument("--model_name", type=str, default="roberta-large", help="Pretrained model name")
222-
parser.add_argument("--batch_size", type=int, default=16, help="Batch size for training")
223+
parser.add_argument("--model_name", type=str, default="google-bert/bert-large-uncased", help="Pretrained model name")
224+
parser.add_argument("--batch_size", type=int, default=4, help="Batch size for training")
223225
parser.add_argument("--learning_rate", type=float, default=2e-5, help="Learning rate")
224226
parser.add_argument("--num_epochs", type=int, default=10, help="Number of training epochs")
225227
parser.add_argument("--push_to_hub", action="store_true", help="Push model to Hugging Face Hub")

0 commit comments

Comments
 (0)