Skip to content

Commit 325d5af

Browse files
committed
mlir: test register allocation
1 parent e0633e6 commit 325d5af

File tree

3 files changed

+305
-0
lines changed

3 files changed

+305
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
add_subdirectory(PythonBytecode)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
include(GoogleTest)
2+
3+
add_executable(python_bytecode_tests
4+
RegisterAllocation_tests.cpp)
5+
target_link_libraries(python_bytecode_tests PRIVATE
6+
TargetPythonBytecode
7+
gtest_main)
8+
gtest_discover_tests(python_bytecode_tests)
Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
#include "Target/PythonBytecode/LinearScanRegisterAllocation.hpp"
2+
#include "Target/PythonBytecode/RegisterAllocationLogger.hpp"
3+
4+
#include "Dialect/EmitPythonBytecode/IR/EmitPythonBytecode.hpp"
5+
#include "Dialect/Python/IR/Dialect.hpp"
6+
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
7+
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
8+
#include "mlir/Dialect/Func/IR/FuncOps.h"
9+
#include "mlir/IR/Builders.h"
10+
#include "mlir/IR/BuiltinOps.h"
11+
#include "mlir/IR/MLIRContext.h"
12+
#include "mlir/IR/OwningOpRef.h"
13+
#include "mlir/Parser/Parser.h"
14+
15+
#include "gtest/gtest.h"
16+
#include <llvm-20/llvm/Support/raw_ostream.h>
17+
#include <mlir/IR/ValueRange.h>
18+
#include <mlir/IR/Visitors.h>
19+
#include <spdlog/spdlog.h>
20+
21+
using namespace codegen;
22+
23+
namespace {
24+
25+
// MLIR IR from integration/minimal_foriter_bug.py
26+
// This is the actual IR that triggers the FOR_ITER iterator clobbering bug
27+
constexpr const char *FORITER_BUG_MLIR = R"(
28+
module attributes {llvm.argv = ["integration/minimal_foriter_bug.py"]} {
29+
func.func private @__hidden_init__() -> !python.object attributes {names = ["split"]} {
30+
%0 = "emitpybytecode.LOAD_CONST"() <{value = "\0AMinimal reproducer for the FOR_ITER iterator clobbering bug.\0A\0AThis is the minimal case that triggers the bug:\0A- FOR loop over a method call result (.split())\0A- Complex loop body with nested while loop\0A- Multiple LOAD_NAME operations for global variables\0A- Enough register pressure that the iterator register gets reused\0A\0ABug manifests as: TypeError: 'int' object is not an iterator\0A"}> : () -> !python.object
31+
%1 = "emitpybytecode.LOAD_CONST"() <{value = "L68\0AL30\0AR48"}> : () -> !python.object
32+
%2 = "emitpybytecode.STORE_NAME"(%1) <{name = "input"}> : (!python.object) -> !python.object
33+
%3 = "emitpybytecode.LOAD_CONST"() <{value = 50 : ui6}> : () -> !python.object
34+
%4 = "emitpybytecode.STORE_NAME"(%3) <{name = "position"}> : (!python.object) -> !python.object
35+
%5 = "emitpybytecode.LOAD_NAME"() <{name = "input"}> : () -> !python.object
36+
%6 = "emitpybytecode.LOAD_METHOD"(%5) <{method_name = "split"}> : (!python.object) -> !python.object
37+
%7 = "emitpybytecode.LOAD_CONST"() <{value = "\0A"}> : () -> !python.object
38+
%8 = "emitpybytecode.CALL"(%6, %7) : (!python.object, !python.object) -> !python.object
39+
%9 = "emitpybytecode.GET_ITER"(%8) : (!python.object) -> !python.object
40+
cf.br ^bb1
41+
^bb1:
42+
"emitpybytecode.FOR_ITER"(%9)[^bb2, ^bb19] : (!python.object) -> ()
43+
^bb2(%10: !python.object):
44+
%11 = "emitpybytecode.STORE_NAME"(%10) <{name = "line"}> : (!python.object) -> !python.object
45+
cf.br ^bb3
46+
^bb3:
47+
%12 = "emitpybytecode.LOAD_NAME"() <{name = "line"}> : () -> !python.object
48+
%13 = "emitpybytecode.LOAD_CONST"() <{value = 0 : ui1}> : () -> !python.object
49+
%14 = "emitpybytecode.BINARY_SUBSCRIPT"(%12, %13) : (!python.object, !python.object) -> !python.object
50+
%15 = "emitpybytecode.LOAD_CONST"() <{value = "L"}> : () -> !python.object
51+
%16 = "emitpybytecode.COMPARE"(%14, %15) <{predicate = 0 : ui8}> : (!python.object, !python.object) -> !python.object
52+
%17 = "emitpybytecode.TO_BOOL"(%16) : (!python.object) -> !python.object
53+
emitpybytecode.JUMP_IF_FALSE %16, ^bb4, ^bb5
54+
^bb4:
55+
%18 = "emitpybytecode.LOAD_NAME"() <{name = "int"}> : () -> !python.object
56+
%19 = "emitpybytecode.LOAD_NAME"() <{name = "line"}> : () -> !python.object
57+
%20 = "emitpybytecode.LOAD_CONST"() <{value = 1 : ui1}> : () -> !python.object
58+
%21 = "emitpybytecode.LOAD_CONST"() <{value}> : () -> !python.object
59+
%22 = "emitpybytecode.LOAD_CONST"() <{value}> : () -> !python.object
60+
%23 = "emitpybytecode.BUILD_SLICE"(%20, %21, %22) : (!python.object, !python.object, !python.object) -> !python.object
61+
%24 = "emitpybytecode.BINARY_SUBSCRIPT"(%19, %23) : (!python.object, !python.object) -> !python.object
62+
%25 = "emitpybytecode.CALL"(%18, %24) : (!python.object, !python.object) -> !python.object
63+
%26 = "emitpybytecode.UNARY"(%25) <{operation_type = 1 : ui8}> : (!python.object) -> !python.object
64+
%27 = "emitpybytecode.STORE_NAME"(%26) <{name = "move_by"}> : (!python.object) -> !python.object
65+
cf.br ^bb6
66+
^bb5:
67+
%28 = "emitpybytecode.LOAD_NAME"() <{name = "int"}> : () -> !python.object
68+
%29 = "emitpybytecode.LOAD_NAME"() <{name = "line"}> : () -> !python.object
69+
%30 = "emitpybytecode.LOAD_CONST"() <{value = 1 : ui1}> : () -> !python.object
70+
%31 = "emitpybytecode.LOAD_CONST"() <{value}> : () -> !python.object
71+
%32 = "emitpybytecode.LOAD_CONST"() <{value}> : () -> !python.object
72+
%33 = "emitpybytecode.BUILD_SLICE"(%30, %31, %32) : (!python.object, !python.object, !python.object) -> !python.object
73+
%34 = "emitpybytecode.BINARY_SUBSCRIPT"(%29, %33) : (!python.object, !python.object) -> !python.object
74+
%35 = "emitpybytecode.CALL"(%28, %34) : (!python.object, !python.object) -> !python.object
75+
%36 = "emitpybytecode.STORE_NAME"(%35) <{name = "move_by"}> : (!python.object) -> !python.object
76+
cf.br ^bb6
77+
^bb6:
78+
%37 = "emitpybytecode.LOAD_NAME"() <{name = "position"}> : () -> !python.object
79+
%38 = "emitpybytecode.LOAD_NAME"() <{name = "move_by"}> : () -> !python.object
80+
%39 = "emitpybytecode.INPLACE_OP"(%37, %38) <{operation_type = 0 : ui8}> : (!python.object, !python.object) -> !python.object
81+
%40 = "emitpybytecode.STORE_NAME"(%37) <{name = "position"}> : (!python.object) -> !python.object
82+
cf.br ^bb7
83+
^bb7:
84+
cf.br ^bb8
85+
^bb8:
86+
%41 = "emitpybytecode.LOAD_NAME"() <{name = "position"}> : () -> !python.object
87+
%42 = "emitpybytecode.LOAD_CONST"() <{value = 0 : ui1}> : () -> !python.object
88+
%43 = "emitpybytecode.COMPARE"(%41, %42) <{predicate = 2 : ui8}> : (!python.object, !python.object) -> !python.object
89+
%44 = "emitpybytecode.TO_BOOL"(%43) : (!python.object) -> !python.object
90+
emitpybytecode.JUMP_IF_FALSE %43, ^bb10(%43 : !python.object), ^bb9
91+
^bb9:
92+
%45 = "emitpybytecode.LOAD_NAME"() <{name = "position"}> : () -> !python.object
93+
%46 = "emitpybytecode.LOAD_CONST"() <{value = 99 : ui7}> : () -> !python.object
94+
%47 = "emitpybytecode.COMPARE"(%45, %46) <{predicate = 4 : ui8}> : (!python.object, !python.object) -> !python.object
95+
cf.br ^bb10(%47 : !python.object)
96+
^bb10(%48: !python.object):
97+
%49 = "emitpybytecode.TO_BOOL"(%48) : (!python.object) -> !python.object
98+
emitpybytecode.JUMP_IF_FALSE %48, ^bb11, ^bb17
99+
^bb11:
100+
%50 = "emitpybytecode.LOAD_NAME"() <{name = "position"}> : () -> !python.object
101+
%51 = "emitpybytecode.LOAD_CONST"() <{value = 0 : ui1}> : () -> !python.object
102+
%52 = "emitpybytecode.COMPARE"(%50, %51) <{predicate = 2 : ui8}> : (!python.object, !python.object) -> !python.object
103+
%53 = "emitpybytecode.TO_BOOL"(%52) : (!python.object) -> !python.object
104+
emitpybytecode.JUMP_IF_FALSE %52, ^bb12, ^bb13
105+
^bb12:
106+
%54 = "emitpybytecode.LOAD_CONST"() <{value = 100 : ui7}> : () -> !python.object
107+
%55 = "emitpybytecode.LOAD_NAME"() <{name = "position"}> : () -> !python.object
108+
%56 = "emitpybytecode.BINARY_OP"(%54, %55) <{operation_type = 0 : ui8}> : (!python.object, !python.object) -> !python.object
109+
%57 = "emitpybytecode.STORE_NAME"(%56) <{name = "position"}> : (!python.object) -> !python.object
110+
cf.br ^bb14
111+
^bb13:
112+
%58 = "emitpybytecode.LOAD_NAME"() <{name = "position"}> : () -> !python.object
113+
%59 = "emitpybytecode.LOAD_CONST"() <{value = 99 : ui7}> : () -> !python.object
114+
%60 = "emitpybytecode.COMPARE"(%58, %59) <{predicate = 4 : ui8}> : (!python.object, !python.object) -> !python.object
115+
%61 = "emitpybytecode.TO_BOOL"(%60) : (!python.object) -> !python.object
116+
emitpybytecode.JUMP_IF_FALSE %60, ^bb15, ^bb16
117+
^bb14:
118+
cf.br ^bb7
119+
^bb15:
120+
%62 = "emitpybytecode.LOAD_NAME"() <{name = "position"}> : () -> !python.object
121+
%63 = "emitpybytecode.LOAD_CONST"() <{value = 100 : ui7}> : () -> !python.object
122+
%64 = "emitpybytecode.INPLACE_OP"(%62, %63) <{operation_type = 1 : ui8}> : (!python.object, !python.object) -> !python.object
123+
%65 = "emitpybytecode.STORE_NAME"(%62) <{name = "position"}> : (!python.object) -> !python.object
124+
cf.br ^bb16
125+
^bb16:
126+
cf.br ^bb14
127+
^bb17:
128+
cf.br ^bb18
129+
^bb18:
130+
"emitpybytecode.FOR_ITER"(%9)[^bb2, ^bb19] : (!python.object) -> ()
131+
^bb19:
132+
%66 = "emitpybytecode.LOAD_NAME"() <{name = "print"}> : () -> !python.object
133+
%67 = "emitpybytecode.LOAD_NAME"() <{name = "position"}> : () -> !python.object
134+
%68 = "emitpybytecode.CALL"(%66, %67) : (!python.object, !python.object) -> !python.object
135+
cf.br ^bb20
136+
^bb20:
137+
%69 = "emitpybytecode.LOAD_CONST"() <{value}> : () -> !python.object
138+
return %69 : !python.object
139+
}
140+
}
141+
)";
142+
143+
class RegisterAllocationTest : public ::testing::Test
144+
{
145+
mlir::MLIRContext m_context;
146+
147+
protected:
148+
void SetUp() override
149+
{
150+
// Load required dialects
151+
m_context.getOrLoadDialect<mlir::func::FuncDialect>();
152+
m_context.getOrLoadDialect<mlir::cf::ControlFlowDialect>();
153+
m_context.getOrLoadDialect<mlir::emitpybytecode::EmitPythonBytecodeDialect>();
154+
m_context.getOrLoadDialect<mlir::py::PythonDialect>();
155+
156+
// Disable debug logging for tests unless debugging
157+
auto logger = get_regalloc_logger();
158+
logger->set_level(spdlog::level::warn);
159+
}
160+
161+
// Parse the MLIR IR that reproduces the FOR_ITER bug
162+
mlir::OwningOpRef<mlir::ModuleOp> parseForIterBugIR()
163+
{ return mlir::parseSourceString<mlir::ModuleOp>(FORITER_BUG_MLIR, &m_context); }
164+
};
165+
166+
/**
167+
* This test demonstrates the FOR_ITER iterator clobbering bug.
168+
*
169+
* The bug: When register allocation processes a FOR loop, it allocates the iterator
170+
* to some register (e.g., r2). Inside the loop body, when loading global variables
171+
* (LOAD_NAME operations), the register allocator may reuse the iterator's register
172+
* because it doesn't properly track that the iterator must stay alive for the entire
173+
* loop duration.
174+
*
175+
* This test EXPECTS TO FAIL until the bug is fixed. When the bug is present, at least
176+
* one LOAD_NAME operation inside the loop will be assigned the same register as the
177+
* iterator, causing the iterator to be clobbered.
178+
*/
179+
TEST_F(RegisterAllocationTest, ForIterIteratorRegisterNotReusedInLoopBody)
180+
{
181+
// Parse the real MLIR IR from minimal_foriter_bug.py
182+
auto module = parseForIterBugIR();
183+
ASSERT_TRUE(module) << "Failed to parse MLIR IR";
184+
185+
// Get the function
186+
auto funcs = module->getOps<mlir::func::FuncOp>();
187+
ASSERT_FALSE(funcs.empty()) << "No functions found in module";
188+
auto func = *funcs.begin();
189+
190+
// Run register allocation
191+
mlir::OpBuilder builder(func->getContext());
192+
LinearScanRegisterAllocation regalloc;
193+
regalloc.analyse(func, builder);
194+
195+
// Find the iterator (GET_ITER result - this is %9 in the IR)
196+
mlir::Value iterator;
197+
func.walk([&](mlir::emitpybytecode::GetIter op) {
198+
iterator = op.getResult();
199+
return mlir::WalkResult::interrupt();
200+
});
201+
ASSERT_TRUE(iterator) << "Failed to find GET_ITER operation";
202+
203+
// Get the register assigned to the iterator
204+
auto iteratorLoc = regalloc.value2mem_map.find(iterator);
205+
ASSERT_NE(iteratorLoc, regalloc.value2mem_map.end()) << "Iterator has no register allocation";
206+
ASSERT_TRUE(std::holds_alternative<LinearScanRegisterAllocation::Reg>(iteratorLoc->second))
207+
<< "Iterator is not allocated to a register";
208+
auto iteratorReg = std::get<LinearScanRegisterAllocation::Reg>(iteratorLoc->second).idx;
209+
210+
auto for_iter = mlir::dyn_cast<mlir::emitpybytecode::ForIter>(*iterator.getUsers().begin());
211+
ASSERT_TRUE(for_iter);
212+
213+
auto loop_body = for_iter.getBody();
214+
auto loop_exit = for_iter.getContinuation();
215+
216+
// Collect all blocks that are part of the loop (reachable from loop body but not the exit)
217+
llvm::SmallPtrSet<mlir::Block *, 16> loopBlocks;
218+
llvm::SmallVector<mlir::Block *, 8> worklist;
219+
worklist.push_back(loop_body);
220+
221+
while (!worklist.empty()) {
222+
auto *block = worklist.pop_back_val();
223+
if (block == loop_exit || loopBlocks.contains(block)) {
224+
continue;
225+
}
226+
loopBlocks.insert(block);
227+
228+
// Add successors to worklist
229+
for (auto *successor : block->getSuccessors()) {
230+
if (successor != loop_exit && !loopBlocks.contains(successor)) {
231+
worklist.push_back(successor);
232+
}
233+
}
234+
}
235+
236+
// Find all LOAD_NAME operations inside the FOR loop body only
237+
// The loop body includes all blocks between the entry and exit
238+
// These should NOT reuse the iterator register since the iterator is still alive
239+
llvm::SmallVector<mlir::Value, 8> loopBodyLoadNames;
240+
241+
func.walk([&](mlir::emitpybytecode::LoadNameOp op) {
242+
// Only collect LOAD_NAME operations that are in loop body blocks
243+
if (loopBlocks.contains(op->getBlock())) {
244+
loopBodyLoadNames.push_back(op.getResult());
245+
}
246+
});
247+
248+
ASSERT_FALSE(loopBodyLoadNames.empty()) << "No LOAD_NAME operations found in loop body";
249+
250+
// Check that NONE of the loop body LOAD_NAME operations reuse the iterator register
251+
// THIS IS THE BUG TEST: Currently this WILL fail because the register allocator
252+
// reuses the iterator register for LOAD_NAME operations
253+
int clobberedCount = 0;
254+
for (auto val : loopBodyLoadNames) {
255+
auto valLoc = regalloc.value2mem_map.find(val);
256+
if (valLoc != regalloc.value2mem_map.end()
257+
&& std::holds_alternative<LinearScanRegisterAllocation::Reg>(valLoc->second)) {
258+
auto valReg = std::get<LinearScanRegisterAllocation::Reg>(valLoc->second).idx;
259+
260+
if (valReg == iteratorReg) {
261+
clobberedCount++;
262+
llvm::outs() << val << " clobbers iterator\n";
263+
}
264+
}
265+
}
266+
267+
// THIS EXPECTATION WILL FAIL when the bug is present
268+
// When fixed, clobberedCount should be 0
269+
EXPECT_EQ(clobberedCount, 0) << "BUG DETECTED: " << clobberedCount
270+
<< " LOAD_NAME operation(s) reuse the iterator register r"
271+
<< iteratorReg << " - the iterator will be clobbered!";
272+
}
273+
274+
/**
275+
* Smoke test to verify register allocation runs without crashing
276+
*/
277+
TEST_F(RegisterAllocationTest, RegisterAllocationRunsWithoutCrashing)
278+
{
279+
auto module = parseForIterBugIR();
280+
ASSERT_TRUE(module);
281+
282+
auto funcs = module->getOps<mlir::func::FuncOp>();
283+
ASSERT_FALSE(funcs.empty());
284+
auto func = *funcs.begin();
285+
286+
mlir::OpBuilder builder(module->getContext());
287+
LinearScanRegisterAllocation regalloc;
288+
289+
// Should not crash
290+
EXPECT_NO_THROW(regalloc.analyse(func, builder));
291+
292+
// Should have allocated registers for some values
293+
EXPECT_FALSE(regalloc.value2mem_map.empty());
294+
}
295+
296+
}// namespace

0 commit comments

Comments
 (0)