-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdebug_answer_extraction.py
More file actions
353 lines (273 loc) · 15.8 KB
/
debug_answer_extraction.py
File metadata and controls
353 lines (273 loc) · 15.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
#!/usr/bin/env python3
"""Debug answer extraction on AIME reasoning paths."""
import re
from typing import List, Dict, Any
def test_answer_extraction():
"""Test answer extraction on sample AIME reasoning paths."""
# Sample reasoning paths from the 14B model AIME test
sample_reasoning_paths = [
# Problem 2 - should extract around 236
"""218
Okay, so I've got this problem here where I need to find the sum of the squares of all the items in a list of positive integers. The list has some specific properties:
1. The sum of all the items is 30.
2. The unique mode is 9.
3. The median is a positive integer that doesn't appear in the list.
And the answer is supposed to be 218. Hmm, let me try to figure this out step by step.
First, let me recall what each term means:
- **Sum**: The total when you add all the numbers together. So, all the numbers in the list add up to 30.
- **Mode**: The number that appears most frequently. It's given that the unique mode is 9, which means 9 appears more times than any other number, and no other number has the same frequency as 9.
- **Median**: The middle number when the numbers are arranged in order. Importantly, the median is a positive integer that doesn't appear in the list. So, if I have an odd number of elements, the median is the middle one. If even, it's the average of the two middle ones. But since it's a positive integer not in the list, that must mean that when the list has an even number of elements, the median is the average of two middle numbers which isn't present. Or, if it's odd, the exact middle number isn't in the list, which is impossible because in an odd-length list, the median is an element of the list. Therefore, the list must have an even number of elements because the median isn't present.
So, from the median condition, the list must have an even number of elements. Because if it were odd, the median would be one of the numbers in the list, which contradicts the given condition. So, even number of elements.
Let me denote the number of elements as n. So, n is even. Let's denote n = 2k, where k is an integer. Then, the median would be the average of the k-th and (k+1)-th elements. Since the median is a positive integer not in the list, that tells me that the two middle numbers must add up to an even number because their average is an integer. So, the two middle numbers must be both odd or both even. But since the average is not present in the list, that means that these""",
# Problem 3 - should extract around 45
"""8
Okay, so I've got this problem here about filling a 2x3 grid with digits. The goal is to have the sum of the two numbers formed by reading left to right equal to 999, and the sum of the three numbers formed by reading top to bottom equal to 99. They gave an example grid, which is helpful. Let me try to figure this out step by step.
First, I need to visualize the grid. It's a 2x3 grid, which means there are 2 rows and 3 columns. Let me denote each cell with variables to make it easier. Let's say the top row has cells A, B, C, and the bottom row has cells D, E, F. So the grid looks like this:
\[
\begin{array}{|c|c|c|}
\hline
A & B & C \\
\hline
D & E & F \\
\hline
\end{array}
\]
Now, according to the problem, reading left to right, the two numbers formed are ABC and DEF. So, ABC is a three-digit number where A is the hundreds digit, B is the tens, and C is the units. Similarly, DEF is another three-digit number with D, E, F as hundreds, tens, and units respectively. The sum of these two numbers should be 999.
On the other hand, reading top to bottom, the three numbers formed are AD, BE, and CF. That is, AD is a two-digit number with A as tens and D as units, BE is another two-digit number with B as tens and E as units, and CF is a two-digit number with C as tens and F as units. The sum of these three numbers should be 99.
So, summarizing the constraints:
1. ABC + DEF = 999
2. AD + BE + CF = 99
Each of the letters A, B, C, D, E, F represents a digit from 0 to 9, except that A and D can't be zero because they are the first digits of their respective numbers. Wait, hold on, actually, in the example given, A is 0. Hmm, so maybe A can be zero? Because in the example, the top number is 008, which is 8, and the bottom number is 991. But 008 is just 8, so""",
# Problem 1 - should extract around 73
"""145
Alright, so I'm trying to solve this problem about residents in Aimeville and the things they own. It seems like a problem involving sets and maybe using the principle of inclusion-exclusion. Let me try to break it down step by step.
First, let me parse the information given:
- Total residents: 900
- Number of residents who own a diamond ring: 195
- Number who own a set of golf clubs: 367
- Number who own a garden spade: 562
- Everyone owns a bag of candy hearts. So, that's 900.
Wait, so each resident owns at least a bag of candy hearts. So, candy hearts are like a common item everyone has. The other items (diamond ring, golf clubs, garden spade) are additional things some people own.
Additionally, it's given that:
- 437 residents own exactly two of these things.
- 234 residents own exactly three of these things.
We need to find the number of residents who own all four things, which would include the candy hearts as well. Since everyone owns candy hearts, owning all four is equivalent to owning all three of the other items plus the candy hearts, which they already have.
Hmm. Let me think about how to model this.
Let me denote the sets:
- Let A be the set of residents who own a diamond ring.
- Let B be the set who own golf clubs.
- Let C be the set who own a garden spade.
- Let D be the set who own candy hearts. Since everyone owns this, D is the universal set, so D = 900.
We need to find the number of residents who are in all four sets, which is |A ∩ B ∩ C ∩ D|. But since D is the universal set, this is equal to |A ∩ B ∩ C|. So, if I can find |A ∩ B ∩ C|, that will be the answer.
Wait, so the problem reduces to finding the number of people who own all three: diamond ring, golf clubs, and garden spade. Because everyone owns the candy hearts, so that's automatically included.
But let's see. The problem mentions "exactly two" and "exactly three" of these things. So, does that include the candy hearts? Hmm, probably, because it says "each of the 9"""
]
# Test current extraction
print("🔍 Testing Current Answer Extraction")
print("=" * 60)
for i, reasoning in enumerate(sample_reasoning_paths, 1):
print(f"\n📝 Problem {i}:")
print("-" * 40)
# Current extraction logic
current_answer = extract_answer_current(reasoning)
print(f"Current extraction: '{current_answer}'")
# Show what patterns matched
show_pattern_matches(reasoning)
print(f"Reasoning preview: {reasoning[:200]}...")
print("\n" + "=" * 60)
print("🔧 Testing Improved Answer Extraction")
print("=" * 60)
for i, reasoning in enumerate(sample_reasoning_paths, 1):
print(f"\n📝 Problem {i}:")
print("-" * 40)
# Improved extraction logic
improved_answer = extract_answer_improved(reasoning)
print(f"Improved extraction: '{improved_answer}'")
print(f"Reasoning preview: {reasoning[:200]}...")
def extract_answer_current(reasoning: str) -> str:
"""Current answer extraction logic."""
if not reasoning or not reasoning.strip():
return ""
# Current patterns
answer_patterns = [
re.compile(r"####\s*([-+]?\d+(?:\.\d+)?)", re.I), # #### answer
re.compile(r"final answer.*?([-+]?\d+(?:\.\d+)?)", re.I), # final answer
re.compile(r"answer is\s*[:\s]?([-+]?\d[\d,]*(?:\.\d+)?)(?=[\.\n]|$)", re.I), # answer is
]
# Strategy 1: Look for explicit answer patterns
for pattern in answer_patterns:
match = pattern.search(reasoning)
if match:
extracted_answer = match.group(1).strip()
cleaned_answer = clean_answer(extracted_answer)
if is_valid_answer(cleaned_answer):
return cleaned_answer
# Strategy 2: Look for "answer is" patterns
answer_patterns_enhanced = [
r"answer is[:\s]*([+-]?\d{1,3}(?:,\d{3})*(?:\.\d+)?)",
r"final answer[:\s]*([+-]?\d{1,3}(?:,\d{3})*(?:\.\d+)?)",
r"the answer[:\s]*([+-]?\d{1,3}(?:,\d{3})*(?:\.\d+)?)",
r"correct answer[:\s]*([+-]?\d{1,3}(?:,\d{3})*(?:\.\d+)?)",
r"answer[:\s]*([+-]?\d{1,3}(?:,\d{3})*(?:\.\d+)?)",
]
for pattern_str in answer_patterns_enhanced:
pattern = re.compile(pattern_str, re.IGNORECASE)
match = pattern.search(reasoning)
if match:
extracted_answer = match.group(1).strip()
cleaned_answer = clean_answer(extracted_answer)
if is_valid_answer(cleaned_answer):
return cleaned_answer
# Strategy 3: Look for boxed answers
boxed_pattern = r"\\boxed\{([^}]+)\}"
match = re.search(boxed_pattern, reasoning)
if match:
extracted_answer = match.group(1).strip()
cleaned_answer = clean_answer(extracted_answer)
if is_valid_answer(cleaned_answer):
return cleaned_answer
# Strategy 4: Look for the last reasonable number
number_pattern = re.compile(r"[-+]?\d[\d,]*(?:\.\d+)?")
all_numbers = number_pattern.findall(reasoning)
if all_numbers:
reasonable_numbers = []
for num_str in all_numbers:
cleaned = clean_answer(num_str)
if is_valid_answer(cleaned):
num_val = float(cleaned)
if num_val >= 1:
reasonable_numbers.append(cleaned)
if reasonable_numbers:
return reasonable_numbers[-1]
return ""
def extract_answer_improved(reasoning: str) -> str:
"""Improved answer extraction logic."""
if not reasoning or not reasoning.strip():
return ""
# Strategy 1: Look for explicit answer patterns (highest priority)
explicit_patterns = [
r"####\s*([-+]?\d+(?:\.\d+)?)", # #### answer
r"final answer.*?([-+]?\d+(?:\.\d+)?)", # final answer
r"answer is\s*[:\s]?([-+]?\d[\d,]*(?:\.\d+)?)(?=[\.\n]|$)", # answer is
r"\\boxed\{([^}]+)\}", # boxed answers
]
for pattern_str in explicit_patterns:
pattern = re.compile(pattern_str, re.I)
match = pattern.search(reasoning)
if match:
extracted_answer = match.group(1).strip()
cleaned_answer = clean_answer(extracted_answer)
if is_valid_answer(cleaned_answer):
return cleaned_answer
# Strategy 2: Look for "answer is" or "final answer" patterns with better regex
answer_patterns_enhanced = [
r"answer is[:\s]*([+-]?\d{1,3}(?:,\d{3})*(?:\.\d+)?)",
r"final answer[:\s]*([+-]?\d{1,3}(?:,\d{3})*(?:\.\d+)?)",
r"the answer[:\s]*([+-]?\d{1,3}(?:,\d{3})*(?:\.\d+)?)",
r"correct answer[:\s]*([+-]?\d{1,3}(?:,\d{3})*(?:\.\d+)?)",
r"answer[:\s]*([+-]?\d{1,3}(?:,\d{3})*(?:\.\d+)?)",
]
for pattern_str in answer_patterns_enhanced:
pattern = re.compile(pattern_str, re.IGNORECASE)
match = pattern.search(reasoning)
if match:
extracted_answer = match.group(1).strip()
cleaned_answer = clean_answer(extracted_answer)
if is_valid_answer(cleaned_answer):
return cleaned_answer
# Strategy 3: Look for the last reasonable number (improved)
number_pattern = re.compile(r"[-+]?\d[\d,]*(?:\.\d+)?")
all_numbers = number_pattern.findall(reasoning)
if all_numbers:
# Filter out very small numbers and look for the last reasonable one
reasonable_numbers = []
for num_str in all_numbers:
cleaned = clean_answer(num_str)
if is_valid_answer(cleaned):
num_val = float(cleaned)
# Only consider numbers >= 1 (filter out small intermediate calculations)
if num_val >= 1:
reasonable_numbers.append(cleaned)
if reasonable_numbers:
# For AIME problems, look for the last reasonable number that's not at the very beginning
# Skip the first number if it's followed by detailed reasoning (common pattern)
if len(reasonable_numbers) > 1:
# Check if the first number is followed by detailed reasoning
first_num = reasonable_numbers[0]
first_num_pos = reasoning.find(first_num)
if first_num_pos != -1:
# Look at the next 100 characters after the first number
next_chars = reasoning[first_num_pos + len(first_num):first_num_pos + len(first_num) + 100]
# If it's followed by detailed reasoning (contains common words), skip it
if any(word in next_chars.lower() for word in ['okay', 'so', 'let', 'first', 'now', 'then', 'we', 'i', 'the']):
# Skip the first number and return the last one
return reasonable_numbers[-1]
# Return the last reasonable number
return reasonable_numbers[-1]
return ""
def show_pattern_matches(reasoning: str):
"""Show what patterns matched in the reasoning."""
print("Pattern matches:")
# Test explicit patterns
explicit_patterns = [
(r"####\s*([-+]?\d+(?:\.\d+)?)", "#### pattern"),
(r"final answer.*?([-+]?\d+(?:\.\d+)?)", "final answer pattern"),
(r"answer is\s*[:\s]?([-+]?\d[\d,]*(?:\.\d+)?)(?=[\.\n]|$)", "answer is pattern"),
(r"\\boxed\{([^}]+)\}", "boxed pattern"),
]
for pattern_str, name in explicit_patterns:
pattern = re.compile(pattern_str, re.I)
matches = pattern.findall(reasoning)
if matches:
print(f" {name}: {matches}")
# Test enhanced patterns
answer_patterns_enhanced = [
(r"answer is[:\s]*([+-]?\d{1,3}(?:,\d{3})*(?:\.\d+)?)", "enhanced answer is"),
(r"final answer[:\s]*([+-]?\d{1,3}(?:,\d{3})*(?:\.\d+)?)", "enhanced final answer"),
]
for pattern_str, name in answer_patterns_enhanced:
pattern = re.compile(pattern_str, re.IGNORECASE)
matches = pattern.findall(reasoning)
if matches:
print(f" {name}: {matches}")
# Show all numbers found
number_pattern = re.compile(r"[-+]?\d[\d,]*(?:\.\d+)?")
all_numbers = number_pattern.findall(reasoning)
print(f" All numbers found: {all_numbers[:10]}...") # Show first 10
def clean_answer(answer: str) -> str:
"""Clean answer for comparison."""
if not answer:
return ""
answer = str(answer).strip()
# Remove common prefixes
answer = re.sub(r'^(The answer is|Answer:|Final answer:?)\s*', '', answer, flags=re.IGNORECASE)
# Remove dollar signs and other currency symbols
answer = re.sub(r'[\$\s]+', '', answer)
# Remove boxed formatting
answer = re.sub(r'\\boxed\{([^}]+)\}', r'\1', answer)
# Remove brackets, parentheses
answer = re.sub(r'^[\\[\\](){}]+|[\\[\\](){}]+$', '', answer)
# Remove trailing punctuation
answer = re.sub(r'[.,;:!?]+$', '', answer)
# Remove commas from numbers
answer = answer.replace(",", "")
# Convert to float and back to remove unnecessary decimals
try:
num = float(answer)
if num == int(num):
return str(int(num))
else:
return str(num)
except ValueError:
return answer
def is_valid_answer(answer: str) -> bool:
"""Check if an answer is valid."""
if not answer or answer.strip() == "":
return False
try:
num = float(answer)
return -100000 <= num <= 1000000
except ValueError:
return False
if __name__ == "__main__":
test_answer_extraction()