|
| 1 | +#include "Target/PythonBytecode/LinearScanRegisterAllocation.hpp" |
| 2 | +#include "Target/PythonBytecode/RegisterAllocationLogger.hpp" |
| 3 | + |
| 4 | +#include "Dialect/EmitPythonBytecode/IR/EmitPythonBytecode.hpp" |
| 5 | +#include "Dialect/Python/IR/Dialect.hpp" |
| 6 | +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" |
| 7 | +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
| 8 | +#include "mlir/Dialect/Func/IR/FuncOps.h" |
| 9 | +#include "mlir/IR/Builders.h" |
| 10 | +#include "mlir/IR/BuiltinOps.h" |
| 11 | +#include "mlir/IR/MLIRContext.h" |
| 12 | +#include "mlir/IR/OwningOpRef.h" |
| 13 | +#include "mlir/Parser/Parser.h" |
| 14 | + |
| 15 | +#include "gtest/gtest.h" |
| 16 | +#include <llvm-20/llvm/Support/raw_ostream.h> |
| 17 | +#include <mlir/IR/ValueRange.h> |
| 18 | +#include <mlir/IR/Visitors.h> |
| 19 | +#include <spdlog/spdlog.h> |
| 20 | + |
| 21 | +using namespace codegen; |
| 22 | + |
| 23 | +namespace { |
| 24 | + |
| 25 | +// MLIR IR from integration/minimal_foriter_bug.py |
| 26 | +// This is the actual IR that triggers the FOR_ITER iterator clobbering bug |
| 27 | +constexpr const char *FORITER_BUG_MLIR = R"( |
| 28 | +module attributes {llvm.argv = ["integration/minimal_foriter_bug.py"]} { |
| 29 | + func.func private @__hidden_init__() -> !python.object attributes {names = ["split"]} { |
| 30 | + %0 = "emitpybytecode.LOAD_CONST"() <{value = "\0AMinimal reproducer for the FOR_ITER iterator clobbering bug.\0A\0AThis is the minimal case that triggers the bug:\0A- FOR loop over a method call result (.split())\0A- Complex loop body with nested while loop\0A- Multiple LOAD_NAME operations for global variables\0A- Enough register pressure that the iterator register gets reused\0A\0ABug manifests as: TypeError: 'int' object is not an iterator\0A"}> : () -> !python.object |
| 31 | + %1 = "emitpybytecode.LOAD_CONST"() <{value = "L68\0AL30\0AR48"}> : () -> !python.object |
| 32 | + %2 = "emitpybytecode.STORE_NAME"(%1) <{name = "input"}> : (!python.object) -> !python.object |
| 33 | + %3 = "emitpybytecode.LOAD_CONST"() <{value = 50 : ui6}> : () -> !python.object |
| 34 | + %4 = "emitpybytecode.STORE_NAME"(%3) <{name = "position"}> : (!python.object) -> !python.object |
| 35 | + %5 = "emitpybytecode.LOAD_NAME"() <{name = "input"}> : () -> !python.object |
| 36 | + %6 = "emitpybytecode.LOAD_METHOD"(%5) <{method_name = "split"}> : (!python.object) -> !python.object |
| 37 | + %7 = "emitpybytecode.LOAD_CONST"() <{value = "\0A"}> : () -> !python.object |
| 38 | + %8 = "emitpybytecode.CALL"(%6, %7) : (!python.object, !python.object) -> !python.object |
| 39 | + %9 = "emitpybytecode.GET_ITER"(%8) : (!python.object) -> !python.object |
| 40 | + cf.br ^bb1 |
| 41 | + ^bb1: |
| 42 | + "emitpybytecode.FOR_ITER"(%9)[^bb2, ^bb19] : (!python.object) -> () |
| 43 | + ^bb2(%10: !python.object): |
| 44 | + %11 = "emitpybytecode.STORE_NAME"(%10) <{name = "line"}> : (!python.object) -> !python.object |
| 45 | + cf.br ^bb3 |
| 46 | + ^bb3: |
| 47 | + %12 = "emitpybytecode.LOAD_NAME"() <{name = "line"}> : () -> !python.object |
| 48 | + %13 = "emitpybytecode.LOAD_CONST"() <{value = 0 : ui1}> : () -> !python.object |
| 49 | + %14 = "emitpybytecode.BINARY_SUBSCRIPT"(%12, %13) : (!python.object, !python.object) -> !python.object |
| 50 | + %15 = "emitpybytecode.LOAD_CONST"() <{value = "L"}> : () -> !python.object |
| 51 | + %16 = "emitpybytecode.COMPARE"(%14, %15) <{predicate = 0 : ui8}> : (!python.object, !python.object) -> !python.object |
| 52 | + %17 = "emitpybytecode.TO_BOOL"(%16) : (!python.object) -> !python.object |
| 53 | + emitpybytecode.JUMP_IF_FALSE %16, ^bb4, ^bb5 |
| 54 | + ^bb4: |
| 55 | + %18 = "emitpybytecode.LOAD_NAME"() <{name = "int"}> : () -> !python.object |
| 56 | + %19 = "emitpybytecode.LOAD_NAME"() <{name = "line"}> : () -> !python.object |
| 57 | + %20 = "emitpybytecode.LOAD_CONST"() <{value = 1 : ui1}> : () -> !python.object |
| 58 | + %21 = "emitpybytecode.LOAD_CONST"() <{value}> : () -> !python.object |
| 59 | + %22 = "emitpybytecode.LOAD_CONST"() <{value}> : () -> !python.object |
| 60 | + %23 = "emitpybytecode.BUILD_SLICE"(%20, %21, %22) : (!python.object, !python.object, !python.object) -> !python.object |
| 61 | + %24 = "emitpybytecode.BINARY_SUBSCRIPT"(%19, %23) : (!python.object, !python.object) -> !python.object |
| 62 | + %25 = "emitpybytecode.CALL"(%18, %24) : (!python.object, !python.object) -> !python.object |
| 63 | + %26 = "emitpybytecode.UNARY"(%25) <{operation_type = 1 : ui8}> : (!python.object) -> !python.object |
| 64 | + %27 = "emitpybytecode.STORE_NAME"(%26) <{name = "move_by"}> : (!python.object) -> !python.object |
| 65 | + cf.br ^bb6 |
| 66 | + ^bb5: |
| 67 | + %28 = "emitpybytecode.LOAD_NAME"() <{name = "int"}> : () -> !python.object |
| 68 | + %29 = "emitpybytecode.LOAD_NAME"() <{name = "line"}> : () -> !python.object |
| 69 | + %30 = "emitpybytecode.LOAD_CONST"() <{value = 1 : ui1}> : () -> !python.object |
| 70 | + %31 = "emitpybytecode.LOAD_CONST"() <{value}> : () -> !python.object |
| 71 | + %32 = "emitpybytecode.LOAD_CONST"() <{value}> : () -> !python.object |
| 72 | + %33 = "emitpybytecode.BUILD_SLICE"(%30, %31, %32) : (!python.object, !python.object, !python.object) -> !python.object |
| 73 | + %34 = "emitpybytecode.BINARY_SUBSCRIPT"(%29, %33) : (!python.object, !python.object) -> !python.object |
| 74 | + %35 = "emitpybytecode.CALL"(%28, %34) : (!python.object, !python.object) -> !python.object |
| 75 | + %36 = "emitpybytecode.STORE_NAME"(%35) <{name = "move_by"}> : (!python.object) -> !python.object |
| 76 | + cf.br ^bb6 |
| 77 | + ^bb6: |
| 78 | + %37 = "emitpybytecode.LOAD_NAME"() <{name = "position"}> : () -> !python.object |
| 79 | + %38 = "emitpybytecode.LOAD_NAME"() <{name = "move_by"}> : () -> !python.object |
| 80 | + %39 = "emitpybytecode.INPLACE_OP"(%37, %38) <{operation_type = 0 : ui8}> : (!python.object, !python.object) -> !python.object |
| 81 | + %40 = "emitpybytecode.STORE_NAME"(%37) <{name = "position"}> : (!python.object) -> !python.object |
| 82 | + cf.br ^bb7 |
| 83 | + ^bb7: |
| 84 | + cf.br ^bb8 |
| 85 | + ^bb8: |
| 86 | + %41 = "emitpybytecode.LOAD_NAME"() <{name = "position"}> : () -> !python.object |
| 87 | + %42 = "emitpybytecode.LOAD_CONST"() <{value = 0 : ui1}> : () -> !python.object |
| 88 | + %43 = "emitpybytecode.COMPARE"(%41, %42) <{predicate = 2 : ui8}> : (!python.object, !python.object) -> !python.object |
| 89 | + %44 = "emitpybytecode.TO_BOOL"(%43) : (!python.object) -> !python.object |
| 90 | + emitpybytecode.JUMP_IF_FALSE %43, ^bb10(%43 : !python.object), ^bb9 |
| 91 | + ^bb9: |
| 92 | + %45 = "emitpybytecode.LOAD_NAME"() <{name = "position"}> : () -> !python.object |
| 93 | + %46 = "emitpybytecode.LOAD_CONST"() <{value = 99 : ui7}> : () -> !python.object |
| 94 | + %47 = "emitpybytecode.COMPARE"(%45, %46) <{predicate = 4 : ui8}> : (!python.object, !python.object) -> !python.object |
| 95 | + cf.br ^bb10(%47 : !python.object) |
| 96 | + ^bb10(%48: !python.object): |
| 97 | + %49 = "emitpybytecode.TO_BOOL"(%48) : (!python.object) -> !python.object |
| 98 | + emitpybytecode.JUMP_IF_FALSE %48, ^bb11, ^bb17 |
| 99 | + ^bb11: |
| 100 | + %50 = "emitpybytecode.LOAD_NAME"() <{name = "position"}> : () -> !python.object |
| 101 | + %51 = "emitpybytecode.LOAD_CONST"() <{value = 0 : ui1}> : () -> !python.object |
| 102 | + %52 = "emitpybytecode.COMPARE"(%50, %51) <{predicate = 2 : ui8}> : (!python.object, !python.object) -> !python.object |
| 103 | + %53 = "emitpybytecode.TO_BOOL"(%52) : (!python.object) -> !python.object |
| 104 | + emitpybytecode.JUMP_IF_FALSE %52, ^bb12, ^bb13 |
| 105 | + ^bb12: |
| 106 | + %54 = "emitpybytecode.LOAD_CONST"() <{value = 100 : ui7}> : () -> !python.object |
| 107 | + %55 = "emitpybytecode.LOAD_NAME"() <{name = "position"}> : () -> !python.object |
| 108 | + %56 = "emitpybytecode.BINARY_OP"(%54, %55) <{operation_type = 0 : ui8}> : (!python.object, !python.object) -> !python.object |
| 109 | + %57 = "emitpybytecode.STORE_NAME"(%56) <{name = "position"}> : (!python.object) -> !python.object |
| 110 | + cf.br ^bb14 |
| 111 | + ^bb13: |
| 112 | + %58 = "emitpybytecode.LOAD_NAME"() <{name = "position"}> : () -> !python.object |
| 113 | + %59 = "emitpybytecode.LOAD_CONST"() <{value = 99 : ui7}> : () -> !python.object |
| 114 | + %60 = "emitpybytecode.COMPARE"(%58, %59) <{predicate = 4 : ui8}> : (!python.object, !python.object) -> !python.object |
| 115 | + %61 = "emitpybytecode.TO_BOOL"(%60) : (!python.object) -> !python.object |
| 116 | + emitpybytecode.JUMP_IF_FALSE %60, ^bb15, ^bb16 |
| 117 | + ^bb14: |
| 118 | + cf.br ^bb7 |
| 119 | + ^bb15: |
| 120 | + %62 = "emitpybytecode.LOAD_NAME"() <{name = "position"}> : () -> !python.object |
| 121 | + %63 = "emitpybytecode.LOAD_CONST"() <{value = 100 : ui7}> : () -> !python.object |
| 122 | + %64 = "emitpybytecode.INPLACE_OP"(%62, %63) <{operation_type = 1 : ui8}> : (!python.object, !python.object) -> !python.object |
| 123 | + %65 = "emitpybytecode.STORE_NAME"(%62) <{name = "position"}> : (!python.object) -> !python.object |
| 124 | + cf.br ^bb16 |
| 125 | + ^bb16: |
| 126 | + cf.br ^bb14 |
| 127 | + ^bb17: |
| 128 | + cf.br ^bb18 |
| 129 | + ^bb18: |
| 130 | + "emitpybytecode.FOR_ITER"(%9)[^bb2, ^bb19] : (!python.object) -> () |
| 131 | + ^bb19: |
| 132 | + %66 = "emitpybytecode.LOAD_NAME"() <{name = "print"}> : () -> !python.object |
| 133 | + %67 = "emitpybytecode.LOAD_NAME"() <{name = "position"}> : () -> !python.object |
| 134 | + %68 = "emitpybytecode.CALL"(%66, %67) : (!python.object, !python.object) -> !python.object |
| 135 | + cf.br ^bb20 |
| 136 | + ^bb20: |
| 137 | + %69 = "emitpybytecode.LOAD_CONST"() <{value}> : () -> !python.object |
| 138 | + return %69 : !python.object |
| 139 | + } |
| 140 | +} |
| 141 | +)"; |
| 142 | + |
| 143 | +class RegisterAllocationTest : public ::testing::Test |
| 144 | +{ |
| 145 | + mlir::MLIRContext m_context; |
| 146 | + |
| 147 | + protected: |
| 148 | + void SetUp() override |
| 149 | + { |
| 150 | + // Load required dialects |
| 151 | + m_context.getOrLoadDialect<mlir::func::FuncDialect>(); |
| 152 | + m_context.getOrLoadDialect<mlir::cf::ControlFlowDialect>(); |
| 153 | + m_context.getOrLoadDialect<mlir::emitpybytecode::EmitPythonBytecodeDialect>(); |
| 154 | + m_context.getOrLoadDialect<mlir::py::PythonDialect>(); |
| 155 | + |
| 156 | + // Disable debug logging for tests unless debugging |
| 157 | + auto logger = get_regalloc_logger(); |
| 158 | + logger->set_level(spdlog::level::warn); |
| 159 | + } |
| 160 | + |
| 161 | + // Parse the MLIR IR that reproduces the FOR_ITER bug |
| 162 | + mlir::OwningOpRef<mlir::ModuleOp> parseForIterBugIR() |
| 163 | + { return mlir::parseSourceString<mlir::ModuleOp>(FORITER_BUG_MLIR, &m_context); } |
| 164 | +}; |
| 165 | + |
| 166 | +/** |
| 167 | + * This test demonstrates the FOR_ITER iterator clobbering bug. |
| 168 | + * |
| 169 | + * The bug: When register allocation processes a FOR loop, it allocates the iterator |
| 170 | + * to some register (e.g., r2). Inside the loop body, when loading global variables |
| 171 | + * (LOAD_NAME operations), the register allocator may reuse the iterator's register |
| 172 | + * because it doesn't properly track that the iterator must stay alive for the entire |
| 173 | + * loop duration. |
| 174 | + * |
| 175 | + * This test EXPECTS TO FAIL until the bug is fixed. When the bug is present, at least |
| 176 | + * one LOAD_NAME operation inside the loop will be assigned the same register as the |
| 177 | + * iterator, causing the iterator to be clobbered. |
| 178 | + */ |
| 179 | +TEST_F(RegisterAllocationTest, ForIterIteratorRegisterNotReusedInLoopBody) |
| 180 | +{ |
| 181 | + // Parse the real MLIR IR from minimal_foriter_bug.py |
| 182 | + auto module = parseForIterBugIR(); |
| 183 | + ASSERT_TRUE(module) << "Failed to parse MLIR IR"; |
| 184 | + |
| 185 | + // Get the function |
| 186 | + auto funcs = module->getOps<mlir::func::FuncOp>(); |
| 187 | + ASSERT_FALSE(funcs.empty()) << "No functions found in module"; |
| 188 | + auto func = *funcs.begin(); |
| 189 | + |
| 190 | + // Run register allocation |
| 191 | + mlir::OpBuilder builder(func->getContext()); |
| 192 | + LinearScanRegisterAllocation regalloc; |
| 193 | + regalloc.analyse(func, builder); |
| 194 | + |
| 195 | + // Find the iterator (GET_ITER result - this is %9 in the IR) |
| 196 | + mlir::Value iterator; |
| 197 | + func.walk([&](mlir::emitpybytecode::GetIter op) { |
| 198 | + iterator = op.getResult(); |
| 199 | + return mlir::WalkResult::interrupt(); |
| 200 | + }); |
| 201 | + ASSERT_TRUE(iterator) << "Failed to find GET_ITER operation"; |
| 202 | + |
| 203 | + // Get the register assigned to the iterator |
| 204 | + auto iteratorLoc = regalloc.value2mem_map.find(iterator); |
| 205 | + ASSERT_NE(iteratorLoc, regalloc.value2mem_map.end()) << "Iterator has no register allocation"; |
| 206 | + ASSERT_TRUE(std::holds_alternative<LinearScanRegisterAllocation::Reg>(iteratorLoc->second)) |
| 207 | + << "Iterator is not allocated to a register"; |
| 208 | + auto iteratorReg = std::get<LinearScanRegisterAllocation::Reg>(iteratorLoc->second).idx; |
| 209 | + |
| 210 | + auto for_iter = mlir::dyn_cast<mlir::emitpybytecode::ForIter>(*iterator.getUsers().begin()); |
| 211 | + ASSERT_TRUE(for_iter); |
| 212 | + |
| 213 | + auto loop_body = for_iter.getBody(); |
| 214 | + auto loop_exit = for_iter.getContinuation(); |
| 215 | + |
| 216 | + // Collect all blocks that are part of the loop (reachable from loop body but not the exit) |
| 217 | + llvm::SmallPtrSet<mlir::Block *, 16> loopBlocks; |
| 218 | + llvm::SmallVector<mlir::Block *, 8> worklist; |
| 219 | + worklist.push_back(loop_body); |
| 220 | + |
| 221 | + while (!worklist.empty()) { |
| 222 | + auto *block = worklist.pop_back_val(); |
| 223 | + if (block == loop_exit || loopBlocks.contains(block)) { |
| 224 | + continue; |
| 225 | + } |
| 226 | + loopBlocks.insert(block); |
| 227 | + |
| 228 | + // Add successors to worklist |
| 229 | + for (auto *successor : block->getSuccessors()) { |
| 230 | + if (successor != loop_exit && !loopBlocks.contains(successor)) { |
| 231 | + worklist.push_back(successor); |
| 232 | + } |
| 233 | + } |
| 234 | + } |
| 235 | + |
| 236 | + // Find all LOAD_NAME operations inside the FOR loop body only |
| 237 | + // The loop body includes all blocks between the entry and exit |
| 238 | + // These should NOT reuse the iterator register since the iterator is still alive |
| 239 | + llvm::SmallVector<mlir::Value, 8> loopBodyLoadNames; |
| 240 | + |
| 241 | + func.walk([&](mlir::emitpybytecode::LoadNameOp op) { |
| 242 | + // Only collect LOAD_NAME operations that are in loop body blocks |
| 243 | + if (loopBlocks.contains(op->getBlock())) { |
| 244 | + loopBodyLoadNames.push_back(op.getResult()); |
| 245 | + } |
| 246 | + }); |
| 247 | + |
| 248 | + ASSERT_FALSE(loopBodyLoadNames.empty()) << "No LOAD_NAME operations found in loop body"; |
| 249 | + |
| 250 | + // Check that NONE of the loop body LOAD_NAME operations reuse the iterator register |
| 251 | + // THIS IS THE BUG TEST: Currently this WILL fail because the register allocator |
| 252 | + // reuses the iterator register for LOAD_NAME operations |
| 253 | + int clobberedCount = 0; |
| 254 | + for (auto val : loopBodyLoadNames) { |
| 255 | + auto valLoc = regalloc.value2mem_map.find(val); |
| 256 | + if (valLoc != regalloc.value2mem_map.end() |
| 257 | + && std::holds_alternative<LinearScanRegisterAllocation::Reg>(valLoc->second)) { |
| 258 | + auto valReg = std::get<LinearScanRegisterAllocation::Reg>(valLoc->second).idx; |
| 259 | + |
| 260 | + if (valReg == iteratorReg) { |
| 261 | + clobberedCount++; |
| 262 | + llvm::outs() << val << " clobbers iterator\n"; |
| 263 | + } |
| 264 | + } |
| 265 | + } |
| 266 | + |
| 267 | + // THIS EXPECTATION WILL FAIL when the bug is present |
| 268 | + // When fixed, clobberedCount should be 0 |
| 269 | + EXPECT_EQ(clobberedCount, 0) << "BUG DETECTED: " << clobberedCount |
| 270 | + << " LOAD_NAME operation(s) reuse the iterator register r" |
| 271 | + << iteratorReg << " - the iterator will be clobbered!"; |
| 272 | +} |
| 273 | + |
| 274 | +/** |
| 275 | + * Smoke test to verify register allocation runs without crashing |
| 276 | + */ |
| 277 | +TEST_F(RegisterAllocationTest, RegisterAllocationRunsWithoutCrashing) |
| 278 | +{ |
| 279 | + auto module = parseForIterBugIR(); |
| 280 | + ASSERT_TRUE(module); |
| 281 | + |
| 282 | + auto funcs = module->getOps<mlir::func::FuncOp>(); |
| 283 | + ASSERT_FALSE(funcs.empty()); |
| 284 | + auto func = *funcs.begin(); |
| 285 | + |
| 286 | + mlir::OpBuilder builder(module->getContext()); |
| 287 | + LinearScanRegisterAllocation regalloc; |
| 288 | + |
| 289 | + // Should not crash |
| 290 | + EXPECT_NO_THROW(regalloc.analyse(func, builder)); |
| 291 | + |
| 292 | + // Should have allocated registers for some values |
| 293 | + EXPECT_FALSE(regalloc.value2mem_map.empty()); |
| 294 | +} |
| 295 | + |
| 296 | +}// namespace |
0 commit comments