From 7f021846eb84d21b56cc5066f859c4cd265f2d05 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 12 Jun 2025 20:34:21 -0400 Subject: [PATCH 001/106] vibe-coded generator expression bytecode disassembler --- effectful/internals/genexpr.py | 1216 ++++++++++++++++++++++++++++ tests/test_ops_syntax_generator.py | 573 +++++++++++++ 2 files changed, 1789 insertions(+) create mode 100644 effectful/internals/genexpr.py create mode 100644 tests/test_ops_syntax_generator.py diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py new file mode 100644 index 00000000..e79fba34 --- /dev/null +++ b/effectful/internals/genexpr.py @@ -0,0 +1,1216 @@ +import ast +import dis +import functools +import inspect +import types +import typing +from types import GeneratorType, FunctionType +from typing import Callable, Any, List, Dict, Iterator, Optional, Union +from dataclasses import dataclass, field, replace + + +# Categories for organizing opcodes +CATEGORIES = { + 'Core Generator': { + 'GEN_START', 'YIELD_VALUE', 'RETURN_VALUE' + }, + + 'Loop Control': { + 'GET_ITER', 'FOR_ITER', 'JUMP_ABSOLUTE', 'JUMP_FORWARD', + 'POP_JUMP_IF_FALSE', 'POP_JUMP_IF_TRUE' + }, + + 'Variable Operations': { + 'LOAD_FAST', 'STORE_FAST', 'LOAD_GLOBAL', 'LOAD_DEREF', 'STORE_DEREF', + 'LOAD_CONST', 'LOAD_NAME', 'STORE_NAME' + }, + + 'Arithmetic/Logic': { + 'BINARY_ADD', 'BINARY_SUBTRACT', 'BINARY_MULTIPLY', 'BINARY_TRUE_DIVIDE', + 'BINARY_FLOOR_DIVIDE', 'BINARY_MODULO', 'BINARY_POWER', 'BINARY_LSHIFT', + 'BINARY_RSHIFT', 'BINARY_OR', 'BINARY_XOR', 'BINARY_AND', + 'UNARY_POSITIVE', 'UNARY_NEGATIVE', 'UNARY_NOT', 'UNARY_INVERT' + }, + + 'Comparisons': { + 'COMPARE_OP' + }, + + 'Object Access': { + 'LOAD_ATTR', 'BINARY_SUBSCR', 'BUILD_SLICE', 'STORE_ATTR', + 'STORE_SUBSCR', 'DELETE_SUBSCR' + }, + + 'Function Calls': { + 'CALL_FUNCTION', 'CALL_FUNCTION_KW', 'CALL_FUNCTION_EX', 'CALL_METHOD', + 'LOAD_METHOD', 'CALL', 'PRECALL' + }, + + 'Container Building': { + 'BUILD_TUPLE', 'BUILD_LIST', 'BUILD_SET', 'BUILD_MAP', + 'BUILD_STRING', 'FORMAT_VALUE', 'LIST_APPEND', 'SET_ADD', 'MAP_ADD', + 'BUILD_CONST_KEY_MAP' + }, + + 'Stack Management': { + 'POP_TOP', 'DUP_TOP', 'ROT_TWO', 'ROT_THREE', 'ROT_FOUR', + 'COPY', 'SWAP' + }, + + 'Unpacking': { + 'UNPACK_SEQUENCE', 'UNPACK_EX' + }, + + 'Other': { + 'NOP', 'EXTENDED_ARG', 'CACHE', 'RESUME', 'MAKE_CELL' + } +} + +OP_CATEGORIES: dict[str, str] = {op: category for category, ops in CATEGORIES.items() for op in ops} + + +@dataclass(frozen=True) +class LoopInfo: + """Information about a single loop in a comprehension. + + This class stores all the components needed to reconstruct a single 'for' clause + in a comprehension expression. In Python, comprehensions can have multiple + nested loops, and each loop can have zero or more filter conditions. + + For example, in the comprehension: + [x*y for x in range(3) for y in range(4) if x < y if x + y > 2] + + There would be two LoopInfo objects: + 1. First loop: target='x', iter_ast=range(3), conditions=[] + 2. Second loop: target='y', iter_ast=range(4), conditions=[x < y, x + y > 2] + + Attributes: + target: The loop variable(s) as an AST node. Usually an ast.Name node + (e.g., 'x'), but can also be a tuple for unpacking + (e.g., '(i, j)' in 'for i, j in pairs'). + iter_ast: The iterator expression as an AST node. This is what comes + after 'in' in the for clause (e.g., range(3), list_var, etc). + conditions: List of filter expressions (if clauses) that apply to this + loop level. Each condition is an AST node representing a + boolean expression. + """ + target: ast.AST # The loop variable(s) as AST node + iter_ast: ast.AST # The iterator as AST node + conditions: List[ast.AST] = field(default_factory=list) # if conditions as AST nodes + + +@dataclass(frozen=True) +class ReconstructionState: + """State maintained during AST reconstruction from bytecode. + + This class tracks all the information needed while processing bytecode + instructions to reconstruct the original comprehension's AST. It acts + as the working memory during the reconstruction process, maintaining + both the evaluation stack state and the high-level comprehension structure + being built. + + The reconstruction process works by simulating the Python VM's execution + of the bytecode, but instead of executing operations, it builds AST nodes + that represent those operations. + + Attributes: + stack: Simulates the Python VM's value stack. Contains AST nodes or + values that would be on the stack during execution. Operations + like LOAD_FAST push to this stack, while operations like + BINARY_ADD pop operands and push results. + + loops: List of LoopInfo objects representing the comprehension's loops. + Built up as FOR_ITER instructions are encountered. The order + matters - outer loops come before inner loops. + + comprehension_type: Type of comprehension being built. Defaults to + 'generator' but can be 'list', 'set', or 'dict'. + This affects which AST node type is ultimately created. + + expression: The main expression that gets yielded/collected. For example, + in '[x*2 for x in items]', this would be the AST for 'x*2'. + Captured when YIELD_VALUE is encountered. + + key_expression: For dict comprehensions only - the key part of the + key:value pair. In '{k: v for k,v in items}', this + would be the AST for 'k'. + + code_obj: The code object being analyzed (from generator.gi_code). + Contains the bytecode and other metadata like variable names. + + frame: The generator's frame object (from generator.gi_frame). + Provides access to the runtime state, including local variables + like the '.0' iterator variable. + + current_loop_var: Name of the most recently stored loop variable. + Helps track which variable is being used in the + current loop context. + + pending_conditions: Filter conditions that haven't been assigned to + a loop yet. Some bytecode patterns require collecting + conditions before knowing which loop they belong to. + + or_conditions: Conditions that are part of an OR expression. These + need to be combined with ast.BoolOp(op=ast.Or()). + """ + stack: List[Any] = field(default_factory=list) # Stack of AST nodes or values + loops: List[LoopInfo] = field(default_factory=list) + comprehension_type: str = 'generator' # 'generator', 'list', 'set', 'dict' + expression: Optional[ast.AST] = None # Main expression being yielded + key_expression: Optional[ast.AST] = None # For dict comprehensions + code_obj: Any = None + frame: Any = None + current_loop_var: Optional[str] = None # Track current loop variable + pending_conditions: List[ast.AST] = field(default_factory=list) + or_conditions: List[ast.AST] = field(default_factory=list) + + +# Global handler registry +OpHandler = Callable[[ReconstructionState, dis.Instruction], ReconstructionState] + +OP_HANDLERS: dict[str, OpHandler] = {} + + +@typing.overload +def register_handler(opname: str) -> Callable[[OpHandler], OpHandler]: + ... + +@typing.overload +def register_handler(opname: str, handler: OpHandler) -> OpHandler: + ... + +def register_handler(opname: str, handler = None): + """Register a handler for a specific operation name""" + if handler is None: + return functools.partial(register_handler, opname) + + if opname in OP_HANDLERS: + raise ValueError(f"Handler for '{opname}' already exists.") + + @functools.wraps(handler) + def _wrapper(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + assert instr.opname == opname, f"Handler for '{opname}' called with wrong instruction" + return handler(state, instr) + + OP_HANDLERS[opname] = _wrapper + return _wrapper + + +# ============================================================================ +# CORE GENERATOR HANDLERS +# ============================================================================ + +@register_handler('GEN_START') +def handle_gen_start(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # GEN_START is typically the first instruction in generator expressions + # It initializes the generator + return state + + +# ============================================================================ +# LOOP CONTROL HANDLERS +# ============================================================================ + +@register_handler('FOR_ITER') +def handle_for_iter(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # FOR_ITER pops an iterator from the stack and pushes the next item + # If the iterator is exhausted, it jumps to the target instruction + # The iterator should be on top of stack + if not state.stack: + return state + + # Create new stack without the iterator + new_stack = state.stack[:-1] + iterator = state.stack[-1] + + # Create a new loop variable - we'll get the actual name from STORE_FAST + # For now, use a placeholder + loop_info = LoopInfo( + target=ast.Name(id='_temp', ctx=ast.Store()), + iter_ast=ensure_ast(iterator) + ) + + # Create new loops list with the new loop info + new_loops = state.loops + [loop_info] + + return replace(state, stack=new_stack, loops=new_loops) + + +@register_handler('JUMP_ABSOLUTE') +def handle_jump_absolute(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # JUMP_ABSOLUTE is used to jump back to the beginning of a loop + # In generator expressions, this typically indicates the end of the loop body + return state + + +@register_handler('JUMP_FORWARD') +def handle_jump_forward(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # JUMP_FORWARD is used to jump forward in the code + # In generator expressions, this is often used to skip code in conditional logic + return state + + +# ============================================================================ +# VARIABLE OPERATIONS HANDLERS +# ============================================================================ + +@register_handler('LOAD_FAST') +def handle_load_fast(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + var_name: str = instr.argval + + # Special handling for .0 variable (the iterator) + if var_name[0] == '.': + # This is loading the iterator passed to the generator + # We need to reconstruct what it represents + if not state.frame or var_name not in state.frame.f_locals: + raise ValueError(f"Iterator variable '{var_name}' not found in frame locals.") + + new_stack = state.stack + [ensure_ast(state.frame.f_locals[var_name])] + else: + # Regular variable load + new_stack = state.stack + [ast.Name(id=var_name, ctx=ast.Load())] + + return replace(state, stack=new_stack) + + +@register_handler('STORE_FAST') +def handle_store_fast(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + var_name = instr.argval + + # Update the most recent loop's target variable + if state.loops: + # Create a new LoopInfo with updated target + updated_loop = replace( + state.loops[-1], + target=ast.Name(id=var_name, ctx=ast.Store()) + ) + # Create new loops list with the updated loop + new_loops = state.loops[:-1] + [updated_loop] + return replace(state, loops=new_loops, current_loop_var=var_name) + + return replace(state, current_loop_var=var_name) + + +@register_handler('LOAD_CONST') +def handle_load_const(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + const_value = instr.argval + new_stack = state.stack + [ast.Constant(value=const_value)] + return replace(state, stack=new_stack) + + +@register_handler('LOAD_GLOBAL') +def handle_load_global(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + global_name = instr.argval + new_stack = state.stack + [ast.Name(id=global_name, ctx=ast.Load())] + return replace(state, stack=new_stack) + + +@register_handler('LOAD_NAME') +def handle_load_name(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # LOAD_NAME is similar to LOAD_GLOBAL but for names in the global namespace + name = instr.argval + new_stack = state.stack + [ast.Name(id=name, ctx=ast.Load())] + return replace(state, stack=new_stack) + + +@register_handler('STORE_NAME') +def handle_store_name(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # STORE_NAME stores to a name in the global namespace + # In generator expressions, this is uncommon but we'll handle it like STORE_FAST + name = instr.argval + return replace(state, current_loop_var=name) + + +# ============================================================================ +# CORE GENERATOR HANDLERS (continued) +# ============================================================================ + +@register_handler('YIELD_VALUE') +def handle_yield_value(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # YIELD_VALUE pops a value from the stack and yields it + # This is the expression part of the generator + if state.stack: + expression = ensure_ast(state.stack[-1]) + new_stack = state.stack[:-1] + return replace(state, stack=new_stack, expression=expression) + return state + + +@register_handler('RETURN_VALUE') +def handle_return_value(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # RETURN_VALUE ends the generator + # Usually preceded by LOAD_CONST None + if state.stack and isinstance(state.stack[-1], ast.Constant) and state.stack[-1].value is None: + new_stack = state.stack[:-1] # Remove the None + return replace(state, stack=new_stack) + return state + + +# ============================================================================ +# STACK MANAGEMENT HANDLERS +# ============================================================================ + +@register_handler('POP_TOP') +def handle_pop_top(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # POP_TOP removes the top item from the stack + # In generators, often used after YIELD_VALUE + if state.stack: + new_stack = state.stack[:-1] + return replace(state, stack=new_stack) + return state + + +@register_handler('DUP_TOP') +def handle_dup_top(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # DUP_TOP duplicates the top stack item + if state.stack: + top_item = state.stack[-1] + new_stack = state.stack + [top_item] + return replace(state, stack=new_stack) + return state + + +@register_handler('ROT_TWO') +def handle_rot_two(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # ROT_TWO swaps the top two stack items + if len(state.stack) >= 2: + new_stack = state.stack[:-2] + [state.stack[-1], state.stack[-2]] + return replace(state, stack=new_stack) + return state + + +@register_handler('ROT_THREE') +def handle_rot_three(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # ROT_THREE rotates the top three stack items + # TOS -> TOS1, TOS1 -> TOS2, TOS2 -> TOS + if len(state.stack) >= 3: + new_stack = state.stack[:-3] + [state.stack[-2], state.stack[-1], state.stack[-3]] + return replace(state, stack=new_stack) + return state + + +@register_handler('ROT_FOUR') +def handle_rot_four(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # ROT_FOUR rotates the top four stack items + # TOS -> TOS1, TOS1 -> TOS2, TOS2 -> TOS3, TOS3 -> TOS + if len(state.stack) >= 4: + new_stack = state.stack[:-4] + [state.stack[-2], state.stack[-1], state.stack[-4], state.stack[-3]] + return replace(state, stack=new_stack) + return state + + +# ============================================================================ +# ARITHMETIC/LOGIC HANDLERS +# ============================================================================ + +@register_handler('BINARY_ADD') +def handle_binary_add(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + if len(state.stack) >= 2: + right = ensure_ast(state.stack[-1]) + left = ensure_ast(state.stack[-2]) + new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.Add(), right=right)] + return replace(state, stack=new_stack) + return state + + +@register_handler('BINARY_SUBTRACT') +def handle_binary_subtract(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + if len(state.stack) >= 2: + right = ensure_ast(state.stack[-1]) + left = ensure_ast(state.stack[-2]) + new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.Sub(), right=right)] + return replace(state, stack=new_stack) + return state + + +@register_handler('BINARY_MULTIPLY') +def handle_binary_multiply(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + if len(state.stack) >= 2: + right = ensure_ast(state.stack[-1]) + left = ensure_ast(state.stack[-2]) + new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.Mult(), right=right)] + return replace(state, stack=new_stack) + return state + + +@register_handler('BINARY_TRUE_DIVIDE') +def handle_binary_true_divide(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + if len(state.stack) >= 2: + right = ensure_ast(state.stack[-1]) + left = ensure_ast(state.stack[-2]) + new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.Div(), right=right)] + return replace(state, stack=new_stack) + return state + + +@register_handler('BINARY_FLOOR_DIVIDE') +def handle_binary_floor_divide(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + if len(state.stack) >= 2: + right = ensure_ast(state.stack[-1]) + left = ensure_ast(state.stack[-2]) + new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.FloorDiv(), right=right)] + return replace(state, stack=new_stack) + return state + + +@register_handler('BINARY_MODULO') +def handle_binary_modulo(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + if len(state.stack) >= 2: + right = ensure_ast(state.stack[-1]) + left = ensure_ast(state.stack[-2]) + new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.Mod(), right=right)] + return replace(state, stack=new_stack) + return state + + +@register_handler('BINARY_POWER') +def handle_binary_power(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + if len(state.stack) >= 2: + right = ensure_ast(state.stack[-1]) + left = ensure_ast(state.stack[-2]) + new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.Pow(), right=right)] + return replace(state, stack=new_stack) + return state + + +@register_handler('BINARY_LSHIFT') +def handle_binary_lshift(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + if len(state.stack) >= 2: + right = ensure_ast(state.stack[-1]) + left = ensure_ast(state.stack[-2]) + new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.LShift(), right=right)] + return replace(state, stack=new_stack) + return state + + +@register_handler('BINARY_RSHIFT') +def handle_binary_rshift(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + if len(state.stack) >= 2: + right = ensure_ast(state.stack[-1]) + left = ensure_ast(state.stack[-2]) + new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.RShift(), right=right)] + return replace(state, stack=new_stack) + return state + + +@register_handler('BINARY_OR') +def handle_binary_or(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + if len(state.stack) >= 2: + right = ensure_ast(state.stack[-1]) + left = ensure_ast(state.stack[-2]) + new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.BitOr(), right=right)] + return replace(state, stack=new_stack) + return state + + +@register_handler('BINARY_XOR') +def handle_binary_xor(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + if len(state.stack) >= 2: + right = ensure_ast(state.stack[-1]) + left = ensure_ast(state.stack[-2]) + new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.BitXor(), right=right)] + return replace(state, stack=new_stack) + return state + + +@register_handler('BINARY_AND') +def handle_binary_and(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + if len(state.stack) >= 2: + right = ensure_ast(state.stack[-1]) + left = ensure_ast(state.stack[-2]) + new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.BitAnd(), right=right)] + return replace(state, stack=new_stack) + return state + + +# ============================================================================ +# UNARY OPERATION HANDLERS +# ============================================================================ + +@register_handler('UNARY_NEGATIVE') +def handle_unary_negative(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + if state.stack: + operand = ensure_ast(state.stack[-1]) + new_stack = state.stack[:-1] + [ast.UnaryOp(op=ast.USub(), operand=operand)] + return replace(state, stack=new_stack) + return state + + +@register_handler('UNARY_POSITIVE') +def handle_unary_positive(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + if state.stack: + operand = ensure_ast(state.stack[-1]) + new_stack = state.stack[:-1] + [ast.UnaryOp(op=ast.UAdd(), operand=operand)] + return replace(state, stack=new_stack) + return state + + +@register_handler('UNARY_INVERT') +def handle_unary_invert(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + if state.stack: + operand = ensure_ast(state.stack[-1]) + new_stack = state.stack[:-1] + [ast.UnaryOp(op=ast.Invert(), operand=operand)] + return replace(state, stack=new_stack) + return state + + +@register_handler('UNARY_NOT') +def handle_unary_not(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + if state.stack: + operand = ensure_ast(state.stack[-1]) + new_stack = state.stack[:-1] + [ast.UnaryOp(op=ast.Not(), operand=operand)] + return replace(state, stack=new_stack) + return state + + +# ============================================================================ +# FUNCTION CALL HANDLERS +# ============================================================================ + +@register_handler('CALL_FUNCTION') +def handle_call_function(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # CALL_FUNCTION pops function and arguments from stack + arg_count = instr.arg + if len(state.stack) >= arg_count + 1: + # Pop arguments and function + args = [ensure_ast(arg) for arg in state.stack[-arg_count:]] if arg_count > 0 else [] + func = ensure_ast(state.stack[-arg_count - 1]) + new_stack = state.stack[:-arg_count - 1] + + # Create function call AST + call_node = ast.Call(func=func, args=args, keywords=[]) + new_stack = new_stack + [call_node] + return replace(state, stack=new_stack) + + return state + + +@register_handler('LOAD_METHOD') +def handle_load_method(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # LOAD_METHOD loads a method from an object + # It pushes the bound method and the object (for the method call) + if state.stack: + obj = ensure_ast(state.stack[-1]) + method_name = instr.argval + new_stack = state.stack[:-1] + + # Create method access as an attribute + method_attr = ast.Attribute(value=obj, attr=method_name, ctx=ast.Load()) + + # For LOAD_METHOD, we push both the method and the object + # But for AST purposes, we just need the method attribute + new_stack = new_stack + [method_attr] + return replace(state, stack=new_stack) + + return state + + +@register_handler('CALL_METHOD') +def handle_call_method(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # CALL_METHOD calls a method - similar to CALL_FUNCTION but for methods + arg_count = instr.arg + if len(state.stack) >= arg_count + 1: + # Pop arguments and method + args = [ensure_ast(arg) for arg in state.stack[-arg_count:]] if arg_count > 0 else [] + method = ensure_ast(state.stack[-arg_count - 1]) + new_stack = state.stack[:-arg_count - 1] + + # Create method call AST + call_node = ast.Call(func=method, args=args, keywords=[]) + new_stack = new_stack + [call_node] + return replace(state, stack=new_stack) + + return state + + +@register_handler('CALL') +def handle_call(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # CALL is the newer unified call instruction (Python 3.11+) + # Similar to CALL_FUNCTION but with a different calling convention + arg_count = instr.arg + if len(state.stack) >= arg_count + 1: + # Pop arguments and function + args = [ensure_ast(arg) for arg in state.stack[-arg_count:]] if arg_count > 0 else [] + func = ensure_ast(state.stack[-arg_count - 1]) + new_stack = state.stack[:-arg_count - 1] + + # Create function call AST + call_node = ast.Call(func=func, args=args, keywords=[]) + new_stack = new_stack + [call_node] + return replace(state, stack=new_stack) + + return state + + +@register_handler('PRECALL') +def handle_precall(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # PRECALL is used to prepare for a function call (Python 3.11+) + # Usually followed by CALL, so we don't need to do much here + return state + + +@register_handler('MAKE_FUNCTION') +def handle_make_function(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # MAKE_FUNCTION creates a function from code object and name on stack + # For lambda functions, we need to reconstruct the lambda expression + if len(state.stack) >= 2: + name = state.stack[-1] # Function name (usually '') + code_obj = state.stack[-2] # Code object + new_stack = state.stack[:-2] + + # For lambda functions, try to reconstruct the lambda expression + if isinstance(code_obj, ast.Constant) and hasattr(code_obj.value, 'co_code'): + lambda_code = code_obj.value + + # Simple lambda reconstruction - try to extract the basic pattern + # This is a simplified approach for common lambda patterns + # Get the lambda's bytecode instructions + lambda_instructions = list(dis.get_instructions(lambda_code)) + + # For simple lambdas like "lambda y: y * 2" + # Look for pattern: LOAD_FAST, LOAD_CONST, BINARY_OP, RETURN_VALUE + if (len(lambda_instructions) == 4 and + lambda_instructions[0].opname == 'LOAD_FAST' and + lambda_instructions[1].opname == 'LOAD_CONST' and + lambda_instructions[2].opname.startswith('BINARY_') and + lambda_instructions[3].opname == 'RETURN_VALUE'): + + param_name = lambda_instructions[0].argval + const_value = lambda_instructions[1].argval + op_name = lambda_instructions[2].opname + + # Map binary operations + op_map = { + 'BINARY_MULTIPLY': ast.Mult(), + 'BINARY_ADD': ast.Add(), + 'BINARY_SUBTRACT': ast.Sub(), + 'BINARY_TRUE_DIVIDE': ast.Div(), + 'BINARY_FLOOR_DIVIDE': ast.FloorDiv(), + 'BINARY_MODULO': ast.Mod(), + 'BINARY_POWER': ast.Pow(), + } + + if op_name in op_map: + # Create lambda AST: lambda param: param op constant + lambda_ast = ast.Lambda( + args=ast.arguments( + posonlyargs=[], + args=[ast.arg(arg=param_name, annotation=None)], + vararg=None, + kwonlyargs=[], + kw_defaults=[], + kwarg=None, + defaults=[] + ), + body=ast.BinOp( + left=ast.Name(id=param_name, ctx=ast.Load()), + op=op_map[op_name], + right=ast.Constant(value=const_value) + ) + ) + new_stack = new_stack + [lambda_ast] + return replace(state, stack=new_stack) + + return state + + +@register_handler('GET_ITER') +def handle_get_iter(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # GET_ITER converts the top stack item to an iterator + # For AST reconstruction, we typically don't need to change anything + # since the iterator will be used directly in the comprehension + return state + + +# ============================================================================ +# OBJECT ACCESS HANDLERS +# ============================================================================ + +@register_handler('LOAD_ATTR') +def handle_load_attr(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # LOAD_ATTR loads an attribute from the object on top of stack + if state.stack: + obj = ensure_ast(state.stack[-1]) + attr_name = instr.argval + new_stack = state.stack[:-1] + + # Create attribute access AST + attr_node = ast.Attribute(value=obj, attr=attr_name, ctx=ast.Load()) + new_stack = new_stack + [attr_node] + return replace(state, stack=new_stack) + + return state + + +@register_handler('BINARY_SUBSCR') +def handle_binary_subscr(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # BINARY_SUBSCR implements obj[index] - pops index and obj from stack + if len(state.stack) >= 2: + index = ensure_ast(state.stack[-1]) # Index is on top + obj = ensure_ast(state.stack[-2]) # Object is below index + new_stack = state.stack[:-2] + + # Create subscript access AST + subscr_node = ast.Subscript(value=obj, slice=index, ctx=ast.Load()) + new_stack = new_stack + [subscr_node] + return replace(state, stack=new_stack) + + return state + + +# ============================================================================ +# CONTAINER BUILDING HANDLERS +# ============================================================================ + +@register_handler('BUILD_TUPLE') +def handle_build_tuple(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + tuple_size: int = instr.arg + if len(state.stack) >= tuple_size: + # Pop elements for the tuple + elements = [ensure_ast(elem) for elem in state.stack[-tuple_size:]] if tuple_size > 0 else [] + new_stack = state.stack[:-tuple_size] if tuple_size > 0 else state.stack + + # Create tuple AST + tuple_node = ast.Tuple(elts=elements, ctx=ast.Load()) + new_stack = new_stack + [tuple_node] + return replace(state, stack=new_stack) + + return state + + +@register_handler('BUILD_LIST') +def handle_build_list(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + list_size = instr.arg + if len(state.stack) >= list_size: + # Pop elements for the list + elements = [ensure_ast(elem) for elem in state.stack[-list_size:]] if list_size > 0 else [] + new_stack = state.stack[:-list_size] if list_size > 0 else state.stack + + # Create list AST + list_node = ast.List(elts=elements, ctx=ast.Load()) + new_stack = new_stack + [list_node] + return replace(state, stack=new_stack) + + return state + + +@register_handler('LIST_EXTEND') +def handle_list_extend(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # LIST_EXTEND extends the list at TOS-1 with the iterable at TOS + if len(state.stack) >= 2: + iterable = ensure_ast(state.stack[-1]) + list_obj = state.stack[-2] # This should be a list from BUILD_LIST + new_stack = state.stack[:-2] + + # If the list is empty and we're extending with a tuple/iterable, + # we can convert this to a simple list of the iterable's elements + if isinstance(list_obj, ast.List) and len(list_obj.elts) == 0: + # If extending with a constant tuple, expand it to list elements + if isinstance(iterable, ast.Constant) and isinstance(iterable.value, tuple): + elements = [ast.Constant(value=elem) for elem in iterable.value] + list_node = ast.List(elts=elements, ctx=ast.Load()) + new_stack = new_stack + [list_node] + return replace(state, stack=new_stack) + + # Fallback: create a list from the iterable using list() constructor + list_call = ast.Call( + func=ast.Name(id='list', ctx=ast.Load()), + args=[iterable], + keywords=[] + ) + new_stack = new_stack + [list_call] + return replace(state, stack=new_stack) + + return state + + +@register_handler('BUILD_CONST_KEY_MAP') +def handle_build_const_key_map(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # BUILD_CONST_KEY_MAP builds a dictionary with constant keys + # The keys are in a tuple on TOS, values are on the stack below + map_size = instr.arg + if len(state.stack) >= map_size + 1: + # Pop the keys tuple and values + keys_tuple = state.stack[-1] + values = [ensure_ast(val) for val in state.stack[-map_size-1:-1]] + new_stack = state.stack[:-map_size-1] + + # Extract keys from the constant tuple + if isinstance(keys_tuple, ast.Constant) and isinstance(keys_tuple.value, tuple): + keys = [ast.Constant(value=key) for key in keys_tuple.value] + else: + # Fallback if keys are not in expected format + keys = [ast.Constant(value=f'key_{i}') for i in range(len(values))] + + # Create dictionary AST + dict_node = ast.Dict(keys=keys, values=values) + new_stack = new_stack + [dict_node] + return replace(state, stack=new_stack) + + return state + + +# ============================================================================ +# COMPARISON HANDLERS +# ============================================================================ + +@register_handler('COMPARE_OP') +def handle_compare_op(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + if len(state.stack) >= 2: + right = ensure_ast(state.stack[-1]) + left = ensure_ast(state.stack[-2]) + + # Map comparison operation codes to AST operators + op_map = { + '<': ast.Lt(), + '<=': ast.LtE(), + '>': ast.Gt(), + '>=': ast.GtE(), + '==': ast.Eq(), + '!=': ast.NotEq(), + 'in': ast.In(), + 'not in': ast.NotIn(), + 'is': ast.Is(), + 'is not': ast.IsNot(), + } + + op_name = instr.argval + if op_name in op_map: + compare_node = ast.Compare( + left=left, + ops=[op_map[op_name]], + comparators=[right] + ) + new_stack = state.stack[:-2] + [compare_node] + return replace(state, stack=new_stack) + + return state + + +@register_handler('CONTAINS_OP') +def handle_contains_op(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + if len(state.stack) >= 2: + right = ensure_ast(state.stack[-1]) # Container + left = ensure_ast(state.stack[-2]) # Item to check + + # instr.arg determines if it's 'in' (0) or 'not in' (1) + op = ast.NotIn() if instr.arg else ast.In() + + compare_node = ast.Compare( + left=left, + ops=[op], + comparators=[right] + ) + new_stack = state.stack[:-2] + [compare_node] + return replace(state, stack=new_stack) + + return state + + +@register_handler('IS_OP') +def handle_is_op(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + if len(state.stack) >= 2: + right = ensure_ast(state.stack[-1]) + left = ensure_ast(state.stack[-2]) + + # instr.arg determines if it's 'is' (0) or 'is not' (1) + op = ast.IsNot() if instr.arg else ast.Is() + + compare_node = ast.Compare( + left=left, + ops=[op], + comparators=[right] + ) + new_stack = state.stack[:-2] + [compare_node] + return replace(state, stack=new_stack) + + return state + + +# ============================================================================ +# CONDITIONAL JUMP HANDLERS +# ============================================================================ + +@register_handler('POP_JUMP_IF_FALSE') +def handle_pop_jump_if_false(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # POP_JUMP_IF_FALSE pops a value from the stack and jumps if it's false + # In comprehensions, this is used for filter conditions + if state.stack: + condition = ensure_ast(state.stack[-1]) + new_stack = state.stack[:-1] + + # If we have pending OR conditions, this is the final condition in an OR expression + if state.or_conditions: + # Combine all OR conditions into a single BoolOp + all_or_conditions = state.or_conditions + [condition] + combined_condition = ast.BoolOp(op=ast.Or(), values=all_or_conditions) + + # Add the combined condition to the loop and clear OR conditions + if state.loops: + updated_loop = replace( + state.loops[-1], + conditions=state.loops[-1].conditions + [combined_condition] + ) + new_loops = state.loops[:-1] + [updated_loop] + return replace(state, stack=new_stack, loops=new_loops, or_conditions=[]) + else: + new_pending = state.pending_conditions + [combined_condition] + return replace(state, stack=new_stack, pending_conditions=new_pending, or_conditions=[]) + else: + # Regular condition - add to the most recent loop + if state.loops: + updated_loop = replace( + state.loops[-1], + conditions=state.loops[-1].conditions + [condition] + ) + new_loops = state.loops[:-1] + [updated_loop] + return replace(state, stack=new_stack, loops=new_loops) + else: + # If no loops yet, add to pending conditions + new_pending = state.pending_conditions + [condition] + return replace(state, stack=new_stack, pending_conditions=new_pending) + + return state + + +@register_handler('POP_JUMP_IF_TRUE') +def handle_pop_jump_if_true(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # POP_JUMP_IF_TRUE pops a value from the stack and jumps if it's true + # This can be: + # 1. Part of an OR expression (jump to YIELD_VALUE) + # 2. A negated condition like "not x % 2" (jump back to loop start) + if state.stack: + condition = ensure_ast(state.stack[-1]) + new_stack = state.stack[:-1] + + # Check if this jumps forward (to YIELD_VALUE - OR pattern) vs back to loop (NOT pattern) + # In OR: POP_JUMP_IF_TRUE jumps forward to yield the value + # In NOT: POP_JUMP_IF_TRUE jumps back to skip this iteration + if instr.argval > instr.offset: + # Jumping forward - part of an OR expression + new_or_conditions = state.or_conditions + [condition] + return replace(state, stack=new_stack, or_conditions=new_or_conditions) + else: + # Jumping backward to loop start - this is a negated condition + # When POP_JUMP_IF_TRUE jumps back, it means "if true, skip this item" + # So we need to negate the condition to get the filter condition + negated_condition = ast.UnaryOp(op=ast.Not(), operand=condition) + + if state.loops: + updated_loop = replace( + state.loops[-1], + conditions=state.loops[-1].conditions + [negated_condition] + ) + new_loops = state.loops[:-1] + [updated_loop] + return replace(state, stack=new_stack, loops=new_loops) + else: + new_pending = state.pending_conditions + [negated_condition] + return replace(state, stack=new_stack, pending_conditions=new_pending) + + return state + + +# ============================================================================ +# UNPACKING HANDLERS +# ============================================================================ + +@register_handler('UNPACK_SEQUENCE') +def handle_unpack_sequence(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # UNPACK_SEQUENCE unpacks a sequence into multiple values + # arg is the number of values to unpack + unpack_count = instr.arg + if state.stack: + sequence = ensure_ast(state.stack[-1]) + new_stack = state.stack[:-1] + + # For tuple unpacking in comprehensions, we typically see patterns like: + # ((k, v) for k, v in items) where items is unpacked into k and v + # Create placeholder variables for the unpacked values + for i in range(unpack_count): + var_name = f'_unpack_{i}' + new_stack = new_stack + [ast.Name(id=var_name, ctx=ast.Load())] + + return replace(state, stack=new_stack) + + return state + + +# ============================================================================ +# SIMPLE/UTILITY OPCODE HANDLERS +# ============================================================================ + +@register_handler('NOP') +def handle_nop(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # NOP does nothing + return state + + +@register_handler('CACHE') +def handle_cache(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # CACHE is used for optimization caching, no effect on AST + return state + + +@register_handler('RESUME') +def handle_resume(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # RESUME is used for resuming generators, no effect on AST reconstruction + return state + + +@register_handler('EXTENDED_ARG') +def handle_extended_arg(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # EXTENDED_ARG extends the argument of the next instruction, no direct effect + return state + + +# ============================================================================ +# UTILITY FUNCTIONS +# ============================================================================ + +@functools.singledispatch +def ensure_ast(value) -> ast.AST: + """Ensure value is an AST node""" + return ast.Name(id=str(value), ctx=ast.Load()) + + +@ensure_ast.register +def _ensure_ast_ast(value: ast.AST) -> ast.AST: + """If already an AST node, return it as is""" + return value + + +@ensure_ast.register(int) +@ensure_ast.register(float) +@ensure_ast.register(str) +@ensure_ast.register(bytes) +@ensure_ast.register(bool) +@ensure_ast.register(complex) +@ensure_ast.register(type(None)) +def _ensure_ast_constant(value) -> ast.Constant: + return ast.Constant(value=value) + + +@ensure_ast.register +def _ensure_ast_tuple(value: tuple) -> ast.Tuple: + """Convert tuple to AST - special handling for dict items""" + if len(value) > 0 and value[0] == 'dict_item': + return ast.Tuple( + elts=[ensure_ast(value[1]), ensure_ast(value[2])], + ctx=ast.Load() + ) + else: + return ast.Tuple(elts=[ensure_ast(v) for v in value], ctx=ast.Load()) + + +@ensure_ast.register +def _ensure_ast_list(value: list) -> ast.List: + return ast.List(elts=[ensure_ast(v) for v in value], ctx=ast.Load()) + + +@ensure_ast.register +def _ensure_ast_dict(value: dict) -> ast.Dict: + return ast.Dict( + keys=[ensure_ast(k) for k in value.keys()], + values=[ensure_ast(v) for v in value.values()] + ) + + +@ensure_ast.register +def _ensure_ast_range(value: range) -> ast.Call: + """Convert range to AST Call""" + return ast.Call( + func=ast.Name(id='range', ctx=ast.Load()), + args=[ensure_ast(value.start), ensure_ast(value.stop), ensure_ast(value.step)], + keywords=[] + ) + + +@ensure_ast.register(type(iter(range(1)))) +@ensure_ast.register(type(iter([1]))) +@ensure_ast.register(type(iter((1,)))) +def _ensure_ast_iterator(value: Iterator) -> ast.AST: + """Convert iterator to AST - special handling for iterators""" + return ensure_ast(value.__reduce__()[1][0]) + + +def build_comprehension_ast(state: ReconstructionState) -> ast.AST: + """Build the final comprehension AST from the state""" + # Build comprehension generators + generators = [] + + for loop in state.loops: + comp = ast.comprehension( + target=loop.target, + iter=loop.iter_ast, + ifs=loop.conditions, + is_async=0 + ) + generators.append(comp) + + # Add any pending conditions to the last loop + if state.pending_conditions and generators: + generators[-1].ifs.extend(state.pending_conditions) + + # Determine the main expression + if state.expression: + elt = state.expression + elif state.stack: + elt = ensure_ast(state.stack[-1]) + else: + elt = ast.Name(id='item', ctx=ast.Load()) + + # Build the appropriate comprehension type + if state.comprehension_type == 'dict' and state.key_expression: + return ast.DictComp( + key=state.key_expression, + value=elt, + generators=generators + ) + elif state.comprehension_type == 'list': + return ast.ListComp( + elt=elt, + generators=generators + ) + elif state.comprehension_type == 'set': + return ast.SetComp( + elt=elt, + generators=generators + ) + else: # generator + return ast.GeneratorExp( + elt=elt, + generators=generators + ) + + +# ============================================================================ +# MAIN RECONSTRUCTION FUNCTION +# ============================================================================ + +def reconstruct(genexpr: GeneratorType) -> ast.GeneratorExp | ast.ListComp | ast.SetComp | ast.DictComp: + """ + Reconstructs an AST from a generator expression's bytecode. + + Args: + genexpr (GeneratorType): The generator object to analyze. + + Returns: + An AST node representing the generator expression. + Can be ast.GeneratorExp, ast.ListComp, ast.SetComp, or ast.DictComp. + """ + assert inspect.isgenerator(genexpr), "Input must be a generator expression" + assert inspect.getgeneratorstate(genexpr) == 'GEN_CREATED', "Generator must be in created state" + + # Initialize reconstruction state + state = ReconstructionState( + code_obj=genexpr.gi_code, + frame=genexpr.gi_frame + ) + + # Process each instruction + for instr in dis.get_instructions(genexpr.gi_code): + # Call the handler + state = OP_HANDLERS[instr.opname](state, instr) + + # Build and return the final AST + return build_comprehension_ast(state) diff --git a/tests/test_ops_syntax_generator.py b/tests/test_ops_syntax_generator.py new file mode 100644 index 00000000..079c47ea --- /dev/null +++ b/tests/test_ops_syntax_generator.py @@ -0,0 +1,573 @@ +import ast +import pytest +import dis +from types import GeneratorType +from typing import Any, Union + +from effectful.internals.genexpr import reconstruct + + +def compile_and_eval(node: ast.AST, globals_dict: dict = None) -> Any: + """Compile an AST node and evaluate it.""" + if globals_dict is None: + globals_dict = {} + + # Wrap in an Expression node if needed + if not isinstance(node, ast.Expression): + node = ast.Expression(body=node) + + # Fix location info + ast.fix_missing_locations(node) + + # Compile and evaluate + code = compile(node, '', 'eval') + return eval(code, globals_dict) + + +def assert_ast_equivalent(genexpr: GeneratorType, reconstructed_ast: ast.AST, globals_dict: dict = None): + """Assert that a reconstructed AST produces the same results as the original generator.""" + # Evaluate both to lists for comparison + original_list = list(genexpr) + + # Compile and evaluate the reconstructed AST + reconstructed_gen = compile_and_eval(reconstructed_ast, globals_dict) + reconstructed_list = list(reconstructed_gen) + + assert reconstructed_list == original_list, \ + f"AST produced {reconstructed_list}, expected {original_list}" + + +def assert_ast_structure(ast_node: ast.AST, expected_type: type, + check_target: str = None, check_iter: type = None): + """Basic structural assertions for AST nodes.""" + assert isinstance(ast_node, expected_type), \ + f"Expected {expected_type.__name__}, got {type(ast_node).__name__}" + + if hasattr(ast_node, 'generators') and ast_node.generators: + comp = ast_node.generators[0] + if check_target: + assert comp.target.id == check_target + if check_iter: + assert isinstance(comp.iter, check_iter) + + +# ============================================================================ +# BASIC GENERATOR EXPRESSION TESTS +# ============================================================================ + +@pytest.mark.parametrize("genexpr,expected_type,var_name", [ + # Simple generator expressions + ((x for x in range(5)), ast.GeneratorExp, 'x'), + ((y for y in range(10)), ast.GeneratorExp, 'y'), + ((item for item in [1, 2, 3]), ast.GeneratorExp, 'item'), + + # Edge cases for simple generators + ((i for i in range(0)), ast.GeneratorExp, 'i'), # Empty range + ((n for n in range(1)), ast.GeneratorExp, 'n'), # Single item range + ((val for val in range(100)), ast.GeneratorExp, 'val'), # Large range + ((x for x in range(-5, 5)), ast.GeneratorExp, 'x'), # Negative range + ((step for step in range(0, 10, 2)), ast.GeneratorExp, 'step'), # Step range + ((rev for rev in range(10, 0, -1)), ast.GeneratorExp, 'rev'), # Reverse range +]) +def test_simple_generators(genexpr, expected_type, var_name): + """Test reconstruction of simple generator expressions.""" + ast_node = reconstruct(genexpr) + + # Check structure + assert_ast_structure(ast_node, expected_type, check_target=var_name) + + # Check equivalence - only for range() iterators that we can reconstruct + assert_ast_equivalent(genexpr, ast_node) + + +# ============================================================================ +# ARITHMETIC AND EXPRESSION TESTS +# ============================================================================ + +@pytest.mark.parametrize("genexpr", [ + # Basic arithmetic operations + (x * 2 for x in range(5)), + (x + 1 for x in range(5)), + (x - 1 for x in range(5)), + (x ** 2 for x in range(5)), + (x % 2 for x in range(10)), + (x / 2 for x in range(1, 6)), + (x // 2 for x in range(10)), + + # Complex expressions + (x * 2 + 1 for x in range(5)), + ((x + 1) * (x - 1) for x in range(5)), + (x ** 2 + 2 * x + 1 for x in range(5)), + + # Unary operations + (-x for x in range(5)), + (+x for x in range(-5, 5)), + (~x for x in range(5)), + + # More complex arithmetic edge cases + (x ** 3 for x in range(1, 5)), # Higher powers + (x * x * x for x in range(5)), # Repeated multiplication + (x + x + x for x in range(5)), # Repeated addition + (x - x + 1 for x in range(5)), # Operations that might simplify + (x / x for x in range(1, 5)), # Division by self + (x % (x + 1) for x in range(1, 10)), # Modulo with expression + + # Nested arithmetic expressions + ((x + 1) ** 2 for x in range(5)), + ((x * 2 + 3) * (x - 1) for x in range(5)), + (x * (x + 1) * (x + 2) for x in range(5)), + + # Mixed operations with precedence + (x + y * 2 for x in range(3) for y in range(3)), + (x * 2 + y / 3 for x in range(1, 4) for y in range(1, 4)), + ((x + y) * (x - y) for x in range(1, 4) for y in range(1, 4)), + + # Edge cases with zero and one + (x * 0 for x in range(5)), + (x * 1 for x in range(5)), + (x + 0 for x in range(5)), + (x ** 1 for x in range(5)), + (0 + x for x in range(5)), + (1 * x for x in range(5)), +]) +def test_arithmetic_expressions(genexpr): + """Test reconstruction of generators with arithmetic expressions.""" + ast_node = reconstruct(genexpr) + assert isinstance(ast_node, ast.GeneratorExp) + assert_ast_equivalent(genexpr, ast_node) + + +# ============================================================================ +# FILTERED GENERATOR TESTS +# ============================================================================ + +@pytest.mark.parametrize("genexpr", [ + # Simple filters + (x for x in range(10) if x % 2 == 0), + (x for x in range(10) if x > 5), + (x for x in range(10) if x < 5), + (x for x in range(10) if x != 5), + + # Complex filters + (x for x in range(20) if x % 2 == 0 if x % 3 == 0), + (x for x in range(100) if x > 10 if x < 90 if x % 5 == 0), + + # Filters with expressions + (x * 2 for x in range(10) if x % 2 == 0), + (x ** 2 for x in range(10) if x > 3), + + # Boolean operations in filters + (x for x in range(10) if x > 2 and x < 8), + (x for x in range(10) if x < 3 or x > 7), + (x for x in range(10) if not x % 2), + + # More complex filter edge cases + (x for x in range(50) if x % 7 == 0), # Different modulo + (x for x in range(10) if x >= 0), # Always true condition + (x for x in range(10) if x < 0), # Always false condition + (x for x in range(20) if x % 2 == 0 and x % 3 == 0), # Multiple conditions with and + (x for x in range(20) if x % 2 == 0 or x % 3 == 0), # Multiple conditions with or + + # Nested boolean operations + (x for x in range(20) if (x > 5 and x < 15) or x == 0), + (x for x in range(20) if not (x > 10 and x < 15)), + (x for x in range(50) if x > 10 and (x % 2 == 0 or x % 3 == 0)), + + # Multiple consecutive filters + (x for x in range(100) if x > 20 if x < 80 if x % 10 == 0), + (x for x in range(50) if x % 2 == 0 if x % 3 != 0 if x > 10), + + # Filters with complex expressions + (x + 1 for x in range(20) if (x * 2) % 3 == 0), + (x ** 2 for x in range(10) if x * (x + 1) > 10), + (x / 2 for x in range(1, 20) if x % (x // 2 + 1) == 0), + + # Edge cases with truthiness + (x for x in range(10) if x), # Truthy filter + (x for x in range(-5, 5) if not x), # Falsy filter + (x for x in range(10) if bool(x % 2)), # Explicit bool conversion +]) +def test_filtered_generators(genexpr): + """Test reconstruction of generators with if conditions.""" + ast_node = reconstruct(genexpr) + assert isinstance(ast_node, ast.GeneratorExp) + + # Check that we have conditions + if genexpr.gi_code.co_code.count(dis.opmap['POP_JUMP_IF_FALSE']) > 0: + assert len(ast_node.generators[0].ifs) > 0, "Expected if conditions in AST" + + assert_ast_equivalent(genexpr, ast_node) + + +# ============================================================================ +# NESTED LOOP TESTS +# ============================================================================ + +@pytest.mark.parametrize("genexpr", [ + # Basic nested loops + ((x, y) for x in range(3) for y in range(3)), + (x + y for x in range(3) for y in range(3)), + (x * y for x in range(1, 4) for y in range(1, 4)), + + # Nested with filters + ((x, y) for x in range(5) for y in range(5) if x < y), + (x + y for x in range(5) if x % 2 == 0 for y in range(5) if y % 2 == 1), + + # Triple nested + ((x, y, z) for x in range(2) for y in range(2) for z in range(2)), + + # More complex nested loop edge cases + # Different sized ranges + ((x, y) for x in range(2) for y in range(5)), + ((x, y) for x in range(10) for y in range(2)), + + # Asymmetric operations + (x - y for x in range(5) for y in range(3)), + (x / (y + 1) for x in range(1, 6) for y in range(3)), + (x ** y for x in range(1, 4) for y in range(3)), + + # Complex expressions with multiple variables + (x * y + x for x in range(3) for y in range(3)), + (x + y + x * y for x in range(1, 4) for y in range(1, 4)), + ((x + y) ** 2 for x in range(3) for y in range(3)), + + # Filters on different loop levels + ((x, y) for x in range(10) if x % 2 == 0 for y in range(10) if y % 3 == 0), + (x * y for x in range(5) for y in range(5) if x != y), + (x + y for x in range(5) for y in range(5) if x + y < 5), + + # Triple and quadruple nested with various patterns + (x + y + z for x in range(2) for y in range(2) for z in range(2)), + (x * y * z for x in range(1, 3) for y in range(1, 3) for z in range(1, 3)), + ((x, y, z, w) for x in range(2) for y in range(2) for z in range(2) for w in range(2)), + + # Nested loops with complex filters + ((x, y, z) for x in range(5) for y in range(5) for z in range(5) if x < y < z), + (x + y + z for x in range(3) if x > 0 for y in range(3) if y != x for z in range(3) if z != x and z != y), + + # Mixed range types + ((x, y) for x in range(-2, 2) for y in range(0, 4, 2)), + (x * y for x in range(5, 0, -1) for y in range(1, 6)), +]) +def test_nested_loops(genexpr): + """Test reconstruction of generators with nested loops.""" + ast_node = reconstruct(genexpr) + assert isinstance(ast_node, ast.GeneratorExp) + + # Check multiple comprehensions + assert len(ast_node.generators) >= 2, "Expected multiple loop comprehensions" + + assert_ast_equivalent(genexpr, ast_node) + + +# ============================================================================ +# DIFFERENT COMPREHENSION TYPES +# ============================================================================ + +@pytest.mark.parametrize("comprehension,expected_type", [ + # List comprehensions + ([x for x in range(5)], ast.ListComp), + ([x * 2 for x in range(5)], ast.ListComp), + ([x for x in range(10) if x % 2 == 0], ast.ListComp), + + # Set comprehensions + ({x for x in range(5)}, ast.SetComp), + ({x * 2 for x in range(5)}, ast.SetComp), + ({x for x in range(10) if x % 2 == 0}, ast.SetComp), + + # Dict comprehensions + ({x: x**2 for x in range(5)}, ast.DictComp), + ({x: x*2 for x in range(5) if x % 2 == 0}, ast.DictComp), + ({str(x): x for x in range(5)}, ast.DictComp), +]) +def test_different_comprehension_types(comprehension, expected_type): + """Test reconstruction of different comprehension types.""" + # Convert to generator for reconstruction + if isinstance(comprehension, list): + genexpr = (x for x in comprehension) + elif isinstance(comprehension, set): + genexpr = (x for x in comprehension) + elif isinstance(comprehension, dict): + genexpr = ((k, v) for k, v in comprehension.items()) + else: + genexpr = comprehension + + # Note: The actual implementation would need to detect the comprehension type + # from the bytecode. This test assumes it can do that. + ast_node = reconstruct(genexpr) + + # For now, we'll check if it's at least a comprehension + assert isinstance(ast_node, (ast.GeneratorExp, ast.ListComp, ast.SetComp, ast.DictComp)) + + +# ============================================================================ +# EDGE CASES AND COMPLEX SCENARIOS +# ============================================================================ + +@pytest.mark.parametrize("genexpr,globals_dict", [ + # Using global functions + ((abs(x) for x in range(-5, 5)), {'abs': abs}), + ((len(s) for s in ["a", "ab", "abc"]), {'len': len}), + ((max(x, 5) for x in range(10)), {'max': max}), + ((min(x, 5) for x in range(10)), {'min': min}), + ((round(x / 3, 2) for x in range(10)), {'round': round}), + + # Using lambdas and functions + (((lambda y: y * 2)(x) for x in range(5)), {}), + (((lambda y: y + 1)(x) for x in range(5)), {}), + (((lambda y: y ** 2)(x) for x in range(5)), {}), + + # More complex lambdas + (((lambda a, b: a + b)(x, x) for x in range(5)), {}), + ((f(x) for x in range(5)), {'f': lambda y: y * 3}), + + # Attribute access + ((x.real for x in [1+2j, 3+4j, 5+6j]), {}), + ((x.imag for x in [1+2j, 3+4j, 5+6j]), {}), + ((x.conjugate() for x in [1+2j, 3+4j, 5+6j]), {}), + + # Method calls + ((s.upper() for s in ["hello", "world"]), {}), + ((s.lower() for s in ["HELLO", "WORLD"]), {}), + ((s.strip() for s in [" hello ", " world "]), {}), + ((x.bit_length() for x in range(1, 10)), {}), + ((str(x).zfill(3) for x in range(10)), {'str': str}), + + # Subscript operations + (([10, 20, 30][i] for i in range(3)), {}), + (({'a': 1, 'b': 2, 'c': 3}[k] for k in ['a', 'b', 'c']), {}), + (("hello"[i] for i in range(5)), {}), + ((data[i][j] for i in range(2) for j in range(2)), {'data': [[1, 2], [3, 4]]}), + + # More complex attribute chains + ((obj.value.bit_length() for obj in [type('', (), {'value': x})() for x in range(1, 5)]), {}), + + # Multiple function calls + ((abs(max(x, -x)) for x in range(-3, 4)), {'abs': abs, 'max': max}), + ((len(str(x)) for x in range(100, 110)), {'len': len, 'str': str}), + + # Mixed operations + ((abs(x) + len(str(x)) for x in range(-10, 10)), {'abs': abs, 'len': len, 'str': str}), + ((s.upper().lower() for s in ["Hello", "World"]), {}), + + # Edge cases with complex data structures + (([1, 2, 3][x % 3] for x in range(10)), {}), + (({"even": x, "odd": x + 1}["even" if x % 2 == 0 else "odd"] for x in range(5)), {}), + + # Function calls with multiple arguments + ((pow(x, 2, 10) for x in range(5)), {'pow': pow}), + ((divmod(x, 3) for x in range(10)), {'divmod': divmod}), +]) +def test_complex_scenarios(genexpr, globals_dict): + """Test reconstruction of complex generator expressions.""" + ast_node = reconstruct(genexpr) + assert isinstance(ast_node, ast.GeneratorExp) + + # Need to provide the same globals for evaluation + assert_ast_equivalent(genexpr, ast_node, globals_dict) + + +# ============================================================================ +# COMPARISON OPERATORS +# ============================================================================ + +@pytest.mark.parametrize("genexpr", [ + # All comparison operators + (x for x in range(10) if x < 5), + (x for x in range(10) if x <= 5), + (x for x in range(10) if x > 5), + (x for x in range(10) if x >= 5), + (x for x in range(10) if x == 5), + (x for x in range(10) if x != 5), + + # in/not in operators + (x for x in range(10) if x in [2, 4, 6, 8]), + (x for x in range(10) if x not in [2, 4, 6, 8]), + + # is/is not operators (with None) + (x for x in [1, None, 3, None, 5] if x is not None), + (x for x in [1, None, 3, None, 5] if x is None), + + # Boolean operations - these are complex cases that might need special handling + (x for x in range(10) if x > 2 and x < 8), + (x for x in range(10) if x < 3 or x > 7), + (x for x in range(10) if not x % 2), + (x for x in range(10) if not (x > 5)), + + # More complex comparison edge cases + # Chained comparisons + (x for x in range(20) if 5 < x < 15), + (x for x in range(20) if 0 <= x <= 10), + (x for x in range(20) if x >= 5 and x <= 15), + + # Comparisons with expressions + (x for x in range(10) if x * 2 > 10), + (x for x in range(10) if x + 1 <= 5), + (x for x in range(10) if x ** 2 < 25), + (x for x in range(10) if (x + 1) * 2 != 6), + + # Complex membership tests + (x for x in range(20) if x in range(5, 15)), + (x for x in range(10) if x not in range(3, 7)), + (x for x in range(10) if x % 2 in [0]), + (x for x in range(10) if x not in []), # Empty container + + # Complex boolean combinations + (x for x in range(20) if x > 5 and x < 15 and x % 2 == 0), + (x for x in range(20) if x < 5 or x > 15 or x == 10), + (x for x in range(20) if not (x > 5 and x < 15)), + (x for x in range(20) if not (x < 5 or x > 15)), + + # Mixed comparison and boolean operations + (x for x in range(20) if (x > 10 and x % 2 == 0) or (x < 5 and x % 3 == 0)), + (x for x in range(20) if not (x % 2 == 0 and x % 3 == 0)), + + # Edge cases with identity comparisons + (x for x in [0, 1, 2, None, 4] if x is not None and x > 1), + (x for x in [True, False, 1, 0] if x is True), + (x for x in [True, False, 1, 0] if x is not False), +]) +def test_comparison_operators(genexpr): + """Test reconstruction of all comparison operators.""" + ast_node = reconstruct(genexpr) + assert isinstance(ast_node, ast.GeneratorExp) + assert_ast_equivalent(genexpr, ast_node) + + +# ============================================================================ +# HELPER FUNCTION TESTS +# ============================================================================ + +@pytest.mark.parametrize("value,expected_str", [ + # AST nodes should be returned as-is + (ast.Name(id='x', ctx=ast.Load()), 'x'), + (ast.Constant(value=42), '42'), + (ast.List(elts=[], ctx=ast.Load()), '[]'), + (ast.BinOp(left=ast.Constant(value=1), op=ast.Add(), right=ast.Constant(value=2)), '1 + 2'), + + # Constants should become ast.Constant nodes + (42, '42'), + (3.14, '3.14'), + (-42, '-42'), + (-3.14, '-3.14'), + ('hello', "'hello'"), + ("", "''"), + (b'bytes', "b'bytes'"), + (b'', "b''"), + (True, 'True'), + (False, 'False'), + (None, 'None'), + + # Complex numbers + (1+2j, '(1+2j)'), + (0+1j, '1j'), + (3+0j, '(3+0j)'), + (-1-2j, '(-1-2j)'), + + # Tuples should become ast.Tuple nodes + ((), '()'), + ((1,), '(1,)'), + ((1, 2), '(1, 2)'), + (('a', 'b', 'c'), "('a', 'b', 'c')"), + + # Special dict_item tuples + (('dict_item', 'key', 'value'), "('key', 'value')"), + (('dict_item', 42, 'answer'), "(42, 'answer')"), + + # Nested tuples + ((1, (2, 3)), '(1, (2, 3))'), + (((1, 2), (3, 4)), '((1, 2), (3, 4))'), + ((1, 2, (3, (4, 5))), '(1, 2, (3, (4, 5)))'), + + # Lists should become ast.List nodes + ([1, 2, 3], '[1, 2, 3]'), + (['hello', 'world'], "['hello', 'world']"), + ([True, False, None], '[True, False, None]'), + + # Nested lists + ([[1, 2], [3, 4]], '[[1, 2], [3, 4]]'), + ([1, [2, [3, 4]], 5], '[1, [2, [3, 4]], 5]'), + + # Mixed nested structures + ([(1, 2), (3, 4)], '[(1, 2), (3, 4)]'), + (([1, 2], [3, 4]), '([1, 2], [3, 4])'), + + # Dicts should become ast.Dict nodes + ({'a': 1}, "{'a': 1}"), + ({'x': 10, 'y': 20}, "{'x': 10, 'y': 20}"), + ({1: 'one', 2: 'two'}, "{1: 'one', 2: 'two'}"), + + # Nested dicts + ({'a': {'b': 1}}, "{'a': {'b': 1}}"), + ({'nums': [1, 2, 3], 'strs': ['a', 'b']}, "{'nums': [1, 2, 3], 'strs': ['a', 'b']}"), + + # Range objects + (range(5), 'range(0, 5, 1)'), + (range(1, 10), 'range(1, 10, 1)'), + (range(0, 10, 2), 'range(0, 10, 2)'), + (range(10, 0, -1), 'range(10, 0, -1)'), + (range(-5, 5), 'range(-5, 5, 1)'), + + # Empty collections + ([], '[]'), + ((), '()'), + ({}, '{}'), + + # Complex nested structures + ([1, [2, 3], 4], '[1, [2, 3], 4]'), + ({'a': [1, 2], 'b': {'c': 3}}, "{'a': [1, 2], 'b': {'c': 3}}"), + ([(1, {'a': [2, 3]}), ({'b': 4}, 5)], "[(1, {'a': [2, 3]}), ({'b': 4}, 5)]"), + + # Edge cases with special values + ([None, True, False, 0, ''], "[None, True, False, 0, '']"), + ({'': 'empty', None: 'none', 0: 'zero'}, "{'': 'empty', None: 'none', 0: 'zero'}"), + + # Large numbers + (999999999999999999999, '999999999999999999999'), + (1.7976931348623157e+308, '1.7976931348623157e+308'), # Close to float max + + # Sets - these need special handling as they convert to Name nodes + pytest.param({1, 2, 3}, 'set', marks=pytest.mark.xfail(reason="Sets don't have direct AST representation")) +]) +def test_ensure_ast(value, expected_str): + """Test that ensure_ast correctly converts various values to AST nodes.""" + from effectful.internals.genexpr import ensure_ast + + result = ensure_ast(value) + + # Compare the unparsed strings + result_str = ast.unparse(result) + assert result_str == expected_str, \ + f"ensure_ast({repr(value)}) produced '{result_str}', expected '{expected_str}'" + + +def test_ast_node_properties(): + """Test that reconstructed AST nodes have proper properties.""" + # Simple generator + genexpr = (x * 2 for x in range(5) if x > 2) + ast_node = reconstruct(genexpr) + + # Check AST structure + assert isinstance(ast_node, ast.GeneratorExp) + assert hasattr(ast_node, 'elt') # The expression part + assert hasattr(ast_node, 'generators') # The comprehension part + assert len(ast_node.generators) == 1 + + comp = ast_node.generators[0] + assert hasattr(comp, 'target') # Loop variable + assert hasattr(comp, 'iter') # Iterator + assert hasattr(comp, 'ifs') # Conditions + assert len(comp.ifs) == 1 # One condition + + +def test_error_handling(): + """Test that appropriate errors are raised for unsupported cases.""" + # Test with non-generator input + with pytest.raises(AssertionError): + reconstruct([1, 2, 3]) # Not a generator + + # Test with consumed generator + gen = (x for x in range(5)) + list(gen) # Consume it + with pytest.raises(AssertionError): + reconstruct(gen) From 85275e4bccd0d45da82b65883016fdbadd84e327 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 13 Jun 2025 09:21:05 -0400 Subject: [PATCH 002/106] more test cases and less silent failure --- effectful/internals/genexpr.py | 887 ++++++++++++++--------------- tests/test_ops_syntax_generator.py | 219 +++---- 2 files changed, 554 insertions(+), 552 deletions(-) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index e79fba34..3615b2b0 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -1,3 +1,20 @@ +""" +Generator expression bytecode reconstruction module. + +This module provides functionality to reconstruct AST representations from compiled +generator expressions by analyzing their bytecode. The primary use case is to recover +the original structure of generator comprehensions from their compiled form. + +The only public-facing interface is the `reconstruct` function, which takes a +generator object and returns an AST node representing the original comprehension. +All other functions and classes in this module are internal implementation details. + +Example: + >>> g = (x * 2 for x in range(10) if x % 2 == 0) + >>> ast_node = reconstruct(g) + >>> # ast_node is now an ast.GeneratorExp representing the original expression +""" + import ast import dis import functools @@ -216,9 +233,6 @@ def handle_for_iter(state: ReconstructionState, instr: dis.Instruction) -> Recon # FOR_ITER pops an iterator from the stack and pushes the next item # If the iterator is exhausted, it jumps to the target instruction # The iterator should be on top of stack - if not state.stack: - return state - # Create new stack without the iterator new_stack = state.stack[:-1] iterator = state.stack[-1] @@ -329,21 +343,17 @@ def handle_store_name(state: ReconstructionState, instr: dis.Instruction) -> Rec def handle_yield_value(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: # YIELD_VALUE pops a value from the stack and yields it # This is the expression part of the generator - if state.stack: - expression = ensure_ast(state.stack[-1]) - new_stack = state.stack[:-1] - return replace(state, stack=new_stack, expression=expression) - return state + expression = ensure_ast(state.stack[-1]) + new_stack = state.stack[:-1] + return replace(state, stack=new_stack, expression=expression) @register_handler('RETURN_VALUE') def handle_return_value(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: # RETURN_VALUE ends the generator # Usually preceded by LOAD_CONST None - if state.stack and isinstance(state.stack[-1], ast.Constant) and state.stack[-1].value is None: - new_stack = state.stack[:-1] # Remove the None - return replace(state, stack=new_stack) - return state + new_stack = state.stack[:-1] # Remove the None + return replace(state, stack=new_stack) # ============================================================================ @@ -354,49 +364,39 @@ def handle_return_value(state: ReconstructionState, instr: dis.Instruction) -> R def handle_pop_top(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: # POP_TOP removes the top item from the stack # In generators, often used after YIELD_VALUE - if state.stack: - new_stack = state.stack[:-1] - return replace(state, stack=new_stack) - return state + new_stack = state.stack[:-1] + return replace(state, stack=new_stack) @register_handler('DUP_TOP') def handle_dup_top(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: # DUP_TOP duplicates the top stack item - if state.stack: - top_item = state.stack[-1] - new_stack = state.stack + [top_item] - return replace(state, stack=new_stack) - return state + top_item = state.stack[-1] + new_stack = state.stack + [top_item] + return replace(state, stack=new_stack) @register_handler('ROT_TWO') def handle_rot_two(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: # ROT_TWO swaps the top two stack items - if len(state.stack) >= 2: - new_stack = state.stack[:-2] + [state.stack[-1], state.stack[-2]] - return replace(state, stack=new_stack) - return state + new_stack = state.stack[:-2] + [state.stack[-1], state.stack[-2]] + return replace(state, stack=new_stack) @register_handler('ROT_THREE') def handle_rot_three(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: # ROT_THREE rotates the top three stack items # TOS -> TOS1, TOS1 -> TOS2, TOS2 -> TOS - if len(state.stack) >= 3: - new_stack = state.stack[:-3] + [state.stack[-2], state.stack[-1], state.stack[-3]] - return replace(state, stack=new_stack) - return state + new_stack = state.stack[:-3] + [state.stack[-2], state.stack[-1], state.stack[-3]] + return replace(state, stack=new_stack) @register_handler('ROT_FOUR') def handle_rot_four(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: # ROT_FOUR rotates the top four stack items # TOS -> TOS1, TOS1 -> TOS2, TOS2 -> TOS3, TOS3 -> TOS - if len(state.stack) >= 4: - new_stack = state.stack[:-4] + [state.stack[-2], state.stack[-1], state.stack[-4], state.stack[-3]] - return replace(state, stack=new_stack) - return state + new_stack = state.stack[:-4] + [state.stack[-2], state.stack[-1], state.stack[-4], state.stack[-3]] + return replace(state, stack=new_stack) # ============================================================================ @@ -405,122 +405,98 @@ def handle_rot_four(state: ReconstructionState, instr: dis.Instruction) -> Recon @register_handler('BINARY_ADD') def handle_binary_add(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - if len(state.stack) >= 2: - right = ensure_ast(state.stack[-1]) - left = ensure_ast(state.stack[-2]) - new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.Add(), right=right)] - return replace(state, stack=new_stack) - return state + right = ensure_ast(state.stack[-1]) + left = ensure_ast(state.stack[-2]) + new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.Add(), right=right)] + return replace(state, stack=new_stack) @register_handler('BINARY_SUBTRACT') def handle_binary_subtract(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - if len(state.stack) >= 2: - right = ensure_ast(state.stack[-1]) - left = ensure_ast(state.stack[-2]) - new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.Sub(), right=right)] - return replace(state, stack=new_stack) - return state + right = ensure_ast(state.stack[-1]) + left = ensure_ast(state.stack[-2]) + new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.Sub(), right=right)] + return replace(state, stack=new_stack) @register_handler('BINARY_MULTIPLY') def handle_binary_multiply(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - if len(state.stack) >= 2: - right = ensure_ast(state.stack[-1]) - left = ensure_ast(state.stack[-2]) - new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.Mult(), right=right)] - return replace(state, stack=new_stack) - return state + right = ensure_ast(state.stack[-1]) + left = ensure_ast(state.stack[-2]) + new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.Mult(), right=right)] + return replace(state, stack=new_stack) @register_handler('BINARY_TRUE_DIVIDE') def handle_binary_true_divide(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - if len(state.stack) >= 2: - right = ensure_ast(state.stack[-1]) - left = ensure_ast(state.stack[-2]) - new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.Div(), right=right)] - return replace(state, stack=new_stack) - return state + right = ensure_ast(state.stack[-1]) + left = ensure_ast(state.stack[-2]) + new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.Div(), right=right)] + return replace(state, stack=new_stack) @register_handler('BINARY_FLOOR_DIVIDE') def handle_binary_floor_divide(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - if len(state.stack) >= 2: - right = ensure_ast(state.stack[-1]) - left = ensure_ast(state.stack[-2]) - new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.FloorDiv(), right=right)] - return replace(state, stack=new_stack) - return state + right = ensure_ast(state.stack[-1]) + left = ensure_ast(state.stack[-2]) + new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.FloorDiv(), right=right)] + return replace(state, stack=new_stack) @register_handler('BINARY_MODULO') def handle_binary_modulo(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - if len(state.stack) >= 2: - right = ensure_ast(state.stack[-1]) - left = ensure_ast(state.stack[-2]) - new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.Mod(), right=right)] - return replace(state, stack=new_stack) - return state + right = ensure_ast(state.stack[-1]) + left = ensure_ast(state.stack[-2]) + new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.Mod(), right=right)] + return replace(state, stack=new_stack) @register_handler('BINARY_POWER') def handle_binary_power(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - if len(state.stack) >= 2: - right = ensure_ast(state.stack[-1]) - left = ensure_ast(state.stack[-2]) - new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.Pow(), right=right)] - return replace(state, stack=new_stack) - return state + right = ensure_ast(state.stack[-1]) + left = ensure_ast(state.stack[-2]) + new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.Pow(), right=right)] + return replace(state, stack=new_stack) @register_handler('BINARY_LSHIFT') def handle_binary_lshift(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - if len(state.stack) >= 2: - right = ensure_ast(state.stack[-1]) - left = ensure_ast(state.stack[-2]) - new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.LShift(), right=right)] - return replace(state, stack=new_stack) - return state + right = ensure_ast(state.stack[-1]) + left = ensure_ast(state.stack[-2]) + new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.LShift(), right=right)] + return replace(state, stack=new_stack) @register_handler('BINARY_RSHIFT') def handle_binary_rshift(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - if len(state.stack) >= 2: - right = ensure_ast(state.stack[-1]) - left = ensure_ast(state.stack[-2]) - new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.RShift(), right=right)] - return replace(state, stack=new_stack) - return state + right = ensure_ast(state.stack[-1]) + left = ensure_ast(state.stack[-2]) + new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.RShift(), right=right)] + return replace(state, stack=new_stack) @register_handler('BINARY_OR') def handle_binary_or(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - if len(state.stack) >= 2: - right = ensure_ast(state.stack[-1]) - left = ensure_ast(state.stack[-2]) - new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.BitOr(), right=right)] - return replace(state, stack=new_stack) - return state + right = ensure_ast(state.stack[-1]) + left = ensure_ast(state.stack[-2]) + new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.BitOr(), right=right)] + return replace(state, stack=new_stack) @register_handler('BINARY_XOR') def handle_binary_xor(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - if len(state.stack) >= 2: - right = ensure_ast(state.stack[-1]) - left = ensure_ast(state.stack[-2]) - new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.BitXor(), right=right)] - return replace(state, stack=new_stack) - return state + right = ensure_ast(state.stack[-1]) + left = ensure_ast(state.stack[-2]) + new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.BitXor(), right=right)] + return replace(state, stack=new_stack) @register_handler('BINARY_AND') def handle_binary_and(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - if len(state.stack) >= 2: - right = ensure_ast(state.stack[-1]) - left = ensure_ast(state.stack[-2]) - new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.BitAnd(), right=right)] - return replace(state, stack=new_stack) - return state + right = ensure_ast(state.stack[-1]) + left = ensure_ast(state.stack[-2]) + new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.BitAnd(), right=right)] + return replace(state, stack=new_stack) # ============================================================================ @@ -529,38 +505,30 @@ def handle_binary_and(state: ReconstructionState, instr: dis.Instruction) -> Rec @register_handler('UNARY_NEGATIVE') def handle_unary_negative(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - if state.stack: - operand = ensure_ast(state.stack[-1]) - new_stack = state.stack[:-1] + [ast.UnaryOp(op=ast.USub(), operand=operand)] - return replace(state, stack=new_stack) - return state + operand = ensure_ast(state.stack[-1]) + new_stack = state.stack[:-1] + [ast.UnaryOp(op=ast.USub(), operand=operand)] + return replace(state, stack=new_stack) @register_handler('UNARY_POSITIVE') def handle_unary_positive(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - if state.stack: - operand = ensure_ast(state.stack[-1]) - new_stack = state.stack[:-1] + [ast.UnaryOp(op=ast.UAdd(), operand=operand)] - return replace(state, stack=new_stack) - return state + operand = ensure_ast(state.stack[-1]) + new_stack = state.stack[:-1] + [ast.UnaryOp(op=ast.UAdd(), operand=operand)] + return replace(state, stack=new_stack) @register_handler('UNARY_INVERT') def handle_unary_invert(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - if state.stack: - operand = ensure_ast(state.stack[-1]) - new_stack = state.stack[:-1] + [ast.UnaryOp(op=ast.Invert(), operand=operand)] - return replace(state, stack=new_stack) - return state + operand = ensure_ast(state.stack[-1]) + new_stack = state.stack[:-1] + [ast.UnaryOp(op=ast.Invert(), operand=operand)] + return replace(state, stack=new_stack) @register_handler('UNARY_NOT') def handle_unary_not(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - if state.stack: - operand = ensure_ast(state.stack[-1]) - new_stack = state.stack[:-1] + [ast.UnaryOp(op=ast.Not(), operand=operand)] - return replace(state, stack=new_stack) - return state + operand = ensure_ast(state.stack[-1]) + new_stack = state.stack[:-1] + [ast.UnaryOp(op=ast.Not(), operand=operand)] + return replace(state, stack=new_stack) # ============================================================================ @@ -571,56 +539,47 @@ def handle_unary_not(state: ReconstructionState, instr: dis.Instruction) -> Reco def handle_call_function(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: # CALL_FUNCTION pops function and arguments from stack arg_count = instr.arg - if len(state.stack) >= arg_count + 1: - # Pop arguments and function - args = [ensure_ast(arg) for arg in state.stack[-arg_count:]] if arg_count > 0 else [] - func = ensure_ast(state.stack[-arg_count - 1]) - new_stack = state.stack[:-arg_count - 1] - - # Create function call AST - call_node = ast.Call(func=func, args=args, keywords=[]) - new_stack = new_stack + [call_node] - return replace(state, stack=new_stack) + # Pop arguments and function + args = [ensure_ast(arg) for arg in state.stack[-arg_count:]] if arg_count > 0 else [] + func = ensure_ast(state.stack[-arg_count - 1]) + new_stack = state.stack[:-arg_count - 1] + + # Create function call AST + call_node = ast.Call(func=func, args=args, keywords=[]) + new_stack = new_stack + [call_node] + return replace(state, stack=new_stack) - return state - @register_handler('LOAD_METHOD') def handle_load_method(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: # LOAD_METHOD loads a method from an object # It pushes the bound method and the object (for the method call) - if state.stack: - obj = ensure_ast(state.stack[-1]) - method_name = instr.argval - new_stack = state.stack[:-1] - - # Create method access as an attribute - method_attr = ast.Attribute(value=obj, attr=method_name, ctx=ast.Load()) - - # For LOAD_METHOD, we push both the method and the object - # But for AST purposes, we just need the method attribute - new_stack = new_stack + [method_attr] - return replace(state, stack=new_stack) + obj = ensure_ast(state.stack[-1]) + method_name = instr.argval + new_stack = state.stack[:-1] - return state + # Create method access as an attribute + method_attr = ast.Attribute(value=obj, attr=method_name, ctx=ast.Load()) + + # For LOAD_METHOD, we push both the method and the object + # But for AST purposes, we just need the method attribute + new_stack = new_stack + [method_attr] + return replace(state, stack=new_stack) @register_handler('CALL_METHOD') def handle_call_method(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: # CALL_METHOD calls a method - similar to CALL_FUNCTION but for methods arg_count = instr.arg - if len(state.stack) >= arg_count + 1: - # Pop arguments and method - args = [ensure_ast(arg) for arg in state.stack[-arg_count:]] if arg_count > 0 else [] - method = ensure_ast(state.stack[-arg_count - 1]) - new_stack = state.stack[:-arg_count - 1] - - # Create method call AST - call_node = ast.Call(func=method, args=args, keywords=[]) - new_stack = new_stack + [call_node] - return replace(state, stack=new_stack) + # Pop arguments and method + args = [ensure_ast(arg) for arg in state.stack[-arg_count:]] if arg_count > 0 else [] + method = ensure_ast(state.stack[-arg_count - 1]) + new_stack = state.stack[:-arg_count - 1] - return state + # Create method call AST + call_node = ast.Call(func=method, args=args, keywords=[]) + new_stack = new_stack + [call_node] + return replace(state, stack=new_stack) @register_handler('CALL') @@ -628,18 +587,16 @@ def handle_call(state: ReconstructionState, instr: dis.Instruction) -> Reconstru # CALL is the newer unified call instruction (Python 3.11+) # Similar to CALL_FUNCTION but with a different calling convention arg_count = instr.arg - if len(state.stack) >= arg_count + 1: - # Pop arguments and function - args = [ensure_ast(arg) for arg in state.stack[-arg_count:]] if arg_count > 0 else [] - func = ensure_ast(state.stack[-arg_count - 1]) - new_stack = state.stack[:-arg_count - 1] - - # Create function call AST - call_node = ast.Call(func=func, args=args, keywords=[]) - new_stack = new_stack + [call_node] - return replace(state, stack=new_stack) + + # Pop arguments and function + args = [ensure_ast(arg) for arg in state.stack[-arg_count:]] if arg_count > 0 else [] + func = ensure_ast(state.stack[-arg_count - 1]) + new_stack = state.stack[:-arg_count - 1] - return state + # Create function call AST + call_node = ast.Call(func=func, args=args, keywords=[]) + new_stack = new_stack + [call_node] + return replace(state, stack=new_stack) @register_handler('PRECALL') @@ -653,65 +610,61 @@ def handle_precall(state: ReconstructionState, instr: dis.Instruction) -> Recons def handle_make_function(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: # MAKE_FUNCTION creates a function from code object and name on stack # For lambda functions, we need to reconstruct the lambda expression - if len(state.stack) >= 2: - name = state.stack[-1] # Function name (usually '') - code_obj = state.stack[-2] # Code object - new_stack = state.stack[:-2] + code_obj: ast.Constant = state.stack[-2] # Code object + lambda_code: types.CodeType = code_obj.value + + # For lambda functions, try to reconstruct the lambda expression + lambda_code = code_obj.value + + # Simple lambda reconstruction - try to extract the basic pattern + # This is a simplified approach for common lambda patterns + # Get the lambda's bytecode instructions + lambda_instructions = list(dis.get_instructions(lambda_code)) + + # Map binary operations + op_map = { + 'BINARY_MULTIPLY': ast.Mult(), + 'BINARY_ADD': ast.Add(), + 'BINARY_SUBTRACT': ast.Sub(), + 'BINARY_TRUE_DIVIDE': ast.Div(), + 'BINARY_FLOOR_DIVIDE': ast.FloorDiv(), + 'BINARY_MODULO': ast.Mod(), + 'BINARY_POWER': ast.Pow(), + } + + # For simple lambdas like "lambda y: y * 2" + # Look for pattern: LOAD_FAST, LOAD_CONST, BINARY_OP, RETURN_VALUE + if (len(lambda_instructions) == 4 and + lambda_instructions[0].opname == 'LOAD_FAST' and + lambda_instructions[1].opname == 'LOAD_CONST' and + lambda_instructions[2].opname in op_map and + lambda_instructions[3].opname == 'RETURN_VALUE'): - # For lambda functions, try to reconstruct the lambda expression - if isinstance(code_obj, ast.Constant) and hasattr(code_obj.value, 'co_code'): - lambda_code = code_obj.value - - # Simple lambda reconstruction - try to extract the basic pattern - # This is a simplified approach for common lambda patterns - # Get the lambda's bytecode instructions - lambda_instructions = list(dis.get_instructions(lambda_code)) - - # For simple lambdas like "lambda y: y * 2" - # Look for pattern: LOAD_FAST, LOAD_CONST, BINARY_OP, RETURN_VALUE - if (len(lambda_instructions) == 4 and - lambda_instructions[0].opname == 'LOAD_FAST' and - lambda_instructions[1].opname == 'LOAD_CONST' and - lambda_instructions[2].opname.startswith('BINARY_') and - lambda_instructions[3].opname == 'RETURN_VALUE'): - - param_name = lambda_instructions[0].argval - const_value = lambda_instructions[1].argval - op_name = lambda_instructions[2].opname - - # Map binary operations - op_map = { - 'BINARY_MULTIPLY': ast.Mult(), - 'BINARY_ADD': ast.Add(), - 'BINARY_SUBTRACT': ast.Sub(), - 'BINARY_TRUE_DIVIDE': ast.Div(), - 'BINARY_FLOOR_DIVIDE': ast.FloorDiv(), - 'BINARY_MODULO': ast.Mod(), - 'BINARY_POWER': ast.Pow(), - } - - if op_name in op_map: - # Create lambda AST: lambda param: param op constant - lambda_ast = ast.Lambda( - args=ast.arguments( - posonlyargs=[], - args=[ast.arg(arg=param_name, annotation=None)], - vararg=None, - kwonlyargs=[], - kw_defaults=[], - kwarg=None, - defaults=[] - ), - body=ast.BinOp( - left=ast.Name(id=param_name, ctx=ast.Load()), - op=op_map[op_name], - right=ast.Constant(value=const_value) - ) - ) - new_stack = new_stack + [lambda_ast] - return replace(state, stack=new_stack) - - return state + param_name = lambda_instructions[0].argval + const_value = lambda_instructions[1].argval + op_name = lambda_instructions[2].opname + + # Create lambda AST: lambda param: param op constant + lambda_ast = ast.Lambda( + args=ast.arguments( + posonlyargs=[], + args=[ast.arg(arg=param_name, annotation=None)], + vararg=None, + kwonlyargs=[], + kw_defaults=[], + kwarg=None, + defaults=[] + ), + body=ast.BinOp( + left=ast.Name(id=param_name, ctx=ast.Load()), + op=op_map[op_name], + right=ast.Constant(value=const_value) + ) + ) + new_stack = state.stack[:-2] + [lambda_ast] + return replace(state, stack=new_stack) + else: + raise NotImplementedError("Complex lambda reconstruction not implemented yet.") @register_handler('GET_ITER') @@ -729,33 +682,27 @@ def handle_get_iter(state: ReconstructionState, instr: dis.Instruction) -> Recon @register_handler('LOAD_ATTR') def handle_load_attr(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: # LOAD_ATTR loads an attribute from the object on top of stack - if state.stack: - obj = ensure_ast(state.stack[-1]) - attr_name = instr.argval - new_stack = state.stack[:-1] - - # Create attribute access AST - attr_node = ast.Attribute(value=obj, attr=attr_name, ctx=ast.Load()) - new_stack = new_stack + [attr_node] - return replace(state, stack=new_stack) + obj = ensure_ast(state.stack[-1]) + attr_name = instr.argval + new_stack = state.stack[:-1] - return state + # Create attribute access AST + attr_node = ast.Attribute(value=obj, attr=attr_name, ctx=ast.Load()) + new_stack = new_stack + [attr_node] + return replace(state, stack=new_stack) @register_handler('BINARY_SUBSCR') def handle_binary_subscr(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: # BINARY_SUBSCR implements obj[index] - pops index and obj from stack - if len(state.stack) >= 2: - index = ensure_ast(state.stack[-1]) # Index is on top - obj = ensure_ast(state.stack[-2]) # Object is below index - new_stack = state.stack[:-2] - - # Create subscript access AST - subscr_node = ast.Subscript(value=obj, slice=index, ctx=ast.Load()) - new_stack = new_stack + [subscr_node] - return replace(state, stack=new_stack) + index = ensure_ast(state.stack[-1]) # Index is on top + obj = ensure_ast(state.stack[-2]) # Object is below index + new_stack = state.stack[:-2] - return state + # Create subscript access AST + subscr_node = ast.Subscript(value=obj, slice=index, ctx=ast.Load()) + new_stack = new_stack + [subscr_node] + return replace(state, stack=new_stack) # ============================================================================ @@ -765,63 +712,54 @@ def handle_binary_subscr(state: ReconstructionState, instr: dis.Instruction) -> @register_handler('BUILD_TUPLE') def handle_build_tuple(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: tuple_size: int = instr.arg - if len(state.stack) >= tuple_size: - # Pop elements for the tuple - elements = [ensure_ast(elem) for elem in state.stack[-tuple_size:]] if tuple_size > 0 else [] - new_stack = state.stack[:-tuple_size] if tuple_size > 0 else state.stack - - # Create tuple AST - tuple_node = ast.Tuple(elts=elements, ctx=ast.Load()) - new_stack = new_stack + [tuple_node] - return replace(state, stack=new_stack) + # Pop elements for the tuple + elements = [ensure_ast(elem) for elem in state.stack[-tuple_size:]] if tuple_size > 0 else [] + new_stack = state.stack[:-tuple_size] if tuple_size > 0 else state.stack - return state + # Create tuple AST + tuple_node = ast.Tuple(elts=elements, ctx=ast.Load()) + new_stack = new_stack + [tuple_node] + return replace(state, stack=new_stack) @register_handler('BUILD_LIST') def handle_build_list(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: list_size = instr.arg - if len(state.stack) >= list_size: - # Pop elements for the list - elements = [ensure_ast(elem) for elem in state.stack[-list_size:]] if list_size > 0 else [] - new_stack = state.stack[:-list_size] if list_size > 0 else state.stack - - # Create list AST - list_node = ast.List(elts=elements, ctx=ast.Load()) - new_stack = new_stack + [list_node] - return replace(state, stack=new_stack) + # Pop elements for the list + elements = [ensure_ast(elem) for elem in state.stack[-list_size:]] if list_size > 0 else [] + new_stack = state.stack[:-list_size] if list_size > 0 else state.stack - return state + # Create list AST + list_node = ast.List(elts=elements, ctx=ast.Load()) + new_stack = new_stack + [list_node] + return replace(state, stack=new_stack) @register_handler('LIST_EXTEND') def handle_list_extend(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: # LIST_EXTEND extends the list at TOS-1 with the iterable at TOS - if len(state.stack) >= 2: - iterable = ensure_ast(state.stack[-1]) - list_obj = state.stack[-2] # This should be a list from BUILD_LIST - new_stack = state.stack[:-2] - - # If the list is empty and we're extending with a tuple/iterable, - # we can convert this to a simple list of the iterable's elements - if isinstance(list_obj, ast.List) and len(list_obj.elts) == 0: - # If extending with a constant tuple, expand it to list elements - if isinstance(iterable, ast.Constant) and isinstance(iterable.value, tuple): - elements = [ast.Constant(value=elem) for elem in iterable.value] - list_node = ast.List(elts=elements, ctx=ast.Load()) - new_stack = new_stack + [list_node] - return replace(state, stack=new_stack) - - # Fallback: create a list from the iterable using list() constructor - list_call = ast.Call( - func=ast.Name(id='list', ctx=ast.Load()), - args=[iterable], - keywords=[] - ) - new_stack = new_stack + [list_call] - return replace(state, stack=new_stack) + iterable = ensure_ast(state.stack[-1]) + list_obj = state.stack[-2] # This should be a list from BUILD_LIST + new_stack = state.stack[:-2] - return state + # If the list is empty and we're extending with a tuple/iterable, + # we can convert this to a simple list of the iterable's elements + if isinstance(list_obj, ast.List) and len(list_obj.elts) == 0: + # If extending with a constant tuple, expand it to list elements + if isinstance(iterable, ast.Constant) and isinstance(iterable.value, tuple): + elements = [ast.Constant(value=elem) for elem in iterable.value] + list_node = ast.List(elts=elements, ctx=ast.Load()) + new_stack = new_stack + [list_node] + return replace(state, stack=new_stack) + + # Fallback: create a list from the iterable using list() constructor + list_call = ast.Call( + func=ast.Name(id='list', ctx=ast.Load()), + args=[iterable], + keywords=[] + ) + new_stack = new_stack + [list_call] + return replace(state, stack=new_stack) @register_handler('BUILD_CONST_KEY_MAP') @@ -829,25 +767,22 @@ def handle_build_const_key_map(state: ReconstructionState, instr: dis.Instructio # BUILD_CONST_KEY_MAP builds a dictionary with constant keys # The keys are in a tuple on TOS, values are on the stack below map_size = instr.arg - if len(state.stack) >= map_size + 1: - # Pop the keys tuple and values - keys_tuple = state.stack[-1] - values = [ensure_ast(val) for val in state.stack[-map_size-1:-1]] - new_stack = state.stack[:-map_size-1] - - # Extract keys from the constant tuple - if isinstance(keys_tuple, ast.Constant) and isinstance(keys_tuple.value, tuple): - keys = [ast.Constant(value=key) for key in keys_tuple.value] - else: - # Fallback if keys are not in expected format - keys = [ast.Constant(value=f'key_{i}') for i in range(len(values))] - - # Create dictionary AST - dict_node = ast.Dict(keys=keys, values=values) - new_stack = new_stack + [dict_node] - return replace(state, stack=new_stack) + # Pop the keys tuple and values + keys_tuple = state.stack[-1] + values = [ensure_ast(val) for val in state.stack[-map_size-1:-1]] + new_stack = state.stack[:-map_size-1] - return state + # Extract keys from the constant tuple + if isinstance(keys_tuple, ast.Constant) and isinstance(keys_tuple.value, tuple): + keys = [ast.Constant(value=key) for key in keys_tuple.value] + else: + # Fallback if keys are not in expected format + keys = [ast.Constant(value=f'key_{i}') for i in range(len(values))] + + # Create dictionary AST + dict_node = ast.Dict(keys=keys, values=values) + new_stack = new_stack + [dict_node] + return replace(state, stack=new_stack) # ============================================================================ @@ -856,75 +791,68 @@ def handle_build_const_key_map(state: ReconstructionState, instr: dis.Instructio @register_handler('COMPARE_OP') def handle_compare_op(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - if len(state.stack) >= 2: - right = ensure_ast(state.stack[-1]) - left = ensure_ast(state.stack[-2]) - - # Map comparison operation codes to AST operators - op_map = { - '<': ast.Lt(), - '<=': ast.LtE(), - '>': ast.Gt(), - '>=': ast.GtE(), - '==': ast.Eq(), - '!=': ast.NotEq(), - 'in': ast.In(), - 'not in': ast.NotIn(), - 'is': ast.Is(), - 'is not': ast.IsNot(), - } - - op_name = instr.argval - if op_name in op_map: - compare_node = ast.Compare( - left=left, - ops=[op_map[op_name]], - comparators=[right] - ) - new_stack = state.stack[:-2] + [compare_node] - return replace(state, stack=new_stack) + right = ensure_ast(state.stack[-1]) + left = ensure_ast(state.stack[-2]) - return state - - -@register_handler('CONTAINS_OP') -def handle_contains_op(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - if len(state.stack) >= 2: - right = ensure_ast(state.stack[-1]) # Container - left = ensure_ast(state.stack[-2]) # Item to check - - # instr.arg determines if it's 'in' (0) or 'not in' (1) - op = ast.NotIn() if instr.arg else ast.In() - + # Map comparison operation codes to AST operators + op_map = { + '<': ast.Lt(), + '<=': ast.LtE(), + '>': ast.Gt(), + '>=': ast.GtE(), + '==': ast.Eq(), + '!=': ast.NotEq(), + 'in': ast.In(), + 'not in': ast.NotIn(), + 'is': ast.Is(), + 'is not': ast.IsNot(), + } + + op_name = instr.argval + if op_name in op_map: compare_node = ast.Compare( left=left, - ops=[op], + ops=[op_map[op_name]], comparators=[right] ) new_stack = state.stack[:-2] + [compare_node] return replace(state, stack=new_stack) + else: + raise TypeError(f"Unsupported comparison operation: {op_name}") + + +@register_handler('CONTAINS_OP') +def handle_contains_op(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + right = ensure_ast(state.stack[-1]) # Container + left = ensure_ast(state.stack[-2]) # Item to check - return state + # instr.arg determines if it's 'in' (0) or 'not in' (1) + op = ast.NotIn() if instr.arg else ast.In() + + compare_node = ast.Compare( + left=left, + ops=[op], + comparators=[right] + ) + new_stack = state.stack[:-2] + [compare_node] + return replace(state, stack=new_stack) @register_handler('IS_OP') def handle_is_op(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - if len(state.stack) >= 2: - right = ensure_ast(state.stack[-1]) - left = ensure_ast(state.stack[-2]) - - # instr.arg determines if it's 'is' (0) or 'is not' (1) - op = ast.IsNot() if instr.arg else ast.Is() - - compare_node = ast.Compare( - left=left, - ops=[op], - comparators=[right] - ) - new_stack = state.stack[:-2] + [compare_node] - return replace(state, stack=new_stack) + right = ensure_ast(state.stack[-1]) + left = ensure_ast(state.stack[-2]) - return state + # instr.arg determines if it's 'is' (0) or 'is not' (1) + op = ast.IsNot() if instr.arg else ast.Is() + + compare_node = ast.Compare( + left=left, + ops=[op], + comparators=[right] + ) + new_stack = state.stack[:-2] + [compare_node] + return replace(state, stack=new_stack) # ============================================================================ @@ -935,42 +863,39 @@ def handle_is_op(state: ReconstructionState, instr: dis.Instruction) -> Reconstr def handle_pop_jump_if_false(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: # POP_JUMP_IF_FALSE pops a value from the stack and jumps if it's false # In comprehensions, this is used for filter conditions - if state.stack: - condition = ensure_ast(state.stack[-1]) - new_stack = state.stack[:-1] + condition = ensure_ast(state.stack[-1]) + new_stack = state.stack[:-1] + + # If we have pending OR conditions, this is the final condition in an OR expression + if state.or_conditions: + # Combine all OR conditions into a single BoolOp + all_or_conditions = state.or_conditions + [condition] + combined_condition = ast.BoolOp(op=ast.Or(), values=all_or_conditions) - # If we have pending OR conditions, this is the final condition in an OR expression - if state.or_conditions: - # Combine all OR conditions into a single BoolOp - all_or_conditions = state.or_conditions + [condition] - combined_condition = ast.BoolOp(op=ast.Or(), values=all_or_conditions) - - # Add the combined condition to the loop and clear OR conditions - if state.loops: - updated_loop = replace( - state.loops[-1], - conditions=state.loops[-1].conditions + [combined_condition] - ) - new_loops = state.loops[:-1] + [updated_loop] - return replace(state, stack=new_stack, loops=new_loops, or_conditions=[]) - else: - new_pending = state.pending_conditions + [combined_condition] - return replace(state, stack=new_stack, pending_conditions=new_pending, or_conditions=[]) + # Add the combined condition to the loop and clear OR conditions + if state.loops: + updated_loop = replace( + state.loops[-1], + conditions=state.loops[-1].conditions + [combined_condition] + ) + new_loops = state.loops[:-1] + [updated_loop] + return replace(state, stack=new_stack, loops=new_loops, or_conditions=[]) else: - # Regular condition - add to the most recent loop - if state.loops: - updated_loop = replace( - state.loops[-1], - conditions=state.loops[-1].conditions + [condition] - ) - new_loops = state.loops[:-1] + [updated_loop] - return replace(state, stack=new_stack, loops=new_loops) - else: - # If no loops yet, add to pending conditions - new_pending = state.pending_conditions + [condition] - return replace(state, stack=new_stack, pending_conditions=new_pending) - - return state + new_pending = state.pending_conditions + [combined_condition] + return replace(state, stack=new_stack, pending_conditions=new_pending, or_conditions=[]) + else: + # Regular condition - add to the most recent loop + if state.loops: + updated_loop = replace( + state.loops[-1], + conditions=state.loops[-1].conditions + [condition] + ) + new_loops = state.loops[:-1] + [updated_loop] + return replace(state, stack=new_stack, loops=new_loops) + else: + # If no loops yet, add to pending conditions + new_pending = state.pending_conditions + [condition] + return replace(state, stack=new_stack, pending_conditions=new_pending) @register_handler('POP_JUMP_IF_TRUE') @@ -979,35 +904,32 @@ def handle_pop_jump_if_true(state: ReconstructionState, instr: dis.Instruction) # This can be: # 1. Part of an OR expression (jump to YIELD_VALUE) # 2. A negated condition like "not x % 2" (jump back to loop start) - if state.stack: - condition = ensure_ast(state.stack[-1]) - new_stack = state.stack[:-1] + condition = ensure_ast(state.stack[-1]) + new_stack = state.stack[:-1] + + # Check if this jumps forward (to YIELD_VALUE - OR pattern) vs back to loop (NOT pattern) + # In OR: POP_JUMP_IF_TRUE jumps forward to yield the value + # In NOT: POP_JUMP_IF_TRUE jumps back to skip this iteration + if instr.argval > instr.offset: + # Jumping forward - part of an OR expression + new_or_conditions = state.or_conditions + [condition] + return replace(state, stack=new_stack, or_conditions=new_or_conditions) + else: + # Jumping backward to loop start - this is a negated condition + # When POP_JUMP_IF_TRUE jumps back, it means "if true, skip this item" + # So we need to negate the condition to get the filter condition + negated_condition = ast.UnaryOp(op=ast.Not(), operand=condition) - # Check if this jumps forward (to YIELD_VALUE - OR pattern) vs back to loop (NOT pattern) - # In OR: POP_JUMP_IF_TRUE jumps forward to yield the value - # In NOT: POP_JUMP_IF_TRUE jumps back to skip this iteration - if instr.argval > instr.offset: - # Jumping forward - part of an OR expression - new_or_conditions = state.or_conditions + [condition] - return replace(state, stack=new_stack, or_conditions=new_or_conditions) + if state.loops: + updated_loop = replace( + state.loops[-1], + conditions=state.loops[-1].conditions + [negated_condition] + ) + new_loops = state.loops[:-1] + [updated_loop] + return replace(state, stack=new_stack, loops=new_loops) else: - # Jumping backward to loop start - this is a negated condition - # When POP_JUMP_IF_TRUE jumps back, it means "if true, skip this item" - # So we need to negate the condition to get the filter condition - negated_condition = ast.UnaryOp(op=ast.Not(), operand=condition) - - if state.loops: - updated_loop = replace( - state.loops[-1], - conditions=state.loops[-1].conditions + [negated_condition] - ) - new_loops = state.loops[:-1] + [updated_loop] - return replace(state, stack=new_stack, loops=new_loops) - else: - new_pending = state.pending_conditions + [negated_condition] - return replace(state, stack=new_stack, pending_conditions=new_pending) - - return state + new_pending = state.pending_conditions + [negated_condition] + return replace(state, stack=new_stack, pending_conditions=new_pending) # ============================================================================ @@ -1019,20 +941,17 @@ def handle_unpack_sequence(state: ReconstructionState, instr: dis.Instruction) - # UNPACK_SEQUENCE unpacks a sequence into multiple values # arg is the number of values to unpack unpack_count = instr.arg - if state.stack: - sequence = ensure_ast(state.stack[-1]) - new_stack = state.stack[:-1] - - # For tuple unpacking in comprehensions, we typically see patterns like: - # ((k, v) for k, v in items) where items is unpacked into k and v - # Create placeholder variables for the unpacked values - for i in range(unpack_count): - var_name = f'_unpack_{i}' - new_stack = new_stack + [ast.Name(id=var_name, ctx=ast.Load())] - - return replace(state, stack=new_stack) + sequence = ensure_ast(state.stack[-1]) + new_stack = state.stack[:-1] - return state + # For tuple unpacking in comprehensions, we typically see patterns like: + # ((k, v) for k, v in items) where items is unpacked into k and v + # Create placeholder variables for the unpacked values + for i in range(unpack_count): + var_name = f'_unpack_{i}' + new_stack = new_stack + [ast.Name(id=var_name, ctx=ast.Load())] + + return replace(state, stack=new_stack) # ============================================================================ @@ -1070,7 +989,7 @@ def handle_extended_arg(state: ReconstructionState, instr: dis.Instruction) -> R @functools.singledispatch def ensure_ast(value) -> ast.AST: """Ensure value is an AST node""" - return ast.Name(id=str(value), ctx=ast.Load()) + raise TypeError(f"Cannot convert {type(value)} to AST node") @ensure_ast.register @@ -1102,11 +1021,31 @@ def _ensure_ast_tuple(value: tuple) -> ast.Tuple: return ast.Tuple(elts=[ensure_ast(v) for v in value], ctx=ast.Load()) +@ensure_ast.register(type(iter((1,)))) +def _ensure_ast_tuple_iterator(value: Iterator) -> ast.AST: + return ensure_ast(tuple(value.__reduce__()[1][0])) + + @ensure_ast.register def _ensure_ast_list(value: list) -> ast.List: return ast.List(elts=[ensure_ast(v) for v in value], ctx=ast.Load()) +@ensure_ast.register(type(iter([1]))) +def _ensure_ast_list_iterator(value: Iterator) -> ast.AST: + return ensure_ast(list(value.__reduce__()[1][0])) + + +@ensure_ast.register +def _ensure_ast_set(value: set) -> ast.Set: + return ast.Set(elts=[ensure_ast(v) for v in value]) + + +@ensure_ast.register(type(iter({1}))) +def _ensure_ast_set_iterator(value: Iterator) -> ast.AST: + return ensure_ast(set(value.__reduce__()[1][0])) + + @ensure_ast.register def _ensure_ast_dict(value: dict) -> ast.Dict: return ast.Dict( @@ -1115,9 +1054,14 @@ def _ensure_ast_dict(value: dict) -> ast.Dict: ) +@ensure_ast.register(type(iter({1: 2}))) +def _ensure_ast_dict_iterator(value: Iterator) -> ast.AST: + # TODO figure out how to handle dict iterators + raise TypeError("dict key iterator not yet supported") + + @ensure_ast.register def _ensure_ast_range(value: range) -> ast.Call: - """Convert range to AST Call""" return ast.Call( func=ast.Name(id='range', ctx=ast.Load()), args=[ensure_ast(value.start), ensure_ast(value.stop), ensure_ast(value.step)], @@ -1126,10 +1070,7 @@ def _ensure_ast_range(value: range) -> ast.Call: @ensure_ast.register(type(iter(range(1)))) -@ensure_ast.register(type(iter([1]))) -@ensure_ast.register(type(iter((1,)))) -def _ensure_ast_iterator(value: Iterator) -> ast.AST: - """Convert iterator to AST - special handling for iterators""" +def _ensure_ast_range_iterator(value: Iterator) -> ast.AST: return ensure_ast(value.__reduce__()[1][0]) @@ -1187,16 +1128,54 @@ def build_comprehension_ast(state: ReconstructionState) -> ast.AST: # MAIN RECONSTRUCTION FUNCTION # ============================================================================ -def reconstruct(genexpr: GeneratorType) -> ast.GeneratorExp | ast.ListComp | ast.SetComp | ast.DictComp: +def reconstruct(genexpr: GeneratorType) -> ast.GeneratorExp: """ - Reconstructs an AST from a generator expression's bytecode. + Reconstruct an AST from a generator expression's bytecode. + + This function analyzes the bytecode of a generator object and reconstructs + an abstract syntax tree (AST) that represents the original comprehension + expression. The reconstruction process simulates the Python VM's execution + of the bytecode, building AST nodes instead of executing operations. + + The reconstruction handles complex comprehension features including: + - Multiple nested loops + - Filter conditions (if clauses) + - Complex expressions in the yield/result part + - Tuple unpacking in loop variables + - Various operators and function calls Args: - genexpr (GeneratorType): The generator object to analyze. - + genexpr (GeneratorType): The generator object to analyze. Must be + a freshly created generator that has not been iterated yet + (in 'GEN_CREATED' state). + Returns: - An AST node representing the generator expression. - Can be ast.GeneratorExp, ast.ListComp, ast.SetComp, or ast.DictComp. + ast.GeneratorExp: An AST node representing the reconstructed comprehension. + The specific type depends on the original comprehension: + + Raises: + AssertionError: If the input is not a generator or if the generator + has already been started (not in 'GEN_CREATED' state). + + Example: + >>> # Generator expression + >>> g = (x * 2 for x in range(10) if x % 2 == 0) + >>> ast_node = reconstruct(g) + >>> isinstance(ast_node, ast.GeneratorExp) + True + + >>> # The reconstructed AST can be compiled and evaluated + >>> import ast + >>> code = compile(ast.Expression(body=ast_node), '', 'eval') + >>> result = eval(code) + >>> list(result) + [0, 4, 8, 12, 16] + + Note: + The reconstruction is based on bytecode analysis and may not perfectly + preserve the original source code formatting or variable names in all + cases. However, the semantic behavior of the reconstructed AST should + match the original comprehension. """ assert inspect.isgenerator(genexpr), "Input must be a generator expression" assert inspect.getgeneratorstate(genexpr) == 'GEN_CREATED', "Generator must be in created state" diff --git a/tests/test_ops_syntax_generator.py b/tests/test_ops_syntax_generator.py index 079c47ea..d60cdfbf 100644 --- a/tests/test_ops_syntax_generator.py +++ b/tests/test_ops_syntax_generator.py @@ -1,6 +1,7 @@ import ast import pytest import dis +import inspect from types import GeneratorType from typing import Any, Union @@ -26,8 +27,21 @@ def compile_and_eval(node: ast.AST, globals_dict: dict = None) -> Any: def assert_ast_equivalent(genexpr: GeneratorType, reconstructed_ast: ast.AST, globals_dict: dict = None): """Assert that a reconstructed AST produces the same results as the original generator.""" - # Evaluate both to lists for comparison + assert inspect.isgenerator(genexpr), "Input must be a generator" + assert inspect.getgeneratorstate(genexpr) == 'GEN_CREATED', "Generator must not be consumed" + + # Save current globals to restore later + curr_globals = globals().copy() + globals().update(globals_dict or {}) + + # Materialize original generator to list for comparison original_list = list(genexpr) + + # Clean up globals to avoid pollution + for key in globals_dict or {}: + if key not in curr_globals: + del globals()[key] + globals().update(curr_globals) # Compile and evaluate the reconstructed AST reconstructed_gen = compile_and_eval(reconstructed_ast, globals_dict) @@ -137,6 +151,73 @@ def test_arithmetic_expressions(genexpr): assert_ast_equivalent(genexpr, ast_node) +# ============================================================================ +# COMPARISON OPERATORS +# ============================================================================ + +@pytest.mark.parametrize("genexpr", [ + # All comparison operators + (x for x in range(10) if x < 5), + (x for x in range(10) if x <= 5), + (x for x in range(10) if x > 5), + (x for x in range(10) if x >= 5), + (x for x in range(10) if x == 5), + (x for x in range(10) if x != 5), + + # in/not in operators + (x for x in range(10) if x in [2, 4, 6, 8]), + (x for x in range(10) if x not in [2, 4, 6, 8]), + + # is/is not operators (with None) + (x for x in [1, None, 3, None, 5] if x is not None), + (x for x in [1, None, 3, None, 5] if x is None), + + # Boolean operations - these are complex cases that might need special handling + (x for x in range(10) if x > 2 and x < 8), + (x for x in range(10) if x < 3 or x > 7), + (x for x in range(10) if not x % 2), + (x for x in range(10) if not (x > 5)), + + # More complex comparison edge cases + # Chained comparisons + (x for x in range(20) if 5 < x < 15), + (x for x in range(20) if 0 <= x <= 10), + (x for x in range(20) if x >= 5 and x <= 15), + + # Comparisons with expressions + (x for x in range(10) if x * 2 > 10), + (x for x in range(10) if x + 1 <= 5), + (x for x in range(10) if x ** 2 < 25), + (x for x in range(10) if (x + 1) * 2 != 6), + + # Complex membership tests + (x for x in range(20) if x in range(5, 15)), + (x for x in range(10) if x not in range(3, 7)), + (x for x in range(10) if x % 2 in [0]), + (x for x in range(10) if x not in []), # Empty container + + # Complex boolean combinations + (x for x in range(20) if x > 5 and x < 15 and x % 2 == 0), + (x for x in range(20) if x < 5 or x > 15 or x == 10), + (x for x in range(20) if not (x > 5 and x < 15)), + (x for x in range(20) if not (x < 5 or x > 15)), + + # Mixed comparison and boolean operations + (x for x in range(20) if (x > 10 and x % 2 == 0) or (x < 5 and x % 3 == 0)), + (x for x in range(20) if not (x % 2 == 0 and x % 3 == 0)), + + # Edge cases with identity comparisons + (x for x in [0, 1, 2, None, 4] if x is not None and x > 1), + (x for x in [True, False, 1, 0] if x is True), + (x for x in [True, False, 1, 0] if x is not False), +]) +def test_comparison_operators(genexpr): + """Test reconstruction of all comparison operators.""" + ast_node = reconstruct(genexpr) + assert isinstance(ast_node, ast.GeneratorExp) + assert_ast_equivalent(genexpr, ast_node) + + # ============================================================================ # FILTERED GENERATOR TESTS # ============================================================================ @@ -208,7 +289,7 @@ def test_filtered_generators(genexpr): ((x, y) for x in range(3) for y in range(3)), (x + y for x in range(3) for y in range(3)), (x * y for x in range(1, 4) for y in range(1, 4)), - + # Nested with filters ((x, y) for x in range(5) for y in range(5) if x < y), (x + y for x in range(5) if x % 2 == 0 for y in range(5) if y % 2 == 1), @@ -248,6 +329,11 @@ def test_filtered_generators(genexpr): # Mixed range types ((x, y) for x in range(-2, 2) for y in range(0, 4, 2)), (x * y for x in range(5, 0, -1) for y in range(1, 6)), + + # Dependent nested loops + ((x, y) for x in range(3) for y in range(x, 3)), + (x + y for x in range(3) for y in range(x + 1, 3)), + (x * y * z for x in range(3) for y in range(x + 1, x + 3) for z in range(y, y + 3)), ]) def test_nested_loops(genexpr): """Test reconstruction of generators with nested loops.""" @@ -264,54 +350,58 @@ def test_nested_loops(genexpr): # DIFFERENT COMPREHENSION TYPES # ============================================================================ -@pytest.mark.parametrize("comprehension,expected_type", [ +@pytest.mark.parametrize("genexpr", [ # List comprehensions - ([x for x in range(5)], ast.ListComp), - ([x * 2 for x in range(5)], ast.ListComp), - ([x for x in range(10) if x % 2 == 0], ast.ListComp), + (x_ for x_ in [x for x in range(5)]), + (x_ for x_ in [x * 2 for x in range(5)]), + (x_ for x_ in [x for x in range(10) if x % 2 == 0]), # Set comprehensions - ({x for x in range(5)}, ast.SetComp), - ({x * 2 for x in range(5)}, ast.SetComp), - ({x for x in range(10) if x % 2 == 0}, ast.SetComp), + (x_ for x_ in {x for x in range(5)}), + (x_ for x_ in {x * 2 for x in range(5)}), + (x_ for x_ in {x for x in range(10) if x % 2 == 0}), # Dict comprehensions - ({x: x**2 for x in range(5)}, ast.DictComp), - ({x: x*2 for x in range(5) if x % 2 == 0}, ast.DictComp), - ({str(x): x for x in range(5)}, ast.DictComp), + pytest.param((x_ for x_ in {x: x**2 for x in range(5)}), marks=pytest.mark.xfail(reason="Dict comprehensions not yet supported")), + pytest.param((x_ for x_ in {x: x*2 for x in range(5) if x % 2 == 0}), marks=pytest.mark.xfail(reason="Dict comprehensions not yet supported")), + pytest.param((x_ for x_ in {str(x): x for x in range(5)}), marks=pytest.mark.xfail(reason="Dict comprehensions not yet supported")), ]) -def test_different_comprehension_types(comprehension, expected_type): +def test_different_comprehension_types(genexpr): """Test reconstruction of different comprehension types.""" - # Convert to generator for reconstruction - if isinstance(comprehension, list): - genexpr = (x for x in comprehension) - elif isinstance(comprehension, set): - genexpr = (x for x in comprehension) - elif isinstance(comprehension, dict): - genexpr = ((k, v) for k, v in comprehension.items()) - else: - genexpr = comprehension - - # Note: The actual implementation would need to detect the comprehension type - # from the bytecode. This test assumes it can do that. ast_node = reconstruct(genexpr) - - # For now, we'll check if it's at least a comprehension - assert isinstance(ast_node, (ast.GeneratorExp, ast.ListComp, ast.SetComp, ast.DictComp)) + assert_ast_equivalent(genexpr, ast_node) # ============================================================================ -# EDGE CASES AND COMPLEX SCENARIOS +# GENERATOR EXPRESSION WITH GLOBALS # ============================================================================ @pytest.mark.parametrize("genexpr,globals_dict", [ + # Using constants + ((x + a for x in range(5)), {'a': 10}), + ((data[i] for i in range(2)), {'data': [3, 4]}), + # Using global functions ((abs(x) for x in range(-5, 5)), {'abs': abs}), ((len(s) for s in ["a", "ab", "abc"]), {'len': len}), ((max(x, 5) for x in range(10)), {'max': max}), ((min(x, 5) for x in range(10)), {'min': min}), ((round(x / 3, 2) for x in range(10)), {'round': round}), +]) +def test_variable_lookup(genexpr, globals_dict): + """Test reconstruction of expressions with globals.""" + ast_node = reconstruct(genexpr) + assert isinstance(ast_node, ast.GeneratorExp) + # Need to provide the same globals for evaluation + assert_ast_equivalent(genexpr, ast_node, globals_dict) + + +# ============================================================================ +# EDGE CASES AND COMPLEX SCENARIOS +# ============================================================================ + +@pytest.mark.parametrize("genexpr,globals_dict", [ # Using lambdas and functions (((lambda y: y * 2)(x) for x in range(5)), {}), (((lambda y: y + 1)(x) for x in range(5)), {}), @@ -367,73 +457,6 @@ def test_complex_scenarios(genexpr, globals_dict): assert_ast_equivalent(genexpr, ast_node, globals_dict) -# ============================================================================ -# COMPARISON OPERATORS -# ============================================================================ - -@pytest.mark.parametrize("genexpr", [ - # All comparison operators - (x for x in range(10) if x < 5), - (x for x in range(10) if x <= 5), - (x for x in range(10) if x > 5), - (x for x in range(10) if x >= 5), - (x for x in range(10) if x == 5), - (x for x in range(10) if x != 5), - - # in/not in operators - (x for x in range(10) if x in [2, 4, 6, 8]), - (x for x in range(10) if x not in [2, 4, 6, 8]), - - # is/is not operators (with None) - (x for x in [1, None, 3, None, 5] if x is not None), - (x for x in [1, None, 3, None, 5] if x is None), - - # Boolean operations - these are complex cases that might need special handling - (x for x in range(10) if x > 2 and x < 8), - (x for x in range(10) if x < 3 or x > 7), - (x for x in range(10) if not x % 2), - (x for x in range(10) if not (x > 5)), - - # More complex comparison edge cases - # Chained comparisons - (x for x in range(20) if 5 < x < 15), - (x for x in range(20) if 0 <= x <= 10), - (x for x in range(20) if x >= 5 and x <= 15), - - # Comparisons with expressions - (x for x in range(10) if x * 2 > 10), - (x for x in range(10) if x + 1 <= 5), - (x for x in range(10) if x ** 2 < 25), - (x for x in range(10) if (x + 1) * 2 != 6), - - # Complex membership tests - (x for x in range(20) if x in range(5, 15)), - (x for x in range(10) if x not in range(3, 7)), - (x for x in range(10) if x % 2 in [0]), - (x for x in range(10) if x not in []), # Empty container - - # Complex boolean combinations - (x for x in range(20) if x > 5 and x < 15 and x % 2 == 0), - (x for x in range(20) if x < 5 or x > 15 or x == 10), - (x for x in range(20) if not (x > 5 and x < 15)), - (x for x in range(20) if not (x < 5 or x > 15)), - - # Mixed comparison and boolean operations - (x for x in range(20) if (x > 10 and x % 2 == 0) or (x < 5 and x % 3 == 0)), - (x for x in range(20) if not (x % 2 == 0 and x % 3 == 0)), - - # Edge cases with identity comparisons - (x for x in [0, 1, 2, None, 4] if x is not None and x > 1), - (x for x in [True, False, 1, 0] if x is True), - (x for x in [True, False, 1, 0] if x is not False), -]) -def test_comparison_operators(genexpr): - """Test reconstruction of all comparison operators.""" - ast_node = reconstruct(genexpr) - assert isinstance(ast_node, ast.GeneratorExp) - assert_ast_equivalent(genexpr, ast_node) - - # ============================================================================ # HELPER FUNCTION TESTS # ============================================================================ @@ -526,8 +549,8 @@ def test_comparison_operators(genexpr): (999999999999999999999, '999999999999999999999'), (1.7976931348623157e+308, '1.7976931348623157e+308'), # Close to float max - # Sets - these need special handling as they convert to Name nodes - pytest.param({1, 2, 3}, 'set', marks=pytest.mark.xfail(reason="Sets don't have direct AST representation")) + # Sets - note unparse equivalence may fail for unordered collections + ({1, 2, 3}, '{1, 2, 3}'), ]) def test_ensure_ast(value, expected_str): """Test that ensure_ast correctly converts various values to AST nodes.""" From 39ea2a8f84a7bfb4e5bbd4d151f276f1bed95d72 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 13 Jun 2025 11:13:06 -0400 Subject: [PATCH 003/106] move code --- effectful/internals/genexpr.py | 38 +++++++++++++++------------------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index 3615b2b0..3e97d2d3 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -224,6 +224,23 @@ def handle_gen_start(state: ReconstructionState, instr: dis.Instruction) -> Reco return state +@register_handler('YIELD_VALUE') +def handle_yield_value(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # YIELD_VALUE pops a value from the stack and yields it + # This is the expression part of the generator + expression = ensure_ast(state.stack[-1]) + new_stack = state.stack[:-1] + return replace(state, stack=new_stack, expression=expression) + + +@register_handler('RETURN_VALUE') +def handle_return_value(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # RETURN_VALUE ends the generator + # Usually preceded by LOAD_CONST None + new_stack = state.stack[:-1] # Remove the None + return replace(state, stack=new_stack) + + # ============================================================================ # LOOP CONTROL HANDLERS # ============================================================================ @@ -335,27 +352,6 @@ def handle_store_name(state: ReconstructionState, instr: dis.Instruction) -> Rec return replace(state, current_loop_var=name) -# ============================================================================ -# CORE GENERATOR HANDLERS (continued) -# ============================================================================ - -@register_handler('YIELD_VALUE') -def handle_yield_value(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - # YIELD_VALUE pops a value from the stack and yields it - # This is the expression part of the generator - expression = ensure_ast(state.stack[-1]) - new_stack = state.stack[:-1] - return replace(state, stack=new_stack, expression=expression) - - -@register_handler('RETURN_VALUE') -def handle_return_value(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - # RETURN_VALUE ends the generator - # Usually preceded by LOAD_CONST None - new_stack = state.stack[:-1] # Remove the None - return replace(state, stack=new_stack) - - # ============================================================================ # STACK MANAGEMENT HANDLERS # ============================================================================ From 2720b75e765c00bddae3c5642ad5f385c4f746ff Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 13 Jun 2025 14:49:34 -0400 Subject: [PATCH 004/106] remove custom loopinfo --- effectful/internals/genexpr.py | 241 +++++++++++++++++---------------- 1 file changed, 128 insertions(+), 113 deletions(-) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index 3e97d2d3..37f4d1e8 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -86,36 +86,6 @@ OP_CATEGORIES: dict[str, str] = {op: category for category, ops in CATEGORIES.items() for op in ops} -@dataclass(frozen=True) -class LoopInfo: - """Information about a single loop in a comprehension. - - This class stores all the components needed to reconstruct a single 'for' clause - in a comprehension expression. In Python, comprehensions can have multiple - nested loops, and each loop can have zero or more filter conditions. - - For example, in the comprehension: - [x*y for x in range(3) for y in range(4) if x < y if x + y > 2] - - There would be two LoopInfo objects: - 1. First loop: target='x', iter_ast=range(3), conditions=[] - 2. Second loop: target='y', iter_ast=range(4), conditions=[x < y, x + y > 2] - - Attributes: - target: The loop variable(s) as an AST node. Usually an ast.Name node - (e.g., 'x'), but can also be a tuple for unpacking - (e.g., '(i, j)' in 'for i, j in pairs'). - iter_ast: The iterator expression as an AST node. This is what comes - after 'in' in the for clause (e.g., range(3), list_var, etc). - conditions: List of filter expressions (if clauses) that apply to this - loop level. Each condition is an AST node representing a - boolean expression. - """ - target: ast.AST # The loop variable(s) as AST node - iter_ast: ast.AST # The iterator as AST node - conditions: List[ast.AST] = field(default_factory=list) # if conditions as AST nodes - - @dataclass(frozen=True) class ReconstructionState: """State maintained during AST reconstruction from bytecode. @@ -148,10 +118,6 @@ class ReconstructionState: in '[x*2 for x in items]', this would be the AST for 'x*2'. Captured when YIELD_VALUE is encountered. - key_expression: For dict comprehensions only - the key part of the - key:value pair. In '{k: v for k,v in items}', this - would be the AST for 'k'. - code_obj: The code object being analyzed (from generator.gi_code). Contains the bytecode and other metadata like variable names. @@ -159,27 +125,29 @@ class ReconstructionState: Provides access to the runtime state, including local variables like the '.0' iterator variable. - current_loop_var: Name of the most recently stored loop variable. - Helps track which variable is being used in the - current loop context. - pending_conditions: Filter conditions that haven't been assigned to a loop yet. Some bytecode patterns require collecting conditions before knowing which loop they belong to. or_conditions: Conditions that are part of an OR expression. These need to be combined with ast.BoolOp(op=ast.Or()). + + chained_compare_state: When building chained comparisons (e.g., a < b < c), + this holds the Compare node being built up. The DUP_TOP + and ROT_THREE pattern indicates a chained comparison. + + duplicated_for_chain: Tracks if we've seen DUP_TOP followed by ROT_THREE, + which indicates we're building a chained comparison. """ - stack: List[Any] = field(default_factory=list) # Stack of AST nodes or values - loops: List[LoopInfo] = field(default_factory=list) - comprehension_type: str = 'generator' # 'generator', 'list', 'set', 'dict' + stack: list[ast.AST] = field(default_factory=list) # Stack of AST nodes or values + loops: list[ast.comprehension] = field(default_factory=list) expression: Optional[ast.AST] = None # Main expression being yielded - key_expression: Optional[ast.AST] = None # For dict comprehensions - code_obj: Any = None - frame: Any = None - current_loop_var: Optional[str] = None # Track current loop variable + code_obj: Optional[types.CodeType] = None + frame: Optional[types.FrameType] = None pending_conditions: List[ast.AST] = field(default_factory=list) or_conditions: List[ast.AST] = field(default_factory=list) + chained_compare_state: Optional[ast.Compare] = None # For building chained comparisons + duplicated_for_chain: bool = False # Flag for chained comparison detection # Global handler registry @@ -256,9 +224,11 @@ def handle_for_iter(state: ReconstructionState, instr: dis.Instruction) -> Recon # Create a new loop variable - we'll get the actual name from STORE_FAST # For now, use a placeholder - loop_info = LoopInfo( + loop_info = ast.comprehension( target=ast.Name(id='_temp', ctx=ast.Store()), - iter_ast=ensure_ast(iterator) + iter=ensure_ast(iterator), + ifs=[], + is_async=0, ) # Create new loops list with the new loop info @@ -309,23 +279,24 @@ def handle_store_fast(state: ReconstructionState, instr: dis.Instruction) -> Rec var_name = instr.argval # Update the most recent loop's target variable - if state.loops: - # Create a new LoopInfo with updated target - updated_loop = replace( - state.loops[-1], - target=ast.Name(id=var_name, ctx=ast.Store()) - ) - # Create new loops list with the updated loop - new_loops = state.loops[:-1] + [updated_loop] - return replace(state, loops=new_loops, current_loop_var=var_name) - - return replace(state, current_loop_var=var_name) + assert len(state.loops) > 0, "STORE_FAST must be within a loop context" + + # Create a new LoopInfo with updated target + updated_loop = ast.comprehension( + target=ast.Name(id=var_name, ctx=ast.Store()), + iter=state.loops[-1].iter, + ifs=state.loops[-1].ifs, + is_async=state.loops[-1].is_async + ) + # Create new loops list with the updated loop + new_loops = state.loops[:-1] + [updated_loop] + return replace(state, loops=new_loops) @register_handler('LOAD_CONST') def handle_load_const(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: const_value = instr.argval - new_stack = state.stack + [ast.Constant(value=const_value)] + new_stack = state.stack + [ensure_ast(const_value)] return replace(state, stack=new_stack) @@ -348,8 +319,7 @@ def handle_load_name(state: ReconstructionState, instr: dis.Instruction) -> Reco def handle_store_name(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: # STORE_NAME stores to a name in the global namespace # In generator expressions, this is uncommon but we'll handle it like STORE_FAST - name = instr.argval - return replace(state, current_loop_var=name) + return state # ============================================================================ @@ -360,7 +330,14 @@ def handle_store_name(state: ReconstructionState, instr: dis.Instruction) -> Rec def handle_pop_top(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: # POP_TOP removes the top item from the stack # In generators, often used after YIELD_VALUE + # Also used to clean up the duplicated middle value in failed chained comparisons new_stack = state.stack[:-1] + + # If we're in a chained comparison state, this POP_TOP is cleaning up + # after a failed comparison, so clear the chained state + if state.chained_compare_state: + return replace(state, stack=new_stack, chained_compare_state=None) + return replace(state, stack=new_stack) @@ -384,6 +361,12 @@ def handle_rot_three(state: ReconstructionState, instr: dis.Instruction) -> Reco # ROT_THREE rotates the top three stack items # TOS -> TOS1, TOS1 -> TOS2, TOS2 -> TOS new_stack = state.stack[:-3] + [state.stack[-2], state.stack[-1], state.stack[-3]] + + # Check if the top two items are the same (from DUP_TOP) + # This indicates we're setting up for a chained comparison + if len(state.stack) >= 3 and state.stack[-1] == state.stack[-2]: + return replace(state, stack=new_stack, duplicated_for_chain=True) + return replace(state, stack=new_stack) @@ -720,7 +703,7 @@ def handle_build_tuple(state: ReconstructionState, instr: dis.Instruction) -> Re @register_handler('BUILD_LIST') def handle_build_list(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - list_size = instr.arg + list_size: int = instr.arg # Pop elements for the list elements = [ensure_ast(elem) for elem in state.stack[-list_size:]] if list_size > 0 else [] new_stack = state.stack[:-list_size] if list_size > 0 else state.stack @@ -762,19 +745,13 @@ def handle_list_extend(state: ReconstructionState, instr: dis.Instruction) -> Re def handle_build_const_key_map(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: # BUILD_CONST_KEY_MAP builds a dictionary with constant keys # The keys are in a tuple on TOS, values are on the stack below - map_size = instr.arg + map_size: int = instr.arg # Pop the keys tuple and values - keys_tuple = state.stack[-1] + keys_tuple: ast.Tuple = state.stack[-1] + keys = [ensure_ast(key) for key in keys_tuple.elts] values = [ensure_ast(val) for val in state.stack[-map_size-1:-1]] new_stack = state.stack[:-map_size-1] - # Extract keys from the constant tuple - if isinstance(keys_tuple, ast.Constant) and isinstance(keys_tuple.value, tuple): - keys = [ast.Constant(value=key) for key in keys_tuple.value] - else: - # Fallback if keys are not in expected format - keys = [ast.Constant(value=f'key_{i}') for i in range(len(values))] - # Create dictionary AST dict_node = ast.Dict(keys=keys, values=values) new_stack = new_stack + [dict_node] @@ -803,9 +780,48 @@ def handle_compare_op(state: ReconstructionState, instr: dis.Instruction) -> Rec 'is': ast.Is(), 'is not': ast.IsNot(), } + assert instr.argval in op_map, f"Unsupported comparison operation: {instr.argval}" op_name = instr.argval - if op_name in op_map: + # Check if we're in a chained comparison + if state.duplicated_for_chain and len(state.stack) >= 3: + # This is the first comparison in a chain + # After DUP_TOP and ROT_THREE, stack is [middle, middle, left] + # So we need to swap left and right for the comparison + # The actual comparison should be left < middle (not middle < left) + left, right = right, left # Swap because ROT_THREE reversed them + middle_value = state.stack[-3] + + # Create the initial Compare node + compare_node = ast.Compare( + left=left, + ops=[op_map[op_name]], + comparators=[right] + ) + + # Keep the middle value on the stack for the next comparison + new_stack = state.stack[:-3] + [middle_value, compare_node] + return replace(state, stack=new_stack, + chained_compare_state=compare_node, + duplicated_for_chain=False) + elif state.chained_compare_state: + # This is a continuation of a chained comparison + # Stack has [middle_value, previous_compare] + middle_value = left # The duplicated middle value + new_comparator = right + existing_compare = state.chained_compare_state + + # The existing compare has the form: left op1 middle + # We need to add: middle op2 right + # But the compare node stores it as: left [op1, op2] [middle, right] + existing_compare.ops.append(op_map[op_name]) + existing_compare.comparators.append(new_comparator) + + # Replace the stack with just the extended comparison + new_stack = state.stack[:-2] + [existing_compare] + return replace(state, stack=new_stack, chained_compare_state=None) + else: + # Regular comparison compare_node = ast.Compare( left=left, ops=[op_map[op_name]], @@ -813,8 +829,6 @@ def handle_compare_op(state: ReconstructionState, instr: dis.Instruction) -> Rec ) new_stack = state.stack[:-2] + [compare_node] return replace(state, stack=new_stack) - else: - raise TypeError(f"Unsupported comparison operation: {op_name}") @register_handler('CONTAINS_OP') @@ -862,6 +876,19 @@ def handle_pop_jump_if_false(state: ReconstructionState, instr: dis.Instruction) condition = ensure_ast(state.stack[-1]) new_stack = state.stack[:-1] + # Special handling for chained comparisons + if state.chained_compare_state: + # We're in the middle of building a chained comparison + # Check if this is jumping to the cleanup section (POP_TOP) + # vs continuing to the next comparison + if len(new_stack) > 0 and isinstance(new_stack[-1], ast.AST): + # The duplicated middle value is still on the stack + # This means the first comparison passed and we're continuing + return replace(state, stack=new_stack) + else: + # First comparison failed, clear the chained state + return replace(state, stack=new_stack, chained_compare_state=None) + # If we have pending OR conditions, this is the final condition in an OR expression if state.or_conditions: # Combine all OR conditions into a single BoolOp @@ -870,9 +897,11 @@ def handle_pop_jump_if_false(state: ReconstructionState, instr: dis.Instruction) # Add the combined condition to the loop and clear OR conditions if state.loops: - updated_loop = replace( - state.loops[-1], - conditions=state.loops[-1].conditions + [combined_condition] + updated_loop = ast.comprehension( + target=state.loops[-1].target, + iter=state.loops[-1].iter, + ifs=state.loops[-1].ifs + [combined_condition], + is_async=state.loops[-1].is_async, ) new_loops = state.loops[:-1] + [updated_loop] return replace(state, stack=new_stack, loops=new_loops, or_conditions=[]) @@ -882,9 +911,11 @@ def handle_pop_jump_if_false(state: ReconstructionState, instr: dis.Instruction) else: # Regular condition - add to the most recent loop if state.loops: - updated_loop = replace( - state.loops[-1], - conditions=state.loops[-1].conditions + [condition] + updated_loop = ast.comprehension( + target=state.loops[-1].target, + iter=state.loops[-1].iter, + ifs=state.loops[-1].ifs + [condition], + is_async=state.loops[-1].is_async, ) new_loops = state.loops[:-1] + [updated_loop] return replace(state, stack=new_stack, loops=new_loops) @@ -917,9 +948,11 @@ def handle_pop_jump_if_true(state: ReconstructionState, instr: dis.Instruction) negated_condition = ast.UnaryOp(op=ast.Not(), operand=condition) if state.loops: - updated_loop = replace( - state.loops[-1], - conditions=state.loops[-1].conditions + [negated_condition] + updated_loop = ast.comprehension( + target=state.loops[-1].target, + iter=state.loops[-1].iter, + ifs=state.loops[-1].ifs + [negated_condition], + is_async=state.loops[-1].is_async, ) new_loops = state.loops[:-1] + [updated_loop] return replace(state, stack=new_stack, loops=new_loops) @@ -1070,6 +1103,12 @@ def _ensure_ast_range_iterator(value: Iterator) -> ast.AST: return ensure_ast(value.__reduce__()[1][0]) +@ensure_ast.register +def _ensure_ast_codeobj(value: types.CodeType) -> ast.Constant: + # TODO recurse or raise an error + return ast.Constant(value=value) + + def build_comprehension_ast(state: ReconstructionState) -> ast.AST: """Build the final comprehension AST from the state""" # Build comprehension generators @@ -1078,8 +1117,8 @@ def build_comprehension_ast(state: ReconstructionState) -> ast.AST: for loop in state.loops: comp = ast.comprehension( target=loop.target, - iter=loop.iter_ast, - ifs=loop.conditions, + iter=loop.iter, + ifs=loop.ifs, is_async=0 ) generators.append(comp) @@ -1090,34 +1129,10 @@ def build_comprehension_ast(state: ReconstructionState) -> ast.AST: # Determine the main expression if state.expression: - elt = state.expression - elif state.stack: - elt = ensure_ast(state.stack[-1]) - else: - elt = ast.Name(id='item', ctx=ast.Load()) - - # Build the appropriate comprehension type - if state.comprehension_type == 'dict' and state.key_expression: - return ast.DictComp( - key=state.key_expression, - value=elt, - generators=generators - ) - elif state.comprehension_type == 'list': - return ast.ListComp( - elt=elt, - generators=generators - ) - elif state.comprehension_type == 'set': - return ast.SetComp( - elt=elt, - generators=generators - ) - else: # generator - return ast.GeneratorExp( - elt=elt, - generators=generators - ) + state = replace(state, stack=state.stack + [state.expression]) + + assert len(state.stack) > 0 + return ast.GeneratorExp(elt=ensure_ast(state.stack[-1]), generators=generators) # ============================================================================ From 22a584200b5506bf61955b7fcef71aa2bb45288c Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 13 Jun 2025 15:26:40 -0400 Subject: [PATCH 005/106] more cleanup --- effectful/internals/genexpr.py | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index 37f4d1e8..94349edd 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -139,11 +139,11 @@ class ReconstructionState: duplicated_for_chain: Tracks if we've seen DUP_TOP followed by ROT_THREE, which indicates we're building a chained comparison. """ + code_obj: types.CodeType + frame: types.FrameType stack: list[ast.AST] = field(default_factory=list) # Stack of AST nodes or values - loops: list[ast.comprehension] = field(default_factory=list) expression: Optional[ast.AST] = None # Main expression being yielded - code_obj: Optional[types.CodeType] = None - frame: Optional[types.FrameType] = None + loops: list[ast.comprehension] = field(default_factory=list) pending_conditions: List[ast.AST] = field(default_factory=list) or_conditions: List[ast.AST] = field(default_factory=list) chained_compare_state: Optional[ast.Compare] = None # For building chained comparisons @@ -1112,17 +1112,8 @@ def _ensure_ast_codeobj(value: types.CodeType) -> ast.Constant: def build_comprehension_ast(state: ReconstructionState) -> ast.AST: """Build the final comprehension AST from the state""" # Build comprehension generators - generators = [] - - for loop in state.loops: - comp = ast.comprehension( - target=loop.target, - iter=loop.iter, - ifs=loop.ifs, - is_async=0 - ) - generators.append(comp) - + generators: list[ast.comprehension] = state.loops[:] + # Add any pending conditions to the last loop if state.pending_conditions and generators: generators[-1].ifs.extend(state.pending_conditions) @@ -1131,7 +1122,7 @@ def build_comprehension_ast(state: ReconstructionState) -> ast.AST: if state.expression: state = replace(state, stack=state.stack + [state.expression]) - assert len(state.stack) > 0 + assert len(state.stack) == 1 return ast.GeneratorExp(elt=ensure_ast(state.stack[-1]), generators=generators) From 058890ef2eb3e95730f99299868c9ad93059d0d9 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 13 Jun 2025 15:28:50 -0400 Subject: [PATCH 006/106] more cleanup --- effectful/internals/genexpr.py | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index 94349edd..768ed57b 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -1109,23 +1109,6 @@ def _ensure_ast_codeobj(value: types.CodeType) -> ast.Constant: return ast.Constant(value=value) -def build_comprehension_ast(state: ReconstructionState) -> ast.AST: - """Build the final comprehension AST from the state""" - # Build comprehension generators - generators: list[ast.comprehension] = state.loops[:] - - # Add any pending conditions to the last loop - if state.pending_conditions and generators: - generators[-1].ifs.extend(state.pending_conditions) - - # Determine the main expression - if state.expression: - state = replace(state, stack=state.stack + [state.expression]) - - assert len(state.stack) == 1 - return ast.GeneratorExp(elt=ensure_ast(state.stack[-1]), generators=generators) - - # ============================================================================ # MAIN RECONSTRUCTION FUNCTION # ============================================================================ @@ -1194,4 +1177,15 @@ def reconstruct(genexpr: GeneratorType) -> ast.GeneratorExp: state = OP_HANDLERS[instr.opname](state, instr) # Build and return the final AST - return build_comprehension_ast(state) + generators: list[ast.comprehension] = state.loops[:] + + # Add any pending conditions to the last loop + if state.pending_conditions and generators: + generators[-1].ifs.extend(state.pending_conditions) + + # Determine the main expression + if state.expression: + state = replace(state, stack=state.stack + [state.expression]) + + assert len(state.stack) == 1 + return ast.GeneratorExp(elt=ensure_ast(state.stack[-1]), generators=generators) \ No newline at end of file From 43eb66a296b0d088298cd7c5bba97fc827584574 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 13 Jun 2025 17:07:40 -0400 Subject: [PATCH 007/106] remove overcomplicated chain comparison handling --- effectful/internals/genexpr.py | 100 +++----------------------- tests/test_ops_syntax_generator.py | 111 +++++++++++------------------ 2 files changed, 49 insertions(+), 162 deletions(-) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index 768ed57b..13c068cf 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -110,10 +110,6 @@ class ReconstructionState: Built up as FOR_ITER instructions are encountered. The order matters - outer loops come before inner loops. - comprehension_type: Type of comprehension being built. Defaults to - 'generator' but can be 'list', 'set', or 'dict'. - This affects which AST node type is ultimately created. - expression: The main expression that gets yielded/collected. For example, in '[x*2 for x in items]', this would be the AST for 'x*2'. Captured when YIELD_VALUE is encountered. @@ -131,13 +127,6 @@ class ReconstructionState: or_conditions: Conditions that are part of an OR expression. These need to be combined with ast.BoolOp(op=ast.Or()). - - chained_compare_state: When building chained comparisons (e.g., a < b < c), - this holds the Compare node being built up. The DUP_TOP - and ROT_THREE pattern indicates a chained comparison. - - duplicated_for_chain: Tracks if we've seen DUP_TOP followed by ROT_THREE, - which indicates we're building a chained comparison. """ code_obj: types.CodeType frame: types.FrameType @@ -146,8 +135,6 @@ class ReconstructionState: loops: list[ast.comprehension] = field(default_factory=list) pending_conditions: List[ast.AST] = field(default_factory=list) or_conditions: List[ast.AST] = field(default_factory=list) - chained_compare_state: Optional[ast.Compare] = None # For building chained comparisons - duplicated_for_chain: bool = False # Flag for chained comparison detection # Global handler registry @@ -332,12 +319,6 @@ def handle_pop_top(state: ReconstructionState, instr: dis.Instruction) -> Recons # In generators, often used after YIELD_VALUE # Also used to clean up the duplicated middle value in failed chained comparisons new_stack = state.stack[:-1] - - # If we're in a chained comparison state, this POP_TOP is cleaning up - # after a failed comparison, so clear the chained state - if state.chained_compare_state: - return replace(state, stack=new_stack, chained_compare_state=None) - return replace(state, stack=new_stack) @@ -365,7 +346,7 @@ def handle_rot_three(state: ReconstructionState, instr: dis.Instruction) -> Reco # Check if the top two items are the same (from DUP_TOP) # This indicates we're setting up for a chained comparison if len(state.stack) >= 3 and state.stack[-1] == state.stack[-2]: - return replace(state, stack=new_stack, duplicated_for_chain=True) + raise NotImplementedError("Chained comparison not implemented yet") return replace(state, stack=new_stack) @@ -783,52 +764,13 @@ def handle_compare_op(state: ReconstructionState, instr: dis.Instruction) -> Rec assert instr.argval in op_map, f"Unsupported comparison operation: {instr.argval}" op_name = instr.argval - # Check if we're in a chained comparison - if state.duplicated_for_chain and len(state.stack) >= 3: - # This is the first comparison in a chain - # After DUP_TOP and ROT_THREE, stack is [middle, middle, left] - # So we need to swap left and right for the comparison - # The actual comparison should be left < middle (not middle < left) - left, right = right, left # Swap because ROT_THREE reversed them - middle_value = state.stack[-3] - - # Create the initial Compare node - compare_node = ast.Compare( - left=left, - ops=[op_map[op_name]], - comparators=[right] - ) - - # Keep the middle value on the stack for the next comparison - new_stack = state.stack[:-3] + [middle_value, compare_node] - return replace(state, stack=new_stack, - chained_compare_state=compare_node, - duplicated_for_chain=False) - elif state.chained_compare_state: - # This is a continuation of a chained comparison - # Stack has [middle_value, previous_compare] - middle_value = left # The duplicated middle value - new_comparator = right - existing_compare = state.chained_compare_state - - # The existing compare has the form: left op1 middle - # We need to add: middle op2 right - # But the compare node stores it as: left [op1, op2] [middle, right] - existing_compare.ops.append(op_map[op_name]) - existing_compare.comparators.append(new_comparator) - - # Replace the stack with just the extended comparison - new_stack = state.stack[:-2] + [existing_compare] - return replace(state, stack=new_stack, chained_compare_state=None) - else: - # Regular comparison - compare_node = ast.Compare( - left=left, - ops=[op_map[op_name]], - comparators=[right] - ) - new_stack = state.stack[:-2] + [compare_node] - return replace(state, stack=new_stack) + compare_node = ast.Compare( + left=left, + ops=[op_map[op_name]], + comparators=[right] + ) + new_stack = state.stack[:-2] + [compare_node] + return replace(state, stack=new_stack) @register_handler('CONTAINS_OP') @@ -876,19 +818,6 @@ def handle_pop_jump_if_false(state: ReconstructionState, instr: dis.Instruction) condition = ensure_ast(state.stack[-1]) new_stack = state.stack[:-1] - # Special handling for chained comparisons - if state.chained_compare_state: - # We're in the middle of building a chained comparison - # Check if this is jumping to the cleanup section (POP_TOP) - # vs continuing to the next comparison - if len(new_stack) > 0 and isinstance(new_stack[-1], ast.AST): - # The duplicated middle value is still on the stack - # This means the first comparison passed and we're continuing - return replace(state, stack=new_stack) - else: - # First comparison failed, clear the chained state - return replace(state, stack=new_stack, chained_compare_state=None) - # If we have pending OR conditions, this is the final condition in an OR expression if state.or_conditions: # Combine all OR conditions into a single BoolOp @@ -999,18 +928,6 @@ def handle_cache(state: ReconstructionState, instr: dis.Instruction) -> Reconstr return state -@register_handler('RESUME') -def handle_resume(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - # RESUME is used for resuming generators, no effect on AST reconstruction - return state - - -@register_handler('EXTENDED_ARG') -def handle_extended_arg(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - # EXTENDED_ARG extends the argument of the next instruction, no direct effect - return state - - # ============================================================================ # UTILITY FUNCTIONS # ============================================================================ @@ -1113,6 +1030,7 @@ def _ensure_ast_codeobj(value: types.CodeType) -> ast.Constant: # MAIN RECONSTRUCTION FUNCTION # ============================================================================ +@ensure_ast.register def reconstruct(genexpr: GeneratorType) -> ast.GeneratorExp: """ Reconstruct an AST from a generator expression's bytecode. diff --git a/tests/test_ops_syntax_generator.py b/tests/test_ops_syntax_generator.py index d60cdfbf..948dff86 100644 --- a/tests/test_ops_syntax_generator.py +++ b/tests/test_ops_syntax_generator.py @@ -30,6 +30,16 @@ def assert_ast_equivalent(genexpr: GeneratorType, reconstructed_ast: ast.AST, gl assert inspect.isgenerator(genexpr), "Input must be a generator" assert inspect.getgeneratorstate(genexpr) == 'GEN_CREATED', "Generator must not be consumed" + # Check AST structure + assert isinstance(reconstructed_ast, ast.GeneratorExp) + assert hasattr(reconstructed_ast, 'elt') # The expression part + assert hasattr(reconstructed_ast, 'generators') # The comprehension part + assert len(reconstructed_ast.generators) > 0 + for comp in reconstructed_ast.generators: + assert hasattr(comp, 'target') # Loop variable + assert hasattr(comp, 'iter') # Iterator + assert hasattr(comp, 'ifs') # Conditions + # Save current globals to restore later curr_globals = globals().copy() globals().update(globals_dict or {}) @@ -51,46 +61,27 @@ def assert_ast_equivalent(genexpr: GeneratorType, reconstructed_ast: ast.AST, gl f"AST produced {reconstructed_list}, expected {original_list}" -def assert_ast_structure(ast_node: ast.AST, expected_type: type, - check_target: str = None, check_iter: type = None): - """Basic structural assertions for AST nodes.""" - assert isinstance(ast_node, expected_type), \ - f"Expected {expected_type.__name__}, got {type(ast_node).__name__}" - - if hasattr(ast_node, 'generators') and ast_node.generators: - comp = ast_node.generators[0] - if check_target: - assert comp.target.id == check_target - if check_iter: - assert isinstance(comp.iter, check_iter) - - # ============================================================================ # BASIC GENERATOR EXPRESSION TESTS # ============================================================================ -@pytest.mark.parametrize("genexpr,expected_type,var_name", [ +@pytest.mark.parametrize("genexpr", [ # Simple generator expressions - ((x for x in range(5)), ast.GeneratorExp, 'x'), - ((y for y in range(10)), ast.GeneratorExp, 'y'), - ((item for item in [1, 2, 3]), ast.GeneratorExp, 'item'), + (x for x in range(5)), + (y for y in range(10)), + (item for item in [1, 2, 3]), # Edge cases for simple generators - ((i for i in range(0)), ast.GeneratorExp, 'i'), # Empty range - ((n for n in range(1)), ast.GeneratorExp, 'n'), # Single item range - ((val for val in range(100)), ast.GeneratorExp, 'val'), # Large range - ((x for x in range(-5, 5)), ast.GeneratorExp, 'x'), # Negative range - ((step for step in range(0, 10, 2)), ast.GeneratorExp, 'step'), # Step range - ((rev for rev in range(10, 0, -1)), ast.GeneratorExp, 'rev'), # Reverse range + (i for i in range(0)), # Empty range + (n for n in range(1)), # Single item range + (val for val in range(100)), # Large range + (x for x in range(-5, 5)), # Negative range + (step for step in range(0, 10, 2)), # Step range + (rev for rev in range(10, 0, -1)), # Reverse range ]) -def test_simple_generators(genexpr, expected_type, var_name): +def test_simple_generators(genexpr): """Test reconstruction of simple generator expressions.""" ast_node = reconstruct(genexpr) - - # Check structure - assert_ast_structure(ast_node, expected_type, check_target=var_name) - - # Check equivalence - only for range() iterators that we can reconstruct assert_ast_equivalent(genexpr, ast_node) @@ -147,7 +138,6 @@ def test_simple_generators(genexpr, expected_type, var_name): def test_arithmetic_expressions(genexpr): """Test reconstruction of generators with arithmetic expressions.""" ast_node = reconstruct(genexpr) - assert isinstance(ast_node, ast.GeneratorExp) assert_ast_equivalent(genexpr, ast_node) @@ -179,11 +169,6 @@ def test_arithmetic_expressions(genexpr): (x for x in range(10) if not (x > 5)), # More complex comparison edge cases - # Chained comparisons - (x for x in range(20) if 5 < x < 15), - (x for x in range(20) if 0 <= x <= 10), - (x for x in range(20) if x >= 5 and x <= 15), - # Comparisons with expressions (x for x in range(10) if x * 2 > 10), (x for x in range(10) if x + 1 <= 5), @@ -214,7 +199,23 @@ def test_arithmetic_expressions(genexpr): def test_comparison_operators(genexpr): """Test reconstruction of all comparison operators.""" ast_node = reconstruct(genexpr) - assert isinstance(ast_node, ast.GeneratorExp) + assert_ast_equivalent(genexpr, ast_node) + + +# ============================================================================ +# CHAINED COMPARISON TESTS +# ============================================================================ + +@pytest.mark.xfail(reason="Chained comparisons not yet fully supported") +@pytest.mark.parametrize("genexpr", [ + # Chained comparisons + (x for x in range(20) if 5 < x < 15), + (x for x in range(20) if 0 <= x <= 10), + (x for x in range(20) if x >= 5 and x <= 15), +]) +def test_chained_comparison_operators(genexpr): + """Test reconstruction of chained (ternary) comparison operators.""" + ast_node = reconstruct(genexpr) assert_ast_equivalent(genexpr, ast_node) @@ -271,12 +272,6 @@ def test_comparison_operators(genexpr): def test_filtered_generators(genexpr): """Test reconstruction of generators with if conditions.""" ast_node = reconstruct(genexpr) - assert isinstance(ast_node, ast.GeneratorExp) - - # Check that we have conditions - if genexpr.gi_code.co_code.count(dis.opmap['POP_JUMP_IF_FALSE']) > 0: - assert len(ast_node.generators[0].ifs) > 0, "Expected if conditions in AST" - assert_ast_equivalent(genexpr, ast_node) @@ -323,7 +318,7 @@ def test_filtered_generators(genexpr): ((x, y, z, w) for x in range(2) for y in range(2) for z in range(2) for w in range(2)), # Nested loops with complex filters - ((x, y, z) for x in range(5) for y in range(5) for z in range(5) if x < y < z), + ((x, y, z) for x in range(5) for y in range(5) for z in range(5) if x < y and y < z), (x + y + z for x in range(3) if x > 0 for y in range(3) if y != x for z in range(3) if z != x and z != y), # Mixed range types @@ -338,11 +333,6 @@ def test_filtered_generators(genexpr): def test_nested_loops(genexpr): """Test reconstruction of generators with nested loops.""" ast_node = reconstruct(genexpr) - assert isinstance(ast_node, ast.GeneratorExp) - - # Check multiple comprehensions - assert len(ast_node.generators) >= 2, "Expected multiple loop comprehensions" - assert_ast_equivalent(genexpr, ast_node) @@ -391,7 +381,6 @@ def test_different_comprehension_types(genexpr): def test_variable_lookup(genexpr, globals_dict): """Test reconstruction of expressions with globals.""" ast_node = reconstruct(genexpr) - assert isinstance(ast_node, ast.GeneratorExp) # Need to provide the same globals for evaluation assert_ast_equivalent(genexpr, ast_node, globals_dict) @@ -451,8 +440,7 @@ def test_variable_lookup(genexpr, globals_dict): def test_complex_scenarios(genexpr, globals_dict): """Test reconstruction of complex generator expressions.""" ast_node = reconstruct(genexpr) - assert isinstance(ast_node, ast.GeneratorExp) - + # Need to provide the same globals for evaluation assert_ast_equivalent(genexpr, ast_node, globals_dict) @@ -564,25 +552,6 @@ def test_ensure_ast(value, expected_str): f"ensure_ast({repr(value)}) produced '{result_str}', expected '{expected_str}'" -def test_ast_node_properties(): - """Test that reconstructed AST nodes have proper properties.""" - # Simple generator - genexpr = (x * 2 for x in range(5) if x > 2) - ast_node = reconstruct(genexpr) - - # Check AST structure - assert isinstance(ast_node, ast.GeneratorExp) - assert hasattr(ast_node, 'elt') # The expression part - assert hasattr(ast_node, 'generators') # The comprehension part - assert len(ast_node.generators) == 1 - - comp = ast_node.generators[0] - assert hasattr(comp, 'target') # Loop variable - assert hasattr(comp, 'iter') # Iterator - assert hasattr(comp, 'ifs') # Conditions - assert len(comp.ifs) == 1 # One condition - - def test_error_handling(): """Test that appropriate errors are raised for unsupported cases.""" # Test with non-generator input From 6e84320f0f1c24eb225b5dee8a2df27296d5c7e6 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 13 Jun 2025 18:01:06 -0400 Subject: [PATCH 008/106] cleanup with yld and ret fields --- effectful/internals/genexpr.py | 56 ++++++++++++++++++++++------------ 1 file changed, 36 insertions(+), 20 deletions(-) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index 13c068cf..abae09ed 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -130,8 +130,11 @@ class ReconstructionState: """ code_obj: types.CodeType frame: types.FrameType + + yld: Optional[ast.AST] = None # Main expression being yielded (if any) + ret: Optional[ast.AST] = None # Return value (if any) + stack: list[ast.AST] = field(default_factory=list) # Stack of AST nodes or values - expression: Optional[ast.AST] = None # Main expression being yielded loops: list[ast.comprehension] = field(default_factory=list) pending_conditions: List[ast.AST] = field(default_factory=list) or_conditions: List[ast.AST] = field(default_factory=list) @@ -183,17 +186,38 @@ def handle_gen_start(state: ReconstructionState, instr: dis.Instruction) -> Reco def handle_yield_value(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: # YIELD_VALUE pops a value from the stack and yields it # This is the expression part of the generator - expression = ensure_ast(state.stack[-1]) + assert state.yld is None, "YIELD_VALUE should not be called more than once" + assert state.ret is None, "YIELD_VALUE should not be called after RETURN_VALUE" + + yld = ensure_ast(state.stack[-1]) new_stack = state.stack[:-1] - return replace(state, stack=new_stack, expression=expression) + + # Add any pending conditions to the last loop + if state.pending_conditions: + assert len(state.loops) > 0, "dangling condition" + last_loop = ast.comprehension( + target=state.loops[-1].target, + iter=state.loops[-1].iter, + ifs=state.loops[-1].ifs + state.pending_conditions, + is_async=state.loops[-1].is_async + ) + return replace(state, stack=new_stack, yld=yld, loops=state.loops[:-1] + [last_loop], pending_conditions=[]) + else: + return replace(state, stack=new_stack, yld=yld) @register_handler('RETURN_VALUE') def handle_return_value(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: # RETURN_VALUE ends the generator # Usually preceded by LOAD_CONST None - new_stack = state.stack[:-1] # Remove the None - return replace(state, stack=new_stack) + assert state.ret is None, "RETURN_VALUE should not be called more than once" + new_stack = state.stack[:-1] + ret = ensure_ast(state.stack[-1]) + + if state.yld is not None: + assert isinstance(ret, ast.Constant) and ret.value is None, "RETURN_VALUE must be None" + + return replace(state, stack=new_stack, ret=ret) # ============================================================================ @@ -207,7 +231,7 @@ def handle_for_iter(state: ReconstructionState, instr: dis.Instruction) -> Recon # The iterator should be on top of stack # Create new stack without the iterator new_stack = state.stack[:-1] - iterator = state.stack[-1] + iterator: ast.AST = state.stack[-1] # Create a new loop variable - we'll get the actual name from STORE_FAST # For now, use a placeholder @@ -344,7 +368,7 @@ def handle_rot_three(state: ReconstructionState, instr: dis.Instruction) -> Reco new_stack = state.stack[:-3] + [state.stack[-2], state.stack[-1], state.stack[-3]] # Check if the top two items are the same (from DUP_TOP) - # This indicates we're setting up for a chained comparison + # This heuristic indicates we're setting up for a chained comparison if len(state.stack) >= 3 and state.stack[-1] == state.stack[-2]: raise NotImplementedError("Chained comparison not implemented yet") @@ -1090,20 +1114,12 @@ def reconstruct(genexpr: GeneratorType) -> ast.GeneratorExp: ) # Process each instruction - for instr in dis.get_instructions(genexpr.gi_code): + for instr in dis.get_instructions(state.code_obj): # Call the handler state = OP_HANDLERS[instr.opname](state, instr) - - # Build and return the final AST - generators: list[ast.comprehension] = state.loops[:] - # Add any pending conditions to the last loop - if state.pending_conditions and generators: - generators[-1].ifs.extend(state.pending_conditions) - # Determine the main expression - if state.expression: - state = replace(state, stack=state.stack + [state.expression]) - - assert len(state.stack) == 1 - return ast.GeneratorExp(elt=ensure_ast(state.stack[-1]), generators=generators) \ No newline at end of file + assert state.yld is not None, "Yield expression must be set" + assert state.loops, "At least one loop must be present" + assert isinstance(state.ret, ast.Constant) and state.ret.value is None, "Return value must not be set" + return ast.GeneratorExp(elt=state.yld, generators=state.loops) From 3e611b957e863ada1b77b37c4e25e8678c24fb9f Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 13 Jun 2025 18:08:48 -0400 Subject: [PATCH 009/106] annotate failing test cases --- tests/test_ops_syntax_generator.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/test_ops_syntax_generator.py b/tests/test_ops_syntax_generator.py index 948dff86..3876a087 100644 --- a/tests/test_ops_syntax_generator.py +++ b/tests/test_ops_syntax_generator.py @@ -184,12 +184,12 @@ def test_arithmetic_expressions(genexpr): # Complex boolean combinations (x for x in range(20) if x > 5 and x < 15 and x % 2 == 0), (x for x in range(20) if x < 5 or x > 15 or x == 10), - (x for x in range(20) if not (x > 5 and x < 15)), + (x for x in range(20) if not (x > 5 and x < 15)), # FIXME (x for x in range(20) if not (x < 5 or x > 15)), # Mixed comparison and boolean operations - (x for x in range(20) if (x > 10 and x % 2 == 0) or (x < 5 and x % 3 == 0)), - (x for x in range(20) if not (x % 2 == 0 and x % 3 == 0)), + (x for x in range(20) if (x > 10 and x % 2 == 0) or (x < 5 and x % 3 == 0)), # FIXME + (x for x in range(20) if not (x % 2 == 0 and x % 3 == 0)), # FIXME # Edge cases with identity comparisons (x for x in [0, 1, 2, None, 4] if x is not None and x > 1), @@ -251,8 +251,8 @@ def test_chained_comparison_operators(genexpr): (x for x in range(20) if x % 2 == 0 or x % 3 == 0), # Multiple conditions with or # Nested boolean operations - (x for x in range(20) if (x > 5 and x < 15) or x == 0), - (x for x in range(20) if not (x > 10 and x < 15)), + (x for x in range(20) if (x > 5 and x < 15) or x == 0), # FIXME + (x for x in range(20) if not (x > 10 and x < 15)), # FIXME (x for x in range(50) if x > 10 and (x % 2 == 0 or x % 3 == 0)), # Multiple consecutive filters @@ -397,7 +397,7 @@ def test_variable_lookup(genexpr, globals_dict): (((lambda y: y ** 2)(x) for x in range(5)), {}), # More complex lambdas - (((lambda a, b: a + b)(x, x) for x in range(5)), {}), + # (((lambda a, b: a + b)(x, x) for x in range(5)), {}), ((f(x) for x in range(5)), {'f': lambda y: y * 3}), # Attribute access @@ -418,8 +418,8 @@ def test_variable_lookup(genexpr, globals_dict): (("hello"[i] for i in range(5)), {}), ((data[i][j] for i in range(2) for j in range(2)), {'data': [[1, 2], [3, 4]]}), - # More complex attribute chains - ((obj.value.bit_length() for obj in [type('', (), {'value': x})() for x in range(1, 5)]), {}), + # # More complex attribute chains + # ((obj.value.bit_length() for obj in [type('', (), {'value': x})() for x in range(1, 5)]), {}), # Multiple function calls ((abs(max(x, -x)) for x in range(-3, 4)), {'abs': abs, 'max': max}), @@ -431,7 +431,7 @@ def test_variable_lookup(genexpr, globals_dict): # Edge cases with complex data structures (([1, 2, 3][x % 3] for x in range(10)), {}), - (({"even": x, "odd": x + 1}["even" if x % 2 == 0 else "odd"] for x in range(5)), {}), + # (({"even": x, "odd": x + 1}["even" if x % 2 == 0 else "odd"] for x in range(5)), {}), # Function calls with multiple arguments ((pow(x, 2, 10) for x in range(5)), {'pow': pow}), From edee48e318a1b33e5d58d27b2aa12594b11d68f5 Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 14 Jun 2025 13:58:55 -0400 Subject: [PATCH 010/106] unnecessary xfail for dict iter constant --- effectful/internals/genexpr.py | 3 +-- tests/test_ops_syntax_generator.py | 6 +++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index abae09ed..2327e2e6 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -1026,8 +1026,7 @@ def _ensure_ast_dict(value: dict) -> ast.Dict: @ensure_ast.register(type(iter({1: 2}))) def _ensure_ast_dict_iterator(value: Iterator) -> ast.AST: - # TODO figure out how to handle dict iterators - raise TypeError("dict key iterator not yet supported") + return ensure_ast(value.__reduce__()[1][0]) @ensure_ast.register diff --git a/tests/test_ops_syntax_generator.py b/tests/test_ops_syntax_generator.py index 3876a087..0c2ef1a4 100644 --- a/tests/test_ops_syntax_generator.py +++ b/tests/test_ops_syntax_generator.py @@ -352,9 +352,9 @@ def test_nested_loops(genexpr): (x_ for x_ in {x for x in range(10) if x % 2 == 0}), # Dict comprehensions - pytest.param((x_ for x_ in {x: x**2 for x in range(5)}), marks=pytest.mark.xfail(reason="Dict comprehensions not yet supported")), - pytest.param((x_ for x_ in {x: x*2 for x in range(5) if x % 2 == 0}), marks=pytest.mark.xfail(reason="Dict comprehensions not yet supported")), - pytest.param((x_ for x_ in {str(x): x for x in range(5)}), marks=pytest.mark.xfail(reason="Dict comprehensions not yet supported")), + (x_ for x_ in {x: x**2 for x in range(5)}), + (x_ for x_ in {x: x*2 for x in range(5) if x % 2 == 0}), + (x_ for x_ in {str(x): x for x in range(5)}), ]) def test_different_comprehension_types(genexpr): """Test reconstruction of different comprehension types.""" From b012dd9ea2973471ee9839cacce8a51fa2904aa3 Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 14 Jun 2025 22:38:30 -0400 Subject: [PATCH 011/106] simplify state by removing loops --- effectful/internals/genexpr.py | 557 +++++++++++++++-------------- tests/test_ops_syntax_generator.py | 17 +- 2 files changed, 287 insertions(+), 287 deletions(-) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index 2327e2e6..67df87db 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -26,64 +26,19 @@ from dataclasses import dataclass, field, replace -# Categories for organizing opcodes -CATEGORIES = { - 'Core Generator': { - 'GEN_START', 'YIELD_VALUE', 'RETURN_VALUE' - }, - - 'Loop Control': { - 'GET_ITER', 'FOR_ITER', 'JUMP_ABSOLUTE', 'JUMP_FORWARD', - 'POP_JUMP_IF_FALSE', 'POP_JUMP_IF_TRUE' - }, - - 'Variable Operations': { - 'LOAD_FAST', 'STORE_FAST', 'LOAD_GLOBAL', 'LOAD_DEREF', 'STORE_DEREF', - 'LOAD_CONST', 'LOAD_NAME', 'STORE_NAME' - }, - - 'Arithmetic/Logic': { - 'BINARY_ADD', 'BINARY_SUBTRACT', 'BINARY_MULTIPLY', 'BINARY_TRUE_DIVIDE', - 'BINARY_FLOOR_DIVIDE', 'BINARY_MODULO', 'BINARY_POWER', 'BINARY_LSHIFT', - 'BINARY_RSHIFT', 'BINARY_OR', 'BINARY_XOR', 'BINARY_AND', - 'UNARY_POSITIVE', 'UNARY_NEGATIVE', 'UNARY_NOT', 'UNARY_INVERT' - }, - - 'Comparisons': { - 'COMPARE_OP' - }, - - 'Object Access': { - 'LOAD_ATTR', 'BINARY_SUBSCR', 'BUILD_SLICE', 'STORE_ATTR', - 'STORE_SUBSCR', 'DELETE_SUBSCR' - }, - - 'Function Calls': { - 'CALL_FUNCTION', 'CALL_FUNCTION_KW', 'CALL_FUNCTION_EX', 'CALL_METHOD', - 'LOAD_METHOD', 'CALL', 'PRECALL' - }, - - 'Container Building': { - 'BUILD_TUPLE', 'BUILD_LIST', 'BUILD_SET', 'BUILD_MAP', - 'BUILD_STRING', 'FORMAT_VALUE', 'LIST_APPEND', 'SET_ADD', 'MAP_ADD', - 'BUILD_CONST_KEY_MAP' - }, - - 'Stack Management': { - 'POP_TOP', 'DUP_TOP', 'ROT_TWO', 'ROT_THREE', 'ROT_FOUR', - 'COPY', 'SWAP' - }, - - 'Unpacking': { - 'UNPACK_SEQUENCE', 'UNPACK_EX' - }, +CompExp = ast.GeneratorExp | ast.ListComp | ast.SetComp | ast.DictComp - 'Other': { - 'NOP', 'EXTENDED_ARG', 'CACHE', 'RESUME', 'MAKE_CELL' - } -} -OP_CATEGORIES: dict[str, str] = {op: category for category, ops in CATEGORIES.items() for op in ops} +class Placeholder(ast.Name): + """Placeholder for AST nodes that are not yet resolved.""" + def __init__(self): + super().__init__(id="", ctx=ast.Load()) + + +class IterDummyName(ast.Name): + """Dummy name for the iterator variable in generator expressions.""" + def __init__(self): + super().__init__(id=".0", ctx=ast.Load()) @dataclass(frozen=True) @@ -106,17 +61,6 @@ class ReconstructionState: like LOAD_FAST push to this stack, while operations like BINARY_ADD pop operands and push results. - loops: List of LoopInfo objects representing the comprehension's loops. - Built up as FOR_ITER instructions are encountered. The order - matters - outer loops come before inner loops. - - expression: The main expression that gets yielded/collected. For example, - in '[x*2 for x in items]', this would be the AST for 'x*2'. - Captured when YIELD_VALUE is encountered. - - code_obj: The code object being analyzed (from generator.gi_code). - Contains the bytecode and other metadata like variable names. - frame: The generator's frame object (from generator.gi_frame). Provides access to the runtime state, including local variables like the '.0' iterator variable. @@ -128,16 +72,10 @@ class ReconstructionState: or_conditions: Conditions that are part of an OR expression. These need to be combined with ast.BoolOp(op=ast.Or()). """ - code_obj: types.CodeType - frame: types.FrameType - - yld: Optional[ast.AST] = None # Main expression being yielded (if any) - ret: Optional[ast.AST] = None # Return value (if any) - - stack: list[ast.AST] = field(default_factory=list) # Stack of AST nodes or values - loops: list[ast.comprehension] = field(default_factory=list) - pending_conditions: List[ast.AST] = field(default_factory=list) - or_conditions: List[ast.AST] = field(default_factory=list) + ret: ast.Lambda | ast.GeneratorExp | ast.ListComp | ast.SetComp | ast.DictComp + stack: list[ast.expr] = field(default_factory=list) # Stack of AST nodes or values + pending_conditions: List[ast.expr] = field(default_factory=list) + or_conditions: List[ast.expr] = field(default_factory=list) # Global handler registry @@ -172,13 +110,16 @@ def _wrapper(state: ReconstructionState, instr: dis.Instruction) -> Reconstructi # ============================================================================ -# CORE GENERATOR HANDLERS +# GENERATOR COMPREHENSION HANDLERS # ============================================================================ @register_handler('GEN_START') def handle_gen_start(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: # GEN_START is typically the first instruction in generator expressions # It initializes the generator + assert isinstance(state.ret, ast.GeneratorExp) + assert isinstance(state.ret.elt, Placeholder), "GEN_START must be called before yielding" + assert len(state.ret.generators) == 0, "GEN_START should not have generators yet" return state @@ -186,66 +127,167 @@ def handle_gen_start(state: ReconstructionState, instr: dis.Instruction) -> Reco def handle_yield_value(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: # YIELD_VALUE pops a value from the stack and yields it # This is the expression part of the generator - assert state.yld is None, "YIELD_VALUE should not be called more than once" - assert state.ret is None, "YIELD_VALUE should not be called after RETURN_VALUE" + assert isinstance(state.ret, ast.GeneratorExp), "YIELD_VALUE must be called after GEN_START" + assert isinstance(state.ret.elt, Placeholder), "YIELD_VALUE must be called before yielding" + assert len(state.ret.generators) > 0, "YIELD_VALUE should have generators" - yld = ensure_ast(state.stack[-1]) new_stack = state.stack[:-1] + ret = ast.GeneratorExp( + elt=ensure_ast(state.stack[-1]), + generators=state.ret.generators, + ) + return replace(state, stack=new_stack, ret=ret) - # Add any pending conditions to the last loop - if state.pending_conditions: - assert len(state.loops) > 0, "dangling condition" - last_loop = ast.comprehension( - target=state.loops[-1].target, - iter=state.loops[-1].iter, - ifs=state.loops[-1].ifs + state.pending_conditions, - is_async=state.loops[-1].is_async - ) - return replace(state, stack=new_stack, yld=yld, loops=state.loops[:-1] + [last_loop], pending_conditions=[]) - else: - return replace(state, stack=new_stack, yld=yld) +# ============================================================================ +# LIST COMPREHENSION HANDLERS +# ============================================================================ -@register_handler('RETURN_VALUE') -def handle_return_value(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - # RETURN_VALUE ends the generator - # Usually preceded by LOAD_CONST None - assert state.ret is None, "RETURN_VALUE should not be called more than once" +@register_handler('BUILD_LIST') +def handle_build_list(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + list_size: int = instr.arg + # Pop elements for the list + elements = [ensure_ast(elem) for elem in state.stack[-list_size:]] if list_size > 0 else [] + new_stack = state.stack[:-list_size] if list_size > 0 else state.stack + + # Create list AST + list_node = ast.List(elts=elements, ctx=ast.Load()) + new_stack = new_stack + [list_node] + return replace(state, stack=new_stack) + + +@register_handler('LIST_APPEND') +def handle_list_append(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + assert isinstance(state.ret, ast.ListComp), "LIST_APPEND must be called within a ListComp context" + new_stack = state.stack[:-1] + new_ret = ast.ListComp( + elt=ensure_ast(state.stack[-1]), + generators=state.ret.generators, + ) + return replace(state, stack=new_stack, ret=new_ret) + + +# ============================================================================ +# SET COMPREHENSION HANDLERS +# ============================================================================ + +@register_handler('BUILD_SET') +def handle_build_set(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + raise NotImplementedError("BUILD_SET not implemented yet") # TODO + + +@register_handler('SET_ADD') +def handle_set_add(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + assert isinstance(state.ret, ast.SetComp), "SET_ADD must be called after BUILD_SET" new_stack = state.stack[:-1] - ret = ensure_ast(state.stack[-1]) + new_ret = ast.SetComp( + elt=ensure_ast(state.stack[-1]), + generators=state.ret.generators, + ) + return replace(state, stack=new_stack, ret=new_ret) - if state.yld is not None: - assert isinstance(ret, ast.Constant) and ret.value is None, "RETURN_VALUE must be None" - return replace(state, stack=new_stack, ret=ret) +# ============================================================================ +# DICT COMPREHENSION HANDLERS +# ============================================================================ + +@register_handler('BUILD_MAP') +def handle_build_map(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + raise NotImplementedError("BUILD_MAP not implemented yet") # TODO + + +@register_handler('MAP_ADD') +def handle_map_add(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + assert isinstance(state.ret, ast.DictComp), "MAP_ADD must be called after BUILD_MAP" + new_stack = state.stack[:-2] + new_ret = ast.DictComp( + key=ensure_ast(state.stack[-2]), + value=ensure_ast(state.stack[-1]), + generators=state.ret.generators, + ) + return replace(state, stack=new_stack, ret=new_ret) # ============================================================================ # LOOP CONTROL HANDLERS # ============================================================================ +@register_handler('RETURN_VALUE') +def handle_return_value(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # RETURN_VALUE ends the generator + # Usually preceded by LOAD_CONST None + new_stack = state.stack[:-1] + + # Add any pending conditions to the last loop + if isinstance(state.ret, CompExp) and state.pending_conditions: + assert len(state.ret.generators) > 0, "dangling condition" + last_loop = ast.comprehension( + target=state.ret.generators[-1].target, + iter=state.ret.generators[-1].iter, + ifs=state.ret.generators[-1].ifs + state.pending_conditions, + is_async=state.ret.generators[-1].is_async + ) + if isinstance(state.ret, ast.DictComp): + new_ret = ast.DictComp( + key=state.ret.key, + value=state.ret.value, + generators=state.ret.generators[:-1] + [last_loop], + ) + else: + new_ret = type(state.ret)( + elt=state.ret.elt, + generators=state.ret.generators[:-1] + [last_loop], + ) + return replace(state, stack=new_stack, ret=new_ret, pending_conditions=[]) + else: + return replace(state, stack=new_stack) + + @register_handler('FOR_ITER') def handle_for_iter(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: # FOR_ITER pops an iterator from the stack and pushes the next item # If the iterator is exhausted, it jumps to the target instruction + assert len(state.stack) > 0, "FOR_ITER must have an iterator on the stack" + assert isinstance(state.ret, CompExp), "FOR_ITER must be called within a comprehension context" + # The iterator should be on top of stack # Create new stack without the iterator new_stack = state.stack[:-1] - iterator: ast.AST = state.stack[-1] + iterator: ast.expr = state.stack[-1] # Create a new loop variable - we'll get the actual name from STORE_FAST # For now, use a placeholder loop_info = ast.comprehension( - target=ast.Name(id='_temp', ctx=ast.Store()), + target=Placeholder(), iter=ensure_ast(iterator), ifs=[], is_async=0, ) # Create new loops list with the new loop info - new_loops = state.loops + [loop_info] - - return replace(state, stack=new_stack, loops=new_loops) + new_ret: ast.GeneratorExp | ast.ListComp | ast.SetComp | ast.DictComp + if isinstance(state.ret, ast.DictComp): + # If it's a DictComp, we need to ensure the loop is added to the dict comprehension + new_ret = ast.DictComp( + key=state.ret.key, + value=state.ret.value, + generators=state.ret.generators + [loop_info], + ) + else: + new_ret = type(state.ret)( + elt=state.ret.elt, + generators=state.ret.generators + [loop_info], + ) + + return replace(state, stack=new_stack, ret=new_ret) + + +@register_handler('GET_ITER') +def handle_get_iter(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # GET_ITER converts the top stack item to an iterator + # For AST reconstruction, we typically don't need to change anything + # since the iterator will be used directly in the comprehension + return state @register_handler('JUMP_ABSOLUTE') @@ -270,14 +312,9 @@ def handle_jump_forward(state: ReconstructionState, instr: dis.Instruction) -> R def handle_load_fast(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: var_name: str = instr.argval - # Special handling for .0 variable (the iterator) - if var_name[0] == '.': - # This is loading the iterator passed to the generator - # We need to reconstruct what it represents - if not state.frame or var_name not in state.frame.f_locals: - raise ValueError(f"Iterator variable '{var_name}' not found in frame locals.") - - new_stack = state.stack + [ensure_ast(state.frame.f_locals[var_name])] + if var_name == '.0': + # Special handling for .0 variable (the iterator) + new_stack = state.stack + [IterDummyName()] else: # Regular variable load new_stack = state.stack + [ast.Name(id=var_name, ctx=ast.Load())] @@ -287,21 +324,35 @@ def handle_load_fast(state: ReconstructionState, instr: dis.Instruction) -> Reco @register_handler('STORE_FAST') def handle_store_fast(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + assert isinstance(state.ret, CompExp), "STORE_FAST must be called within a comprehension context" var_name = instr.argval # Update the most recent loop's target variable - assert len(state.loops) > 0, "STORE_FAST must be within a loop context" + assert len(state.ret.generators) > 0, "STORE_FAST must be within a loop context" # Create a new LoopInfo with updated target updated_loop = ast.comprehension( target=ast.Name(id=var_name, ctx=ast.Store()), - iter=state.loops[-1].iter, - ifs=state.loops[-1].ifs, - is_async=state.loops[-1].is_async + iter=state.ret.generators[-1].iter, + ifs=state.ret.generators[-1].ifs, + is_async=state.ret.generators[-1].is_async ) + + # Update the last loop in the generators list + if isinstance(state.ret, ast.DictComp): + new_ret = ast.DictComp( + key=state.ret.key, + value=state.ret.value, + generators=state.ret.generators[:-1] + [updated_loop], + ) + else: + new_ret = type(state.ret)( + elt=state.ret.elt, + generators=state.ret.generators[:-1] + [updated_loop], + ) + # Create new loops list with the updated loop - new_loops = state.loops[:-1] + [updated_loop] - return replace(state, loops=new_loops) + return replace(state, ret=new_ret) @register_handler('LOAD_CONST') @@ -326,13 +377,6 @@ def handle_load_name(state: ReconstructionState, instr: dis.Instruction) -> Reco return replace(state, stack=new_stack) -@register_handler('STORE_NAME') -def handle_store_name(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - # STORE_NAME stores to a name in the global namespace - # In generator expressions, this is uncommon but we'll handle it like STORE_FAST - return state - - # ============================================================================ # STACK MANAGEMENT HANDLERS # ============================================================================ @@ -583,80 +627,14 @@ def handle_call(state: ReconstructionState, instr: dis.Instruction) -> Reconstru return replace(state, stack=new_stack) -@register_handler('PRECALL') -def handle_precall(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - # PRECALL is used to prepare for a function call (Python 3.11+) - # Usually followed by CALL, so we don't need to do much here - return state - - @register_handler('MAKE_FUNCTION') def handle_make_function(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: # MAKE_FUNCTION creates a function from code object and name on stack - # For lambda functions, we need to reconstruct the lambda expression - code_obj: ast.Constant = state.stack[-2] # Code object - lambda_code: types.CodeType = code_obj.value - - # For lambda functions, try to reconstruct the lambda expression - lambda_code = code_obj.value - - # Simple lambda reconstruction - try to extract the basic pattern - # This is a simplified approach for common lambda patterns - # Get the lambda's bytecode instructions - lambda_instructions = list(dis.get_instructions(lambda_code)) + # For lambda functions, we need to recurse into the body - # Map binary operations - op_map = { - 'BINARY_MULTIPLY': ast.Mult(), - 'BINARY_ADD': ast.Add(), - 'BINARY_SUBTRACT': ast.Sub(), - 'BINARY_TRUE_DIVIDE': ast.Div(), - 'BINARY_FLOOR_DIVIDE': ast.FloorDiv(), - 'BINARY_MODULO': ast.Mod(), - 'BINARY_POWER': ast.Pow(), - } - - # For simple lambdas like "lambda y: y * 2" - # Look for pattern: LOAD_FAST, LOAD_CONST, BINARY_OP, RETURN_VALUE - if (len(lambda_instructions) == 4 and - lambda_instructions[0].opname == 'LOAD_FAST' and - lambda_instructions[1].opname == 'LOAD_CONST' and - lambda_instructions[2].opname in op_map and - lambda_instructions[3].opname == 'RETURN_VALUE'): - - param_name = lambda_instructions[0].argval - const_value = lambda_instructions[1].argval - op_name = lambda_instructions[2].opname - - # Create lambda AST: lambda param: param op constant - lambda_ast = ast.Lambda( - args=ast.arguments( - posonlyargs=[], - args=[ast.arg(arg=param_name, annotation=None)], - vararg=None, - kwonlyargs=[], - kw_defaults=[], - kwarg=None, - defaults=[] - ), - body=ast.BinOp( - left=ast.Name(id=param_name, ctx=ast.Load()), - op=op_map[op_name], - right=ast.Constant(value=const_value) - ) - ) - new_stack = state.stack[:-2] + [lambda_ast] - return replace(state, stack=new_stack) - else: - raise NotImplementedError("Complex lambda reconstruction not implemented yet.") + assert instr.arg == 0, "MAKE_FUNCTION with non-zero flags not allowed." - -@register_handler('GET_ITER') -def handle_get_iter(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - # GET_ITER converts the top stack item to an iterator - # For AST reconstruction, we typically don't need to change anything - # since the iterator will be used directly in the comprehension - return state + raise NotImplementedError("Complex lambda reconstruction not implemented yet.") # ============================================================================ @@ -690,7 +668,7 @@ def handle_binary_subscr(state: ReconstructionState, instr: dis.Instruction) -> # ============================================================================ -# CONTAINER BUILDING HANDLERS +# OTHER CONTAINER BUILDING HANDLERS # ============================================================================ @register_handler('BUILD_TUPLE') @@ -706,19 +684,6 @@ def handle_build_tuple(state: ReconstructionState, instr: dis.Instruction) -> Re return replace(state, stack=new_stack) -@register_handler('BUILD_LIST') -def handle_build_list(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - list_size: int = instr.arg - # Pop elements for the list - elements = [ensure_ast(elem) for elem in state.stack[-list_size:]] if list_size > 0 else [] - new_stack = state.stack[:-list_size] if list_size > 0 else state.stack - - # Create list AST - list_node = ast.List(elts=elements, ctx=ast.Load()) - new_stack = new_stack + [list_node] - return replace(state, stack=new_stack) - - @register_handler('LIST_EXTEND') def handle_list_extend(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: # LIST_EXTEND extends the list at TOS-1 with the iterable at TOS @@ -849,29 +814,49 @@ def handle_pop_jump_if_false(state: ReconstructionState, instr: dis.Instruction) combined_condition = ast.BoolOp(op=ast.Or(), values=all_or_conditions) # Add the combined condition to the loop and clear OR conditions - if state.loops: + if isinstance(state.ret, CompExp) and state.ret.generators: updated_loop = ast.comprehension( - target=state.loops[-1].target, - iter=state.loops[-1].iter, - ifs=state.loops[-1].ifs + [combined_condition], - is_async=state.loops[-1].is_async, + target=state.ret.generators[-1].target, + iter=state.ret.generators[-1].iter, + ifs=state.ret.generators[-1].ifs + [combined_condition], + is_async=state.ret.generators[-1].is_async, ) - new_loops = state.loops[:-1] + [updated_loop] - return replace(state, stack=new_stack, loops=new_loops, or_conditions=[]) + if isinstance(state.ret, ast.DictComp): + new_ret = ast.DictComp( + key=state.ret.key, + value=state.ret.value, + generators=state.ret.generators[:-1] + [updated_loop], + ) + else: + new_ret = type(state.ret)( + elt=state.ret.elt, + generators=state.ret.generators[:-1] + [updated_loop], + ) + return replace(state, stack=new_stack, ret=new_ret, or_conditions=[]) else: new_pending = state.pending_conditions + [combined_condition] return replace(state, stack=new_stack, pending_conditions=new_pending, or_conditions=[]) else: # Regular condition - add to the most recent loop - if state.loops: + if isinstance(state.ret, CompExp) and state.ret.generators: updated_loop = ast.comprehension( - target=state.loops[-1].target, - iter=state.loops[-1].iter, - ifs=state.loops[-1].ifs + [condition], - is_async=state.loops[-1].is_async, + target=state.ret.generators[-1].target, + iter=state.ret.generators[-1].iter, + ifs=state.ret.generators[-1].ifs + [condition], + is_async=state.ret.generators[-1].is_async, ) - new_loops = state.loops[:-1] + [updated_loop] - return replace(state, stack=new_stack, loops=new_loops) + if isinstance(state.ret, ast.DictComp): + new_ret = ast.DictComp( + key=state.ret.key, + value=state.ret.value, + generators=state.ret.generators[:-1] + [updated_loop], + ) + else: + new_ret = type(state.ret)( + elt=state.ret.elt, + generators=state.ret.generators[:-1] + [updated_loop], + ) + return replace(state, stack=new_stack, ret=new_ret) else: # If no loops yet, add to pending conditions new_pending = state.pending_conditions + [condition] @@ -900,15 +885,25 @@ def handle_pop_jump_if_true(state: ReconstructionState, instr: dis.Instruction) # So we need to negate the condition to get the filter condition negated_condition = ast.UnaryOp(op=ast.Not(), operand=condition) - if state.loops: + if isinstance(state.ret, CompExp) and state.ret.generators: updated_loop = ast.comprehension( - target=state.loops[-1].target, - iter=state.loops[-1].iter, - ifs=state.loops[-1].ifs + [negated_condition], - is_async=state.loops[-1].is_async, + target=state.ret.generators[-1].target, + iter=state.ret.generators[-1].iter, + ifs=state.ret.generators[-1].ifs + [negated_condition], + is_async=state.ret.generators[-1].is_async, ) - new_loops = state.loops[:-1] + [updated_loop] - return replace(state, stack=new_stack, loops=new_loops) + if isinstance(state.ret, ast.DictComp): + new_ret = ast.DictComp( + key=state.ret.key, + value=state.ret.value, + generators=state.ret.generators[:-1] + [updated_loop], + ) + else: + new_ret = type(state.ret)( + elt=state.ret.elt, + generators=state.ret.generators[:-1] + [updated_loop], + ) + return replace(state, stack=new_stack, ret=new_ret) else: new_pending = state.pending_conditions + [negated_condition] return replace(state, stack=new_stack, pending_conditions=new_pending) @@ -936,38 +931,23 @@ def handle_unpack_sequence(state: ReconstructionState, instr: dis.Instruction) - return replace(state, stack=new_stack) -# ============================================================================ -# SIMPLE/UTILITY OPCODE HANDLERS -# ============================================================================ - -@register_handler('NOP') -def handle_nop(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - # NOP does nothing - return state - - -@register_handler('CACHE') -def handle_cache(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - # CACHE is used for optimization caching, no effect on AST - return state - - # ============================================================================ # UTILITY FUNCTIONS # ============================================================================ @functools.singledispatch -def ensure_ast(value) -> ast.AST: +def ensure_ast(value) -> ast.expr: """Ensure value is an AST node""" raise TypeError(f"Cannot convert {type(value)} to AST node") @ensure_ast.register -def _ensure_ast_ast(value: ast.AST) -> ast.AST: +def _ensure_ast_ast(value: ast.expr) -> ast.expr: """If already an AST node, return it as is""" return value +@ensure_ast.register(types.FunctionType) @ensure_ast.register(int) @ensure_ast.register(float) @ensure_ast.register(str) @@ -992,7 +972,7 @@ def _ensure_ast_tuple(value: tuple) -> ast.Tuple: @ensure_ast.register(type(iter((1,)))) -def _ensure_ast_tuple_iterator(value: Iterator) -> ast.AST: +def _ensure_ast_tuple_iterator(value: Iterator) -> ast.expr: return ensure_ast(tuple(value.__reduce__()[1][0])) @@ -1002,7 +982,7 @@ def _ensure_ast_list(value: list) -> ast.List: @ensure_ast.register(type(iter([1]))) -def _ensure_ast_list_iterator(value: Iterator) -> ast.AST: +def _ensure_ast_list_iterator(value: Iterator) -> ast.expr: return ensure_ast(list(value.__reduce__()[1][0])) @@ -1012,7 +992,7 @@ def _ensure_ast_set(value: set) -> ast.Set: @ensure_ast.register(type(iter({1}))) -def _ensure_ast_set_iterator(value: Iterator) -> ast.AST: +def _ensure_ast_set_iterator(value: Iterator) -> ast.expr: return ensure_ast(set(value.__reduce__()[1][0])) @@ -1025,7 +1005,7 @@ def _ensure_ast_dict(value: dict) -> ast.Dict: @ensure_ast.register(type(iter({1: 2}))) -def _ensure_ast_dict_iterator(value: Iterator) -> ast.AST: +def _ensure_ast_dict_iterator(value: Iterator) -> ast.expr: return ensure_ast(value.__reduce__()[1][0]) @@ -1039,14 +1019,51 @@ def _ensure_ast_range(value: range) -> ast.Call: @ensure_ast.register(type(iter(range(1)))) -def _ensure_ast_range_iterator(value: Iterator) -> ast.AST: +def _ensure_ast_range_iterator(value: Iterator) -> ast.expr: return ensure_ast(value.__reduce__()[1][0]) @ensure_ast.register -def _ensure_ast_codeobj(value: types.CodeType) -> ast.Constant: - # TODO recurse or raise an error - return ast.Constant(value=value) +def _ensure_ast_codeobj(value: types.CodeType) -> CompExp | ast.Lambda: + # Determine return type based on the first instruction + ret: CompExp | ast.Lambda + instructions = list(dis.get_instructions(value)) + if instructions[0].opname == 'GEN_START' and instructions[1].opname == 'LOAD_FAST' and instructions[1].argval == '.0': + ret = ast.GeneratorExp(elt=Placeholder(), generators=[]) + elif instructions[0].opname == 'BUILD_LIST' and instructions[1].opname == 'LOAD_FAST' and instructions[1].argval == '.0': + ret = ast.ListComp(elt=Placeholder(), generators=[]) + elif instructions[0].opname == 'BUILD_SET' and instructions[1].opname == 'LOAD_FAST' and instructions[1].argval == '.0': + ret = ast.SetComp(elt=Placeholder(), generators=[]) + elif instructions[0].opname == 'BUILD_MAP' and instructions[1].opname == 'LOAD_FAST' and instructions[1].argval == '.0': + ret = ast.DictComp(key=Placeholder(), value=Placeholder(), generators=[]) + elif instructions[0].opname in {'BUILD_LIST', 'BUILD_SET', 'BUILD_MAP'}: + raise NotImplementedError("Unpacking construction not implemented yet") + elif instructions[-1].opname == 'RETURN_VALUE': + # not a comprehension, assume it's a lambda + ret = ast.Lambda( + args=ast.arguments( + posonlyargs=[], + args=[], + vararg=None, + kwonlyargs=[], + kwarg=None, + defaults=[], + kw_defaults=[], + ), + body=Placeholder() + ) + else: + raise TypeError("Code type from unsupported source") + + # Symbolic execution to reconstruct the AST + state = ReconstructionState(ret=ret) + for instr in instructions: + state = OP_HANDLERS[instr.opname](state, instr) + + # Check postconditions + assert not any(isinstance(x, Placeholder) for x in ast.walk(state.ret)), "Return value must not contain placeholders" + assert isinstance(state.ret, ast.Lambda) or len(state.ret.generators) > 0, "Return value must have generators if not a lambda" + return state.ret # ============================================================================ @@ -1104,21 +1121,9 @@ def reconstruct(genexpr: GeneratorType) -> ast.GeneratorExp: match the original comprehension. """ assert inspect.isgenerator(genexpr), "Input must be a generator expression" - assert inspect.getgeneratorstate(genexpr) == 'GEN_CREATED', "Generator must be in created state" - - # Initialize reconstruction state - state = ReconstructionState( - code_obj=genexpr.gi_code, - frame=genexpr.gi_frame - ) - - # Process each instruction - for instr in dis.get_instructions(state.code_obj): - # Call the handler - state = OP_HANDLERS[instr.opname](state, instr) - - # Determine the main expression - assert state.yld is not None, "Yield expression must be set" - assert state.loops, "At least one loop must be present" - assert isinstance(state.ret, ast.Constant) and state.ret.value is None, "Return value must not be set" - return ast.GeneratorExp(elt=state.yld, generators=state.loops) + assert inspect.getgeneratorstate(genexpr) == inspect.GEN_CREATED, "Generator must be in created state" + genexpr_ast: ast.GeneratorExp = ensure_ast(genexpr.gi_code) + assert isinstance(genexpr_ast.generators[0].iter, IterDummyName) + assert len([x for x in ast.walk(genexpr_ast) if isinstance(x, IterDummyName)]) == 1 + genexpr_ast.generators[0].iter = ensure_ast(genexpr.gi_frame.f_locals['.0']) + return genexpr_ast diff --git a/tests/test_ops_syntax_generator.py b/tests/test_ops_syntax_generator.py index 0c2ef1a4..c016cd47 100644 --- a/tests/test_ops_syntax_generator.py +++ b/tests/test_ops_syntax_generator.py @@ -341,20 +341,15 @@ def test_nested_loops(genexpr): # ============================================================================ @pytest.mark.parametrize("genexpr", [ - # List comprehensions + # Comprehensions as iterator constants (x_ for x_ in [x for x in range(5)]), - (x_ for x_ in [x * 2 for x in range(5)]), - (x_ for x_ in [x for x in range(10) if x % 2 == 0]), - - # Set comprehensions (x_ for x_ in {x for x in range(5)}), - (x_ for x_ in {x * 2 for x in range(5)}), - (x_ for x_ in {x for x in range(10) if x % 2 == 0}), - - # Dict comprehensions (x_ for x_ in {x: x**2 for x in range(5)}), - (x_ for x_ in {x: x*2 for x in range(5) if x % 2 == 0}), - (x_ for x_ in {str(x): x for x in range(5)}), + + # Comprehensions as yield expressions + ([y for y in range(x + 1)] for x in range(3)), + ({y for y in range(x + 1)} for x in range(3)), + ({y: y**2 for y in range(x + 1)} for x in range(3)), ]) def test_different_comprehension_types(genexpr): """Test reconstruction of different comprehension types.""" From 6e180846094b428f1450b82cc10c9d980e4cd645 Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 14 Jun 2025 22:51:30 -0400 Subject: [PATCH 012/106] simplify and remove an unused handler --- effectful/internals/genexpr.py | 52 ++++------------------------------ 1 file changed, 6 insertions(+), 46 deletions(-) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index 67df87db..787f4d4a 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -216,31 +216,10 @@ def handle_map_add(state: ReconstructionState, instr: dis.Instruction) -> Recons def handle_return_value(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: # RETURN_VALUE ends the generator # Usually preceded by LOAD_CONST None - new_stack = state.stack[:-1] - - # Add any pending conditions to the last loop - if isinstance(state.ret, CompExp) and state.pending_conditions: - assert len(state.ret.generators) > 0, "dangling condition" - last_loop = ast.comprehension( - target=state.ret.generators[-1].target, - iter=state.ret.generators[-1].iter, - ifs=state.ret.generators[-1].ifs + state.pending_conditions, - is_async=state.ret.generators[-1].is_async - ) - if isinstance(state.ret, ast.DictComp): - new_ret = ast.DictComp( - key=state.ret.key, - value=state.ret.value, - generators=state.ret.generators[:-1] + [last_loop], - ) - else: - new_ret = type(state.ret)( - elt=state.ret.elt, - generators=state.ret.generators[:-1] + [last_loop], - ) - return replace(state, stack=new_stack, ret=new_ret, pending_conditions=[]) - else: - return replace(state, stack=new_stack) + if isinstance(state.ret, CompExp): + return replace(state, stack=state.stack[:-1]) + elif isinstance(state.ret, ast.Lambda): + raise NotImplementedError("Lambda reconstruction not implemented yet") @register_handler('FOR_ITER') @@ -610,31 +589,12 @@ def handle_call_method(state: ReconstructionState, instr: dis.Instruction) -> Re return replace(state, stack=new_stack) -@register_handler('CALL') -def handle_call(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - # CALL is the newer unified call instruction (Python 3.11+) - # Similar to CALL_FUNCTION but with a different calling convention - arg_count = instr.arg - - # Pop arguments and function - args = [ensure_ast(arg) for arg in state.stack[-arg_count:]] if arg_count > 0 else [] - func = ensure_ast(state.stack[-arg_count - 1]) - new_stack = state.stack[:-arg_count - 1] - - # Create function call AST - call_node = ast.Call(func=func, args=args, keywords=[]) - new_stack = new_stack + [call_node] - return replace(state, stack=new_stack) - - @register_handler('MAKE_FUNCTION') def handle_make_function(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: # MAKE_FUNCTION creates a function from code object and name on stack - # For lambda functions, we need to recurse into the body - - assert instr.arg == 0, "MAKE_FUNCTION with non-zero flags not allowed." + assert instr.arg == 0, "MAKE_FUNCTION with defaults or annotations not allowed." - raise NotImplementedError("Complex lambda reconstruction not implemented yet.") + raise NotImplementedError("Lambda reconstruction not implemented yet") # ============================================================================ From 4220b87b488c44da5151b7a1335df8ecf7c28b82 Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 14 Jun 2025 22:59:44 -0400 Subject: [PATCH 013/106] nit --- effectful/internals/genexpr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index 787f4d4a..6d3f62fc 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -32,7 +32,7 @@ class Placeholder(ast.Name): """Placeholder for AST nodes that are not yet resolved.""" def __init__(self): - super().__init__(id="", ctx=ast.Load()) + super().__init__(id=".PLACEHOLDER", ctx=ast.Load()) class IterDummyName(ast.Name): From 8fcefd80a79047356d9d804322d626047d8abd48 Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 14 Jun 2025 23:31:05 -0400 Subject: [PATCH 014/106] lambdatype --- effectful/internals/genexpr.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index 6d3f62fc..a01fb936 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -21,8 +21,7 @@ import inspect import types import typing -from types import GeneratorType, FunctionType -from typing import Callable, Any, List, Dict, Iterator, Optional, Union +from collections.abc import Callable, Iterator from dataclasses import dataclass, field, replace @@ -61,10 +60,6 @@ class ReconstructionState: like LOAD_FAST push to this stack, while operations like BINARY_ADD pop operands and push results. - frame: The generator's frame object (from generator.gi_frame). - Provides access to the runtime state, including local variables - like the '.0' iterator variable. - pending_conditions: Filter conditions that haven't been assigned to a loop yet. Some bytecode patterns require collecting conditions before knowing which loop they belong to. @@ -74,8 +69,8 @@ class ReconstructionState: """ ret: ast.Lambda | ast.GeneratorExp | ast.ListComp | ast.SetComp | ast.DictComp stack: list[ast.expr] = field(default_factory=list) # Stack of AST nodes or values - pending_conditions: List[ast.expr] = field(default_factory=list) - or_conditions: List[ast.expr] = field(default_factory=list) + pending_conditions: list[ast.expr] = field(default_factory=list) + or_conditions: list[ast.expr] = field(default_factory=list) # Global handler registry @@ -907,7 +902,6 @@ def _ensure_ast_ast(value: ast.expr) -> ast.expr: return value -@ensure_ast.register(types.FunctionType) @ensure_ast.register(int) @ensure_ast.register(float) @ensure_ast.register(str) @@ -1031,7 +1025,13 @@ def _ensure_ast_codeobj(value: types.CodeType) -> CompExp | ast.Lambda: # ============================================================================ @ensure_ast.register -def reconstruct(genexpr: GeneratorType) -> ast.GeneratorExp: +def _ensure_ast_lambda(value: types.LambdaType) -> ast.Lambda: + assert inspect.isfunction(value), "Input must be a lambda function" + raise NotImplementedError("Lambda reconstruction not implemented yet") + + +@ensure_ast.register +def reconstruct(genexpr: types.GeneratorType) -> ast.GeneratorExp: """ Reconstruct an AST from a generator expression's bytecode. From 8ceb4299189deeeea9e1b75c73987b7bb24dfc57 Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 14 Jun 2025 23:45:44 -0400 Subject: [PATCH 015/106] list_to_tuple --- effectful/internals/genexpr.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index a01fb936..903c78dd 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -639,6 +639,18 @@ def handle_build_tuple(state: ReconstructionState, instr: dis.Instruction) -> Re return replace(state, stack=new_stack) +@register_handler('LIST_TO_TUPLE') +def handle_list_to_tuple(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # LIST_TO_TUPLE converts a list on the stack to a tuple + list_obj = ensure_ast(state.stack[-1]) + assert isinstance(list_obj, ast.List), "Expected a list for LIST_TO_TUPLE" + + # Create tuple AST from the list's elements + tuple_node = ast.Tuple(elts=list_obj.elts, ctx=ast.Load()) + new_stack = state.stack[:-1] + [tuple_node] + return replace(state, stack=new_stack) + + @register_handler('LIST_EXTEND') def handle_list_extend(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: # LIST_EXTEND extends the list at TOS-1 with the iterable at TOS From 6bd713dbf3538926425496c9d7d95f5bcdd2759b Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 14 Jun 2025 23:49:10 -0400 Subject: [PATCH 016/106] opmap --- effectful/internals/genexpr.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index 903c78dd..e04eaa88 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -92,6 +92,9 @@ def register_handler(opname: str, handler = None): if handler is None: return functools.partial(register_handler, opname) + if opname not in dis.opmap: + raise ValueError(f"Invalid operation name: '{opname}'") + if opname in OP_HANDLERS: raise ValueError(f"Handler for '{opname}' already exists.") From 425f0cc1c44ea7e30c5c27a6fbe5e4327b68dea6 Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 14 Jun 2025 23:49:53 -0400 Subject: [PATCH 017/106] opmap --- effectful/internals/genexpr.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index e04eaa88..1594887d 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -92,8 +92,7 @@ def register_handler(opname: str, handler = None): if handler is None: return functools.partial(register_handler, opname) - if opname not in dis.opmap: - raise ValueError(f"Invalid operation name: '{opname}'") + assert opname not in dis.opmap, f"Invalid operation name: '{opname}'" if opname in OP_HANDLERS: raise ValueError(f"Handler for '{opname}' already exists.") From bad6f3a4c160995a89c5d4c8b8b178aa06ecbc97 Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 14 Jun 2025 23:50:02 -0400 Subject: [PATCH 018/106] opmap --- effectful/internals/genexpr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index 1594887d..a41d4bf1 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -92,7 +92,7 @@ def register_handler(opname: str, handler = None): if handler is None: return functools.partial(register_handler, opname) - assert opname not in dis.opmap, f"Invalid operation name: '{opname}'" + assert opname in dis.opmap, f"Invalid operation name: '{opname}'" if opname in OP_HANDLERS: raise ValueError(f"Handler for '{opname}' already exists.") From a333e72b273608737dd7bb43a5dbac482e2ebb13 Mon Sep 17 00:00:00 2001 From: Eli Date: Sun, 15 Jun 2025 00:31:53 -0400 Subject: [PATCH 019/106] move unpack, remove incorrect compare ops --- effectful/internals/genexpr.py | 46 ++++++++++++++-------------------- 1 file changed, 19 insertions(+), 27 deletions(-) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index a41d4bf1..18ed5699 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -280,6 +280,24 @@ def handle_jump_forward(state: ReconstructionState, instr: dis.Instruction) -> R return state +@register_handler('UNPACK_SEQUENCE') +def handle_unpack_sequence(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # UNPACK_SEQUENCE unpacks a sequence into multiple values + # arg is the number of values to unpack + unpack_count = instr.arg + sequence = ensure_ast(state.stack[-1]) + new_stack = state.stack[:-1] + + # For tuple unpacking in comprehensions, we typically see patterns like: + # ((k, v) for k, v in items) where items is unpacked into k and v + # Create placeholder variables for the unpacked values + for i in range(unpack_count): + var_name = f'_unpack_{i}' + new_stack = new_stack + [ast.Name(id=var_name, ctx=ast.Load())] + + return replace(state, stack=new_stack) + + # ============================================================================ # VARIABLE OPERATIONS HANDLERS # ============================================================================ @@ -714,12 +732,8 @@ def handle_compare_op(state: ReconstructionState, instr: dis.Instruction) -> Rec '>=': ast.GtE(), '==': ast.Eq(), '!=': ast.NotEq(), - 'in': ast.In(), - 'not in': ast.NotIn(), - 'is': ast.Is(), - 'is not': ast.IsNot(), } - assert instr.argval in op_map, f"Unsupported comparison operation: {instr.argval}" + assert instr.argval in dis.cmp_op, f"Unsupported comparison operation: {instr.argval}" op_name = instr.argval compare_node = ast.Compare( @@ -878,28 +892,6 @@ def handle_pop_jump_if_true(state: ReconstructionState, instr: dis.Instruction) return replace(state, stack=new_stack, pending_conditions=new_pending) -# ============================================================================ -# UNPACKING HANDLERS -# ============================================================================ - -@register_handler('UNPACK_SEQUENCE') -def handle_unpack_sequence(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - # UNPACK_SEQUENCE unpacks a sequence into multiple values - # arg is the number of values to unpack - unpack_count = instr.arg - sequence = ensure_ast(state.stack[-1]) - new_stack = state.stack[:-1] - - # For tuple unpacking in comprehensions, we typically see patterns like: - # ((k, v) for k, v in items) where items is unpacked into k and v - # Create placeholder variables for the unpacked values - for i in range(unpack_count): - var_name = f'_unpack_{i}' - new_stack = new_stack + [ast.Name(id=var_name, ctx=ast.Load())] - - return replace(state, stack=new_stack) - - # ============================================================================ # UTILITY FUNCTIONS # ============================================================================ From abd5f3ae0c10995d7b2fa0752947f6fe7fa942d8 Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 16 Jun 2025 09:05:41 -0400 Subject: [PATCH 020/106] ret -> result --- effectful/internals/genexpr.py | 173 +++++++++++++++++---------------- 1 file changed, 87 insertions(+), 86 deletions(-) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index 18ed5699..ef68bb96 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -67,8 +67,9 @@ class ReconstructionState: or_conditions: Conditions that are part of an OR expression. These need to be combined with ast.BoolOp(op=ast.Or()). """ - ret: ast.Lambda | ast.GeneratorExp | ast.ListComp | ast.SetComp | ast.DictComp - stack: list[ast.expr] = field(default_factory=list) # Stack of AST nodes or values + result: CompExp | ast.Lambda | Placeholder = field(default_factory=Placeholder) + stack: list[ast.expr] = field(default_factory=list) + pending_conditions: list[ast.expr] = field(default_factory=list) or_conditions: list[ast.expr] = field(default_factory=list) @@ -114,9 +115,9 @@ def _wrapper(state: ReconstructionState, instr: dis.Instruction) -> Reconstructi def handle_gen_start(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: # GEN_START is typically the first instruction in generator expressions # It initializes the generator - assert isinstance(state.ret, ast.GeneratorExp) - assert isinstance(state.ret.elt, Placeholder), "GEN_START must be called before yielding" - assert len(state.ret.generators) == 0, "GEN_START should not have generators yet" + assert isinstance(state.result, ast.GeneratorExp) + assert isinstance(state.result.elt, Placeholder), "GEN_START must be called before yielding" + assert len(state.result.generators) == 0, "GEN_START should not have generators yet" return state @@ -124,16 +125,16 @@ def handle_gen_start(state: ReconstructionState, instr: dis.Instruction) -> Reco def handle_yield_value(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: # YIELD_VALUE pops a value from the stack and yields it # This is the expression part of the generator - assert isinstance(state.ret, ast.GeneratorExp), "YIELD_VALUE must be called after GEN_START" - assert isinstance(state.ret.elt, Placeholder), "YIELD_VALUE must be called before yielding" - assert len(state.ret.generators) > 0, "YIELD_VALUE should have generators" + assert isinstance(state.result, ast.GeneratorExp), "YIELD_VALUE must be called after GEN_START" + assert isinstance(state.result.elt, Placeholder), "YIELD_VALUE must be called before yielding" + assert len(state.result.generators) > 0, "YIELD_VALUE should have generators" new_stack = state.stack[:-1] ret = ast.GeneratorExp( elt=ensure_ast(state.stack[-1]), - generators=state.ret.generators, + generators=state.result.generators, ) - return replace(state, stack=new_stack, ret=ret) + return replace(state, stack=new_stack, result=ret) # ============================================================================ @@ -155,13 +156,13 @@ def handle_build_list(state: ReconstructionState, instr: dis.Instruction) -> Rec @register_handler('LIST_APPEND') def handle_list_append(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - assert isinstance(state.ret, ast.ListComp), "LIST_APPEND must be called within a ListComp context" + assert isinstance(state.result, ast.ListComp), "LIST_APPEND must be called within a ListComp context" new_stack = state.stack[:-1] new_ret = ast.ListComp( elt=ensure_ast(state.stack[-1]), - generators=state.ret.generators, + generators=state.result.generators, ) - return replace(state, stack=new_stack, ret=new_ret) + return replace(state, stack=new_stack, result=new_ret) # ============================================================================ @@ -175,13 +176,13 @@ def handle_build_set(state: ReconstructionState, instr: dis.Instruction) -> Reco @register_handler('SET_ADD') def handle_set_add(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - assert isinstance(state.ret, ast.SetComp), "SET_ADD must be called after BUILD_SET" + assert isinstance(state.result, ast.SetComp), "SET_ADD must be called after BUILD_SET" new_stack = state.stack[:-1] new_ret = ast.SetComp( elt=ensure_ast(state.stack[-1]), - generators=state.ret.generators, + generators=state.result.generators, ) - return replace(state, stack=new_stack, ret=new_ret) + return replace(state, stack=new_stack, result=new_ret) # ============================================================================ @@ -195,14 +196,14 @@ def handle_build_map(state: ReconstructionState, instr: dis.Instruction) -> Reco @register_handler('MAP_ADD') def handle_map_add(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - assert isinstance(state.ret, ast.DictComp), "MAP_ADD must be called after BUILD_MAP" + assert isinstance(state.result, ast.DictComp), "MAP_ADD must be called after BUILD_MAP" new_stack = state.stack[:-2] new_ret = ast.DictComp( key=ensure_ast(state.stack[-2]), value=ensure_ast(state.stack[-1]), - generators=state.ret.generators, + generators=state.result.generators, ) - return replace(state, stack=new_stack, ret=new_ret) + return replace(state, stack=new_stack, result=new_ret) # ============================================================================ @@ -213,9 +214,9 @@ def handle_map_add(state: ReconstructionState, instr: dis.Instruction) -> Recons def handle_return_value(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: # RETURN_VALUE ends the generator # Usually preceded by LOAD_CONST None - if isinstance(state.ret, CompExp): + if isinstance(state.result, CompExp): return replace(state, stack=state.stack[:-1]) - elif isinstance(state.ret, ast.Lambda): + elif isinstance(state.result, ast.Lambda): raise NotImplementedError("Lambda reconstruction not implemented yet") @@ -224,7 +225,7 @@ def handle_for_iter(state: ReconstructionState, instr: dis.Instruction) -> Recon # FOR_ITER pops an iterator from the stack and pushes the next item # If the iterator is exhausted, it jumps to the target instruction assert len(state.stack) > 0, "FOR_ITER must have an iterator on the stack" - assert isinstance(state.ret, CompExp), "FOR_ITER must be called within a comprehension context" + assert isinstance(state.result, CompExp), "FOR_ITER must be called within a comprehension context" # The iterator should be on top of stack # Create new stack without the iterator @@ -242,20 +243,20 @@ def handle_for_iter(state: ReconstructionState, instr: dis.Instruction) -> Recon # Create new loops list with the new loop info new_ret: ast.GeneratorExp | ast.ListComp | ast.SetComp | ast.DictComp - if isinstance(state.ret, ast.DictComp): + if isinstance(state.result, ast.DictComp): # If it's a DictComp, we need to ensure the loop is added to the dict comprehension new_ret = ast.DictComp( - key=state.ret.key, - value=state.ret.value, - generators=state.ret.generators + [loop_info], + key=state.result.key, + value=state.result.value, + generators=state.result.generators + [loop_info], ) else: - new_ret = type(state.ret)( - elt=state.ret.elt, - generators=state.ret.generators + [loop_info], + new_ret = type(state.result)( + elt=state.result.elt, + generators=state.result.generators + [loop_info], ) - return replace(state, stack=new_stack, ret=new_ret) + return replace(state, stack=new_stack, result=new_ret) @register_handler('GET_ITER') @@ -318,35 +319,35 @@ def handle_load_fast(state: ReconstructionState, instr: dis.Instruction) -> Reco @register_handler('STORE_FAST') def handle_store_fast(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - assert isinstance(state.ret, CompExp), "STORE_FAST must be called within a comprehension context" + assert isinstance(state.result, CompExp), "STORE_FAST must be called within a comprehension context" var_name = instr.argval # Update the most recent loop's target variable - assert len(state.ret.generators) > 0, "STORE_FAST must be within a loop context" + assert len(state.result.generators) > 0, "STORE_FAST must be within a loop context" # Create a new LoopInfo with updated target updated_loop = ast.comprehension( target=ast.Name(id=var_name, ctx=ast.Store()), - iter=state.ret.generators[-1].iter, - ifs=state.ret.generators[-1].ifs, - is_async=state.ret.generators[-1].is_async + iter=state.result.generators[-1].iter, + ifs=state.result.generators[-1].ifs, + is_async=state.result.generators[-1].is_async ) # Update the last loop in the generators list - if isinstance(state.ret, ast.DictComp): + if isinstance(state.result, ast.DictComp): new_ret = ast.DictComp( - key=state.ret.key, - value=state.ret.value, - generators=state.ret.generators[:-1] + [updated_loop], + key=state.result.key, + value=state.result.value, + generators=state.result.generators[:-1] + [updated_loop], ) else: - new_ret = type(state.ret)( - elt=state.ret.elt, - generators=state.ret.generators[:-1] + [updated_loop], + new_ret = type(state.result)( + elt=state.result.elt, + generators=state.result.generators[:-1] + [updated_loop], ) # Create new loops list with the updated loop - return replace(state, ret=new_ret) + return replace(state, result=new_ret) @register_handler('LOAD_CONST') @@ -797,49 +798,49 @@ def handle_pop_jump_if_false(state: ReconstructionState, instr: dis.Instruction) combined_condition = ast.BoolOp(op=ast.Or(), values=all_or_conditions) # Add the combined condition to the loop and clear OR conditions - if isinstance(state.ret, CompExp) and state.ret.generators: + if isinstance(state.result, CompExp) and state.result.generators: updated_loop = ast.comprehension( - target=state.ret.generators[-1].target, - iter=state.ret.generators[-1].iter, - ifs=state.ret.generators[-1].ifs + [combined_condition], - is_async=state.ret.generators[-1].is_async, + target=state.result.generators[-1].target, + iter=state.result.generators[-1].iter, + ifs=state.result.generators[-1].ifs + [combined_condition], + is_async=state.result.generators[-1].is_async, ) - if isinstance(state.ret, ast.DictComp): + if isinstance(state.result, ast.DictComp): new_ret = ast.DictComp( - key=state.ret.key, - value=state.ret.value, - generators=state.ret.generators[:-1] + [updated_loop], + key=state.result.key, + value=state.result.value, + generators=state.result.generators[:-1] + [updated_loop], ) else: - new_ret = type(state.ret)( - elt=state.ret.elt, - generators=state.ret.generators[:-1] + [updated_loop], + new_ret = type(state.result)( + elt=state.result.elt, + generators=state.result.generators[:-1] + [updated_loop], ) - return replace(state, stack=new_stack, ret=new_ret, or_conditions=[]) + return replace(state, stack=new_stack, result=new_ret, or_conditions=[]) else: new_pending = state.pending_conditions + [combined_condition] return replace(state, stack=new_stack, pending_conditions=new_pending, or_conditions=[]) else: # Regular condition - add to the most recent loop - if isinstance(state.ret, CompExp) and state.ret.generators: + if isinstance(state.result, CompExp) and state.result.generators: updated_loop = ast.comprehension( - target=state.ret.generators[-1].target, - iter=state.ret.generators[-1].iter, - ifs=state.ret.generators[-1].ifs + [condition], - is_async=state.ret.generators[-1].is_async, + target=state.result.generators[-1].target, + iter=state.result.generators[-1].iter, + ifs=state.result.generators[-1].ifs + [condition], + is_async=state.result.generators[-1].is_async, ) - if isinstance(state.ret, ast.DictComp): + if isinstance(state.result, ast.DictComp): new_ret = ast.DictComp( - key=state.ret.key, - value=state.ret.value, - generators=state.ret.generators[:-1] + [updated_loop], + key=state.result.key, + value=state.result.value, + generators=state.result.generators[:-1] + [updated_loop], ) else: - new_ret = type(state.ret)( - elt=state.ret.elt, - generators=state.ret.generators[:-1] + [updated_loop], + new_ret = type(state.result)( + elt=state.result.elt, + generators=state.result.generators[:-1] + [updated_loop], ) - return replace(state, stack=new_stack, ret=new_ret) + return replace(state, stack=new_stack, result=new_ret) else: # If no loops yet, add to pending conditions new_pending = state.pending_conditions + [condition] @@ -868,25 +869,25 @@ def handle_pop_jump_if_true(state: ReconstructionState, instr: dis.Instruction) # So we need to negate the condition to get the filter condition negated_condition = ast.UnaryOp(op=ast.Not(), operand=condition) - if isinstance(state.ret, CompExp) and state.ret.generators: + if isinstance(state.result, CompExp) and state.result.generators: updated_loop = ast.comprehension( - target=state.ret.generators[-1].target, - iter=state.ret.generators[-1].iter, - ifs=state.ret.generators[-1].ifs + [negated_condition], - is_async=state.ret.generators[-1].is_async, + target=state.result.generators[-1].target, + iter=state.result.generators[-1].iter, + ifs=state.result.generators[-1].ifs + [negated_condition], + is_async=state.result.generators[-1].is_async, ) - if isinstance(state.ret, ast.DictComp): + if isinstance(state.result, ast.DictComp): new_ret = ast.DictComp( - key=state.ret.key, - value=state.ret.value, - generators=state.ret.generators[:-1] + [updated_loop], + key=state.result.key, + value=state.result.value, + generators=state.result.generators[:-1] + [updated_loop], ) else: - new_ret = type(state.ret)( - elt=state.ret.elt, - generators=state.ret.generators[:-1] + [updated_loop], + new_ret = type(state.result)( + elt=state.result.elt, + generators=state.result.generators[:-1] + [updated_loop], ) - return replace(state, stack=new_stack, ret=new_ret) + return replace(state, stack=new_stack, result=new_ret) else: new_pending = state.pending_conditions + [negated_condition] return replace(state, stack=new_stack, pending_conditions=new_pending) @@ -1016,14 +1017,14 @@ def _ensure_ast_codeobj(value: types.CodeType) -> CompExp | ast.Lambda: raise TypeError("Code type from unsupported source") # Symbolic execution to reconstruct the AST - state = ReconstructionState(ret=ret) + state = ReconstructionState(result=ret) for instr in instructions: state = OP_HANDLERS[instr.opname](state, instr) # Check postconditions - assert not any(isinstance(x, Placeholder) for x in ast.walk(state.ret)), "Return value must not contain placeholders" - assert isinstance(state.ret, ast.Lambda) or len(state.ret.generators) > 0, "Return value must have generators if not a lambda" - return state.ret + assert not any(isinstance(x, Placeholder) for x in ast.walk(state.result)), "Return value must not contain placeholders" + assert isinstance(state.result, ast.Lambda) or len(state.result.generators) > 0, "Return value must have generators if not a lambda" + return state.result # ============================================================================ From bdb4b2761dd169f8f3fcf855ab905ce53c9827ac Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 16 Jun 2025 09:12:14 -0400 Subject: [PATCH 021/106] ret -> result --- effectful/internals/genexpr.py | 109 +++++---------------------------- 1 file changed, 17 insertions(+), 92 deletions(-) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index ef68bb96..c757d9fa 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -426,100 +426,25 @@ def handle_rot_four(state: ReconstructionState, instr: dis.Instruction) -> Recon # ARITHMETIC/LOGIC HANDLERS # ============================================================================ -@register_handler('BINARY_ADD') -def handle_binary_add(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: +def handle_binop(op: ast.operator, state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: right = ensure_ast(state.stack[-1]) left = ensure_ast(state.stack[-2]) - new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.Add(), right=right)] - return replace(state, stack=new_stack) - - -@register_handler('BINARY_SUBTRACT') -def handle_binary_subtract(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - right = ensure_ast(state.stack[-1]) - left = ensure_ast(state.stack[-2]) - new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.Sub(), right=right)] - return replace(state, stack=new_stack) - - -@register_handler('BINARY_MULTIPLY') -def handle_binary_multiply(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - right = ensure_ast(state.stack[-1]) - left = ensure_ast(state.stack[-2]) - new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.Mult(), right=right)] - return replace(state, stack=new_stack) - - -@register_handler('BINARY_TRUE_DIVIDE') -def handle_binary_true_divide(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - right = ensure_ast(state.stack[-1]) - left = ensure_ast(state.stack[-2]) - new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.Div(), right=right)] - return replace(state, stack=new_stack) - - -@register_handler('BINARY_FLOOR_DIVIDE') -def handle_binary_floor_divide(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - right = ensure_ast(state.stack[-1]) - left = ensure_ast(state.stack[-2]) - new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.FloorDiv(), right=right)] - return replace(state, stack=new_stack) - - -@register_handler('BINARY_MODULO') -def handle_binary_modulo(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - right = ensure_ast(state.stack[-1]) - left = ensure_ast(state.stack[-2]) - new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.Mod(), right=right)] - return replace(state, stack=new_stack) - - -@register_handler('BINARY_POWER') -def handle_binary_power(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - right = ensure_ast(state.stack[-1]) - left = ensure_ast(state.stack[-2]) - new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.Pow(), right=right)] - return replace(state, stack=new_stack) - - -@register_handler('BINARY_LSHIFT') -def handle_binary_lshift(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - right = ensure_ast(state.stack[-1]) - left = ensure_ast(state.stack[-2]) - new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.LShift(), right=right)] - return replace(state, stack=new_stack) - - -@register_handler('BINARY_RSHIFT') -def handle_binary_rshift(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - right = ensure_ast(state.stack[-1]) - left = ensure_ast(state.stack[-2]) - new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.RShift(), right=right)] - return replace(state, stack=new_stack) - - -@register_handler('BINARY_OR') -def handle_binary_or(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - right = ensure_ast(state.stack[-1]) - left = ensure_ast(state.stack[-2]) - new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.BitOr(), right=right)] - return replace(state, stack=new_stack) - - -@register_handler('BINARY_XOR') -def handle_binary_xor(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - right = ensure_ast(state.stack[-1]) - left = ensure_ast(state.stack[-2]) - new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.BitXor(), right=right)] - return replace(state, stack=new_stack) - - -@register_handler('BINARY_AND') -def handle_binary_and(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - right = ensure_ast(state.stack[-1]) - left = ensure_ast(state.stack[-2]) - new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=ast.BitAnd(), right=right)] - return replace(state, stack=new_stack) + new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=op, right=right)] + return replace(state, stack=new_stack) + + +handler_binop_add = register_handler('BINARY_ADD', functools.partial(handle_binop, ast.Add())) +handler_binop_subtract = register_handler('BINARY_SUBTRACT', functools.partial(handle_binop, ast.Sub())) +handler_binop_multiply = register_handler('BINARY_MULTIPLY', functools.partial(handle_binop, ast.Mult())) +handler_binop_true_divide = register_handler('BINARY_TRUE_DIVIDE', functools.partial(handle_binop, ast.Div())) +handler_binop_floor_divide = register_handler('BINARY_FLOOR_DIVIDE', functools.partial(handle_binop, ast.FloorDiv())) +handler_binop_modulo = register_handler('BINARY_MODULO', functools.partial(handle_binop, ast.Mod())) +handler_binop_power = register_handler('BINARY_POWER', functools.partial(handle_binop, ast.Pow())) +handler_binop_lshift = register_handler('BINARY_LSHIFT', functools.partial(handle_binop, ast.LShift())) +handler_binop_rshift = register_handler('BINARY_RSHIFT', functools.partial(handle_binop, ast.RShift())) +handler_binop_or = register_handler('BINARY_OR', functools.partial(handle_binop, ast.BitOr())) +handler_binop_xor = register_handler('BINARY_XOR', functools.partial(handle_binop, ast.BitXor())) +handler_binop_and = register_handler('BINARY_AND', functools.partial(handle_binop, ast.BitAnd())) # ============================================================================ From 554c8364803dd1bb259b7e994dab3cb919802fbe Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 16 Jun 2025 09:14:45 -0400 Subject: [PATCH 022/106] abstract unary and binary ops --- effectful/internals/genexpr.py | 30 +++++++----------------------- 1 file changed, 7 insertions(+), 23 deletions(-) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index c757d9fa..dd5a3a13 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -423,7 +423,7 @@ def handle_rot_four(state: ReconstructionState, instr: dis.Instruction) -> Recon # ============================================================================ -# ARITHMETIC/LOGIC HANDLERS +# BINARY ARITHMETIC/LOGIC OPERATION HANDLERS # ============================================================================ def handle_binop(op: ast.operator, state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: @@ -451,32 +451,16 @@ def handle_binop(op: ast.operator, state: ReconstructionState, instr: dis.Instru # UNARY OPERATION HANDLERS # ============================================================================ -@register_handler('UNARY_NEGATIVE') -def handle_unary_negative(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: +def handle_unary_op(op: ast.unaryop, state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: operand = ensure_ast(state.stack[-1]) - new_stack = state.stack[:-1] + [ast.UnaryOp(op=ast.USub(), operand=operand)] + new_stack = state.stack[:-1] + [ast.UnaryOp(op=op, operand=operand)] return replace(state, stack=new_stack) -@register_handler('UNARY_POSITIVE') -def handle_unary_positive(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - operand = ensure_ast(state.stack[-1]) - new_stack = state.stack[:-1] + [ast.UnaryOp(op=ast.UAdd(), operand=operand)] - return replace(state, stack=new_stack) - - -@register_handler('UNARY_INVERT') -def handle_unary_invert(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - operand = ensure_ast(state.stack[-1]) - new_stack = state.stack[:-1] + [ast.UnaryOp(op=ast.Invert(), operand=operand)] - return replace(state, stack=new_stack) - - -@register_handler('UNARY_NOT') -def handle_unary_not(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - operand = ensure_ast(state.stack[-1]) - new_stack = state.stack[:-1] + [ast.UnaryOp(op=ast.Not(), operand=operand)] - return replace(state, stack=new_stack) +handle_unary_negative = register_handler('UNARY_NEGATIVE', functools.partial(handle_unary_op, ast.USub())) +handle_unary_positive = register_handler('UNARY_POSITIVE', functools.partial(handle_unary_op, ast.UAdd())) +handle_unary_invert = register_handler('UNARY_INVERT', functools.partial(handle_unary_op, ast.Invert())) +handle_unary_not = register_handler('UNARY_NOT', functools.partial(handle_unary_op, ast.Not())) # ============================================================================ From e6db8fb8d32ea1d102dbbccc2aeb65fe506a9031 Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 16 Jun 2025 09:16:39 -0400 Subject: [PATCH 023/106] reorder --- effectful/internals/genexpr.py | 128 ++++++++++++++++----------------- 1 file changed, 64 insertions(+), 64 deletions(-) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index dd5a3a13..68b949fc 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -463,6 +463,70 @@ def handle_unary_op(op: ast.unaryop, state: ReconstructionState, instr: dis.Inst handle_unary_not = register_handler('UNARY_NOT', functools.partial(handle_unary_op, ast.Not())) +# ============================================================================ +# COMPARISON OPERATION HANDLERS +# ============================================================================ + +@register_handler('COMPARE_OP') +def handle_compare_op(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + right = ensure_ast(state.stack[-1]) + left = ensure_ast(state.stack[-2]) + + # Map comparison operation codes to AST operators + op_map = { + '<': ast.Lt(), + '<=': ast.LtE(), + '>': ast.Gt(), + '>=': ast.GtE(), + '==': ast.Eq(), + '!=': ast.NotEq(), + } + assert instr.argval in dis.cmp_op, f"Unsupported comparison operation: {instr.argval}" + + op_name = instr.argval + compare_node = ast.Compare( + left=left, + ops=[op_map[op_name]], + comparators=[right] + ) + new_stack = state.stack[:-2] + [compare_node] + return replace(state, stack=new_stack) + + +@register_handler('CONTAINS_OP') +def handle_contains_op(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + right = ensure_ast(state.stack[-1]) # Container + left = ensure_ast(state.stack[-2]) # Item to check + + # instr.arg determines if it's 'in' (0) or 'not in' (1) + op = ast.NotIn() if instr.arg else ast.In() + + compare_node = ast.Compare( + left=left, + ops=[op], + comparators=[right] + ) + new_stack = state.stack[:-2] + [compare_node] + return replace(state, stack=new_stack) + + +@register_handler('IS_OP') +def handle_is_op(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + right = ensure_ast(state.stack[-1]) + left = ensure_ast(state.stack[-2]) + + # instr.arg determines if it's 'is' (0) or 'is not' (1) + op = ast.IsNot() if instr.arg else ast.Is() + + compare_node = ast.Compare( + left=left, + ops=[op], + comparators=[right] + ) + new_stack = state.stack[:-2] + [compare_node] + return replace(state, stack=new_stack) + + # ============================================================================ # FUNCTION CALL HANDLERS # ============================================================================ @@ -625,70 +689,6 @@ def handle_build_const_key_map(state: ReconstructionState, instr: dis.Instructio return replace(state, stack=new_stack) -# ============================================================================ -# COMPARISON HANDLERS -# ============================================================================ - -@register_handler('COMPARE_OP') -def handle_compare_op(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - right = ensure_ast(state.stack[-1]) - left = ensure_ast(state.stack[-2]) - - # Map comparison operation codes to AST operators - op_map = { - '<': ast.Lt(), - '<=': ast.LtE(), - '>': ast.Gt(), - '>=': ast.GtE(), - '==': ast.Eq(), - '!=': ast.NotEq(), - } - assert instr.argval in dis.cmp_op, f"Unsupported comparison operation: {instr.argval}" - - op_name = instr.argval - compare_node = ast.Compare( - left=left, - ops=[op_map[op_name]], - comparators=[right] - ) - new_stack = state.stack[:-2] + [compare_node] - return replace(state, stack=new_stack) - - -@register_handler('CONTAINS_OP') -def handle_contains_op(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - right = ensure_ast(state.stack[-1]) # Container - left = ensure_ast(state.stack[-2]) # Item to check - - # instr.arg determines if it's 'in' (0) or 'not in' (1) - op = ast.NotIn() if instr.arg else ast.In() - - compare_node = ast.Compare( - left=left, - ops=[op], - comparators=[right] - ) - new_stack = state.stack[:-2] + [compare_node] - return replace(state, stack=new_stack) - - -@register_handler('IS_OP') -def handle_is_op(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - right = ensure_ast(state.stack[-1]) - left = ensure_ast(state.stack[-2]) - - # instr.arg determines if it's 'is' (0) or 'is not' (1) - op = ast.IsNot() if instr.arg else ast.Is() - - compare_node = ast.Compare( - left=left, - ops=[op], - comparators=[right] - ) - new_stack = state.stack[:-2] + [compare_node] - return replace(state, stack=new_stack) - - # ============================================================================ # CONDITIONAL JUMP HANDLERS # ============================================================================ From d109a4d8e6abe3cf50ff35d62c23f3ed9e2b8be5 Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 16 Jun 2025 09:30:49 -0400 Subject: [PATCH 024/106] cmp --- effectful/internals/genexpr.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index 68b949fc..f82e034a 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -467,26 +467,28 @@ def handle_unary_op(op: ast.unaryop, state: ReconstructionState, instr: dis.Inst # COMPARISON OPERATION HANDLERS # ============================================================================ +CMP_OPMAP: dict[str, ast.cmpop] = { + '<': ast.Lt(), + '<=': ast.LtE(), + '>': ast.Gt(), + '>=': ast.GtE(), + '==': ast.Eq(), + '!=': ast.NotEq(), +} + + @register_handler('COMPARE_OP') def handle_compare_op(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + assert dis.cmp_op[instr.arg] == instr.argval, f"Unsupported comparison operation: {instr.argval}" + right = ensure_ast(state.stack[-1]) left = ensure_ast(state.stack[-2]) # Map comparison operation codes to AST operators - op_map = { - '<': ast.Lt(), - '<=': ast.LtE(), - '>': ast.Gt(), - '>=': ast.GtE(), - '==': ast.Eq(), - '!=': ast.NotEq(), - } - assert instr.argval in dis.cmp_op, f"Unsupported comparison operation: {instr.argval}" - op_name = instr.argval compare_node = ast.Compare( left=left, - ops=[op_map[op_name]], + ops=[CMP_OPMAP[op_name]], comparators=[right] ) new_stack = state.stack[:-2] + [compare_node] From 0862c8417c20406192a42af8968812d6f812a221 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 17 Jun 2025 09:25:41 -0400 Subject: [PATCH 025/106] nits --- effectful/internals/genexpr.py | 45 +++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index f82e034a..c3a18246 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -55,23 +55,28 @@ class ReconstructionState: that represent those operations. Attributes: + result: The current comprehension expression being built. Initially + a placeholder, it gets updated as the bytecode is processed. + It can be a GeneratorExp, ListComp, SetComp, DictComp, or + a Lambda for lambda expressions. + stack: Simulates the Python VM's value stack. Contains AST nodes or values that would be on the stack during execution. Operations like LOAD_FAST push to this stack, while operations like BINARY_ADD pop operands and push results. - pending_conditions: Filter conditions that haven't been assigned to + _pending_conditions: Filter conditions that haven't been assigned to a loop yet. Some bytecode patterns require collecting conditions before knowing which loop they belong to. - or_conditions: Conditions that are part of an OR expression. These + _or_conditions: Conditions that are part of an OR expression. These need to be combined with ast.BoolOp(op=ast.Or()). """ result: CompExp | ast.Lambda | Placeholder = field(default_factory=Placeholder) stack: list[ast.expr] = field(default_factory=list) - pending_conditions: list[ast.expr] = field(default_factory=list) - or_conditions: list[ast.expr] = field(default_factory=list) + _pending_conditions: list[ast.expr] = field(default_factory=list) + _or_conditions: list[ast.expr] = field(default_factory=list) # Global handler registry @@ -703,9 +708,9 @@ def handle_pop_jump_if_false(state: ReconstructionState, instr: dis.Instruction) new_stack = state.stack[:-1] # If we have pending OR conditions, this is the final condition in an OR expression - if state.or_conditions: + if state._or_conditions: # Combine all OR conditions into a single BoolOp - all_or_conditions = state.or_conditions + [condition] + all_or_conditions = state._or_conditions + [condition] combined_condition = ast.BoolOp(op=ast.Or(), values=all_or_conditions) # Add the combined condition to the loop and clear OR conditions @@ -727,10 +732,10 @@ def handle_pop_jump_if_false(state: ReconstructionState, instr: dis.Instruction) elt=state.result.elt, generators=state.result.generators[:-1] + [updated_loop], ) - return replace(state, stack=new_stack, result=new_ret, or_conditions=[]) + return replace(state, stack=new_stack, result=new_ret, _or_conditions=[]) else: - new_pending = state.pending_conditions + [combined_condition] - return replace(state, stack=new_stack, pending_conditions=new_pending, or_conditions=[]) + new_pending = state._pending_conditions + [combined_condition] + return replace(state, stack=new_stack, _pending_conditions=new_pending, _or_conditions=[]) else: # Regular condition - add to the most recent loop if isinstance(state.result, CompExp) and state.result.generators: @@ -754,8 +759,8 @@ def handle_pop_jump_if_false(state: ReconstructionState, instr: dis.Instruction) return replace(state, stack=new_stack, result=new_ret) else: # If no loops yet, add to pending conditions - new_pending = state.pending_conditions + [condition] - return replace(state, stack=new_stack, pending_conditions=new_pending) + new_pending = state._pending_conditions + [condition] + return replace(state, stack=new_stack, _pending_conditions=new_pending) @register_handler('POP_JUMP_IF_TRUE') @@ -772,8 +777,8 @@ def handle_pop_jump_if_true(state: ReconstructionState, instr: dis.Instruction) # In NOT: POP_JUMP_IF_TRUE jumps back to skip this iteration if instr.argval > instr.offset: # Jumping forward - part of an OR expression - new_or_conditions = state.or_conditions + [condition] - return replace(state, stack=new_stack, or_conditions=new_or_conditions) + new_or_conditions = state._or_conditions + [condition] + return replace(state, stack=new_stack, _or_conditions=new_or_conditions) else: # Jumping backward to loop start - this is a negated condition # When POP_JUMP_IF_TRUE jumps back, it means "if true, skip this item" @@ -800,8 +805,8 @@ def handle_pop_jump_if_true(state: ReconstructionState, instr: dis.Instruction) ) return replace(state, stack=new_stack, result=new_ret) else: - new_pending = state.pending_conditions + [negated_condition] - return replace(state, stack=new_stack, pending_conditions=new_pending) + new_pending = state._pending_conditions + [negated_condition] + return replace(state, stack=new_stack, _pending_conditions=new_pending) # ============================================================================ @@ -844,7 +849,7 @@ def _ensure_ast_tuple(value: tuple) -> ast.Tuple: @ensure_ast.register(type(iter((1,)))) -def _ensure_ast_tuple_iterator(value: Iterator) -> ast.expr: +def _ensure_ast_tuple_iterator(value: Iterator) -> ast.Tuple: return ensure_ast(tuple(value.__reduce__()[1][0])) @@ -854,7 +859,7 @@ def _ensure_ast_list(value: list) -> ast.List: @ensure_ast.register(type(iter([1]))) -def _ensure_ast_list_iterator(value: Iterator) -> ast.expr: +def _ensure_ast_list_iterator(value: Iterator) -> ast.List: return ensure_ast(list(value.__reduce__()[1][0])) @@ -864,7 +869,7 @@ def _ensure_ast_set(value: set) -> ast.Set: @ensure_ast.register(type(iter({1}))) -def _ensure_ast_set_iterator(value: Iterator) -> ast.expr: +def _ensure_ast_set_iterator(value: Iterator) -> ast.Set: return ensure_ast(set(value.__reduce__()[1][0])) @@ -877,7 +882,7 @@ def _ensure_ast_dict(value: dict) -> ast.Dict: @ensure_ast.register(type(iter({1: 2}))) -def _ensure_ast_dict_iterator(value: Iterator) -> ast.expr: +def _ensure_ast_dict_iterator(value: Iterator) -> ast.Dict: return ensure_ast(value.__reduce__()[1][0]) @@ -891,7 +896,7 @@ def _ensure_ast_range(value: range) -> ast.Call: @ensure_ast.register(type(iter(range(1)))) -def _ensure_ast_range_iterator(value: Iterator) -> ast.expr: +def _ensure_ast_range_iterator(value: Iterator) -> ast.Call: return ensure_ast(value.__reduce__()[1][0]) From 81b37f734baa2b2b41e4dad7487f6aa2b5aff1d6 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 17 Jun 2025 11:57:11 -0400 Subject: [PATCH 026/106] build_map --- effectful/internals/genexpr.py | 251 ++++++++++++++------------------- 1 file changed, 106 insertions(+), 145 deletions(-) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index c3a18246..4562ed98 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -64,20 +64,10 @@ class ReconstructionState: values that would be on the stack during execution. Operations like LOAD_FAST push to this stack, while operations like BINARY_ADD pop operands and push results. - - _pending_conditions: Filter conditions that haven't been assigned to - a loop yet. Some bytecode patterns require collecting - conditions before knowing which loop they belong to. - - _or_conditions: Conditions that are part of an OR expression. These - need to be combined with ast.BoolOp(op=ast.Or()). """ result: CompExp | ast.Lambda | Placeholder = field(default_factory=Placeholder) stack: list[ast.expr] = field(default_factory=list) - _pending_conditions: list[ast.expr] = field(default_factory=list) - _or_conditions: list[ast.expr] = field(default_factory=list) - # Global handler registry OpHandler = Callable[[ReconstructionState, dis.Instruction], ReconstructionState] @@ -120,10 +110,8 @@ def _wrapper(state: ReconstructionState, instr: dis.Instruction) -> Reconstructi def handle_gen_start(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: # GEN_START is typically the first instruction in generator expressions # It initializes the generator - assert isinstance(state.result, ast.GeneratorExp) - assert isinstance(state.result.elt, Placeholder), "GEN_START must be called before yielding" - assert len(state.result.generators) == 0, "GEN_START should not have generators yet" - return state + assert isinstance(state.result, Placeholder), "GEN_START must be the first instruction" + return replace(state, result=ast.GeneratorExp(elt=Placeholder(), generators=[])) @register_handler('YIELD_VALUE') @@ -148,15 +136,22 @@ def handle_yield_value(state: ReconstructionState, instr: dis.Instruction) -> Re @register_handler('BUILD_LIST') def handle_build_list(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - list_size: int = instr.arg - # Pop elements for the list - elements = [ensure_ast(elem) for elem in state.stack[-list_size:]] if list_size > 0 else [] - new_stack = state.stack[:-list_size] if list_size > 0 else state.stack - - # Create list AST - list_node = ast.List(elts=elements, ctx=ast.Load()) - new_stack = new_stack + [list_node] - return replace(state, stack=new_stack) + if isinstance(state.result, Placeholder) and len(state.stack) == 0: + # This BUILD_LIST is the start of a list comprehension + # Initialize the result as a ListComp with a placeholder element + ret = ast.ListComp(elt=Placeholder(), generators=[]) + new_stack = state.stack + [ret] + return replace(state, stack=new_stack, result=ret) + else: + size: int = instr.arg + # Pop elements for the list + elements = [ensure_ast(elem) for elem in state.stack[-size:]] if size > 0 else [] + new_stack = state.stack[:-size] if size > 0 else state.stack + + # Create list AST + elt_node = ast.List(elts=elements, ctx=ast.Load()) + new_stack = new_stack + [elt_node] + return replace(state, stack=new_stack) @register_handler('LIST_APPEND') @@ -176,7 +171,22 @@ def handle_list_append(state: ReconstructionState, instr: dis.Instruction) -> Re @register_handler('BUILD_SET') def handle_build_set(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - raise NotImplementedError("BUILD_SET not implemented yet") # TODO + if isinstance(state.result, Placeholder) and len(state.stack) == 0: + # This BUILD_SET is the start of a list comprehension + # Initialize the result as a ListComp with a placeholder element + ret = ast.SetComp(elt=Placeholder(), generators=[]) + new_stack = state.stack + [ret] + return replace(state, stack=new_stack, result=ret) + else: + size: int = instr.arg + # Pop elements for the set + elements = [ensure_ast(elem) for elem in state.stack[-size:]] if size > 0 else [] + new_stack = state.stack[:-size] if size > 0 else state.stack + + # Create set AST + elt_node = ast.Set(elts=elements) + new_stack = new_stack + [elt_node] + return replace(state, stack=new_stack) @register_handler('SET_ADD') @@ -196,7 +206,23 @@ def handle_set_add(state: ReconstructionState, instr: dis.Instruction) -> Recons @register_handler('BUILD_MAP') def handle_build_map(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - raise NotImplementedError("BUILD_MAP not implemented yet") # TODO + if isinstance(state.result, Placeholder) and len(state.stack) == 0: + # This is the start of a comprehension + # Initialize the result with a placeholder element + ret = ast.DictComp(key=Placeholder(), value=Placeholder(), generators=[]) + new_stack = state.stack + [ret] + return replace(state, stack=new_stack, result=ret) + else: + size: int = instr.arg + # Pop key-value pairs for the dict + keys = [ensure_ast(state.stack[-2*i-2]) for i in range(size)] + values = [ensure_ast(state.stack[-2*i-1]) for i in range(size)] + new_stack = state.stack[:-2*size] if size > 0 else state.stack + + # Create dict AST + dict_node = ast.Dict(keys=keys, values=values) + new_stack = new_stack + [dict_node] + return replace(state, stack=new_stack) @register_handler('MAP_ADD') @@ -221,8 +247,10 @@ def handle_return_value(state: ReconstructionState, instr: dis.Instruction) -> R # Usually preceded by LOAD_CONST None if isinstance(state.result, CompExp): return replace(state, stack=state.stack[:-1]) - elif isinstance(state.result, ast.Lambda): - raise NotImplementedError("Lambda reconstruction not implemented yet") + elif isinstance(state.result, Placeholder): + return replace(state, stack=state.stack[:-1], result=ensure_ast(state.stack[-1])) + else: + raise TypeError("Unexpected RETURN_VALUE in reconstruction") @register_handler('FOR_ITER') @@ -706,61 +734,32 @@ def handle_pop_jump_if_false(state: ReconstructionState, instr: dis.Instruction) # In comprehensions, this is used for filter conditions condition = ensure_ast(state.stack[-1]) new_stack = state.stack[:-1] - - # If we have pending OR conditions, this is the final condition in an OR expression - if state._or_conditions: - # Combine all OR conditions into a single BoolOp - all_or_conditions = state._or_conditions + [condition] - combined_condition = ast.BoolOp(op=ast.Or(), values=all_or_conditions) - - # Add the combined condition to the loop and clear OR conditions - if isinstance(state.result, CompExp) and state.result.generators: - updated_loop = ast.comprehension( - target=state.result.generators[-1].target, - iter=state.result.generators[-1].iter, - ifs=state.result.generators[-1].ifs + [combined_condition], - is_async=state.result.generators[-1].is_async, + + if instr.argval < instr.offset: + # Jumping backward to loop start - this is a condition + # When POP_JUMP_IF_FALSE jumps back, it means "if false, skip this item" + # So we need to negate the condition to get the filter condition + assert isinstance(state.result, CompExp) and state.result.generators + updated_loop = ast.comprehension( + target=state.result.generators[-1].target, + iter=state.result.generators[-1].iter, + ifs=state.result.generators[-1].ifs + [condition], + is_async=state.result.generators[-1].is_async, + ) + if isinstance(state.result, ast.DictComp): + new_ret = ast.DictComp( + key=state.result.key, + value=state.result.value, + generators=state.result.generators[:-1] + [updated_loop], ) - if isinstance(state.result, ast.DictComp): - new_ret = ast.DictComp( - key=state.result.key, - value=state.result.value, - generators=state.result.generators[:-1] + [updated_loop], - ) - else: - new_ret = type(state.result)( - elt=state.result.elt, - generators=state.result.generators[:-1] + [updated_loop], - ) - return replace(state, stack=new_stack, result=new_ret, _or_conditions=[]) else: - new_pending = state._pending_conditions + [combined_condition] - return replace(state, stack=new_stack, _pending_conditions=new_pending, _or_conditions=[]) - else: - # Regular condition - add to the most recent loop - if isinstance(state.result, CompExp) and state.result.generators: - updated_loop = ast.comprehension( - target=state.result.generators[-1].target, - iter=state.result.generators[-1].iter, - ifs=state.result.generators[-1].ifs + [condition], - is_async=state.result.generators[-1].is_async, + new_ret = type(state.result)( + elt=state.result.elt, + generators=state.result.generators[:-1] + [updated_loop], ) - if isinstance(state.result, ast.DictComp): - new_ret = ast.DictComp( - key=state.result.key, - value=state.result.value, - generators=state.result.generators[:-1] + [updated_loop], - ) - else: - new_ret = type(state.result)( - elt=state.result.elt, - generators=state.result.generators[:-1] + [updated_loop], - ) - return replace(state, stack=new_stack, result=new_ret) - else: - # If no loops yet, add to pending conditions - new_pending = state._pending_conditions + [condition] - return replace(state, stack=new_stack, _pending_conditions=new_pending) + return replace(state, stack=new_stack, result=new_ret) + else: + raise NotImplementedError("POP_JUMP_IF_FALSE jumping forward not implemented yet") @register_handler('POP_JUMP_IF_TRUE') @@ -771,42 +770,34 @@ def handle_pop_jump_if_true(state: ReconstructionState, instr: dis.Instruction) # 2. A negated condition like "not x % 2" (jump back to loop start) condition = ensure_ast(state.stack[-1]) new_stack = state.stack[:-1] - - # Check if this jumps forward (to YIELD_VALUE - OR pattern) vs back to loop (NOT pattern) - # In OR: POP_JUMP_IF_TRUE jumps forward to yield the value - # In NOT: POP_JUMP_IF_TRUE jumps back to skip this iteration - if instr.argval > instr.offset: - # Jumping forward - part of an OR expression - new_or_conditions = state._or_conditions + [condition] - return replace(state, stack=new_stack, _or_conditions=new_or_conditions) - else: + + if instr.argval < instr.offset: # Jumping backward to loop start - this is a negated condition - # When POP_JUMP_IF_TRUE jumps back, it means "if true, skip this item" + # When POP_JUMP_IF_TRUE jumps back, it means "if false, skip this item" # So we need to negate the condition to get the filter condition - negated_condition = ast.UnaryOp(op=ast.Not(), operand=condition) - - if isinstance(state.result, CompExp) and state.result.generators: - updated_loop = ast.comprehension( - target=state.result.generators[-1].target, - iter=state.result.generators[-1].iter, - ifs=state.result.generators[-1].ifs + [negated_condition], - is_async=state.result.generators[-1].is_async, + assert isinstance(state.result, CompExp) and state.result.generators + # negate the condition + condition = ast.UnaryOp(op=ast.Not(), operand=condition) + updated_loop = ast.comprehension( + target=state.result.generators[-1].target, + iter=state.result.generators[-1].iter, + ifs=state.result.generators[-1].ifs + [condition], + is_async=state.result.generators[-1].is_async, + ) + if isinstance(state.result, ast.DictComp): + new_ret = ast.DictComp( + key=state.result.key, + value=state.result.value, + generators=state.result.generators[:-1] + [updated_loop], ) - if isinstance(state.result, ast.DictComp): - new_ret = ast.DictComp( - key=state.result.key, - value=state.result.value, - generators=state.result.generators[:-1] + [updated_loop], - ) - else: - new_ret = type(state.result)( - elt=state.result.elt, - generators=state.result.generators[:-1] + [updated_loop], - ) - return replace(state, stack=new_stack, result=new_ret) else: - new_pending = state._pending_conditions + [negated_condition] - return replace(state, stack=new_stack, _pending_conditions=new_pending) + new_ret = type(state.result)( + elt=state.result.elt, + generators=state.result.generators[:-1] + [updated_loop], + ) + return replace(state, stack=new_stack, result=new_ret) + else: + raise NotImplementedError("POP_JUMP_IF_TRUE jumping forward not implemented yet") # ============================================================================ @@ -901,45 +892,15 @@ def _ensure_ast_range_iterator(value: Iterator) -> ast.Call: @ensure_ast.register -def _ensure_ast_codeobj(value: types.CodeType) -> CompExp | ast.Lambda: - # Determine return type based on the first instruction - ret: CompExp | ast.Lambda - instructions = list(dis.get_instructions(value)) - if instructions[0].opname == 'GEN_START' and instructions[1].opname == 'LOAD_FAST' and instructions[1].argval == '.0': - ret = ast.GeneratorExp(elt=Placeholder(), generators=[]) - elif instructions[0].opname == 'BUILD_LIST' and instructions[1].opname == 'LOAD_FAST' and instructions[1].argval == '.0': - ret = ast.ListComp(elt=Placeholder(), generators=[]) - elif instructions[0].opname == 'BUILD_SET' and instructions[1].opname == 'LOAD_FAST' and instructions[1].argval == '.0': - ret = ast.SetComp(elt=Placeholder(), generators=[]) - elif instructions[0].opname == 'BUILD_MAP' and instructions[1].opname == 'LOAD_FAST' and instructions[1].argval == '.0': - ret = ast.DictComp(key=Placeholder(), value=Placeholder(), generators=[]) - elif instructions[0].opname in {'BUILD_LIST', 'BUILD_SET', 'BUILD_MAP'}: - raise NotImplementedError("Unpacking construction not implemented yet") - elif instructions[-1].opname == 'RETURN_VALUE': - # not a comprehension, assume it's a lambda - ret = ast.Lambda( - args=ast.arguments( - posonlyargs=[], - args=[], - vararg=None, - kwonlyargs=[], - kwarg=None, - defaults=[], - kw_defaults=[], - ), - body=Placeholder() - ) - else: - raise TypeError("Code type from unsupported source") - +def _ensure_ast_codeobj(value: types.CodeType) -> ast.expr: # Symbolic execution to reconstruct the AST - state = ReconstructionState(result=ret) - for instr in instructions: + state = ReconstructionState() + for instr in dis.get_instructions(value): state = OP_HANDLERS[instr.opname](state, instr) # Check postconditions assert not any(isinstance(x, Placeholder) for x in ast.walk(state.result)), "Return value must not contain placeholders" - assert isinstance(state.result, ast.Lambda) or len(state.result.generators) > 0, "Return value must have generators if not a lambda" + assert not isinstance(state.result, CompExp) or len(state.result.generators) > 0, "Return value must have generators if not a lambda" return state.result From 1e0bd1bac7dc73c533dd814d18a6de2c232e9fb6 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 17 Jun 2025 12:38:16 -0400 Subject: [PATCH 027/106] nested comprehension with lambda --- effectful/internals/genexpr.py | 56 ++++++++++++++++++++++++++-------- 1 file changed, 44 insertions(+), 12 deletions(-) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index 4562ed98..18823bb1 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -34,12 +34,30 @@ def __init__(self): super().__init__(id=".PLACEHOLDER", ctx=ast.Load()) -class IterDummyName(ast.Name): +class DummyIterName(ast.Name): """Dummy name for the iterator variable in generator expressions.""" def __init__(self): super().__init__(id=".0", ctx=ast.Load()) +class CompLambda(ast.Lambda): + """Placeholder AST node representing a lambda function used in comprehensions.""" + def __init__(self, body: CompExp): + assert sum(1 for x in ast.walk(body) if isinstance(x, DummyIterName)) == 1 + assert len(body.generators) > 0 + assert isinstance(body.generators[0].iter, DummyIterName) + super().__init__( + args=ast.arguments( + posonlyargs=[ast.arg(DummyIterName().id)], + args=[], + kwonlyargs=[], + kw_defaults=[], + defaults=[] + ), + body=body + ) + + @dataclass(frozen=True) class ReconstructionState: """State maintained during AST reconstruction from bytecode. @@ -342,7 +360,7 @@ def handle_load_fast(state: ReconstructionState, instr: dis.Instruction) -> Reco if var_name == '.0': # Special handling for .0 variable (the iterator) - new_stack = state.stack + [IterDummyName()] + new_stack = state.stack + [DummyIterName()] else: # Regular variable load new_stack = state.stack + [ast.Name(id=var_name, ctx=ast.Load())] @@ -569,16 +587,22 @@ def handle_is_op(state: ReconstructionState, instr: dis.Instruction) -> Reconstr @register_handler('CALL_FUNCTION') def handle_call_function(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: # CALL_FUNCTION pops function and arguments from stack - arg_count = instr.arg + arg_count: int = instr.arg # Pop arguments and function args = [ensure_ast(arg) for arg in state.stack[-arg_count:]] if arg_count > 0 else [] func = ensure_ast(state.stack[-arg_count - 1]) new_stack = state.stack[:-arg_count - 1] - - # Create function call AST - call_node = ast.Call(func=func, args=args, keywords=[]) - new_stack = new_stack + [call_node] - return replace(state, stack=new_stack) + + if isinstance(func, CompLambda): + assert len(args) == 1 + comp_body: CompExp = func.body + comp_body.generators[0].iter = args[0] + return replace(state, stack=new_stack + [comp_body]) + else: + # Create function call AST + call_node = ast.Call(func=func, args=args, keywords=[]) + new_stack = new_stack + [call_node] + return replace(state, stack=new_stack) @register_handler('LOAD_METHOD') @@ -617,8 +641,15 @@ def handle_call_method(state: ReconstructionState, instr: dis.Instruction) -> Re def handle_make_function(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: # MAKE_FUNCTION creates a function from code object and name on stack assert instr.arg == 0, "MAKE_FUNCTION with defaults or annotations not allowed." + assert isinstance(state.stack[-1], ast.Constant) and isinstance(state.stack[-1].value, str), "Function name must be a constant string." + body: ast.expr = state.stack[-2] + name: str = state.stack[-1].value - raise NotImplementedError("Lambda reconstruction not implemented yet") + if isinstance(body, CompExp) and sum(1 for x in ast.walk(body) if isinstance(x, DummyIterName)) == 1: + new_stack = state.stack[:-2] + [CompLambda(body)] + return replace(state, stack=new_stack) + else: + raise NotImplementedError("Lambda reconstruction not implemented yet") # ============================================================================ @@ -967,7 +998,8 @@ def reconstruct(genexpr: types.GeneratorType) -> ast.GeneratorExp: assert inspect.isgenerator(genexpr), "Input must be a generator expression" assert inspect.getgeneratorstate(genexpr) == inspect.GEN_CREATED, "Generator must be in created state" genexpr_ast: ast.GeneratorExp = ensure_ast(genexpr.gi_code) - assert isinstance(genexpr_ast.generators[0].iter, IterDummyName) - assert len([x for x in ast.walk(genexpr_ast) if isinstance(x, IterDummyName)]) == 1 - genexpr_ast.generators[0].iter = ensure_ast(genexpr.gi_frame.f_locals['.0']) + geniter_ast: ast.expr = ensure_ast(genexpr.gi_frame.f_locals['.0']) + assert isinstance(genexpr_ast.generators[0].iter, DummyIterName) + assert len([x for x in ast.walk(genexpr_ast) if isinstance(x, DummyIterName)]) == 1 + genexpr_ast.generators[0].iter = geniter_ast return genexpr_ast From 7560626d8a90ff63b68ca6ba8d4bd5365c3df928 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 17 Jun 2025 12:43:04 -0400 Subject: [PATCH 028/106] inline --- effectful/internals/genexpr.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index 18823bb1..f58882fd 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -16,6 +16,7 @@ """ import ast +import copy import dis import functools import inspect @@ -57,6 +58,11 @@ def __init__(self, body: CompExp): body=body ) + def inline(self, iterator: ast.expr) -> CompExp: + res: CompExp = copy.copy(self.body) + res.generators[0].iter = iterator + return res + @dataclass(frozen=True) class ReconstructionState: @@ -595,9 +601,7 @@ def handle_call_function(state: ReconstructionState, instr: dis.Instruction) -> if isinstance(func, CompLambda): assert len(args) == 1 - comp_body: CompExp = func.body - comp_body.generators[0].iter = args[0] - return replace(state, stack=new_stack + [comp_body]) + return replace(state, stack=new_stack + [func.inline(args[0])]) else: # Create function call AST call_node = ast.Call(func=func, args=args, keywords=[]) @@ -999,7 +1003,4 @@ def reconstruct(genexpr: types.GeneratorType) -> ast.GeneratorExp: assert inspect.getgeneratorstate(genexpr) == inspect.GEN_CREATED, "Generator must be in created state" genexpr_ast: ast.GeneratorExp = ensure_ast(genexpr.gi_code) geniter_ast: ast.expr = ensure_ast(genexpr.gi_frame.f_locals['.0']) - assert isinstance(genexpr_ast.generators[0].iter, DummyIterName) - assert len([x for x in ast.walk(genexpr_ast) if isinstance(x, DummyIterName)]) == 1 - genexpr_ast.generators[0].iter = geniter_ast - return genexpr_ast + return CompLambda(genexpr_ast).inline(geniter_ast) From bd4a634e0541bdabd5322e3d060ed997b052b95a Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 17 Jun 2025 12:58:36 -0400 Subject: [PATCH 029/106] lambda --- effectful/internals/genexpr.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index f58882fd..d3ee426b 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -945,8 +945,18 @@ def _ensure_ast_codeobj(value: types.CodeType) -> ast.expr: @ensure_ast.register def _ensure_ast_lambda(value: types.LambdaType) -> ast.Lambda: - assert inspect.isfunction(value), "Input must be a lambda function" - raise NotImplementedError("Lambda reconstruction not implemented yet") + assert inspect.isfunction(value) and value.__name__.endswith(""), "Input must be a lambda function" + + code: types.CodeType = value.__code__ + body: ast.expr = ensure_ast(code) + args = ast.arguments( + posonlyargs=[ast.arg(arg=arg) for arg in code.co_varnames[:code.co_posonlyargcount]], + args=[ast.arg(arg=arg) for arg in code.co_varnames[code.co_posonlyargcount:code.co_argcount]], + kwonlyargs=[ast.arg(arg=arg) for arg in code.co_varnames[code.co_argcount:code.co_argcount + code.co_kwonlyargcount]], + kw_defaults=[], + defaults=[], + ) + return ast.Lambda(args=args, body=body) @ensure_ast.register From 123b09c4a628ee5a489ad73bda8bfb11a1cc1007 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 17 Jun 2025 13:31:54 -0400 Subject: [PATCH 030/106] nested comprehensions --- effectful/internals/genexpr.py | 64 ++++++++++++++++++++++++++++-- tests/test_ops_syntax_generator.py | 25 ++++++++++++ 2 files changed, 85 insertions(+), 4 deletions(-) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index d3ee426b..7296c21d 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -271,7 +271,7 @@ def handle_return_value(state: ReconstructionState, instr: dis.Instruction) -> R # Usually preceded by LOAD_CONST None if isinstance(state.result, CompExp): return replace(state, stack=state.stack[:-1]) - elif isinstance(state.result, Placeholder): + elif isinstance(state.result, Placeholder) and len(state.stack) == 1: return replace(state, stack=state.stack[:-1], result=ensure_ast(state.stack[-1])) else: raise TypeError("Unexpected RETURN_VALUE in reconstruction") @@ -429,6 +429,56 @@ def handle_load_name(state: ReconstructionState, instr: dis.Instruction) -> Reco return replace(state, stack=new_stack) +@register_handler('STORE_DEREF') +def handle_store_deref(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # STORE_DEREF stores a value into a closure variable + assert isinstance(state.result, CompExp), "STORE_DEREF must be called within a comprehension context" + var_name = instr.argval + + # Update the most recent loop's target variable + assert len(state.result.generators) > 0, "STORE_DEREF must be within a loop context" + + # Create a new LoopInfo with updated target + updated_loop = ast.comprehension( + target=ast.Name(id=var_name, ctx=ast.Store()), + iter=state.result.generators[-1].iter, + ifs=state.result.generators[-1].ifs, + is_async=state.result.generators[-1].is_async + ) + + # Update the last loop in the generators list + if isinstance(state.result, ast.DictComp): + new_ret = ast.DictComp( + key=state.result.key, + value=state.result.value, + generators=state.result.generators[:-1] + [updated_loop], + ) + else: + new_ret = type(state.result)( + elt=state.result.elt, + generators=state.result.generators[:-1] + [updated_loop], + ) + + # Create new loops list with the updated loop + return replace(state, result=new_ret) + + +@register_handler('LOAD_DEREF') +def handle_load_deref(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # LOAD_DEREF loads a value from a closure variable + var_name = instr.argval + new_stack = state.stack + [ast.Name(id=var_name, ctx=ast.Load())] + return replace(state, stack=new_stack) + + +@register_handler('LOAD_CLOSURE') +def handle_load_closure(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + # LOAD_CLOSURE loads a closure variable + var_name = instr.argval + new_stack = state.stack + [ast.Name(id=var_name, ctx=ast.Load())] + return replace(state, stack=new_stack) + + # ============================================================================ # STACK MANAGEMENT HANDLERS # ============================================================================ @@ -644,14 +694,20 @@ def handle_call_method(state: ReconstructionState, instr: dis.Instruction) -> Re @register_handler('MAKE_FUNCTION') def handle_make_function(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: # MAKE_FUNCTION creates a function from code object and name on stack - assert instr.arg == 0, "MAKE_FUNCTION with defaults or annotations not allowed." assert isinstance(state.stack[-1], ast.Constant) and isinstance(state.stack[-1].value, str), "Function name must be a constant string." + if instr.argrepr == 'closure': + # This is a closure, remove the environment tuple from the stack for AST purposes + new_stack = state.stack[:-3] + elif instr.argrepr == '': + new_stack = state.stack[:-2] + else: + raise NotImplementedError("MAKE_FUNCTION with defaults or annotations not implemented.") + body: ast.expr = state.stack[-2] name: str = state.stack[-1].value if isinstance(body, CompExp) and sum(1 for x in ast.walk(body) if isinstance(x, DummyIterName)) == 1: - new_stack = state.stack[:-2] + [CompLambda(body)] - return replace(state, stack=new_stack) + return replace(state, stack=new_stack + [CompLambda(body)]) else: raise NotImplementedError("Lambda reconstruction not implemented yet") diff --git a/tests/test_ops_syntax_generator.py b/tests/test_ops_syntax_generator.py index c016cd47..98efeb31 100644 --- a/tests/test_ops_syntax_generator.py +++ b/tests/test_ops_syntax_generator.py @@ -336,6 +336,31 @@ def test_nested_loops(genexpr): assert_ast_equivalent(genexpr, ast_node) +# =========================================================================== +# NESTED COMPREHENSIONS +# =========================================================================== + +@pytest.mark.parametrize("genexpr", [ + ([x for x in range(i)] for i in range(5)), + ({x: x**2 for x in range(i)} for i in range(5)), + ([[x for x in range(i + j)] for j in range(i)] for i in range(5)), + + # Nested comprehensions with filters inside + ([x for x in range(i)] for i in range(5) if i > 0), + ([x for x in range(i) if x < i] for i in range(5) if i > 0), + ([[x for x in range(i + j) if x < i + j] for j in range(i)] for i in range(5)), + ([[x for x in range(i + j) if x < i + j] for j in range(i)] for i in range(5) if i > 0), + + # nesting on both sides + ([y for y in range(x)] for x in (x_ + 1 for x_ in range(5))), + ([y for y in range(x)] for x in (x_ + 1 for x_ in range(5))), +]) +def test_nested_comprehensions(genexpr): + """Test reconstruction of nested comprehensions.""" + ast_node = reconstruct(genexpr) + assert_ast_equivalent(genexpr, ast_node) + + # ============================================================================ # DIFFERENT COMPREHENSION TYPES # ============================================================================ From 07e880c6e873c021c4ecf9e88f1947cf4227e522 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 17 Jun 2025 13:35:12 -0400 Subject: [PATCH 031/106] test case --- tests/test_ops_syntax_generator.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_ops_syntax_generator.py b/tests/test_ops_syntax_generator.py index 98efeb31..1547b37a 100644 --- a/tests/test_ops_syntax_generator.py +++ b/tests/test_ops_syntax_generator.py @@ -345,6 +345,9 @@ def test_nested_loops(genexpr): ({x: x**2 for x in range(i)} for i in range(5)), ([[x for x in range(i + j)] for j in range(i)] for i in range(5)), + # function call + (sum(x for x in range(i + 1)) for i in range(3)), + # Nested comprehensions with filters inside ([x for x in range(i)] for i in range(5) if i > 0), ([x for x in range(i) if x < i] for i in range(5) if i > 0), From b08ee5f551da03d328f95460a4ffae31d96de07e Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 17 Jun 2025 13:40:43 -0400 Subject: [PATCH 032/106] test case --- tests/test_ops_syntax_generator.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/test_ops_syntax_generator.py b/tests/test_ops_syntax_generator.py index 1547b37a..fe2e4795 100644 --- a/tests/test_ops_syntax_generator.py +++ b/tests/test_ops_syntax_generator.py @@ -56,7 +56,6 @@ def assert_ast_equivalent(genexpr: GeneratorType, reconstructed_ast: ast.AST, gl # Compile and evaluate the reconstructed AST reconstructed_gen = compile_and_eval(reconstructed_ast, globals_dict) reconstructed_list = list(reconstructed_gen) - assert reconstructed_list == original_list, \ f"AST produced {reconstructed_list}, expected {original_list}" @@ -345,8 +344,12 @@ def test_nested_loops(genexpr): ({x: x**2 for x in range(i)} for i in range(5)), ([[x for x in range(i + j)] for j in range(i)] for i in range(5)), - # function call + # aggregation function call (sum(x for x in range(i + 1)) for i in range(3)), + (max(x for x in range(i + 1)) for i in range(3)), + + # map + (list(map(abs, (x + 1 for x in range(i + 1)))) for i in range(3)), # Nested comprehensions with filters inside ([x for x in range(i)] for i in range(5) if i > 0), From 942aff07137af8067cf9c495ef99e8df11eab8db Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 17 Jun 2025 13:46:07 -0400 Subject: [PATCH 033/106] test case --- effectful/internals/genexpr.py | 4 ++-- tests/test_ops_syntax_generator.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index 7296c21d..3ecf42a2 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -850,7 +850,7 @@ def handle_pop_jump_if_false(state: ReconstructionState, instr: dis.Instruction) ) return replace(state, stack=new_stack, result=new_ret) else: - raise NotImplementedError("POP_JUMP_IF_FALSE jumping forward not implemented yet") + raise NotImplementedError("Lazy and+or behavior not implemented yet") @register_handler('POP_JUMP_IF_TRUE') @@ -888,7 +888,7 @@ def handle_pop_jump_if_true(state: ReconstructionState, instr: dis.Instruction) ) return replace(state, stack=new_stack, result=new_ret) else: - raise NotImplementedError("POP_JUMP_IF_TRUE jumping forward not implemented yet") + raise NotImplementedError("Lazy and+or behavior not implemented yet") # ============================================================================ diff --git a/tests/test_ops_syntax_generator.py b/tests/test_ops_syntax_generator.py index fe2e4795..24a178b2 100644 --- a/tests/test_ops_syntax_generator.py +++ b/tests/test_ops_syntax_generator.py @@ -350,6 +350,7 @@ def test_nested_loops(genexpr): # map (list(map(abs, (x + 1 for x in range(i + 1)))) for i in range(3)), + (list(enumerate(x + 1 for x in range(i + 1))) for i in range(3)), # Nested comprehensions with filters inside ([x for x in range(i)] for i in range(5) if i > 0), From c5f9ad050074258dbad3b4084da5003c26f57eaa Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 17 Jun 2025 13:56:05 -0400 Subject: [PATCH 034/106] test nits --- tests/test_ops_syntax_generator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_ops_syntax_generator.py b/tests/test_ops_syntax_generator.py index 24a178b2..aec7b5dd 100644 --- a/tests/test_ops_syntax_generator.py +++ b/tests/test_ops_syntax_generator.py @@ -25,10 +25,10 @@ def compile_and_eval(node: ast.AST, globals_dict: dict = None) -> Any: return eval(code, globals_dict) -def assert_ast_equivalent(genexpr: GeneratorType, reconstructed_ast: ast.AST, globals_dict: dict = None): +def assert_ast_equivalent(genexpr: GeneratorType, reconstructed_ast: ast.AST, globals_dict: dict | None = None): """Assert that a reconstructed AST produces the same results as the original generator.""" assert inspect.isgenerator(genexpr), "Input must be a generator" - assert inspect.getgeneratorstate(genexpr) == 'GEN_CREATED', "Generator must not be consumed" + assert inspect.getgeneratorstate(genexpr) == inspect.GEN_CREATED, "Generator must not be consumed" # Check AST structure assert isinstance(reconstructed_ast, ast.GeneratorExp) From 921e45106d89fe4abc17840923d4c051bdd936e9 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 17 Jun 2025 14:02:04 -0400 Subject: [PATCH 035/106] postcondition --- effectful/internals/genexpr.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index 3ecf42a2..ab56c46d 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -1069,4 +1069,6 @@ def reconstruct(genexpr: types.GeneratorType) -> ast.GeneratorExp: assert inspect.getgeneratorstate(genexpr) == inspect.GEN_CREATED, "Generator must be in created state" genexpr_ast: ast.GeneratorExp = ensure_ast(genexpr.gi_code) geniter_ast: ast.expr = ensure_ast(genexpr.gi_frame.f_locals['.0']) - return CompLambda(genexpr_ast).inline(geniter_ast) + result = CompLambda(genexpr_ast).inline(geniter_ast) + assert inspect.getgeneratorstate(genexpr) == inspect.GEN_CREATED, "Generator must stay in created state" + return result From 2eb4ee6599bf1950db396717ffd37811822cc69b Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 17 Jun 2025 14:02:25 -0400 Subject: [PATCH 036/106] postcondition --- tests/test_ops_syntax_generator.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_ops_syntax_generator.py b/tests/test_ops_syntax_generator.py index aec7b5dd..d7502dfa 100644 --- a/tests/test_ops_syntax_generator.py +++ b/tests/test_ops_syntax_generator.py @@ -27,9 +27,6 @@ def compile_and_eval(node: ast.AST, globals_dict: dict = None) -> Any: def assert_ast_equivalent(genexpr: GeneratorType, reconstructed_ast: ast.AST, globals_dict: dict | None = None): """Assert that a reconstructed AST produces the same results as the original generator.""" - assert inspect.isgenerator(genexpr), "Input must be a generator" - assert inspect.getgeneratorstate(genexpr) == inspect.GEN_CREATED, "Generator must not be consumed" - # Check AST structure assert isinstance(reconstructed_ast, ast.GeneratorExp) assert hasattr(reconstructed_ast, 'elt') # The expression part From d10e204586010bd98527504b45c07410e30bd3e5 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 17 Jun 2025 14:22:58 -0400 Subject: [PATCH 037/106] move behavior out of reconstruct body --- effectful/internals/genexpr.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index ab56c46d..6033d31f 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -995,10 +995,6 @@ def _ensure_ast_codeobj(value: types.CodeType) -> ast.expr: return state.result -# ============================================================================ -# MAIN RECONSTRUCTION FUNCTION -# ============================================================================ - @ensure_ast.register def _ensure_ast_lambda(value: types.LambdaType) -> ast.Lambda: assert inspect.isfunction(value) and value.__name__.endswith(""), "Input must be a lambda function" @@ -1016,7 +1012,21 @@ def _ensure_ast_lambda(value: types.LambdaType) -> ast.Lambda: @ensure_ast.register -def reconstruct(genexpr: types.GeneratorType) -> ast.GeneratorExp: +def _ensure_ast_genexpr(genexpr: types.GeneratorType) -> ast.GeneratorExp: + assert inspect.isgenerator(genexpr), "Input must be a generator expression" + assert inspect.getgeneratorstate(genexpr) == inspect.GEN_CREATED, "Generator must be in created state" + genexpr_ast: ast.GeneratorExp = ensure_ast(genexpr.gi_code) + geniter_ast: ast.expr = ensure_ast(genexpr.gi_frame.f_locals['.0']) + result = CompLambda(genexpr_ast).inline(geniter_ast) + assert inspect.getgeneratorstate(genexpr) == inspect.GEN_CREATED, "Generator must stay in created state" + return result + + +# ============================================================================ +# MAIN RECONSTRUCTION FUNCTION +# ============================================================================ + +def reconstruct(genexpr: types.GeneratorType[object, None, None]) -> ast.GeneratorExp: """ Reconstruct an AST from a generator expression's bytecode. @@ -1065,10 +1075,4 @@ def reconstruct(genexpr: types.GeneratorType) -> ast.GeneratorExp: cases. However, the semantic behavior of the reconstructed AST should match the original comprehension. """ - assert inspect.isgenerator(genexpr), "Input must be a generator expression" - assert inspect.getgeneratorstate(genexpr) == inspect.GEN_CREATED, "Generator must be in created state" - genexpr_ast: ast.GeneratorExp = ensure_ast(genexpr.gi_code) - geniter_ast: ast.expr = ensure_ast(genexpr.gi_frame.f_locals['.0']) - result = CompLambda(genexpr_ast).inline(geniter_ast) - assert inspect.getgeneratorstate(genexpr) == inspect.GEN_CREATED, "Generator must stay in created state" - return result + return ensure_ast(genexpr) From 3516788e976b3fd5e8702c2a26ddb450f3f02a18 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 17 Jun 2025 14:47:05 -0400 Subject: [PATCH 038/106] wrap --- effectful/internals/genexpr.py | 11 ++++++----- tests/test_ops_syntax_generator.py | 10 +++++----- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index 6033d31f..ea59ec6c 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -22,7 +22,7 @@ import inspect import types import typing -from collections.abc import Callable, Iterator +from collections.abc import Callable, Iterator, Generator from dataclasses import dataclass, field, replace @@ -1026,7 +1026,7 @@ def _ensure_ast_genexpr(genexpr: types.GeneratorType) -> ast.GeneratorExp: # MAIN RECONSTRUCTION FUNCTION # ============================================================================ -def reconstruct(genexpr: types.GeneratorType[object, None, None]) -> ast.GeneratorExp: +def reconstruct(genexpr: Generator[object, None, None]) -> ast.Expression: """ Reconstruct an AST from a generator expression's bytecode. @@ -1043,8 +1043,8 @@ def reconstruct(genexpr: types.GeneratorType[object, None, None]) -> ast.Generat - Various operators and function calls Args: - genexpr (GeneratorType): The generator object to analyze. Must be - a freshly created generator that has not been iterated yet + genexpr (Generator[object, None, None]): The generator object to analyze. + Must be a freshly created generator that has not been iterated yet (in 'GEN_CREATED' state). Returns: @@ -1075,4 +1075,5 @@ def reconstruct(genexpr: types.GeneratorType[object, None, None]) -> ast.Generat cases. However, the semantic behavior of the reconstructed AST should match the original comprehension. """ - return ensure_ast(genexpr) + assert inspect.isgenerator(genexpr), "Input must be a generator expression" + return ast.fix_missing_locations(ast.Expression(ensure_ast(genexpr))) diff --git a/tests/test_ops_syntax_generator.py b/tests/test_ops_syntax_generator.py index d7502dfa..a724f998 100644 --- a/tests/test_ops_syntax_generator.py +++ b/tests/test_ops_syntax_generator.py @@ -28,11 +28,11 @@ def compile_and_eval(node: ast.AST, globals_dict: dict = None) -> Any: def assert_ast_equivalent(genexpr: GeneratorType, reconstructed_ast: ast.AST, globals_dict: dict | None = None): """Assert that a reconstructed AST produces the same results as the original generator.""" # Check AST structure - assert isinstance(reconstructed_ast, ast.GeneratorExp) - assert hasattr(reconstructed_ast, 'elt') # The expression part - assert hasattr(reconstructed_ast, 'generators') # The comprehension part - assert len(reconstructed_ast.generators) > 0 - for comp in reconstructed_ast.generators: + assert isinstance(reconstructed_ast, ast.Expression) + assert hasattr(reconstructed_ast.body, 'elt') # The expression part + assert hasattr(reconstructed_ast.body, 'generators') # The comprehension part + assert len(reconstructed_ast.body.generators) > 0 + for comp in reconstructed_ast.body.generators: assert hasattr(comp, 'target') # Loop variable assert hasattr(comp, 'iter') # Iterator assert hasattr(comp, 'ifs') # Conditions From f727ed7bf1d636e77289a88c0fcdd08455268f4d Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 17 Jun 2025 14:47:47 -0400 Subject: [PATCH 039/106] wrap --- effectful/internals/genexpr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index ea59ec6c..e2e1968e 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -1064,7 +1064,7 @@ def reconstruct(genexpr: Generator[object, None, None]) -> ast.Expression: >>> # The reconstructed AST can be compiled and evaluated >>> import ast - >>> code = compile(ast.Expression(body=ast_node), '', 'eval') + >>> code = compile(ast_node, '', 'eval') >>> result = eval(code) >>> list(result) [0, 4, 8, 12, 16] From 6f2511e77bf2dd94586fec08be60d0bfd3e53a66 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 17 Jun 2025 14:48:24 -0400 Subject: [PATCH 040/106] wrap --- effectful/internals/genexpr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index e2e1968e..673ad760 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -1059,7 +1059,7 @@ def reconstruct(genexpr: Generator[object, None, None]) -> ast.Expression: >>> # Generator expression >>> g = (x * 2 for x in range(10) if x % 2 == 0) >>> ast_node = reconstruct(g) - >>> isinstance(ast_node, ast.GeneratorExp) + >>> isinstance(ast_node, ast.Expression) True >>> # The reconstructed AST can be compiled and evaluated From cda21b7f542783ee1ee47db3660e09694e24ade5 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 17 Jun 2025 14:50:31 -0400 Subject: [PATCH 041/106] doc --- effectful/internals/genexpr.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index 673ad760..4494c18b 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -1048,8 +1048,7 @@ def reconstruct(genexpr: Generator[object, None, None]) -> ast.Expression: (in 'GEN_CREATED' state). Returns: - ast.GeneratorExp: An AST node representing the reconstructed comprehension. - The specific type depends on the original comprehension: + ast.Expression: An AST node representing the reconstructed comprehension. Raises: AssertionError: If the input is not a generator or if the generator From 22154d8cddf188f4edcabec92dd4807ae9c3bf45 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 17 Jun 2025 14:50:47 -0400 Subject: [PATCH 042/106] doc --- effectful/internals/genexpr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/effectful/internals/genexpr.py b/effectful/internals/genexpr.py index 4494c18b..2988d7a2 100644 --- a/effectful/internals/genexpr.py +++ b/effectful/internals/genexpr.py @@ -12,7 +12,7 @@ Example: >>> g = (x * 2 for x in range(10) if x % 2 == 0) >>> ast_node = reconstruct(g) - >>> # ast_node is now an ast.GeneratorExp representing the original expression + >>> # ast_node is now an ast.Expression representing the original expression """ import ast From b9b3c82d053a3304e0a3d5239b4628915735795f Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 17 Jun 2025 15:40:21 -0400 Subject: [PATCH 043/106] rename module --- effectful/internals/{genexpr.py => disassembler.py} | 0 ...ops_syntax_generator.py => test_internals_disassembler.py} | 4 ++-- 2 files changed, 2 insertions(+), 2 deletions(-) rename effectful/internals/{genexpr.py => disassembler.py} (100%) rename tests/{test_ops_syntax_generator.py => test_internals_disassembler.py} (99%) diff --git a/effectful/internals/genexpr.py b/effectful/internals/disassembler.py similarity index 100% rename from effectful/internals/genexpr.py rename to effectful/internals/disassembler.py diff --git a/tests/test_ops_syntax_generator.py b/tests/test_internals_disassembler.py similarity index 99% rename from tests/test_ops_syntax_generator.py rename to tests/test_internals_disassembler.py index a724f998..c99d079d 100644 --- a/tests/test_ops_syntax_generator.py +++ b/tests/test_internals_disassembler.py @@ -5,7 +5,7 @@ from types import GeneratorType from typing import Any, Union -from effectful.internals.genexpr import reconstruct +from effectful.internals.disassembler import reconstruct def compile_and_eval(node: ast.AST, globals_dict: dict = None) -> Any: @@ -566,7 +566,7 @@ def test_complex_scenarios(genexpr, globals_dict): ]) def test_ensure_ast(value, expected_str): """Test that ensure_ast correctly converts various values to AST nodes.""" - from effectful.internals.genexpr import ensure_ast + from effectful.internals.disassembler import ensure_ast result = ensure_ast(value) From 209aae33a95194d09fcc8c3443a3ded040ee4215 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 17 Jun 2025 15:47:29 -0400 Subject: [PATCH 044/106] lint --- effectful/internals/disassembler.py | 3 +-- tests/test_internals_disassembler.py | 9 ++++----- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index 2988d7a2..0bfa2850 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -22,10 +22,9 @@ import inspect import types import typing -from collections.abc import Callable, Iterator, Generator +from collections.abc import Callable, Generator, Iterator from dataclasses import dataclass, field, replace - CompExp = ast.GeneratorExp | ast.ListComp | ast.SetComp | ast.DictComp diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index c99d079d..c928c6b2 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -1,14 +1,13 @@ import ast -import pytest -import dis -import inspect from types import GeneratorType -from typing import Any, Union +from typing import Any + +import pytest from effectful.internals.disassembler import reconstruct -def compile_and_eval(node: ast.AST, globals_dict: dict = None) -> Any: +def compile_and_eval(node: ast.expr | ast.Expression, globals_dict: dict | None = None) -> Any: """Compile an AST node and evaluate it.""" if globals_dict is None: globals_dict = {} From a02c58e1edd6259bc6b7cdd7c54a5d0ab970149a Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 17 Jun 2025 15:57:01 -0400 Subject: [PATCH 045/106] copy --- effectful/internals/disassembler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index 0bfa2850..a7024461 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -58,7 +58,7 @@ def __init__(self, body: CompExp): ) def inline(self, iterator: ast.expr) -> CompExp: - res: CompExp = copy.copy(self.body) + res: CompExp = copy.deepcopy(self.body) res.generators[0].iter = iterator return res @@ -119,7 +119,7 @@ def register_handler(opname: str, handler = None): @functools.wraps(handler) def _wrapper(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: assert instr.opname == opname, f"Handler for '{opname}' called with wrong instruction" - return handler(state, instr) + return handler(copy.deepcopy(state), instr) OP_HANDLERS[opname] = _wrapper return _wrapper From 1746ab136a9439e177c28fe5688b6d187e5be6a7 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 17 Jun 2025 15:58:07 -0400 Subject: [PATCH 046/106] no copy --- effectful/internals/disassembler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index a7024461..a06e004b 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -119,7 +119,7 @@ def register_handler(opname: str, handler = None): @functools.wraps(handler) def _wrapper(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: assert instr.opname == opname, f"Handler for '{opname}' called with wrong instruction" - return handler(copy.deepcopy(state), instr) + return handler(state, instr) OP_HANDLERS[opname] = _wrapper return _wrapper From cbe5b40219be436629d94d9cb288af6f09c85e89 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 18 Jun 2025 09:09:25 -0400 Subject: [PATCH 047/106] xfail --- tests/test_internals_disassembler.py | 33 ++++++++++++++-------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index c928c6b2..429a3db7 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -158,10 +158,10 @@ def test_arithmetic_expressions(genexpr): (x for x in [1, None, 3, None, 5] if x is None), # Boolean operations - these are complex cases that might need special handling - (x for x in range(10) if x > 2 and x < 8), - (x for x in range(10) if x < 3 or x > 7), (x for x in range(10) if not x % 2), (x for x in range(10) if not (x > 5)), + (x for x in range(10) if x > 2 and x < 8), + pytest.param((x for x in range(10) if x < 3 or x > 7), marks=pytest.mark.xfail(reason="Lambda reconstruction not implemented yet")), # More complex comparison edge cases # Comparisons with expressions @@ -177,14 +177,14 @@ def test_arithmetic_expressions(genexpr): (x for x in range(10) if x not in []), # Empty container # Complex boolean combinations - (x for x in range(20) if x > 5 and x < 15 and x % 2 == 0), - (x for x in range(20) if x < 5 or x > 15 or x == 10), - (x for x in range(20) if not (x > 5 and x < 15)), # FIXME (x for x in range(20) if not (x < 5 or x > 15)), + (x for x in range(20) if x > 5 and x < 15 and x % 2 == 0), + pytest.param((x for x in range(20) if x < 5 or x > 15 or x == 10), marks=pytest.mark.xfail(reason="Lambda reconstruction not implemented yet")), + pytest.param((x for x in range(20) if not (x > 5 and x < 15)), marks=pytest.mark.xfail(reason="Lambda reconstruction not implemented yet")), # Mixed comparison and boolean operations - (x for x in range(20) if (x > 10 and x % 2 == 0) or (x < 5 and x % 3 == 0)), # FIXME - (x for x in range(20) if not (x % 2 == 0 and x % 3 == 0)), # FIXME + pytest.param((x for x in range(20) if (x > 10 and x % 2 == 0) or (x < 5 and x % 3 == 0)), marks=pytest.mark.xfail(reason="Lambda reconstruction not implemented yet")), + pytest.param((x for x in range(20) if not (x % 2 == 0 and x % 3 == 0)), marks=pytest.mark.xfail(reason="Lambda reconstruction not implemented yet")), # Edge cases with identity comparisons (x for x in [0, 1, 2, None, 4] if x is not None and x > 1), @@ -206,7 +206,6 @@ def test_comparison_operators(genexpr): # Chained comparisons (x for x in range(20) if 5 < x < 15), (x for x in range(20) if 0 <= x <= 10), - (x for x in range(20) if x >= 5 and x <= 15), ]) def test_chained_comparison_operators(genexpr): """Test reconstruction of chained (ternary) comparison operators.""" @@ -234,21 +233,21 @@ def test_chained_comparison_operators(genexpr): (x ** 2 for x in range(10) if x > 3), # Boolean operations in filters - (x for x in range(10) if x > 2 and x < 8), - (x for x in range(10) if x < 3 or x > 7), (x for x in range(10) if not x % 2), + (x for x in range(10) if x > 2 and x < 8), + pytest.param((x for x in range(10) if x < 3 or x > 7), marks=pytest.mark.xfail(reason="Lazy conjunctions not implemented yet")), # More complex filter edge cases (x for x in range(50) if x % 7 == 0), # Different modulo (x for x in range(10) if x >= 0), # Always true condition (x for x in range(10) if x < 0), # Always false condition (x for x in range(20) if x % 2 == 0 and x % 3 == 0), # Multiple conditions with and - (x for x in range(20) if x % 2 == 0 or x % 3 == 0), # Multiple conditions with or + pytest.param((x for x in range(20) if x % 2 == 0 or x % 3 == 0), marks=pytest.mark.xfail(reason="Lazy conjunctions not implemented yet")), # Multiple conditions with or # Nested boolean operations - (x for x in range(20) if (x > 5 and x < 15) or x == 0), # FIXME - (x for x in range(20) if not (x > 10 and x < 15)), # FIXME - (x for x in range(50) if x > 10 and (x % 2 == 0 or x % 3 == 0)), + pytest.param((x for x in range(20) if (x > 5 and x < 15) or x == 0), marks=pytest.mark.xfail(reason="Lazy conjunctions not implemented yet")), + pytest.param((x for x in range(20) if not (x > 10 and x < 15)), marks=pytest.mark.xfail(reason="Lazy conjunctions not implemented yet")), + pytest.param((x for x in range(50) if x > 10 and (x % 2 == 0 or x % 3 == 0)), marks=pytest.mark.xfail(reason="Lazy conjunctions not implemented yet")), # Multiple consecutive filters (x for x in range(100) if x > 20 if x < 80 if x % 10 == 0), @@ -415,9 +414,9 @@ def test_variable_lookup(genexpr, globals_dict): @pytest.mark.parametrize("genexpr,globals_dict", [ # Using lambdas and functions - (((lambda y: y * 2)(x) for x in range(5)), {}), - (((lambda y: y + 1)(x) for x in range(5)), {}), - (((lambda y: y ** 2)(x) for x in range(5)), {}), + pytest.param(((lambda y: y * 2)(x) for x in range(5)), {}, marks=pytest.mark.xfail(reason="Lambda reconstruction not implemented yet")), + pytest.param(((lambda y: y + 1)(x) for x in range(5)), {}, marks=pytest.mark.xfail(reason="Lambda reconstruction not implemented yet")), + pytest.param(((lambda y: y ** 2)(x) for x in range(5)), {}, marks=pytest.mark.xfail(reason="Lambda reconstruction not implemented yet")), # More complex lambdas # (((lambda a, b: a + b)(x, x) for x in range(5)), {}), From 1354734cb5db9c61bcf31c981959824d8175847a Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 18 Jun 2025 18:31:35 -0400 Subject: [PATCH 048/106] format --- effectful/internals/disassembler.py | 657 +++++++++++++------- tests/test_internals_disassembler.py | 897 ++++++++++++++------------- 2 files changed, 909 insertions(+), 645 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index a06e004b..cfa3f9c5 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -30,18 +30,21 @@ class Placeholder(ast.Name): """Placeholder for AST nodes that are not yet resolved.""" + def __init__(self): super().__init__(id=".PLACEHOLDER", ctx=ast.Load()) class DummyIterName(ast.Name): """Dummy name for the iterator variable in generator expressions.""" + def __init__(self): super().__init__(id=".0", ctx=ast.Load()) class CompLambda(ast.Lambda): """Placeholder AST node representing a lambda function used in comprehensions.""" + def __init__(self, body: CompExp): assert sum(1 for x in ast.walk(body) if isinstance(x, DummyIterName)) == 1 assert len(body.generators) > 0 @@ -52,9 +55,9 @@ def __init__(self, body: CompExp): args=[], kwonlyargs=[], kw_defaults=[], - defaults=[] + defaults=[], ), - body=body + body=body, ) def inline(self, iterator: ast.expr) -> CompExp: @@ -66,17 +69,17 @@ def inline(self, iterator: ast.expr) -> CompExp: @dataclass(frozen=True) class ReconstructionState: """State maintained during AST reconstruction from bytecode. - + This class tracks all the information needed while processing bytecode instructions to reconstruct the original comprehension's AST. It acts as the working memory during the reconstruction process, maintaining both the evaluation stack state and the high-level comprehension structure being built. - + The reconstruction process works by simulating the Python VM's execution of the bytecode, but instead of executing operations, it builds AST nodes that represent those operations. - + Attributes: result: The current comprehension expression being built. Initially a placeholder, it gets updated as the bytecode is processed. @@ -88,6 +91,7 @@ class ReconstructionState: like LOAD_FAST push to this stack, while operations like BINARY_ADD pop operands and push results. """ + result: CompExp | ast.Lambda | Placeholder = field(default_factory=Placeholder) stack: list[ast.expr] = field(default_factory=list) @@ -99,28 +103,32 @@ class ReconstructionState: @typing.overload -def register_handler(opname: str) -> Callable[[OpHandler], OpHandler]: - ... +def register_handler(opname: str) -> Callable[[OpHandler], OpHandler]: ... + + +@typing.overload +def register_handler(opname: str, handler: OpHandler) -> OpHandler: ... -@typing.overload -def register_handler(opname: str, handler: OpHandler) -> OpHandler: - ... -def register_handler(opname: str, handler = None): +def register_handler(opname: str, handler=None): """Register a handler for a specific operation name""" if handler is None: return functools.partial(register_handler, opname) - + assert opname in dis.opmap, f"Invalid operation name: '{opname}'" if opname in OP_HANDLERS: raise ValueError(f"Handler for '{opname}' already exists.") @functools.wraps(handler) - def _wrapper(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - assert instr.opname == opname, f"Handler for '{opname}' called with wrong instruction" + def _wrapper( + state: ReconstructionState, instr: dis.Instruction + ) -> ReconstructionState: + assert ( + instr.opname == opname + ), f"Handler for '{opname}' called with wrong instruction" return handler(state, instr) - + OP_HANDLERS[opname] = _wrapper return _wrapper @@ -129,20 +137,31 @@ def _wrapper(state: ReconstructionState, instr: dis.Instruction) -> Reconstructi # GENERATOR COMPREHENSION HANDLERS # ============================================================================ -@register_handler('GEN_START') -def handle_gen_start(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + +@register_handler("GEN_START") +def handle_gen_start( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: # GEN_START is typically the first instruction in generator expressions # It initializes the generator - assert isinstance(state.result, Placeholder), "GEN_START must be the first instruction" + assert isinstance( + state.result, Placeholder + ), "GEN_START must be the first instruction" return replace(state, result=ast.GeneratorExp(elt=Placeholder(), generators=[])) -@register_handler('YIELD_VALUE') -def handle_yield_value(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: +@register_handler("YIELD_VALUE") +def handle_yield_value( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: # YIELD_VALUE pops a value from the stack and yields it # This is the expression part of the generator - assert isinstance(state.result, ast.GeneratorExp), "YIELD_VALUE must be called after GEN_START" - assert isinstance(state.result.elt, Placeholder), "YIELD_VALUE must be called before yielding" + assert isinstance( + state.result, ast.GeneratorExp + ), "YIELD_VALUE must be called after GEN_START" + assert isinstance( + state.result.elt, Placeholder + ), "YIELD_VALUE must be called before yielding" assert len(state.result.generators) > 0, "YIELD_VALUE should have generators" new_stack = state.stack[:-1] @@ -157,8 +176,11 @@ def handle_yield_value(state: ReconstructionState, instr: dis.Instruction) -> Re # LIST COMPREHENSION HANDLERS # ============================================================================ -@register_handler('BUILD_LIST') -def handle_build_list(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + +@register_handler("BUILD_LIST") +def handle_build_list( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: if isinstance(state.result, Placeholder) and len(state.stack) == 0: # This BUILD_LIST is the start of a list comprehension # Initialize the result as a ListComp with a placeholder element @@ -168,18 +190,24 @@ def handle_build_list(state: ReconstructionState, instr: dis.Instruction) -> Rec else: size: int = instr.arg # Pop elements for the list - elements = [ensure_ast(elem) for elem in state.stack[-size:]] if size > 0 else [] + elements = ( + [ensure_ast(elem) for elem in state.stack[-size:]] if size > 0 else [] + ) new_stack = state.stack[:-size] if size > 0 else state.stack - + # Create list AST elt_node = ast.List(elts=elements, ctx=ast.Load()) new_stack = new_stack + [elt_node] return replace(state, stack=new_stack) -@register_handler('LIST_APPEND') -def handle_list_append(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - assert isinstance(state.result, ast.ListComp), "LIST_APPEND must be called within a ListComp context" +@register_handler("LIST_APPEND") +def handle_list_append( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + assert isinstance( + state.result, ast.ListComp + ), "LIST_APPEND must be called within a ListComp context" new_stack = state.stack[:-1] new_ret = ast.ListComp( elt=ensure_ast(state.stack[-1]), @@ -192,8 +220,11 @@ def handle_list_append(state: ReconstructionState, instr: dis.Instruction) -> Re # SET COMPREHENSION HANDLERS # ============================================================================ -@register_handler('BUILD_SET') -def handle_build_set(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + +@register_handler("BUILD_SET") +def handle_build_set( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: if isinstance(state.result, Placeholder) and len(state.stack) == 0: # This BUILD_SET is the start of a list comprehension # Initialize the result as a ListComp with a placeholder element @@ -203,18 +234,24 @@ def handle_build_set(state: ReconstructionState, instr: dis.Instruction) -> Reco else: size: int = instr.arg # Pop elements for the set - elements = [ensure_ast(elem) for elem in state.stack[-size:]] if size > 0 else [] + elements = ( + [ensure_ast(elem) for elem in state.stack[-size:]] if size > 0 else [] + ) new_stack = state.stack[:-size] if size > 0 else state.stack - + # Create set AST elt_node = ast.Set(elts=elements) new_stack = new_stack + [elt_node] return replace(state, stack=new_stack) -@register_handler('SET_ADD') -def handle_set_add(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - assert isinstance(state.result, ast.SetComp), "SET_ADD must be called after BUILD_SET" +@register_handler("SET_ADD") +def handle_set_add( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + assert isinstance( + state.result, ast.SetComp + ), "SET_ADD must be called after BUILD_SET" new_stack = state.stack[:-1] new_ret = ast.SetComp( elt=ensure_ast(state.stack[-1]), @@ -227,8 +264,11 @@ def handle_set_add(state: ReconstructionState, instr: dis.Instruction) -> Recons # DICT COMPREHENSION HANDLERS # ============================================================================ -@register_handler('BUILD_MAP') -def handle_build_map(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + +@register_handler("BUILD_MAP") +def handle_build_map( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: if isinstance(state.result, Placeholder) and len(state.stack) == 0: # This is the start of a comprehension # Initialize the result with a placeholder element @@ -238,9 +278,9 @@ def handle_build_map(state: ReconstructionState, instr: dis.Instruction) -> Reco else: size: int = instr.arg # Pop key-value pairs for the dict - keys = [ensure_ast(state.stack[-2*i-2]) for i in range(size)] - values = [ensure_ast(state.stack[-2*i-1]) for i in range(size)] - new_stack = state.stack[:-2*size] if size > 0 else state.stack + keys = [ensure_ast(state.stack[-2 * i - 2]) for i in range(size)] + values = [ensure_ast(state.stack[-2 * i - 1]) for i in range(size)] + new_stack = state.stack[: -2 * size] if size > 0 else state.stack # Create dict AST dict_node = ast.Dict(keys=keys, values=values) @@ -248,9 +288,13 @@ def handle_build_map(state: ReconstructionState, instr: dis.Instruction) -> Reco return replace(state, stack=new_stack) -@register_handler('MAP_ADD') -def handle_map_add(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - assert isinstance(state.result, ast.DictComp), "MAP_ADD must be called after BUILD_MAP" +@register_handler("MAP_ADD") +def handle_map_add( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + assert isinstance( + state.result, ast.DictComp + ), "MAP_ADD must be called after BUILD_MAP" new_stack = state.stack[:-2] new_ret = ast.DictComp( key=ensure_ast(state.stack[-2]), @@ -264,30 +308,39 @@ def handle_map_add(state: ReconstructionState, instr: dis.Instruction) -> Recons # LOOP CONTROL HANDLERS # ============================================================================ -@register_handler('RETURN_VALUE') -def handle_return_value(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + +@register_handler("RETURN_VALUE") +def handle_return_value( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: # RETURN_VALUE ends the generator # Usually preceded by LOAD_CONST None if isinstance(state.result, CompExp): return replace(state, stack=state.stack[:-1]) elif isinstance(state.result, Placeholder) and len(state.stack) == 1: - return replace(state, stack=state.stack[:-1], result=ensure_ast(state.stack[-1])) + return replace( + state, stack=state.stack[:-1], result=ensure_ast(state.stack[-1]) + ) else: raise TypeError("Unexpected RETURN_VALUE in reconstruction") -@register_handler('FOR_ITER') -def handle_for_iter(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: +@register_handler("FOR_ITER") +def handle_for_iter( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: # FOR_ITER pops an iterator from the stack and pushes the next item # If the iterator is exhausted, it jumps to the target instruction assert len(state.stack) > 0, "FOR_ITER must have an iterator on the stack" - assert isinstance(state.result, CompExp), "FOR_ITER must be called within a comprehension context" + assert isinstance( + state.result, CompExp + ), "FOR_ITER must be called within a comprehension context" # The iterator should be on top of stack # Create new stack without the iterator new_stack = state.stack[:-1] iterator: ast.expr = state.stack[-1] - + # Create a new loop variable - we'll get the actual name from STORE_FAST # For now, use a placeholder loop_info = ast.comprehension( @@ -296,7 +349,7 @@ def handle_for_iter(state: ReconstructionState, instr: dis.Instruction) -> Recon ifs=[], is_async=0, ) - + # Create new loops list with the new loop info new_ret: ast.GeneratorExp | ast.ListComp | ast.SetComp | ast.DictComp if isinstance(state.result, ast.DictComp): @@ -315,43 +368,51 @@ def handle_for_iter(state: ReconstructionState, instr: dis.Instruction) -> Recon return replace(state, stack=new_stack, result=new_ret) -@register_handler('GET_ITER') -def handle_get_iter(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: +@register_handler("GET_ITER") +def handle_get_iter( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: # GET_ITER converts the top stack item to an iterator # For AST reconstruction, we typically don't need to change anything # since the iterator will be used directly in the comprehension return state -@register_handler('JUMP_ABSOLUTE') -def handle_jump_absolute(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: +@register_handler("JUMP_ABSOLUTE") +def handle_jump_absolute( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: # JUMP_ABSOLUTE is used to jump back to the beginning of a loop # In generator expressions, this typically indicates the end of the loop body return state -@register_handler('JUMP_FORWARD') -def handle_jump_forward(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: +@register_handler("JUMP_FORWARD") +def handle_jump_forward( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: # JUMP_FORWARD is used to jump forward in the code # In generator expressions, this is often used to skip code in conditional logic return state -@register_handler('UNPACK_SEQUENCE') -def handle_unpack_sequence(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: +@register_handler("UNPACK_SEQUENCE") +def handle_unpack_sequence( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: # UNPACK_SEQUENCE unpacks a sequence into multiple values # arg is the number of values to unpack unpack_count = instr.arg - sequence = ensure_ast(state.stack[-1]) + sequence = ensure_ast(state.stack[-1]) # noqa: F841 new_stack = state.stack[:-1] - + # For tuple unpacking in comprehensions, we typically see patterns like: # ((k, v) for k, v in items) where items is unpacked into k and v # Create placeholder variables for the unpacked values for i in range(unpack_count): - var_name = f'_unpack_{i}' + var_name = f"_unpack_{i}" new_stack = new_stack + [ast.Name(id=var_name, ctx=ast.Load())] - + return replace(state, stack=new_stack) @@ -359,25 +420,32 @@ def handle_unpack_sequence(state: ReconstructionState, instr: dis.Instruction) - # VARIABLE OPERATIONS HANDLERS # ============================================================================ -@register_handler('LOAD_FAST') -def handle_load_fast(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + +@register_handler("LOAD_FAST") +def handle_load_fast( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: var_name: str = instr.argval - - if var_name == '.0': + + if var_name == ".0": # Special handling for .0 variable (the iterator) new_stack = state.stack + [DummyIterName()] else: # Regular variable load new_stack = state.stack + [ast.Name(id=var_name, ctx=ast.Load())] - + return replace(state, stack=new_stack) -@register_handler('STORE_FAST') -def handle_store_fast(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - assert isinstance(state.result, CompExp), "STORE_FAST must be called within a comprehension context" +@register_handler("STORE_FAST") +def handle_store_fast( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + assert isinstance( + state.result, CompExp + ), "STORE_FAST must be called within a comprehension context" var_name = instr.argval - + # Update the most recent loop's target variable assert len(state.result.generators) > 0, "STORE_FAST must be within a loop context" @@ -386,7 +454,7 @@ def handle_store_fast(state: ReconstructionState, instr: dis.Instruction) -> Rec target=ast.Name(id=var_name, ctx=ast.Store()), iter=state.result.generators[-1].iter, ifs=state.result.generators[-1].ifs, - is_async=state.result.generators[-1].is_async + is_async=state.result.generators[-1].is_async, ) # Update the last loop in the generators list @@ -406,34 +474,44 @@ def handle_store_fast(state: ReconstructionState, instr: dis.Instruction) -> Rec return replace(state, result=new_ret) -@register_handler('LOAD_CONST') -def handle_load_const(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: +@register_handler("LOAD_CONST") +def handle_load_const( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: const_value = instr.argval new_stack = state.stack + [ensure_ast(const_value)] return replace(state, stack=new_stack) -@register_handler('LOAD_GLOBAL') -def handle_load_global(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: +@register_handler("LOAD_GLOBAL") +def handle_load_global( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: global_name = instr.argval new_stack = state.stack + [ast.Name(id=global_name, ctx=ast.Load())] return replace(state, stack=new_stack) -@register_handler('LOAD_NAME') -def handle_load_name(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: +@register_handler("LOAD_NAME") +def handle_load_name( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: # LOAD_NAME is similar to LOAD_GLOBAL but for names in the global namespace name = instr.argval new_stack = state.stack + [ast.Name(id=name, ctx=ast.Load())] return replace(state, stack=new_stack) -@register_handler('STORE_DEREF') -def handle_store_deref(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: +@register_handler("STORE_DEREF") +def handle_store_deref( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: # STORE_DEREF stores a value into a closure variable - assert isinstance(state.result, CompExp), "STORE_DEREF must be called within a comprehension context" + assert isinstance( + state.result, CompExp + ), "STORE_DEREF must be called within a comprehension context" var_name = instr.argval - + # Update the most recent loop's target variable assert len(state.result.generators) > 0, "STORE_DEREF must be within a loop context" @@ -442,7 +520,7 @@ def handle_store_deref(state: ReconstructionState, instr: dis.Instruction) -> Re target=ast.Name(id=var_name, ctx=ast.Store()), iter=state.result.generators[-1].iter, ifs=state.result.generators[-1].ifs, - is_async=state.result.generators[-1].is_async + is_async=state.result.generators[-1].is_async, ) # Update the last loop in the generators list @@ -462,16 +540,20 @@ def handle_store_deref(state: ReconstructionState, instr: dis.Instruction) -> Re return replace(state, result=new_ret) -@register_handler('LOAD_DEREF') -def handle_load_deref(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: +@register_handler("LOAD_DEREF") +def handle_load_deref( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: # LOAD_DEREF loads a value from a closure variable var_name = instr.argval new_stack = state.stack + [ast.Name(id=var_name, ctx=ast.Load())] return replace(state, stack=new_stack) -@register_handler('LOAD_CLOSURE') -def handle_load_closure(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: +@register_handler("LOAD_CLOSURE") +def handle_load_closure( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: # LOAD_CLOSURE loads a closure variable var_name = instr.argval new_stack = state.stack + [ast.Name(id=var_name, ctx=ast.Load())] @@ -482,8 +564,11 @@ def handle_load_closure(state: ReconstructionState, instr: dis.Instruction) -> R # STACK MANAGEMENT HANDLERS # ============================================================================ -@register_handler('POP_TOP') -def handle_pop_top(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + +@register_handler("POP_TOP") +def handle_pop_top( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: # POP_TOP removes the top item from the stack # In generators, often used after YIELD_VALUE # Also used to clean up the duplicated middle value in failed chained comparisons @@ -491,40 +576,53 @@ def handle_pop_top(state: ReconstructionState, instr: dis.Instruction) -> Recons return replace(state, stack=new_stack) -@register_handler('DUP_TOP') -def handle_dup_top(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: +@register_handler("DUP_TOP") +def handle_dup_top( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: # DUP_TOP duplicates the top stack item top_item = state.stack[-1] new_stack = state.stack + [top_item] return replace(state, stack=new_stack) -@register_handler('ROT_TWO') -def handle_rot_two(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: +@register_handler("ROT_TWO") +def handle_rot_two( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: # ROT_TWO swaps the top two stack items new_stack = state.stack[:-2] + [state.stack[-1], state.stack[-2]] return replace(state, stack=new_stack) -@register_handler('ROT_THREE') -def handle_rot_three(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: +@register_handler("ROT_THREE") +def handle_rot_three( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: # ROT_THREE rotates the top three stack items # TOS -> TOS1, TOS1 -> TOS2, TOS2 -> TOS new_stack = state.stack[:-3] + [state.stack[-2], state.stack[-1], state.stack[-3]] - + # Check if the top two items are the same (from DUP_TOP) # This heuristic indicates we're setting up for a chained comparison if len(state.stack) >= 3 and state.stack[-1] == state.stack[-2]: raise NotImplementedError("Chained comparison not implemented yet") - + return replace(state, stack=new_stack) -@register_handler('ROT_FOUR') -def handle_rot_four(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: +@register_handler("ROT_FOUR") +def handle_rot_four( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: # ROT_FOUR rotates the top four stack items # TOS -> TOS1, TOS1 -> TOS2, TOS2 -> TOS3, TOS3 -> TOS - new_stack = state.stack[:-4] + [state.stack[-2], state.stack[-1], state.stack[-4], state.stack[-3]] + new_stack = state.stack[:-4] + [ + state.stack[-2], + state.stack[-1], + state.stack[-4], + state.stack[-3], + ] return replace(state, stack=new_stack) @@ -532,41 +630,79 @@ def handle_rot_four(state: ReconstructionState, instr: dis.Instruction) -> Recon # BINARY ARITHMETIC/LOGIC OPERATION HANDLERS # ============================================================================ -def handle_binop(op: ast.operator, state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + +def handle_binop( + op: ast.operator, state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: right = ensure_ast(state.stack[-1]) left = ensure_ast(state.stack[-2]) new_stack = state.stack[:-2] + [ast.BinOp(left=left, op=op, right=right)] return replace(state, stack=new_stack) -handler_binop_add = register_handler('BINARY_ADD', functools.partial(handle_binop, ast.Add())) -handler_binop_subtract = register_handler('BINARY_SUBTRACT', functools.partial(handle_binop, ast.Sub())) -handler_binop_multiply = register_handler('BINARY_MULTIPLY', functools.partial(handle_binop, ast.Mult())) -handler_binop_true_divide = register_handler('BINARY_TRUE_DIVIDE', functools.partial(handle_binop, ast.Div())) -handler_binop_floor_divide = register_handler('BINARY_FLOOR_DIVIDE', functools.partial(handle_binop, ast.FloorDiv())) -handler_binop_modulo = register_handler('BINARY_MODULO', functools.partial(handle_binop, ast.Mod())) -handler_binop_power = register_handler('BINARY_POWER', functools.partial(handle_binop, ast.Pow())) -handler_binop_lshift = register_handler('BINARY_LSHIFT', functools.partial(handle_binop, ast.LShift())) -handler_binop_rshift = register_handler('BINARY_RSHIFT', functools.partial(handle_binop, ast.RShift())) -handler_binop_or = register_handler('BINARY_OR', functools.partial(handle_binop, ast.BitOr())) -handler_binop_xor = register_handler('BINARY_XOR', functools.partial(handle_binop, ast.BitXor())) -handler_binop_and = register_handler('BINARY_AND', functools.partial(handle_binop, ast.BitAnd())) +handler_binop_add = register_handler( + "BINARY_ADD", functools.partial(handle_binop, ast.Add()) +) +handler_binop_subtract = register_handler( + "BINARY_SUBTRACT", functools.partial(handle_binop, ast.Sub()) +) +handler_binop_multiply = register_handler( + "BINARY_MULTIPLY", functools.partial(handle_binop, ast.Mult()) +) +handler_binop_true_divide = register_handler( + "BINARY_TRUE_DIVIDE", functools.partial(handle_binop, ast.Div()) +) +handler_binop_floor_divide = register_handler( + "BINARY_FLOOR_DIVIDE", functools.partial(handle_binop, ast.FloorDiv()) +) +handler_binop_modulo = register_handler( + "BINARY_MODULO", functools.partial(handle_binop, ast.Mod()) +) +handler_binop_power = register_handler( + "BINARY_POWER", functools.partial(handle_binop, ast.Pow()) +) +handler_binop_lshift = register_handler( + "BINARY_LSHIFT", functools.partial(handle_binop, ast.LShift()) +) +handler_binop_rshift = register_handler( + "BINARY_RSHIFT", functools.partial(handle_binop, ast.RShift()) +) +handler_binop_or = register_handler( + "BINARY_OR", functools.partial(handle_binop, ast.BitOr()) +) +handler_binop_xor = register_handler( + "BINARY_XOR", functools.partial(handle_binop, ast.BitXor()) +) +handler_binop_and = register_handler( + "BINARY_AND", functools.partial(handle_binop, ast.BitAnd()) +) # ============================================================================ # UNARY OPERATION HANDLERS # ============================================================================ -def handle_unary_op(op: ast.unaryop, state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + +def handle_unary_op( + op: ast.unaryop, state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: operand = ensure_ast(state.stack[-1]) new_stack = state.stack[:-1] + [ast.UnaryOp(op=op, operand=operand)] return replace(state, stack=new_stack) -handle_unary_negative = register_handler('UNARY_NEGATIVE', functools.partial(handle_unary_op, ast.USub())) -handle_unary_positive = register_handler('UNARY_POSITIVE', functools.partial(handle_unary_op, ast.UAdd())) -handle_unary_invert = register_handler('UNARY_INVERT', functools.partial(handle_unary_op, ast.Invert())) -handle_unary_not = register_handler('UNARY_NOT', functools.partial(handle_unary_op, ast.Not())) +handle_unary_negative = register_handler( + "UNARY_NEGATIVE", functools.partial(handle_unary_op, ast.USub()) +) +handle_unary_positive = register_handler( + "UNARY_POSITIVE", functools.partial(handle_unary_op, ast.UAdd()) +) +handle_unary_invert = register_handler( + "UNARY_INVERT", functools.partial(handle_unary_op, ast.Invert()) +) +handle_unary_not = register_handler( + "UNARY_NOT", functools.partial(handle_unary_op, ast.Not()) +) # ============================================================================ @@ -574,63 +710,59 @@ def handle_unary_op(op: ast.unaryop, state: ReconstructionState, instr: dis.Inst # ============================================================================ CMP_OPMAP: dict[str, ast.cmpop] = { - '<': ast.Lt(), - '<=': ast.LtE(), - '>': ast.Gt(), - '>=': ast.GtE(), - '==': ast.Eq(), - '!=': ast.NotEq(), + "<": ast.Lt(), + "<=": ast.LtE(), + ">": ast.Gt(), + ">=": ast.GtE(), + "==": ast.Eq(), + "!=": ast.NotEq(), } -@register_handler('COMPARE_OP') -def handle_compare_op(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: - assert dis.cmp_op[instr.arg] == instr.argval, f"Unsupported comparison operation: {instr.argval}" +@register_handler("COMPARE_OP") +def handle_compare_op( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + assert ( + dis.cmp_op[instr.arg] == instr.argval + ), f"Unsupported comparison operation: {instr.argval}" right = ensure_ast(state.stack[-1]) left = ensure_ast(state.stack[-2]) - + # Map comparison operation codes to AST operators op_name = instr.argval - compare_node = ast.Compare( - left=left, - ops=[CMP_OPMAP[op_name]], - comparators=[right] - ) + compare_node = ast.Compare(left=left, ops=[CMP_OPMAP[op_name]], comparators=[right]) new_stack = state.stack[:-2] + [compare_node] return replace(state, stack=new_stack) -@register_handler('CONTAINS_OP') -def handle_contains_op(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: +@register_handler("CONTAINS_OP") +def handle_contains_op( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: right = ensure_ast(state.stack[-1]) # Container - left = ensure_ast(state.stack[-2]) # Item to check - + left = ensure_ast(state.stack[-2]) # Item to check + # instr.arg determines if it's 'in' (0) or 'not in' (1) op = ast.NotIn() if instr.arg else ast.In() - - compare_node = ast.Compare( - left=left, - ops=[op], - comparators=[right] - ) + + compare_node = ast.Compare(left=left, ops=[op], comparators=[right]) new_stack = state.stack[:-2] + [compare_node] return replace(state, stack=new_stack) -@register_handler('IS_OP') -def handle_is_op(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: +@register_handler("IS_OP") +def handle_is_op( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: right = ensure_ast(state.stack[-1]) left = ensure_ast(state.stack[-2]) - + # instr.arg determines if it's 'is' (0) or 'is not' (1) op = ast.IsNot() if instr.arg else ast.Is() - - compare_node = ast.Compare( - left=left, - ops=[op], - comparators=[right] - ) + + compare_node = ast.Compare(left=left, ops=[op], comparators=[right]) new_stack = state.stack[:-2] + [compare_node] return replace(state, stack=new_stack) @@ -639,14 +771,19 @@ def handle_is_op(state: ReconstructionState, instr: dis.Instruction) -> Reconstr # FUNCTION CALL HANDLERS # ============================================================================ -@register_handler('CALL_FUNCTION') -def handle_call_function(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + +@register_handler("CALL_FUNCTION") +def handle_call_function( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: # CALL_FUNCTION pops function and arguments from stack arg_count: int = instr.arg # Pop arguments and function - args = [ensure_ast(arg) for arg in state.stack[-arg_count:]] if arg_count > 0 else [] + args = ( + [ensure_ast(arg) for arg in state.stack[-arg_count:]] if arg_count > 0 else [] + ) func = ensure_ast(state.stack[-arg_count - 1]) - new_stack = state.stack[:-arg_count - 1] + new_stack = state.stack[: -arg_count - 1] if isinstance(func, CompLambda): assert len(args) == 1 @@ -656,85 +793,110 @@ def handle_call_function(state: ReconstructionState, instr: dis.Instruction) -> call_node = ast.Call(func=func, args=args, keywords=[]) new_stack = new_stack + [call_node] return replace(state, stack=new_stack) - -@register_handler('LOAD_METHOD') -def handle_load_method(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + +@register_handler("LOAD_METHOD") +def handle_load_method( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: # LOAD_METHOD loads a method from an object # It pushes the bound method and the object (for the method call) obj = ensure_ast(state.stack[-1]) method_name = instr.argval new_stack = state.stack[:-1] - + # Create method access as an attribute method_attr = ast.Attribute(value=obj, attr=method_name, ctx=ast.Load()) - + # For LOAD_METHOD, we push both the method and the object # But for AST purposes, we just need the method attribute new_stack = new_stack + [method_attr] return replace(state, stack=new_stack) -@register_handler('CALL_METHOD') -def handle_call_method(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: +@register_handler("CALL_METHOD") +def handle_call_method( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: # CALL_METHOD calls a method - similar to CALL_FUNCTION but for methods arg_count = instr.arg # Pop arguments and method - args = [ensure_ast(arg) for arg in state.stack[-arg_count:]] if arg_count > 0 else [] + args = ( + [ensure_ast(arg) for arg in state.stack[-arg_count:]] if arg_count > 0 else [] + ) method = ensure_ast(state.stack[-arg_count - 1]) - new_stack = state.stack[:-arg_count - 1] - + new_stack = state.stack[: -arg_count - 1] + # Create method call AST call_node = ast.Call(func=method, args=args, keywords=[]) new_stack = new_stack + [call_node] return replace(state, stack=new_stack) -@register_handler('MAKE_FUNCTION') -def handle_make_function(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: +@register_handler("MAKE_FUNCTION") +def handle_make_function( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: # MAKE_FUNCTION creates a function from code object and name on stack - assert isinstance(state.stack[-1], ast.Constant) and isinstance(state.stack[-1].value, str), "Function name must be a constant string." - if instr.argrepr == 'closure': + assert isinstance(state.stack[-1], ast.Constant) and isinstance( + state.stack[-1].value, str + ), "Function name must be a constant string." + if instr.argrepr == "closure": # This is a closure, remove the environment tuple from the stack for AST purposes new_stack = state.stack[:-3] - elif instr.argrepr == '': + elif instr.argrepr == "": new_stack = state.stack[:-2] else: - raise NotImplementedError("MAKE_FUNCTION with defaults or annotations not implemented.") + raise NotImplementedError( + "MAKE_FUNCTION with defaults or annotations not implemented." + ) body: ast.expr = state.stack[-2] name: str = state.stack[-1].value - if isinstance(body, CompExp) and sum(1 for x in ast.walk(body) if isinstance(x, DummyIterName)) == 1: + assert any( + name.endswith(suffix) + for suffix in ("", "", "", "", "") + ), f"Expected a comprehension or lambda function, got '{name}'" + + if ( + isinstance(body, CompExp) + and sum(1 for x in ast.walk(body) if isinstance(x, DummyIterName)) == 1 + ): return replace(state, stack=new_stack + [CompLambda(body)]) else: raise NotImplementedError("Lambda reconstruction not implemented yet") # ============================================================================ -# OBJECT ACCESS HANDLERS +# OBJECT ACCESS HANDLERS # ============================================================================ -@register_handler('LOAD_ATTR') -def handle_load_attr(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + +@register_handler("LOAD_ATTR") +def handle_load_attr( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: # LOAD_ATTR loads an attribute from the object on top of stack obj = ensure_ast(state.stack[-1]) attr_name = instr.argval new_stack = state.stack[:-1] - + # Create attribute access AST attr_node = ast.Attribute(value=obj, attr=attr_name, ctx=ast.Load()) new_stack = new_stack + [attr_node] return replace(state, stack=new_stack) -@register_handler('BINARY_SUBSCR') -def handle_binary_subscr(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: +@register_handler("BINARY_SUBSCR") +def handle_binary_subscr( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: # BINARY_SUBSCR implements obj[index] - pops index and obj from stack index = ensure_ast(state.stack[-1]) # Index is on top - obj = ensure_ast(state.stack[-2]) # Object is below index + obj = ensure_ast(state.stack[-2]) # Object is below index new_stack = state.stack[:-2] - + # Create subscript access AST subscr_node = ast.Subscript(value=obj, slice=index, ctx=ast.Load()) new_stack = new_stack + [subscr_node] @@ -745,38 +907,49 @@ def handle_binary_subscr(state: ReconstructionState, instr: dis.Instruction) -> # OTHER CONTAINER BUILDING HANDLERS # ============================================================================ -@register_handler('BUILD_TUPLE') -def handle_build_tuple(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + +@register_handler("BUILD_TUPLE") +def handle_build_tuple( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: tuple_size: int = instr.arg # Pop elements for the tuple - elements = [ensure_ast(elem) for elem in state.stack[-tuple_size:]] if tuple_size > 0 else [] + elements = ( + [ensure_ast(elem) for elem in state.stack[-tuple_size:]] + if tuple_size > 0 + else [] + ) new_stack = state.stack[:-tuple_size] if tuple_size > 0 else state.stack - + # Create tuple AST tuple_node = ast.Tuple(elts=elements, ctx=ast.Load()) new_stack = new_stack + [tuple_node] return replace(state, stack=new_stack) -@register_handler('LIST_TO_TUPLE') -def handle_list_to_tuple(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: +@register_handler("LIST_TO_TUPLE") +def handle_list_to_tuple( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: # LIST_TO_TUPLE converts a list on the stack to a tuple list_obj = ensure_ast(state.stack[-1]) assert isinstance(list_obj, ast.List), "Expected a list for LIST_TO_TUPLE" - + # Create tuple AST from the list's elements tuple_node = ast.Tuple(elts=list_obj.elts, ctx=ast.Load()) new_stack = state.stack[:-1] + [tuple_node] return replace(state, stack=new_stack) -@register_handler('LIST_EXTEND') -def handle_list_extend(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: +@register_handler("LIST_EXTEND") +def handle_list_extend( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: # LIST_EXTEND extends the list at TOS-1 with the iterable at TOS iterable = ensure_ast(state.stack[-1]) list_obj = state.stack[-2] # This should be a list from BUILD_LIST new_stack = state.stack[:-2] - + # If the list is empty and we're extending with a tuple/iterable, # we can convert this to a simple list of the iterable's elements if isinstance(list_obj, ast.List) and len(list_obj.elts) == 0: @@ -786,28 +959,28 @@ def handle_list_extend(state: ReconstructionState, instr: dis.Instruction) -> Re list_node = ast.List(elts=elements, ctx=ast.Load()) new_stack = new_stack + [list_node] return replace(state, stack=new_stack) - + # Fallback: create a list from the iterable using list() constructor list_call = ast.Call( - func=ast.Name(id='list', ctx=ast.Load()), - args=[iterable], - keywords=[] + func=ast.Name(id="list", ctx=ast.Load()), args=[iterable], keywords=[] ) new_stack = new_stack + [list_call] return replace(state, stack=new_stack) -@register_handler('BUILD_CONST_KEY_MAP') -def handle_build_const_key_map(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: +@register_handler("BUILD_CONST_KEY_MAP") +def handle_build_const_key_map( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: # BUILD_CONST_KEY_MAP builds a dictionary with constant keys # The keys are in a tuple on TOS, values are on the stack below map_size: int = instr.arg # Pop the keys tuple and values keys_tuple: ast.Tuple = state.stack[-1] keys = [ensure_ast(key) for key in keys_tuple.elts] - values = [ensure_ast(val) for val in state.stack[-map_size-1:-1]] - new_stack = state.stack[:-map_size-1] - + values = [ensure_ast(val) for val in state.stack[-map_size - 1 : -1]] + new_stack = state.stack[: -map_size - 1] + # Create dictionary AST dict_node = ast.Dict(keys=keys, values=values) new_stack = new_stack + [dict_node] @@ -818,8 +991,11 @@ def handle_build_const_key_map(state: ReconstructionState, instr: dis.Instructio # CONDITIONAL JUMP HANDLERS # ============================================================================ -@register_handler('POP_JUMP_IF_FALSE') -def handle_pop_jump_if_false(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: + +@register_handler("POP_JUMP_IF_FALSE") +def handle_pop_jump_if_false( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: # POP_JUMP_IF_FALSE pops a value from the stack and jumps if it's false # In comprehensions, this is used for filter conditions condition = ensure_ast(state.stack[-1]) @@ -852,8 +1028,10 @@ def handle_pop_jump_if_false(state: ReconstructionState, instr: dis.Instruction) raise NotImplementedError("Lazy and+or behavior not implemented yet") -@register_handler('POP_JUMP_IF_TRUE') -def handle_pop_jump_if_true(state: ReconstructionState, instr: dis.Instruction) -> ReconstructionState: +@register_handler("POP_JUMP_IF_TRUE") +def handle_pop_jump_if_true( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: # POP_JUMP_IF_TRUE pops a value from the stack and jumps if it's true # This can be: # 1. Part of an OR expression (jump to YIELD_VALUE) @@ -894,6 +1072,7 @@ def handle_pop_jump_if_true(state: ReconstructionState, instr: dis.Instruction) # UTILITY FUNCTIONS # ============================================================================ + @functools.singledispatch def ensure_ast(value) -> ast.expr: """Ensure value is an AST node""" @@ -920,10 +1099,9 @@ def _ensure_ast_constant(value) -> ast.Constant: @ensure_ast.register def _ensure_ast_tuple(value: tuple) -> ast.Tuple: """Convert tuple to AST - special handling for dict items""" - if len(value) > 0 and value[0] == 'dict_item': + if len(value) > 0 and value[0] == "dict_item": return ast.Tuple( - elts=[ensure_ast(value[1]), ensure_ast(value[2])], - ctx=ast.Load() + elts=[ensure_ast(value[1]), ensure_ast(value[2])], ctx=ast.Load() ) else: return ast.Tuple(elts=[ensure_ast(v) for v in value], ctx=ast.Load()) @@ -958,7 +1136,7 @@ def _ensure_ast_set_iterator(value: Iterator) -> ast.Set: def _ensure_ast_dict(value: dict) -> ast.Dict: return ast.Dict( keys=[ensure_ast(k) for k in value.keys()], - values=[ensure_ast(v) for v in value.values()] + values=[ensure_ast(v) for v in value.values()], ) @@ -970,9 +1148,9 @@ def _ensure_ast_dict_iterator(value: Iterator) -> ast.Dict: @ensure_ast.register def _ensure_ast_range(value: range) -> ast.Call: return ast.Call( - func=ast.Name(id='range', ctx=ast.Load()), + func=ast.Name(id="range", ctx=ast.Load()), args=[ensure_ast(value.start), ensure_ast(value.stop), ensure_ast(value.step)], - keywords=[] + keywords=[], ) @@ -989,21 +1167,37 @@ def _ensure_ast_codeobj(value: types.CodeType) -> ast.expr: state = OP_HANDLERS[instr.opname](state, instr) # Check postconditions - assert not any(isinstance(x, Placeholder) for x in ast.walk(state.result)), "Return value must not contain placeholders" - assert not isinstance(state.result, CompExp) or len(state.result.generators) > 0, "Return value must have generators if not a lambda" + assert not any( + isinstance(x, Placeholder) for x in ast.walk(state.result) + ), "Return value must not contain placeholders" + assert ( + not isinstance(state.result, CompExp) or len(state.result.generators) > 0 + ), "Return value must have generators if not a lambda" return state.result @ensure_ast.register def _ensure_ast_lambda(value: types.LambdaType) -> ast.Lambda: - assert inspect.isfunction(value) and value.__name__.endswith(""), "Input must be a lambda function" + assert inspect.isfunction(value) and value.__name__.endswith( + "" + ), "Input must be a lambda function" code: types.CodeType = value.__code__ body: ast.expr = ensure_ast(code) args = ast.arguments( - posonlyargs=[ast.arg(arg=arg) for arg in code.co_varnames[:code.co_posonlyargcount]], - args=[ast.arg(arg=arg) for arg in code.co_varnames[code.co_posonlyargcount:code.co_argcount]], - kwonlyargs=[ast.arg(arg=arg) for arg in code.co_varnames[code.co_argcount:code.co_argcount + code.co_kwonlyargcount]], + posonlyargs=[ + ast.arg(arg=arg) for arg in code.co_varnames[: code.co_posonlyargcount] + ], + args=[ + ast.arg(arg=arg) + for arg in code.co_varnames[code.co_posonlyargcount : code.co_argcount] + ], + kwonlyargs=[ + ast.arg(arg=arg) + for arg in code.co_varnames[ + code.co_argcount : code.co_argcount + code.co_kwonlyargcount + ] + ], kw_defaults=[], defaults=[], ) @@ -1013,11 +1207,15 @@ def _ensure_ast_lambda(value: types.LambdaType) -> ast.Lambda: @ensure_ast.register def _ensure_ast_genexpr(genexpr: types.GeneratorType) -> ast.GeneratorExp: assert inspect.isgenerator(genexpr), "Input must be a generator expression" - assert inspect.getgeneratorstate(genexpr) == inspect.GEN_CREATED, "Generator must be in created state" + assert ( + inspect.getgeneratorstate(genexpr) == inspect.GEN_CREATED + ), "Generator must be in created state" genexpr_ast: ast.GeneratorExp = ensure_ast(genexpr.gi_code) - geniter_ast: ast.expr = ensure_ast(genexpr.gi_frame.f_locals['.0']) + geniter_ast: ast.expr = ensure_ast(genexpr.gi_frame.f_locals[".0"]) result = CompLambda(genexpr_ast).inline(geniter_ast) - assert inspect.getgeneratorstate(genexpr) == inspect.GEN_CREATED, "Generator must stay in created state" + assert ( + inspect.getgeneratorstate(genexpr) == inspect.GEN_CREATED + ), "Generator must stay in created state" return result @@ -1025,48 +1223,49 @@ def _ensure_ast_genexpr(genexpr: types.GeneratorType) -> ast.GeneratorExp: # MAIN RECONSTRUCTION FUNCTION # ============================================================================ + def reconstruct(genexpr: Generator[object, None, None]) -> ast.Expression: """ Reconstruct an AST from a generator expression's bytecode. - + This function analyzes the bytecode of a generator object and reconstructs an abstract syntax tree (AST) that represents the original comprehension expression. The reconstruction process simulates the Python VM's execution of the bytecode, building AST nodes instead of executing operations. - + The reconstruction handles complex comprehension features including: - Multiple nested loops - Filter conditions (if clauses) - Complex expressions in the yield/result part - Tuple unpacking in loop variables - Various operators and function calls - + Args: genexpr (Generator[object, None, None]): The generator object to analyze. Must be a freshly created generator that has not been iterated yet (in 'GEN_CREATED' state). - + Returns: ast.Expression: An AST node representing the reconstructed comprehension. - + Raises: AssertionError: If the input is not a generator or if the generator has already been started (not in 'GEN_CREATED' state). - + Example: >>> # Generator expression >>> g = (x * 2 for x in range(10) if x % 2 == 0) >>> ast_node = reconstruct(g) >>> isinstance(ast_node, ast.Expression) True - + >>> # The reconstructed AST can be compiled and evaluated >>> import ast >>> code = compile(ast_node, '', 'eval') >>> result = eval(code) >>> list(result) [0, 4, 8, 12, 16] - + Note: The reconstruction is based on bytecode analysis and may not perfectly preserve the original source code formatting or variable names in all diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index 429a3db7..dd5e2c2e 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -7,34 +7,38 @@ from effectful.internals.disassembler import reconstruct -def compile_and_eval(node: ast.expr | ast.Expression, globals_dict: dict | None = None) -> Any: +def compile_and_eval( + node: ast.expr | ast.Expression, globals_dict: dict | None = None +) -> Any: """Compile an AST node and evaluate it.""" if globals_dict is None: globals_dict = {} - + # Wrap in an Expression node if needed if not isinstance(node, ast.Expression): node = ast.Expression(body=node) - + # Fix location info ast.fix_missing_locations(node) - + # Compile and evaluate - code = compile(node, '', 'eval') + code = compile(node, "", "eval") return eval(code, globals_dict) -def assert_ast_equivalent(genexpr: GeneratorType, reconstructed_ast: ast.AST, globals_dict: dict | None = None): +def assert_ast_equivalent( + genexpr: GeneratorType, reconstructed_ast: ast.AST, globals_dict: dict | None = None +): """Assert that a reconstructed AST produces the same results as the original generator.""" # Check AST structure assert isinstance(reconstructed_ast, ast.Expression) - assert hasattr(reconstructed_ast.body, 'elt') # The expression part - assert hasattr(reconstructed_ast.body, 'generators') # The comprehension part + assert hasattr(reconstructed_ast.body, "elt") # The expression part + assert hasattr(reconstructed_ast.body, "generators") # The comprehension part assert len(reconstructed_ast.body.generators) > 0 for comp in reconstructed_ast.body.generators: - assert hasattr(comp, 'target') # Loop variable - assert hasattr(comp, 'iter') # Iterator - assert hasattr(comp, 'ifs') # Conditions + assert hasattr(comp, "target") # Loop variable + assert hasattr(comp, "iter") # Iterator + assert hasattr(comp, "ifs") # Conditions # Save current globals to restore later curr_globals = globals().copy() @@ -48,32 +52,36 @@ def assert_ast_equivalent(genexpr: GeneratorType, reconstructed_ast: ast.AST, gl if key not in curr_globals: del globals()[key] globals().update(curr_globals) - + # Compile and evaluate the reconstructed AST reconstructed_gen = compile_and_eval(reconstructed_ast, globals_dict) reconstructed_list = list(reconstructed_gen) - assert reconstructed_list == original_list, \ - f"AST produced {reconstructed_list}, expected {original_list}" + assert ( + reconstructed_list == original_list + ), f"AST produced {reconstructed_list}, expected {original_list}" # ============================================================================ # BASIC GENERATOR EXPRESSION TESTS # ============================================================================ -@pytest.mark.parametrize("genexpr", [ - # Simple generator expressions - (x for x in range(5)), - (y for y in range(10)), - (item for item in [1, 2, 3]), - - # Edge cases for simple generators - (i for i in range(0)), # Empty range - (n for n in range(1)), # Single item range - (val for val in range(100)), # Large range - (x for x in range(-5, 5)), # Negative range - (step for step in range(0, 10, 2)), # Step range - (rev for rev in range(10, 0, -1)), # Reverse range -]) + +@pytest.mark.parametrize( + "genexpr", + [ + # Simple generator expressions + (x for x in range(5)), + (y for y in range(10)), + (item for item in [1, 2, 3]), + # Edge cases for simple generators + (i for i in range(0)), # Empty range + (n for n in range(1)), # Single item range + (val for val in range(100)), # Large range + (x for x in range(-5, 5)), # Negative range + (step for step in range(0, 10, 2)), # Step range + (rev for rev in range(10, 0, -1)), # Reverse range + ], +) def test_simple_generators(genexpr): """Test reconstruction of simple generator expressions.""" ast_node = reconstruct(genexpr) @@ -84,52 +92,50 @@ def test_simple_generators(genexpr): # ARITHMETIC AND EXPRESSION TESTS # ============================================================================ -@pytest.mark.parametrize("genexpr", [ - # Basic arithmetic operations - (x * 2 for x in range(5)), - (x + 1 for x in range(5)), - (x - 1 for x in range(5)), - (x ** 2 for x in range(5)), - (x % 2 for x in range(10)), - (x / 2 for x in range(1, 6)), - (x // 2 for x in range(10)), - - # Complex expressions - (x * 2 + 1 for x in range(5)), - ((x + 1) * (x - 1) for x in range(5)), - (x ** 2 + 2 * x + 1 for x in range(5)), - - # Unary operations - (-x for x in range(5)), - (+x for x in range(-5, 5)), - (~x for x in range(5)), - - # More complex arithmetic edge cases - (x ** 3 for x in range(1, 5)), # Higher powers - (x * x * x for x in range(5)), # Repeated multiplication - (x + x + x for x in range(5)), # Repeated addition - (x - x + 1 for x in range(5)), # Operations that might simplify - (x / x for x in range(1, 5)), # Division by self - (x % (x + 1) for x in range(1, 10)), # Modulo with expression - - # Nested arithmetic expressions - ((x + 1) ** 2 for x in range(5)), - ((x * 2 + 3) * (x - 1) for x in range(5)), - (x * (x + 1) * (x + 2) for x in range(5)), - - # Mixed operations with precedence - (x + y * 2 for x in range(3) for y in range(3)), - (x * 2 + y / 3 for x in range(1, 4) for y in range(1, 4)), - ((x + y) * (x - y) for x in range(1, 4) for y in range(1, 4)), - - # Edge cases with zero and one - (x * 0 for x in range(5)), - (x * 1 for x in range(5)), - (x + 0 for x in range(5)), - (x ** 1 for x in range(5)), - (0 + x for x in range(5)), - (1 * x for x in range(5)), -]) + +@pytest.mark.parametrize( + "genexpr", + [ + # Basic arithmetic operations + (x * 2 for x in range(5)), + (x + 1 for x in range(5)), + (x - 1 for x in range(5)), + (x**2 for x in range(5)), + (x % 2 for x in range(10)), + (x / 2 for x in range(1, 6)), + (x // 2 for x in range(10)), + # Complex expressions + (x * 2 + 1 for x in range(5)), + ((x + 1) * (x - 1) for x in range(5)), + (x**2 + 2 * x + 1 for x in range(5)), + # Unary operations + (-x for x in range(5)), + (+x for x in range(-5, 5)), + (~x for x in range(5)), + # More complex arithmetic edge cases + (x**3 for x in range(1, 5)), # Higher powers + (x * x * x for x in range(5)), # Repeated multiplication + (x + x + x for x in range(5)), # Repeated addition + (x - x + 1 for x in range(5)), # Operations that might simplify + (x / x for x in range(1, 5)), # Division by self + (x % (x + 1) for x in range(1, 10)), # Modulo with expression + # Nested arithmetic expressions + ((x + 1) ** 2 for x in range(5)), + ((x * 2 + 3) * (x - 1) for x in range(5)), + (x * (x + 1) * (x + 2) for x in range(5)), + # Mixed operations with precedence + (x + y * 2 for x in range(3) for y in range(3)), + (x * 2 + y / 3 for x in range(1, 4) for y in range(1, 4)), + ((x + y) * (x - y) for x in range(1, 4) for y in range(1, 4)), + # Edge cases with zero and one + (x * 0 for x in range(5)), + (x * 1 for x in range(5)), + (x + 0 for x in range(5)), + (x**1 for x in range(5)), + (0 + x for x in range(5)), + (1 * x for x in range(5)), + ], +) def test_arithmetic_expressions(genexpr): """Test reconstruction of generators with arithmetic expressions.""" ast_node = reconstruct(genexpr) @@ -140,57 +146,68 @@ def test_arithmetic_expressions(genexpr): # COMPARISON OPERATORS # ============================================================================ -@pytest.mark.parametrize("genexpr", [ - # All comparison operators - (x for x in range(10) if x < 5), - (x for x in range(10) if x <= 5), - (x for x in range(10) if x > 5), - (x for x in range(10) if x >= 5), - (x for x in range(10) if x == 5), - (x for x in range(10) if x != 5), - - # in/not in operators - (x for x in range(10) if x in [2, 4, 6, 8]), - (x for x in range(10) if x not in [2, 4, 6, 8]), - - # is/is not operators (with None) - (x for x in [1, None, 3, None, 5] if x is not None), - (x for x in [1, None, 3, None, 5] if x is None), - - # Boolean operations - these are complex cases that might need special handling - (x for x in range(10) if not x % 2), - (x for x in range(10) if not (x > 5)), - (x for x in range(10) if x > 2 and x < 8), - pytest.param((x for x in range(10) if x < 3 or x > 7), marks=pytest.mark.xfail(reason="Lambda reconstruction not implemented yet")), - - # More complex comparison edge cases - # Comparisons with expressions - (x for x in range(10) if x * 2 > 10), - (x for x in range(10) if x + 1 <= 5), - (x for x in range(10) if x ** 2 < 25), - (x for x in range(10) if (x + 1) * 2 != 6), - - # Complex membership tests - (x for x in range(20) if x in range(5, 15)), - (x for x in range(10) if x not in range(3, 7)), - (x for x in range(10) if x % 2 in [0]), - (x for x in range(10) if x not in []), # Empty container - - # Complex boolean combinations - (x for x in range(20) if not (x < 5 or x > 15)), - (x for x in range(20) if x > 5 and x < 15 and x % 2 == 0), - pytest.param((x for x in range(20) if x < 5 or x > 15 or x == 10), marks=pytest.mark.xfail(reason="Lambda reconstruction not implemented yet")), - pytest.param((x for x in range(20) if not (x > 5 and x < 15)), marks=pytest.mark.xfail(reason="Lambda reconstruction not implemented yet")), - - # Mixed comparison and boolean operations - pytest.param((x for x in range(20) if (x > 10 and x % 2 == 0) or (x < 5 and x % 3 == 0)), marks=pytest.mark.xfail(reason="Lambda reconstruction not implemented yet")), - pytest.param((x for x in range(20) if not (x % 2 == 0 and x % 3 == 0)), marks=pytest.mark.xfail(reason="Lambda reconstruction not implemented yet")), - - # Edge cases with identity comparisons - (x for x in [0, 1, 2, None, 4] if x is not None and x > 1), - (x for x in [True, False, 1, 0] if x is True), - (x for x in [True, False, 1, 0] if x is not False), -]) + +@pytest.mark.parametrize( + "genexpr", + [ + # All comparison operators + (x for x in range(10) if x < 5), + (x for x in range(10) if x <= 5), + (x for x in range(10) if x > 5), + (x for x in range(10) if x >= 5), + (x for x in range(10) if x == 5), + (x for x in range(10) if x != 5), + # in/not in operators + (x for x in range(10) if x in [2, 4, 6, 8]), + (x for x in range(10) if x not in [2, 4, 6, 8]), + # is/is not operators (with None) + (x for x in [1, None, 3, None, 5] if x is not None), + (x for x in [1, None, 3, None, 5] if x is None), + # Boolean operations - these are complex cases that might need special handling + (x for x in range(10) if not x % 2), + (x for x in range(10) if not (x > 5)), + (x for x in range(10) if x > 2 and x < 8), + pytest.param( + (x for x in range(10) if x < 3 or x > 7), + marks=pytest.mark.xfail(reason="Lambda reconstruction not implemented yet"), + ), + # More complex comparison edge cases + # Comparisons with expressions + (x for x in range(10) if x * 2 > 10), + (x for x in range(10) if x + 1 <= 5), + (x for x in range(10) if x**2 < 25), + (x for x in range(10) if (x + 1) * 2 != 6), + # Complex membership tests + (x for x in range(20) if x in range(5, 15)), + (x for x in range(10) if x not in range(3, 7)), + (x for x in range(10) if x % 2 in [0]), + (x for x in range(10) if x not in []), # Empty container + # Complex boolean combinations + (x for x in range(20) if not (x < 5 or x > 15)), + (x for x in range(20) if x > 5 and x < 15 and x % 2 == 0), + pytest.param( + (x for x in range(20) if x < 5 or x > 15 or x == 10), + marks=pytest.mark.xfail(reason="Lambda reconstruction not implemented yet"), + ), + pytest.param( + (x for x in range(20) if not (x > 5 and x < 15)), + marks=pytest.mark.xfail(reason="Lambda reconstruction not implemented yet"), + ), + # Mixed comparison and boolean operations + pytest.param( + (x for x in range(20) if (x > 10 and x % 2 == 0) or (x < 5 and x % 3 == 0)), + marks=pytest.mark.xfail(reason="Lambda reconstruction not implemented yet"), + ), + pytest.param( + (x for x in range(20) if not (x % 2 == 0 and x % 3 == 0)), + marks=pytest.mark.xfail(reason="Lambda reconstruction not implemented yet"), + ), + # Edge cases with identity comparisons + (x for x in [0, 1, 2, None, 4] if x is not None and x > 1), + (x for x in [True, False, 1, 0] if x is True), + (x for x in [True, False, 1, 0] if x is not False), + ], +) def test_comparison_operators(genexpr): """Test reconstruction of all comparison operators.""" ast_node = reconstruct(genexpr) @@ -201,12 +218,16 @@ def test_comparison_operators(genexpr): # CHAINED COMPARISON TESTS # ============================================================================ + @pytest.mark.xfail(reason="Chained comparisons not yet fully supported") -@pytest.mark.parametrize("genexpr", [ - # Chained comparisons - (x for x in range(20) if 5 < x < 15), - (x for x in range(20) if 0 <= x <= 10), -]) +@pytest.mark.parametrize( + "genexpr", + [ + # Chained comparisons + (x for x in range(20) if 5 < x < 15), + (x for x in range(20) if 0 <= x <= 10), + ], +) def test_chained_comparison_operators(genexpr): """Test reconstruction of chained (ternary) comparison operators.""" ast_node = reconstruct(genexpr) @@ -217,52 +238,65 @@ def test_chained_comparison_operators(genexpr): # FILTERED GENERATOR TESTS # ============================================================================ -@pytest.mark.parametrize("genexpr", [ - # Simple filters - (x for x in range(10) if x % 2 == 0), - (x for x in range(10) if x > 5), - (x for x in range(10) if x < 5), - (x for x in range(10) if x != 5), - - # Complex filters - (x for x in range(20) if x % 2 == 0 if x % 3 == 0), - (x for x in range(100) if x > 10 if x < 90 if x % 5 == 0), - - # Filters with expressions - (x * 2 for x in range(10) if x % 2 == 0), - (x ** 2 for x in range(10) if x > 3), - - # Boolean operations in filters - (x for x in range(10) if not x % 2), - (x for x in range(10) if x > 2 and x < 8), - pytest.param((x for x in range(10) if x < 3 or x > 7), marks=pytest.mark.xfail(reason="Lazy conjunctions not implemented yet")), - - # More complex filter edge cases - (x for x in range(50) if x % 7 == 0), # Different modulo - (x for x in range(10) if x >= 0), # Always true condition - (x for x in range(10) if x < 0), # Always false condition - (x for x in range(20) if x % 2 == 0 and x % 3 == 0), # Multiple conditions with and - pytest.param((x for x in range(20) if x % 2 == 0 or x % 3 == 0), marks=pytest.mark.xfail(reason="Lazy conjunctions not implemented yet")), # Multiple conditions with or - - # Nested boolean operations - pytest.param((x for x in range(20) if (x > 5 and x < 15) or x == 0), marks=pytest.mark.xfail(reason="Lazy conjunctions not implemented yet")), - pytest.param((x for x in range(20) if not (x > 10 and x < 15)), marks=pytest.mark.xfail(reason="Lazy conjunctions not implemented yet")), - pytest.param((x for x in range(50) if x > 10 and (x % 2 == 0 or x % 3 == 0)), marks=pytest.mark.xfail(reason="Lazy conjunctions not implemented yet")), - - # Multiple consecutive filters - (x for x in range(100) if x > 20 if x < 80 if x % 10 == 0), - (x for x in range(50) if x % 2 == 0 if x % 3 != 0 if x > 10), - - # Filters with complex expressions - (x + 1 for x in range(20) if (x * 2) % 3 == 0), - (x ** 2 for x in range(10) if x * (x + 1) > 10), - (x / 2 for x in range(1, 20) if x % (x // 2 + 1) == 0), - - # Edge cases with truthiness - (x for x in range(10) if x), # Truthy filter - (x for x in range(-5, 5) if not x), # Falsy filter - (x for x in range(10) if bool(x % 2)), # Explicit bool conversion -]) + +@pytest.mark.parametrize( + "genexpr", + [ + # Simple filters + (x for x in range(10) if x % 2 == 0), + (x for x in range(10) if x > 5), + (x for x in range(10) if x < 5), + (x for x in range(10) if x != 5), + # Complex filters + (x for x in range(20) if x % 2 == 0 if x % 3 == 0), + (x for x in range(100) if x > 10 if x < 90 if x % 5 == 0), + # Filters with expressions + (x * 2 for x in range(10) if x % 2 == 0), + (x**2 for x in range(10) if x > 3), + # Boolean operations in filters + (x for x in range(10) if not x % 2), + (x for x in range(10) if x > 2 and x < 8), + pytest.param( + (x for x in range(10) if x < 3 or x > 7), + marks=pytest.mark.xfail(reason="Lazy conjunctions not implemented yet"), + ), + # More complex filter edge cases + (x for x in range(50) if x % 7 == 0), # Different modulo + (x for x in range(10) if x >= 0), # Always true condition + (x for x in range(10) if x < 0), # Always false condition + ( + x for x in range(20) if x % 2 == 0 and x % 3 == 0 + ), # Multiple conditions with and + pytest.param( + (x for x in range(20) if x % 2 == 0 or x % 3 == 0), + marks=pytest.mark.xfail(reason="Lazy conjunctions not implemented yet"), + ), # Multiple conditions with or + # Nested boolean operations + pytest.param( + (x for x in range(20) if (x > 5 and x < 15) or x == 0), + marks=pytest.mark.xfail(reason="Lazy conjunctions not implemented yet"), + ), + pytest.param( + (x for x in range(20) if not (x > 10 and x < 15)), + marks=pytest.mark.xfail(reason="Lazy conjunctions not implemented yet"), + ), + pytest.param( + (x for x in range(50) if x > 10 and (x % 2 == 0 or x % 3 == 0)), + marks=pytest.mark.xfail(reason="Lazy conjunctions not implemented yet"), + ), + # Multiple consecutive filters + (x for x in range(100) if x > 20 if x < 80 if x % 10 == 0), + (x for x in range(50) if x % 2 == 0 if x % 3 != 0 if x > 10), + # Filters with complex expressions + (x + 1 for x in range(20) if (x * 2) % 3 == 0), + (x**2 for x in range(10) if x * (x + 1) > 10), + (x / 2 for x in range(1, 20) if x % (x // 2 + 1) == 0), + # Edge cases with truthiness + (x for x in range(10) if x), # Truthy filter + (x for x in range(-5, 5) if not x), # Falsy filter + (x for x in range(10) if bool(x % 2)), # Explicit bool conversion + ], +) def test_filtered_generators(genexpr): """Test reconstruction of generators with if conditions.""" ast_node = reconstruct(genexpr) @@ -273,57 +307,68 @@ def test_filtered_generators(genexpr): # NESTED LOOP TESTS # ============================================================================ -@pytest.mark.parametrize("genexpr", [ - # Basic nested loops - ((x, y) for x in range(3) for y in range(3)), - (x + y for x in range(3) for y in range(3)), - (x * y for x in range(1, 4) for y in range(1, 4)), - - # Nested with filters - ((x, y) for x in range(5) for y in range(5) if x < y), - (x + y for x in range(5) if x % 2 == 0 for y in range(5) if y % 2 == 1), - - # Triple nested - ((x, y, z) for x in range(2) for y in range(2) for z in range(2)), - - # More complex nested loop edge cases - # Different sized ranges - ((x, y) for x in range(2) for y in range(5)), - ((x, y) for x in range(10) for y in range(2)), - - # Asymmetric operations - (x - y for x in range(5) for y in range(3)), - (x / (y + 1) for x in range(1, 6) for y in range(3)), - (x ** y for x in range(1, 4) for y in range(3)), - - # Complex expressions with multiple variables - (x * y + x for x in range(3) for y in range(3)), - (x + y + x * y for x in range(1, 4) for y in range(1, 4)), - ((x + y) ** 2 for x in range(3) for y in range(3)), - - # Filters on different loop levels - ((x, y) for x in range(10) if x % 2 == 0 for y in range(10) if y % 3 == 0), - (x * y for x in range(5) for y in range(5) if x != y), - (x + y for x in range(5) for y in range(5) if x + y < 5), - - # Triple and quadruple nested with various patterns - (x + y + z for x in range(2) for y in range(2) for z in range(2)), - (x * y * z for x in range(1, 3) for y in range(1, 3) for z in range(1, 3)), - ((x, y, z, w) for x in range(2) for y in range(2) for z in range(2) for w in range(2)), - - # Nested loops with complex filters - ((x, y, z) for x in range(5) for y in range(5) for z in range(5) if x < y and y < z), - (x + y + z for x in range(3) if x > 0 for y in range(3) if y != x for z in range(3) if z != x and z != y), - - # Mixed range types - ((x, y) for x in range(-2, 2) for y in range(0, 4, 2)), - (x * y for x in range(5, 0, -1) for y in range(1, 6)), - - # Dependent nested loops - ((x, y) for x in range(3) for y in range(x, 3)), - (x + y for x in range(3) for y in range(x + 1, 3)), - (x * y * z for x in range(3) for y in range(x + 1, x + 3) for z in range(y, y + 3)), -]) + +@pytest.mark.parametrize( + "genexpr", + [ + # Basic nested loops + ((x, y) for x in range(3) for y in range(3)), + (x + y for x in range(3) for y in range(3)), + (x * y for x in range(1, 4) for y in range(1, 4)), + # Nested with filters + ((x, y) for x in range(5) for y in range(5) if x < y), + (x + y for x in range(5) if x % 2 == 0 for y in range(5) if y % 2 == 1), + # Triple nested + ((x, y, z) for x in range(2) for y in range(2) for z in range(2)), + # More complex nested loop edge cases + # Different sized ranges + ((x, y) for x in range(2) for y in range(5)), + ((x, y) for x in range(10) for y in range(2)), + # Asymmetric operations + (x - y for x in range(5) for y in range(3)), + (x / (y + 1) for x in range(1, 6) for y in range(3)), + (x**y for x in range(1, 4) for y in range(3)), + # Complex expressions with multiple variables + (x * y + x for x in range(3) for y in range(3)), + (x + y + x * y for x in range(1, 4) for y in range(1, 4)), + ((x + y) ** 2 for x in range(3) for y in range(3)), + # Filters on different loop levels + ((x, y) for x in range(10) if x % 2 == 0 for y in range(10) if y % 3 == 0), + (x * y for x in range(5) for y in range(5) if x != y), + (x + y for x in range(5) for y in range(5) if x + y < 5), + # Triple and quadruple nested with various patterns + (x + y + z for x in range(2) for y in range(2) for z in range(2)), + (x * y * z for x in range(1, 3) for y in range(1, 3) for z in range(1, 3)), + ( + (x, y, z, w) + for x in range(2) + for y in range(2) + for z in range(2) + for w in range(2) + ), + # Nested loops with complex filters + ( + (x, y, z) + for x in range(5) + for y in range(5) + for z in range(5) + if x < y and y < z + ), + (x + y for x in range(3) if x > 0 for y in range(3)), + # Mixed range types + ((x, y) for x in range(-2, 2) for y in range(0, 4, 2)), + (x * y for x in range(5, 0, -1) for y in range(1, 6)), + # Dependent nested loops + ((x, y) for x in range(3) for y in range(x, 3)), + (x + y for x in range(3) for y in range(x + 1, 3)), + ( + x * y * z + for x in range(3) + for y in range(x + 1, x + 3) + for z in range(y, y + 3) + ), + ], +) def test_nested_loops(genexpr): """Test reconstruction of generators with nested loops.""" ast_node = reconstruct(genexpr) @@ -334,29 +379,33 @@ def test_nested_loops(genexpr): # NESTED COMPREHENSIONS # =========================================================================== -@pytest.mark.parametrize("genexpr", [ - ([x for x in range(i)] for i in range(5)), - ({x: x**2 for x in range(i)} for i in range(5)), - ([[x for x in range(i + j)] for j in range(i)] for i in range(5)), - - # aggregation function call - (sum(x for x in range(i + 1)) for i in range(3)), - (max(x for x in range(i + 1)) for i in range(3)), - - # map - (list(map(abs, (x + 1 for x in range(i + 1)))) for i in range(3)), - (list(enumerate(x + 1 for x in range(i + 1))) for i in range(3)), - - # Nested comprehensions with filters inside - ([x for x in range(i)] for i in range(5) if i > 0), - ([x for x in range(i) if x < i] for i in range(5) if i > 0), - ([[x for x in range(i + j) if x < i + j] for j in range(i)] for i in range(5)), - ([[x for x in range(i + j) if x < i + j] for j in range(i)] for i in range(5) if i > 0), - - # nesting on both sides - ([y for y in range(x)] for x in (x_ + 1 for x_ in range(5))), - ([y for y in range(x)] for x in (x_ + 1 for x_ in range(5))), -]) + +@pytest.mark.parametrize( + "genexpr", + [ + ([x for x in range(i)] for i in range(5)), + ({x: x**2 for x in range(i)} for i in range(5)), + ([[x for x in range(i + j)] for j in range(i)] for i in range(5)), + # aggregation function call + (sum(x for x in range(i + 1)) for i in range(3)), + (max(x for x in range(i + 1)) for i in range(3)), + # map + (list(map(abs, (x + 1 for x in range(i + 1)))) for i in range(3)), + (list(enumerate(x + 1 for x in range(i + 1))) for i in range(3)), + # Nested comprehensions with filters inside + ([x for x in range(i)] for i in range(5) if i > 0), + ([x for x in range(i) if x < i] for i in range(5) if i > 0), + ([[x for x in range(i + j) if x < i + j] for j in range(i)] for i in range(5)), + ( + [[x for x in range(i + j) if x < i + j] for j in range(i)] + for i in range(5) + if i > 0 + ), + # nesting on both sides + ([y for y in range(x)] for x in (x_ + 1 for x_ in range(5))), + ([y for y in range(x)] for x in (x_ + 1 for x_ in range(5))), + ], +) def test_nested_comprehensions(genexpr): """Test reconstruction of nested comprehensions.""" ast_node = reconstruct(genexpr) @@ -367,17 +416,20 @@ def test_nested_comprehensions(genexpr): # DIFFERENT COMPREHENSION TYPES # ============================================================================ -@pytest.mark.parametrize("genexpr", [ - # Comprehensions as iterator constants - (x_ for x_ in [x for x in range(5)]), - (x_ for x_ in {x for x in range(5)}), - (x_ for x_ in {x: x**2 for x in range(5)}), - - # Comprehensions as yield expressions - ([y for y in range(x + 1)] for x in range(3)), - ({y for y in range(x + 1)} for x in range(3)), - ({y: y**2 for y in range(x + 1)} for x in range(3)), -]) + +@pytest.mark.parametrize( + "genexpr", + [ + # Comprehensions as iterator constants + (x_ for x_ in [x for x in range(5)]), + (x_ for x_ in {x for x in range(5)}), + (x_ for x_ in {x: x**2 for x in range(5)}), + # Comprehensions as yield expressions + ([y for y in range(x + 1)] for x in range(3)), + ({y for y in range(x + 1)} for x in range(3)), + ({y: y**2 for y in range(x + 1)} for x in range(3)), + ], +) def test_different_comprehension_types(genexpr): """Test reconstruction of different comprehension types.""" ast_node = reconstruct(genexpr) @@ -388,22 +440,25 @@ def test_different_comprehension_types(genexpr): # GENERATOR EXPRESSION WITH GLOBALS # ============================================================================ -@pytest.mark.parametrize("genexpr,globals_dict", [ - # Using constants - ((x + a for x in range(5)), {'a': 10}), - ((data[i] for i in range(2)), {'data': [3, 4]}), - - # Using global functions - ((abs(x) for x in range(-5, 5)), {'abs': abs}), - ((len(s) for s in ["a", "ab", "abc"]), {'len': len}), - ((max(x, 5) for x in range(10)), {'max': max}), - ((min(x, 5) for x in range(10)), {'min': min}), - ((round(x / 3, 2) for x in range(10)), {'round': round}), -]) + +@pytest.mark.parametrize( + "genexpr,globals_dict", + [ + # Using constants + ((x + a for x in range(5)), {"a": 10}), # noqa: F821 + ((data[i] for i in range(2)), {"data": [3, 4]}), # noqa: F821 + # Using global functions + ((abs(x) for x in range(-5, 5)), {"abs": abs}), + ((len(s) for s in ["a", "ab", "abc"]), {"len": len}), + ((max(x, 5) for x in range(10)), {"max": max}), + ((min(x, 5) for x in range(10)), {"min": min}), + ((round(x / 3, 2) for x in range(10)), {"round": round}), + ], +) def test_variable_lookup(genexpr, globals_dict): """Test reconstruction of expressions with globals.""" ast_node = reconstruct(genexpr) - + # Need to provide the same globals for evaluation assert_ast_equivalent(genexpr, ast_node, globals_dict) @@ -412,53 +467,63 @@ def test_variable_lookup(genexpr, globals_dict): # EDGE CASES AND COMPLEX SCENARIOS # ============================================================================ -@pytest.mark.parametrize("genexpr,globals_dict", [ - # Using lambdas and functions - pytest.param(((lambda y: y * 2)(x) for x in range(5)), {}, marks=pytest.mark.xfail(reason="Lambda reconstruction not implemented yet")), - pytest.param(((lambda y: y + 1)(x) for x in range(5)), {}, marks=pytest.mark.xfail(reason="Lambda reconstruction not implemented yet")), - pytest.param(((lambda y: y ** 2)(x) for x in range(5)), {}, marks=pytest.mark.xfail(reason="Lambda reconstruction not implemented yet")), - - # More complex lambdas - # (((lambda a, b: a + b)(x, x) for x in range(5)), {}), - ((f(x) for x in range(5)), {'f': lambda y: y * 3}), - - # Attribute access - ((x.real for x in [1+2j, 3+4j, 5+6j]), {}), - ((x.imag for x in [1+2j, 3+4j, 5+6j]), {}), - ((x.conjugate() for x in [1+2j, 3+4j, 5+6j]), {}), - - # Method calls - ((s.upper() for s in ["hello", "world"]), {}), - ((s.lower() for s in ["HELLO", "WORLD"]), {}), - ((s.strip() for s in [" hello ", " world "]), {}), - ((x.bit_length() for x in range(1, 10)), {}), - ((str(x).zfill(3) for x in range(10)), {'str': str}), - - # Subscript operations - (([10, 20, 30][i] for i in range(3)), {}), - (({'a': 1, 'b': 2, 'c': 3}[k] for k in ['a', 'b', 'c']), {}), - (("hello"[i] for i in range(5)), {}), - ((data[i][j] for i in range(2) for j in range(2)), {'data': [[1, 2], [3, 4]]}), - - # # More complex attribute chains - # ((obj.value.bit_length() for obj in [type('', (), {'value': x})() for x in range(1, 5)]), {}), - - # Multiple function calls - ((abs(max(x, -x)) for x in range(-3, 4)), {'abs': abs, 'max': max}), - ((len(str(x)) for x in range(100, 110)), {'len': len, 'str': str}), - - # Mixed operations - ((abs(x) + len(str(x)) for x in range(-10, 10)), {'abs': abs, 'len': len, 'str': str}), - ((s.upper().lower() for s in ["Hello", "World"]), {}), - - # Edge cases with complex data structures - (([1, 2, 3][x % 3] for x in range(10)), {}), - # (({"even": x, "odd": x + 1}["even" if x % 2 == 0 else "odd"] for x in range(5)), {}), - - # Function calls with multiple arguments - ((pow(x, 2, 10) for x in range(5)), {'pow': pow}), - ((divmod(x, 3) for x in range(10)), {'divmod': divmod}), -]) + +@pytest.mark.parametrize( + "genexpr,globals_dict", + [ + # Using lambdas and functions + pytest.param( + ((lambda y: y * 2)(x) for x in range(5)), + {}, + marks=pytest.mark.xfail(reason="Lambda reconstruction not implemented yet"), + ), + pytest.param( + ((lambda y: y + 1)(x) for x in range(5)), + {}, + marks=pytest.mark.xfail(reason="Lambda reconstruction not implemented yet"), + ), + pytest.param( + ((lambda y: y**2)(x) for x in range(5)), + {}, + marks=pytest.mark.xfail(reason="Lambda reconstruction not implemented yet"), + ), + # More complex lambdas + # (((lambda a, b: a + b)(x, x) for x in range(5)), {}), + ((f(x) for x in range(5)), {"f": lambda y: y * 3}), # noqa: F821 + # Attribute access + ((x.real for x in [1 + 2j, 3 + 4j, 5 + 6j]), {}), + ((x.imag for x in [1 + 2j, 3 + 4j, 5 + 6j]), {}), + ((x.conjugate() for x in [1 + 2j, 3 + 4j, 5 + 6j]), {}), + # Method calls + ((s.upper() for s in ["hello", "world"]), {}), + ((s.lower() for s in ["HELLO", "WORLD"]), {}), + ((s.strip() for s in [" hello ", " world "]), {}), + ((x.bit_length() for x in range(1, 10)), {}), + ((str(x).zfill(3) for x in range(10)), {"str": str}), + # Subscript operations + (([10, 20, 30][i] for i in range(3)), {}), + (({"a": 1, "b": 2, "c": 3}[k] for k in ["a", "b", "c"]), {}), + (("hello"[i] for i in range(5)), {}), + ((data[i][j] for i in range(2) for j in range(2)), {"data": [[1, 2], [3, 4]]}), # noqa: F821 + # # More complex attribute chains + # ((obj.value.bit_length() for obj in [type('', (), {'value': x})() for x in range(1, 5)]), {}), + # Multiple function calls + ((abs(max(x, -x)) for x in range(-3, 4)), {"abs": abs, "max": max}), + ((len(str(x)) for x in range(100, 110)), {"len": len, "str": str}), + # Mixed operations + ( + (abs(x) + len(str(x)) for x in range(-10, 10)), + {"abs": abs, "len": len, "str": str}, + ), + ((s.upper().lower() for s in ["Hello", "World"]), {}), + # Edge cases with complex data structures + (([1, 2, 3][x % 3] for x in range(10)), {}), + # (({"even": x, "odd": x + 1}["even" if x % 2 == 0 else "odd"] for x in range(5)), {}), + # Function calls with multiple arguments + ((pow(x, 2, 10) for x in range(5)), {"pow": pow}), + ((divmod(x, 3) for x in range(10)), {"divmod": divmod}), + ], +) def test_complex_scenarios(genexpr, globals_dict): """Test reconstruction of complex generator expressions.""" ast_node = reconstruct(genexpr) @@ -471,107 +536,107 @@ def test_complex_scenarios(genexpr, globals_dict): # HELPER FUNCTION TESTS # ============================================================================ -@pytest.mark.parametrize("value,expected_str", [ - # AST nodes should be returned as-is - (ast.Name(id='x', ctx=ast.Load()), 'x'), - (ast.Constant(value=42), '42'), - (ast.List(elts=[], ctx=ast.Load()), '[]'), - (ast.BinOp(left=ast.Constant(value=1), op=ast.Add(), right=ast.Constant(value=2)), '1 + 2'), - - # Constants should become ast.Constant nodes - (42, '42'), - (3.14, '3.14'), - (-42, '-42'), - (-3.14, '-3.14'), - ('hello', "'hello'"), - ("", "''"), - (b'bytes', "b'bytes'"), - (b'', "b''"), - (True, 'True'), - (False, 'False'), - (None, 'None'), - - # Complex numbers - (1+2j, '(1+2j)'), - (0+1j, '1j'), - (3+0j, '(3+0j)'), - (-1-2j, '(-1-2j)'), - - # Tuples should become ast.Tuple nodes - ((), '()'), - ((1,), '(1,)'), - ((1, 2), '(1, 2)'), - (('a', 'b', 'c'), "('a', 'b', 'c')"), - - # Special dict_item tuples - (('dict_item', 'key', 'value'), "('key', 'value')"), - (('dict_item', 42, 'answer'), "(42, 'answer')"), - - # Nested tuples - ((1, (2, 3)), '(1, (2, 3))'), - (((1, 2), (3, 4)), '((1, 2), (3, 4))'), - ((1, 2, (3, (4, 5))), '(1, 2, (3, (4, 5)))'), - - # Lists should become ast.List nodes - ([1, 2, 3], '[1, 2, 3]'), - (['hello', 'world'], "['hello', 'world']"), - ([True, False, None], '[True, False, None]'), - - # Nested lists - ([[1, 2], [3, 4]], '[[1, 2], [3, 4]]'), - ([1, [2, [3, 4]], 5], '[1, [2, [3, 4]], 5]'), - - # Mixed nested structures - ([(1, 2), (3, 4)], '[(1, 2), (3, 4)]'), - (([1, 2], [3, 4]), '([1, 2], [3, 4])'), - - # Dicts should become ast.Dict nodes - ({'a': 1}, "{'a': 1}"), - ({'x': 10, 'y': 20}, "{'x': 10, 'y': 20}"), - ({1: 'one', 2: 'two'}, "{1: 'one', 2: 'two'}"), - - # Nested dicts - ({'a': {'b': 1}}, "{'a': {'b': 1}}"), - ({'nums': [1, 2, 3], 'strs': ['a', 'b']}, "{'nums': [1, 2, 3], 'strs': ['a', 'b']}"), - - # Range objects - (range(5), 'range(0, 5, 1)'), - (range(1, 10), 'range(1, 10, 1)'), - (range(0, 10, 2), 'range(0, 10, 2)'), - (range(10, 0, -1), 'range(10, 0, -1)'), - (range(-5, 5), 'range(-5, 5, 1)'), - - # Empty collections - ([], '[]'), - ((), '()'), - ({}, '{}'), - - # Complex nested structures - ([1, [2, 3], 4], '[1, [2, 3], 4]'), - ({'a': [1, 2], 'b': {'c': 3}}, "{'a': [1, 2], 'b': {'c': 3}}"), - ([(1, {'a': [2, 3]}), ({'b': 4}, 5)], "[(1, {'a': [2, 3]}), ({'b': 4}, 5)]"), - - # Edge cases with special values - ([None, True, False, 0, ''], "[None, True, False, 0, '']"), - ({'': 'empty', None: 'none', 0: 'zero'}, "{'': 'empty', None: 'none', 0: 'zero'}"), - - # Large numbers - (999999999999999999999, '999999999999999999999'), - (1.7976931348623157e+308, '1.7976931348623157e+308'), # Close to float max - - # Sets - note unparse equivalence may fail for unordered collections - ({1, 2, 3}, '{1, 2, 3}'), -]) + +@pytest.mark.parametrize( + "value,expected_str", + [ + # AST nodes should be returned as-is + (ast.Name(id="x", ctx=ast.Load()), "x"), + (ast.Constant(value=42), "42"), + (ast.List(elts=[], ctx=ast.Load()), "[]"), + ( + ast.BinOp( + left=ast.Constant(value=1), op=ast.Add(), right=ast.Constant(value=2) + ), + "1 + 2", + ), + # Constants should become ast.Constant nodes + (42, "42"), + (3.14, "3.14"), + (-42, "-42"), + (-3.14, "-3.14"), + ("hello", "'hello'"), + ("", "''"), + (b"bytes", "b'bytes'"), + (b"", "b''"), + (True, "True"), + (False, "False"), + (None, "None"), + # Complex numbers + (1 + 2j, "(1+2j)"), + (0 + 1j, "1j"), + (3 + 0j, "(3+0j)"), + (-1 - 2j, "(-1-2j)"), + # Tuples should become ast.Tuple nodes + ((), "()"), + ((1,), "(1,)"), + ((1, 2), "(1, 2)"), + (("a", "b", "c"), "('a', 'b', 'c')"), + # Special dict_item tuples + (("dict_item", "key", "value"), "('key', 'value')"), + (("dict_item", 42, "answer"), "(42, 'answer')"), + # Nested tuples + ((1, (2, 3)), "(1, (2, 3))"), + (((1, 2), (3, 4)), "((1, 2), (3, 4))"), + ((1, 2, (3, (4, 5))), "(1, 2, (3, (4, 5)))"), + # Lists should become ast.List nodes + ([1, 2, 3], "[1, 2, 3]"), + (["hello", "world"], "['hello', 'world']"), + ([True, False, None], "[True, False, None]"), + # Nested lists + ([[1, 2], [3, 4]], "[[1, 2], [3, 4]]"), + ([1, [2, [3, 4]], 5], "[1, [2, [3, 4]], 5]"), + # Mixed nested structures + ([(1, 2), (3, 4)], "[(1, 2), (3, 4)]"), + (([1, 2], [3, 4]), "([1, 2], [3, 4])"), + # Dicts should become ast.Dict nodes + ({"a": 1}, "{'a': 1}"), + ({"x": 10, "y": 20}, "{'x': 10, 'y': 20}"), + ({1: "one", 2: "two"}, "{1: 'one', 2: 'two'}"), + # Nested dicts + ({"a": {"b": 1}}, "{'a': {'b': 1}}"), + ( + {"nums": [1, 2, 3], "strs": ["a", "b"]}, + "{'nums': [1, 2, 3], 'strs': ['a', 'b']}", + ), + # Range objects + (range(5), "range(0, 5, 1)"), + (range(1, 10), "range(1, 10, 1)"), + (range(0, 10, 2), "range(0, 10, 2)"), + (range(10, 0, -1), "range(10, 0, -1)"), + (range(-5, 5), "range(-5, 5, 1)"), + # Empty collections + ([], "[]"), + ((), "()"), + ({}, "{}"), + # Complex nested structures + ([1, [2, 3], 4], "[1, [2, 3], 4]"), + ({"a": [1, 2], "b": {"c": 3}}, "{'a': [1, 2], 'b': {'c': 3}}"), + ([(1, {"a": [2, 3]}), ({"b": 4}, 5)], "[(1, {'a': [2, 3]}), ({'b': 4}, 5)]"), + # Edge cases with special values + ([None, True, False, 0, ""], "[None, True, False, 0, '']"), + ( + {"": "empty", None: "none", 0: "zero"}, + "{'': 'empty', None: 'none', 0: 'zero'}", + ), + # Large numbers + (999999999999999999999, "999999999999999999999"), + (1.7976931348623157e308, "1.7976931348623157e+308"), # Close to float max + # Sets - note unparse equivalence may fail for unordered collections + ({1, 2, 3}, "{1, 2, 3}"), + ], +) def test_ensure_ast(value, expected_str): """Test that ensure_ast correctly converts various values to AST nodes.""" from effectful.internals.disassembler import ensure_ast - + result = ensure_ast(value) # Compare the unparsed strings result_str = ast.unparse(result) - assert result_str == expected_str, \ - f"ensure_ast({repr(value)}) produced '{result_str}', expected '{expected_str}'" + assert ( + result_str == expected_str + ), f"ensure_ast({repr(value)}) produced '{result_str}', expected '{expected_str}'" def test_error_handling(): @@ -579,7 +644,7 @@ def test_error_handling(): # Test with non-generator input with pytest.raises(AssertionError): reconstruct([1, 2, 3]) # Not a generator - + # Test with consumed generator gen = (x for x in range(5)) list(gen) # Consume it From 7566cc4a99349af849f7cc1784675155fc9f0915 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 19 Jun 2025 10:59:21 -0400 Subject: [PATCH 049/106] appease mypy --- effectful/internals/disassembler.py | 80 +++++++++++++++++----------- tests/test_internals_disassembler.py | 8 +-- 2 files changed, 52 insertions(+), 36 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index cfa3f9c5..5e0a2fbe 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -61,6 +61,7 @@ def __init__(self, body: CompExp): ) def inline(self, iterator: ast.expr) -> CompExp: + assert isinstance(self.body, CompExp) res: CompExp = copy.deepcopy(self.body) res.generators[0].iter = iterator return res @@ -188,6 +189,7 @@ def handle_build_list( new_stack = state.stack + [ret] return replace(state, stack=new_stack, result=ret) else: + assert instr.arg is not None size: int = instr.arg # Pop elements for the list elements = ( @@ -232,6 +234,7 @@ def handle_build_set( new_stack = state.stack + [ret] return replace(state, stack=new_stack, result=ret) else: + assert instr.arg is not None size: int = instr.arg # Pop elements for the set elements = ( @@ -276,9 +279,12 @@ def handle_build_map( new_stack = state.stack + [ret] return replace(state, stack=new_stack, result=ret) else: + assert instr.arg is not None size: int = instr.arg # Pop key-value pairs for the dict - keys = [ensure_ast(state.stack[-2 * i - 2]) for i in range(size)] + keys: list[ast.expr | None] = [ + ensure_ast(state.stack[-2 * i - 2]) for i in range(size) + ] values = [ensure_ast(state.stack[-2 * i - 1]) for i in range(size)] new_stack = state.stack[: -2 * size] if size > 0 else state.stack @@ -318,9 +324,9 @@ def handle_return_value( if isinstance(state.result, CompExp): return replace(state, stack=state.stack[:-1]) elif isinstance(state.result, Placeholder) and len(state.stack) == 1: - return replace( - state, stack=state.stack[:-1], result=ensure_ast(state.stack[-1]) - ) + new_result = ensure_ast(state.stack[-1]) + assert isinstance(new_result, CompExp | ast.Lambda) + return replace(state, stack=state.stack[:-1], result=new_result) else: raise TypeError("Unexpected RETURN_VALUE in reconstruction") @@ -402,7 +408,8 @@ def handle_unpack_sequence( ) -> ReconstructionState: # UNPACK_SEQUENCE unpacks a sequence into multiple values # arg is the number of values to unpack - unpack_count = instr.arg + assert instr.arg is not None + unpack_count: int = instr.arg sequence = ensure_ast(state.stack[-1]) # noqa: F841 new_stack = state.stack[:-1] @@ -459,19 +466,18 @@ def handle_store_fast( # Update the last loop in the generators list if isinstance(state.result, ast.DictComp): - new_ret = ast.DictComp( + new_dict: ast.DictComp = ast.DictComp( key=state.result.key, value=state.result.value, generators=state.result.generators[:-1] + [updated_loop], ) + return replace(state, result=new_dict) else: - new_ret = type(state.result)( + new_comp: ast.GeneratorExp | ast.ListComp | ast.SetComp = type(state.result)( elt=state.result.elt, generators=state.result.generators[:-1] + [updated_loop], ) - - # Create new loops list with the updated loop - return replace(state, result=new_ret) + return replace(state, result=new_comp) @register_handler("LOAD_CONST") @@ -525,19 +531,18 @@ def handle_store_deref( # Update the last loop in the generators list if isinstance(state.result, ast.DictComp): - new_ret = ast.DictComp( + new_dict: ast.DictComp = ast.DictComp( key=state.result.key, value=state.result.value, generators=state.result.generators[:-1] + [updated_loop], ) + return replace(state, result=new_dict) else: - new_ret = type(state.result)( + new_comp: ast.GeneratorExp | ast.ListComp | ast.SetComp = type(state.result)( elt=state.result.elt, generators=state.result.generators[:-1] + [updated_loop], ) - - # Create new loops list with the updated loop - return replace(state, result=new_ret) + return replace(state, result=new_comp) @register_handler("LOAD_DEREF") @@ -724,7 +729,7 @@ def handle_compare_op( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: assert ( - dis.cmp_op[instr.arg] == instr.argval + instr.arg is not None and dis.cmp_op[instr.arg] == instr.argval ), f"Unsupported comparison operation: {instr.argval}" right = ensure_ast(state.stack[-1]) @@ -777,6 +782,7 @@ def handle_call_function( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: # CALL_FUNCTION pops function and arguments from stack + assert instr.arg is not None arg_count: int = instr.arg # Pop arguments and function args = ( @@ -819,7 +825,8 @@ def handle_call_method( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: # CALL_METHOD calls a method - similar to CALL_FUNCTION but for methods - arg_count = instr.arg + assert instr.arg is not None + arg_count: int = instr.arg # Pop arguments and method args = ( [ensure_ast(arg) for arg in state.stack[-arg_count:]] if arg_count > 0 else [] @@ -912,6 +919,7 @@ def handle_binary_subscr( def handle_build_tuple( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: + assert instr.arg is not None tuple_size: int = instr.arg # Pop elements for the tuple elements = ( @@ -955,7 +963,9 @@ def handle_list_extend( if isinstance(list_obj, ast.List) and len(list_obj.elts) == 0: # If extending with a constant tuple, expand it to list elements if isinstance(iterable, ast.Constant) and isinstance(iterable.value, tuple): - elements = [ast.Constant(value=elem) for elem in iterable.value] + elements: list[ast.expr] = [ + ast.Constant(value=elem) for elem in iterable.value + ] list_node = ast.List(elts=elements, ctx=ast.Load()) new_stack = new_stack + [list_node] return replace(state, stack=new_stack) @@ -974,10 +984,12 @@ def handle_build_const_key_map( ) -> ReconstructionState: # BUILD_CONST_KEY_MAP builds a dictionary with constant keys # The keys are in a tuple on TOS, values are on the stack below + assert instr.arg is not None + assert isinstance(state.stack[-1], ast.Tuple), "Expected a tuple of keys" map_size: int = instr.arg # Pop the keys tuple and values keys_tuple: ast.Tuple = state.stack[-1] - keys = [ensure_ast(key) for key in keys_tuple.elts] + keys: list[ast.expr | None] = [ensure_ast(key) for key in keys_tuple.elts] values = [ensure_ast(val) for val in state.stack[-map_size - 1 : -1]] new_stack = state.stack[: -map_size - 1] @@ -1013,17 +1025,18 @@ def handle_pop_jump_if_false( is_async=state.result.generators[-1].is_async, ) if isinstance(state.result, ast.DictComp): - new_ret = ast.DictComp( + new_dict = ast.DictComp( key=state.result.key, value=state.result.value, generators=state.result.generators[:-1] + [updated_loop], ) + return replace(state, stack=new_stack, result=new_dict) else: - new_ret = type(state.result)( + new_comp = type(state.result)( elt=state.result.elt, generators=state.result.generators[:-1] + [updated_loop], ) - return replace(state, stack=new_stack, result=new_ret) + return replace(state, stack=new_stack, result=new_comp) else: raise NotImplementedError("Lazy and+or behavior not implemented yet") @@ -1053,17 +1066,18 @@ def handle_pop_jump_if_true( is_async=state.result.generators[-1].is_async, ) if isinstance(state.result, ast.DictComp): - new_ret = ast.DictComp( + new_dict = ast.DictComp( key=state.result.key, value=state.result.value, generators=state.result.generators[:-1] + [updated_loop], ) + return replace(state, stack=new_stack, result=new_dict) else: - new_ret = type(state.result)( + new_comp = type(state.result)( elt=state.result.elt, generators=state.result.generators[:-1] + [updated_loop], ) - return replace(state, stack=new_stack, result=new_ret) + return replace(state, stack=new_stack, result=new_comp) else: raise NotImplementedError("Lazy and+or behavior not implemented yet") @@ -1109,7 +1123,7 @@ def _ensure_ast_tuple(value: tuple) -> ast.Tuple: @ensure_ast.register(type(iter((1,)))) def _ensure_ast_tuple_iterator(value: Iterator) -> ast.Tuple: - return ensure_ast(tuple(value.__reduce__()[1][0])) + return ensure_ast(tuple(value.__reduce__()[1][0])) # type: ignore @ensure_ast.register @@ -1119,7 +1133,7 @@ def _ensure_ast_list(value: list) -> ast.List: @ensure_ast.register(type(iter([1]))) def _ensure_ast_list_iterator(value: Iterator) -> ast.List: - return ensure_ast(list(value.__reduce__()[1][0])) + return ensure_ast(list(value.__reduce__()[1][0])) # type: ignore @ensure_ast.register @@ -1129,7 +1143,7 @@ def _ensure_ast_set(value: set) -> ast.Set: @ensure_ast.register(type(iter({1}))) def _ensure_ast_set_iterator(value: Iterator) -> ast.Set: - return ensure_ast(set(value.__reduce__()[1][0])) + return ensure_ast(set(value.__reduce__()[1][0])) # type: ignore @ensure_ast.register @@ -1141,7 +1155,7 @@ def _ensure_ast_dict(value: dict) -> ast.Dict: @ensure_ast.register(type(iter({1: 2}))) -def _ensure_ast_dict_iterator(value: Iterator) -> ast.Dict: +def _ensure_ast_dict_iterator(value: Iterator) -> ast.expr: return ensure_ast(value.__reduce__()[1][0]) @@ -1156,7 +1170,7 @@ def _ensure_ast_range(value: range) -> ast.Call: @ensure_ast.register(type(iter(range(1)))) def _ensure_ast_range_iterator(value: Iterator) -> ast.Call: - return ensure_ast(value.__reduce__()[1][0]) + return ensure_ast(value.__reduce__()[1][0]) # type: ignore @ensure_ast.register @@ -1210,9 +1224,11 @@ def _ensure_ast_genexpr(genexpr: types.GeneratorType) -> ast.GeneratorExp: assert ( inspect.getgeneratorstate(genexpr) == inspect.GEN_CREATED ), "Generator must be in created state" - genexpr_ast: ast.GeneratorExp = ensure_ast(genexpr.gi_code) - geniter_ast: ast.expr = ensure_ast(genexpr.gi_frame.f_locals[".0"]) + genexpr_ast = ensure_ast(genexpr.gi_code) + assert isinstance(genexpr_ast, ast.GeneratorExp) + geniter_ast = ensure_ast(genexpr.gi_frame.f_locals[".0"]) result = CompLambda(genexpr_ast).inline(geniter_ast) + assert isinstance(result, ast.GeneratorExp) assert ( inspect.getgeneratorstate(genexpr) == inspect.GEN_CREATED ), "Generator must stay in created state" diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index dd5e2c2e..ad2723d7 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -445,8 +445,8 @@ def test_different_comprehension_types(genexpr): "genexpr,globals_dict", [ # Using constants - ((x + a for x in range(5)), {"a": 10}), # noqa: F821 - ((data[i] for i in range(2)), {"data": [3, 4]}), # noqa: F821 + ((x + a for x in range(5)), {"a": 10}), # type: ignore # noqa: F821 + ((data[i] for i in range(2)), {"data": [3, 4]}), # type: ignore # noqa: F821 # Using global functions ((abs(x) for x in range(-5, 5)), {"abs": abs}), ((len(s) for s in ["a", "ab", "abc"]), {"len": len}), @@ -489,7 +489,7 @@ def test_variable_lookup(genexpr, globals_dict): ), # More complex lambdas # (((lambda a, b: a + b)(x, x) for x in range(5)), {}), - ((f(x) for x in range(5)), {"f": lambda y: y * 3}), # noqa: F821 + ((f(x) for x in range(5)), {"f": lambda y: y * 3}), # type: ignore # noqa: F821 # Attribute access ((x.real for x in [1 + 2j, 3 + 4j, 5 + 6j]), {}), ((x.imag for x in [1 + 2j, 3 + 4j, 5 + 6j]), {}), @@ -504,7 +504,7 @@ def test_variable_lookup(genexpr, globals_dict): (([10, 20, 30][i] for i in range(3)), {}), (({"a": 1, "b": 2, "c": 3}[k] for k in ["a", "b", "c"]), {}), (("hello"[i] for i in range(5)), {}), - ((data[i][j] for i in range(2) for j in range(2)), {"data": [[1, 2], [3, 4]]}), # noqa: F821 + ((data[i][j] for i in range(2) for j in range(2)), {"data": [[1, 2], [3, 4]]}), # type: ignore # noqa: F821 # # More complex attribute chains # ((obj.value.bit_length() for obj in [type('', (), {'value': x})() for x in range(1, 5)]), {}), # Multiple function calls From 7991776332062f5b45818d4bb7dc07b573309556 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 19 Jun 2025 12:24:59 -0400 Subject: [PATCH 050/106] format 313 --- effectful/internals/disassembler.py | 96 ++++++++++++++-------------- tests/test_internals_disassembler.py | 12 ++-- 2 files changed, 54 insertions(+), 54 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index 5e0a2fbe..636983e3 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -125,9 +125,9 @@ def register_handler(opname: str, handler=None): def _wrapper( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: - assert ( - instr.opname == opname - ), f"Handler for '{opname}' called with wrong instruction" + assert instr.opname == opname, ( + f"Handler for '{opname}' called with wrong instruction" + ) return handler(state, instr) OP_HANDLERS[opname] = _wrapper @@ -145,9 +145,9 @@ def handle_gen_start( ) -> ReconstructionState: # GEN_START is typically the first instruction in generator expressions # It initializes the generator - assert isinstance( - state.result, Placeholder - ), "GEN_START must be the first instruction" + assert isinstance(state.result, Placeholder), ( + "GEN_START must be the first instruction" + ) return replace(state, result=ast.GeneratorExp(elt=Placeholder(), generators=[])) @@ -157,12 +157,12 @@ def handle_yield_value( ) -> ReconstructionState: # YIELD_VALUE pops a value from the stack and yields it # This is the expression part of the generator - assert isinstance( - state.result, ast.GeneratorExp - ), "YIELD_VALUE must be called after GEN_START" - assert isinstance( - state.result.elt, Placeholder - ), "YIELD_VALUE must be called before yielding" + assert isinstance(state.result, ast.GeneratorExp), ( + "YIELD_VALUE must be called after GEN_START" + ) + assert isinstance(state.result.elt, Placeholder), ( + "YIELD_VALUE must be called before yielding" + ) assert len(state.result.generators) > 0, "YIELD_VALUE should have generators" new_stack = state.stack[:-1] @@ -207,9 +207,9 @@ def handle_build_list( def handle_list_append( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: - assert isinstance( - state.result, ast.ListComp - ), "LIST_APPEND must be called within a ListComp context" + assert isinstance(state.result, ast.ListComp), ( + "LIST_APPEND must be called within a ListComp context" + ) new_stack = state.stack[:-1] new_ret = ast.ListComp( elt=ensure_ast(state.stack[-1]), @@ -252,9 +252,9 @@ def handle_build_set( def handle_set_add( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: - assert isinstance( - state.result, ast.SetComp - ), "SET_ADD must be called after BUILD_SET" + assert isinstance(state.result, ast.SetComp), ( + "SET_ADD must be called after BUILD_SET" + ) new_stack = state.stack[:-1] new_ret = ast.SetComp( elt=ensure_ast(state.stack[-1]), @@ -298,9 +298,9 @@ def handle_build_map( def handle_map_add( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: - assert isinstance( - state.result, ast.DictComp - ), "MAP_ADD must be called after BUILD_MAP" + assert isinstance(state.result, ast.DictComp), ( + "MAP_ADD must be called after BUILD_MAP" + ) new_stack = state.stack[:-2] new_ret = ast.DictComp( key=ensure_ast(state.stack[-2]), @@ -338,9 +338,9 @@ def handle_for_iter( # FOR_ITER pops an iterator from the stack and pushes the next item # If the iterator is exhausted, it jumps to the target instruction assert len(state.stack) > 0, "FOR_ITER must have an iterator on the stack" - assert isinstance( - state.result, CompExp - ), "FOR_ITER must be called within a comprehension context" + assert isinstance(state.result, CompExp), ( + "FOR_ITER must be called within a comprehension context" + ) # The iterator should be on top of stack # Create new stack without the iterator @@ -448,9 +448,9 @@ def handle_load_fast( def handle_store_fast( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: - assert isinstance( - state.result, CompExp - ), "STORE_FAST must be called within a comprehension context" + assert isinstance(state.result, CompExp), ( + "STORE_FAST must be called within a comprehension context" + ) var_name = instr.argval # Update the most recent loop's target variable @@ -513,9 +513,9 @@ def handle_store_deref( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: # STORE_DEREF stores a value into a closure variable - assert isinstance( - state.result, CompExp - ), "STORE_DEREF must be called within a comprehension context" + assert isinstance(state.result, CompExp), ( + "STORE_DEREF must be called within a comprehension context" + ) var_name = instr.argval # Update the most recent loop's target variable @@ -728,9 +728,9 @@ def handle_unary_op( def handle_compare_op( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: - assert ( - instr.arg is not None and dis.cmp_op[instr.arg] == instr.argval - ), f"Unsupported comparison operation: {instr.argval}" + assert instr.arg is not None and dis.cmp_op[instr.arg] == instr.argval, ( + f"Unsupported comparison operation: {instr.argval}" + ) right = ensure_ast(state.stack[-1]) left = ensure_ast(state.stack[-2]) @@ -1181,20 +1181,20 @@ def _ensure_ast_codeobj(value: types.CodeType) -> ast.expr: state = OP_HANDLERS[instr.opname](state, instr) # Check postconditions - assert not any( - isinstance(x, Placeholder) for x in ast.walk(state.result) - ), "Return value must not contain placeholders" - assert ( - not isinstance(state.result, CompExp) or len(state.result.generators) > 0 - ), "Return value must have generators if not a lambda" + assert not any(isinstance(x, Placeholder) for x in ast.walk(state.result)), ( + "Return value must not contain placeholders" + ) + assert not isinstance(state.result, CompExp) or len(state.result.generators) > 0, ( + "Return value must have generators if not a lambda" + ) return state.result @ensure_ast.register def _ensure_ast_lambda(value: types.LambdaType) -> ast.Lambda: - assert inspect.isfunction(value) and value.__name__.endswith( - "" - ), "Input must be a lambda function" + assert inspect.isfunction(value) and value.__name__.endswith(""), ( + "Input must be a lambda function" + ) code: types.CodeType = value.__code__ body: ast.expr = ensure_ast(code) @@ -1221,17 +1221,17 @@ def _ensure_ast_lambda(value: types.LambdaType) -> ast.Lambda: @ensure_ast.register def _ensure_ast_genexpr(genexpr: types.GeneratorType) -> ast.GeneratorExp: assert inspect.isgenerator(genexpr), "Input must be a generator expression" - assert ( - inspect.getgeneratorstate(genexpr) == inspect.GEN_CREATED - ), "Generator must be in created state" + assert inspect.getgeneratorstate(genexpr) == inspect.GEN_CREATED, ( + "Generator must be in created state" + ) genexpr_ast = ensure_ast(genexpr.gi_code) assert isinstance(genexpr_ast, ast.GeneratorExp) geniter_ast = ensure_ast(genexpr.gi_frame.f_locals[".0"]) result = CompLambda(genexpr_ast).inline(geniter_ast) assert isinstance(result, ast.GeneratorExp) - assert ( - inspect.getgeneratorstate(genexpr) == inspect.GEN_CREATED - ), "Generator must stay in created state" + assert inspect.getgeneratorstate(genexpr) == inspect.GEN_CREATED, ( + "Generator must stay in created state" + ) return result diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index ad2723d7..f217fe72 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -56,9 +56,9 @@ def assert_ast_equivalent( # Compile and evaluate the reconstructed AST reconstructed_gen = compile_and_eval(reconstructed_ast, globals_dict) reconstructed_list = list(reconstructed_gen) - assert ( - reconstructed_list == original_list - ), f"AST produced {reconstructed_list}, expected {original_list}" + assert reconstructed_list == original_list, ( + f"AST produced {reconstructed_list}, expected {original_list}" + ) # ============================================================================ @@ -634,9 +634,9 @@ def test_ensure_ast(value, expected_str): # Compare the unparsed strings result_str = ast.unparse(result) - assert ( - result_str == expected_str - ), f"ensure_ast({repr(value)}) produced '{result_str}', expected '{expected_str}'" + assert result_str == expected_str, ( + f"ensure_ast({repr(value)}) produced '{result_str}', expected '{expected_str}'" + ) def test_error_handling(): From d27954d1e57c22502e18ce6ea881af173d5e792a Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 19 Jun 2025 14:58:05 -0400 Subject: [PATCH 051/106] add first pass at some 3.13 ops --- effectful/internals/disassembler.py | 772 ++++++++++++++++++++++++++-- 1 file changed, 724 insertions(+), 48 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index 636983e3..c3398e45 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -20,10 +20,12 @@ import dis import functools import inspect +import sys import types import typing from collections.abc import Callable, Generator, Iterator from dataclasses import dataclass, field, replace +from enum import Enum CompExp = ast.GeneratorExp | ast.ListComp | ast.SetComp | ast.DictComp @@ -31,15 +33,19 @@ class Placeholder(ast.Name): """Placeholder for AST nodes that are not yet resolved.""" - def __init__(self): - super().__init__(id=".PLACEHOLDER", ctx=ast.Load()) + def __init__(self, id=".PLACEHOLDER", ctx=None): + if ctx is None: + ctx = ast.Load() + super().__init__(id=id, ctx=ctx) class DummyIterName(ast.Name): """Dummy name for the iterator variable in generator expressions.""" - def __init__(self): - super().__init__(id=".0", ctx=ast.Load()) + def __init__(self, id=".0", ctx=None): + if ctx is None: + ctx = ast.Load() + super().__init__(id=id, ctx=ctx) class CompLambda(ast.Lambda): @@ -97,6 +103,12 @@ class ReconstructionState: stack: list[ast.expr] = field(default_factory=list) +# Python version enum for version-specific handling +class PythonVersion(Enum): + PY_310 = 10 + PY_313 = 13 + + # Global handler registry OpHandler = Callable[[ReconstructionState, dis.Instruction], ReconstructionState] @@ -104,22 +116,39 @@ class ReconstructionState: @typing.overload -def register_handler(opname: str) -> Callable[[OpHandler], OpHandler]: ... +def register_handler( + opname: str, *, version: PythonVersion = PythonVersion(sys.version_info.minor) +) -> Callable[[OpHandler], OpHandler]: ... @typing.overload -def register_handler(opname: str, handler: OpHandler) -> OpHandler: ... - - -def register_handler(opname: str, handler=None): - """Register a handler for a specific operation name""" +def register_handler( + opname: str, + handler: OpHandler, + *, + version: PythonVersion = PythonVersion(sys.version_info.minor), +) -> OpHandler: ... + + +def register_handler( + opname: str, + handler=None, + *, + version: PythonVersion = PythonVersion(sys.version_info.minor), +): + """Register a handler for a specific operation name and optional version""" if handler is None: - return functools.partial(register_handler, opname) + return functools.partial(register_handler, opname, version=version) + + # Skip registration if version doesn't match current version + if version != PythonVersion(sys.version_info.minor): + return handler + # Only check opmap if the version matches (or no version specified) assert opname in dis.opmap, f"Invalid operation name: '{opname}'" if opname in OP_HANDLERS: - raise ValueError(f"Handler for '{opname}' already exists.") + raise ValueError(f"Handler for '{opname}' (version {version}) already exists.") @functools.wraps(handler) def _wrapper( @@ -139,11 +168,11 @@ def _wrapper( # ============================================================================ -@register_handler("GEN_START") +@register_handler("GEN_START", version=PythonVersion.PY_310) def handle_gen_start( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: - # GEN_START is typically the first instruction in generator expressions + # GEN_START is the first instruction in generator expressions in Python 3.10 # It initializes the generator assert isinstance(state.result, Placeholder), ( "GEN_START must be the first instruction" @@ -151,6 +180,18 @@ def handle_gen_start( return replace(state, result=ast.GeneratorExp(elt=Placeholder(), generators=[])) +@register_handler("RETURN_GENERATOR", version=PythonVersion.PY_313) +def handle_return_generator( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # RETURN_GENERATOR is the first instruction in generator expressions in Python 3.13+ + # It initializes the generator + assert isinstance(state.result, Placeholder), ( + "RETURN_GENERATOR must be the first instruction" + ) + return replace(state, result=ast.GeneratorExp(elt=Placeholder(), generators=[])) + + @register_handler("YIELD_VALUE") def handle_yield_value( state: ReconstructionState, instr: dis.Instruction @@ -158,7 +199,7 @@ def handle_yield_value( # YIELD_VALUE pops a value from the stack and yields it # This is the expression part of the generator assert isinstance(state.result, ast.GeneratorExp), ( - "YIELD_VALUE must be called after GEN_START" + "YIELD_VALUE must be called after RETURN_GENERATOR" ) assert isinstance(state.result.elt, Placeholder), ( "YIELD_VALUE must be called before yielding" @@ -178,8 +219,9 @@ def handle_yield_value( # ============================================================================ -@register_handler("BUILD_LIST") -def handle_build_list( +# Python 3.10 version +@register_handler("BUILD_LIST", version=PythonVersion.PY_310) +def handle_build_list_310( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: if isinstance(state.result, Placeholder) and len(state.stack) == 0: @@ -203,8 +245,45 @@ def handle_build_list( return replace(state, stack=new_stack) -@register_handler("LIST_APPEND") -def handle_list_append( +# Python 3.13 version +@register_handler("BUILD_LIST", version=PythonVersion.PY_313) +def handle_build_list( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + assert instr.arg is not None + size: int = instr.arg + + if size == 0: + # BUILD_LIST with 0 elements can mean two things: + # 1. Start of a list comprehension (if we're at the start or in a nested context) + # 2. Creating an empty list + + # Check if this looks like the start of a list comprehension pattern + # In nested comprehensions, BUILD_LIST(0) starts a new list comprehension + if isinstance(state.result, Placeholder) and len(state.stack) == 0: + # This is the start of a standalone list comprehension + ret = ast.ListComp(elt=Placeholder(), generators=[]) + new_stack = state.stack + [ret] + return replace(state, stack=new_stack, result=ret) + else: + # This might be a nested list comprehension or just an empty list + # For nested comprehensions, we create a ListComp and put it on the stack + # without changing the main result (which is the outer generator) + ret = ast.ListComp(elt=Placeholder(), generators=[]) + new_stack = state.stack + [ret] + return replace(state, stack=new_stack) + else: + # BUILD_LIST with elements - create a regular list + elements = [ensure_ast(elem) for elem in state.stack[-size:]] + new_stack = state.stack[:-size] + elt_node = ast.List(elts=elements, ctx=ast.Load()) + new_stack = new_stack + [elt_node] + return replace(state, stack=new_stack) + + +# Python 3.10 version +@register_handler("LIST_APPEND", version=PythonVersion.PY_310) +def handle_list_append_310( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: assert isinstance(state.result, ast.ListComp), ( @@ -218,6 +297,46 @@ def handle_list_append( return replace(state, stack=new_stack, result=new_ret) +# Python 3.13 version +@register_handler("LIST_APPEND", version=PythonVersion.PY_313) +def handle_list_append( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # LIST_APPEND appends to a list comprehension + # The list comprehension might be the main result or on the stack (for nested comprehensions) + + if isinstance(state.result, ast.ListComp): + # Main result is a list comprehension + new_stack = state.stack[:-1] + new_ret = ast.ListComp( + elt=ensure_ast(state.stack[-1]), + generators=state.result.generators, + ) + return replace(state, stack=new_stack, result=new_ret) + elif len(state.stack) >= 2 and isinstance(state.stack[-2], ast.ListComp): + # There's a list comprehension on the stack (nested case) + # LIST_APPEND with argument 2 means append to the list 2 positions down + assert instr.arg == 2, f"Expected LIST_APPEND with arg 2, got {instr.arg}" + + list_comp = state.stack[-2] + element = ensure_ast(state.stack[-1]) + + # Update the list comprehension with the element + new_list_comp = ast.ListComp( + elt=element, + generators=list_comp.generators, + ) + + # Replace the list comprehension on the stack + new_stack = state.stack[:-2] + [new_list_comp] + return replace(state, stack=new_stack) + else: + raise AssertionError( + f"LIST_APPEND must be called within a ListComp context. " + f"State result: {type(state.result)}, Stack: {[type(x) for x in state.stack]}" + ) + + # ============================================================================ # SET COMPREHENSION HANDLERS # ============================================================================ @@ -384,7 +503,16 @@ def handle_get_iter( return state -@register_handler("JUMP_ABSOLUTE") +@register_handler("JUMP_BACKWARD", version=PythonVersion.PY_313) +def handle_jump_backward( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # JUMP_BACKWARD is used to jump back to the beginning of a loop (replaces JUMP_ABSOLUTE in 3.13) + # In generator expressions, this typically indicates the end of the loop body + return state + + +@register_handler("JUMP_ABSOLUTE", version=PythonVersion.PY_310) def handle_jump_absolute( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -423,6 +551,66 @@ def handle_unpack_sequence( return replace(state, stack=new_stack) +@register_handler("RESUME", version=PythonVersion.PY_313) +def handle_resume( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # RESUME is used for resuming execution after yield/await - mostly no-op for AST reconstruction + return state + + +@register_handler("END_FOR", version=PythonVersion.PY_313) +def handle_end_for( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # END_FOR marks the end of a for loop - no action needed for AST reconstruction + return state + + +@register_handler("RETURN_CONST", version=PythonVersion.PY_313) +def handle_return_const( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # RETURN_CONST returns a constant value (replaces some LOAD_CONST + RETURN_VALUE patterns) + # Similar to RETURN_VALUE but with a constant + if isinstance(state.result, CompExp): + return state + elif isinstance(state.result, Placeholder) and len(state.stack) == 1: + new_result = ensure_ast(state.stack[-1]) + assert isinstance(new_result, CompExp | ast.Lambda) + return replace(state, stack=state.stack[:-1], result=new_result) + else: + # For generators, this typically ends the generator with None + return state + + +@register_handler("CALL_INTRINSIC_1", version=PythonVersion.PY_313) +def handle_call_intrinsic_1( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # CALL_INTRINSIC_1 calls an intrinsic function with one argument + # For generator expressions, this is often used for exception handling + # We can generally ignore this for AST reconstruction + return state + + +@register_handler("CALL_INTRINSIC_2", version=PythonVersion.PY_313) +def handle_call_intrinsic_2( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # CALL_INTRINSIC_2 calls an intrinsic function with two arguments + # We can generally ignore this for AST reconstruction + return state + + +@register_handler("RERAISE", version=PythonVersion.PY_313) +def handle_reraise( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # RERAISE re-raises an exception - generally ignore for AST reconstruction + return state + + # ============================================================================ # VARIABLE OPERATIONS HANDLERS # ============================================================================ @@ -444,6 +632,53 @@ def handle_load_fast( return replace(state, stack=new_stack) +@register_handler("LOAD_FAST_AND_CLEAR", version=PythonVersion.PY_313) +def handle_load_fast_and_clear( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # LOAD_FAST_AND_CLEAR pushes a local variable onto the stack and clears it + # For AST reconstruction, we treat this the same as LOAD_FAST + var_name: str = instr.argval + + if var_name == ".0": + # Special handling for .0 variable (the iterator) + new_stack = state.stack + [DummyIterName()] + else: + # Regular variable load + new_stack = state.stack + [ast.Name(id=var_name, ctx=ast.Load())] + + return replace(state, stack=new_stack) + + +@register_handler("LOAD_FAST_LOAD_FAST", version=PythonVersion.PY_313) +def handle_load_fast_load_fast( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # LOAD_FAST_LOAD_FAST loads two variables (optimization in Python 3.13) + # The instruction argument contains both variable names + if isinstance(instr.argval, tuple): + var1, var2 = instr.argval + else: + # Fallback: assume both names are the same + var1 = var2 = instr.argval + + new_stack = state.stack + + # Load first variable + if var1 == ".0": + new_stack = new_stack + [DummyIterName()] + else: + new_stack = new_stack + [ast.Name(id=var1, ctx=ast.Load())] + + # Load second variable + if var2 == ".0": + new_stack = new_stack + [DummyIterName()] + else: + new_stack = new_stack + [ast.Name(id=var2, ctx=ast.Load())] + + return replace(state, stack=new_stack) + + @register_handler("STORE_FAST") def handle_store_fast( state: ReconstructionState, instr: dis.Instruction @@ -480,6 +715,60 @@ def handle_store_fast( return replace(state, result=new_comp) +@register_handler("STORE_FAST_LOAD_FAST", version=PythonVersion.PY_313) +def handle_store_fast_load_fast( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # STORE_FAST_LOAD_FAST stores and then loads the same variable (optimization) + # The instruction has two names: store_name and load_name + # In Python 3.13, this is often used for loop variables + + # First handle the store part + assert isinstance(state.result, CompExp), ( + "STORE_FAST_LOAD_FAST must be called within a comprehension context" + ) + + # In Python 3.13, the instruction argument contains both names + # argval should be a tuple (store_name, load_name) + if isinstance(instr.argval, tuple): + store_name, load_name = instr.argval + else: + # Fallback: assume both names are the same + store_name = load_name = instr.argval + + # Update the most recent loop's target variable + assert len(state.result.generators) > 0, ( + "STORE_FAST_LOAD_FAST must be within a loop context" + ) + + # Create a new LoopInfo with updated target + updated_loop = ast.comprehension( + target=ast.Name(id=store_name, ctx=ast.Store()), + iter=state.result.generators[-1].iter, + ifs=state.result.generators[-1].ifs, + is_async=state.result.generators[-1].is_async, + ) + + # Update the last loop in the generators list + if isinstance(state.result, ast.DictComp): + new_dict: ast.DictComp = ast.DictComp( + key=state.result.key, + value=state.result.value, + generators=state.result.generators[:-1] + [updated_loop], + ) + new_state = replace(state, result=new_dict) + else: + new_comp: ast.GeneratorExp | ast.ListComp | ast.SetComp = type(state.result)( + elt=state.result.elt, + generators=state.result.generators[:-1] + [updated_loop], + ) + new_state = replace(state, result=new_comp) + + # Now handle the load part - push the variable onto the stack + new_stack = new_state.stack + [ast.Name(id=load_name, ctx=ast.Load())] + return replace(new_state, stack=new_stack) + + @register_handler("LOAD_CONST") def handle_load_const( state: ReconstructionState, instr: dis.Instruction @@ -581,7 +870,7 @@ def handle_pop_top( return replace(state, stack=new_stack) -@register_handler("DUP_TOP") +@register_handler("DUP_TOP", version=PythonVersion.PY_310) def handle_dup_top( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -591,7 +880,7 @@ def handle_dup_top( return replace(state, stack=new_stack) -@register_handler("ROT_TWO") +@register_handler("ROT_TWO", version=PythonVersion.PY_310) def handle_rot_two( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -600,7 +889,7 @@ def handle_rot_two( return replace(state, stack=new_stack) -@register_handler("ROT_THREE") +@register_handler("ROT_THREE", version=PythonVersion.PY_310) def handle_rot_three( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -616,7 +905,7 @@ def handle_rot_three( return replace(state, stack=new_stack) -@register_handler("ROT_FOUR") +@register_handler("ROT_FOUR", version=PythonVersion.PY_310) def handle_rot_four( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -631,6 +920,59 @@ def handle_rot_four( return replace(state, stack=new_stack) +# Python 3.13 replacement for stack manipulation +@register_handler("SWAP", version=PythonVersion.PY_313) +def handle_swap( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # SWAP exchanges the top two stack items (replaces ROT_TWO in many cases) + assert instr.arg is not None + depth = instr.arg + stack_size = len(state.stack) + + if depth > stack_size: + # Not enough items on stack - this might be a pattern where some items were optimized away + # For AST reconstruction, we can often ignore certain stack manipulations + return state + + if depth == 2 and stack_size >= 2: + # Equivalent to ROT_TWO + new_stack = state.stack[:-2] + [state.stack[-1], state.stack[-2]] + return replace(state, stack=new_stack) + elif depth <= stack_size: + # For other depths, swap TOS with the item at specified depth + idx = stack_size - depth + new_stack = state.stack.copy() + new_stack[-1], new_stack[idx] = new_stack[idx], new_stack[-1] + return replace(state, stack=new_stack) + else: + # Edge case - not enough items, just return unchanged + return state + + +@register_handler("COPY", version=PythonVersion.PY_313) +def handle_copy( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # COPY duplicates the item at the specified depth (replaces DUP_TOP in many cases) + assert instr.arg is not None + depth = instr.arg + if depth == 1: + # Equivalent to DUP_TOP + top_item = state.stack[-1] + new_stack = state.stack + [top_item] + return replace(state, stack=new_stack) + else: + # Copy the item at specified depth to top of stack + stack_size = len(state.stack) + if depth > stack_size: + raise ValueError(f"COPY depth {depth} exceeds stack size {stack_size}") + idx = stack_size - depth + copied_item = state.stack[idx] + new_stack = state.stack + [copied_item] + return replace(state, stack=new_stack) + + # ============================================================================ # BINARY ARITHMETIC/LOGIC OPERATION HANDLERS # ============================================================================ @@ -645,41 +987,98 @@ def handle_binop( return replace(state, stack=new_stack) +# Python 3.13 BINARY_OP handler +@register_handler("BINARY_OP", version=PythonVersion.PY_313) +def handle_binary_op( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # BINARY_OP in Python 3.13 consolidates all binary operations + # The operation type is determined by the instruction argument + assert instr.arg is not None + + # Map argument values to AST operators based on Python 3.13 implementation + op_map = { + 0: ast.Add(), # + + 1: ast.BitAnd(), # & + 2: ast.FloorDiv(), # // + 3: ast.LShift(), # << + 5: ast.Mult(), # * + 6: ast.Mod(), # % + 7: ast.BitOr(), # | (guessing based on pattern) + 8: ast.Pow(), # ** + 9: ast.RShift(), # >> + 10: ast.Sub(), # - + 11: ast.Div(), # / + 12: ast.BitXor(), # ^ + } + + op = op_map.get(instr.arg) + if op is None: + raise NotImplementedError(f"Unknown binary operation: {instr.arg}") + + return handle_binop(op, state, instr) + + +# Legacy binary operation handlers (for Python 3.10 compatibility) handler_binop_add = register_handler( - "BINARY_ADD", functools.partial(handle_binop, ast.Add()) + "BINARY_ADD", + functools.partial(handle_binop, ast.Add()), + version=PythonVersion.PY_310, ) handler_binop_subtract = register_handler( - "BINARY_SUBTRACT", functools.partial(handle_binop, ast.Sub()) + "BINARY_SUBTRACT", + functools.partial(handle_binop, ast.Sub()), + version=PythonVersion.PY_310, ) handler_binop_multiply = register_handler( - "BINARY_MULTIPLY", functools.partial(handle_binop, ast.Mult()) + "BINARY_MULTIPLY", + functools.partial(handle_binop, ast.Mult()), + version=PythonVersion.PY_310, ) handler_binop_true_divide = register_handler( - "BINARY_TRUE_DIVIDE", functools.partial(handle_binop, ast.Div()) + "BINARY_TRUE_DIVIDE", + functools.partial(handle_binop, ast.Div()), + version=PythonVersion.PY_310, ) handler_binop_floor_divide = register_handler( - "BINARY_FLOOR_DIVIDE", functools.partial(handle_binop, ast.FloorDiv()) + "BINARY_FLOOR_DIVIDE", + functools.partial(handle_binop, ast.FloorDiv()), + version=PythonVersion.PY_310, ) handler_binop_modulo = register_handler( - "BINARY_MODULO", functools.partial(handle_binop, ast.Mod()) + "BINARY_MODULO", + functools.partial(handle_binop, ast.Mod()), + version=PythonVersion.PY_310, ) handler_binop_power = register_handler( - "BINARY_POWER", functools.partial(handle_binop, ast.Pow()) + "BINARY_POWER", + functools.partial(handle_binop, ast.Pow()), + version=PythonVersion.PY_310, ) handler_binop_lshift = register_handler( - "BINARY_LSHIFT", functools.partial(handle_binop, ast.LShift()) + "BINARY_LSHIFT", + functools.partial(handle_binop, ast.LShift()), + version=PythonVersion.PY_310, ) handler_binop_rshift = register_handler( - "BINARY_RSHIFT", functools.partial(handle_binop, ast.RShift()) + "BINARY_RSHIFT", + functools.partial(handle_binop, ast.RShift()), + version=PythonVersion.PY_310, ) handler_binop_or = register_handler( - "BINARY_OR", functools.partial(handle_binop, ast.BitOr()) + "BINARY_OR", + functools.partial(handle_binop, ast.BitOr()), + version=PythonVersion.PY_310, ) handler_binop_xor = register_handler( - "BINARY_XOR", functools.partial(handle_binop, ast.BitXor()) + "BINARY_XOR", + functools.partial(handle_binop, ast.BitXor()), + version=PythonVersion.PY_310, ) handler_binop_and = register_handler( - "BINARY_AND", functools.partial(handle_binop, ast.BitAnd()) + "BINARY_AND", + functools.partial(handle_binop, ast.BitAnd()), + version=PythonVersion.PY_310, ) @@ -700,7 +1099,9 @@ def handle_unary_op( "UNARY_NEGATIVE", functools.partial(handle_unary_op, ast.USub()) ) handle_unary_positive = register_handler( - "UNARY_POSITIVE", functools.partial(handle_unary_op, ast.UAdd()) + "UNARY_POSITIVE", + functools.partial(handle_unary_op, ast.UAdd()), + version=PythonVersion.PY_310, ) handle_unary_invert = register_handler( "UNARY_INVERT", functools.partial(handle_unary_op, ast.Invert()) @@ -724,8 +1125,9 @@ def handle_unary_op( } -@register_handler("COMPARE_OP") -def handle_compare_op( +# Python 3.10 version +@register_handler("COMPARE_OP", version=PythonVersion.PY_310) +def handle_compare_op_310( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: assert instr.arg is not None and dis.cmp_op[instr.arg] == instr.argval, ( @@ -742,6 +1144,28 @@ def handle_compare_op( return replace(state, stack=new_stack) +# Python 3.13 version +@register_handler("COMPARE_OP", version=PythonVersion.PY_313) +def handle_compare_op( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # In Python 3.13, COMPARE_OP arguments have changed + # The operation is directly in argval, not indexed through cmp_op + assert instr.arg is not None + + right = ensure_ast(state.stack[-1]) + left = ensure_ast(state.stack[-2]) + + # Map comparison operation codes to AST operators + op_name = instr.argval + if op_name not in CMP_OPMAP: + raise NotImplementedError(f"Unsupported comparison operation: {op_name}") + + compare_node = ast.Compare(left=left, ops=[CMP_OPMAP[op_name]], comparators=[right]) + new_stack = state.stack[:-2] + [compare_node] + return replace(state, stack=new_stack) + + @register_handler("CONTAINS_OP") def handle_contains_op( state: ReconstructionState, instr: dis.Instruction @@ -777,7 +1201,31 @@ def handle_is_op( # ============================================================================ -@register_handler("CALL_FUNCTION") +@register_handler("CALL", version=PythonVersion.PY_313) +def handle_call( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # CALL pops function and arguments from stack (replaces CALL_FUNCTION in Python 3.13) + assert instr.arg is not None + arg_count: int = instr.arg + # Pop arguments and function + args = ( + [ensure_ast(arg) for arg in state.stack[-arg_count:]] if arg_count > 0 else [] + ) + func = ensure_ast(state.stack[-arg_count - 1]) + new_stack = state.stack[: -arg_count - 1] + + if isinstance(func, CompLambda): + assert len(args) == 1 + return replace(state, stack=new_stack + [func.inline(args[0])]) + else: + # Create function call AST + call_node = ast.Call(func=func, args=args, keywords=[]) + new_stack = new_stack + [call_node] + return replace(state, stack=new_stack) + + +@register_handler("CALL_FUNCTION", version=PythonVersion.PY_310) def handle_call_function( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -820,7 +1268,7 @@ def handle_load_method( return replace(state, stack=new_stack) -@register_handler("CALL_METHOD") +@register_handler("CALL_METHOD", version=PythonVersion.PY_310) def handle_call_method( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -840,8 +1288,9 @@ def handle_call_method( return replace(state, stack=new_stack) -@register_handler("MAKE_FUNCTION") -def handle_make_function( +# Python 3.10 version +@register_handler("MAKE_FUNCTION", version=PythonVersion.PY_310) +def handle_make_function_310( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: # MAKE_FUNCTION creates a function from code object and name on stack @@ -875,6 +1324,65 @@ def handle_make_function( raise NotImplementedError("Lambda reconstruction not implemented yet") +# Python 3.13 version +@register_handler("MAKE_FUNCTION", version=PythonVersion.PY_313) +def handle_make_function( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # MAKE_FUNCTION creates a function from code object and name on stack + # In Python 3.13, function attributes are set using SET_FUNCTION_ATTRIBUTE + # The stack layout is simpler: just code object and name + + assert isinstance(state.stack[-1], ast.Constant) and isinstance( + state.stack[-1].value, str + ), "Function name must be a constant string." + + # In Python 3.13, MAKE_FUNCTION typically just has code object and name + # Additional attributes like closures are handled by SET_FUNCTION_ATTRIBUTE + + # Check if there are any flags that indicate special handling needed + flags = instr.arg or 0 + + body: ast.expr + if flags & 0x08: # Closure flag + # This is a closure, remove the closure tuple from the stack + new_stack = state.stack[:-3] + body = state.stack[-3] + else: + # Simple function without closure + new_stack = state.stack[:-2] + body = state.stack[-2] + + name: str = state.stack[-1].value + + assert any( + name.endswith(suffix) + for suffix in ("", "", "", "", "") + ), f"Expected a comprehension or lambda function, got '{name}'" + + if ( + isinstance(body, CompExp) + and sum(1 for x in ast.walk(body) if isinstance(x, DummyIterName)) == 1 + ): + return replace(state, stack=new_stack + [CompLambda(body)]) + else: + raise NotImplementedError("Lambda reconstruction not implemented yet") + + +@register_handler("SET_FUNCTION_ATTRIBUTE", version=PythonVersion.PY_313) +def handle_set_function_attribute( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # SET_FUNCTION_ATTRIBUTE sets an attribute on a function object + # In Python 3.13, this is used instead of flags in MAKE_FUNCTION + # For AST reconstruction, we typically don't need to track function attributes + # Just pop the attribute value and leave the function on the stack + + # Pop the attribute value but keep the function + new_stack = state.stack[:-1] + return replace(state, stack=new_stack) + + # ============================================================================ # OBJECT ACCESS HANDLERS # ============================================================================ @@ -935,7 +1443,7 @@ def handle_build_tuple( return replace(state, stack=new_stack) -@register_handler("LIST_TO_TUPLE") +@register_handler("LIST_TO_TUPLE", version=PythonVersion.PY_310) def handle_list_to_tuple( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -1004,8 +1512,9 @@ def handle_build_const_key_map( # ============================================================================ -@register_handler("POP_JUMP_IF_FALSE") -def handle_pop_jump_if_false( +# Python 3.10 version +@register_handler("POP_JUMP_IF_FALSE", version=PythonVersion.PY_310) +def handle_pop_jump_if_false_310( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: # POP_JUMP_IF_FALSE pops a value from the stack and jumps if it's false @@ -1041,8 +1550,49 @@ def handle_pop_jump_if_false( raise NotImplementedError("Lazy and+or behavior not implemented yet") -@register_handler("POP_JUMP_IF_TRUE") -def handle_pop_jump_if_true( +# Python 3.13 version +@register_handler("POP_JUMP_IF_FALSE", version=PythonVersion.PY_313) +def handle_pop_jump_if_false( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # POP_JUMP_IF_FALSE pops a value from the stack and jumps if it's false + # In comprehensions, this is used for filter conditions + condition = ensure_ast(state.stack[-1]) + new_stack = state.stack[:-1] + + if isinstance(state.result, CompExp) and state.result.generators: + # In Python 3.13, when POP_JUMP_IF_FALSE jumps forward to the yield, + # it means "if condition is False, then yield the item" + # So we need to negate the condition: we want items where NOT condition + negated_condition = ast.UnaryOp(op=ast.Not(), operand=condition) + + updated_loop = ast.comprehension( + target=state.result.generators[-1].target, + iter=state.result.generators[-1].iter, + ifs=state.result.generators[-1].ifs + [negated_condition], + is_async=state.result.generators[-1].is_async, + ) + if isinstance(state.result, ast.DictComp): + new_dict = ast.DictComp( + key=state.result.key, + value=state.result.value, + generators=state.result.generators[:-1] + [updated_loop], + ) + return replace(state, stack=new_stack, result=new_dict) + else: + new_comp = type(state.result)( + elt=state.result.elt, + generators=state.result.generators[:-1] + [updated_loop], + ) + return replace(state, stack=new_stack, result=new_comp) + else: + # Not in a comprehension context - might be boolean logic + raise NotImplementedError("Lazy and+or behavior not implemented yet") + + +# Python 3.10 version +@register_handler("POP_JUMP_IF_TRUE", version=PythonVersion.PY_310) +def handle_pop_jump_if_true_310( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: # POP_JUMP_IF_TRUE pops a value from the stack and jumps if it's true @@ -1082,6 +1632,129 @@ def handle_pop_jump_if_true( raise NotImplementedError("Lazy and+or behavior not implemented yet") +# Python 3.13 version +@register_handler("POP_JUMP_IF_TRUE", version=PythonVersion.PY_313) +def handle_pop_jump_if_true( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # POP_JUMP_IF_TRUE pops a value from the stack and jumps if it's true + # In Python 3.13, this is used for filter conditions where True means continue + condition = ensure_ast(state.stack[-1]) + new_stack = state.stack[:-1] + + # In Python 3.13, if we have a comprehension and generators, this is likely a filter + if isinstance(state.result, CompExp) and state.result.generators: + # For POP_JUMP_IF_TRUE in filters, we want the condition to be true to continue + # So we add the condition directly (no negation needed) + updated_loop = ast.comprehension( + target=state.result.generators[-1].target, + iter=state.result.generators[-1].iter, + ifs=state.result.generators[-1].ifs + [condition], + is_async=state.result.generators[-1].is_async, + ) + if isinstance(state.result, ast.DictComp): + new_dict = ast.DictComp( + key=state.result.key, + value=state.result.value, + generators=state.result.generators[:-1] + [updated_loop], + ) + return replace(state, stack=new_stack, result=new_dict) + else: + new_comp = type(state.result)( + elt=state.result.elt, + generators=state.result.generators[:-1] + [updated_loop], + ) + return replace(state, stack=new_stack, result=new_comp) + else: + # Not in a comprehension context - might be boolean logic + raise NotImplementedError("Lazy and+or behavior not implemented yet") + + +@register_handler("TO_BOOL", version=PythonVersion.PY_313) +def handle_to_bool( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # TO_BOOL converts the top stack item to a boolean + # For AST reconstruction, we typically don't need an explicit bool() call + # since the boolean context is usually handled by the conditional jump that follows + # However, for some cases we might need to preserve the explicit conversion + + # For now, leave the value as-is since the jump instruction will handle the boolean logic + return state + + +@register_handler("POP_JUMP_IF_NONE", version=PythonVersion.PY_313) +def handle_pop_jump_if_none( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # POP_JUMP_IF_NONE pops a value and jumps if it's None + condition = ensure_ast(state.stack[-1]) + new_stack = state.stack[:-1] + + if isinstance(state.result, CompExp) and state.result.generators: + # Create "x is None" condition + none_condition = ast.Compare( + left=condition, ops=[ast.Is()], comparators=[ast.Constant(value=None)] + ) + updated_loop = ast.comprehension( + target=state.result.generators[-1].target, + iter=state.result.generators[-1].iter, + ifs=state.result.generators[-1].ifs + [none_condition], + is_async=state.result.generators[-1].is_async, + ) + if isinstance(state.result, ast.DictComp): + new_dict = ast.DictComp( + key=state.result.key, + value=state.result.value, + generators=state.result.generators[:-1] + [updated_loop], + ) + return replace(state, stack=new_stack, result=new_dict) + else: + new_comp = type(state.result)( + elt=state.result.elt, + generators=state.result.generators[:-1] + [updated_loop], + ) + return replace(state, stack=new_stack, result=new_comp) + else: + raise NotImplementedError("Lazy and+or behavior not implemented yet") + + +@register_handler("POP_JUMP_IF_NOT_NONE", version=PythonVersion.PY_313) +def handle_pop_jump_if_not_none( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # POP_JUMP_IF_NOT_NONE pops a value and jumps if it's not None + condition = ensure_ast(state.stack[-1]) + new_stack = state.stack[:-1] + + if isinstance(state.result, CompExp) and state.result.generators: + # Create "x is not None" condition + not_none_condition = ast.Compare( + left=condition, ops=[ast.IsNot()], comparators=[ast.Constant(value=None)] + ) + updated_loop = ast.comprehension( + target=state.result.generators[-1].target, + iter=state.result.generators[-1].iter, + ifs=state.result.generators[-1].ifs + [not_none_condition], + is_async=state.result.generators[-1].is_async, + ) + if isinstance(state.result, ast.DictComp): + new_dict = ast.DictComp( + key=state.result.key, + value=state.result.value, + generators=state.result.generators[:-1] + [updated_loop], + ) + return replace(state, stack=new_stack, result=new_dict) + else: + new_comp = type(state.result)( + elt=state.result.elt, + generators=state.result.generators[:-1] + [updated_loop], + ) + return replace(state, stack=new_stack, result=new_comp) + else: + raise NotImplementedError("Lazy and+or behavior not implemented yet") + + # ============================================================================ # UTILITY FUNCTIONS # ============================================================================ @@ -1178,6 +1851,9 @@ def _ensure_ast_codeobj(value: types.CodeType) -> ast.expr: # Symbolic execution to reconstruct the AST state = ReconstructionState() for instr in dis.get_instructions(value): + if instr.opname not in OP_HANDLERS: + raise KeyError(f"No handler found for opcode '{instr.opname}'") + state = OP_HANDLERS[instr.opname](state, instr) # Check postconditions From a19c3ae87e49c18f6f071437aa88dee5dd7160f8 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 19 Jun 2025 15:42:36 -0400 Subject: [PATCH 052/106] test nested genexpr with materialize --- tests/test_internals_disassembler.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index f217fe72..787a9e8f 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -3,6 +3,7 @@ from typing import Any import pytest +import tree from effectful.internals.disassembler import reconstruct @@ -26,6 +27,20 @@ def compile_and_eval( return eval(code, globals_dict) +def materialize(genexpr: GeneratorType) -> tree.Structure: + """Materialize a nested generator expression to a nested list.""" + + def _materialize(genexpr): + if isinstance(genexpr, GeneratorType): + return tree.map_structure(_materialize, list(genexpr)) + elif tree.is_nested(genexpr): + return tree.map_structure(_materialize, genexpr) + else: + return genexpr + + return _materialize(genexpr) + + def assert_ast_equivalent( genexpr: GeneratorType, reconstructed_ast: ast.AST, globals_dict: dict | None = None ): @@ -45,7 +60,7 @@ def assert_ast_equivalent( globals().update(globals_dict or {}) # Materialize original generator to list for comparison - original_list = list(genexpr) + original_list = materialize(genexpr) # Clean up globals to avoid pollution for key in globals_dict or {}: @@ -55,7 +70,7 @@ def assert_ast_equivalent( # Compile and evaluate the reconstructed AST reconstructed_gen = compile_and_eval(reconstructed_ast, globals_dict) - reconstructed_list = list(reconstructed_gen) + reconstructed_list = materialize(reconstructed_gen) assert reconstructed_list == original_list, ( f"AST produced {reconstructed_list}, expected {original_list}" ) @@ -383,8 +398,11 @@ def test_nested_loops(genexpr): @pytest.mark.parametrize( "genexpr", [ + ((x for x in range(i)) for i in range(5)), + ([x for x in range(i)] for i in range(5)), ([x for x in range(i)] for i in range(5)), ({x: x**2 for x in range(i)} for i in range(5)), + (((x for x in range(i + j)) for j in range(i)) for i in range(5)), ([[x for x in range(i + j)] for j in range(i)] for i in range(5)), # aggregation function call (sum(x for x in range(i + 1)) for i in range(3)), From e51d44e64573fe77b41382720c0b0dce0239ea5f Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 19 Jun 2025 15:47:07 -0400 Subject: [PATCH 053/106] make_function --- effectful/internals/disassembler.py | 35 +++++------------------------ 1 file changed, 6 insertions(+), 29 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index c3398e45..eee27d62 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -1329,36 +1329,13 @@ def handle_make_function_310( def handle_make_function( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: - # MAKE_FUNCTION creates a function from code object and name on stack - # In Python 3.13, function attributes are set using SET_FUNCTION_ATTRIBUTE - # The stack layout is simpler: just code object and name - - assert isinstance(state.stack[-1], ast.Constant) and isinstance( - state.stack[-1].value, str - ), "Function name must be a constant string." - - # In Python 3.13, MAKE_FUNCTION typically just has code object and name - # Additional attributes like closures are handled by SET_FUNCTION_ATTRIBUTE - - # Check if there are any flags that indicate special handling needed - flags = instr.arg or 0 - - body: ast.expr - if flags & 0x08: # Closure flag - # This is a closure, remove the closure tuple from the stack - new_stack = state.stack[:-3] - body = state.stack[-3] - else: - # Simple function without closure - new_stack = state.stack[:-2] - body = state.stack[-2] + # MAKE_FUNCTION in Python 3.13 is simplified: it only takes a code object from the stack + # and creates a function from it. No flags, no extra attributes on the stack. + # All extra attributes are handled by separate SET_FUNCTION_ATTRIBUTE instructions. - name: str = state.stack[-1].value - - assert any( - name.endswith(suffix) - for suffix in ("", "", "", "", "") - ), f"Expected a comprehension or lambda function, got '{name}'" + # Pop the code object from the stack (it's the only thing expected) + body: ast.expr = state.stack[-1] + new_stack = state.stack[:-1] if ( isinstance(body, CompExp) From 7ba5bf5b301d193117c009d20992aae587e28960 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 19 Jun 2025 16:03:33 -0400 Subject: [PATCH 054/106] add more null ops --- effectful/internals/disassembler.py | 22 ++++++++++++++++++++++ tests/test_internals_disassembler.py | 4 +++- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index eee27d62..02cf4ad4 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -797,6 +797,28 @@ def handle_load_name( return replace(state, stack=new_stack) +@register_handler("MAKE_CELL", version=PythonVersion.PY_313) +def handle_make_cell( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # MAKE_CELL creates a new cell in slot i for closure variables + # This is used when variables from outer scopes are captured by inner scopes + # For AST reconstruction purposes, this is just a variable scoping mechanism + # that we can ignore since the AST doesn't track low-level closure details + return state + + +@register_handler("COPY_FREE_VARS", version=PythonVersion.PY_313) +def handle_copy_free_vars( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # COPY_FREE_VARS copies n free (closure) variables from the closure into the frame + # This removes the need for special code on the caller's side when calling closures + # For AST reconstruction purposes, this is just a variable scoping mechanism + # that we can ignore since the AST doesn't track runtime variable management + return state + + @register_handler("STORE_DEREF") def handle_store_deref( state: ReconstructionState, instr: dis.Instruction diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index 787a9e8f..9d57f1d8 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -398,11 +398,13 @@ def test_nested_loops(genexpr): @pytest.mark.parametrize( "genexpr", [ + # nested generators ((x for x in range(i)) for i in range(5)), + (((x for x in range(i + j)) for j in range(i)) for i in range(5)), + # nested non-generators ([x for x in range(i)] for i in range(5)), ([x for x in range(i)] for i in range(5)), ({x: x**2 for x in range(i)} for i in range(5)), - (((x for x in range(i + j)) for j in range(i)) for i in range(5)), ([[x for x in range(i + j)] for j in range(i)] for i in range(5)), # aggregation function call (sum(x for x in range(i + 1)) for i in range(3)), From eee829e10578a33cda5436e893e23a68f11a74e9 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 19 Jun 2025 16:15:28 -0400 Subject: [PATCH 055/106] reconsolidate --- effectful/internals/disassembler.py | 29 +++-------------------------- 1 file changed, 3 insertions(+), 26 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index 02cf4ad4..7957618e 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -1147,12 +1147,11 @@ def handle_unary_op( } -# Python 3.10 version -@register_handler("COMPARE_OP", version=PythonVersion.PY_310) -def handle_compare_op_310( +@register_handler("COMPARE_OP") +def handle_compare_op( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: - assert instr.arg is not None and dis.cmp_op[instr.arg] == instr.argval, ( + assert instr.arg is not None and instr.argval in dis.cmp_op, ( f"Unsupported comparison operation: {instr.argval}" ) @@ -1166,28 +1165,6 @@ def handle_compare_op_310( return replace(state, stack=new_stack) -# Python 3.13 version -@register_handler("COMPARE_OP", version=PythonVersion.PY_313) -def handle_compare_op( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - # In Python 3.13, COMPARE_OP arguments have changed - # The operation is directly in argval, not indexed through cmp_op - assert instr.arg is not None - - right = ensure_ast(state.stack[-1]) - left = ensure_ast(state.stack[-2]) - - # Map comparison operation codes to AST operators - op_name = instr.argval - if op_name not in CMP_OPMAP: - raise NotImplementedError(f"Unsupported comparison operation: {op_name}") - - compare_node = ast.Compare(left=left, ops=[CMP_OPMAP[op_name]], comparators=[right]) - new_stack = state.stack[:-2] + [compare_node] - return replace(state, stack=new_stack) - - @register_handler("CONTAINS_OP") def handle_contains_op( state: ReconstructionState, instr: dis.Instruction From ab7ef48e9bd4a7ad7d11993c6b4ce0e70e7dc6de Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 19 Jun 2025 16:40:00 -0400 Subject: [PATCH 056/106] copy support --- effectful/internals/disassembler.py | 41 +++++++++++++----------- tests/test_internals_disassembler.py | 47 ++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 18 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index 7957618e..81dcae63 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -33,18 +33,14 @@ class Placeholder(ast.Name): """Placeholder for AST nodes that are not yet resolved.""" - def __init__(self, id=".PLACEHOLDER", ctx=None): - if ctx is None: - ctx = ast.Load() + def __init__(self, id=".PLACEHOLDER", ctx=ast.Load()): super().__init__(id=id, ctx=ctx) class DummyIterName(ast.Name): """Dummy name for the iterator variable in generator expressions.""" - def __init__(self, id=".0", ctx=None): - if ctx is None: - ctx = ast.Load() + def __init__(self, id=".0", ctx=ast.Load()): super().__init__(id=id, ctx=ctx) @@ -52,19 +48,28 @@ class CompLambda(ast.Lambda): """Placeholder AST node representing a lambda function used in comprehensions.""" def __init__(self, body: CompExp): + assert isinstance(body, CompExp) assert sum(1 for x in ast.walk(body) if isinstance(x, DummyIterName)) == 1 assert len(body.generators) > 0 assert isinstance(body.generators[0].iter, DummyIterName) - super().__init__( - args=ast.arguments( - posonlyargs=[ast.arg(DummyIterName().id)], - args=[], - kwonlyargs=[], - kw_defaults=[], - defaults=[], - ), - body=body, + args = ast.arguments( + posonlyargs=[ast.arg(DummyIterName().id)], + args=[], + kwonlyargs=[], + kw_defaults=[], + defaults=[], ) + super().__init__(args=args, body=body) + + def __copy__(self): + """Support copy.copy operation.""" + assert isinstance(self.body, CompExp) + return CompLambda(self.body) + + def __deepcopy__(self, memo): + """Support copy.deepcopy operation.""" + assert isinstance(self.body, CompExp) + return CompLambda(copy.deepcopy(self.body, memo)) def inline(self, iterator: ast.expr) -> CompExp: assert isinstance(self.body, CompExp) @@ -1318,7 +1323,7 @@ def handle_make_function_310( isinstance(body, CompExp) and sum(1 for x in ast.walk(body) if isinstance(x, DummyIterName)) == 1 ): - return replace(state, stack=new_stack + [CompLambda(body)]) + return replace(state, stack=new_stack + [CompLambda(body=body)]) else: raise NotImplementedError("Lambda reconstruction not implemented yet") @@ -1340,7 +1345,7 @@ def handle_make_function( isinstance(body, CompExp) and sum(1 for x in ast.walk(body) if isinstance(x, DummyIterName)) == 1 ): - return replace(state, stack=new_stack + [CompLambda(body)]) + return replace(state, stack=new_stack + [CompLambda(body=body)]) else: raise NotImplementedError("Lambda reconstruction not implemented yet") @@ -1879,7 +1884,7 @@ def _ensure_ast_genexpr(genexpr: types.GeneratorType) -> ast.GeneratorExp: genexpr_ast = ensure_ast(genexpr.gi_code) assert isinstance(genexpr_ast, ast.GeneratorExp) geniter_ast = ensure_ast(genexpr.gi_frame.f_locals[".0"]) - result = CompLambda(genexpr_ast).inline(geniter_ast) + result = CompLambda(body=genexpr_ast).inline(geniter_ast) assert isinstance(result, ast.GeneratorExp) assert inspect.getgeneratorstate(genexpr) == inspect.GEN_CREATED, ( "Generator must stay in created state" diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index 9d57f1d8..ae35004a 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -670,3 +670,50 @@ def test_error_handling(): list(gen) # Consume it with pytest.raises(AssertionError): reconstruct(gen) + + +def test_comp_lambda_copy(): + """Test that CompLambda is compatible with copy.copy and copy.deepcopy.""" + import copy + + from effectful.internals.disassembler import CompLambda, DummyIterName + + # Create a test generator expression AST + genexpr_ast = ast.GeneratorExp( + elt=ast.Name(id="x", ctx=ast.Load()), + generators=[ + ast.comprehension( + target=ast.Name(id="x", ctx=ast.Store()), + iter=DummyIterName(), + is_async=0, + ) + ], + ) + + # Create a CompLambda instance + comp_lambda = CompLambda(genexpr_ast) + + # Test copy.copy + copied = copy.copy(comp_lambda) + assert isinstance(copied, CompLambda) + assert ast.unparse(copied.body) == ast.unparse(comp_lambda.body) + assert copied.body is comp_lambda.body # Shallow copy shares the body + + # Test copy.deepcopy + deep_copied = copy.deepcopy(comp_lambda) + assert isinstance(deep_copied, CompLambda) + assert ast.unparse(deep_copied.body) == ast.unparse(comp_lambda.body) + assert deep_copied.body is not comp_lambda.body # Deep copy creates new body + + # Test that deep copied version works the same way + iterator = ast.Call( + func=ast.Name(id="range", ctx=ast.Load()), + args=[ast.Constant(value=5)], + keywords=[], + ) + + original_result = comp_lambda.inline(iterator) + deep_copied_result = deep_copied.inline(iterator) + + assert ast.unparse(original_result) == ast.unparse(deep_copied_result) + assert type(original_result) == type(deep_copied_result) From 72f6187e2214b8637ce05d80056f486ff60781cf Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 19 Jun 2025 18:49:03 -0400 Subject: [PATCH 057/106] fix nested generators, add stack_effect postcondition --- effectful/internals/disassembler.py | 76 ++++++++++++++++++++-------- tests/test_internals_disassembler.py | 8 +-- 2 files changed, 60 insertions(+), 24 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index 81dcae63..9ac09a2f 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -44,6 +44,13 @@ def __init__(self, id=".0", ctx=ast.Load()): super().__init__(id=id, ctx=ctx) +class Null(ast.Constant): + """Placeholder for NULL values generated in bytecode.""" + + def __init__(self, value=None): + super().__init__(value=value) + + class CompLambda(ast.Lambda): """Placeholder AST node representing a lambda function used in comprehensions.""" @@ -162,7 +169,22 @@ def _wrapper( assert instr.opname == opname, ( f"Handler for '{opname}' called with wrong instruction" ) - return handler(state, instr) + + new_state = handler(state, instr) + + # post-condition: check stack effect + expected_stack_effect = dis.stack_effect(instr.opcode, instr.arg) + actual_stack_effect = len(new_state.stack) - len(state.stack) + if not (len(state.stack) == len(new_state.stack) == 0): + assert len(state.stack) + expected_stack_effect >= 0, ( + f"Handler for '{opname}' would result in negative stack size" + ) + assert actual_stack_effect == expected_stack_effect, ( + f"Handler for '{opname}' has incorrect stack effect: " + f"expected {expected_stack_effect}, got {actual_stack_effect}" + ) + + return new_state OP_HANDLERS[opname] = _wrapper return _wrapper @@ -194,10 +216,12 @@ def handle_return_generator( assert isinstance(state.result, Placeholder), ( "RETURN_GENERATOR must be the first instruction" ) - return replace(state, result=ast.GeneratorExp(elt=Placeholder(), generators=[])) + new_result = ast.GeneratorExp(elt=Placeholder(), generators=[]) + new_stack = state.stack + [new_result] + return replace(state, result=new_result, stack=new_stack) -@register_handler("YIELD_VALUE") +@register_handler("YIELD_VALUE", version=PythonVersion.PY_313) def handle_yield_value( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -211,7 +235,7 @@ def handle_yield_value( ) assert len(state.result.generators) > 0, "YIELD_VALUE should have generators" - new_stack = state.stack[:-1] + new_stack = state.stack # [:-1] ret = ast.GeneratorExp( elt=ensure_ast(state.stack[-1]), generators=state.result.generators, @@ -467,8 +491,6 @@ def handle_for_iter( ) # The iterator should be on top of stack - # Create new stack without the iterator - new_stack = state.stack[:-1] iterator: ast.expr = state.stack[-1] # Create a new loop variable - we'll get the actual name from STORE_FAST @@ -495,6 +517,7 @@ def handle_for_iter( generators=state.result.generators + [loop_info], ) + new_stack = state.stack + [loop_info.target] return replace(state, stack=new_stack, result=new_ret) @@ -569,7 +592,8 @@ def handle_end_for( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: # END_FOR marks the end of a for loop - no action needed for AST reconstruction - return state + new_stack = state.stack[:-1] + return replace(state, stack=new_stack) @register_handler("RETURN_CONST", version=PythonVersion.PY_313) @@ -696,6 +720,8 @@ def handle_store_fast( # Update the most recent loop's target variable assert len(state.result.generators) > 0, "STORE_FAST must be within a loop context" + new_stack = state.stack[:-1] + # Create a new LoopInfo with updated target updated_loop = ast.comprehension( target=ast.Name(id=var_name, ctx=ast.Store()), @@ -711,13 +737,13 @@ def handle_store_fast( value=state.result.value, generators=state.result.generators[:-1] + [updated_loop], ) - return replace(state, result=new_dict) + return replace(state, stack=new_stack, result=new_dict) else: new_comp: ast.GeneratorExp | ast.ListComp | ast.SetComp = type(state.result)( elt=state.result.elt, generators=state.result.generators[:-1] + [updated_loop], ) - return replace(state, result=new_comp) + return replace(state, stack=new_stack, result=new_comp) @register_handler("STORE_FAST_LOAD_FAST", version=PythonVersion.PY_313) @@ -770,7 +796,7 @@ def handle_store_fast_load_fast( new_state = replace(state, result=new_comp) # Now handle the load part - push the variable onto the stack - new_stack = new_state.stack + [ast.Name(id=load_name, ctx=ast.Load())] + new_stack = new_state.stack[:-1] + [ast.Name(id=load_name, ctx=ast.Load())] return replace(new_state, stack=new_stack) @@ -788,7 +814,11 @@ def handle_load_global( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: global_name = instr.argval - new_stack = state.stack + [ast.Name(id=global_name, ctx=ast.Load())] + + if instr.argrepr.endswith(" + NULL"): + new_stack = state.stack + [ast.Name(id=global_name, ctx=ast.Load()), Null()] + else: + new_stack = state.stack + [ast.Name(id=global_name, ctx=ast.Load())] return replace(state, stack=new_stack) @@ -852,13 +882,13 @@ def handle_store_deref( value=state.result.value, generators=state.result.generators[:-1] + [updated_loop], ) - return replace(state, result=new_dict) + return replace(state, stack=state.stack[:-1], result=new_dict) else: new_comp: ast.GeneratorExp | ast.ListComp | ast.SetComp = type(state.result)( elt=state.result.elt, generators=state.result.generators[:-1] + [updated_loop], ) - return replace(state, result=new_comp) + return replace(state, stack=state.stack[:-1], result=new_comp) @register_handler("LOAD_DEREF") @@ -1212,13 +1242,17 @@ def handle_call( # CALL pops function and arguments from stack (replaces CALL_FUNCTION in Python 3.13) assert instr.arg is not None arg_count: int = instr.arg + + func = ensure_ast(state.stack[-arg_count - 2]) + # Pop arguments and function args = ( [ensure_ast(arg) for arg in state.stack[-arg_count:]] if arg_count > 0 else [] ) - func = ensure_ast(state.stack[-arg_count - 1]) - new_stack = state.stack[: -arg_count - 1] + if not isinstance(state.stack[-arg_count - 1], Null): + args = [ensure_ast(state.stack[-arg_count - 1])] + args + new_stack = state.stack[: -arg_count - 2] if isinstance(func, CompLambda): assert len(args) == 1 return replace(state, stack=new_stack + [func.inline(args[0])]) @@ -1253,7 +1287,7 @@ def handle_call_function( return replace(state, stack=new_stack) -@register_handler("LOAD_METHOD") +@register_handler("LOAD_METHOD", version=PythonVersion.PY_310) def handle_load_method( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -1360,7 +1394,7 @@ def handle_set_function_attribute( # Just pop the attribute value and leave the function on the stack # Pop the attribute value but keep the function - new_stack = state.stack[:-1] + new_stack = state.stack[:-2] + [state.stack[-1]] # Keep the function on top return replace(state, stack=new_stack) @@ -1369,18 +1403,20 @@ def handle_set_function_attribute( # ============================================================================ -@register_handler("LOAD_ATTR") +@register_handler("LOAD_ATTR", version=PythonVersion.PY_313) def handle_load_attr( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: # LOAD_ATTR loads an attribute from the object on top of stack obj = ensure_ast(state.stack[-1]) attr_name = instr.argval - new_stack = state.stack[:-1] # Create attribute access AST attr_node = ast.Attribute(value=obj, attr=attr_name, ctx=ast.Load()) - new_stack = new_stack + [attr_node] + if instr.argrepr.endswith(" + NULL|self"): + new_stack = state.stack[:-1] + [attr_node, Null()] + else: + new_stack = state.stack[:-1] + [attr_node] return replace(state, stack=new_stack) diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index ae35004a..d0ab3b43 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -139,9 +139,9 @@ def test_simple_generators(genexpr): ((x * 2 + 3) * (x - 1) for x in range(5)), (x * (x + 1) * (x + 2) for x in range(5)), # Mixed operations with precedence - (x + y * 2 for x in range(3) for y in range(3)), - (x * 2 + y / 3 for x in range(1, 4) for y in range(1, 4)), - ((x + y) * (x - y) for x in range(1, 4) for y in range(1, 4)), + (x + 3 * 2 for x in range(3)), + (x * 2 + 9 / 3 for x in range(1, 4)), + ((x + 2) * (x - 2) for x in range(1, 4)), # Edge cases with zero and one (x * 0 for x in range(5)), (x * 1 for x in range(5)), @@ -399,7 +399,7 @@ def test_nested_loops(genexpr): "genexpr", [ # nested generators - ((x for x in range(i)) for i in range(5)), + ((x for x in range(i + 1)) for i in range(5)), (((x for x in range(i + j)) for j in range(i)) for i in range(5)), # nested non-generators ([x for x in range(i)] for i in range(5)), From dfa27fe487448d5f35b1db851740746bd51fcb00 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 20 Jun 2025 01:56:50 -0400 Subject: [PATCH 058/106] stash changes --- effectful/internals/disassembler.py | 157 ++++++++++++++++++--------- tests/test_internals_disassembler.py | 3 +- 2 files changed, 105 insertions(+), 55 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index 9ac09a2f..b4501520 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -283,24 +283,10 @@ def handle_build_list( size: int = instr.arg if size == 0: - # BUILD_LIST with 0 elements can mean two things: - # 1. Start of a list comprehension (if we're at the start or in a nested context) - # 2. Creating an empty list - # Check if this looks like the start of a list comprehension pattern - # In nested comprehensions, BUILD_LIST(0) starts a new list comprehension - if isinstance(state.result, Placeholder) and len(state.stack) == 0: - # This is the start of a standalone list comprehension - ret = ast.ListComp(elt=Placeholder(), generators=[]) - new_stack = state.stack + [ret] - return replace(state, stack=new_stack, result=ret) - else: - # This might be a nested list comprehension or just an empty list - # For nested comprehensions, we create a ListComp and put it on the stack - # without changing the main result (which is the outer generator) - ret = ast.ListComp(elt=Placeholder(), generators=[]) - new_stack = state.stack + [ret] - return replace(state, stack=new_stack) + # In nested comprehensions, BUILD_LIST(0) starts a new list comprehe + new_stack = state.stack + [ast.ListComp(elt=Placeholder(), generators=[])] + return replace(state, stack=new_stack) else: # BUILD_LIST with elements - create a regular list elements = [ensure_ast(elem) for elem in state.stack[-size:]] @@ -334,36 +320,16 @@ def handle_list_append( # LIST_APPEND appends to a list comprehension # The list comprehension might be the main result or on the stack (for nested comprehensions) - if isinstance(state.result, ast.ListComp): - # Main result is a list comprehension - new_stack = state.stack[:-1] - new_ret = ast.ListComp( - elt=ensure_ast(state.stack[-1]), - generators=state.result.generators, - ) - return replace(state, stack=new_stack, result=new_ret) - elif len(state.stack) >= 2 and isinstance(state.stack[-2], ast.ListComp): - # There's a list comprehension on the stack (nested case) - # LIST_APPEND with argument 2 means append to the list 2 positions down - assert instr.arg == 2, f"Expected LIST_APPEND with arg 2, got {instr.arg}" - - list_comp = state.stack[-2] - element = ensure_ast(state.stack[-1]) - - # Update the list comprehension with the element - new_list_comp = ast.ListComp( - elt=element, - generators=list_comp.generators, - ) + comp: ast.ListComp = state.stack[-instr.argval - 1] + assert isinstance(comp.elt, Placeholder) - # Replace the list comprehension on the stack - new_stack = state.stack[:-2] + [new_list_comp] - return replace(state, stack=new_stack) - else: - raise AssertionError( - f"LIST_APPEND must be called within a ListComp context. " - f"State result: {type(state.result)}, Stack: {[type(x) for x in state.stack]}" - ) + new_elt = state.stack[-1] + new_stack = state.stack[:-1] + new_stack[-instr.argval] = ast.ListComp( + elt=new_elt, + generators=comp.generators, + ) + return replace(state, stack=new_stack) # ============================================================================ @@ -371,8 +337,8 @@ def handle_list_append( # ============================================================================ -@register_handler("BUILD_SET") -def handle_build_set( +@register_handler("BUILD_SET", version=PythonVersion.PY_310) +def handle_build_set_310( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: if isinstance(state.result, Placeholder) and len(state.stack) == 0: @@ -396,8 +362,27 @@ def handle_build_set( return replace(state, stack=new_stack) -@register_handler("SET_ADD") -def handle_set_add( +# Python 3.13 version +@register_handler("BUILD_SET", version=PythonVersion.PY_313) +def handle_build_set( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + assert instr.arg is not None + size: int = instr.arg + + if size == 0: + new_stack = state.stack + [ast.SetComp(elt=Placeholder(), generators=[])] + return replace(state, stack=new_stack) + else: + elements = [ensure_ast(elem) for elem in state.stack[-size:]] + new_stack = state.stack[:-size] + elt_node = ast.Set(elts=elements) + new_stack = new_stack + [elt_node] + return replace(state, stack=new_stack) + + +@register_handler("SET_ADD", version=PythonVersion.PY_310) +def handle_set_add_310( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: assert isinstance(state.result, ast.SetComp), ( @@ -411,13 +396,30 @@ def handle_set_add( return replace(state, stack=new_stack, result=new_ret) +# Python 3.13 version +@register_handler("SET_ADD", version=PythonVersion.PY_313) +def handle_set_add( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + comp: ast.SetComp = state.stack[-instr.argval - 1] + assert isinstance(comp.elt, Placeholder) + + new_elt = state.stack[-1] + new_stack = state.stack[:-1] + new_stack[-instr.argval] = ast.SetComp( + elt=new_elt, + generators=comp.generators, + ) + return replace(state, stack=new_stack) + + # ============================================================================ # DICT COMPREHENSION HANDLERS # ============================================================================ -@register_handler("BUILD_MAP") -def handle_build_map( +@register_handler("BUILD_MAP", version=PythonVersion.PY_310) +def handle_build_map_310( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: if isinstance(state.result, Placeholder) and len(state.stack) == 0: @@ -442,8 +444,37 @@ def handle_build_map( return replace(state, stack=new_stack) -@register_handler("MAP_ADD") -def handle_map_add( +# Python 3.13 version +@register_handler("BUILD_MAP", version=PythonVersion.PY_313) +def handle_build_map( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + assert instr.arg is not None + size: int = instr.arg + + if size == 0: + new_stack = state.stack + [ + ast.DictComp(key=Placeholder(), value=Placeholder(), generators=[]) + ] + return replace(state, stack=new_stack) + else: + assert instr.arg is not None + size: int = instr.arg + # Pop key-value pairs for the dict + keys: list[ast.expr | None] = [ + ensure_ast(state.stack[-2 * i - 2]) for i in range(size) + ] + values = [ensure_ast(state.stack[-2 * i - 1]) for i in range(size)] + new_stack = state.stack[: -2 * size] if size > 0 else state.stack + + # Create dict AST + dict_node = ast.Dict(keys=keys, values=values) + new_stack = new_stack + [dict_node] + return replace(state, stack=new_stack) + + +@register_handler("MAP_ADD", version=PythonVersion.PY_310) +def handle_map_add_310( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: assert isinstance(state.result, ast.DictComp), ( @@ -458,6 +489,24 @@ def handle_map_add( return replace(state, stack=new_stack, result=new_ret) +# Python 3.13 version +@register_handler("MAP_ADD", version=PythonVersion.PY_313) +def handle_map_add( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + comp: ast.DictComp = state.stack[-instr.argval - 2] + assert isinstance(comp.key, Placeholder) + assert isinstance(comp.value, Placeholder) + + new_stack = state.stack[:-2] + new_stack[-instr.argval] = ast.DictComp( + key=ensure_ast(state.stack[-2]), + value=ensure_ast(state.stack[-1]), + generators=comp.generators, + ) + return replace(state, stack=new_stack) + + # ============================================================================ # LOOP CONTROL HANDLERS # ============================================================================ diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index d0ab3b43..c441dbc7 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -400,10 +400,11 @@ def test_nested_loops(genexpr): [ # nested generators ((x for x in range(i + 1)) for i in range(5)), + ((x for j in range(i) for x in range(j)) for i in range(5)), (((x for x in range(i + j)) for j in range(i)) for i in range(5)), # nested non-generators ([x for x in range(i)] for i in range(5)), - ([x for x in range(i)] for i in range(5)), + ([x for j in range(i) for x in range(j)] for i in range(5)), ({x: x**2 for x in range(i)} for i in range(5)), ([[x for x in range(i + j)] for j in range(i)] for i in range(5)), # aggregation function call From 10001ea9d84f7eaad2381489b3c086c7fd5b7b8a Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 20 Jun 2025 12:22:32 -0400 Subject: [PATCH 059/106] all tests pass for Python 3.13 --- effectful/internals/disassembler.py | 176 ++++++++++++++------------- tests/test_internals_disassembler.py | 47 ++++--- 2 files changed, 119 insertions(+), 104 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index b4501520..0104f499 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -175,7 +175,9 @@ def _wrapper( # post-condition: check stack effect expected_stack_effect = dis.stack_effect(instr.opcode, instr.arg) actual_stack_effect = len(new_state.stack) - len(state.stack) - if not (len(state.stack) == len(new_state.stack) == 0): + if not ( + len(state.stack) == len(new_state.stack) == 0 or instr.opname == "END_FOR" + ): assert len(state.stack) + expected_stack_effect >= 0, ( f"Handler for '{opname}' would result in negative stack size" ) @@ -285,8 +287,9 @@ def handle_build_list( if size == 0: # Check if this looks like the start of a list comprehension pattern # In nested comprehensions, BUILD_LIST(0) starts a new list comprehe - new_stack = state.stack + [ast.ListComp(elt=Placeholder(), generators=[])] - return replace(state, stack=new_stack) + new_ret = ast.ListComp(elt=Placeholder(), generators=[]) + new_stack = state.stack + [state.result] + return replace(state, stack=new_stack, result=new_ret) else: # BUILD_LIST with elements - create a regular list elements = [ensure_ast(elem) for elem in state.stack[-size:]] @@ -317,19 +320,19 @@ def handle_list_append_310( def handle_list_append( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: - # LIST_APPEND appends to a list comprehension - # The list comprehension might be the main result or on the stack (for nested comprehensions) + assert isinstance(state.result, ast.ListComp) + assert isinstance(state.result.elt, Placeholder) - comp: ast.ListComp = state.stack[-instr.argval - 1] - assert isinstance(comp.elt, Placeholder) + # add the body to the comprehension + comp: ast.ListComp = copy.deepcopy(state.result) + comp.elt = state.stack[-1] - new_elt = state.stack[-1] + # swap the return value + prev_result: CompExp = state.stack[-instr.argval - 1] new_stack = state.stack[:-1] - new_stack[-instr.argval] = ast.ListComp( - elt=new_elt, - generators=comp.generators, - ) - return replace(state, stack=new_stack) + new_stack[-instr.argval] = comp + + return replace(state, stack=new_stack, result=prev_result) # ============================================================================ @@ -371,8 +374,9 @@ def handle_build_set( size: int = instr.arg if size == 0: - new_stack = state.stack + [ast.SetComp(elt=Placeholder(), generators=[])] - return replace(state, stack=new_stack) + new_result = ast.SetComp(elt=Placeholder(), generators=[]) + new_stack = state.stack + [state.result] + return replace(state, stack=new_stack, result=new_result) else: elements = [ensure_ast(elem) for elem in state.stack[-size:]] new_stack = state.stack[:-size] @@ -401,16 +405,19 @@ def handle_set_add_310( def handle_set_add( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: - comp: ast.SetComp = state.stack[-instr.argval - 1] - assert isinstance(comp.elt, Placeholder) + assert isinstance(state.result, ast.SetComp) + assert isinstance(state.result.elt, Placeholder) - new_elt = state.stack[-1] + # add the body to the comprehension + comp: ast.SetComp = copy.deepcopy(state.result) + comp.elt = state.stack[-1] + + # swap the return value + prev_result: CompExp = state.stack[-instr.argval - 1] new_stack = state.stack[:-1] - new_stack[-instr.argval] = ast.SetComp( - elt=new_elt, - generators=comp.generators, - ) - return replace(state, stack=new_stack) + new_stack[-instr.argval] = comp + + return replace(state, stack=new_stack, result=prev_result) # ============================================================================ @@ -453,13 +460,10 @@ def handle_build_map( size: int = instr.arg if size == 0: - new_stack = state.stack + [ - ast.DictComp(key=Placeholder(), value=Placeholder(), generators=[]) - ] - return replace(state, stack=new_stack) + new_stack = state.stack + [state.result] + new_result = ast.DictComp(key=Placeholder(), value=Placeholder(), generators=[]) + return replace(state, stack=new_stack, result=new_result) else: - assert instr.arg is not None - size: int = instr.arg # Pop key-value pairs for the dict keys: list[ast.expr | None] = [ ensure_ast(state.stack[-2 * i - 2]) for i in range(size) @@ -494,17 +498,21 @@ def handle_map_add_310( def handle_map_add( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: - comp: ast.DictComp = state.stack[-instr.argval - 2] - assert isinstance(comp.key, Placeholder) - assert isinstance(comp.value, Placeholder) + assert isinstance(state.result, ast.DictComp) + assert isinstance(state.result.key, Placeholder) + assert isinstance(state.result.value, Placeholder) + # add the body to the comprehension + comp: ast.DictComp = copy.deepcopy(state.result) + comp.key = state.stack[-2] + comp.value = state.stack[-1] + + # swap the return value + prev_result: CompExp = state.stack[-instr.argval - 2] new_stack = state.stack[:-2] - new_stack[-instr.argval] = ast.DictComp( - key=ensure_ast(state.stack[-2]), - value=ensure_ast(state.stack[-1]), - generators=comp.generators, - ) - return replace(state, stack=new_stack) + new_stack[-instr.argval] = comp + + return replace(state, stack=new_stack, result=prev_result) # ============================================================================ @@ -552,21 +560,12 @@ def handle_for_iter( ) # Create new loops list with the new loop info - new_ret: ast.GeneratorExp | ast.ListComp | ast.SetComp | ast.DictComp - if isinstance(state.result, ast.DictComp): - # If it's a DictComp, we need to ensure the loop is added to the dict comprehension - new_ret = ast.DictComp( - key=state.result.key, - value=state.result.value, - generators=state.result.generators + [loop_info], - ) - else: - new_ret = type(state.result)( - elt=state.result.elt, - generators=state.result.generators + [loop_info], - ) + assert isinstance(state.result, CompExp) + new_ret = copy.deepcopy(state.result) + new_ret.generators = new_ret.generators + [loop_info] new_stack = state.stack + [loop_info.target] + assert isinstance(new_ret, CompExp) return replace(state, stack=new_stack, result=new_ret) @@ -641,7 +640,7 @@ def handle_end_for( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: # END_FOR marks the end of a for loop - no action needed for AST reconstruction - new_stack = state.stack[:-1] + new_stack = state.stack # [:-1] return replace(state, stack=new_stack) @@ -757,7 +756,7 @@ def handle_load_fast_load_fast( return replace(state, stack=new_stack) -@register_handler("STORE_FAST") +@register_handler("STORE_FAST", version=PythonVersion.PY_313) def handle_store_fast( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -766,33 +765,22 @@ def handle_store_fast( ) var_name = instr.argval - # Update the most recent loop's target variable - assert len(state.result.generators) > 0, "STORE_FAST must be within a loop context" - - new_stack = state.stack[:-1] - - # Create a new LoopInfo with updated target - updated_loop = ast.comprehension( - target=ast.Name(id=var_name, ctx=ast.Store()), - iter=state.result.generators[-1].iter, - ifs=state.result.generators[-1].ifs, - is_async=state.result.generators[-1].is_async, - ) - - # Update the last loop in the generators list - if isinstance(state.result, ast.DictComp): - new_dict: ast.DictComp = ast.DictComp( - key=state.result.key, - value=state.result.value, - generators=state.result.generators[:-1] + [updated_loop], - ) - return replace(state, stack=new_stack, result=new_dict) + if not state.stack or ( + isinstance(state.stack[-1], ast.Name) and state.stack[-1].id == var_name + ): + # If the variable is already on the stack, we can skip adding it again + # This is common in nested comprehensions where the same variable is reused + return replace(state, stack=state.stack[:-1]) else: - new_comp: ast.GeneratorExp | ast.ListComp | ast.SetComp = type(state.result)( - elt=state.result.elt, - generators=state.result.generators[:-1] + [updated_loop], + # Update the most recent loop's target variable + assert len(state.result.generators) > 0, ( + "STORE_FAST must be within a loop context" ) - return replace(state, stack=new_stack, result=new_comp) + + new_stack = state.stack[:-1] + new_result: CompExp = copy.deepcopy(state.result) + new_result.generators[-1].target = ast.Name(id=var_name, ctx=ast.Store()) + return replace(state, stack=new_stack, result=new_result) @register_handler("STORE_FAST_LOAD_FAST", version=PythonVersion.PY_313) @@ -810,11 +798,8 @@ def handle_store_fast_load_fast( # In Python 3.13, the instruction argument contains both names # argval should be a tuple (store_name, load_name) - if isinstance(instr.argval, tuple): - store_name, load_name = instr.argval - else: - # Fallback: assume both names are the same - store_name = load_name = instr.argval + assert isinstance(instr.argval, tuple) + store_name, load_name = instr.argval # Update the most recent loop's target variable assert len(state.result.generators) > 0, ( @@ -1523,8 +1508,8 @@ def handle_list_to_tuple( return replace(state, stack=new_stack) -@register_handler("LIST_EXTEND") -def handle_list_extend( +@register_handler("LIST_EXTEND", version=PythonVersion.PY_310) +def handle_list_extend_310( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: # LIST_EXTEND extends the list at TOS-1 with the iterable at TOS @@ -1552,6 +1537,27 @@ def handle_list_extend( return replace(state, stack=new_stack) +@register_handler("LIST_EXTEND", version=PythonVersion.PY_313) +def handle_list_extend( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # LIST_EXTEND extends the list at TOS-1 with the iterable at TOS + # initially recognized as list comp + + # The list being extended is actually in state.result instead of the stack + # because it was initially recognized as a list comprehension in BUILD_LIST, + # while the actual result expression is in the stack where the list "should be" + # and needs to be put back into the state result slot + assert isinstance(state.result, ast.ListComp) and not state.result.generators + assert isinstance(state.stack[-1], ast.Tuple | ast.List) + prev_result = state.stack[-instr.argval - 1] + + list_obj = ast.List(elts=[ensure_ast(e) for e in state.stack[-1].elts]) + new_stack = state.stack[:-2] + [list_obj] + + return replace(state, stack=new_stack, result=prev_result) + + @register_handler("BUILD_CONST_KEY_MAP") def handle_build_const_key_map( state: ReconstructionState, instr: dis.Instruction diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index c441dbc7..301aab87 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -334,7 +334,8 @@ def test_filtered_generators(genexpr): ((x, y) for x in range(5) for y in range(5) if x < y), (x + y for x in range(5) if x % 2 == 0 for y in range(5) if y % 2 == 1), # Triple nested - ((x, y, z) for x in range(2) for y in range(2) for z in range(2)), + (x + y + z for x in range(2) for y in range(3) for z in range(4)), + ((x, y, z) for x in range(2) for y in range(3) for z in range(4)), # More complex nested loop edge cases # Different sized ranges ((x, y) for x in range(2) for y in range(5)), @@ -402,29 +403,21 @@ def test_nested_loops(genexpr): ((x for x in range(i + 1)) for i in range(5)), ((x for j in range(i) for x in range(j)) for i in range(5)), (((x for x in range(i + j)) for j in range(i)) for i in range(5)), - # nested non-generators - ([x for x in range(i)] for i in range(5)), - ([x for j in range(i) for x in range(j)] for i in range(5)), - ({x: x**2 for x in range(i)} for i in range(5)), - ([[x for x in range(i + j)] for j in range(i)] for i in range(5)), + # nested generators with filters + ((x for x in range(i)) for i in range(5) if i > 0), + ((x for x in range(i) if x < i) for i in range(5) if i > 0), + (((x for x in range(i + j) if x < i + j) for j in range(i)) for i in range(5)), # aggregation function call (sum(x for x in range(i + 1)) for i in range(3)), (max(x for x in range(i + 1)) for i in range(3)), + (dict((x, x + 1) for x in range(i + 1)) for i in range(3)), + (set(x for x in range(i + 1)) for i in range(3)), # map (list(map(abs, (x + 1 for x in range(i + 1)))) for i in range(3)), (list(enumerate(x + 1 for x in range(i + 1))) for i in range(3)), - # Nested comprehensions with filters inside - ([x for x in range(i)] for i in range(5) if i > 0), - ([x for x in range(i) if x < i] for i in range(5) if i > 0), - ([[x for x in range(i + j) if x < i + j] for j in range(i)] for i in range(5)), - ( - [[x for x in range(i + j) if x < i + j] for j in range(i)] - for i in range(5) - if i > 0 - ), # nesting on both sides - ([y for y in range(x)] for x in (x_ + 1 for x_ in range(5))), - ([y for y in range(x)] for x in (x_ + 1 for x_ in range(5))), + ((y for y in range(x)) for x in (x_ + 1 for x_ in range(5))), + ((y for y in range(x)) for x in (x_ + 1 for x_ in range(5))), ], ) def test_nested_comprehensions(genexpr): @@ -446,9 +439,23 @@ def test_nested_comprehensions(genexpr): (x_ for x_ in {x for x in range(5)}), (x_ for x_ in {x: x**2 for x in range(5)}), # Comprehensions as yield expressions - ([y for y in range(x + 1)] for x in range(3)), - ({y for y in range(x + 1)} for x in range(3)), + ([y * 2 for y in range(x + 1)] for x in range(3)), + ({y + 3 for y in range(x + 1)} for x in range(3)), ({y: y**2 for y in range(x + 1)} for x in range(3)), + # nested non-generators + ([x for x in range(i)] for i in range(5)), + ([x for j in range(i) for x in range(j)] for i in range(5)), + ({x: x**2 for x in range(i)} for i in range(5)), + ([[x for x in range(i + j)] for j in range(i)] for i in range(5)), + # Nested comprehensions with filters inside + ([x for x in range(i)] for i in range(5) if i > 0), + ([x for x in range(i) if x < i] for i in range(5) if i > 0), + ([[x for x in range(i + j) if x < i + j] for j in range(i)] for i in range(5)), + ( + [[x for x in range(i + j) if x < i + j] for j in range(i)] + for i in range(5) + if i > 0 + ), ], ) def test_different_comprehension_types(genexpr): @@ -522,6 +529,7 @@ def test_variable_lookup(genexpr, globals_dict): ((x.bit_length() for x in range(1, 10)), {}), ((str(x).zfill(3) for x in range(10)), {"str": str}), # Subscript operations + (((10, 20, 30)[i] for i in range(3)), {}), (([10, 20, 30][i] for i in range(3)), {}), (({"a": 1, "b": 2, "c": 3}[k] for k in ["a", "b", "c"]), {}), (("hello"[i] for i in range(5)), {}), @@ -538,6 +546,7 @@ def test_variable_lookup(genexpr, globals_dict): ), ((s.upper().lower() for s in ["Hello", "World"]), {}), # Edge cases with complex data structures + (((1, 2, 3)[x % 3] for x in range(10)), {}), (([1, 2, 3][x % 3] for x in range(10)), {}), # (({"even": x, "odd": x + 1}["even" if x % 2 == 0 else "odd"] for x in range(5)), {}), # Function calls with multiple arguments From a1cd1dc420bb1244b545d691c87f67aede156b7e Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 20 Jun 2025 13:13:09 -0400 Subject: [PATCH 060/106] all tests pass on both versions --- effectful/internals/disassembler.py | 174 ++++++++++++++++++--------- tests/test_internals_disassembler.py | 1 + 2 files changed, 115 insertions(+), 60 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index 0104f499..9b2c1425 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -223,6 +223,27 @@ def handle_return_generator( return replace(state, result=new_result, stack=new_stack) +@register_handler("YIELD_VALUE", version=PythonVersion.PY_310) +def handle_yield_value_310( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # YIELD_VALUE pops a value from the stack and yields it + # This is the expression part of the generator + assert isinstance(state.result, ast.GeneratorExp), ( + "YIELD_VALUE must be called after RETURN_GENERATOR" + ) + assert isinstance(state.result.elt, Placeholder), ( + "YIELD_VALUE must be called before yielding" + ) + assert len(state.result.generators) > 0, "YIELD_VALUE should have generators" + + ret = ast.GeneratorExp( + elt=ensure_ast(state.stack[-1]), + generators=state.result.generators, + ) + return replace(state, result=ret) + + @register_handler("YIELD_VALUE", version=PythonVersion.PY_313) def handle_yield_value( state: ReconstructionState, instr: dis.Instruction @@ -650,15 +671,11 @@ def handle_return_const( ) -> ReconstructionState: # RETURN_CONST returns a constant value (replaces some LOAD_CONST + RETURN_VALUE patterns) # Similar to RETURN_VALUE but with a constant - if isinstance(state.result, CompExp): - return state - elif isinstance(state.result, Placeholder) and len(state.stack) == 1: - new_result = ensure_ast(state.stack[-1]) - assert isinstance(new_result, CompExp | ast.Lambda) - return replace(state, stack=state.stack[:-1], result=new_result) - else: + if isinstance(state.result, ast.GeneratorExp): # For generators, this typically ends the generator with None return state + else: + raise TypeError("Unexpected RETURN_CONST in reconstruction") @register_handler("CALL_INTRINSIC_1", version=PythonVersion.PY_313) @@ -666,18 +683,14 @@ def handle_call_intrinsic_1( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: # CALL_INTRINSIC_1 calls an intrinsic function with one argument - # For generator expressions, this is often used for exception handling - # We can generally ignore this for AST reconstruction - return state - - -@register_handler("CALL_INTRINSIC_2", version=PythonVersion.PY_313) -def handle_call_intrinsic_2( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - # CALL_INTRINSIC_2 calls an intrinsic function with two arguments - # We can generally ignore this for AST reconstruction - return state + if instr.argrepr == "INTRINSIC_STOPITERATION_ERROR": + return state + elif instr.argrepr == "INTRINSIC_UNARY_POSITIVE": + assert len(state.stack) > 0 + new_val = ast.UnaryOp(op=ast.UAdd(), operand=state.stack[-1]) + return replace(state, stack=state.stack[:-1] + [new_val]) + else: + raise TypeError(f"Unsupported generator intrinsic operation: {instr.argrepr}") @register_handler("RERAISE", version=PythonVersion.PY_313) @@ -685,6 +698,7 @@ def handle_reraise( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: # RERAISE re-raises an exception - generally ignore for AST reconstruction + assert not state.stack # in generator expressions, we shouldn't have a stack here return state @@ -756,6 +770,22 @@ def handle_load_fast_load_fast( return replace(state, stack=new_stack) +@register_handler("STORE_FAST", version=PythonVersion.PY_310) +def handle_store_fast_310( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + assert isinstance(state.result, CompExp), ( + "STORE_FAST must be called within a comprehension context" + ) + var_name = instr.argval + assert len(state.result.generators) > 0, "STORE_FAST must be within a loop context" + + new_stack = state.stack[:-1] + new_result: CompExp = copy.deepcopy(state.result) + new_result.generators[-1].target = ast.Name(id=var_name, ctx=ast.Store()) + return replace(state, stack=new_stack, result=new_result) + + @register_handler("STORE_FAST", version=PythonVersion.PY_313) def handle_store_fast( state: ReconstructionState, instr: dis.Instruction @@ -888,8 +918,8 @@ def handle_copy_free_vars( return state -@register_handler("STORE_DEREF") -def handle_store_deref( +@register_handler("STORE_DEREF", version=PythonVersion.PY_310) +def handle_store_deref_310( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: # STORE_DEREF stores a value into a closure variable @@ -901,28 +931,38 @@ def handle_store_deref( # Update the most recent loop's target variable assert len(state.result.generators) > 0, "STORE_DEREF must be within a loop context" - # Create a new LoopInfo with updated target - updated_loop = ast.comprehension( - target=ast.Name(id=var_name, ctx=ast.Store()), - iter=state.result.generators[-1].iter, - ifs=state.result.generators[-1].ifs, - is_async=state.result.generators[-1].is_async, + new_stack = state.stack[:-1] + new_result: CompExp = copy.deepcopy(state.result) + new_result.generators[-1].target = ast.Name(id=var_name, ctx=ast.Store()) + return replace(state, stack=new_stack, result=new_result) + + +@register_handler("STORE_DEREF", version=PythonVersion.PY_313) +def handle_store_deref( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # STORE_DEREF stores a value into a closure variable + assert isinstance(state.result, CompExp), ( + "STORE_DEREF must be called within a comprehension context" ) + var_name = instr.argval - # Update the last loop in the generators list - if isinstance(state.result, ast.DictComp): - new_dict: ast.DictComp = ast.DictComp( - key=state.result.key, - value=state.result.value, - generators=state.result.generators[:-1] + [updated_loop], - ) - return replace(state, stack=state.stack[:-1], result=new_dict) + if not state.stack or ( + isinstance(state.stack[-1], ast.Name) and state.stack[-1].id == var_name + ): + # If the variable is already on the stack, we can skip adding it again + # This is common in nested comprehensions where the same variable is reused + return replace(state, stack=state.stack[:-1]) else: - new_comp: ast.GeneratorExp | ast.ListComp | ast.SetComp = type(state.result)( - elt=state.result.elt, - generators=state.result.generators[:-1] + [updated_loop], + # Update the most recent loop's target variable + assert len(state.result.generators) > 0, ( + "STORE_DEREF must be within a loop context" ) - return replace(state, stack=state.stack[:-1], result=new_comp) + + new_stack = state.stack[:-1] + new_result: CompExp = copy.deepcopy(state.result) + new_result.generators[-1].target = ast.Name(id=var_name, ctx=ast.Store()) + return replace(state, stack=new_stack, result=new_result) @register_handler("LOAD_DEREF") @@ -1321,25 +1361,6 @@ def handle_call_function( return replace(state, stack=new_stack) -@register_handler("LOAD_METHOD", version=PythonVersion.PY_310) -def handle_load_method( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - # LOAD_METHOD loads a method from an object - # It pushes the bound method and the object (for the method call) - obj = ensure_ast(state.stack[-1]) - method_name = instr.argval - new_stack = state.stack[:-1] - - # Create method access as an attribute - method_attr = ast.Attribute(value=obj, attr=method_name, ctx=ast.Load()) - - # For LOAD_METHOD, we push both the method and the object - # But for AST purposes, we just need the method attribute - new_stack = new_stack + [method_attr] - return replace(state, stack=new_stack) - - @register_handler("CALL_METHOD", version=PythonVersion.PY_310) def handle_call_method( state: ReconstructionState, instr: dis.Instruction @@ -1351,8 +1372,8 @@ def handle_call_method( args = ( [ensure_ast(arg) for arg in state.stack[-arg_count:]] if arg_count > 0 else [] ) - method = ensure_ast(state.stack[-arg_count - 1]) - new_stack = state.stack[: -arg_count - 1] + method = ensure_ast(state.stack[-arg_count - 2]) + new_stack = state.stack[: -arg_count - 2] # Create method call AST call_node = ast.Call(func=method, args=args, keywords=[]) @@ -1437,6 +1458,39 @@ def handle_set_function_attribute( # ============================================================================ +@register_handler("LOAD_METHOD", version=PythonVersion.PY_310) +def handle_load_method( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # LOAD_METHOD loads a method from an object + # It pushes the bound method and the object (for the method call) + obj = ensure_ast(state.stack[-1]) + method_name = instr.argval + new_stack = state.stack[:-1] + + # Create method access as an attribute + method_attr = ast.Attribute(value=obj, attr=method_name, ctx=ast.Load()) + + # For LOAD_METHOD, we push both the method and the object + # But for AST purposes, we just need the method attribute + new_stack = new_stack + [method_attr, obj] + return replace(state, stack=new_stack) + + +@register_handler("LOAD_ATTR", version=PythonVersion.PY_310) +def handle_load_attr_310( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # LOAD_ATTR loads an attribute from the object on top of stack + obj = ensure_ast(state.stack[-1]) + attr_name = instr.argval + + # Create attribute access AST + attr_node = ast.Attribute(value=obj, attr=attr_name, ctx=ast.Load()) + new_stack = state.stack[:-1] + [attr_node] + return replace(state, stack=new_stack) + + @register_handler("LOAD_ATTR", version=PythonVersion.PY_313) def handle_load_attr( state: ReconstructionState, instr: dis.Instruction diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index 301aab87..e7904c85 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -695,6 +695,7 @@ def test_comp_lambda_copy(): ast.comprehension( target=ast.Name(id="x", ctx=ast.Store()), iter=DummyIterName(), + ifs=[], is_async=0, ) ], From 4fef01b3fbd26ab16155bacfd82af2b96e59fb34 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 20 Jun 2025 13:44:08 -0400 Subject: [PATCH 061/106] more consolidation --- effectful/internals/disassembler.py | 401 +++++++--------------------- 1 file changed, 100 insertions(+), 301 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index 9b2c1425..f306fb4b 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -223,28 +223,7 @@ def handle_return_generator( return replace(state, result=new_result, stack=new_stack) -@register_handler("YIELD_VALUE", version=PythonVersion.PY_310) -def handle_yield_value_310( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - # YIELD_VALUE pops a value from the stack and yields it - # This is the expression part of the generator - assert isinstance(state.result, ast.GeneratorExp), ( - "YIELD_VALUE must be called after RETURN_GENERATOR" - ) - assert isinstance(state.result.elt, Placeholder), ( - "YIELD_VALUE must be called before yielding" - ) - assert len(state.result.generators) > 0, "YIELD_VALUE should have generators" - - ret = ast.GeneratorExp( - elt=ensure_ast(state.stack[-1]), - generators=state.result.generators, - ) - return replace(state, result=ret) - - -@register_handler("YIELD_VALUE", version=PythonVersion.PY_313) +@register_handler("YIELD_VALUE") def handle_yield_value( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -258,12 +237,11 @@ def handle_yield_value( ) assert len(state.result.generators) > 0, "YIELD_VALUE should have generators" - new_stack = state.stack # [:-1] ret = ast.GeneratorExp( elt=ensure_ast(state.stack[-1]), generators=state.result.generators, ) - return replace(state, stack=new_stack, result=ret) + return replace(state, result=ret) # ============================================================================ @@ -271,34 +249,7 @@ def handle_yield_value( # ============================================================================ -# Python 3.10 version -@register_handler("BUILD_LIST", version=PythonVersion.PY_310) -def handle_build_list_310( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - if isinstance(state.result, Placeholder) and len(state.stack) == 0: - # This BUILD_LIST is the start of a list comprehension - # Initialize the result as a ListComp with a placeholder element - ret = ast.ListComp(elt=Placeholder(), generators=[]) - new_stack = state.stack + [ret] - return replace(state, stack=new_stack, result=ret) - else: - assert instr.arg is not None - size: int = instr.arg - # Pop elements for the list - elements = ( - [ensure_ast(elem) for elem in state.stack[-size:]] if size > 0 else [] - ) - new_stack = state.stack[:-size] if size > 0 else state.stack - - # Create list AST - elt_node = ast.List(elts=elements, ctx=ast.Load()) - new_stack = new_stack + [elt_node] - return replace(state, stack=new_stack) - - -# Python 3.13 version -@register_handler("BUILD_LIST", version=PythonVersion.PY_313) +@register_handler("BUILD_LIST") def handle_build_list( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -361,33 +312,7 @@ def handle_list_append( # ============================================================================ -@register_handler("BUILD_SET", version=PythonVersion.PY_310) -def handle_build_set_310( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - if isinstance(state.result, Placeholder) and len(state.stack) == 0: - # This BUILD_SET is the start of a list comprehension - # Initialize the result as a ListComp with a placeholder element - ret = ast.SetComp(elt=Placeholder(), generators=[]) - new_stack = state.stack + [ret] - return replace(state, stack=new_stack, result=ret) - else: - assert instr.arg is not None - size: int = instr.arg - # Pop elements for the set - elements = ( - [ensure_ast(elem) for elem in state.stack[-size:]] if size > 0 else [] - ) - new_stack = state.stack[:-size] if size > 0 else state.stack - - # Create set AST - elt_node = ast.Set(elts=elements) - new_stack = new_stack + [elt_node] - return replace(state, stack=new_stack) - - -# Python 3.13 version -@register_handler("BUILD_SET", version=PythonVersion.PY_313) +@register_handler("BUILD_SET") def handle_build_set( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -446,34 +371,7 @@ def handle_set_add( # ============================================================================ -@register_handler("BUILD_MAP", version=PythonVersion.PY_310) -def handle_build_map_310( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - if isinstance(state.result, Placeholder) and len(state.stack) == 0: - # This is the start of a comprehension - # Initialize the result with a placeholder element - ret = ast.DictComp(key=Placeholder(), value=Placeholder(), generators=[]) - new_stack = state.stack + [ret] - return replace(state, stack=new_stack, result=ret) - else: - assert instr.arg is not None - size: int = instr.arg - # Pop key-value pairs for the dict - keys: list[ast.expr | None] = [ - ensure_ast(state.stack[-2 * i - 2]) for i in range(size) - ] - values = [ensure_ast(state.stack[-2 * i - 1]) for i in range(size)] - new_stack = state.stack[: -2 * size] if size > 0 else state.stack - - # Create dict AST - dict_node = ast.Dict(keys=keys, values=values) - new_stack = new_stack + [dict_node] - return replace(state, stack=new_stack) - - -# Python 3.13 version -@register_handler("BUILD_MAP", version=PythonVersion.PY_313) +@register_handler("BUILD_MAP") def handle_build_map( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -678,21 +576,6 @@ def handle_return_const( raise TypeError("Unexpected RETURN_CONST in reconstruction") -@register_handler("CALL_INTRINSIC_1", version=PythonVersion.PY_313) -def handle_call_intrinsic_1( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - # CALL_INTRINSIC_1 calls an intrinsic function with one argument - if instr.argrepr == "INTRINSIC_STOPITERATION_ERROR": - return state - elif instr.argrepr == "INTRINSIC_UNARY_POSITIVE": - assert len(state.stack) > 0 - new_val = ast.UnaryOp(op=ast.UAdd(), operand=state.stack[-1]) - return replace(state, stack=state.stack[:-1] + [new_val]) - else: - raise TypeError(f"Unsupported generator intrinsic operation: {instr.argrepr}") - - @register_handler("RERAISE", version=PythonVersion.PY_313) def handle_reraise( state: ReconstructionState, instr: dis.Instruction @@ -723,6 +606,58 @@ def handle_load_fast( return replace(state, stack=new_stack) +@register_handler("LOAD_DEREF") +def handle_load_deref( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # LOAD_DEREF loads a value from a closure variable + var_name = instr.argval + new_stack = state.stack + [ast.Name(id=var_name, ctx=ast.Load())] + return replace(state, stack=new_stack) + + +@register_handler("LOAD_CLOSURE") +def handle_load_closure( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # LOAD_CLOSURE loads a closure variable + var_name = instr.argval + new_stack = state.stack + [ast.Name(id=var_name, ctx=ast.Load())] + return replace(state, stack=new_stack) + + +@register_handler("LOAD_CONST") +def handle_load_const( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + const_value = instr.argval + new_stack = state.stack + [ensure_ast(const_value)] + return replace(state, stack=new_stack) + + +@register_handler("LOAD_GLOBAL") +def handle_load_global( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + global_name = instr.argval + + if instr.argrepr.endswith(" + NULL"): + new_stack = state.stack + [ast.Name(id=global_name, ctx=ast.Load()), Null()] + else: + new_stack = state.stack + [ast.Name(id=global_name, ctx=ast.Load())] + return replace(state, stack=new_stack) + + +@register_handler("LOAD_NAME") +def handle_load_name( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # LOAD_NAME is similar to LOAD_GLOBAL but for names in the global namespace + name = instr.argval + new_stack = state.stack + [ast.Name(id=name, ctx=ast.Load())] + return replace(state, stack=new_stack) + + @register_handler("LOAD_FAST_AND_CLEAR", version=PythonVersion.PY_313) def handle_load_fast_and_clear( state: ReconstructionState, instr: dis.Instruction @@ -770,15 +705,21 @@ def handle_load_fast_load_fast( return replace(state, stack=new_stack) -@register_handler("STORE_FAST", version=PythonVersion.PY_310) -def handle_store_fast_310( +@register_handler("STORE_FAST") +def handle_store_fast( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: - assert isinstance(state.result, CompExp), ( + assert isinstance(state.result, CompExp) and state.result.generators, ( "STORE_FAST must be called within a comprehension context" ) var_name = instr.argval - assert len(state.result.generators) > 0, "STORE_FAST must be within a loop context" + + if not state.stack or ( + isinstance(state.stack[-1], ast.Name) and state.stack[-1].id == var_name + ): + # If the variable is already on the stack, we can skip adding it again + # This is common in nested comprehensions where the same variable is reused + return replace(state, stack=state.stack[:-1]) new_stack = state.stack[:-1] new_result: CompExp = copy.deepcopy(state.result) @@ -786,12 +727,13 @@ def handle_store_fast_310( return replace(state, stack=new_stack, result=new_result) -@register_handler("STORE_FAST", version=PythonVersion.PY_313) -def handle_store_fast( +@register_handler("STORE_DEREF") +def handle_store_deref( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: - assert isinstance(state.result, CompExp), ( - "STORE_FAST must be called within a comprehension context" + # STORE_DEREF stores a value into a closure variable + assert isinstance(state.result, CompExp) and state.result.generators, ( + "STORE_DEREF must be called within a comprehension context" ) var_name = instr.argval @@ -801,16 +743,11 @@ def handle_store_fast( # If the variable is already on the stack, we can skip adding it again # This is common in nested comprehensions where the same variable is reused return replace(state, stack=state.stack[:-1]) - else: - # Update the most recent loop's target variable - assert len(state.result.generators) > 0, ( - "STORE_FAST must be within a loop context" - ) - new_stack = state.stack[:-1] - new_result: CompExp = copy.deepcopy(state.result) - new_result.generators[-1].target = ast.Name(id=var_name, ctx=ast.Store()) - return replace(state, stack=new_stack, result=new_result) + new_stack = state.stack[:-1] + new_result: CompExp = copy.deepcopy(state.result) + new_result.generators[-1].target = ast.Name(id=var_name, ctx=ast.Store()) + return replace(state, stack=new_stack, result=new_result) @register_handler("STORE_FAST_LOAD_FAST", version=PythonVersion.PY_313) @@ -822,7 +759,7 @@ def handle_store_fast_load_fast( # In Python 3.13, this is often used for loop variables # First handle the store part - assert isinstance(state.result, CompExp), ( + assert isinstance(state.result, CompExp) and state.result.generators, ( "STORE_FAST_LOAD_FAST must be called within a comprehension context" ) @@ -831,69 +768,10 @@ def handle_store_fast_load_fast( assert isinstance(instr.argval, tuple) store_name, load_name = instr.argval - # Update the most recent loop's target variable - assert len(state.result.generators) > 0, ( - "STORE_FAST_LOAD_FAST must be within a loop context" - ) - - # Create a new LoopInfo with updated target - updated_loop = ast.comprehension( - target=ast.Name(id=store_name, ctx=ast.Store()), - iter=state.result.generators[-1].iter, - ifs=state.result.generators[-1].ifs, - is_async=state.result.generators[-1].is_async, - ) - - # Update the last loop in the generators list - if isinstance(state.result, ast.DictComp): - new_dict: ast.DictComp = ast.DictComp( - key=state.result.key, - value=state.result.value, - generators=state.result.generators[:-1] + [updated_loop], - ) - new_state = replace(state, result=new_dict) - else: - new_comp: ast.GeneratorExp | ast.ListComp | ast.SetComp = type(state.result)( - elt=state.result.elt, - generators=state.result.generators[:-1] + [updated_loop], - ) - new_state = replace(state, result=new_comp) - - # Now handle the load part - push the variable onto the stack - new_stack = new_state.stack[:-1] + [ast.Name(id=load_name, ctx=ast.Load())] - return replace(new_state, stack=new_stack) - - -@register_handler("LOAD_CONST") -def handle_load_const( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - const_value = instr.argval - new_stack = state.stack + [ensure_ast(const_value)] - return replace(state, stack=new_stack) - - -@register_handler("LOAD_GLOBAL") -def handle_load_global( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - global_name = instr.argval - - if instr.argrepr.endswith(" + NULL"): - new_stack = state.stack + [ast.Name(id=global_name, ctx=ast.Load()), Null()] - else: - new_stack = state.stack + [ast.Name(id=global_name, ctx=ast.Load())] - return replace(state, stack=new_stack) - - -@register_handler("LOAD_NAME") -def handle_load_name( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - # LOAD_NAME is similar to LOAD_GLOBAL but for names in the global namespace - name = instr.argval - new_stack = state.stack + [ast.Name(id=name, ctx=ast.Load())] - return replace(state, stack=new_stack) + new_stack = state.stack[:-1] + [ast.Name(id=load_name, ctx=ast.Load())] + new_result: CompExp = copy.deepcopy(state.result) + new_result.generators[-1].target = ast.Name(id=store_name, ctx=ast.Store()) + return replace(state, stack=new_stack, result=new_result) @register_handler("MAKE_CELL", version=PythonVersion.PY_313) @@ -918,73 +796,6 @@ def handle_copy_free_vars( return state -@register_handler("STORE_DEREF", version=PythonVersion.PY_310) -def handle_store_deref_310( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - # STORE_DEREF stores a value into a closure variable - assert isinstance(state.result, CompExp), ( - "STORE_DEREF must be called within a comprehension context" - ) - var_name = instr.argval - - # Update the most recent loop's target variable - assert len(state.result.generators) > 0, "STORE_DEREF must be within a loop context" - - new_stack = state.stack[:-1] - new_result: CompExp = copy.deepcopy(state.result) - new_result.generators[-1].target = ast.Name(id=var_name, ctx=ast.Store()) - return replace(state, stack=new_stack, result=new_result) - - -@register_handler("STORE_DEREF", version=PythonVersion.PY_313) -def handle_store_deref( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - # STORE_DEREF stores a value into a closure variable - assert isinstance(state.result, CompExp), ( - "STORE_DEREF must be called within a comprehension context" - ) - var_name = instr.argval - - if not state.stack or ( - isinstance(state.stack[-1], ast.Name) and state.stack[-1].id == var_name - ): - # If the variable is already on the stack, we can skip adding it again - # This is common in nested comprehensions where the same variable is reused - return replace(state, stack=state.stack[:-1]) - else: - # Update the most recent loop's target variable - assert len(state.result.generators) > 0, ( - "STORE_DEREF must be within a loop context" - ) - - new_stack = state.stack[:-1] - new_result: CompExp = copy.deepcopy(state.result) - new_result.generators[-1].target = ast.Name(id=var_name, ctx=ast.Store()) - return replace(state, stack=new_stack, result=new_result) - - -@register_handler("LOAD_DEREF") -def handle_load_deref( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - # LOAD_DEREF loads a value from a closure variable - var_name = instr.argval - new_stack = state.stack + [ast.Name(id=var_name, ctx=ast.Load())] - return replace(state, stack=new_stack) - - -@register_handler("LOAD_CLOSURE") -def handle_load_closure( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - # LOAD_CLOSURE loads a closure variable - var_name = instr.argval - new_stack = state.stack + [ast.Name(id=var_name, ctx=ast.Load())] - return replace(state, stack=new_stack) - - # ============================================================================ # STACK MANAGEMENT HANDLERS # ============================================================================ @@ -1242,6 +1053,21 @@ def handle_unary_op( ) +@register_handler("CALL_INTRINSIC_1", version=PythonVersion.PY_313) +def handle_call_intrinsic_1( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # CALL_INTRINSIC_1 calls an intrinsic function with one argument + if instr.argrepr == "INTRINSIC_STOPITERATION_ERROR": + return state + elif instr.argrepr == "INTRINSIC_UNARY_POSITIVE": + assert len(state.stack) > 0 + new_val = ast.UnaryOp(op=ast.UAdd(), operand=state.stack[-1]) + return replace(state, stack=state.stack[:-1] + [new_val]) + else: + raise TypeError(f"Unsupported generator intrinsic operation: {instr.argrepr}") + + # ============================================================================ # COMPARISON OPERATION HANDLERS # ============================================================================ @@ -1562,36 +1388,7 @@ def handle_list_to_tuple( return replace(state, stack=new_stack) -@register_handler("LIST_EXTEND", version=PythonVersion.PY_310) -def handle_list_extend_310( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - # LIST_EXTEND extends the list at TOS-1 with the iterable at TOS - iterable = ensure_ast(state.stack[-1]) - list_obj = state.stack[-2] # This should be a list from BUILD_LIST - new_stack = state.stack[:-2] - - # If the list is empty and we're extending with a tuple/iterable, - # we can convert this to a simple list of the iterable's elements - if isinstance(list_obj, ast.List) and len(list_obj.elts) == 0: - # If extending with a constant tuple, expand it to list elements - if isinstance(iterable, ast.Constant) and isinstance(iterable.value, tuple): - elements: list[ast.expr] = [ - ast.Constant(value=elem) for elem in iterable.value - ] - list_node = ast.List(elts=elements, ctx=ast.Load()) - new_stack = new_stack + [list_node] - return replace(state, stack=new_stack) - - # Fallback: create a list from the iterable using list() constructor - list_call = ast.Call( - func=ast.Name(id="list", ctx=ast.Load()), args=[iterable], keywords=[] - ) - new_stack = new_stack + [list_call] - return replace(state, stack=new_stack) - - -@register_handler("LIST_EXTEND", version=PythonVersion.PY_313) +@register_handler("LIST_EXTEND") def handle_list_extend( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -1606,7 +1403,9 @@ def handle_list_extend( assert isinstance(state.stack[-1], ast.Tuple | ast.List) prev_result = state.stack[-instr.argval - 1] - list_obj = ast.List(elts=[ensure_ast(e) for e in state.stack[-1].elts]) + list_obj = ast.List( + elts=[ensure_ast(e) for e in state.stack[-1].elts], ctx=ast.Load() + ) new_stack = state.stack[:-2] + [list_obj] return replace(state, stack=new_stack, result=prev_result) From 7999916103530ec679d974ae88d0f2d257a43c70 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 20 Jun 2025 13:54:25 -0400 Subject: [PATCH 062/106] more consolidation of comprehension bodies --- effectful/internals/disassembler.py | 156 ++++++++++------------------ 1 file changed, 53 insertions(+), 103 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index f306fb4b..1ea825fc 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -271,24 +271,7 @@ def handle_build_list( return replace(state, stack=new_stack) -# Python 3.10 version -@register_handler("LIST_APPEND", version=PythonVersion.PY_310) -def handle_list_append_310( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - assert isinstance(state.result, ast.ListComp), ( - "LIST_APPEND must be called within a ListComp context" - ) - new_stack = state.stack[:-1] - new_ret = ast.ListComp( - elt=ensure_ast(state.stack[-1]), - generators=state.result.generators, - ) - return replace(state, stack=new_stack, result=new_ret) - - -# Python 3.13 version -@register_handler("LIST_APPEND", version=PythonVersion.PY_313) +@register_handler("LIST_APPEND") def handle_list_append( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -331,23 +314,7 @@ def handle_build_set( return replace(state, stack=new_stack) -@register_handler("SET_ADD", version=PythonVersion.PY_310) -def handle_set_add_310( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - assert isinstance(state.result, ast.SetComp), ( - "SET_ADD must be called after BUILD_SET" - ) - new_stack = state.stack[:-1] - new_ret = ast.SetComp( - elt=ensure_ast(state.stack[-1]), - generators=state.result.generators, - ) - return replace(state, stack=new_stack, result=new_ret) - - -# Python 3.13 version -@register_handler("SET_ADD", version=PythonVersion.PY_313) +@register_handler("SET_ADD") def handle_set_add( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -396,24 +363,7 @@ def handle_build_map( return replace(state, stack=new_stack) -@register_handler("MAP_ADD", version=PythonVersion.PY_310) -def handle_map_add_310( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - assert isinstance(state.result, ast.DictComp), ( - "MAP_ADD must be called after BUILD_MAP" - ) - new_stack = state.stack[:-2] - new_ret = ast.DictComp( - key=ensure_ast(state.stack[-2]), - value=ensure_ast(state.stack[-1]), - generators=state.result.generators, - ) - return replace(state, stack=new_stack, result=new_ret) - - -# Python 3.13 version -@register_handler("MAP_ADD", version=PythonVersion.PY_313) +@register_handler("MAP_ADD") def handle_map_add( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -447,10 +397,10 @@ def handle_return_value( # Usually preceded by LOAD_CONST None if isinstance(state.result, CompExp): return replace(state, stack=state.stack[:-1]) - elif isinstance(state.result, Placeholder) and len(state.stack) == 1: - new_result = ensure_ast(state.stack[-1]) + elif isinstance(state.result, Placeholder): + new_result = ensure_ast(state.stack[0]) assert isinstance(new_result, CompExp | ast.Lambda) - return replace(state, stack=state.stack[:-1], result=new_result) + return replace(state, stack=state.stack[1:], result=new_result) else: raise TypeError("Unexpected RETURN_VALUE in reconstruction") @@ -658,53 +608,6 @@ def handle_load_name( return replace(state, stack=new_stack) -@register_handler("LOAD_FAST_AND_CLEAR", version=PythonVersion.PY_313) -def handle_load_fast_and_clear( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - # LOAD_FAST_AND_CLEAR pushes a local variable onto the stack and clears it - # For AST reconstruction, we treat this the same as LOAD_FAST - var_name: str = instr.argval - - if var_name == ".0": - # Special handling for .0 variable (the iterator) - new_stack = state.stack + [DummyIterName()] - else: - # Regular variable load - new_stack = state.stack + [ast.Name(id=var_name, ctx=ast.Load())] - - return replace(state, stack=new_stack) - - -@register_handler("LOAD_FAST_LOAD_FAST", version=PythonVersion.PY_313) -def handle_load_fast_load_fast( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - # LOAD_FAST_LOAD_FAST loads two variables (optimization in Python 3.13) - # The instruction argument contains both variable names - if isinstance(instr.argval, tuple): - var1, var2 = instr.argval - else: - # Fallback: assume both names are the same - var1 = var2 = instr.argval - - new_stack = state.stack - - # Load first variable - if var1 == ".0": - new_stack = new_stack + [DummyIterName()] - else: - new_stack = new_stack + [ast.Name(id=var1, ctx=ast.Load())] - - # Load second variable - if var2 == ".0": - new_stack = new_stack + [DummyIterName()] - else: - new_stack = new_stack + [ast.Name(id=var2, ctx=ast.Load())] - - return replace(state, stack=new_stack) - - @register_handler("STORE_FAST") def handle_store_fast( state: ReconstructionState, instr: dis.Instruction @@ -774,6 +677,53 @@ def handle_store_fast_load_fast( return replace(state, stack=new_stack, result=new_result) +@register_handler("LOAD_FAST_AND_CLEAR", version=PythonVersion.PY_313) +def handle_load_fast_and_clear( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # LOAD_FAST_AND_CLEAR pushes a local variable onto the stack and clears it + # For AST reconstruction, we treat this the same as LOAD_FAST + var_name: str = instr.argval + + if var_name == ".0": + # Special handling for .0 variable (the iterator) + new_stack = state.stack + [DummyIterName()] + else: + # Regular variable load + new_stack = state.stack + [ast.Name(id=var_name, ctx=ast.Load())] + + return replace(state, stack=new_stack) + + +@register_handler("LOAD_FAST_LOAD_FAST", version=PythonVersion.PY_313) +def handle_load_fast_load_fast( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # LOAD_FAST_LOAD_FAST loads two variables (optimization in Python 3.13) + # The instruction argument contains both variable names + if isinstance(instr.argval, tuple): + var1, var2 = instr.argval + else: + # Fallback: assume both names are the same + var1 = var2 = instr.argval + + new_stack = state.stack + + # Load first variable + if var1 == ".0": + new_stack = new_stack + [DummyIterName()] + else: + new_stack = new_stack + [ast.Name(id=var1, ctx=ast.Load())] + + # Load second variable + if var2 == ".0": + new_stack = new_stack + [DummyIterName()] + else: + new_stack = new_stack + [ast.Name(id=var2, ctx=ast.Load())] + + return replace(state, stack=new_stack) + + @register_handler("MAKE_CELL", version=PythonVersion.PY_313) def handle_make_cell( state: ReconstructionState, instr: dis.Instruction From 2e59776ea26e39d7d742537bf89c90b046b45818 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 20 Jun 2025 13:57:38 -0400 Subject: [PATCH 063/106] shuffle --- effectful/internals/disassembler.py | 58 ++++++++++++++--------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index 1ea825fc..3fc0f9ce 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -448,12 +448,12 @@ def handle_get_iter( return state -@register_handler("JUMP_BACKWARD", version=PythonVersion.PY_313) -def handle_jump_backward( +@register_handler("JUMP_FORWARD") +def handle_jump_forward( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: - # JUMP_BACKWARD is used to jump back to the beginning of a loop (replaces JUMP_ABSOLUTE in 3.13) - # In generator expressions, this typically indicates the end of the loop body + # JUMP_FORWARD is used to jump forward in the code + # In generator expressions, this is often used to skip code in conditional logic return state @@ -466,36 +466,15 @@ def handle_jump_absolute( return state -@register_handler("JUMP_FORWARD") -def handle_jump_forward( +@register_handler("JUMP_BACKWARD", version=PythonVersion.PY_313) +def handle_jump_backward( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: - # JUMP_FORWARD is used to jump forward in the code - # In generator expressions, this is often used to skip code in conditional logic + # JUMP_BACKWARD is used to jump back to the beginning of a loop (replaces JUMP_ABSOLUTE in 3.13) + # In generator expressions, this typically indicates the end of the loop body return state -@register_handler("UNPACK_SEQUENCE") -def handle_unpack_sequence( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - # UNPACK_SEQUENCE unpacks a sequence into multiple values - # arg is the number of values to unpack - assert instr.arg is not None - unpack_count: int = instr.arg - sequence = ensure_ast(state.stack[-1]) # noqa: F841 - new_stack = state.stack[:-1] - - # For tuple unpacking in comprehensions, we typically see patterns like: - # ((k, v) for k, v in items) where items is unpacked into k and v - # Create placeholder variables for the unpacked values - for i in range(unpack_count): - var_name = f"_unpack_{i}" - new_stack = new_stack + [ast.Name(id=var_name, ctx=ast.Load())] - - return replace(state, stack=new_stack) - - @register_handler("RESUME", version=PythonVersion.PY_313) def handle_resume( state: ReconstructionState, instr: dis.Instruction @@ -1304,6 +1283,27 @@ def handle_binary_subscr( # ============================================================================ +@register_handler("UNPACK_SEQUENCE") +def handle_unpack_sequence( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # UNPACK_SEQUENCE unpacks a sequence into multiple values + # arg is the number of values to unpack + assert instr.arg is not None + unpack_count: int = instr.arg + sequence = ensure_ast(state.stack[-1]) # noqa: F841 + new_stack = state.stack[:-1] + + # For tuple unpacking in comprehensions, we typically see patterns like: + # ((k, v) for k, v in items) where items is unpacked into k and v + # Create placeholder variables for the unpacked values + for i in range(unpack_count): + var_name = f"_unpack_{i}" + new_stack = new_stack + [ast.Name(id=var_name, ctx=ast.Load())] + + return replace(state, stack=new_stack) + + @register_handler("BUILD_TUPLE") def handle_build_tuple( state: ReconstructionState, instr: dis.Instruction From df42bf9bb895cb47674f00e45d6cedfb7a662b7a Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 20 Jun 2025 14:19:55 -0400 Subject: [PATCH 064/106] more ops --- effectful/internals/disassembler.py | 102 +++++++++++++++++++-------- tests/test_internals_disassembler.py | 1 + 2 files changed, 75 insertions(+), 28 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index 3fc0f9ce..f65a228f 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -982,6 +982,20 @@ def handle_unary_op( ) +@register_handler("LIST_TO_TUPLE", version=PythonVersion.PY_310) +def handle_list_to_tuple( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # LIST_TO_TUPLE converts a list on the stack to a tuple + list_obj = ensure_ast(state.stack[-1]) + assert isinstance(list_obj, ast.List), "Expected a list for LIST_TO_TUPLE" + + # Create tuple AST from the list's elements + tuple_node = ast.Tuple(elts=list_obj.elts, ctx=ast.Load()) + new_stack = state.stack[:-1] + [tuple_node] + return replace(state, stack=new_stack) + + @register_handler("CALL_INTRINSIC_1", version=PythonVersion.PY_313) def handle_call_intrinsic_1( state: ReconstructionState, instr: dis.Instruction @@ -989,6 +1003,12 @@ def handle_call_intrinsic_1( # CALL_INTRINSIC_1 calls an intrinsic function with one argument if instr.argrepr == "INTRINSIC_STOPITERATION_ERROR": return state + elif instr.argrepr == "INTRINSIC_LIST_TO_TUPLE": + assert isinstance(state.stack[-1], ast.List), ( + "Expected a list for LIST_TO_TUPLE" + ) + tuple_node = ast.Tuple(elts=state.stack[-1].elts, ctx=ast.Load()) + return replace(state, stack=state.stack[:-1] + [tuple_node]) elif instr.argrepr == "INTRINSIC_UNARY_POSITIVE": assert len(state.stack) > 0 new_val = ast.UnaryOp(op=ast.UAdd(), operand=state.stack[-1]) @@ -1324,17 +1344,24 @@ def handle_build_tuple( return replace(state, stack=new_stack) -@register_handler("LIST_TO_TUPLE", version=PythonVersion.PY_310) -def handle_list_to_tuple( +@register_handler("BUILD_CONST_KEY_MAP") +def handle_build_const_key_map( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: - # LIST_TO_TUPLE converts a list on the stack to a tuple - list_obj = ensure_ast(state.stack[-1]) - assert isinstance(list_obj, ast.List), "Expected a list for LIST_TO_TUPLE" + # BUILD_CONST_KEY_MAP builds a dictionary with constant keys + # The keys are in a tuple on TOS, values are on the stack below + assert instr.arg is not None + assert isinstance(state.stack[-1], ast.Tuple), "Expected a tuple of keys" + map_size: int = instr.arg + # Pop the keys tuple and values + keys_tuple: ast.Tuple = state.stack[-1] + keys: list[ast.expr | None] = [ensure_ast(key) for key in keys_tuple.elts] + values = [ensure_ast(val) for val in state.stack[-map_size - 1 : -1]] + new_stack = state.stack[: -map_size - 1] - # Create tuple AST from the list's elements - tuple_node = ast.Tuple(elts=list_obj.elts, ctx=ast.Load()) - new_stack = state.stack[:-1] + [tuple_node] + # Create dictionary AST + dict_node = ast.Dict(keys=keys, values=values) + new_stack = new_stack + [dict_node] return replace(state, stack=new_stack) @@ -1353,33 +1380,51 @@ def handle_list_extend( assert isinstance(state.stack[-1], ast.Tuple | ast.List) prev_result = state.stack[-instr.argval - 1] - list_obj = ast.List( + new_val = ast.List( elts=[ensure_ast(e) for e in state.stack[-1].elts], ctx=ast.Load() ) - new_stack = state.stack[:-2] + [list_obj] + new_stack = state.stack[:-2] + [new_val] return replace(state, stack=new_stack, result=prev_result) -@register_handler("BUILD_CONST_KEY_MAP") -def handle_build_const_key_map( +@register_handler("SET_UPDATE") +def handle_set_update( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: - # BUILD_CONST_KEY_MAP builds a dictionary with constant keys - # The keys are in a tuple on TOS, values are on the stack below - assert instr.arg is not None - assert isinstance(state.stack[-1], ast.Tuple), "Expected a tuple of keys" - map_size: int = instr.arg - # Pop the keys tuple and values - keys_tuple: ast.Tuple = state.stack[-1] - keys: list[ast.expr | None] = [ensure_ast(key) for key in keys_tuple.elts] - values = [ensure_ast(val) for val in state.stack[-map_size - 1 : -1]] - new_stack = state.stack[: -map_size - 1] + # The set being extended is actually in state.result instead of the stack + # because it was initially recognized as a list comprehension in BUILD_SET, + # while the actual result expression is in the stack where the set "should be" + # and needs to be put back into the state result slot + assert isinstance(state.result, ast.SetComp) and not state.result.generators + assert isinstance(state.stack[-1], ast.Tuple | ast.List | ast.Set) + prev_result = state.stack[-instr.argval - 1] - # Create dictionary AST - dict_node = ast.Dict(keys=keys, values=values) - new_stack = new_stack + [dict_node] - return replace(state, stack=new_stack) + new_val = ast.Set(elts=[ensure_ast(e) for e in state.stack[-1].elts]) + new_stack = state.stack[:-2] + [new_val] + + return replace(state, stack=new_stack, result=prev_result) + + +@register_handler("DICT_UPDATE") +def handle_dict_update( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # The dict being extended is actually in state.result instead of the stack + # because it was initially recognized as a list comprehension in BUILD_MAP, + # while the actual result expression is in the stack where the dict "should be" + # and needs to be put back into the state result slot + assert isinstance(state.result, ast.DictComp) and not state.result.generators + assert isinstance(state.stack[-1], ast.Dict) + prev_result = state.stack[-instr.argval - 1] + + new_val = ast.Dict( + keys=[ensure_ast(e) for e in state.stack[-1].keys], + values=[ensure_ast(e) for e in state.stack[-1].values], + ) + new_stack = state.stack[:-2] + [new_val] + + return replace(state, stack=new_stack, result=prev_result) # ============================================================================ @@ -1684,8 +1729,9 @@ def _ensure_ast_list_iterator(value: Iterator) -> ast.List: return ensure_ast(list(value.__reduce__()[1][0])) # type: ignore -@ensure_ast.register -def _ensure_ast_set(value: set) -> ast.Set: +@ensure_ast.register(set) +@ensure_ast.register(frozenset) +def _ensure_ast_set(value: set | frozenset) -> ast.Set: return ast.Set(elts=[ensure_ast(v) for v in value]) diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index e7904c85..8e14499f 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -548,6 +548,7 @@ def test_variable_lookup(genexpr, globals_dict): # Edge cases with complex data structures (((1, 2, 3)[x % 3] for x in range(10)), {}), (([1, 2, 3][x % 3] for x in range(10)), {}), + (({1, 2, 3} for x in range(10)), {}), # (({"even": x, "odd": x + 1}["even" if x % 2 == 0 else "odd"] for x in range(5)), {}), # Function calls with multiple arguments ((pow(x, 2, 10) for x in range(5)), {"pow": pow}), From fc448e8933b6da2c6ae5876cfeb244c70d9b5f25 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 20 Jun 2025 14:24:46 -0400 Subject: [PATCH 065/106] decorator --- effectful/internals/disassembler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index f65a228f..147bfc53 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -189,7 +189,7 @@ def _wrapper( return new_state OP_HANDLERS[opname] = _wrapper - return _wrapper + return handler # return the original handler for multiple decorator usage # ============================================================================ From 3b545a28a0ffcdfa78fde6552e0f400ab5b790e0 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 20 Jun 2025 14:37:45 -0400 Subject: [PATCH 066/106] clean up conditionals --- effectful/internals/disassembler.py | 163 ++++++---------------------- 1 file changed, 32 insertions(+), 131 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index 147bfc53..3df4aca6 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -1017,6 +1017,19 @@ def handle_call_intrinsic_1( raise TypeError(f"Unsupported generator intrinsic operation: {instr.argrepr}") +@register_handler("TO_BOOL", version=PythonVersion.PY_313) +def handle_to_bool( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # TO_BOOL converts the top stack item to a boolean + # For AST reconstruction, we typically don't need an explicit bool() call + # since the boolean context is usually handled by the conditional jump that follows + # However, for some cases we might need to preserve the explicit conversion + + # For now, leave the value as-is since the jump instruction will handle the boolean logic + return state + + # ============================================================================ # COMPARISON OPERATION HANDLERS # ============================================================================ @@ -1445,27 +1458,10 @@ def handle_pop_jump_if_false_310( if instr.argval < instr.offset: # Jumping backward to loop start - this is a condition # When POP_JUMP_IF_FALSE jumps back, it means "if false, skip this item" - # So we need to negate the condition to get the filter condition assert isinstance(state.result, CompExp) and state.result.generators - updated_loop = ast.comprehension( - target=state.result.generators[-1].target, - iter=state.result.generators[-1].iter, - ifs=state.result.generators[-1].ifs + [condition], - is_async=state.result.generators[-1].is_async, - ) - if isinstance(state.result, ast.DictComp): - new_dict = ast.DictComp( - key=state.result.key, - value=state.result.value, - generators=state.result.generators[:-1] + [updated_loop], - ) - return replace(state, stack=new_stack, result=new_dict) - else: - new_comp = type(state.result)( - elt=state.result.elt, - generators=state.result.generators[:-1] + [updated_loop], - ) - return replace(state, stack=new_stack, result=new_comp) + new_result = copy.deepcopy(state.result) + new_result.generators[-1].ifs.append(condition) + return replace(state, stack=new_stack, result=new_result) else: raise NotImplementedError("Lazy and+or behavior not implemented yet") @@ -1485,26 +1481,9 @@ def handle_pop_jump_if_false( # it means "if condition is False, then yield the item" # So we need to negate the condition: we want items where NOT condition negated_condition = ast.UnaryOp(op=ast.Not(), operand=condition) - - updated_loop = ast.comprehension( - target=state.result.generators[-1].target, - iter=state.result.generators[-1].iter, - ifs=state.result.generators[-1].ifs + [negated_condition], - is_async=state.result.generators[-1].is_async, - ) - if isinstance(state.result, ast.DictComp): - new_dict = ast.DictComp( - key=state.result.key, - value=state.result.value, - generators=state.result.generators[:-1] + [updated_loop], - ) - return replace(state, stack=new_stack, result=new_dict) - else: - new_comp = type(state.result)( - elt=state.result.elt, - generators=state.result.generators[:-1] + [updated_loop], - ) - return replace(state, stack=new_stack, result=new_comp) + new_result = copy.deepcopy(state.result) + new_result.generators[-1].ifs.append(negated_condition) + return replace(state, stack=new_stack, result=new_result) else: # Not in a comprehension context - might be boolean logic raise NotImplementedError("Lazy and+or behavior not implemented yet") @@ -1527,27 +1506,10 @@ def handle_pop_jump_if_true_310( # When POP_JUMP_IF_TRUE jumps back, it means "if false, skip this item" # So we need to negate the condition to get the filter condition assert isinstance(state.result, CompExp) and state.result.generators - # negate the condition - condition = ast.UnaryOp(op=ast.Not(), operand=condition) - updated_loop = ast.comprehension( - target=state.result.generators[-1].target, - iter=state.result.generators[-1].iter, - ifs=state.result.generators[-1].ifs + [condition], - is_async=state.result.generators[-1].is_async, - ) - if isinstance(state.result, ast.DictComp): - new_dict = ast.DictComp( - key=state.result.key, - value=state.result.value, - generators=state.result.generators[:-1] + [updated_loop], - ) - return replace(state, stack=new_stack, result=new_dict) - else: - new_comp = type(state.result)( - elt=state.result.elt, - generators=state.result.generators[:-1] + [updated_loop], - ) - return replace(state, stack=new_stack, result=new_comp) + negated_condition = ast.UnaryOp(op=ast.Not(), operand=condition) + new_result = copy.deepcopy(state.result) + new_result.generators[-1].ifs.append(negated_condition) + return replace(state, stack=new_stack, result=new_result) else: raise NotImplementedError("Lazy and+or behavior not implemented yet") @@ -1566,43 +1528,14 @@ def handle_pop_jump_if_true( if isinstance(state.result, CompExp) and state.result.generators: # For POP_JUMP_IF_TRUE in filters, we want the condition to be true to continue # So we add the condition directly (no negation needed) - updated_loop = ast.comprehension( - target=state.result.generators[-1].target, - iter=state.result.generators[-1].iter, - ifs=state.result.generators[-1].ifs + [condition], - is_async=state.result.generators[-1].is_async, - ) - if isinstance(state.result, ast.DictComp): - new_dict = ast.DictComp( - key=state.result.key, - value=state.result.value, - generators=state.result.generators[:-1] + [updated_loop], - ) - return replace(state, stack=new_stack, result=new_dict) - else: - new_comp = type(state.result)( - elt=state.result.elt, - generators=state.result.generators[:-1] + [updated_loop], - ) - return replace(state, stack=new_stack, result=new_comp) + new_result = copy.deepcopy(state.result) + new_result.generators[-1].ifs.append(condition) + return replace(state, stack=new_stack, result=new_result) else: # Not in a comprehension context - might be boolean logic raise NotImplementedError("Lazy and+or behavior not implemented yet") -@register_handler("TO_BOOL", version=PythonVersion.PY_313) -def handle_to_bool( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - # TO_BOOL converts the top stack item to a boolean - # For AST reconstruction, we typically don't need an explicit bool() call - # since the boolean context is usually handled by the conditional jump that follows - # However, for some cases we might need to preserve the explicit conversion - - # For now, leave the value as-is since the jump instruction will handle the boolean logic - return state - - @register_handler("POP_JUMP_IF_NONE", version=PythonVersion.PY_313) def handle_pop_jump_if_none( state: ReconstructionState, instr: dis.Instruction @@ -1616,25 +1549,9 @@ def handle_pop_jump_if_none( none_condition = ast.Compare( left=condition, ops=[ast.Is()], comparators=[ast.Constant(value=None)] ) - updated_loop = ast.comprehension( - target=state.result.generators[-1].target, - iter=state.result.generators[-1].iter, - ifs=state.result.generators[-1].ifs + [none_condition], - is_async=state.result.generators[-1].is_async, - ) - if isinstance(state.result, ast.DictComp): - new_dict = ast.DictComp( - key=state.result.key, - value=state.result.value, - generators=state.result.generators[:-1] + [updated_loop], - ) - return replace(state, stack=new_stack, result=new_dict) - else: - new_comp = type(state.result)( - elt=state.result.elt, - generators=state.result.generators[:-1] + [updated_loop], - ) - return replace(state, stack=new_stack, result=new_comp) + new_result = copy.deepcopy(state.result) + new_result.generators[-1].ifs.append(none_condition) + return replace(state, stack=new_stack, result=new_result) else: raise NotImplementedError("Lazy and+or behavior not implemented yet") @@ -1652,25 +1569,9 @@ def handle_pop_jump_if_not_none( not_none_condition = ast.Compare( left=condition, ops=[ast.IsNot()], comparators=[ast.Constant(value=None)] ) - updated_loop = ast.comprehension( - target=state.result.generators[-1].target, - iter=state.result.generators[-1].iter, - ifs=state.result.generators[-1].ifs + [not_none_condition], - is_async=state.result.generators[-1].is_async, - ) - if isinstance(state.result, ast.DictComp): - new_dict = ast.DictComp( - key=state.result.key, - value=state.result.value, - generators=state.result.generators[:-1] + [updated_loop], - ) - return replace(state, stack=new_stack, result=new_dict) - else: - new_comp = type(state.result)( - elt=state.result.elt, - generators=state.result.generators[:-1] + [updated_loop], - ) - return replace(state, stack=new_stack, result=new_comp) + new_result = copy.deepcopy(state.result) + new_result.generators[-1].ifs.append(not_none_condition) + return replace(state, stack=new_stack, result=new_result) else: raise NotImplementedError("Lazy and+or behavior not implemented yet") From fc6a48c40e56289de09fd24d9871c5e320ddeb8b Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 21 Jun 2025 13:10:15 -0400 Subject: [PATCH 067/106] Support lambda reconstruction --- effectful/internals/disassembler.py | 136 +++++++++++++++++---------- tests/test_internals_disassembler.py | 21 +---- 2 files changed, 91 insertions(+), 66 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index 3df4aca6..269edf4d 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -111,7 +111,7 @@ class ReconstructionState: BINARY_ADD pop operands and push results. """ - result: CompExp | ast.Lambda | Placeholder = field(default_factory=Placeholder) + result: ast.expr = field(default_factory=Placeholder) stack: list[ast.expr] = field(default_factory=list) @@ -399,7 +399,7 @@ def handle_return_value( return replace(state, stack=state.stack[:-1]) elif isinstance(state.result, Placeholder): new_result = ensure_ast(state.stack[0]) - assert isinstance(new_result, CompExp | ast.Lambda) + assert isinstance(new_result, ast.expr) return replace(state, stack=state.stack[1:], result=new_result) else: raise TypeError("Unexpected RETURN_VALUE in reconstruction") @@ -844,6 +844,13 @@ def handle_copy( return replace(state, stack=new_stack) +@register_handler("PUSH_NULL", version=PythonVersion.PY_313) +def handle_push_null( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + return replace(state, stack=state.stack + [Null()]) + + # ============================================================================ # BINARY ARITHMETIC/LOGIC OPERATION HANDLERS # ============================================================================ @@ -1175,6 +1182,7 @@ def handle_make_function_310( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: # MAKE_FUNCTION creates a function from code object and name on stack + assert isinstance(state.stack[-2], ast.Lambda | CompLambda) assert isinstance(state.stack[-1], ast.Constant) and isinstance( state.stack[-1].value, str ), "Function name must be a constant string." @@ -1188,21 +1196,16 @@ def handle_make_function_310( "MAKE_FUNCTION with defaults or annotations not implemented." ) - body: ast.expr = state.stack[-2] + # Pop the function object and name from the stack + # Conversion from CodeType to ast.Lambda should have happened already + func: ast.Lambda | CompLambda = state.stack[-2] name: str = state.stack[-1].value assert any( name.endswith(suffix) for suffix in ("", "", "", "", "") ), f"Expected a comprehension or lambda function, got '{name}'" - - if ( - isinstance(body, CompExp) - and sum(1 for x in ast.walk(body) if isinstance(x, DummyIterName)) == 1 - ): - return replace(state, stack=new_stack + [CompLambda(body=body)]) - else: - raise NotImplementedError("Lambda reconstruction not implemented yet") + return replace(state, stack=new_stack + [func]) # Python 3.13 version @@ -1214,17 +1217,12 @@ def handle_make_function( # and creates a function from it. No flags, no extra attributes on the stack. # All extra attributes are handled by separate SET_FUNCTION_ATTRIBUTE instructions. - # Pop the code object from the stack (it's the only thing expected) - body: ast.expr = state.stack[-1] - new_stack = state.stack[:-1] - - if ( - isinstance(body, CompExp) - and sum(1 for x in ast.walk(body) if isinstance(x, DummyIterName)) == 1 - ): - return replace(state, stack=new_stack + [CompLambda(body=body)]) - else: - raise NotImplementedError("Lambda reconstruction not implemented yet") + # Pop the function object from the stack (it's the only thing expected) + # Conversion from CodeType to ast.Lambda should have happened already + assert isinstance(state.stack[-1], ast.Lambda | CompLambda), ( + "Expected a function object (Lambda or CompLambda) on the stack." + ) + return state @register_handler("SET_FUNCTION_ATTRIBUTE", version=PythonVersion.PY_313) @@ -1669,7 +1667,26 @@ def _ensure_ast_range_iterator(value: Iterator) -> ast.Call: @ensure_ast.register -def _ensure_ast_codeobj(value: types.CodeType) -> ast.expr: +def _ensure_ast_codeobj(value: types.CodeType) -> ast.Lambda | CompLambda: + assert inspect.iscode(value), "Input must be a code object" + + name: str = value.co_name.split(".")[-1] + + # Check preconditions + if name in {"", "", "", ""}: + assert name == "" or sys.version_info < (3, 13) + assert name != "" or value.co_flags & inspect.CO_GENERATOR + assert value.co_flags & inspect.CO_NEWLOCALS + assert value.co_argcount == 1 + assert value.co_kwonlyargcount == value.co_posonlyargcount == 0 + assert DummyIterName().id in value.co_varnames + elif name == "": + assert not value.co_flags & inspect.CO_GENERATOR + assert value.co_flags & inspect.CO_NEWLOCALS + assert DummyIterName().id not in value.co_varnames + else: + raise TypeError(f"Unsupported code object type: {value.co_name}") + # Symbolic execution to reconstruct the AST state = ReconstructionState() for instr in dis.get_instructions(value): @@ -1682,10 +1699,47 @@ def _ensure_ast_codeobj(value: types.CodeType) -> ast.expr: assert not any(isinstance(x, Placeholder) for x in ast.walk(state.result)), ( "Return value must not contain placeholders" ) - assert not isinstance(state.result, CompExp) or len(state.result.generators) > 0, ( - "Return value must have generators if not a lambda" - ) - return state.result + assert all( + x.generators for x in ast.walk(state.result) if isinstance(x, CompExp) + ), "Return value must have generators if not a lambda" + + if name == "": + assert isinstance(state.result, ast.expr) + args = ast.arguments( + posonlyargs=[ + ast.arg(arg=arg) + for arg in value.co_varnames[: value.co_posonlyargcount] + ], + args=[ + ast.arg(arg=arg) + for arg in value.co_varnames[ + value.co_posonlyargcount : value.co_argcount + ] + ], + kwonlyargs=[ + ast.arg(arg=arg) + for arg in value.co_varnames[ + value.co_argcount : value.co_argcount + value.co_kwonlyargcount + ] + ], + kw_defaults=[], + defaults=[], + ) + return ast.Lambda(args=args, body=state.result) + elif name == "": + assert isinstance(state.result, ast.GeneratorExp) + return CompLambda(body=state.result) + elif name == "" and sys.version_info < (3, 13): + assert isinstance(state.result, ast.DictComp) + return CompLambda(body=state.result) + elif name == "" and sys.version_info < (3, 13): + assert isinstance(state.result, ast.ListComp) + return CompLambda(body=state.result) + elif name == "" and sys.version_info < (3, 13): + assert isinstance(state.result, ast.SetComp) + return CompLambda(body=state.result) + else: + raise TypeError(f"Unsupported code object type: {value.co_name}") @ensure_ast.register @@ -1693,27 +1747,11 @@ def _ensure_ast_lambda(value: types.LambdaType) -> ast.Lambda: assert inspect.isfunction(value) and value.__name__.endswith(""), ( "Input must be a lambda function" ) - code: types.CodeType = value.__code__ - body: ast.expr = ensure_ast(code) - args = ast.arguments( - posonlyargs=[ - ast.arg(arg=arg) for arg in code.co_varnames[: code.co_posonlyargcount] - ], - args=[ - ast.arg(arg=arg) - for arg in code.co_varnames[code.co_posonlyargcount : code.co_argcount] - ], - kwonlyargs=[ - ast.arg(arg=arg) - for arg in code.co_varnames[ - code.co_argcount : code.co_argcount + code.co_kwonlyargcount - ] - ], - kw_defaults=[], - defaults=[], - ) - return ast.Lambda(args=args, body=body) + result = ensure_ast(code) + assert isinstance(result, ast.Lambda), "Lambda body must be an AST Lambda node" + assert not isinstance(result, CompLambda), "Lambda must not be a CompLambda" + return result @ensure_ast.register @@ -1723,9 +1761,9 @@ def _ensure_ast_genexpr(genexpr: types.GeneratorType) -> ast.GeneratorExp: "Generator must be in created state" ) genexpr_ast = ensure_ast(genexpr.gi_code) - assert isinstance(genexpr_ast, ast.GeneratorExp) + assert isinstance(genexpr_ast, CompLambda) geniter_ast = ensure_ast(genexpr.gi_frame.f_locals[".0"]) - result = CompLambda(body=genexpr_ast).inline(geniter_ast) + result = genexpr_ast.inline(geniter_ast) assert isinstance(result, ast.GeneratorExp) assert inspect.getgeneratorstate(genexpr) == inspect.GEN_CREATED, ( "Generator must stay in created state" diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index 8e14499f..2ec68d3f 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -500,23 +500,10 @@ def test_variable_lookup(genexpr, globals_dict): "genexpr,globals_dict", [ # Using lambdas and functions - pytest.param( - ((lambda y: y * 2)(x) for x in range(5)), - {}, - marks=pytest.mark.xfail(reason="Lambda reconstruction not implemented yet"), - ), - pytest.param( - ((lambda y: y + 1)(x) for x in range(5)), - {}, - marks=pytest.mark.xfail(reason="Lambda reconstruction not implemented yet"), - ), - pytest.param( - ((lambda y: y**2)(x) for x in range(5)), - {}, - marks=pytest.mark.xfail(reason="Lambda reconstruction not implemented yet"), - ), - # More complex lambdas - # (((lambda a, b: a + b)(x, x) for x in range(5)), {}), + (((lambda y: y * 2)(x) for x in range(5)), {}), + (((lambda y: y + 1)(x) for x in range(5)), {}), + (((lambda y: y**2)(x) for x in range(5)), {}), + (((lambda a, b: a + b)(x, x) for x in range(5)), {}), ((f(x) for x in range(5)), {"f": lambda y: y * 3}), # type: ignore # noqa: F821 # Attribute access ((x.real for x in [1 + 2j, 3 + 4j, 5 + 6j]), {}), From f23c07e483c943520742598e7eb2608eb0b345d4 Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 21 Jun 2025 13:13:31 -0400 Subject: [PATCH 068/106] add one more test case --- tests/test_internals_disassembler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index 2ec68d3f..facab839 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -504,6 +504,7 @@ def test_variable_lookup(genexpr, globals_dict): (((lambda y: y + 1)(x) for x in range(5)), {}), (((lambda y: y**2)(x) for x in range(5)), {}), (((lambda a, b: a + b)(x, x) for x in range(5)), {}), + (((lambda: (x for x in range(i)))() for i in range(3)), {}), ((f(x) for x in range(5)), {"f": lambda y: y * 3}), # type: ignore # noqa: F821 # Attribute access ((x.real for x in [1 + 2j, 3 + 4j, 5 + 6j]), {}), From d65b9b7d14eaad2de36287070bfcd25cd35db56a Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 21 Jun 2025 13:43:34 -0400 Subject: [PATCH 069/106] cleanup --- effectful/internals/disassembler.py | 52 ++++++++++++++++------------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index 269edf4d..3150ad7e 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -1690,21 +1690,29 @@ def _ensure_ast_codeobj(value: types.CodeType) -> ast.Lambda | CompLambda: # Symbolic execution to reconstruct the AST state = ReconstructionState() for instr in dis.get_instructions(value): - if instr.opname not in OP_HANDLERS: - raise KeyError(f"No handler found for opcode '{instr.opname}'") - state = OP_HANDLERS[instr.opname](state, instr) + result: ast.expr = state.result # Check postconditions - assert not any(isinstance(x, Placeholder) for x in ast.walk(state.result)), ( - "Return value must not contain placeholders" + assert not any( + isinstance(x, Placeholder | Null | CompLambda) for x in ast.walk(result) + ), "Final return value must not contain temporary nodes" + assert not any(x.arg == ".0" for x in ast.walk(result) if isinstance(x, ast.arg)), ( + "Final return value must not contain .0 argument" + ) + assert not any( + isinstance(x, ast.Name) and x.id == ".0" + for x in ast.walk(result) + if not isinstance(x, DummyIterName) + ), "Final return value must not contain .0 names" + assert sum(1 for x in ast.walk(result) if isinstance(x, DummyIterName)) <= 1, ( + "Final return value must contain at most 1 dummy iterator names" + ) + assert all(x.generators for x in ast.walk(result) if isinstance(x, CompExp)), ( + "Return value must have generators if not a lambda" ) - assert all( - x.generators for x in ast.walk(state.result) if isinstance(x, CompExp) - ), "Return value must have generators if not a lambda" - if name == "": - assert isinstance(state.result, ast.expr) + if name == "" and isinstance(result, ast.expr): args = ast.arguments( posonlyargs=[ ast.arg(arg=arg) @@ -1725,21 +1733,17 @@ def _ensure_ast_codeobj(value: types.CodeType) -> ast.Lambda | CompLambda: kw_defaults=[], defaults=[], ) - return ast.Lambda(args=args, body=state.result) - elif name == "": - assert isinstance(state.result, ast.GeneratorExp) - return CompLambda(body=state.result) - elif name == "" and sys.version_info < (3, 13): - assert isinstance(state.result, ast.DictComp) - return CompLambda(body=state.result) - elif name == "" and sys.version_info < (3, 13): - assert isinstance(state.result, ast.ListComp) - return CompLambda(body=state.result) - elif name == "" and sys.version_info < (3, 13): - assert isinstance(state.result, ast.SetComp) - return CompLambda(body=state.result) + return ast.Lambda(args=args, body=result) + elif name == "" and isinstance(result, ast.GeneratorExp): + return CompLambda(body=result) + elif name == "" and isinstance(result, ast.DictComp): + return CompLambda(body=result) + elif name == "" and isinstance(result, ast.ListComp): + return CompLambda(body=result) + elif name == "" and isinstance(result, ast.SetComp): + return CompLambda(body=result) else: - raise TypeError(f"Unsupported code object type: {value.co_name}") + raise TypeError(f"Invalid result for type {name}: {result}") @ensure_ast.register From 075e9423b8dedac02ef6453f58ae437dbfeec351 Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 21 Jun 2025 13:58:42 -0400 Subject: [PATCH 070/106] stmt postcondition --- effectful/internals/disassembler.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index 3150ad7e..70dd3aa1 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -1694,6 +1694,9 @@ def _ensure_ast_codeobj(value: types.CodeType) -> ast.Lambda | CompLambda: result: ast.expr = state.result # Check postconditions + assert not any(isinstance(x, ast.stmt) for x in ast.walk(result)), ( + "Final return value must not contain statement nodes" + ) assert not any( isinstance(x, Placeholder | Null | CompLambda) for x in ast.walk(result) ), "Final return value must not contain temporary nodes" From 8cf3d443ed2e84cb0087f72585e9d5440d82ab14 Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 21 Jun 2025 14:32:06 -0400 Subject: [PATCH 071/106] nit --- effectful/internals/disassembler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index 70dd3aa1..c12a6c90 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -18,6 +18,7 @@ import ast import copy import dis +import enum import functools import inspect import sys @@ -25,7 +26,6 @@ import typing from collections.abc import Callable, Generator, Iterator from dataclasses import dataclass, field, replace -from enum import Enum CompExp = ast.GeneratorExp | ast.ListComp | ast.SetComp | ast.DictComp @@ -116,7 +116,7 @@ class ReconstructionState: # Python version enum for version-specific handling -class PythonVersion(Enum): +class PythonVersion(enum.Enum): PY_310 = 10 PY_313 = 13 From 7f4682b7b70cacbe2cc5437bba08c10903c908f3 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 25 Jul 2025 09:12:56 -0400 Subject: [PATCH 072/106] explicit version required --- effectful/internals/disassembler.py | 189 ++++++++++++++++++++++------ 1 file changed, 153 insertions(+), 36 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index c12a6c90..ba600339 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -14,6 +14,35 @@ >>> ast_node = reconstruct(g) >>> # ast_node is now an ast.Expression representing the original expression """ +# Summary of Python 3.12 Bytecode Differences, by Claude Opus +# +# Based on my analysis of the disassembler module and Python documentation, here are the key bytecode differences in Python 3.12 relative to 3.10 and 3.13: +# +# Operations Missing in Current 3.12 Implementation: +# +# 1. BINARY_OP - In Python 3.12, individual binary operations (BINARY_ADD, BINARY_SUBTRACT, etc.) were replaced by a single BINARY_OP instruction with different argument values. The module currently only handles this for 3.13. +# 2. Jump instruction changes - Python 3.12 still uses some older jump instructions but with different offset calculations compared to 3.10. +# 3. Stack manipulation - Python 3.12 is in a transitional state: +# - Still has ROT_TWO, ROT_THREE, ROT_FOUR (like 3.10) +# - Doesn't have SWAP/COPY yet (introduced in 3.13) +# 4. CALL changes - Python 3.12 has different CALL behavior than both 3.10 (CALL_FUNCTION) and 3.13 (unified CALL). +# 5. Other missing operations for 3.12: +# - END_FOR (introduced in 3.12, currently only handled for 3.13) +# - CALL_INTRINSIC_1/CALL_INTRINSIC_2 (new in 3.12) +# - TO_BOOL (introduced in 3.12, currently only handled for 3.13) +# - Different MAKE_FUNCTION behavior +# +# Key Operations to Add for Python 3.12 Support: +# +# 1. Binary operations - Need BINARY_OP handler for 3.12 +# 2. Jump instructions - May need adjustments for 3.12-specific behavior +# 3. Stack operations - Keep using ROT_* operations for 3.12 +# 4. Call operations - Need 3.12-specific CALL handling +# 5. Intrinsic operations - Add CALL_INTRINSIC_1 for 3.12 +# 6. Boolean conversion - Add TO_BOOL for 3.12 +# 7. Loop control - Add END_FOR for 3.12 +# +# The module is well-structured with version-specific handlers, so adding 3.12 support involves registering appropriate handlers with version=PythonVersion.PY_312 for operations that differ between versions. import ast import copy @@ -118,6 +147,7 @@ class ReconstructionState: # Python version enum for version-specific handling class PythonVersion(enum.Enum): PY_310 = 10 + PY_312 = 12 PY_313 = 13 @@ -129,7 +159,7 @@ class PythonVersion(enum.Enum): @typing.overload def register_handler( - opname: str, *, version: PythonVersion = PythonVersion(sys.version_info.minor) + opname: str, *, version: PythonVersion ) -> Callable[[OpHandler], OpHandler]: ... @@ -138,7 +168,7 @@ def register_handler( opname: str, handler: OpHandler, *, - version: PythonVersion = PythonVersion(sys.version_info.minor), + version: PythonVersion, ) -> OpHandler: ... @@ -146,7 +176,7 @@ def register_handler( opname: str, handler=None, *, - version: PythonVersion = PythonVersion(sys.version_info.minor), + version: PythonVersion, ): """Register a handler for a specific operation name and optional version""" if handler is None: @@ -223,7 +253,9 @@ def handle_return_generator( return replace(state, result=new_result, stack=new_stack) -@register_handler("YIELD_VALUE") +@register_handler("YIELD_VALUE", version=PythonVersion.PY_310) +@register_handler("YIELD_VALUE", version=PythonVersion.PY_312) +@register_handler("YIELD_VALUE", version=PythonVersion.PY_313) def handle_yield_value( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -249,7 +281,9 @@ def handle_yield_value( # ============================================================================ -@register_handler("BUILD_LIST") +@register_handler("BUILD_LIST", version=PythonVersion.PY_310) +@register_handler("BUILD_LIST", version=PythonVersion.PY_312) +@register_handler("BUILD_LIST", version=PythonVersion.PY_313) def handle_build_list( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -271,7 +305,9 @@ def handle_build_list( return replace(state, stack=new_stack) -@register_handler("LIST_APPEND") +@register_handler("LIST_APPEND", version=PythonVersion.PY_310) +@register_handler("LIST_APPEND", version=PythonVersion.PY_312) +@register_handler("LIST_APPEND", version=PythonVersion.PY_313) def handle_list_append( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -295,7 +331,9 @@ def handle_list_append( # ============================================================================ -@register_handler("BUILD_SET") +@register_handler("BUILD_SET", version=PythonVersion.PY_310) +@register_handler("BUILD_SET", version=PythonVersion.PY_312) +@register_handler("BUILD_SET", version=PythonVersion.PY_313) def handle_build_set( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -314,7 +352,9 @@ def handle_build_set( return replace(state, stack=new_stack) -@register_handler("SET_ADD") +@register_handler("SET_ADD", version=PythonVersion.PY_310) +@register_handler("SET_ADD", version=PythonVersion.PY_312) +@register_handler("SET_ADD", version=PythonVersion.PY_313) def handle_set_add( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -338,7 +378,9 @@ def handle_set_add( # ============================================================================ -@register_handler("BUILD_MAP") +@register_handler("BUILD_MAP", version=PythonVersion.PY_310) +@register_handler("BUILD_MAP", version=PythonVersion.PY_312) +@register_handler("BUILD_MAP", version=PythonVersion.PY_313) def handle_build_map( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -363,7 +405,9 @@ def handle_build_map( return replace(state, stack=new_stack) -@register_handler("MAP_ADD") +@register_handler("MAP_ADD", version=PythonVersion.PY_310) +@register_handler("MAP_ADD", version=PythonVersion.PY_312) +@register_handler("MAP_ADD", version=PythonVersion.PY_313) def handle_map_add( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -389,7 +433,9 @@ def handle_map_add( # ============================================================================ -@register_handler("RETURN_VALUE") +@register_handler("RETURN_VALUE", version=PythonVersion.PY_310) +@register_handler("RETURN_VALUE", version=PythonVersion.PY_312) +@register_handler("RETURN_VALUE", version=PythonVersion.PY_313) def handle_return_value( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -405,7 +451,9 @@ def handle_return_value( raise TypeError("Unexpected RETURN_VALUE in reconstruction") -@register_handler("FOR_ITER") +@register_handler("FOR_ITER", version=PythonVersion.PY_310) +@register_handler("FOR_ITER", version=PythonVersion.PY_312) +@register_handler("FOR_ITER", version=PythonVersion.PY_313) def handle_for_iter( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -438,7 +486,9 @@ def handle_for_iter( return replace(state, stack=new_stack, result=new_ret) -@register_handler("GET_ITER") +@register_handler("GET_ITER", version=PythonVersion.PY_310) +@register_handler("GET_ITER", version=PythonVersion.PY_312) +@register_handler("GET_ITER", version=PythonVersion.PY_313) def handle_get_iter( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -448,7 +498,9 @@ def handle_get_iter( return state -@register_handler("JUMP_FORWARD") +@register_handler("JUMP_FORWARD", version=PythonVersion.PY_310) +@register_handler("JUMP_FORWARD", version=PythonVersion.PY_312) +@register_handler("JUMP_FORWARD", version=PythonVersion.PY_313) def handle_jump_forward( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -519,7 +571,9 @@ def handle_reraise( # ============================================================================ -@register_handler("LOAD_FAST") +@register_handler("LOAD_FAST", version=PythonVersion.PY_310) +@register_handler("LOAD_FAST", version=PythonVersion.PY_312) +@register_handler("LOAD_FAST", version=PythonVersion.PY_313) def handle_load_fast( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -535,7 +589,9 @@ def handle_load_fast( return replace(state, stack=new_stack) -@register_handler("LOAD_DEREF") +@register_handler("LOAD_DEREF", version=PythonVersion.PY_310) +@register_handler("LOAD_DEREF", version=PythonVersion.PY_312) +@register_handler("LOAD_DEREF", version=PythonVersion.PY_313) def handle_load_deref( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -545,7 +601,9 @@ def handle_load_deref( return replace(state, stack=new_stack) -@register_handler("LOAD_CLOSURE") +@register_handler("LOAD_CLOSURE", version=PythonVersion.PY_310) +@register_handler("LOAD_CLOSURE", version=PythonVersion.PY_312) +@register_handler("LOAD_CLOSURE", version=PythonVersion.PY_313) def handle_load_closure( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -555,7 +613,9 @@ def handle_load_closure( return replace(state, stack=new_stack) -@register_handler("LOAD_CONST") +@register_handler("LOAD_CONST", version=PythonVersion.PY_310) +@register_handler("LOAD_CONST", version=PythonVersion.PY_312) +@register_handler("LOAD_CONST", version=PythonVersion.PY_313) def handle_load_const( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -564,7 +624,9 @@ def handle_load_const( return replace(state, stack=new_stack) -@register_handler("LOAD_GLOBAL") +@register_handler("LOAD_GLOBAL", version=PythonVersion.PY_310) +@register_handler("LOAD_GLOBAL", version=PythonVersion.PY_312) +@register_handler("LOAD_GLOBAL", version=PythonVersion.PY_313) def handle_load_global( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -577,7 +639,9 @@ def handle_load_global( return replace(state, stack=new_stack) -@register_handler("LOAD_NAME") +@register_handler("LOAD_NAME", version=PythonVersion.PY_310) +@register_handler("LOAD_NAME", version=PythonVersion.PY_312) +@register_handler("LOAD_NAME", version=PythonVersion.PY_313) def handle_load_name( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -587,7 +651,9 @@ def handle_load_name( return replace(state, stack=new_stack) -@register_handler("STORE_FAST") +@register_handler("STORE_FAST", version=PythonVersion.PY_310) +@register_handler("STORE_FAST", version=PythonVersion.PY_312) +@register_handler("STORE_FAST", version=PythonVersion.PY_313) def handle_store_fast( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -609,7 +675,9 @@ def handle_store_fast( return replace(state, stack=new_stack, result=new_result) -@register_handler("STORE_DEREF") +@register_handler("STORE_DEREF", version=PythonVersion.PY_310) +@register_handler("STORE_DEREF", version=PythonVersion.PY_312) +@register_handler("STORE_DEREF", version=PythonVersion.PY_313) def handle_store_deref( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -730,7 +798,9 @@ def handle_copy_free_vars( # ============================================================================ -@register_handler("POP_TOP") +@register_handler("POP_TOP", version=PythonVersion.PY_310) +@register_handler("POP_TOP", version=PythonVersion.PY_312) +@register_handler("POP_TOP", version=PythonVersion.PY_313) def handle_pop_top( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -974,7 +1044,16 @@ def handle_unary_op( handle_unary_negative = register_handler( - "UNARY_NEGATIVE", functools.partial(handle_unary_op, ast.USub()) + "UNARY_NEGATIVE", functools.partial(handle_unary_op, ast.USub()), + version=PythonVersion.PY_310 +) +handle_unary_negative = register_handler( + "UNARY_NEGATIVE", functools.partial(handle_unary_op, ast.USub()), + version=PythonVersion.PY_312 +) +handle_unary_negative = register_handler( + "UNARY_NEGATIVE", functools.partial(handle_unary_op, ast.USub()), + version=PythonVersion.PY_313 ) handle_unary_positive = register_handler( "UNARY_POSITIVE", @@ -982,10 +1061,28 @@ def handle_unary_op( version=PythonVersion.PY_310, ) handle_unary_invert = register_handler( - "UNARY_INVERT", functools.partial(handle_unary_op, ast.Invert()) + "UNARY_INVERT", functools.partial(handle_unary_op, ast.Invert()), + version=PythonVersion.PY_310 +) +handle_unary_invert = register_handler( + "UNARY_INVERT", functools.partial(handle_unary_op, ast.Invert()), + version=PythonVersion.PY_312 +) +handle_unary_invert = register_handler( + "UNARY_INVERT", functools.partial(handle_unary_op, ast.Invert()), + version=PythonVersion.PY_313 +) +handle_unary_not = register_handler( + "UNARY_NOT", functools.partial(handle_unary_op, ast.Not()), + version=PythonVersion.PY_310 +) +handle_unary_not = register_handler( + "UNARY_NOT", functools.partial(handle_unary_op, ast.Not()), + version=PythonVersion.PY_312 ) handle_unary_not = register_handler( - "UNARY_NOT", functools.partial(handle_unary_op, ast.Not()) + "UNARY_NOT", functools.partial(handle_unary_op, ast.Not()), + version=PythonVersion.PY_313 ) @@ -1051,7 +1148,9 @@ def handle_to_bool( } -@register_handler("COMPARE_OP") +@register_handler("COMPARE_OP", version=PythonVersion.PY_310) +@register_handler("COMPARE_OP", version=PythonVersion.PY_312) +@register_handler("COMPARE_OP", version=PythonVersion.PY_313) def handle_compare_op( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -1069,7 +1168,9 @@ def handle_compare_op( return replace(state, stack=new_stack) -@register_handler("CONTAINS_OP") +@register_handler("CONTAINS_OP", version=PythonVersion.PY_310) +@register_handler("CONTAINS_OP", version=PythonVersion.PY_312) +@register_handler("CONTAINS_OP", version=PythonVersion.PY_313) def handle_contains_op( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -1084,7 +1185,9 @@ def handle_contains_op( return replace(state, stack=new_stack) -@register_handler("IS_OP") +@register_handler("IS_OP", version=PythonVersion.PY_310) +@register_handler("IS_OP", version=PythonVersion.PY_312) +@register_handler("IS_OP", version=PythonVersion.PY_313) def handle_is_op( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -1294,7 +1397,9 @@ def handle_load_attr( return replace(state, stack=new_stack) -@register_handler("BINARY_SUBSCR") +@register_handler("BINARY_SUBSCR", version=PythonVersion.PY_310) +@register_handler("BINARY_SUBSCR", version=PythonVersion.PY_312) +@register_handler("BINARY_SUBSCR", version=PythonVersion.PY_313) def handle_binary_subscr( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -1314,7 +1419,9 @@ def handle_binary_subscr( # ============================================================================ -@register_handler("UNPACK_SEQUENCE") +@register_handler("UNPACK_SEQUENCE", version=PythonVersion.PY_310) +@register_handler("UNPACK_SEQUENCE", version=PythonVersion.PY_312) +@register_handler("UNPACK_SEQUENCE", version=PythonVersion.PY_313) def handle_unpack_sequence( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -1335,7 +1442,9 @@ def handle_unpack_sequence( return replace(state, stack=new_stack) -@register_handler("BUILD_TUPLE") +@register_handler("BUILD_TUPLE", version=PythonVersion.PY_310) +@register_handler("BUILD_TUPLE", version=PythonVersion.PY_312) +@register_handler("BUILD_TUPLE", version=PythonVersion.PY_313) def handle_build_tuple( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -1355,7 +1464,9 @@ def handle_build_tuple( return replace(state, stack=new_stack) -@register_handler("BUILD_CONST_KEY_MAP") +@register_handler("BUILD_CONST_KEY_MAP", version=PythonVersion.PY_310) +@register_handler("BUILD_CONST_KEY_MAP", version=PythonVersion.PY_312) +@register_handler("BUILD_CONST_KEY_MAP", version=PythonVersion.PY_313) def handle_build_const_key_map( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -1376,7 +1487,9 @@ def handle_build_const_key_map( return replace(state, stack=new_stack) -@register_handler("LIST_EXTEND") +@register_handler("LIST_EXTEND", version=PythonVersion.PY_310) +@register_handler("LIST_EXTEND", version=PythonVersion.PY_312) +@register_handler("LIST_EXTEND", version=PythonVersion.PY_313) def handle_list_extend( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -1399,7 +1512,9 @@ def handle_list_extend( return replace(state, stack=new_stack, result=prev_result) -@register_handler("SET_UPDATE") +@register_handler("SET_UPDATE", version=PythonVersion.PY_310) +@register_handler("SET_UPDATE", version=PythonVersion.PY_312) +@register_handler("SET_UPDATE", version=PythonVersion.PY_313) def handle_set_update( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: @@ -1417,7 +1532,9 @@ def handle_set_update( return replace(state, stack=new_stack, result=prev_result) -@register_handler("DICT_UPDATE") +@register_handler("DICT_UPDATE", version=PythonVersion.PY_310) +@register_handler("DICT_UPDATE", version=PythonVersion.PY_312) +@register_handler("DICT_UPDATE", version=PythonVersion.PY_313) def handle_dict_update( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: From c94cbd8d9be678a6e9500ae2e00f2c58aff9daf5 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 20 Aug 2025 16:05:47 -0400 Subject: [PATCH 073/106] add 312 support --- effectful/internals/disassembler.py | 196 +++++++++++++++++++-------- tests/test_internals_disassembler.py | 3 +- 2 files changed, 145 insertions(+), 54 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index ba600339..eb2ea4c9 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -14,35 +14,6 @@ >>> ast_node = reconstruct(g) >>> # ast_node is now an ast.Expression representing the original expression """ -# Summary of Python 3.12 Bytecode Differences, by Claude Opus -# -# Based on my analysis of the disassembler module and Python documentation, here are the key bytecode differences in Python 3.12 relative to 3.10 and 3.13: -# -# Operations Missing in Current 3.12 Implementation: -# -# 1. BINARY_OP - In Python 3.12, individual binary operations (BINARY_ADD, BINARY_SUBTRACT, etc.) were replaced by a single BINARY_OP instruction with different argument values. The module currently only handles this for 3.13. -# 2. Jump instruction changes - Python 3.12 still uses some older jump instructions but with different offset calculations compared to 3.10. -# 3. Stack manipulation - Python 3.12 is in a transitional state: -# - Still has ROT_TWO, ROT_THREE, ROT_FOUR (like 3.10) -# - Doesn't have SWAP/COPY yet (introduced in 3.13) -# 4. CALL changes - Python 3.12 has different CALL behavior than both 3.10 (CALL_FUNCTION) and 3.13 (unified CALL). -# 5. Other missing operations for 3.12: -# - END_FOR (introduced in 3.12, currently only handled for 3.13) -# - CALL_INTRINSIC_1/CALL_INTRINSIC_2 (new in 3.12) -# - TO_BOOL (introduced in 3.12, currently only handled for 3.13) -# - Different MAKE_FUNCTION behavior -# -# Key Operations to Add for Python 3.12 Support: -# -# 1. Binary operations - Need BINARY_OP handler for 3.12 -# 2. Jump instructions - May need adjustments for 3.12-specific behavior -# 3. Stack operations - Keep using ROT_* operations for 3.12 -# 4. Call operations - Need 3.12-specific CALL handling -# 5. Intrinsic operations - Add CALL_INTRINSIC_1 for 3.12 -# 6. Boolean conversion - Add TO_BOOL for 3.12 -# 7. Loop control - Add END_FOR for 3.12 -# -# The module is well-structured with version-specific handlers, so adding 3.12 support involves registering appropriate handlers with version=PythonVersion.PY_312 for operations that differ between versions. import ast import copy @@ -239,6 +210,19 @@ def handle_gen_start( return replace(state, result=ast.GeneratorExp(elt=Placeholder(), generators=[])) +@register_handler("RETURN_GENERATOR", version=PythonVersion.PY_312) +def handle_return_generator_312( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # RETURN_GENERATOR is the first instruction in generator expressions in Python 3.13+ + # It initializes the generator + assert isinstance(state.result, Placeholder), ( + "RETURN_GENERATOR must be the first instruction" + ) + new_result = ast.GeneratorExp(elt=Placeholder(), generators=[]) + return replace(state, result=new_result) + + @register_handler("RETURN_GENERATOR", version=PythonVersion.PY_313) def handle_return_generator( state: ReconstructionState, instr: dis.Instruction @@ -518,6 +502,7 @@ def handle_jump_absolute( return state +@register_handler("JUMP_BACKWARD", version=PythonVersion.PY_312) @register_handler("JUMP_BACKWARD", version=PythonVersion.PY_313) def handle_jump_backward( state: ReconstructionState, instr: dis.Instruction @@ -527,6 +512,7 @@ def handle_jump_backward( return state +@register_handler("RESUME", version=PythonVersion.PY_312) @register_handler("RESUME", version=PythonVersion.PY_313) def handle_resume( state: ReconstructionState, instr: dis.Instruction @@ -535,15 +521,25 @@ def handle_resume( return state +@register_handler("END_FOR", version=PythonVersion.PY_312) +def handle_end_for_312( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # END_FOR marks the end of a for loop, followed by POP_TOP (in 3.12) + new_stack = state.stack[:-1] + return replace(state, stack=new_stack) + + @register_handler("END_FOR", version=PythonVersion.PY_313) def handle_end_for( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: # END_FOR marks the end of a for loop - no action needed for AST reconstruction - new_stack = state.stack # [:-1] + new_stack = state.stack return replace(state, stack=new_stack) +@register_handler("RETURN_CONST", version=PythonVersion.PY_312) @register_handler("RETURN_CONST", version=PythonVersion.PY_313) def handle_return_const( state: ReconstructionState, instr: dis.Instruction @@ -557,6 +553,7 @@ def handle_return_const( raise TypeError("Unexpected RETURN_CONST in reconstruction") +@register_handler("RERAISE", version=PythonVersion.PY_312) @register_handler("RERAISE", version=PythonVersion.PY_313) def handle_reraise( state: ReconstructionState, instr: dis.Instruction @@ -634,6 +631,8 @@ def handle_load_global( if instr.argrepr.endswith(" + NULL"): new_stack = state.stack + [ast.Name(id=global_name, ctx=ast.Load()), Null()] + elif instr.argrepr.startswith("NULL + "): + new_stack = state.stack + [Null(), ast.Name(id=global_name, ctx=ast.Load())] else: new_stack = state.stack + [ast.Name(id=global_name, ctx=ast.Load())] return replace(state, stack=new_stack) @@ -724,6 +723,7 @@ def handle_store_fast_load_fast( return replace(state, stack=new_stack, result=new_result) +@register_handler("LOAD_FAST_AND_CLEAR", version=PythonVersion.PY_312) @register_handler("LOAD_FAST_AND_CLEAR", version=PythonVersion.PY_313) def handle_load_fast_and_clear( state: ReconstructionState, instr: dis.Instruction @@ -771,6 +771,7 @@ def handle_load_fast_load_fast( return replace(state, stack=new_stack) +@register_handler("MAKE_CELL", version=PythonVersion.PY_312) @register_handler("MAKE_CELL", version=PythonVersion.PY_313) def handle_make_cell( state: ReconstructionState, instr: dis.Instruction @@ -782,6 +783,7 @@ def handle_make_cell( return state +@register_handler("COPY_FREE_VARS", version=PythonVersion.PY_312) @register_handler("COPY_FREE_VARS", version=PythonVersion.PY_313) def handle_copy_free_vars( state: ReconstructionState, instr: dis.Instruction @@ -862,6 +864,7 @@ def handle_rot_four( # Python 3.13 replacement for stack manipulation +@register_handler("SWAP", version=PythonVersion.PY_312) @register_handler("SWAP", version=PythonVersion.PY_313) def handle_swap( state: ReconstructionState, instr: dis.Instruction @@ -891,6 +894,7 @@ def handle_swap( return state +@register_handler("COPY", version=PythonVersion.PY_312) @register_handler("COPY", version=PythonVersion.PY_313) def handle_copy( state: ReconstructionState, instr: dis.Instruction @@ -914,6 +918,7 @@ def handle_copy( return replace(state, stack=new_stack) +@register_handler("PUSH_NULL", version=PythonVersion.PY_312) @register_handler("PUSH_NULL", version=PythonVersion.PY_313) def handle_push_null( state: ReconstructionState, instr: dis.Instruction @@ -936,6 +941,7 @@ def handle_binop( # Python 3.13 BINARY_OP handler +@register_handler("BINARY_OP", version=PythonVersion.PY_312) @register_handler("BINARY_OP", version=PythonVersion.PY_313) def handle_binary_op( state: ReconstructionState, instr: dis.Instruction @@ -1044,16 +1050,19 @@ def handle_unary_op( handle_unary_negative = register_handler( - "UNARY_NEGATIVE", functools.partial(handle_unary_op, ast.USub()), - version=PythonVersion.PY_310 + "UNARY_NEGATIVE", + functools.partial(handle_unary_op, ast.USub()), + version=PythonVersion.PY_310, ) handle_unary_negative = register_handler( - "UNARY_NEGATIVE", functools.partial(handle_unary_op, ast.USub()), - version=PythonVersion.PY_312 + "UNARY_NEGATIVE", + functools.partial(handle_unary_op, ast.USub()), + version=PythonVersion.PY_312, ) handle_unary_negative = register_handler( - "UNARY_NEGATIVE", functools.partial(handle_unary_op, ast.USub()), - version=PythonVersion.PY_313 + "UNARY_NEGATIVE", + functools.partial(handle_unary_op, ast.USub()), + version=PythonVersion.PY_313, ) handle_unary_positive = register_handler( "UNARY_POSITIVE", @@ -1061,28 +1070,34 @@ def handle_unary_op( version=PythonVersion.PY_310, ) handle_unary_invert = register_handler( - "UNARY_INVERT", functools.partial(handle_unary_op, ast.Invert()), - version=PythonVersion.PY_310 + "UNARY_INVERT", + functools.partial(handle_unary_op, ast.Invert()), + version=PythonVersion.PY_310, ) handle_unary_invert = register_handler( - "UNARY_INVERT", functools.partial(handle_unary_op, ast.Invert()), - version=PythonVersion.PY_312 + "UNARY_INVERT", + functools.partial(handle_unary_op, ast.Invert()), + version=PythonVersion.PY_312, ) handle_unary_invert = register_handler( - "UNARY_INVERT", functools.partial(handle_unary_op, ast.Invert()), - version=PythonVersion.PY_313 + "UNARY_INVERT", + functools.partial(handle_unary_op, ast.Invert()), + version=PythonVersion.PY_313, ) handle_unary_not = register_handler( - "UNARY_NOT", functools.partial(handle_unary_op, ast.Not()), - version=PythonVersion.PY_310 + "UNARY_NOT", + functools.partial(handle_unary_op, ast.Not()), + version=PythonVersion.PY_310, ) handle_unary_not = register_handler( - "UNARY_NOT", functools.partial(handle_unary_op, ast.Not()), - version=PythonVersion.PY_312 + "UNARY_NOT", + functools.partial(handle_unary_op, ast.Not()), + version=PythonVersion.PY_312, ) handle_unary_not = register_handler( - "UNARY_NOT", functools.partial(handle_unary_op, ast.Not()), - version=PythonVersion.PY_313 + "UNARY_NOT", + functools.partial(handle_unary_op, ast.Not()), + version=PythonVersion.PY_313, ) @@ -1100,14 +1115,13 @@ def handle_list_to_tuple( return replace(state, stack=new_stack) +@register_handler("CALL_INTRINSIC_1", version=PythonVersion.PY_312) @register_handler("CALL_INTRINSIC_1", version=PythonVersion.PY_313) def handle_call_intrinsic_1( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: # CALL_INTRINSIC_1 calls an intrinsic function with one argument - if instr.argrepr == "INTRINSIC_STOPITERATION_ERROR": - return state - elif instr.argrepr == "INTRINSIC_LIST_TO_TUPLE": + if instr.argrepr == "INTRINSIC_LIST_TO_TUPLE": assert isinstance(state.stack[-1], ast.List), ( "Expected a list for LIST_TO_TUPLE" ) @@ -1117,6 +1131,8 @@ def handle_call_intrinsic_1( assert len(state.stack) > 0 new_val = ast.UnaryOp(op=ast.UAdd(), operand=state.stack[-1]) return replace(state, stack=state.stack[:-1] + [new_val]) + elif instr.argrepr == "INTRINSIC_STOPITERATION_ERROR": + return state else: raise TypeError(f"Unsupported generator intrinsic operation: {instr.argrepr}") @@ -1207,6 +1223,47 @@ def handle_is_op( # ============================================================================ +@register_handler("CALL", version=PythonVersion.PY_312) +def handle_call_312( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # CALL in Python 3.12 handles both function and method calls + # Stack layout: [..., callable or self, callable or NULL] + assert instr.arg is not None + arg_count: int = instr.arg + + # Check if this is a method call (no NULL on top) + if isinstance(state.stack[-arg_count - 2], Null): + # Regular function call: [..., NULL, callable, *args] + func = ensure_ast(state.stack[-arg_count - 1]) + args = ( + [ensure_ast(arg) for arg in state.stack[-arg_count:]] + if arg_count > 0 + else [] + ) + new_stack = state.stack[: -arg_count - 2] + else: + # Method call: [..., callable, self, *args] + func = ensure_ast(state.stack[-arg_count - 2]) + self_arg = ensure_ast(state.stack[-arg_count - 1]) + remaining_args = ( + [ensure_ast(arg) for arg in state.stack[-arg_count:]] + if arg_count > 0 + else [] + ) + args = [self_arg] + remaining_args + new_stack = state.stack[: -arg_count - 2] + + if isinstance(func, CompLambda): + assert len(args) == 1 + return replace(state, stack=new_stack + [func.inline(args[0])]) + else: + # Create function call AST + call_node = ast.Call(func=func, args=args, keywords=[]) + new_stack = new_stack + [call_node] + return replace(state, stack=new_stack) + + @register_handler("CALL", version=PythonVersion.PY_313) def handle_call( state: ReconstructionState, instr: dis.Instruction @@ -1311,6 +1368,32 @@ def handle_make_function_310( return replace(state, stack=new_stack + [func]) +@register_handler("MAKE_FUNCTION", version=PythonVersion.PY_312) +def handle_make_function_312( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # MAKE_FUNCTION in Python 3.12 uses flags to determine stack consumption + # Unlike 3.10, no qualified name on stack + # Unlike 3.13, uses flags instead of SET_FUNCTION_ATTRIBUTE + assert instr.arg is not None + assert isinstance(state.stack[-1], ast.Lambda | CompLambda), ( + "Expected a function object (Lambda or CompLambda) on the stack." + ) + if instr.argrepr == "closure": + # This is a closure, remove the environment tuple from the stack for AST purposes + new_stack = state.stack[:-2] + elif instr.argrepr == "": + new_stack = state.stack[:-1] + else: + raise NotImplementedError( + "MAKE_FUNCTION with defaults or annotations not implemented." + ) + + # For comprehensions, we only care about the function object + func = state.stack[-1] + return replace(state, stack=new_stack + [func]) + + # Python 3.13 version @register_handler("MAKE_FUNCTION", version=PythonVersion.PY_313) def handle_make_function( @@ -1380,6 +1463,7 @@ def handle_load_attr_310( return replace(state, stack=new_stack) +@register_handler("LOAD_ATTR", version=PythonVersion.PY_312) @register_handler("LOAD_ATTR", version=PythonVersion.PY_313) def handle_load_attr( state: ReconstructionState, instr: dis.Instruction @@ -1392,6 +1476,8 @@ def handle_load_attr( attr_node = ast.Attribute(value=obj, attr=attr_name, ctx=ast.Load()) if instr.argrepr.endswith(" + NULL|self"): new_stack = state.stack[:-1] + [attr_node, Null()] + elif instr.argrepr.startswith("NULL|self + "): + new_stack = state.stack[:-1] + [Null(), attr_node] else: new_stack = state.stack[:-1] + [attr_node] return replace(state, stack=new_stack) @@ -1581,7 +1667,8 @@ def handle_pop_jump_if_false_310( raise NotImplementedError("Lazy and+or behavior not implemented yet") -# Python 3.13 version +# Python 3.12+ version +@register_handler("POP_JUMP_IF_FALSE", version=PythonVersion.PY_312) @register_handler("POP_JUMP_IF_FALSE", version=PythonVersion.PY_313) def handle_pop_jump_if_false( state: ReconstructionState, instr: dis.Instruction @@ -1629,7 +1716,8 @@ def handle_pop_jump_if_true_310( raise NotImplementedError("Lazy and+or behavior not implemented yet") -# Python 3.13 version +# Python 3.12+ version +@register_handler("POP_JUMP_IF_TRUE", version=PythonVersion.PY_312) @register_handler("POP_JUMP_IF_TRUE", version=PythonVersion.PY_313) def handle_pop_jump_if_true( state: ReconstructionState, instr: dis.Instruction @@ -1651,6 +1739,7 @@ def handle_pop_jump_if_true( raise NotImplementedError("Lazy and+or behavior not implemented yet") +@register_handler("POP_JUMP_IF_NONE", version=PythonVersion.PY_312) @register_handler("POP_JUMP_IF_NONE", version=PythonVersion.PY_313) def handle_pop_jump_if_none( state: ReconstructionState, instr: dis.Instruction @@ -1671,6 +1760,7 @@ def handle_pop_jump_if_none( raise NotImplementedError("Lazy and+or behavior not implemented yet") +@register_handler("POP_JUMP_IF_NOT_NONE", version=PythonVersion.PY_312) @register_handler("POP_JUMP_IF_NOT_NONE", version=PythonVersion.PY_313) def handle_pop_jump_if_not_none( state: ReconstructionState, instr: dis.Instruction diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index facab839..cc95a195 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -367,8 +367,9 @@ def test_filtered_generators(genexpr): (x, y, z) for x in range(5) for y in range(5) + if x < y for z in range(5) - if x < y and y < z + if y < z ), (x + y for x in range(3) if x > 0 for y in range(3)), # Mixed range types From 355443287e35f4751edb6fef0ac78cec88d5c084 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 20 Aug 2025 16:09:05 -0400 Subject: [PATCH 074/106] remove 310 support --- effectful/internals/disassembler.py | 354 ---------------------------- 1 file changed, 354 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index eb2ea4c9..120a17e2 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -117,7 +117,6 @@ class ReconstructionState: # Python version enum for version-specific handling class PythonVersion(enum.Enum): - PY_310 = 10 PY_312 = 12 PY_313 = 13 @@ -198,18 +197,6 @@ def _wrapper( # ============================================================================ -@register_handler("GEN_START", version=PythonVersion.PY_310) -def handle_gen_start( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - # GEN_START is the first instruction in generator expressions in Python 3.10 - # It initializes the generator - assert isinstance(state.result, Placeholder), ( - "GEN_START must be the first instruction" - ) - return replace(state, result=ast.GeneratorExp(elt=Placeholder(), generators=[])) - - @register_handler("RETURN_GENERATOR", version=PythonVersion.PY_312) def handle_return_generator_312( state: ReconstructionState, instr: dis.Instruction @@ -237,7 +224,6 @@ def handle_return_generator( return replace(state, result=new_result, stack=new_stack) -@register_handler("YIELD_VALUE", version=PythonVersion.PY_310) @register_handler("YIELD_VALUE", version=PythonVersion.PY_312) @register_handler("YIELD_VALUE", version=PythonVersion.PY_313) def handle_yield_value( @@ -265,7 +251,6 @@ def handle_yield_value( # ============================================================================ -@register_handler("BUILD_LIST", version=PythonVersion.PY_310) @register_handler("BUILD_LIST", version=PythonVersion.PY_312) @register_handler("BUILD_LIST", version=PythonVersion.PY_313) def handle_build_list( @@ -289,7 +274,6 @@ def handle_build_list( return replace(state, stack=new_stack) -@register_handler("LIST_APPEND", version=PythonVersion.PY_310) @register_handler("LIST_APPEND", version=PythonVersion.PY_312) @register_handler("LIST_APPEND", version=PythonVersion.PY_313) def handle_list_append( @@ -315,7 +299,6 @@ def handle_list_append( # ============================================================================ -@register_handler("BUILD_SET", version=PythonVersion.PY_310) @register_handler("BUILD_SET", version=PythonVersion.PY_312) @register_handler("BUILD_SET", version=PythonVersion.PY_313) def handle_build_set( @@ -336,7 +319,6 @@ def handle_build_set( return replace(state, stack=new_stack) -@register_handler("SET_ADD", version=PythonVersion.PY_310) @register_handler("SET_ADD", version=PythonVersion.PY_312) @register_handler("SET_ADD", version=PythonVersion.PY_313) def handle_set_add( @@ -362,7 +344,6 @@ def handle_set_add( # ============================================================================ -@register_handler("BUILD_MAP", version=PythonVersion.PY_310) @register_handler("BUILD_MAP", version=PythonVersion.PY_312) @register_handler("BUILD_MAP", version=PythonVersion.PY_313) def handle_build_map( @@ -389,7 +370,6 @@ def handle_build_map( return replace(state, stack=new_stack) -@register_handler("MAP_ADD", version=PythonVersion.PY_310) @register_handler("MAP_ADD", version=PythonVersion.PY_312) @register_handler("MAP_ADD", version=PythonVersion.PY_313) def handle_map_add( @@ -417,7 +397,6 @@ def handle_map_add( # ============================================================================ -@register_handler("RETURN_VALUE", version=PythonVersion.PY_310) @register_handler("RETURN_VALUE", version=PythonVersion.PY_312) @register_handler("RETURN_VALUE", version=PythonVersion.PY_313) def handle_return_value( @@ -435,7 +414,6 @@ def handle_return_value( raise TypeError("Unexpected RETURN_VALUE in reconstruction") -@register_handler("FOR_ITER", version=PythonVersion.PY_310) @register_handler("FOR_ITER", version=PythonVersion.PY_312) @register_handler("FOR_ITER", version=PythonVersion.PY_313) def handle_for_iter( @@ -470,7 +448,6 @@ def handle_for_iter( return replace(state, stack=new_stack, result=new_ret) -@register_handler("GET_ITER", version=PythonVersion.PY_310) @register_handler("GET_ITER", version=PythonVersion.PY_312) @register_handler("GET_ITER", version=PythonVersion.PY_313) def handle_get_iter( @@ -482,7 +459,6 @@ def handle_get_iter( return state -@register_handler("JUMP_FORWARD", version=PythonVersion.PY_310) @register_handler("JUMP_FORWARD", version=PythonVersion.PY_312) @register_handler("JUMP_FORWARD", version=PythonVersion.PY_313) def handle_jump_forward( @@ -493,15 +469,6 @@ def handle_jump_forward( return state -@register_handler("JUMP_ABSOLUTE", version=PythonVersion.PY_310) -def handle_jump_absolute( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - # JUMP_ABSOLUTE is used to jump back to the beginning of a loop - # In generator expressions, this typically indicates the end of the loop body - return state - - @register_handler("JUMP_BACKWARD", version=PythonVersion.PY_312) @register_handler("JUMP_BACKWARD", version=PythonVersion.PY_313) def handle_jump_backward( @@ -568,7 +535,6 @@ def handle_reraise( # ============================================================================ -@register_handler("LOAD_FAST", version=PythonVersion.PY_310) @register_handler("LOAD_FAST", version=PythonVersion.PY_312) @register_handler("LOAD_FAST", version=PythonVersion.PY_313) def handle_load_fast( @@ -586,7 +552,6 @@ def handle_load_fast( return replace(state, stack=new_stack) -@register_handler("LOAD_DEREF", version=PythonVersion.PY_310) @register_handler("LOAD_DEREF", version=PythonVersion.PY_312) @register_handler("LOAD_DEREF", version=PythonVersion.PY_313) def handle_load_deref( @@ -598,7 +563,6 @@ def handle_load_deref( return replace(state, stack=new_stack) -@register_handler("LOAD_CLOSURE", version=PythonVersion.PY_310) @register_handler("LOAD_CLOSURE", version=PythonVersion.PY_312) @register_handler("LOAD_CLOSURE", version=PythonVersion.PY_313) def handle_load_closure( @@ -610,7 +574,6 @@ def handle_load_closure( return replace(state, stack=new_stack) -@register_handler("LOAD_CONST", version=PythonVersion.PY_310) @register_handler("LOAD_CONST", version=PythonVersion.PY_312) @register_handler("LOAD_CONST", version=PythonVersion.PY_313) def handle_load_const( @@ -621,7 +584,6 @@ def handle_load_const( return replace(state, stack=new_stack) -@register_handler("LOAD_GLOBAL", version=PythonVersion.PY_310) @register_handler("LOAD_GLOBAL", version=PythonVersion.PY_312) @register_handler("LOAD_GLOBAL", version=PythonVersion.PY_313) def handle_load_global( @@ -638,7 +600,6 @@ def handle_load_global( return replace(state, stack=new_stack) -@register_handler("LOAD_NAME", version=PythonVersion.PY_310) @register_handler("LOAD_NAME", version=PythonVersion.PY_312) @register_handler("LOAD_NAME", version=PythonVersion.PY_313) def handle_load_name( @@ -650,7 +611,6 @@ def handle_load_name( return replace(state, stack=new_stack) -@register_handler("STORE_FAST", version=PythonVersion.PY_310) @register_handler("STORE_FAST", version=PythonVersion.PY_312) @register_handler("STORE_FAST", version=PythonVersion.PY_313) def handle_store_fast( @@ -674,7 +634,6 @@ def handle_store_fast( return replace(state, stack=new_stack, result=new_result) -@register_handler("STORE_DEREF", version=PythonVersion.PY_310) @register_handler("STORE_DEREF", version=PythonVersion.PY_312) @register_handler("STORE_DEREF", version=PythonVersion.PY_313) def handle_store_deref( @@ -800,7 +759,6 @@ def handle_copy_free_vars( # ============================================================================ -@register_handler("POP_TOP", version=PythonVersion.PY_310) @register_handler("POP_TOP", version=PythonVersion.PY_312) @register_handler("POP_TOP", version=PythonVersion.PY_313) def handle_pop_top( @@ -813,56 +771,6 @@ def handle_pop_top( return replace(state, stack=new_stack) -@register_handler("DUP_TOP", version=PythonVersion.PY_310) -def handle_dup_top( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - # DUP_TOP duplicates the top stack item - top_item = state.stack[-1] - new_stack = state.stack + [top_item] - return replace(state, stack=new_stack) - - -@register_handler("ROT_TWO", version=PythonVersion.PY_310) -def handle_rot_two( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - # ROT_TWO swaps the top two stack items - new_stack = state.stack[:-2] + [state.stack[-1], state.stack[-2]] - return replace(state, stack=new_stack) - - -@register_handler("ROT_THREE", version=PythonVersion.PY_310) -def handle_rot_three( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - # ROT_THREE rotates the top three stack items - # TOS -> TOS1, TOS1 -> TOS2, TOS2 -> TOS - new_stack = state.stack[:-3] + [state.stack[-2], state.stack[-1], state.stack[-3]] - - # Check if the top two items are the same (from DUP_TOP) - # This heuristic indicates we're setting up for a chained comparison - if len(state.stack) >= 3 and state.stack[-1] == state.stack[-2]: - raise NotImplementedError("Chained comparison not implemented yet") - - return replace(state, stack=new_stack) - - -@register_handler("ROT_FOUR", version=PythonVersion.PY_310) -def handle_rot_four( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - # ROT_FOUR rotates the top four stack items - # TOS -> TOS1, TOS1 -> TOS2, TOS2 -> TOS3, TOS3 -> TOS - new_stack = state.stack[:-4] + [ - state.stack[-2], - state.stack[-1], - state.stack[-4], - state.stack[-3], - ] - return replace(state, stack=new_stack) - - # Python 3.13 replacement for stack manipulation @register_handler("SWAP", version=PythonVersion.PY_312) @register_handler("SWAP", version=PythonVersion.PY_313) @@ -973,69 +881,6 @@ def handle_binary_op( return handle_binop(op, state, instr) -# Legacy binary operation handlers (for Python 3.10 compatibility) -handler_binop_add = register_handler( - "BINARY_ADD", - functools.partial(handle_binop, ast.Add()), - version=PythonVersion.PY_310, -) -handler_binop_subtract = register_handler( - "BINARY_SUBTRACT", - functools.partial(handle_binop, ast.Sub()), - version=PythonVersion.PY_310, -) -handler_binop_multiply = register_handler( - "BINARY_MULTIPLY", - functools.partial(handle_binop, ast.Mult()), - version=PythonVersion.PY_310, -) -handler_binop_true_divide = register_handler( - "BINARY_TRUE_DIVIDE", - functools.partial(handle_binop, ast.Div()), - version=PythonVersion.PY_310, -) -handler_binop_floor_divide = register_handler( - "BINARY_FLOOR_DIVIDE", - functools.partial(handle_binop, ast.FloorDiv()), - version=PythonVersion.PY_310, -) -handler_binop_modulo = register_handler( - "BINARY_MODULO", - functools.partial(handle_binop, ast.Mod()), - version=PythonVersion.PY_310, -) -handler_binop_power = register_handler( - "BINARY_POWER", - functools.partial(handle_binop, ast.Pow()), - version=PythonVersion.PY_310, -) -handler_binop_lshift = register_handler( - "BINARY_LSHIFT", - functools.partial(handle_binop, ast.LShift()), - version=PythonVersion.PY_310, -) -handler_binop_rshift = register_handler( - "BINARY_RSHIFT", - functools.partial(handle_binop, ast.RShift()), - version=PythonVersion.PY_310, -) -handler_binop_or = register_handler( - "BINARY_OR", - functools.partial(handle_binop, ast.BitOr()), - version=PythonVersion.PY_310, -) -handler_binop_xor = register_handler( - "BINARY_XOR", - functools.partial(handle_binop, ast.BitXor()), - version=PythonVersion.PY_310, -) -handler_binop_and = register_handler( - "BINARY_AND", - functools.partial(handle_binop, ast.BitAnd()), - version=PythonVersion.PY_310, -) - - # ============================================================================ # UNARY OPERATION HANDLERS # ============================================================================ @@ -1049,11 +894,6 @@ def handle_unary_op( return replace(state, stack=new_stack) -handle_unary_negative = register_handler( - "UNARY_NEGATIVE", - functools.partial(handle_unary_op, ast.USub()), - version=PythonVersion.PY_310, -) handle_unary_negative = register_handler( "UNARY_NEGATIVE", functools.partial(handle_unary_op, ast.USub()), @@ -1064,16 +904,6 @@ def handle_unary_op( functools.partial(handle_unary_op, ast.USub()), version=PythonVersion.PY_313, ) -handle_unary_positive = register_handler( - "UNARY_POSITIVE", - functools.partial(handle_unary_op, ast.UAdd()), - version=PythonVersion.PY_310, -) -handle_unary_invert = register_handler( - "UNARY_INVERT", - functools.partial(handle_unary_op, ast.Invert()), - version=PythonVersion.PY_310, -) handle_unary_invert = register_handler( "UNARY_INVERT", functools.partial(handle_unary_op, ast.Invert()), @@ -1084,11 +914,6 @@ def handle_unary_op( functools.partial(handle_unary_op, ast.Invert()), version=PythonVersion.PY_313, ) -handle_unary_not = register_handler( - "UNARY_NOT", - functools.partial(handle_unary_op, ast.Not()), - version=PythonVersion.PY_310, -) handle_unary_not = register_handler( "UNARY_NOT", functools.partial(handle_unary_op, ast.Not()), @@ -1101,20 +926,6 @@ def handle_unary_op( ) -@register_handler("LIST_TO_TUPLE", version=PythonVersion.PY_310) -def handle_list_to_tuple( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - # LIST_TO_TUPLE converts a list on the stack to a tuple - list_obj = ensure_ast(state.stack[-1]) - assert isinstance(list_obj, ast.List), "Expected a list for LIST_TO_TUPLE" - - # Create tuple AST from the list's elements - tuple_node = ast.Tuple(elts=list_obj.elts, ctx=ast.Load()) - new_stack = state.stack[:-1] + [tuple_node] - return replace(state, stack=new_stack) - - @register_handler("CALL_INTRINSIC_1", version=PythonVersion.PY_312) @register_handler("CALL_INTRINSIC_1", version=PythonVersion.PY_313) def handle_call_intrinsic_1( @@ -1164,7 +975,6 @@ def handle_to_bool( } -@register_handler("COMPARE_OP", version=PythonVersion.PY_310) @register_handler("COMPARE_OP", version=PythonVersion.PY_312) @register_handler("COMPARE_OP", version=PythonVersion.PY_313) def handle_compare_op( @@ -1184,7 +994,6 @@ def handle_compare_op( return replace(state, stack=new_stack) -@register_handler("CONTAINS_OP", version=PythonVersion.PY_310) @register_handler("CONTAINS_OP", version=PythonVersion.PY_312) @register_handler("CONTAINS_OP", version=PythonVersion.PY_313) def handle_contains_op( @@ -1201,7 +1010,6 @@ def handle_contains_op( return replace(state, stack=new_stack) -@register_handler("IS_OP", version=PythonVersion.PY_310) @register_handler("IS_OP", version=PythonVersion.PY_312) @register_handler("IS_OP", version=PythonVersion.PY_313) def handle_is_op( @@ -1292,82 +1100,6 @@ def handle_call( return replace(state, stack=new_stack) -@register_handler("CALL_FUNCTION", version=PythonVersion.PY_310) -def handle_call_function( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - # CALL_FUNCTION pops function and arguments from stack - assert instr.arg is not None - arg_count: int = instr.arg - # Pop arguments and function - args = ( - [ensure_ast(arg) for arg in state.stack[-arg_count:]] if arg_count > 0 else [] - ) - func = ensure_ast(state.stack[-arg_count - 1]) - new_stack = state.stack[: -arg_count - 1] - - if isinstance(func, CompLambda): - assert len(args) == 1 - return replace(state, stack=new_stack + [func.inline(args[0])]) - else: - # Create function call AST - call_node = ast.Call(func=func, args=args, keywords=[]) - new_stack = new_stack + [call_node] - return replace(state, stack=new_stack) - - -@register_handler("CALL_METHOD", version=PythonVersion.PY_310) -def handle_call_method( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - # CALL_METHOD calls a method - similar to CALL_FUNCTION but for methods - assert instr.arg is not None - arg_count: int = instr.arg - # Pop arguments and method - args = ( - [ensure_ast(arg) for arg in state.stack[-arg_count:]] if arg_count > 0 else [] - ) - method = ensure_ast(state.stack[-arg_count - 2]) - new_stack = state.stack[: -arg_count - 2] - - # Create method call AST - call_node = ast.Call(func=method, args=args, keywords=[]) - new_stack = new_stack + [call_node] - return replace(state, stack=new_stack) - - -# Python 3.10 version -@register_handler("MAKE_FUNCTION", version=PythonVersion.PY_310) -def handle_make_function_310( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - # MAKE_FUNCTION creates a function from code object and name on stack - assert isinstance(state.stack[-2], ast.Lambda | CompLambda) - assert isinstance(state.stack[-1], ast.Constant) and isinstance( - state.stack[-1].value, str - ), "Function name must be a constant string." - if instr.argrepr == "closure": - # This is a closure, remove the environment tuple from the stack for AST purposes - new_stack = state.stack[:-3] - elif instr.argrepr == "": - new_stack = state.stack[:-2] - else: - raise NotImplementedError( - "MAKE_FUNCTION with defaults or annotations not implemented." - ) - - # Pop the function object and name from the stack - # Conversion from CodeType to ast.Lambda should have happened already - func: ast.Lambda | CompLambda = state.stack[-2] - name: str = state.stack[-1].value - - assert any( - name.endswith(suffix) - for suffix in ("", "", "", "", "") - ), f"Expected a comprehension or lambda function, got '{name}'" - return replace(state, stack=new_stack + [func]) - - @register_handler("MAKE_FUNCTION", version=PythonVersion.PY_312) def handle_make_function_312( state: ReconstructionState, instr: dis.Instruction @@ -1430,39 +1162,6 @@ def handle_set_function_attribute( # ============================================================================ -@register_handler("LOAD_METHOD", version=PythonVersion.PY_310) -def handle_load_method( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - # LOAD_METHOD loads a method from an object - # It pushes the bound method and the object (for the method call) - obj = ensure_ast(state.stack[-1]) - method_name = instr.argval - new_stack = state.stack[:-1] - - # Create method access as an attribute - method_attr = ast.Attribute(value=obj, attr=method_name, ctx=ast.Load()) - - # For LOAD_METHOD, we push both the method and the object - # But for AST purposes, we just need the method attribute - new_stack = new_stack + [method_attr, obj] - return replace(state, stack=new_stack) - - -@register_handler("LOAD_ATTR", version=PythonVersion.PY_310) -def handle_load_attr_310( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - # LOAD_ATTR loads an attribute from the object on top of stack - obj = ensure_ast(state.stack[-1]) - attr_name = instr.argval - - # Create attribute access AST - attr_node = ast.Attribute(value=obj, attr=attr_name, ctx=ast.Load()) - new_stack = state.stack[:-1] + [attr_node] - return replace(state, stack=new_stack) - - @register_handler("LOAD_ATTR", version=PythonVersion.PY_312) @register_handler("LOAD_ATTR", version=PythonVersion.PY_313) def handle_load_attr( @@ -1483,7 +1182,6 @@ def handle_load_attr( return replace(state, stack=new_stack) -@register_handler("BINARY_SUBSCR", version=PythonVersion.PY_310) @register_handler("BINARY_SUBSCR", version=PythonVersion.PY_312) @register_handler("BINARY_SUBSCR", version=PythonVersion.PY_313) def handle_binary_subscr( @@ -1505,7 +1203,6 @@ def handle_binary_subscr( # ============================================================================ -@register_handler("UNPACK_SEQUENCE", version=PythonVersion.PY_310) @register_handler("UNPACK_SEQUENCE", version=PythonVersion.PY_312) @register_handler("UNPACK_SEQUENCE", version=PythonVersion.PY_313) def handle_unpack_sequence( @@ -1528,7 +1225,6 @@ def handle_unpack_sequence( return replace(state, stack=new_stack) -@register_handler("BUILD_TUPLE", version=PythonVersion.PY_310) @register_handler("BUILD_TUPLE", version=PythonVersion.PY_312) @register_handler("BUILD_TUPLE", version=PythonVersion.PY_313) def handle_build_tuple( @@ -1550,7 +1246,6 @@ def handle_build_tuple( return replace(state, stack=new_stack) -@register_handler("BUILD_CONST_KEY_MAP", version=PythonVersion.PY_310) @register_handler("BUILD_CONST_KEY_MAP", version=PythonVersion.PY_312) @register_handler("BUILD_CONST_KEY_MAP", version=PythonVersion.PY_313) def handle_build_const_key_map( @@ -1573,7 +1268,6 @@ def handle_build_const_key_map( return replace(state, stack=new_stack) -@register_handler("LIST_EXTEND", version=PythonVersion.PY_310) @register_handler("LIST_EXTEND", version=PythonVersion.PY_312) @register_handler("LIST_EXTEND", version=PythonVersion.PY_313) def handle_list_extend( @@ -1598,7 +1292,6 @@ def handle_list_extend( return replace(state, stack=new_stack, result=prev_result) -@register_handler("SET_UPDATE", version=PythonVersion.PY_310) @register_handler("SET_UPDATE", version=PythonVersion.PY_312) @register_handler("SET_UPDATE", version=PythonVersion.PY_313) def handle_set_update( @@ -1618,7 +1311,6 @@ def handle_set_update( return replace(state, stack=new_stack, result=prev_result) -@register_handler("DICT_UPDATE", version=PythonVersion.PY_310) @register_handler("DICT_UPDATE", version=PythonVersion.PY_312) @register_handler("DICT_UPDATE", version=PythonVersion.PY_313) def handle_dict_update( @@ -1646,27 +1338,6 @@ def handle_dict_update( # ============================================================================ -# Python 3.10 version -@register_handler("POP_JUMP_IF_FALSE", version=PythonVersion.PY_310) -def handle_pop_jump_if_false_310( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - # POP_JUMP_IF_FALSE pops a value from the stack and jumps if it's false - # In comprehensions, this is used for filter conditions - condition = ensure_ast(state.stack[-1]) - new_stack = state.stack[:-1] - - if instr.argval < instr.offset: - # Jumping backward to loop start - this is a condition - # When POP_JUMP_IF_FALSE jumps back, it means "if false, skip this item" - assert isinstance(state.result, CompExp) and state.result.generators - new_result = copy.deepcopy(state.result) - new_result.generators[-1].ifs.append(condition) - return replace(state, stack=new_stack, result=new_result) - else: - raise NotImplementedError("Lazy and+or behavior not implemented yet") - - # Python 3.12+ version @register_handler("POP_JUMP_IF_FALSE", version=PythonVersion.PY_312) @register_handler("POP_JUMP_IF_FALSE", version=PythonVersion.PY_313) @@ -1691,31 +1362,6 @@ def handle_pop_jump_if_false( raise NotImplementedError("Lazy and+or behavior not implemented yet") -# Python 3.10 version -@register_handler("POP_JUMP_IF_TRUE", version=PythonVersion.PY_310) -def handle_pop_jump_if_true_310( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - # POP_JUMP_IF_TRUE pops a value from the stack and jumps if it's true - # This can be: - # 1. Part of an OR expression (jump to YIELD_VALUE) - # 2. A negated condition like "not x % 2" (jump back to loop start) - condition = ensure_ast(state.stack[-1]) - new_stack = state.stack[:-1] - - if instr.argval < instr.offset: - # Jumping backward to loop start - this is a negated condition - # When POP_JUMP_IF_TRUE jumps back, it means "if false, skip this item" - # So we need to negate the condition to get the filter condition - assert isinstance(state.result, CompExp) and state.result.generators - negated_condition = ast.UnaryOp(op=ast.Not(), operand=condition) - new_result = copy.deepcopy(state.result) - new_result.generators[-1].ifs.append(negated_condition) - return replace(state, stack=new_stack, result=new_result) - else: - raise NotImplementedError("Lazy and+or behavior not implemented yet") - - # Python 3.12+ version @register_handler("POP_JUMP_IF_TRUE", version=PythonVersion.PY_312) @register_handler("POP_JUMP_IF_TRUE", version=PythonVersion.PY_313) From 4b6d5e0c8289726140deee4520898cd068e4ba75 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 20 Aug 2025 17:21:00 -0400 Subject: [PATCH 075/106] add code to state --- effectful/internals/disassembler.py | 33 +++++++++++++++-------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index 120a17e2..c0dfadcc 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -16,6 +16,7 @@ """ import ast +import collections import copy import dis import enum @@ -111,9 +112,17 @@ class ReconstructionState: BINARY_ADD pop operands and push results. """ + code: types.CodeType result: ast.expr = field(default_factory=Placeholder) stack: list[ast.expr] = field(default_factory=list) + @property + def instructions(self) -> collections.OrderedDict[int, dis.Instruction]: + """Get the bytecode instructions for the current code object.""" + return collections.OrderedDict( + (instr.offset, instr) for instr in dis.get_instructions(self.code) + ) + # Python version enum for version-specific handling class PythonVersion(enum.Enum): @@ -526,7 +535,6 @@ def handle_reraise( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: # RERAISE re-raises an exception - generally ignore for AST reconstruction - assert not state.stack # in generator expressions, we shouldn't have a stack here return state @@ -787,19 +795,12 @@ def handle_swap( # For AST reconstruction, we can often ignore certain stack manipulations return state - if depth == 2 and stack_size >= 2: - # Equivalent to ROT_TWO - new_stack = state.stack[:-2] + [state.stack[-1], state.stack[-2]] - return replace(state, stack=new_stack) - elif depth <= stack_size: - # For other depths, swap TOS with the item at specified depth - idx = stack_size - depth - new_stack = state.stack.copy() - new_stack[-1], new_stack[idx] = new_stack[idx], new_stack[-1] - return replace(state, stack=new_stack) - else: - # Edge case - not enough items, just return unchanged - return state + # For other depths, swap TOS with the item at specified depth + assert depth <= stack_size, f"SWAP depth {depth} exceeds stack size {stack_size}" + idx = stack_size - depth + new_stack = state.stack.copy() + new_stack[-1], new_stack[idx] = new_stack[idx], new_stack[-1] + return replace(state, stack=new_stack) @register_handler("COPY", version=PythonVersion.PY_312) @@ -1541,8 +1542,8 @@ def _ensure_ast_codeobj(value: types.CodeType) -> ast.Lambda | CompLambda: raise TypeError(f"Unsupported code object type: {value.co_name}") # Symbolic execution to reconstruct the AST - state = ReconstructionState() - for instr in dis.get_instructions(value): + state = ReconstructionState(code=value) + for instr in state.instructions.values(): state = OP_HANDLERS[instr.opname](state, instr) result: ast.expr = state.result From 4bfcb59fdce00f04420a7e718bc871b449298f54 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 20 Aug 2025 17:46:23 -0400 Subject: [PATCH 076/106] compress pop_jump_if --- effectful/internals/disassembler.py | 87 ++++++++++++----------------- 1 file changed, 36 insertions(+), 51 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index c0dfadcc..2db512f1 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -1339,24 +1339,20 @@ def handle_dict_update( # ============================================================================ -# Python 3.12+ version -@register_handler("POP_JUMP_IF_FALSE", version=PythonVersion.PY_312) -@register_handler("POP_JUMP_IF_FALSE", version=PythonVersion.PY_313) -def handle_pop_jump_if_false( - state: ReconstructionState, instr: dis.Instruction +def _handle_pop_jump_if( + f_condition: Callable[[ast.expr], ast.expr], + state: ReconstructionState, + instr: dis.Instruction, ) -> ReconstructionState: - # POP_JUMP_IF_FALSE pops a value from the stack and jumps if it's false - # In comprehensions, this is used for filter conditions - condition = ensure_ast(state.stack[-1]) + # Generic handler for POP_JUMP_IF_* instructions + # Pops a value from the stack and jumps if the condition is met + condition = f_condition(ensure_ast(state.stack[-1])) new_stack = state.stack[:-1] if isinstance(state.result, CompExp) and state.result.generators: - # In Python 3.13, when POP_JUMP_IF_FALSE jumps forward to the yield, - # it means "if condition is False, then yield the item" - # So we need to negate the condition: we want items where NOT condition - negated_condition = ast.UnaryOp(op=ast.Not(), operand=condition) + # In comprehensions, we add the condition to the last generator's ifs new_result = copy.deepcopy(state.result) - new_result.generators[-1].ifs.append(negated_condition) + new_result.generators[-1].ifs.append(condition) return replace(state, stack=new_stack, result=new_result) else: # Not in a comprehension context - might be boolean logic @@ -1371,19 +1367,20 @@ def handle_pop_jump_if_true( ) -> ReconstructionState: # POP_JUMP_IF_TRUE pops a value from the stack and jumps if it's true # In Python 3.13, this is used for filter conditions where True means continue - condition = ensure_ast(state.stack[-1]) - new_stack = state.stack[:-1] + return _handle_pop_jump_if(lambda c: c, state, instr) - # In Python 3.13, if we have a comprehension and generators, this is likely a filter - if isinstance(state.result, CompExp) and state.result.generators: - # For POP_JUMP_IF_TRUE in filters, we want the condition to be true to continue - # So we add the condition directly (no negation needed) - new_result = copy.deepcopy(state.result) - new_result.generators[-1].ifs.append(condition) - return replace(state, stack=new_stack, result=new_result) - else: - # Not in a comprehension context - might be boolean logic - raise NotImplementedError("Lazy and+or behavior not implemented yet") + +# Python 3.12+ version +@register_handler("POP_JUMP_IF_FALSE", version=PythonVersion.PY_312) +@register_handler("POP_JUMP_IF_FALSE", version=PythonVersion.PY_313) +def handle_pop_jump_if_false( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # POP_JUMP_IF_FALSE pops a value from the stack and jumps if it's false + # In comprehensions, this is used for filter conditions + return _handle_pop_jump_if( + lambda c: ast.UnaryOp(op=ast.Not(), operand=c), state, instr + ) @register_handler("POP_JUMP_IF_NONE", version=PythonVersion.PY_312) @@ -1392,19 +1389,13 @@ def handle_pop_jump_if_none( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: # POP_JUMP_IF_NONE pops a value and jumps if it's None - condition = ensure_ast(state.stack[-1]) - new_stack = state.stack[:-1] - - if isinstance(state.result, CompExp) and state.result.generators: - # Create "x is None" condition - none_condition = ast.Compare( - left=condition, ops=[ast.Is()], comparators=[ast.Constant(value=None)] - ) - new_result = copy.deepcopy(state.result) - new_result.generators[-1].ifs.append(none_condition) - return replace(state, stack=new_stack, result=new_result) - else: - raise NotImplementedError("Lazy and+or behavior not implemented yet") + return _handle_pop_jump_if( + lambda c: ast.Compare( + left=c, ops=[ast.Is()], comparators=[ast.Constant(value=None)] + ), + state, + instr, + ) @register_handler("POP_JUMP_IF_NOT_NONE", version=PythonVersion.PY_312) @@ -1413,19 +1404,13 @@ def handle_pop_jump_if_not_none( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: # POP_JUMP_IF_NOT_NONE pops a value and jumps if it's not None - condition = ensure_ast(state.stack[-1]) - new_stack = state.stack[:-1] - - if isinstance(state.result, CompExp) and state.result.generators: - # Create "x is not None" condition - not_none_condition = ast.Compare( - left=condition, ops=[ast.IsNot()], comparators=[ast.Constant(value=None)] - ) - new_result = copy.deepcopy(state.result) - new_result.generators[-1].ifs.append(not_none_condition) - return replace(state, stack=new_stack, result=new_result) - else: - raise NotImplementedError("Lazy and+or behavior not implemented yet") + return _handle_pop_jump_if( + lambda c: ast.Compare( + left=c, ops=[ast.IsNot()], comparators=[ast.Constant(value=None)] + ), + state, + instr, + ) # ============================================================================ From 0d084ae0879c643d1074a95ef351c61f647095e0 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 20 Aug 2025 17:48:35 -0400 Subject: [PATCH 077/106] comment --- effectful/internals/disassembler.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index 2db512f1..55047dcd 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -1359,7 +1359,6 @@ def _handle_pop_jump_if( raise NotImplementedError("Lazy and+or behavior not implemented yet") -# Python 3.12+ version @register_handler("POP_JUMP_IF_TRUE", version=PythonVersion.PY_312) @register_handler("POP_JUMP_IF_TRUE", version=PythonVersion.PY_313) def handle_pop_jump_if_true( @@ -1370,7 +1369,6 @@ def handle_pop_jump_if_true( return _handle_pop_jump_if(lambda c: c, state, instr) -# Python 3.12+ version @register_handler("POP_JUMP_IF_FALSE", version=PythonVersion.PY_312) @register_handler("POP_JUMP_IF_FALSE", version=PythonVersion.PY_313) def handle_pop_jump_if_false( From 13c7973baa0a1c4bad72b5469ea704257d8fec27 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 20 Aug 2025 18:03:58 -0400 Subject: [PATCH 078/106] isolate weird behavior in a separate test --- effectful/internals/disassembler.py | 23 +++++--------- tests/test_internals_disassembler.py | 45 ++++++++++++++++------------ 2 files changed, 34 insertions(+), 34 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index 55047dcd..62d7a269 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -808,23 +808,16 @@ def handle_swap( def handle_copy( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: - # COPY duplicates the item at the specified depth (replaces DUP_TOP in many cases) + # COPY duplicates the item at the specified depth assert instr.arg is not None depth = instr.arg - if depth == 1: - # Equivalent to DUP_TOP - top_item = state.stack[-1] - new_stack = state.stack + [top_item] - return replace(state, stack=new_stack) - else: - # Copy the item at specified depth to top of stack - stack_size = len(state.stack) - if depth > stack_size: - raise ValueError(f"COPY depth {depth} exceeds stack size {stack_size}") - idx = stack_size - depth - copied_item = state.stack[idx] - new_stack = state.stack + [copied_item] - return replace(state, stack=new_stack) + stack_size = len(state.stack) + if depth > stack_size: + raise ValueError(f"COPY depth {depth} exceeds stack size {stack_size}") + idx = stack_size - depth + copied_item = state.stack[idx] + new_stack = state.stack + [copied_item] + return replace(state, stack=new_stack) @register_handler("PUSH_NULL", version=PythonVersion.PY_312) diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index cc95a195..a1b9e26e 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -1,4 +1,6 @@ import ast +import contextlib +import sys from types import GeneratorType from typing import Any @@ -363,14 +365,7 @@ def test_filtered_generators(genexpr): for w in range(2) ), # Nested loops with complex filters - ( - (x, y, z) - for x in range(5) - for y in range(5) - if x < y - for z in range(5) - if y < z - ), + ((x, y) for x in range(5) if x > 1 for y in range(5) if x < y), (x + y for x in range(3) if x > 0 for y in range(3)), # Mixed range types ((x, y) for x in range(-2, 2) for y in range(0, 4, 2)), @@ -378,12 +373,6 @@ def test_filtered_generators(genexpr): # Dependent nested loops ((x, y) for x in range(3) for y in range(x, 3)), (x + y for x in range(3) for y in range(x + 1, 3)), - ( - x * y * z - for x in range(3) - for y in range(x + 1, x + 3) - for z in range(y, y + 3) - ), ], ) def test_nested_loops(genexpr): @@ -427,6 +416,29 @@ def test_nested_comprehensions(genexpr): assert_ast_equivalent(genexpr, ast_node) +def test_nested_comprehensions_multiline_fail(): + """Illustrate bug in dis for multiline comprehensions""" + xs1 = (x for x in range(5) if x > 1) + xs2 = ( + x + for x in range(5) # comment to avoid linter warning + if x > 1 + ) + + ast_node_1 = reconstruct(xs1) + assert_ast_equivalent(xs1, ast_node_1) + + with ( + contextlib.nullcontext() + if sys.version_info[:2] > (3, 12) + else pytest.xfail( + reason="Multiline comprehensions are not handled correctly by disassembler" + ) + ): + ast_node_2 = reconstruct(xs2) + assert_ast_equivalent(xs2, ast_node_2) + + # ============================================================================ # DIFFERENT COMPREHENSION TYPES # ============================================================================ @@ -452,11 +464,6 @@ def test_nested_comprehensions(genexpr): ([x for x in range(i)] for i in range(5) if i > 0), ([x for x in range(i) if x < i] for i in range(5) if i > 0), ([[x for x in range(i + j) if x < i + j] for j in range(i)] for i in range(5)), - ( - [[x for x in range(i + j) if x < i + j] for j in range(i)] - for i in range(5) - if i > 0 - ), ], ) def test_different_comprehension_types(genexpr): From 47be4f04c9fc4adf6e76ce4fe45360f2cdadc0f1 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 20 Aug 2025 18:24:02 -0400 Subject: [PATCH 079/106] add test illustrating bug --- tests/test_internals_disassembler.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index a1b9e26e..f6758f87 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -1,6 +1,5 @@ import ast -import contextlib -import sys +import dis from types import GeneratorType from typing import Any @@ -416,27 +415,22 @@ def test_nested_comprehensions(genexpr): assert_ast_equivalent(genexpr, ast_node) +@pytest.mark.xfail(reason="bug in builtin module dis breaks this test in Python 3.12") def test_nested_comprehensions_multiline_fail(): """Illustrate bug in dis for multiline comprehensions""" + # this part works - dis.dis correctly reconstructs the generator expression xs1 = (x for x in range(5) if x > 1) + assert any(i.opname == "POP_JUMP_IF_TRUE" for i in dis.get_instructions(xs1)) + assert_ast_equivalent(xs1, reconstruct(xs1)) + + # this part fails - dis.dis incorrectly negates the filter expression x > 1 xs2 = ( x - for x in range(5) # comment to avoid linter warning + for x in range(5) # comment to avoid reformatting if x > 1 ) - - ast_node_1 = reconstruct(xs1) - assert_ast_equivalent(xs1, ast_node_1) - - with ( - contextlib.nullcontext() - if sys.version_info[:2] > (3, 12) - else pytest.xfail( - reason="Multiline comprehensions are not handled correctly by disassembler" - ) - ): - ast_node_2 = reconstruct(xs2) - assert_ast_equivalent(xs2, ast_node_2) + assert_ast_equivalent(xs2, reconstruct(xs2)) + assert any(i.opname == "POP_JUMP_IF_TRUE" for i in dis.get_instructions(xs2)) # ============================================================================ From d3a271f67dd7261cc44fe3cc929f274c0ec5c45e Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 20 Aug 2025 18:51:36 -0400 Subject: [PATCH 080/106] doc --- effectful/internals/disassembler.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index 62d7a269..1f93ea23 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -101,6 +101,9 @@ class ReconstructionState: that represent those operations. Attributes: + code: The compiled code object from which the bytecode is being processed. + This is typically obtained from a generator function or comprehension. + result: The current comprehension expression being built. Initially a placeholder, it gets updated as the bytecode is processed. It can be a GeneratorExp, ListComp, SetComp, DictComp, or From e35f62e4e7a501febe47bfa4035b1b632b97c27f Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 20 Aug 2025 18:53:05 -0400 Subject: [PATCH 081/106] comment --- effectful/internals/disassembler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index 1f93ea23..82871286 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -845,17 +845,17 @@ def handle_binop( return replace(state, stack=new_stack) -# Python 3.13 BINARY_OP handler +# Python 3.12+ BINARY_OP handler @register_handler("BINARY_OP", version=PythonVersion.PY_312) @register_handler("BINARY_OP", version=PythonVersion.PY_313) def handle_binary_op( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: - # BINARY_OP in Python 3.13 consolidates all binary operations + # BINARY_OP in Python 3.12+ consolidates all binary operations # The operation type is determined by the instruction argument assert instr.arg is not None - # Map argument values to AST operators based on Python 3.13 implementation + # Map argument values to AST operators based on Python 3.12+ implementation op_map = { 0: ast.Add(), # + 1: ast.BitAnd(), # & From 4c68357c4b3d670e9770c77a4c120c989f783998 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 20 Aug 2025 19:07:34 -0400 Subject: [PATCH 082/106] slices --- effectful/internals/disassembler.py | 49 ++++++++++++++++++++++++++++ tests/test_internals_disassembler.py | 7 ++++ 2 files changed, 56 insertions(+) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index 82871286..314d20a3 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -1195,6 +1195,24 @@ def handle_binary_subscr( return replace(state, stack=new_stack) +@register_handler("BINARY_SLICE", version=PythonVersion.PY_312) +@register_handler("BINARY_SLICE", version=PythonVersion.PY_313) +def handle_binary_slice( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # BINARY_SLICE implements obj[start:end] - pops start, end, and obj from stack + end = ensure_ast(state.stack[-1]) + start = ensure_ast(state.stack[-2]) + container = ensure_ast(state.stack[-3]) # Object is below start and end + sliced = ast.Subscript( + value=container, + slice=ast.Slice(lower=start, upper=end, step=None), + ctx=ast.Load(), + ) + new_stack = state.stack[:-3] + [sliced] + return replace(state, stack=new_stack) + + # ============================================================================ # OTHER CONTAINER BUILDING HANDLERS # ============================================================================ @@ -1243,6 +1261,37 @@ def handle_build_tuple( return replace(state, stack=new_stack) +@register_handler("BUILD_SLICE", version=PythonVersion.PY_312) +@register_handler("BUILD_SLICE", version=PythonVersion.PY_313) +def handle_build_slice( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # BUILD_SLICE creates a slice object from the top of the stack + # The number of elements to pop is determined by the instruction argument + assert instr.arg is not None + slice_size: int = instr.arg + + if slice_size == 2: + # Slice with start and end: [start, end] + end = ensure_ast(state.stack[-1]) + start = ensure_ast(state.stack[-2]) + new_stack = state.stack[:-2] + slice_node = ast.Slice(lower=start, upper=end, step=None) + elif slice_size == 3: + # Slice with start, end, and step: [start, end, step] + step = ensure_ast(state.stack[-1]) + end = ensure_ast(state.stack[-2]) + start = ensure_ast(state.stack[-3]) + new_stack = state.stack[:-3] + slice_node = ast.Slice(lower=start, upper=end, step=step) + else: + raise ValueError(f"Unsupported slice size: {slice_size}") + + # Create slice AST + new_stack = new_stack + [slice_node] + return replace(state, stack=new_stack) + + @register_handler("BUILD_CONST_KEY_MAP", version=PythonVersion.PY_312) @register_handler("BUILD_CONST_KEY_MAP", version=PythonVersion.PY_313) def handle_build_const_key_map( diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index f6758f87..6ad98dbb 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -512,6 +512,13 @@ def test_variable_lookup(genexpr, globals_dict): ((x.real for x in [1 + 2j, 3 + 4j, 5 + 6j]), {}), ((x.imag for x in [1 + 2j, 3 + 4j, 5 + 6j]), {}), ((x.conjugate() for x in [1 + 2j, 3 + 4j, 5 + 6j]), {}), + # slicing and indexing + ((s[:2] for s in ["hello", "world"]), {}), + ((s[1:3] for s in ["hello", "world"]), {}), + ((s[-1] for s in ["hello", "world"]), {}), + ((s[0:3] for s in ["hello", "world"]), {}), + ((s[::-1] for s in ["hello", "world"]), {}), + ((s[1:2:] for s in ["hello", "world"]), {}), # Method calls ((s.upper() for s in ["hello", "world"]), {}), ((s.lower() for s in ["HELLO", "WORLD"]), {}), From 89b33d5b183d6581f6bc6631ba3d16e498fc5de9 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 20 Aug 2025 19:25:42 -0400 Subject: [PATCH 083/106] start fstring --- effectful/internals/disassembler.py | 80 ++++++++++++++++++++++++++++ tests/test_internals_disassembler.py | 4 ++ 2 files changed, 84 insertions(+) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index 314d20a3..6ef9775d 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -1379,6 +1379,86 @@ def handle_dict_update( return replace(state, stack=new_stack, result=prev_result) +@register_handler("BUILD_STRING", version=PythonVersion.PY_312) +@register_handler("BUILD_STRING", version=PythonVersion.PY_313) +def handle_build_string( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # BUILD_STRING creates a string from the top of the stack + # The number of elements to pop is determined by the instruction argument + assert instr.arg is not None + string_size: int = instr.arg + + if string_size == 0: + # Empty string case + new_stack = state.stack + [ast.Constant(value="")] + else: + # Pop elements for the string + elements = ( + [ensure_ast(elem) for elem in state.stack[-string_size:]] + if string_size > 0 + else [] + ) + new_stack = state.stack[:-string_size] + + # Create concatenated string AST + concat_node = ast.JoinedStr(values=elements) + new_stack = new_stack + [concat_node] + + return replace(state, stack=new_stack) + + +@register_handler("FORMAT_VALUE", version=PythonVersion.PY_312) +def handle_format_value( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # FORMAT_VALUE formats a string with a value + # Pops the value and the format string from the stack + assert len(state.stack) >= 1, "Not enough items on stack for FORMAT_VALUE" + value = ensure_ast(state.stack[-1]) + new_stack = state.stack[:-1] + + # Create formatted string AST + formatted_node = ast.FormattedValue(value=value, conversion=-1, format_spec=None) + new_stack = new_stack + [formatted_node] + return replace(state, stack=new_stack) + + +@register_handler("FORMAT_SIMPLE", version=PythonVersion.PY_313) +def handle_format_simple( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # FORMAT_SIMPLE formats a string with a single value + # Pops the value and the format string from the stack + assert len(state.stack) >= 1, "Not enough items on stack for FORMAT_SIMPLE" + value = ensure_ast(state.stack[-1]) + new_stack = state.stack[:-1] + + # Create formatted string AST + formatted_node = ast.FormattedValue(value=value, conversion=-1, format_spec=None) + new_stack = new_stack + [formatted_node] + return replace(state, stack=new_stack) + + +@register_handler("FORMAT_WITH_SPEC", version=PythonVersion.PY_313) +def handle_format_with_spec( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # FORMAT_WITH_SPEC formats a string with a value and a format specifier + # Pops the value, format string, and format specifier from the stack + assert len(state.stack) >= 2, "Not enough items on stack for FORMAT_WITH_SPEC" + value = ensure_ast(state.stack[-1]) + format_string = ensure_ast(state.stack[-2]) + new_stack = state.stack[:-2] + + # Create formatted string AST with specifier + formatted_node = ast.FormattedValue( + value=value, conversion=-1, format_spec=format_string + ) + new_stack = new_stack + [formatted_node] + return replace(state, stack=new_stack) + + # ============================================================================ # CONDITIONAL JUMP HANDLERS # ============================================================================ diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index 6ad98dbb..457e7c60 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -519,6 +519,10 @@ def test_variable_lookup(genexpr, globals_dict): ((s[0:3] for s in ["hello", "world"]), {}), ((s[::-1] for s in ["hello", "world"]), {}), ((s[1:2:] for s in ["hello", "world"]), {}), + # fstrings and formatted strings + ((f"{x} is {x**2}" for x in range(5)), {}), + ((f"{x:02d}" for x in range(10)), {}), + ((f"{x:.2f}" for x in [1.2345, 2.3456, 3.4567]), {}), # Method calls ((s.upper() for s in ["hello", "world"]), {}), ((s.lower() for s in ["HELLO", "WORLD"]), {}), From ce5453653d7007178c8aa981775b305be3dde53c Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 20 Aug 2025 20:22:04 -0400 Subject: [PATCH 084/106] remove tree --- tests/test_internals_disassembler.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index 457e7c60..c69c803c 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -1,10 +1,10 @@ import ast +import collections.abc import dis from types import GeneratorType from typing import Any import pytest -import tree from effectful.internals.disassembler import reconstruct @@ -28,18 +28,22 @@ def compile_and_eval( return eval(code, globals_dict) -def materialize(genexpr: GeneratorType) -> tree.Structure: +def materialize(genexpr: GeneratorType) -> list: """Materialize a nested generator expression to a nested list.""" def _materialize(genexpr): - if isinstance(genexpr, GeneratorType): - return tree.map_structure(_materialize, list(genexpr)) - elif tree.is_nested(genexpr): - return tree.map_structure(_materialize, genexpr) + if isinstance(genexpr, str | bytes): + return genexpr + elif isinstance(genexpr, collections.abc.Generator): + return _materialize(list(genexpr)) + elif isinstance(genexpr, collections.abc.Sequence | collections.abc.Set): + return [_materialize(item) for item in genexpr] + elif isinstance(genexpr, collections.abc.Mapping): + return {_materialize(k): _materialize(v) for k, v in genexpr.items()} else: return genexpr - return _materialize(genexpr) + return [_materialize(x) for x in genexpr] def assert_ast_equivalent( From 9d794e4ff054ef488410256747b2161b43d0de4e Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 20 Aug 2025 20:24:49 -0400 Subject: [PATCH 085/106] remove tree --- tests/test_internals_disassembler.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index c69c803c..9cca4a1f 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -35,9 +35,11 @@ def _materialize(genexpr): if isinstance(genexpr, str | bytes): return genexpr elif isinstance(genexpr, collections.abc.Generator): - return _materialize(list(genexpr)) - elif isinstance(genexpr, collections.abc.Sequence | collections.abc.Set): return [_materialize(item) for item in genexpr] + elif isinstance(genexpr, collections.abc.Sequence): + return [_materialize(item) for item in genexpr] + elif isinstance(genexpr, collections.abc.Set): + return {_materialize(item) for item in genexpr} elif isinstance(genexpr, collections.abc.Mapping): return {_materialize(k): _materialize(v) for k, v in genexpr.items()} else: From d18631b2272ffc8b9cea6d944715f3ce4d5a5dd5 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 20 Aug 2025 20:37:32 -0400 Subject: [PATCH 086/106] rename file --- effectful/internals/{disassembler.py => disassembly.py} | 0 tests/test_internals_disassembler.py | 6 +++--- 2 files changed, 3 insertions(+), 3 deletions(-) rename effectful/internals/{disassembler.py => disassembly.py} (100%) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembly.py similarity index 100% rename from effectful/internals/disassembler.py rename to effectful/internals/disassembly.py diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index 9cca4a1f..20666f1b 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -6,7 +6,7 @@ import pytest -from effectful.internals.disassembler import reconstruct +from effectful.internals.disassembly import reconstruct def compile_and_eval( @@ -666,7 +666,7 @@ def test_complex_scenarios(genexpr, globals_dict): ) def test_ensure_ast(value, expected_str): """Test that ensure_ast correctly converts various values to AST nodes.""" - from effectful.internals.disassembler import ensure_ast + from effectful.internals.disassembly import ensure_ast result = ensure_ast(value) @@ -694,7 +694,7 @@ def test_comp_lambda_copy(): """Test that CompLambda is compatible with copy.copy and copy.deepcopy.""" import copy - from effectful.internals.disassembler import CompLambda, DummyIterName + from effectful.internals.disassembly import CompLambda, DummyIterName # Create a test generator expression AST genexpr_ast = ast.GeneratorExp( From bb5d01f2b2e1dfd81aa8724a1efd2baccb486f49 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 20 Aug 2025 20:38:36 -0400 Subject: [PATCH 087/106] rename reconstruct -> disassemble --- effectful/internals/disassembly.py | 6 +++--- tests/test_internals_disassembler.py | 30 ++++++++++++++-------------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/effectful/internals/disassembly.py b/effectful/internals/disassembly.py index 6ef9775d..7d6693fc 100644 --- a/effectful/internals/disassembly.py +++ b/effectful/internals/disassembly.py @@ -11,7 +11,7 @@ Example: >>> g = (x * 2 for x in range(10) if x % 2 == 0) - >>> ast_node = reconstruct(g) + >>> ast_node = disassemble(g) >>> # ast_node is now an ast.Expression representing the original expression """ @@ -1745,7 +1745,7 @@ def _ensure_ast_genexpr(genexpr: types.GeneratorType) -> ast.GeneratorExp: # ============================================================================ -def reconstruct(genexpr: Generator[object, None, None]) -> ast.Expression: +def disassemble(genexpr: Generator[object, None, None]) -> ast.Expression: """ Reconstruct an AST from a generator expression's bytecode. @@ -1776,7 +1776,7 @@ def reconstruct(genexpr: Generator[object, None, None]) -> ast.Expression: Example: >>> # Generator expression >>> g = (x * 2 for x in range(10) if x % 2 == 0) - >>> ast_node = reconstruct(g) + >>> ast_node = disassemble(g) >>> isinstance(ast_node, ast.Expression) True diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index 20666f1b..a2f54e13 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -6,7 +6,7 @@ import pytest -from effectful.internals.disassembly import reconstruct +from effectful.internals.disassembly import disassemble def compile_and_eval( @@ -106,7 +106,7 @@ def assert_ast_equivalent( ) def test_simple_generators(genexpr): """Test reconstruction of simple generator expressions.""" - ast_node = reconstruct(genexpr) + ast_node = disassemble(genexpr) assert_ast_equivalent(genexpr, ast_node) @@ -160,7 +160,7 @@ def test_simple_generators(genexpr): ) def test_arithmetic_expressions(genexpr): """Test reconstruction of generators with arithmetic expressions.""" - ast_node = reconstruct(genexpr) + ast_node = disassemble(genexpr) assert_ast_equivalent(genexpr, ast_node) @@ -232,7 +232,7 @@ def test_arithmetic_expressions(genexpr): ) def test_comparison_operators(genexpr): """Test reconstruction of all comparison operators.""" - ast_node = reconstruct(genexpr) + ast_node = disassemble(genexpr) assert_ast_equivalent(genexpr, ast_node) @@ -252,7 +252,7 @@ def test_comparison_operators(genexpr): ) def test_chained_comparison_operators(genexpr): """Test reconstruction of chained (ternary) comparison operators.""" - ast_node = reconstruct(genexpr) + ast_node = disassemble(genexpr) assert_ast_equivalent(genexpr, ast_node) @@ -321,7 +321,7 @@ def test_chained_comparison_operators(genexpr): ) def test_filtered_generators(genexpr): """Test reconstruction of generators with if conditions.""" - ast_node = reconstruct(genexpr) + ast_node = disassemble(genexpr) assert_ast_equivalent(genexpr, ast_node) @@ -382,7 +382,7 @@ def test_filtered_generators(genexpr): ) def test_nested_loops(genexpr): """Test reconstruction of generators with nested loops.""" - ast_node = reconstruct(genexpr) + ast_node = disassemble(genexpr) assert_ast_equivalent(genexpr, ast_node) @@ -417,7 +417,7 @@ def test_nested_loops(genexpr): ) def test_nested_comprehensions(genexpr): """Test reconstruction of nested comprehensions.""" - ast_node = reconstruct(genexpr) + ast_node = disassemble(genexpr) assert_ast_equivalent(genexpr, ast_node) @@ -427,7 +427,7 @@ def test_nested_comprehensions_multiline_fail(): # this part works - dis.dis correctly reconstructs the generator expression xs1 = (x for x in range(5) if x > 1) assert any(i.opname == "POP_JUMP_IF_TRUE" for i in dis.get_instructions(xs1)) - assert_ast_equivalent(xs1, reconstruct(xs1)) + assert_ast_equivalent(xs1, disassemble(xs1)) # this part fails - dis.dis incorrectly negates the filter expression x > 1 xs2 = ( @@ -435,7 +435,7 @@ def test_nested_comprehensions_multiline_fail(): for x in range(5) # comment to avoid reformatting if x > 1 ) - assert_ast_equivalent(xs2, reconstruct(xs2)) + assert_ast_equivalent(xs2, disassemble(xs2)) assert any(i.opname == "POP_JUMP_IF_TRUE" for i in dis.get_instructions(xs2)) @@ -468,7 +468,7 @@ def test_nested_comprehensions_multiline_fail(): ) def test_different_comprehension_types(genexpr): """Test reconstruction of different comprehension types.""" - ast_node = reconstruct(genexpr) + ast_node = disassemble(genexpr) assert_ast_equivalent(genexpr, ast_node) @@ -493,7 +493,7 @@ def test_different_comprehension_types(genexpr): ) def test_variable_lookup(genexpr, globals_dict): """Test reconstruction of expressions with globals.""" - ast_node = reconstruct(genexpr) + ast_node = disassemble(genexpr) # Need to provide the same globals for evaluation assert_ast_equivalent(genexpr, ast_node, globals_dict) @@ -564,7 +564,7 @@ def test_variable_lookup(genexpr, globals_dict): ) def test_complex_scenarios(genexpr, globals_dict): """Test reconstruction of complex generator expressions.""" - ast_node = reconstruct(genexpr) + ast_node = disassemble(genexpr) # Need to provide the same globals for evaluation assert_ast_equivalent(genexpr, ast_node, globals_dict) @@ -681,13 +681,13 @@ def test_error_handling(): """Test that appropriate errors are raised for unsupported cases.""" # Test with non-generator input with pytest.raises(AssertionError): - reconstruct([1, 2, 3]) # Not a generator + disassemble([1, 2, 3]) # Not a generator # Test with consumed generator gen = (x for x in range(5)) list(gen) # Consume it with pytest.raises(AssertionError): - reconstruct(gen) + disassemble(gen) def test_comp_lambda_copy(): From 5d6939ac64040949b4d536af4e0261254b2fd8b4 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 20 Aug 2025 20:49:54 -0400 Subject: [PATCH 088/106] name --- effectful/internals/disassembly.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/effectful/internals/disassembly.py b/effectful/internals/disassembly.py index 7d6693fc..f54f4735 100644 --- a/effectful/internals/disassembly.py +++ b/effectful/internals/disassembly.py @@ -5,7 +5,7 @@ generator expressions by analyzing their bytecode. The primary use case is to recover the original structure of generator comprehensions from their compiled form. -The only public-facing interface is the `reconstruct` function, which takes a +The only public-facing interface is the `disassemble()` function, which takes a generator object and returns an AST node representing the original comprehension. All other functions and classes in this module are internal implementation details. From e886dbb2bd83aa494740479dda043b77f64a44e5 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 21 Aug 2025 13:29:48 -0400 Subject: [PATCH 089/106] conditional expression test --- effectful/internals/disassembly.py | 2 +- tests/test_internals_disassembler.py | 31 ++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/effectful/internals/disassembly.py b/effectful/internals/disassembly.py index f54f4735..8dd20307 100644 --- a/effectful/internals/disassembly.py +++ b/effectful/internals/disassembly.py @@ -232,7 +232,7 @@ def handle_return_generator( "RETURN_GENERATOR must be the first instruction" ) new_result = ast.GeneratorExp(elt=Placeholder(), generators=[]) - new_stack = state.stack + [new_result] + new_stack = state.stack + [Null()] return replace(state, result=new_result, stack=new_stack) diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index a2f54e13..3ffa0bbc 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -472,6 +472,37 @@ def test_different_comprehension_types(genexpr): assert_ast_equivalent(genexpr, ast_node) +# ============================================================================ +# CONDITIONAL EXPRESSIONS +# ============================================================================ + + +@pytest.mark.xfail(reason="Conditional expressions not yet properly supported") +@pytest.mark.parametrize("genexpr", [ + # simple conditional expressions without nesting + ((lambda x: x if x % 2 == 0 else -x)(xi) for xi in range(5)), + ((lambda x: (x + 1) if x < 5 else (x - 1))(xi) for xi in range(10)), + ((lambda x: (x * 2) if x > 0 else (x / 2))(xi) for xi in range(-5, 5)), + ((lambda x: (x**2) if x != 0 else 1)(xi) for xi in range(-3, 4)), + # simple conditional expressions with negation + ((lambda x: (x + 10) if not (x < 5) else (x - 10))(xi) for xi in range(20)), + ((lambda x: (x * 3) if not (x % 2 == 0) else (x // 3))(xi) for xi in range(10)), + ((lambda x: (x**3) if not (x < 0) else (x**0.5))(xi) for xi in range(-5, 15)), + # conditional expressions with lazy test + ((lambda x: (x + 10) if (x > 5 and x < 15) else (x - 10))(xi) for xi in range(20)), + ((lambda x: (x * 3) if (x % 2 == 0 or x % 3 == 0) else (x // 3))(xi) for xi in range(10)), + ((lambda x: (x**3) if not (x < 0 or x > 10) else (x**0.5))(xi) for xi in range(-5, 15)), + # nested conditional expressions + ((lambda x: (x + 1) if x < 5 else ((x - 1) if x < 10 else (x * 2)))(xi) for xi in range(15)), + ((lambda x: (x * 2) if x % 2 == 0 else ((x // 2) if x % 3 == 0 else (x + 2)))(xi) for xi in range(10)), + ((lambda x: (x**2) if x > 0 else ((-x)**2 if x < -5 else 1))(xi) for xi in range(-10, 5)), +]) +def test_conditional_expressions_no_comprehension(genexpr): + """Test reconstruction of conditional expressions isolated from comprehension bodies.""" + ast_node = disassemble(genexpr) + assert_ast_equivalent(genexpr, ast_node) + + # ============================================================================ # GENERATOR EXPRESSION WITH GLOBALS # ============================================================================ From 05f22b36c18ed1e5e8a21794707fd5818fe3c359 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 21 Aug 2025 13:30:35 -0400 Subject: [PATCH 090/106] lint --- tests/test_internals_disassembler.py | 63 +++++++++++++++++++--------- 1 file changed, 44 insertions(+), 19 deletions(-) diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index 3ffa0bbc..f4b1a3ec 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -478,25 +478,50 @@ def test_different_comprehension_types(genexpr): @pytest.mark.xfail(reason="Conditional expressions not yet properly supported") -@pytest.mark.parametrize("genexpr", [ - # simple conditional expressions without nesting - ((lambda x: x if x % 2 == 0 else -x)(xi) for xi in range(5)), - ((lambda x: (x + 1) if x < 5 else (x - 1))(xi) for xi in range(10)), - ((lambda x: (x * 2) if x > 0 else (x / 2))(xi) for xi in range(-5, 5)), - ((lambda x: (x**2) if x != 0 else 1)(xi) for xi in range(-3, 4)), - # simple conditional expressions with negation - ((lambda x: (x + 10) if not (x < 5) else (x - 10))(xi) for xi in range(20)), - ((lambda x: (x * 3) if not (x % 2 == 0) else (x // 3))(xi) for xi in range(10)), - ((lambda x: (x**3) if not (x < 0) else (x**0.5))(xi) for xi in range(-5, 15)), - # conditional expressions with lazy test - ((lambda x: (x + 10) if (x > 5 and x < 15) else (x - 10))(xi) for xi in range(20)), - ((lambda x: (x * 3) if (x % 2 == 0 or x % 3 == 0) else (x // 3))(xi) for xi in range(10)), - ((lambda x: (x**3) if not (x < 0 or x > 10) else (x**0.5))(xi) for xi in range(-5, 15)), - # nested conditional expressions - ((lambda x: (x + 1) if x < 5 else ((x - 1) if x < 10 else (x * 2)))(xi) for xi in range(15)), - ((lambda x: (x * 2) if x % 2 == 0 else ((x // 2) if x % 3 == 0 else (x + 2)))(xi) for xi in range(10)), - ((lambda x: (x**2) if x > 0 else ((-x)**2 if x < -5 else 1))(xi) for xi in range(-10, 5)), -]) +@pytest.mark.parametrize( + "genexpr", + [ + # simple conditional expressions without nesting + ((lambda x: x if x % 2 == 0 else -x)(xi) for xi in range(5)), + ((lambda x: (x + 1) if x < 5 else (x - 1))(xi) for xi in range(10)), + ((lambda x: (x * 2) if x > 0 else (x / 2))(xi) for xi in range(-5, 5)), + ((lambda x: (x**2) if x != 0 else 1)(xi) for xi in range(-3, 4)), + # simple conditional expressions with negation + ((lambda x: (x + 10) if not (x < 5) else (x - 10))(xi) for xi in range(20)), + ((lambda x: (x * 3) if not (x % 2 == 0) else (x // 3))(xi) for xi in range(10)), + ((lambda x: (x**3) if not (x < 0) else (x**0.5))(xi) for xi in range(-5, 15)), + # conditional expressions with lazy test + ( + (lambda x: (x + 10) if (x > 5 and x < 15) else (x - 10))(xi) + for xi in range(20) + ), + ( + (lambda x: (x * 3) if (x % 2 == 0 or x % 3 == 0) else (x // 3))(xi) + for xi in range(10) + ), + ( + (lambda x: (x**3) if not (x < 0 or x > 10) else (x**0.5))(xi) + for xi in range(-5, 15) + ), + # nested conditional expressions + ( + (lambda x: (x + 1) if x < 5 else ((x - 1) if x < 10 else (x * 2)))(xi) + for xi in range(15) + ), + ( + ( + lambda x: (x * 2) + if x % 2 == 0 + else ((x // 2) if x % 3 == 0 else (x + 2)) + )(xi) + for xi in range(10) + ), + ( + (lambda x: (x**2) if x > 0 else ((-x) ** 2 if x < -5 else 1))(xi) + for xi in range(-10, 5) + ), + ], +) def test_conditional_expressions_no_comprehension(genexpr): """Test reconstruction of conditional expressions isolated from comprehension bodies.""" ast_node = disassemble(genexpr) From cff32f449c4b31d190b25024c9fa57ecaac36889 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 21 Aug 2025 13:35:24 -0400 Subject: [PATCH 091/106] separate fstring test --- tests/test_internals_disassembler.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index f4b1a3ec..fee4b1df 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -164,6 +164,27 @@ def test_arithmetic_expressions(genexpr): assert_ast_equivalent(genexpr, ast_node) +# ============================================================================ +# FSTRING EXPRESSIONS +# ============================================================================ + + +@pytest.mark.xfail(reason="f-string expressions not yet fully supported") +@pytest.mark.parametrize( + "genexpr", + [ + # simple cases + (f"{x} is {x**2}" for x in range(5)), + (f"{x:02d}" for x in range(10)), + (f"{x:.2f}" for x in [1.2345, 2.3456, 3.4567]), + ], +) +def test_fstring_expressions(genexpr): + """Test reconstruction of generators with f-string expressions.""" + ast_node = disassemble(genexpr) + assert_ast_equivalent(genexpr, ast_node) + + # ============================================================================ # COMPARISON OPERATORS # ============================================================================ @@ -581,10 +602,6 @@ def test_variable_lookup(genexpr, globals_dict): ((s[0:3] for s in ["hello", "world"]), {}), ((s[::-1] for s in ["hello", "world"]), {}), ((s[1:2:] for s in ["hello", "world"]), {}), - # fstrings and formatted strings - ((f"{x} is {x**2}" for x in range(5)), {}), - ((f"{x:02d}" for x in range(10)), {}), - ((f"{x:.2f}" for x in [1.2345, 2.3456, 3.4567]), {}), # Method calls ((s.upper() for s in ["hello", "world"]), {}), ((s.lower() for s in ["HELLO", "WORLD"]), {}), From d291d8be022adee811f17fb2be1d4dea0f3f119f Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 21 Aug 2025 14:40:37 -0400 Subject: [PATCH 092/106] support fstrings in python 3.13 --- effectful/internals/disassembly.py | 121 +++++++++++++++++++++------ tests/test_internals_disassembler.py | 48 +++++++++-- 2 files changed, 139 insertions(+), 30 deletions(-) diff --git a/effectful/internals/disassembly.py b/effectful/internals/disassembly.py index 8dd20307..71a5870a 100644 --- a/effectful/internals/disassembly.py +++ b/effectful/internals/disassembly.py @@ -52,6 +52,19 @@ def __init__(self, value=None): super().__init__(value=value) +class ConvertedValue(ast.expr): + """Wrapper for values that have been converted with CONVERT_VALUE.""" + + def __init__(self, value: ast.expr, conversion: int): + self.value = value + self.conversion = conversion + # Map CONVERT_VALUE args to ast.FormattedValue conversion values + # CONVERT_VALUE: 0=None, 1=str, 2=repr, 3=ascii + # ast.FormattedValue: -1=none, 115=str, 114=repr, 97=ascii + conversion_map = {0: -1, 1: 115, 2: 114, 3: 97} + self.ast_conversion = conversion_map.get(conversion, -1) + + class CompLambda(ast.Lambda): """Placeholder AST node representing a lambda function used in comprehensions.""" @@ -250,7 +263,6 @@ def handle_yield_value( "YIELD_VALUE must be called before yielding" ) assert len(state.result.generators) > 0, "YIELD_VALUE should have generators" - ret = ast.GeneratorExp( elt=ensure_ast(state.stack[-1]), generators=state.result.generators, @@ -923,6 +935,25 @@ def handle_unary_op( ) +@register_handler("CONVERT_VALUE", version=PythonVersion.PY_313) +def handle_convert_value( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # CONVERT_VALUE applies a conversion to the value on top of stack + # Used for f-string conversions like !r, !s, !a + # The conversion type is stored in instr.arg: + # 0 = None, 1 = str (!s), 2 = repr (!r), 3 = ascii (!a) + assert len(state.stack) > 0, "CONVERT_VALUE requires a value on stack" + assert instr.arg is not None, "CONVERT_VALUE requires conversion type" + + # Wrap the value with conversion information + value = state.stack[-1] + converted = ConvertedValue(value, instr.arg) + new_stack = state.stack[:-1] + [converted] + + return replace(state, stack=new_stack) + + @register_handler("CALL_INTRINSIC_1", version=PythonVersion.PY_312) @register_handler("CALL_INTRINSIC_1", version=PythonVersion.PY_313) def handle_call_intrinsic_1( @@ -1384,27 +1415,41 @@ def handle_dict_update( def handle_build_string( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: - # BUILD_STRING creates a string from the top of the stack - # The number of elements to pop is determined by the instruction argument + # BUILD_STRING concatenates strings from the stack + # For f-strings, it combines FormattedValue and Constant nodes assert instr.arg is not None string_size: int = instr.arg if string_size == 0: # Empty string case new_stack = state.stack + [ast.Constant(value="")] - else: - # Pop elements for the string - elements = ( - [ensure_ast(elem) for elem in state.stack[-string_size:]] - if string_size > 0 - else [] - ) - new_stack = state.stack[:-string_size] + return replace(state, stack=new_stack) - # Create concatenated string AST - concat_node = ast.JoinedStr(values=elements) - new_stack = new_stack + [concat_node] + # Pop elements for the string + elements = [ensure_ast(elem) for elem in state.stack[-string_size:]] + new_stack = state.stack[:-string_size] + + # Check if this is an f-string build (has FormattedValue nodes) + # or a regular string concatenation + if any(isinstance(elem, ast.JoinedStr) for elem in elements): + # This is an f-string - create JoinedStr + values = [] + for elem in elements: + if isinstance(elem, ast.JoinedStr): + values.extend(elem.values) + else: + values.append(elem) + concat_node = ast.JoinedStr(values=values) + elif all(isinstance(elem, ast.Constant) for elem in elements): + # This is regular string concatenation or format spec building + # If all elements are constants, we might be building a format spec + # Concatenate the constant strings + concat_str = "".join(elem.value for elem in elements) + concat_node = ast.Constant(value=concat_str) + else: + raise TypeError("Should not be here?") + new_stack = new_stack + [concat_node] return replace(state, stack=new_stack) @@ -1420,7 +1465,7 @@ def handle_format_value( # Create formatted string AST formatted_node = ast.FormattedValue(value=value, conversion=-1, format_spec=None) - new_stack = new_stack + [formatted_node] + new_stack = new_stack + [ast.JoinedStr(values=[formatted_node])] return replace(state, stack=new_stack) @@ -1431,12 +1476,22 @@ def handle_format_simple( # FORMAT_SIMPLE formats a string with a single value # Pops the value and the format string from the stack assert len(state.stack) >= 1, "Not enough items on stack for FORMAT_SIMPLE" - value = ensure_ast(state.stack[-1]) + value = state.stack[-1] new_stack = state.stack[:-1] + # Check if the value was converted + if isinstance(value, ConvertedValue): + conversion = value.ast_conversion + value = value.value + else: + conversion = -1 + value = ensure_ast(value) + # Create formatted string AST - formatted_node = ast.FormattedValue(value=value, conversion=-1, format_spec=None) - new_stack = new_stack + [formatted_node] + formatted_node = ast.FormattedValue( + value=value, conversion=conversion, format_spec=None + ) + new_stack = new_stack + [ast.JoinedStr(values=[formatted_node])] return replace(state, stack=new_stack) @@ -1444,18 +1499,33 @@ def handle_format_simple( def handle_format_with_spec( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: - # FORMAT_WITH_SPEC formats a string with a value and a format specifier - # Pops the value, format string, and format specifier from the stack + # FORMAT_WITH_SPEC formats a value with a format specifier + # Stack order in Python 3.13: format_spec on top, value below assert len(state.stack) >= 2, "Not enough items on stack for FORMAT_WITH_SPEC" - value = ensure_ast(state.stack[-1]) - format_string = ensure_ast(state.stack[-2]) + format_spec = ensure_ast(state.stack[-1]) # Format spec is on top + value = state.stack[-2] # Value is below new_stack = state.stack[:-2] + # Check if the value was converted + if isinstance(value, ConvertedValue): + conversion = value.ast_conversion + value = value.value + else: + conversion = -1 + value = ensure_ast(value) + # Create formatted string AST with specifier + # The format_spec should be wrapped in a JoinedStr if it's a simple constant + if isinstance(format_spec, ast.Constant): + format_spec_node = ast.JoinedStr(values=[format_spec]) + else: + # Already a JoinedStr from nested formatting + format_spec_node = format_spec + formatted_node = ast.FormattedValue( - value=value, conversion=-1, format_spec=format_string + value=value, conversion=conversion, format_spec=format_spec_node ) - new_stack = new_stack + [formatted_node] + new_stack = new_stack + [ast.JoinedStr(values=[formatted_node])] return replace(state, stack=new_stack) @@ -1660,7 +1730,8 @@ def _ensure_ast_codeobj(value: types.CodeType) -> ast.Lambda | CompLambda: "Final return value must not contain statement nodes" ) assert not any( - isinstance(x, Placeholder | Null | CompLambda) for x in ast.walk(result) + isinstance(x, Placeholder | Null | CompLambda | ConvertedValue) + for x in ast.walk(result) ), "Final return value must not contain temporary nodes" assert not any(x.arg == ".0" for x in ast.walk(result) if isinstance(x, ast.arg)), ( "Final return value must not contain .0 argument" diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index fee4b1df..bd1ab302 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -169,14 +169,52 @@ def test_arithmetic_expressions(genexpr): # ============================================================================ -@pytest.mark.xfail(reason="f-string expressions not yet fully supported") @pytest.mark.parametrize( "genexpr", [ - # simple cases - (f"{x} is {x**2}" for x in range(5)), - (f"{x:02d}" for x in range(10)), - (f"{x:.2f}" for x in [1.2345, 2.3456, 3.4567]), + # Basic f-string cases + (f"{x}" for x in range(5)), # Single value, no format + (f"{x} is {x**2}" for x in range(5)), # Multiple values + (f"{x:02d}" for x in range(10)), # Format spec + (f"{x:.2f}" for x in [1.2345, 2.3456, 3.4567]), # Float format spec + # Conversion specifiers + (f"{x!r}" for x in ["hello", "world"]), # repr conversion + (f"{x!s}" for x in [1, 2, 3]), # str conversion + (f"{x!a}" for x in ["hello\n", "world\t"]), # ascii conversion + # Conversion with format spec + (f"{x!r:>10}" for x in ["hello", "world"]), # repr with alignment + (f"{x!s:^15}" for x in [1, 2, 3]), # str with center align + # Empty and literal f-strings + ("" for x in range(3)), # Empty f-string + ("constant" for x in range(3)), # No formatting + (f"x={x}" for x in range(5)), # Literal prefix + (f"result: {x * 2}" for x in range(5)), # Literal with expression + # Complex expressions in f-strings + (f"{x + 1}" for x in range(5)), # Arithmetic + (f"{x * x}" for x in range(5)), # Multiplication + (f"{x % 2}" for x in range(10)), # Modulo + (f"{-x}" for x in range(-2, 3)), # Unary minus + # Nested formatting + (f"{x:0{2}d}" for x in range(5)), # Format spec with expression + (f"{x:>{3 * 2}}" for x in range(5)), # Expression in format spec + # Multiple formatted values + (f"{x} + {y} = {x + y}" for x in range(3) for y in range(3)), # Multiple vars + (f"({x}, {y})" for x in range(2) for y in range(2)), # Tuple display + # F-strings with various data types + (f"{s}" for s in ["hello", "world"]), # Strings + (f"{b}" for b in [True, False]), # Booleans + (f"{n}" for n in [None, None]), # None values + (f"{lst}" for lst in [[1, 2], [3, 4]]), # Lists + # Complex format specifications + (f"{x:+05d}" for x in range(-2, 3)), # Sign, zero pad, width + (f"{x:.2%}" for x in [0.1, 0.25, 0.333]), # Percentage format + (f"{x:.2e}" for x in [100, 1000, 10000]), # Scientific notation + (f"{x:#x}" for x in [10, 15, 255]), # Hex with prefix + (f"{x:b}" for x in [2, 7, 15]), # Binary format + # Edge cases + ("{x}" for x in range(3)), # Escaped braces + (f"{{x}} = {x}" for x in range(3)), # Mixed escaped/formatted + (f"{{{x}}}" for x in range(3)), # Brace around formatted ], ) def test_fstring_expressions(genexpr): From 48741aa424ed8bbff50b5477be6ecdd7225743eb Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 21 Aug 2025 14:57:17 -0400 Subject: [PATCH 093/106] python 3.12 fstring support --- effectful/internals/disassembly.py | 67 ++++++++++++++++++++++++------ 1 file changed, 55 insertions(+), 12 deletions(-) diff --git a/effectful/internals/disassembly.py b/effectful/internals/disassembly.py index 71a5870a..49a78f04 100644 --- a/effectful/internals/disassembly.py +++ b/effectful/internals/disassembly.py @@ -1439,32 +1439,74 @@ def handle_build_string( values.extend(elem.values) else: values.append(elem) - concat_node = ast.JoinedStr(values=values) + return replace(state, stack=new_stack + [ast.JoinedStr(values=values)]) elif all(isinstance(elem, ast.Constant) for elem in elements): # This is regular string concatenation or format spec building # If all elements are constants, we might be building a format spec # Concatenate the constant strings - concat_str = "".join(elem.value for elem in elements) - concat_node = ast.Constant(value=concat_str) + assert all( + isinstance(elem, ast.Constant) and isinstance(elem.value, str) + for elem in elements + ) + concat_str = "".join( + elem.value + for elem in elements + if isinstance(elem, ast.Constant) and isinstance(elem.value, str) + ) + return replace(state, stack=new_stack + [ast.Constant(value=concat_str)]) else: raise TypeError("Should not be here?") - new_stack = new_stack + [concat_node] - return replace(state, stack=new_stack) - @register_handler("FORMAT_VALUE", version=PythonVersion.PY_312) def handle_format_value( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: - # FORMAT_VALUE formats a string with a value - # Pops the value and the format string from the stack + # FORMAT_VALUE formats a string with a value in Python 3.12 + # Flag bits: (flags & 0x03) = conversion, (flags & 0x04) = has format spec + assert instr.arg is not None, "FORMAT_VALUE requires flags argument" assert len(state.stack) >= 1, "Not enough items on stack for FORMAT_VALUE" - value = ensure_ast(state.stack[-1]) - new_stack = state.stack[:-1] - # Create formatted string AST - formatted_node = ast.FormattedValue(value=value, conversion=-1, format_spec=None) + flags = instr.arg + + # Check if there's a format specification + has_format_spec = bool(flags & 0x04) + + if has_format_spec: + # Pop format spec and value + assert len(state.stack) >= 2, ( + "FORMAT_VALUE with format spec needs 2 stack items" + ) + format_spec = ensure_ast(state.stack[-1]) + value = ensure_ast(state.stack[-2]) + new_stack = state.stack[:-2] + + # Wrap format spec in JoinedStr if it's a constant + if isinstance(format_spec, ast.Constant): + format_spec_node = ast.JoinedStr(values=[format_spec]) + else: + assert isinstance(format_spec, ast.JoinedStr) + format_spec_node = format_spec + else: + # Just pop the value + value = ensure_ast(state.stack[-1]) + new_stack = state.stack[:-1] + format_spec_node = None + + # Determine conversion type from flags + conversion_flags = flags & 0x03 + conversion_map = { + 0: -1, # No conversion + 1: 115, # str (!s) + 2: 114, # repr (!r) + 3: 97, # ascii (!a) + } + conversion = conversion_map[conversion_flags] + + # Create formatted value AST + formatted_node = ast.FormattedValue( + value=value, conversion=conversion, format_spec=format_spec_node + ) new_stack = new_stack + [ast.JoinedStr(values=[formatted_node])] return replace(state, stack=new_stack) @@ -1520,6 +1562,7 @@ def handle_format_with_spec( format_spec_node = ast.JoinedStr(values=[format_spec]) else: # Already a JoinedStr from nested formatting + assert isinstance(format_spec, ast.JoinedStr) format_spec_node = format_spec formatted_node = ast.FormattedValue( From d1307afee288ac885eac992646564568de830bb0 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 21 Aug 2025 15:43:38 -0400 Subject: [PATCH 094/106] conditional expression tests expanded --- tests/test_internals_disassembler.py | 81 +++++++++++++++++++++++++++- 1 file changed, 80 insertions(+), 1 deletion(-) diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index bd1ab302..1e370752 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -536,7 +536,6 @@ def test_different_comprehension_types(genexpr): # ============================================================================ -@pytest.mark.xfail(reason="Conditional expressions not yet properly supported") @pytest.mark.parametrize( "genexpr", [ @@ -587,6 +586,86 @@ def test_conditional_expressions_no_comprehension(genexpr): assert_ast_equivalent(genexpr, ast_node) +@pytest.mark.parametrize( + "genexpr", + [ + # Basic conditional expressions in comprehension bodies + ((x if x % 2 == 0 else -x) for x in range(5)), + ((x * 2 if x > 0 else x / 2) for x in range(-3, 4)), + ((x**2 if x != 0 else 1) for x in range(-2, 3)), + # Conditional expressions with filters + ((x if x % 2 == 0 else -x) for x in range(10) if x > 2), + ((x * 3 if x > 5 else x + 1) for x in range(20) if x % 3 == 0), + # Nested loops with conditional expressions + ((x + y if x > y else x - y) for x in range(3) for y in range(3)), + ( + (x * y if x != 0 and y != 0 else 0) + for x in range(-2, 3) + for y in range(-2, 3) + ), + # Multiple conditional expressions + ( + (x if x > 0 else 0) + (y if y > 0 else 0) + for x in range(-2, 3) + for y in range(-2, 3) + ), + # Conditional expressions in different parts + ([x if x > 0 else -x for x in range(i)] for i in range(1, 4)), + ((x if x % 2 == 0 else -x) for x in (y if y > 2 else y + 10 for y in range(5))), + # Complex nested conditional expressions + ((x if x > 0 else (x + 5 if x > -3 else x * 2)) for x in range(-5, 5)), + ((x * 2 if x > 0 else (x / 2 if x < 0 else 1)) for x in range(-3, 4)), + # Conditional expressions with function calls + ((abs(x) if x < 0 else x) for x in range(-3, 4)), + ((max(x, 0) if x is not None else 0) for x in [None, -1, 0, 1, 2]), + # Mixed with other complex expressions + ((x + 1 if x % 2 == 0 else x - 1) * 2 for x in range(5)), + ((x, y, x + y if x > y else x - y) for x in range(3) for y in range(3)), + ], +) +def test_conditional_expressions_simple_comprehensions(genexpr): + ast_node = disassemble(genexpr) + assert_ast_equivalent(genexpr, ast_node) + + +@pytest.mark.parametrize( + "genexpr", + [ + # Lazy boolean operations (and/or) + (x for x in range(10) if x > 2 and x < 8), + (x for x in range(10) if x < 3 or x > 7), + (x for x in range(20) if x > 5 and x % 2 == 0 and x < 15), + (x for x in range(20) if x < 5 or x > 15 or x == 10), + # Mixed and/or + (x for x in range(20) if (x > 10 and x % 2 == 0) or (x < 5 and x % 3 == 0)), + (x for x in range(20) if x > 5 and (x < 10 or x > 15)), + # Chained comparisons + (x for x in range(20) if 5 < x < 15), + (x for x in range(20) if 0 <= x <= 10), + (x for x in range(50) if 10 < x < 20 < x * 2), + (x for x in range(10) if 0 <= x <= 5 <= x + 5), + # Mixed chained and boolean + (x for x in range(50) if 5 < x < 15 and x % 2 == 0), + (x for x in range(50) if x > 20 or 5 < x < 15), + # Complex boolean expressions in comprehension body + ((x if x > 5 and x < 15 else 0) for x in range(20)), + ((x if x < 3 or x > 17 else -x) for x in range(20)), + # Chained comparisons in conditional expressions + ((x if 5 < x < 15 else 0) for x in range(20)), + ((x * 2 if 0 <= x <= 10 else x / 2) for x in range(-5, 15)), + # Nested boolean logic + (x for x in range(100) if (x > 10 and x < 50) and (x % 3 == 0 or x % 5 == 0)), + (x for x in range(100) if not (x > 30 and x < 70)), + # Boolean expressions with function calls + (x for x in range(-10, 10) if abs(x) > 3 and x % 2 == 0), + (x for x in ["hello", "world", "test"] if len(x) > 3 and x.startswith("h")), + ], +) +def test_lazy_boolean_and_chained_comparisons(genexpr): + ast_node = disassemble(genexpr) + assert_ast_equivalent(genexpr, ast_node) + + # ============================================================================ # GENERATOR EXPRESSION WITH GLOBALS # ============================================================================ From a09f4716203c00dc412ff804e9ef8da6809aebb5 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 21 Aug 2025 16:20:39 -0400 Subject: [PATCH 095/106] pre-refactor stash --- effectful/internals/disassembly.py | 13 +++++++--- tests/test_internals_disassembler.py | 37 ++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/effectful/internals/disassembly.py b/effectful/internals/disassembly.py index 49a78f04..a9add18e 100644 --- a/effectful/internals/disassembly.py +++ b/effectful/internals/disassembly.py @@ -1741,6 +1741,14 @@ def _ensure_ast_range_iterator(value: Iterator) -> ast.Call: return ensure_ast(value.__reduce__()[1][0]) # type: ignore +def _symbolic_exec(code: types.CodeType) -> ReconstructionState: + # TODO respect control flow + state = ReconstructionState(code=code) + for instr in state.instructions.values(): + state = OP_HANDLERS[instr.opname](state, instr) + return state + + @ensure_ast.register def _ensure_ast_codeobj(value: types.CodeType) -> ast.Lambda | CompLambda: assert inspect.iscode(value), "Input must be a code object" @@ -1763,10 +1771,7 @@ def _ensure_ast_codeobj(value: types.CodeType) -> ast.Lambda | CompLambda: raise TypeError(f"Unsupported code object type: {value.co_name}") # Symbolic execution to reconstruct the AST - state = ReconstructionState(code=value) - for instr in state.instructions.values(): - state = OP_HANDLERS[instr.opname](state, instr) - result: ast.expr = state.result + result: ast.expr = _symbolic_exec(value).result # Check postconditions assert not any(isinstance(x, ast.stmt) for x in ast.walk(result)), ( diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index 1e370752..d5c4b01a 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -666,6 +666,43 @@ def test_lazy_boolean_and_chained_comparisons(genexpr): assert_ast_equivalent(genexpr, ast_node) +@pytest.mark.parametrize( + "genexpr", + [ + # Simple conditional as function argument + (max(x if x > 0 else 0) for x in range(-2, 3)), + (abs(x if x < 0 else -x) for x in range(-3, 3)), + (len(str(x) if x > 10 else "small") for x in range(15)), + # Multiple conditional arguments + ( + max(x if x > 0 else 0, y if y > 0 else 0) + for x in range(-1, 2) + for y in range(-1, 2) + ), + ( + pow(x if x != 0 else 1, y if y > 0 else 1) + for x in range(3) + for y in range(3) + ), + # Nested function calls with conditionals + (max(abs(x if x < 0 else -x), 1) for x in range(-3, 4)), + (int(str(x if x > 5 else x + 10)) for x in range(10)), + # Conditionals in keyword arguments (using dict constructor as example) + (dict(a=x if x > 0 else 0, b=x * 2 if x < 5 else x) for x in range(8)), + # Method calls with conditional arguments + ([1, 2, 3].index(x if x in [1, 2, 3] else 1) for x in range(5)), + ("hello".replace("l", x if isinstance(x, str) else "X") for x in ["a", 1, "b"]), + # Mixed: conditional in function call within comprehension filter + (x for x in range(20) if max(x if x > 10 else 0, 5) > 8), + # Complex nested case: conditional in function argument, function call in conditional + (max(x if len(str(x)) > 1 else x * 10) for x in range(15)), + ], +) +def test_conditional_expressions_function_arguments(genexpr): + ast_node = disassemble(genexpr) + assert_ast_equivalent(genexpr, ast_node) + + # ============================================================================ # GENERATOR EXPRESSION WITH GLOBALS # ============================================================================ From 28421073a0c80c76ad86a5d95511a9a9b2a58cc6 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 21 Aug 2025 16:27:13 -0400 Subject: [PATCH 096/106] xfail --- tests/test_internals_disassembler.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index d5c4b01a..493c7976 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -536,6 +536,7 @@ def test_different_comprehension_types(genexpr): # ============================================================================ +@pytest.mark.xfail(reason="Conditional expressions not yet fully supported") @pytest.mark.parametrize( "genexpr", [ @@ -586,6 +587,7 @@ def test_conditional_expressions_no_comprehension(genexpr): assert_ast_equivalent(genexpr, ast_node) +@pytest.mark.xfail(reason="Conditional expressions not yet fully supported") @pytest.mark.parametrize( "genexpr", [ @@ -628,6 +630,9 @@ def test_conditional_expressions_simple_comprehensions(genexpr): assert_ast_equivalent(genexpr, ast_node) +@pytest.mark.xfail( + reason="Lazy boolean ops and chained comparisons not yet fully supported" +) @pytest.mark.parametrize( "genexpr", [ @@ -666,6 +671,7 @@ def test_lazy_boolean_and_chained_comparisons(genexpr): assert_ast_equivalent(genexpr, ast_node) +@pytest.mark.xfail(reason="Conditional expressions not yet fully supported") @pytest.mark.parametrize( "genexpr", [ From 28940d4ff118fc1ba6a5df72b79324b6e1d8abd3 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 21 Aug 2025 18:15:12 -0400 Subject: [PATCH 097/106] sort of works --- effectful/internals/disassembly.py | 89 +++++++++++++++++++++++++++++- 1 file changed, 86 insertions(+), 3 deletions(-) diff --git a/effectful/internals/disassembly.py b/effectful/internals/disassembly.py index a9add18e..a577d757 100644 --- a/effectful/internals/disassembly.py +++ b/effectful/internals/disassembly.py @@ -1649,6 +1649,42 @@ def handle_pop_jump_if_not_none( ) +@register_handler("SEND", version=PythonVersion.PY_312) +@register_handler("SEND", version=PythonVersion.PY_313) +def handle_send( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + raise TypeError("SEND instruction should not appear in generator comprehensions") + + +@register_handler("JUMP_BACKWARD_NO_INTERRUPT", version=PythonVersion.PY_312) +@register_handler("JUMP_BACKWARD_NO_INTERRUPT", version=PythonVersion.PY_313) +def handle_jump_backward_no_interrupt( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + raise TypeError( + "JUMP_BACKWARD_NO_INTERRUPT instruction should not appear in generator comprehensions" + ) + + +@register_handler("JUMP", version=PythonVersion.PY_312) +@register_handler("JUMP", version=PythonVersion.PY_313) +def handle_jump( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + raise TypeError("JUMP instruction should not appear in generator comprehensions") + + +@register_handler("JUMP_NO_INTERRUPT", version=PythonVersion.PY_312) +@register_handler("JUMP_NO_INTERRUPT", version=PythonVersion.PY_313) +def handle_jump_no_interrupt( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + raise TypeError( + "JUMP_NO_INTERRUPT instruction should not appear in generator comprehensions" + ) + + # ============================================================================ # UTILITY FUNCTIONS # ============================================================================ @@ -1742,10 +1778,57 @@ def _ensure_ast_range_iterator(value: Iterator) -> ast.Call: def _symbolic_exec(code: types.CodeType) -> ReconstructionState: - # TODO respect control flow + """Execute bytecode symbolically, following control flow.""" state = ReconstructionState(code=code) - for instr in state.instructions.values(): - state = OP_HANDLERS[instr.opname](state, instr) + instructions = state.instructions + instrs_list = list(instructions.values()) + next_instr = { + i1.offset: i2.offset for i1, i2 in zip(instrs_list[:-1], instrs_list[1:]) + } + + loop_state = collections.Counter() + branch_state = collections.Counter() + + instr = next(iter(instructions.values())) # Start at first instruction + while instr is not None: + if instr.opname == "FOR_ITER": + # FOR_ITER has two paths: continue loop or exit when exhausted + # For reconstruction, we execute the continue path once + if loop_state[instr.offset] > 0: + # Simulate iterator exhaustion - jump to FOR_ITER target + instr = instructions[instr.jump_target] + else: + # Continue loop - execute FOR_ITER handler + state = OP_HANDLERS[instr.opname](state, instr) + loop_state[instr.offset] += 1 + instr = instructions[next_instr[instr.offset]] + elif instr.opname.startswith("POP_JUMP_IF_"): + # POP_JUMP_IF_*: conditional jump, follow the jump path once + if branch_state[instr.offset] > 0: + # Simulate not taking the jump - continue to next instruction + instr = instructions[next_instr[instr.offset]] + else: + # Take the jump - execute the POP_JUMP_IF_* handler + state = OP_HANDLERS[instr.opname](state, instr) + instr = instructions[instr.jump_target] + branch_state[instr.offset] += 1 + elif instr.opname in {"JUMP_BACKWARD", "JUMP_FORWARD"}: + # JUMP_BACKWARD: loop back to FOR_ITER + state = OP_HANDLERS[instr.opname](state, instr) + instr = instructions[instr.jump_target] + elif instr.opname in {"RETURN_VALUE", "RETURN_CONST"}: + # YIELD_VALUE and RETURN_VALUE end execution + state = OP_HANDLERS[instr.opname](state, instr) + instr = None + else: + # All other operations: handle normally + state = OP_HANDLERS[instr.opname](state, instr) + instr = ( + instructions[next_instr[instr.offset]] + if instr.offset in next_instr + else None + ) + return state From 0c9e407bb0b4ad8f6fed0ac80508e302d5142008 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 22 Aug 2025 16:19:23 -0400 Subject: [PATCH 098/106] rework to symbolic interpreter --- effectful/internals/disassembly.py | 441 +++++++++++++-------------- tests/test_internals_disassembler.py | 4 +- 2 files changed, 220 insertions(+), 225 deletions(-) diff --git a/effectful/internals/disassembly.py b/effectful/internals/disassembly.py index a577d757..e467cf54 100644 --- a/effectful/internals/disassembly.py +++ b/effectful/internals/disassembly.py @@ -117,11 +117,6 @@ class ReconstructionState: code: The compiled code object from which the bytecode is being processed. This is typically obtained from a generator function or comprehension. - result: The current comprehension expression being built. Initially - a placeholder, it gets updated as the bytecode is processed. - It can be a GeneratorExp, ListComp, SetComp, DictComp, or - a Lambda for lambda expressions. - stack: Simulates the Python VM's value stack. Contains AST nodes or values that would be on the stack during execution. Operations like LOAD_FAST push to this stack, while operations like @@ -129,8 +124,8 @@ class ReconstructionState: """ code: types.CodeType + stack: list[ast.expr] result: ast.expr = field(default_factory=Placeholder) - stack: list[ast.expr] = field(default_factory=list) @property def instructions(self) -> collections.OrderedDict[int, dis.Instruction]: @@ -187,29 +182,34 @@ def register_handler( if opname in OP_HANDLERS: raise ValueError(f"Handler for '{opname}' (version {version}) already exists.") + if dis.opmap[opname] in dis.hasjrel: + assert "jump" in inspect.signature(handler).parameters + @functools.wraps(handler) def _wrapper( - state: ReconstructionState, instr: dis.Instruction + state: ReconstructionState, instr: dis.Instruction, *, jump: bool | None = None ) -> ReconstructionState: assert instr.opname == opname, ( f"Handler for '{opname}' called with wrong instruction" ) - new_state = handler(state, instr) + if instr.opcode in dis.hasjrel: + assert jump is not None, f"Jump op {opname} must have jump state" + new_state = handler(state, instr, jump=jump) + else: + assert jump is None, f"Non-jump op {opname} must not have jump state" + new_state = handler(state, instr) # post-condition: check stack effect - expected_stack_effect = dis.stack_effect(instr.opcode, instr.arg) + expected_stack_effect = dis.stack_effect(instr.opcode, instr.arg, jump=jump) actual_stack_effect = len(new_state.stack) - len(state.stack) - if not ( - len(state.stack) == len(new_state.stack) == 0 or instr.opname == "END_FOR" - ): - assert len(state.stack) + expected_stack_effect >= 0, ( - f"Handler for '{opname}' would result in negative stack size" - ) - assert actual_stack_effect == expected_stack_effect, ( - f"Handler for '{opname}' has incorrect stack effect: " - f"expected {expected_stack_effect}, got {actual_stack_effect}" - ) + assert len(state.stack) + expected_stack_effect >= 0, ( + f"Handler for '{opname}' would result in negative stack size" + ) + assert actual_stack_effect == expected_stack_effect, ( + f"Handler for '{opname}' has incorrect stack effect: " + f"expected {expected_stack_effect}, got {actual_stack_effect}" + ) return new_state @@ -217,6 +217,68 @@ def _wrapper( return handler # return the original handler for multiple decorator usage +def _symbolic_exec(code: types.CodeType) -> ast.expr: + """Execute bytecode symbolically, following control flow.""" + state = ReconstructionState(code=code, stack=[Null()]) + instructions = state.instructions + instrs_list = list(instructions.values()) + next_instr = { + i1.offset: i2.offset for i1, i2 in zip(instrs_list[:-1], instrs_list[1:]) + } + + loop_state: collections.Counter[int] = collections.Counter() + branch_state: collections.Counter[int] = collections.Counter() + + instr = next(iter(instructions.values())) # Start at first instruction + while instr is not None: + if instr.opname in {"FOR_ITER"}: + # FOR_ITER has two paths: continue loop or exit when exhausted + # For reconstruction, we execute the continue path once + if loop_state[instr.offset] > 0: + # Simulate iterator exhaustion - jump to FOR_ITER target + state = OP_HANDLERS[instr.opname](state, instr, jump=True) + instr = instructions[instr.jump_target] + else: + # Continue loop - execute FOR_ITER handler + state = OP_HANDLERS[instr.opname](state, instr, jump=False) + loop_state[instr.offset] += 1 + instr = instructions[next_instr[instr.offset]] + elif instr.opname in { + "POP_JUMP_IF_TRUE", + "POP_JUMP_IF_FALSE", + "POP_JUMP_IF_NOT_NONE", + "POP_JUMP_IF_NONE", + }: + # POP_JUMP_IF_*: conditional jump, follow the jump path once + if branch_state[instr.offset] > 0: + # Take the jump - execute the POP_JUMP_IF_* handler + state = OP_HANDLERS[instr.opname](state, instr, jump=True) + instr = instructions[next_instr[instr.offset]] + else: + # Simulate not taking the jump - continue to next instruction + state = OP_HANDLERS[instr.opname](state, instr, jump=False) + instr = instructions[instr.jump_target] + branch_state[instr.offset] += 1 + elif instr.opname in {"JUMP_BACKWARD", "JUMP_FORWARD"}: + # JUMP_BACKWARD: loop back to FOR_ITER + state = OP_HANDLERS[instr.opname](state, instr, jump=True) + instr = instructions[instr.jump_target] + elif instr.opname in {"RETURN_VALUE", "RETURN_CONST"}: + # RETURN_VALUE ends execution + state = OP_HANDLERS[instr.opname](state, instr) + instr = None + else: + # All other operations: handle normally + state = OP_HANDLERS[instr.opname](state, instr) + instr = ( + instructions[next_instr[instr.offset]] + if instr.offset in next_instr + else None + ) + + return state.result + + # ============================================================================ # GENERATOR COMPREHENSION HANDLERS # ============================================================================ @@ -227,12 +289,12 @@ def handle_return_generator_312( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: # RETURN_GENERATOR is the first instruction in generator expressions in Python 3.13+ - # It initializes the generator - assert isinstance(state.result, Placeholder), ( + assert isinstance(state.stack[-1], Placeholder), ( "RETURN_GENERATOR must be the first instruction" ) new_result = ast.GeneratorExp(elt=Placeholder(), generators=[]) - return replace(state, result=new_result) + new_stack = state.stack[:-1] + [new_result] + return replace(state, stack=new_stack) @register_handler("RETURN_GENERATOR", version=PythonVersion.PY_313) @@ -240,13 +302,9 @@ def handle_return_generator( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: # RETURN_GENERATOR is the first instruction in generator expressions in Python 3.13+ - # It initializes the generator - assert isinstance(state.result, Placeholder), ( - "RETURN_GENERATOR must be the first instruction" + return replace( + state, stack=[ast.GeneratorExp(elt=Placeholder(), generators=[]), Null()] ) - new_result = ast.GeneratorExp(elt=Placeholder(), generators=[]) - new_stack = state.stack + [Null()] - return replace(state, result=new_result, stack=new_stack) @register_handler("YIELD_VALUE", version=PythonVersion.PY_312) @@ -256,18 +314,18 @@ def handle_yield_value( ) -> ReconstructionState: # YIELD_VALUE pops a value from the stack and yields it # This is the expression part of the generator - assert isinstance(state.result, ast.GeneratorExp), ( + assert isinstance(state.result, Placeholder) + new_result = copy.deepcopy(state.stack[0]) + assert isinstance(new_result, ast.GeneratorExp), ( "YIELD_VALUE must be called after RETURN_GENERATOR" ) - assert isinstance(state.result.elt, Placeholder), ( + assert isinstance(new_result.elt, Placeholder), ( "YIELD_VALUE must be called before yielding" ) - assert len(state.result.generators) > 0, "YIELD_VALUE should have generators" - ret = ast.GeneratorExp( - elt=ensure_ast(state.stack[-1]), - generators=state.result.generators, - ) - return replace(state, result=ret) + assert len(new_result.generators) > 0, "YIELD_VALUE should have generators" + new_result.elt = ensure_ast(state.stack[-1]) + new_stack = [new_result] + state.stack[1:] + return replace(state, stack=new_stack, result=new_result) # ============================================================================ @@ -287,8 +345,8 @@ def handle_build_list( # Check if this looks like the start of a list comprehension pattern # In nested comprehensions, BUILD_LIST(0) starts a new list comprehe new_ret = ast.ListComp(elt=Placeholder(), generators=[]) - new_stack = state.stack + [state.result] - return replace(state, stack=new_stack, result=new_ret) + new_stack = state.stack + [new_ret] + return replace(state, stack=new_stack) else: # BUILD_LIST with elements - create a regular list elements = [ensure_ast(elem) for elem in state.stack[-size:]] @@ -303,19 +361,18 @@ def handle_build_list( def handle_list_append( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: - assert isinstance(state.result, ast.ListComp) - assert isinstance(state.result.elt, Placeholder) + assert isinstance(state.stack[-instr.argval - 1], ast.ListComp) + assert isinstance(state.stack[-instr.argval - 1].elt, Placeholder) # add the body to the comprehension - comp: ast.ListComp = copy.deepcopy(state.result) + comp: ast.ListComp = copy.deepcopy(state.stack[-instr.argval - 1]) comp.elt = state.stack[-1] # swap the return value - prev_result: CompExp = state.stack[-instr.argval - 1] new_stack = state.stack[:-1] new_stack[-instr.argval] = comp - return replace(state, stack=new_stack, result=prev_result) + return replace(state, stack=new_stack) # ============================================================================ @@ -333,8 +390,8 @@ def handle_build_set( if size == 0: new_result = ast.SetComp(elt=Placeholder(), generators=[]) - new_stack = state.stack + [state.result] - return replace(state, stack=new_stack, result=new_result) + new_stack = state.stack + [new_result] + return replace(state, stack=new_stack) else: elements = [ensure_ast(elem) for elem in state.stack[-size:]] new_stack = state.stack[:-size] @@ -348,19 +405,18 @@ def handle_build_set( def handle_set_add( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: - assert isinstance(state.result, ast.SetComp) - assert isinstance(state.result.elt, Placeholder) + assert isinstance(state.stack[-instr.argval - 1], ast.SetComp) + assert isinstance(state.stack[-instr.argval - 1].elt, Placeholder) # add the body to the comprehension - comp: ast.SetComp = copy.deepcopy(state.result) + comp: ast.SetComp = copy.deepcopy(state.stack[-instr.argval - 1]) comp.elt = state.stack[-1] # swap the return value - prev_result: CompExp = state.stack[-instr.argval - 1] new_stack = state.stack[:-1] new_stack[-instr.argval] = comp - return replace(state, stack=new_stack, result=prev_result) + return replace(state, stack=new_stack) # ============================================================================ @@ -377,9 +433,9 @@ def handle_build_map( size: int = instr.arg if size == 0: - new_stack = state.stack + [state.result] new_result = ast.DictComp(key=Placeholder(), value=Placeholder(), generators=[]) - return replace(state, stack=new_stack, result=new_result) + new_stack = state.stack + [new_result] + return replace(state, stack=new_stack) else: # Pop key-value pairs for the dict keys: list[ast.expr | None] = [ @@ -399,21 +455,20 @@ def handle_build_map( def handle_map_add( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: - assert isinstance(state.result, ast.DictComp) - assert isinstance(state.result.key, Placeholder) - assert isinstance(state.result.value, Placeholder) + assert isinstance(state.stack[-instr.argval - 2], ast.DictComp) + assert isinstance(state.stack[-instr.argval - 2].key, Placeholder) + assert isinstance(state.stack[-instr.argval - 2].value, Placeholder) # add the body to the comprehension - comp: ast.DictComp = copy.deepcopy(state.result) + comp: ast.DictComp = copy.deepcopy(state.stack[-instr.argval - 2]) comp.key = state.stack[-2] comp.value = state.stack[-1] # swap the return value - prev_result: CompExp = state.stack[-instr.argval - 2] new_stack = state.stack[:-2] new_stack[-instr.argval] = comp - return replace(state, stack=new_stack, result=prev_result) + return replace(state, stack=new_stack) # ============================================================================ @@ -426,29 +481,37 @@ def handle_map_add( def handle_return_value( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: - # RETURN_VALUE ends the generator - # Usually preceded by LOAD_CONST None - if isinstance(state.result, CompExp): - return replace(state, stack=state.stack[:-1]) - elif isinstance(state.result, Placeholder): - new_result = ensure_ast(state.stack[0]) - assert isinstance(new_result, ast.expr) - return replace(state, stack=state.stack[1:], result=new_result) + assert isinstance(state.result, Placeholder) + new_result = ensure_ast(state.stack[-1]) + new_stack = state.stack[:-1] + return replace(state, stack=new_stack, result=new_result) + + +@register_handler("RETURN_CONST", version=PythonVersion.PY_312) +@register_handler("RETURN_CONST", version=PythonVersion.PY_313) +def handle_return_const( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # RETURN_CONST returns a constant value (replaces some LOAD_CONST + RETURN_VALUE patterns) + # Similar to RETURN_VALUE but with a constant + if isinstance(state.result, Placeholder): + return replace(state, result=ensure_ast(instr.argval)) else: - raise TypeError("Unexpected RETURN_VALUE in reconstruction") + assert instr.argval is None + return state @register_handler("FOR_ITER", version=PythonVersion.PY_312) @register_handler("FOR_ITER", version=PythonVersion.PY_313) def handle_for_iter( - state: ReconstructionState, instr: dis.Instruction + state: ReconstructionState, instr: dis.Instruction, *, jump: bool ) -> ReconstructionState: # FOR_ITER pops an iterator from the stack and pushes the next item # If the iterator is exhausted, it jumps to the target instruction assert len(state.stack) > 0, "FOR_ITER must have an iterator on the stack" - assert isinstance(state.result, CompExp), ( - "FOR_ITER must be called within a comprehension context" - ) + + if jump: + return replace(state, stack=state.stack + [Null()]) # The iterator should be on top of stack iterator: ast.expr = state.stack[-1] @@ -462,14 +525,21 @@ def handle_for_iter( is_async=0, ) - # Create new loops list with the new loop info - assert isinstance(state.result, CompExp) - new_ret = copy.deepcopy(state.result) - new_ret.generators = new_ret.generators + [loop_info] + for pos, item in zip(reversed(range(len(state.stack))), reversed(state.stack)): + if isinstance(item, CompExp) and isinstance( + getattr(item, "elt", getattr(item, "key", None)), Placeholder + ): + new_result = copy.deepcopy(item) + new_result.generators.append(loop_info) + new_stack = ( + state.stack[:pos] + + [new_result] + + state.stack[pos + 1 :] + + [loop_info.target] + ) + return replace(state, stack=new_stack) - new_stack = state.stack + [loop_info.target] - assert isinstance(new_ret, CompExp) - return replace(state, stack=new_stack, result=new_ret) + raise TypeError("FOR_ITER did not find partial comprehension on stack") @register_handler("GET_ITER", version=PythonVersion.PY_312) @@ -486,20 +556,22 @@ def handle_get_iter( @register_handler("JUMP_FORWARD", version=PythonVersion.PY_312) @register_handler("JUMP_FORWARD", version=PythonVersion.PY_313) def handle_jump_forward( - state: ReconstructionState, instr: dis.Instruction + state: ReconstructionState, instr: dis.Instruction, *, jump: bool = True ) -> ReconstructionState: # JUMP_FORWARD is used to jump forward in the code # In generator expressions, this is often used to skip code in conditional logic + assert jump, "JUMP_FORWARD always jumps" return state @register_handler("JUMP_BACKWARD", version=PythonVersion.PY_312) @register_handler("JUMP_BACKWARD", version=PythonVersion.PY_313) def handle_jump_backward( - state: ReconstructionState, instr: dis.Instruction + state: ReconstructionState, instr: dis.Instruction, *, jump: bool = True ) -> ReconstructionState: # JUMP_BACKWARD is used to jump back to the beginning of a loop (replaces JUMP_ABSOLUTE in 3.13) # In generator expressions, this typically indicates the end of the loop body + assert jump, "JUMP_BACKWARD always jumps" return state @@ -517,7 +589,7 @@ def handle_end_for_312( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: # END_FOR marks the end of a for loop, followed by POP_TOP (in 3.12) - new_stack = state.stack[:-1] + new_stack = state.stack[:-2] return replace(state, stack=new_stack) @@ -526,24 +598,10 @@ def handle_end_for( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: # END_FOR marks the end of a for loop - no action needed for AST reconstruction - new_stack = state.stack + new_stack = state.stack[:-1] return replace(state, stack=new_stack) -@register_handler("RETURN_CONST", version=PythonVersion.PY_312) -@register_handler("RETURN_CONST", version=PythonVersion.PY_313) -def handle_return_const( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - # RETURN_CONST returns a constant value (replaces some LOAD_CONST + RETURN_VALUE patterns) - # Similar to RETURN_VALUE but with a constant - if isinstance(state.result, ast.GeneratorExp): - # For generators, this typically ends the generator with None - return state - else: - raise TypeError("Unexpected RETURN_CONST in reconstruction") - - @register_handler("RERAISE", version=PythonVersion.PY_312) @register_handler("RERAISE", version=PythonVersion.PY_313) def handle_reraise( @@ -634,51 +692,37 @@ def handle_load_name( return replace(state, stack=new_stack) -@register_handler("STORE_FAST", version=PythonVersion.PY_312) -@register_handler("STORE_FAST", version=PythonVersion.PY_313) -def handle_store_fast( - state: ReconstructionState, instr: dis.Instruction -) -> ReconstructionState: - assert isinstance(state.result, CompExp) and state.result.generators, ( - "STORE_FAST must be called within a comprehension context" - ) - var_name = instr.argval - - if not state.stack or ( - isinstance(state.stack[-1], ast.Name) and state.stack[-1].id == var_name - ): - # If the variable is already on the stack, we can skip adding it again - # This is common in nested comprehensions where the same variable is reused - return replace(state, stack=state.stack[:-1]) - - new_stack = state.stack[:-1] - new_result: CompExp = copy.deepcopy(state.result) - new_result.generators[-1].target = ast.Name(id=var_name, ctx=ast.Store()) - return replace(state, stack=new_stack, result=new_result) - - @register_handler("STORE_DEREF", version=PythonVersion.PY_312) @register_handler("STORE_DEREF", version=PythonVersion.PY_313) def handle_store_deref( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: # STORE_DEREF stores a value into a closure variable - assert isinstance(state.result, CompExp) and state.result.generators, ( - "STORE_DEREF must be called within a comprehension context" - ) - var_name = instr.argval + # For AST reconstruction, we treat this the same as STORE_FAST + return handle_store_fast(state, instr) + - if not state.stack or ( - isinstance(state.stack[-1], ast.Name) and state.stack[-1].id == var_name - ): +@register_handler("STORE_FAST", version=PythonVersion.PY_312) +@register_handler("STORE_FAST", version=PythonVersion.PY_313) +def handle_store_fast( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + if isinstance(state.stack[-1], ast.Name) and state.stack[-1].id == instr.argval: # If the variable is already on the stack, we can skip adding it again # This is common in nested comprehensions where the same variable is reused return replace(state, stack=state.stack[:-1]) - new_stack = state.stack[:-1] - new_result: CompExp = copy.deepcopy(state.result) - new_result.generators[-1].target = ast.Name(id=var_name, ctx=ast.Store()) - return replace(state, stack=new_stack, result=new_result) + assert isinstance(state.stack[-1], Placeholder) + for pos, item in zip(reversed(range(len(state.stack))), reversed(state.stack)): + if isinstance(item, CompExp) and item.generators[-1].target == state.stack[-1]: + new_result = copy.deepcopy(item) + new_result.generators[-1].target = ast.Name( + id=instr.argval, ctx=ast.Store() + ) + new_stack = state.stack[:pos] + [new_result] + state.stack[pos + 1 : -1] + return replace(state, stack=new_stack) + + raise TypeError("STORE_FAST did not find matching Placeholder") @register_handler("STORE_FAST_LOAD_FAST", version=PythonVersion.PY_313) @@ -689,20 +733,23 @@ def handle_store_fast_load_fast( # The instruction has two names: store_name and load_name # In Python 3.13, this is often used for loop variables - # First handle the store part - assert isinstance(state.result, CompExp) and state.result.generators, ( - "STORE_FAST_LOAD_FAST must be called within a comprehension context" - ) - # In Python 3.13, the instruction argument contains both names # argval should be a tuple (store_name, load_name) assert isinstance(instr.argval, tuple) store_name, load_name = instr.argval - new_stack = state.stack[:-1] + [ast.Name(id=load_name, ctx=ast.Load())] - new_result: CompExp = copy.deepcopy(state.result) - new_result.generators[-1].target = ast.Name(id=store_name, ctx=ast.Store()) - return replace(state, stack=new_stack, result=new_result) + assert isinstance(state.stack[-1], Placeholder) + for pos, item in zip(reversed(range(len(state.stack))), reversed(state.stack)): + if isinstance(item, CompExp) and item.generators[-1].target == state.stack[-1]: + new_result = copy.deepcopy(item) + new_result.generators[-1].target = ast.Name(id=store_name, ctx=ast.Store()) + new_var = ast.Name(id=load_name, ctx=ast.Load()) + new_stack = ( + state.stack[:pos] + [new_result] + state.stack[pos + 1 : -1] + [new_var] + ) + return replace(state, stack=new_stack) + + raise TypeError("STORE_FAST_LOAD_FAST did not find matching Placeholder") @register_handler("LOAD_FAST_AND_CLEAR", version=PythonVersion.PY_312) @@ -1357,16 +1404,15 @@ def handle_list_extend( # because it was initially recognized as a list comprehension in BUILD_LIST, # while the actual result expression is in the stack where the list "should be" # and needs to be put back into the state result slot - assert isinstance(state.result, ast.ListComp) and not state.result.generators assert isinstance(state.stack[-1], ast.Tuple | ast.List) - prev_result = state.stack[-instr.argval - 1] + assert isinstance(state.stack[-instr.argval - 1], ast.ListComp) new_val = ast.List( elts=[ensure_ast(e) for e in state.stack[-1].elts], ctx=ast.Load() ) new_stack = state.stack[:-2] + [new_val] - return replace(state, stack=new_stack, result=prev_result) + return replace(state, stack=new_stack) @register_handler("SET_UPDATE", version=PythonVersion.PY_312) @@ -1378,14 +1424,13 @@ def handle_set_update( # because it was initially recognized as a list comprehension in BUILD_SET, # while the actual result expression is in the stack where the set "should be" # and needs to be put back into the state result slot - assert isinstance(state.result, ast.SetComp) and not state.result.generators + assert isinstance(state.stack[-instr.argval - 1], ast.SetComp) assert isinstance(state.stack[-1], ast.Tuple | ast.List | ast.Set) - prev_result = state.stack[-instr.argval - 1] new_val = ast.Set(elts=[ensure_ast(e) for e in state.stack[-1].elts]) new_stack = state.stack[:-2] + [new_val] - return replace(state, stack=new_stack, result=prev_result) + return replace(state, stack=new_stack) @register_handler("DICT_UPDATE", version=PythonVersion.PY_312) @@ -1397,9 +1442,8 @@ def handle_dict_update( # because it was initially recognized as a list comprehension in BUILD_MAP, # while the actual result expression is in the stack where the dict "should be" # and needs to be put back into the state result slot - assert isinstance(state.result, ast.DictComp) and not state.result.generators + assert isinstance(state.stack[-instr.argval - 1], ast.DictComp) assert isinstance(state.stack[-1], ast.Dict) - prev_result = state.stack[-instr.argval - 1] new_val = ast.Dict( keys=[ensure_ast(e) for e in state.stack[-1].keys], @@ -1407,7 +1451,7 @@ def handle_dict_update( ) new_stack = state.stack[:-2] + [new_val] - return replace(state, stack=new_stack, result=prev_result) + return replace(state, stack=new_stack) @register_handler("BUILD_STRING", version=PythonVersion.PY_312) @@ -1581,48 +1625,52 @@ def _handle_pop_jump_if( f_condition: Callable[[ast.expr], ast.expr], state: ReconstructionState, instr: dis.Instruction, + *, + jump: bool, ) -> ReconstructionState: # Generic handler for POP_JUMP_IF_* instructions # Pops a value from the stack and jumps if the condition is met condition = f_condition(ensure_ast(state.stack[-1])) - new_stack = state.stack[:-1] - if isinstance(state.result, CompExp) and state.result.generators: - # In comprehensions, we add the condition to the last generator's ifs - new_result = copy.deepcopy(state.result) - new_result.generators[-1].ifs.append(condition) - return replace(state, stack=new_stack, result=new_result) - else: - # Not in a comprehension context - might be boolean logic - raise NotImplementedError("Lazy and+or behavior not implemented yet") + for pos, item in zip(reversed(range(len(state.stack))), reversed(state.stack)): + if isinstance(item, CompExp) and isinstance( + getattr(item, "elt", getattr(item, "key", None)), Placeholder + ): + new_result = copy.deepcopy(item) + new_result.generators[-1].ifs.append(condition) + new_stack = state.stack[:pos] + [new_result] + state.stack[pos + 1 : -1] + return replace(state, stack=new_stack) + + # Not in a comprehension context - might be boolean logic + raise NotImplementedError("Lazy and+or behavior not implemented yet") @register_handler("POP_JUMP_IF_TRUE", version=PythonVersion.PY_312) @register_handler("POP_JUMP_IF_TRUE", version=PythonVersion.PY_313) def handle_pop_jump_if_true( - state: ReconstructionState, instr: dis.Instruction + state: ReconstructionState, instr: dis.Instruction, *, jump: bool ) -> ReconstructionState: # POP_JUMP_IF_TRUE pops a value from the stack and jumps if it's true # In Python 3.13, this is used for filter conditions where True means continue - return _handle_pop_jump_if(lambda c: c, state, instr) + return _handle_pop_jump_if(lambda c: c, state, instr, jump=jump) @register_handler("POP_JUMP_IF_FALSE", version=PythonVersion.PY_312) @register_handler("POP_JUMP_IF_FALSE", version=PythonVersion.PY_313) def handle_pop_jump_if_false( - state: ReconstructionState, instr: dis.Instruction + state: ReconstructionState, instr: dis.Instruction, *, jump: bool ) -> ReconstructionState: # POP_JUMP_IF_FALSE pops a value from the stack and jumps if it's false # In comprehensions, this is used for filter conditions return _handle_pop_jump_if( - lambda c: ast.UnaryOp(op=ast.Not(), operand=c), state, instr + lambda c: ast.UnaryOp(op=ast.Not(), operand=c), state, instr, jump=jump ) @register_handler("POP_JUMP_IF_NONE", version=PythonVersion.PY_312) @register_handler("POP_JUMP_IF_NONE", version=PythonVersion.PY_313) def handle_pop_jump_if_none( - state: ReconstructionState, instr: dis.Instruction + state: ReconstructionState, instr: dis.Instruction, *, jump: bool ) -> ReconstructionState: # POP_JUMP_IF_NONE pops a value and jumps if it's None return _handle_pop_jump_if( @@ -1631,13 +1679,14 @@ def handle_pop_jump_if_none( ), state, instr, + jump=jump, ) @register_handler("POP_JUMP_IF_NOT_NONE", version=PythonVersion.PY_312) @register_handler("POP_JUMP_IF_NOT_NONE", version=PythonVersion.PY_313) def handle_pop_jump_if_not_none( - state: ReconstructionState, instr: dis.Instruction + state: ReconstructionState, instr: dis.Instruction, *, jump: bool ) -> ReconstructionState: # POP_JUMP_IF_NOT_NONE pops a value and jumps if it's not None return _handle_pop_jump_if( @@ -1646,13 +1695,14 @@ def handle_pop_jump_if_not_none( ), state, instr, + jump=jump, ) @register_handler("SEND", version=PythonVersion.PY_312) @register_handler("SEND", version=PythonVersion.PY_313) def handle_send( - state: ReconstructionState, instr: dis.Instruction + state: ReconstructionState, instr: dis.Instruction, *, jump: bool ) -> ReconstructionState: raise TypeError("SEND instruction should not appear in generator comprehensions") @@ -1660,7 +1710,7 @@ def handle_send( @register_handler("JUMP_BACKWARD_NO_INTERRUPT", version=PythonVersion.PY_312) @register_handler("JUMP_BACKWARD_NO_INTERRUPT", version=PythonVersion.PY_313) def handle_jump_backward_no_interrupt( - state: ReconstructionState, instr: dis.Instruction + state: ReconstructionState, instr: dis.Instruction, *, jump: bool ) -> ReconstructionState: raise TypeError( "JUMP_BACKWARD_NO_INTERRUPT instruction should not appear in generator comprehensions" @@ -1670,7 +1720,7 @@ def handle_jump_backward_no_interrupt( @register_handler("JUMP", version=PythonVersion.PY_312) @register_handler("JUMP", version=PythonVersion.PY_313) def handle_jump( - state: ReconstructionState, instr: dis.Instruction + state: ReconstructionState, instr: dis.Instruction, *, jump: bool ) -> ReconstructionState: raise TypeError("JUMP instruction should not appear in generator comprehensions") @@ -1678,7 +1728,7 @@ def handle_jump( @register_handler("JUMP_NO_INTERRUPT", version=PythonVersion.PY_312) @register_handler("JUMP_NO_INTERRUPT", version=PythonVersion.PY_313) def handle_jump_no_interrupt( - state: ReconstructionState, instr: dis.Instruction + state: ReconstructionState, instr: dis.Instruction, *, jump: bool ) -> ReconstructionState: raise TypeError( "JUMP_NO_INTERRUPT instruction should not appear in generator comprehensions" @@ -1777,61 +1827,6 @@ def _ensure_ast_range_iterator(value: Iterator) -> ast.Call: return ensure_ast(value.__reduce__()[1][0]) # type: ignore -def _symbolic_exec(code: types.CodeType) -> ReconstructionState: - """Execute bytecode symbolically, following control flow.""" - state = ReconstructionState(code=code) - instructions = state.instructions - instrs_list = list(instructions.values()) - next_instr = { - i1.offset: i2.offset for i1, i2 in zip(instrs_list[:-1], instrs_list[1:]) - } - - loop_state = collections.Counter() - branch_state = collections.Counter() - - instr = next(iter(instructions.values())) # Start at first instruction - while instr is not None: - if instr.opname == "FOR_ITER": - # FOR_ITER has two paths: continue loop or exit when exhausted - # For reconstruction, we execute the continue path once - if loop_state[instr.offset] > 0: - # Simulate iterator exhaustion - jump to FOR_ITER target - instr = instructions[instr.jump_target] - else: - # Continue loop - execute FOR_ITER handler - state = OP_HANDLERS[instr.opname](state, instr) - loop_state[instr.offset] += 1 - instr = instructions[next_instr[instr.offset]] - elif instr.opname.startswith("POP_JUMP_IF_"): - # POP_JUMP_IF_*: conditional jump, follow the jump path once - if branch_state[instr.offset] > 0: - # Simulate not taking the jump - continue to next instruction - instr = instructions[next_instr[instr.offset]] - else: - # Take the jump - execute the POP_JUMP_IF_* handler - state = OP_HANDLERS[instr.opname](state, instr) - instr = instructions[instr.jump_target] - branch_state[instr.offset] += 1 - elif instr.opname in {"JUMP_BACKWARD", "JUMP_FORWARD"}: - # JUMP_BACKWARD: loop back to FOR_ITER - state = OP_HANDLERS[instr.opname](state, instr) - instr = instructions[instr.jump_target] - elif instr.opname in {"RETURN_VALUE", "RETURN_CONST"}: - # YIELD_VALUE and RETURN_VALUE end execution - state = OP_HANDLERS[instr.opname](state, instr) - instr = None - else: - # All other operations: handle normally - state = OP_HANDLERS[instr.opname](state, instr) - instr = ( - instructions[next_instr[instr.offset]] - if instr.offset in next_instr - else None - ) - - return state - - @ensure_ast.register def _ensure_ast_codeobj(value: types.CodeType) -> ast.Lambda | CompLambda: assert inspect.iscode(value), "Input must be a code object" @@ -1854,7 +1849,7 @@ def _ensure_ast_codeobj(value: types.CodeType) -> ast.Lambda | CompLambda: raise TypeError(f"Unsupported code object type: {value.co_name}") # Symbolic execution to reconstruct the AST - result: ast.expr = _symbolic_exec(value).result + result: ast.expr = _symbolic_exec(value) # Check postconditions assert not any(isinstance(x, ast.stmt) for x in ast.walk(result)), ( diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index 493c7976..e3ed8eb3 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -676,7 +676,7 @@ def test_lazy_boolean_and_chained_comparisons(genexpr): "genexpr", [ # Simple conditional as function argument - (max(x if x > 0 else 0) for x in range(-2, 3)), + (max(x if x > 0 else 0, 1) for x in range(-2, 3)), (abs(x if x < 0 else -x) for x in range(-3, 3)), (len(str(x) if x > 10 else "small") for x in range(15)), # Multiple conditional arguments @@ -701,7 +701,7 @@ def test_lazy_boolean_and_chained_comparisons(genexpr): # Mixed: conditional in function call within comprehension filter (x for x in range(20) if max(x if x > 10 else 0, 5) > 8), # Complex nested case: conditional in function argument, function call in conditional - (max(x if len(str(x)) > 1 else x * 10) for x in range(15)), + (abs(x if len(str(x)) > 1 else x * 10) for x in range(15)), ], ) def test_conditional_expressions_function_arguments(genexpr): From e39dc5dc7f9835a55cbd9e93e90492050170a077 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 22 Aug 2025 16:26:52 -0400 Subject: [PATCH 099/106] 312 --- effectful/internals/disassembly.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/effectful/internals/disassembly.py b/effectful/internals/disassembly.py index e467cf54..a348a1a3 100644 --- a/effectful/internals/disassembly.py +++ b/effectful/internals/disassembly.py @@ -219,7 +219,12 @@ def _wrapper( def _symbolic_exec(code: types.CodeType) -> ast.expr: """Execute bytecode symbolically, following control flow.""" - state = ReconstructionState(code=code, stack=[Null()]) + state = ReconstructionState( + code=code, + stack=[Null(), Null()] + if PythonVersion(sys.version_info.minor) == PythonVersion.PY_312 + else [Null()], + ) instructions = state.instructions instrs_list = list(instructions.values()) next_instr = { @@ -237,7 +242,7 @@ def _symbolic_exec(code: types.CodeType) -> ast.expr: if loop_state[instr.offset] > 0: # Simulate iterator exhaustion - jump to FOR_ITER target state = OP_HANDLERS[instr.opname](state, instr, jump=True) - instr = instructions[instr.jump_target] + instr = instructions[instr.argval] else: # Continue loop - execute FOR_ITER handler state = OP_HANDLERS[instr.opname](state, instr, jump=False) @@ -257,12 +262,12 @@ def _symbolic_exec(code: types.CodeType) -> ast.expr: else: # Simulate not taking the jump - continue to next instruction state = OP_HANDLERS[instr.opname](state, instr, jump=False) - instr = instructions[instr.jump_target] + instr = instructions[instr.argval] branch_state[instr.offset] += 1 elif instr.opname in {"JUMP_BACKWARD", "JUMP_FORWARD"}: # JUMP_BACKWARD: loop back to FOR_ITER state = OP_HANDLERS[instr.opname](state, instr, jump=True) - instr = instructions[instr.jump_target] + instr = instructions[instr.argval] elif instr.opname in {"RETURN_VALUE", "RETURN_CONST"}: # RETURN_VALUE ends execution state = OP_HANDLERS[instr.opname](state, instr) @@ -289,12 +294,11 @@ def handle_return_generator_312( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: # RETURN_GENERATOR is the first instruction in generator expressions in Python 3.13+ - assert isinstance(state.stack[-1], Placeholder), ( + assert len(state.stack) == 2 and all(isinstance(x, Null) for x in state.stack), ( "RETURN_GENERATOR must be the first instruction" ) new_result = ast.GeneratorExp(elt=Placeholder(), generators=[]) - new_stack = state.stack[:-1] + [new_result] - return replace(state, stack=new_stack) + return replace(state, stack=[new_result, Null()]) @register_handler("RETURN_GENERATOR", version=PythonVersion.PY_313) @@ -302,6 +306,9 @@ def handle_return_generator( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: # RETURN_GENERATOR is the first instruction in generator expressions in Python 3.13+ + assert len(state.stack) == 1 and isinstance(state.stack[0], Null), ( + "RETURN_GENERATOR must be the first instruction" + ) return replace( state, stack=[ast.GeneratorExp(elt=Placeholder(), generators=[]), Null()] ) From fab278b2cac2eecbbecab2f5971d5998636a1309 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 22 Aug 2025 16:29:09 -0400 Subject: [PATCH 100/106] sanity check --- effectful/internals/disassembly.py | 1 + 1 file changed, 1 insertion(+) diff --git a/effectful/internals/disassembly.py b/effectful/internals/disassembly.py index a348a1a3..1ed51257 100644 --- a/effectful/internals/disassembly.py +++ b/effectful/internals/disassembly.py @@ -195,6 +195,7 @@ def _wrapper( if instr.opcode in dis.hasjrel: assert jump is not None, f"Jump op {opname} must have jump state" + assert instr.argval == getattr(instr, "jump_target", instr.argval) new_state = handler(state, instr, jump=jump) else: assert jump is None, f"Non-jump op {opname} must not have jump state" From 6894cc97d8797a966ded50eb152e97e0bd309c10 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 22 Aug 2025 16:32:20 -0400 Subject: [PATCH 101/106] constants --- effectful/internals/disassembly.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/effectful/internals/disassembly.py b/effectful/internals/disassembly.py index 1ed51257..512ba467 100644 --- a/effectful/internals/disassembly.py +++ b/effectful/internals/disassembly.py @@ -218,6 +218,17 @@ def _wrapper( return handler # return the original handler for multiple decorator usage +LOOP_OPS = {"FOR_ITER"} +BRANCH_OPS = { + "POP_JUMP_IF_TRUE", + "POP_JUMP_IF_FALSE", + "POP_JUMP_IF_NOT_NONE", + "POP_JUMP_IF_NONE", +} +RETURN_OPS = {"RETURN_VALUE", "RETURN_CONST"} +JUMP_OPS = {dis.opname[d] for d in dis.hasjrel} - LOOP_OPS - BRANCH_OPS - RETURN_OPS + + def _symbolic_exec(code: types.CodeType) -> ast.expr: """Execute bytecode symbolically, following control flow.""" state = ReconstructionState( @@ -237,7 +248,7 @@ def _symbolic_exec(code: types.CodeType) -> ast.expr: instr = next(iter(instructions.values())) # Start at first instruction while instr is not None: - if instr.opname in {"FOR_ITER"}: + if instr.opname in LOOP_OPS: # FOR_ITER has two paths: continue loop or exit when exhausted # For reconstruction, we execute the continue path once if loop_state[instr.offset] > 0: @@ -249,12 +260,7 @@ def _symbolic_exec(code: types.CodeType) -> ast.expr: state = OP_HANDLERS[instr.opname](state, instr, jump=False) loop_state[instr.offset] += 1 instr = instructions[next_instr[instr.offset]] - elif instr.opname in { - "POP_JUMP_IF_TRUE", - "POP_JUMP_IF_FALSE", - "POP_JUMP_IF_NOT_NONE", - "POP_JUMP_IF_NONE", - }: + elif instr.opname in BRANCH_OPS: # POP_JUMP_IF_*: conditional jump, follow the jump path once if branch_state[instr.offset] > 0: # Take the jump - execute the POP_JUMP_IF_* handler @@ -265,11 +271,11 @@ def _symbolic_exec(code: types.CodeType) -> ast.expr: state = OP_HANDLERS[instr.opname](state, instr, jump=False) instr = instructions[instr.argval] branch_state[instr.offset] += 1 - elif instr.opname in {"JUMP_BACKWARD", "JUMP_FORWARD"}: + elif instr.opname in JUMP_OPS: # JUMP_BACKWARD: loop back to FOR_ITER state = OP_HANDLERS[instr.opname](state, instr, jump=True) instr = instructions[instr.argval] - elif instr.opname in {"RETURN_VALUE", "RETURN_CONST"}: + elif instr.opname in RETURN_OPS: # RETURN_VALUE ends execution state = OP_HANDLERS[instr.opname](state, instr) instr = None From dcf395d45590baadde1384d04b2d597ef5da597b Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 25 Aug 2025 12:33:58 -0400 Subject: [PATCH 102/106] move logic into handler --- effectful/internals/disassembly.py | 151 ++++++++++++++--------------- 1 file changed, 73 insertions(+), 78 deletions(-) diff --git a/effectful/internals/disassembly.py b/effectful/internals/disassembly.py index 512ba467..b2a0a904 100644 --- a/effectful/internals/disassembly.py +++ b/effectful/internals/disassembly.py @@ -17,6 +17,7 @@ import ast import collections +import collections.abc import copy import dis import enum @@ -124,9 +125,15 @@ class ReconstructionState: """ code: types.CodeType - stack: list[ast.expr] + instruction: dis.Instruction + + stack: list[ast.expr] = field(default_factory=list) result: ast.expr = field(default_factory=Placeholder) + loops: dict[int, int] = field(default_factory=collections.Counter) + branches: dict[int, int] = field(default_factory=collections.Counter) + finished: bool = field(default=False) + @property def instructions(self) -> collections.OrderedDict[int, dis.Instruction]: """Get the bytecode instructions for the current code object.""" @@ -134,6 +141,11 @@ def instructions(self) -> collections.OrderedDict[int, dis.Instruction]: (instr.offset, instr) for instr in dis.get_instructions(self.code) ) + @property + def next_instructions(self) -> collections.abc.Mapping[int, dis.Instruction]: + instrs_list = list(self.instructions.values()) + return {i1.offset: i2 for i1, i2 in zip(instrs_list[:-1], instrs_list[1:])} + # Python version enum for version-specific handling class PythonVersion(enum.Enum): @@ -183,23 +195,58 @@ def register_handler( raise ValueError(f"Handler for '{opname}' (version {version}) already exists.") if dis.opmap[opname] in dis.hasjrel: - assert "jump" in inspect.signature(handler).parameters + assert opname in LOOP_OPS | BRANCH_OPS | JUMP_OPS + else: + assert opname not in LOOP_OPS | BRANCH_OPS | JUMP_OPS @functools.wraps(handler) def _wrapper( - state: ReconstructionState, instr: dis.Instruction, *, jump: bool | None = None + state: ReconstructionState, + instr: dis.Instruction, ) -> ReconstructionState: assert instr.opname == opname, ( f"Handler for '{opname}' called with wrong instruction" ) + assert not state.finished, "Cannot process instruction on finished state" - if instr.opcode in dis.hasjrel: - assert jump is not None, f"Jump op {opname} must have jump state" - assert instr.argval == getattr(instr, "jump_target", instr.argval) - new_state = handler(state, instr, jump=jump) + new_state = handler(state, instr) + + jump: bool | None + if instr.opname in LOOP_OPS: + if state.loops[instr.offset] > 0: + new_state = replace( + new_state, instruction=state.instructions[instr.argval] + ) + jump = True + else: + new_state = replace( + new_state, instruction=state.next_instructions[instr.offset] + ) + new_state.loops[instr.offset] += 1 + jump = False + elif instr.opname in BRANCH_OPS: + if state.branches[instr.offset] > 0: + new_state = replace( + new_state, instruction=state.next_instructions[instr.offset] + ) + jump = False + else: + new_state = replace( + new_state, instruction=state.instructions[instr.argval] + ) + new_state.branches[instr.offset] += 1 + jump = True + elif instr.opname in JUMP_OPS: + new_state = replace(new_state, instruction=state.instructions[instr.argval]) + jump = True + elif instr.opname not in RETURN_OPS and instr.offset in state.next_instructions: + new_state = replace( + new_state, instruction=state.next_instructions[instr.offset] + ) + jump = None else: - assert jump is None, f"Non-jump op {opname} must not have jump state" - new_state = handler(state, instr) + new_state = replace(new_state, finished=True) + jump = None # post-condition: check stack effect expected_stack_effect = dis.stack_effect(instr.opcode, instr.arg, jump=jump) @@ -233,60 +280,14 @@ def _symbolic_exec(code: types.CodeType) -> ast.expr: """Execute bytecode symbolically, following control flow.""" state = ReconstructionState( code=code, + instruction=next(iter(dis.get_instructions(code))), stack=[Null(), Null()] if PythonVersion(sys.version_info.minor) == PythonVersion.PY_312 else [Null()], ) - instructions = state.instructions - instrs_list = list(instructions.values()) - next_instr = { - i1.offset: i2.offset for i1, i2 in zip(instrs_list[:-1], instrs_list[1:]) - } - loop_state: collections.Counter[int] = collections.Counter() - branch_state: collections.Counter[int] = collections.Counter() - - instr = next(iter(instructions.values())) # Start at first instruction - while instr is not None: - if instr.opname in LOOP_OPS: - # FOR_ITER has two paths: continue loop or exit when exhausted - # For reconstruction, we execute the continue path once - if loop_state[instr.offset] > 0: - # Simulate iterator exhaustion - jump to FOR_ITER target - state = OP_HANDLERS[instr.opname](state, instr, jump=True) - instr = instructions[instr.argval] - else: - # Continue loop - execute FOR_ITER handler - state = OP_HANDLERS[instr.opname](state, instr, jump=False) - loop_state[instr.offset] += 1 - instr = instructions[next_instr[instr.offset]] - elif instr.opname in BRANCH_OPS: - # POP_JUMP_IF_*: conditional jump, follow the jump path once - if branch_state[instr.offset] > 0: - # Take the jump - execute the POP_JUMP_IF_* handler - state = OP_HANDLERS[instr.opname](state, instr, jump=True) - instr = instructions[next_instr[instr.offset]] - else: - # Simulate not taking the jump - continue to next instruction - state = OP_HANDLERS[instr.opname](state, instr, jump=False) - instr = instructions[instr.argval] - branch_state[instr.offset] += 1 - elif instr.opname in JUMP_OPS: - # JUMP_BACKWARD: loop back to FOR_ITER - state = OP_HANDLERS[instr.opname](state, instr, jump=True) - instr = instructions[instr.argval] - elif instr.opname in RETURN_OPS: - # RETURN_VALUE ends execution - state = OP_HANDLERS[instr.opname](state, instr) - instr = None - else: - # All other operations: handle normally - state = OP_HANDLERS[instr.opname](state, instr) - instr = ( - instructions[next_instr[instr.offset]] - if instr.offset in next_instr - else None - ) + while not state.finished: + state = OP_HANDLERS[state.instruction.opname](state, state.instruction) return state.result @@ -518,13 +519,13 @@ def handle_return_const( @register_handler("FOR_ITER", version=PythonVersion.PY_312) @register_handler("FOR_ITER", version=PythonVersion.PY_313) def handle_for_iter( - state: ReconstructionState, instr: dis.Instruction, *, jump: bool + state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: # FOR_ITER pops an iterator from the stack and pushes the next item # If the iterator is exhausted, it jumps to the target instruction assert len(state.stack) > 0, "FOR_ITER must have an iterator on the stack" - if jump: + if state.loops[instr.offset] > 0: return replace(state, stack=state.stack + [Null()]) # The iterator should be on top of stack @@ -570,22 +571,20 @@ def handle_get_iter( @register_handler("JUMP_FORWARD", version=PythonVersion.PY_312) @register_handler("JUMP_FORWARD", version=PythonVersion.PY_313) def handle_jump_forward( - state: ReconstructionState, instr: dis.Instruction, *, jump: bool = True + state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: # JUMP_FORWARD is used to jump forward in the code # In generator expressions, this is often used to skip code in conditional logic - assert jump, "JUMP_FORWARD always jumps" return state @register_handler("JUMP_BACKWARD", version=PythonVersion.PY_312) @register_handler("JUMP_BACKWARD", version=PythonVersion.PY_313) def handle_jump_backward( - state: ReconstructionState, instr: dis.Instruction, *, jump: bool = True + state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: # JUMP_BACKWARD is used to jump back to the beginning of a loop (replaces JUMP_ABSOLUTE in 3.13) # In generator expressions, this typically indicates the end of the loop body - assert jump, "JUMP_BACKWARD always jumps" return state @@ -1639,8 +1638,6 @@ def _handle_pop_jump_if( f_condition: Callable[[ast.expr], ast.expr], state: ReconstructionState, instr: dis.Instruction, - *, - jump: bool, ) -> ReconstructionState: # Generic handler for POP_JUMP_IF_* instructions # Pops a value from the stack and jumps if the condition is met @@ -1662,29 +1659,29 @@ def _handle_pop_jump_if( @register_handler("POP_JUMP_IF_TRUE", version=PythonVersion.PY_312) @register_handler("POP_JUMP_IF_TRUE", version=PythonVersion.PY_313) def handle_pop_jump_if_true( - state: ReconstructionState, instr: dis.Instruction, *, jump: bool + state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: # POP_JUMP_IF_TRUE pops a value from the stack and jumps if it's true # In Python 3.13, this is used for filter conditions where True means continue - return _handle_pop_jump_if(lambda c: c, state, instr, jump=jump) + return _handle_pop_jump_if(lambda c: c, state, instr) @register_handler("POP_JUMP_IF_FALSE", version=PythonVersion.PY_312) @register_handler("POP_JUMP_IF_FALSE", version=PythonVersion.PY_313) def handle_pop_jump_if_false( - state: ReconstructionState, instr: dis.Instruction, *, jump: bool + state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: # POP_JUMP_IF_FALSE pops a value from the stack and jumps if it's false # In comprehensions, this is used for filter conditions return _handle_pop_jump_if( - lambda c: ast.UnaryOp(op=ast.Not(), operand=c), state, instr, jump=jump + lambda c: ast.UnaryOp(op=ast.Not(), operand=c), state, instr ) @register_handler("POP_JUMP_IF_NONE", version=PythonVersion.PY_312) @register_handler("POP_JUMP_IF_NONE", version=PythonVersion.PY_313) def handle_pop_jump_if_none( - state: ReconstructionState, instr: dis.Instruction, *, jump: bool + state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: # POP_JUMP_IF_NONE pops a value and jumps if it's None return _handle_pop_jump_if( @@ -1693,14 +1690,13 @@ def handle_pop_jump_if_none( ), state, instr, - jump=jump, ) @register_handler("POP_JUMP_IF_NOT_NONE", version=PythonVersion.PY_312) @register_handler("POP_JUMP_IF_NOT_NONE", version=PythonVersion.PY_313) def handle_pop_jump_if_not_none( - state: ReconstructionState, instr: dis.Instruction, *, jump: bool + state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: # POP_JUMP_IF_NOT_NONE pops a value and jumps if it's not None return _handle_pop_jump_if( @@ -1709,14 +1705,13 @@ def handle_pop_jump_if_not_none( ), state, instr, - jump=jump, ) @register_handler("SEND", version=PythonVersion.PY_312) @register_handler("SEND", version=PythonVersion.PY_313) def handle_send( - state: ReconstructionState, instr: dis.Instruction, *, jump: bool + state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: raise TypeError("SEND instruction should not appear in generator comprehensions") @@ -1724,7 +1719,7 @@ def handle_send( @register_handler("JUMP_BACKWARD_NO_INTERRUPT", version=PythonVersion.PY_312) @register_handler("JUMP_BACKWARD_NO_INTERRUPT", version=PythonVersion.PY_313) def handle_jump_backward_no_interrupt( - state: ReconstructionState, instr: dis.Instruction, *, jump: bool + state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: raise TypeError( "JUMP_BACKWARD_NO_INTERRUPT instruction should not appear in generator comprehensions" @@ -1734,7 +1729,7 @@ def handle_jump_backward_no_interrupt( @register_handler("JUMP", version=PythonVersion.PY_312) @register_handler("JUMP", version=PythonVersion.PY_313) def handle_jump( - state: ReconstructionState, instr: dis.Instruction, *, jump: bool + state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: raise TypeError("JUMP instruction should not appear in generator comprehensions") @@ -1742,7 +1737,7 @@ def handle_jump( @register_handler("JUMP_NO_INTERRUPT", version=PythonVersion.PY_312) @register_handler("JUMP_NO_INTERRUPT", version=PythonVersion.PY_313) def handle_jump_no_interrupt( - state: ReconstructionState, instr: dis.Instruction, *, jump: bool + state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: raise TypeError( "JUMP_NO_INTERRUPT instruction should not appear in generator comprehensions" From 2dfb7ba7516c13bd65856b6a1650422eff3ed0ce Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 25 Aug 2025 12:49:40 -0400 Subject: [PATCH 103/106] comment --- effectful/internals/disassembly.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/effectful/internals/disassembly.py b/effectful/internals/disassembly.py index b2a0a904..08d3ae58 100644 --- a/effectful/internals/disassembly.py +++ b/effectful/internals/disassembly.py @@ -211,7 +211,7 @@ def _wrapper( new_state = handler(state, instr) - jump: bool | None + jump: bool | None # argument to dis.stack_effect if instr.opname in LOOP_OPS: if state.loops[instr.offset] > 0: new_state = replace( From 7605d65463e4270720727078c284ef47f3ccb600 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 4 Sep 2025 13:39:42 -0400 Subject: [PATCH 104/106] partially support conditionals --- effectful/internals/disassembly.py | 172 ++++++++++++++++++++------- tests/test_internals_disassembler.py | 18 ++- 2 files changed, 145 insertions(+), 45 deletions(-) diff --git a/effectful/internals/disassembly.py b/effectful/internals/disassembly.py index 08d3ae58..1021c08b 100644 --- a/effectful/internals/disassembly.py +++ b/effectful/internals/disassembly.py @@ -46,6 +46,13 @@ def __init__(self, id=".0", ctx=ast.Load()): super().__init__(id=id, ctx=ctx) +class Skipped(ast.Name): + """Placeholder for skipped branches in if-expressions.""" + + def __init__(self, id: str, ctx=ast.Load()): + super().__init__(id=id, ctx=ctx) + + class Null(ast.Constant): """Placeholder for NULL values generated in bytecode.""" @@ -100,6 +107,20 @@ def inline(self, iterator: ast.expr) -> CompExp: return res +class ReplacePlaceholder(ast.NodeTransformer): + def __init__(self, value: ast.expr): + self.value = value + self._done = False + super().__init__() + + def visit(self, node): + if isinstance(node, Placeholder) and not self._done: + self._done = True + return self.value + else: + return self.generic_visit(node) + + @dataclass(frozen=True) class ReconstructionState: """State maintained during AST reconstruction from bytecode. @@ -131,9 +152,10 @@ class ReconstructionState: result: ast.expr = field(default_factory=Placeholder) loops: dict[int, int] = field(default_factory=collections.Counter) - branches: dict[int, int] = field(default_factory=collections.Counter) finished: bool = field(default=False) + branches: dict[int, int] = field(default_factory=collections.Counter) + @property def instructions(self) -> collections.OrderedDict[int, dis.Instruction]: """Get the bytecode instructions for the current code object.""" @@ -146,6 +168,24 @@ def next_instructions(self) -> collections.abc.Mapping[int, dis.Instruction]: instrs_list = list(self.instructions.values()) return {i1.offset: i2 for i1, i2 in zip(instrs_list[:-1], instrs_list[1:])} + @property + def is_filter(self) -> bool: + """Check if an instruction is a filter clause in a comprehension""" + return ( + self.instruction.opname in BRANCH_OPS + and self.next_instructions[self.instruction.offset].opname + == "JUMP_BACKWARD" + and self.instructions[ + self.next_instructions[self.instruction.offset].argval + ].opname + in LOOP_OPS + ) + + @property + def is_branch(self) -> bool: + """Check if an instruction is a branch in an if-expression""" + return self.instruction.opname in BRANCH_OPS and not self.is_filter + # Python version enum for version-specific handling class PythonVersion(enum.Enum): @@ -225,7 +265,7 @@ def _wrapper( new_state.loops[instr.offset] += 1 jump = False elif instr.opname in BRANCH_OPS: - if state.branches[instr.offset] > 0: + if state.branches.get(instr.offset, 0): new_state = replace( new_state, instruction=state.next_instructions[instr.offset] ) @@ -234,7 +274,6 @@ def _wrapper( new_state = replace( new_state, instruction=state.instructions[instr.argval] ) - new_state.branches[instr.offset] += 1 jump = True elif instr.opname in JUMP_OPS: new_state = replace(new_state, instruction=state.instructions[instr.argval]) @@ -278,18 +317,53 @@ def _wrapper( def _symbolic_exec(code: types.CodeType) -> ast.expr: """Execute bytecode symbolically, following control flow.""" - state = ReconstructionState( - code=code, - instruction=next(iter(dis.get_instructions(code))), - stack=[Null(), Null()] - if PythonVersion(sys.version_info.minor) == PythonVersion.PY_312 - else [Null()], + continuations: list[ReconstructionState] = [ + ReconstructionState( + code=code, + instruction=next(iter(dis.get_instructions(code))), + stack=[Placeholder(), Placeholder()] + if PythonVersion(sys.version_info.minor) == PythonVersion.PY_312 + else [Placeholder()], + ) + ] + + results: list[ast.expr] = [] + + while continuations: + state = continuations.pop() + while not state.finished: + if state.is_branch and not state.branches.get(state.instruction.offset, 0): + continuations.append( + replace( + state, branches=state.branches | {state.instruction.offset: 1} + ) + ) + state = OP_HANDLERS[state.instruction.opname](state, state.instruction) + results.append(state.result) + + assert results, "No results from symbolic execution" + return functools.reduce( + lambda a, b: _MergeBranches(a).visit(b), reversed(results[:-1]), results[-1] ) - while not state.finished: - state = OP_HANDLERS[state.instruction.opname](state, state.instruction) - return state.result +class _MergeBranches(ast.NodeTransformer): + def __init__(self, node_with_orelse: ast.expr): + self._orelses = { + n.body.id: n.orelse + for n in ast.walk(node_with_orelse) + if isinstance(n, ast.IfExp) and isinstance(n.body, Skipped) + } + assert self._orelses, "No orelse branches to merge" + super().__init__() + + def visit_IfExp(self, node: ast.IfExp): + if isinstance(node.orelse, Skipped) and node.orelse.id in self._orelses: + return ast.IfExp( + test=node.test, body=node.body, orelse=self._orelses[node.orelse.id] + ) + else: + return self.generic_visit(node) # ============================================================================ @@ -302,9 +376,9 @@ def handle_return_generator_312( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: # RETURN_GENERATOR is the first instruction in generator expressions in Python 3.13+ - assert len(state.stack) == 2 and all(isinstance(x, Null) for x in state.stack), ( - "RETURN_GENERATOR must be the first instruction" - ) + assert len(state.stack) == 2 and all( + isinstance(x, Null | Placeholder) for x in state.stack + ), "RETURN_GENERATOR must be the first instruction" new_result = ast.GeneratorExp(elt=Placeholder(), generators=[]) return replace(state, stack=[new_result, Null()]) @@ -314,7 +388,7 @@ def handle_return_generator( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: # RETURN_GENERATOR is the first instruction in generator expressions in Python 3.13+ - assert len(state.stack) == 1 and isinstance(state.stack[0], Null), ( + assert len(state.stack) == 1 and isinstance(state.stack[0], Null | Placeholder), ( "RETURN_GENERATOR must be the first instruction" ) return replace( @@ -334,11 +408,11 @@ def handle_yield_value( assert isinstance(new_result, ast.GeneratorExp), ( "YIELD_VALUE must be called after RETURN_GENERATOR" ) - assert isinstance(new_result.elt, Placeholder), ( - "YIELD_VALUE must be called before yielding" - ) assert len(new_result.generators) > 0, "YIELD_VALUE should have generators" - new_result.elt = ensure_ast(state.stack[-1]) + assert any(isinstance(x, Placeholder) for x in ast.walk(new_result.elt)) + new_result.elt = ReplacePlaceholder(ensure_ast(state.stack[-1])).visit( + new_result.elt + ) new_stack = [new_result] + state.stack[1:] return replace(state, stack=new_stack, result=new_result) @@ -377,11 +451,11 @@ def handle_list_append( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: assert isinstance(state.stack[-instr.argval - 1], ast.ListComp) - assert isinstance(state.stack[-instr.argval - 1].elt, Placeholder) # add the body to the comprehension comp: ast.ListComp = copy.deepcopy(state.stack[-instr.argval - 1]) - comp.elt = state.stack[-1] + assert any(isinstance(x, Placeholder) for x in ast.walk(comp.elt)) + comp.elt = ReplacePlaceholder(state.stack[-1]).visit(comp.elt) # swap the return value new_stack = state.stack[:-1] @@ -421,11 +495,11 @@ def handle_set_add( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: assert isinstance(state.stack[-instr.argval - 1], ast.SetComp) - assert isinstance(state.stack[-instr.argval - 1].elt, Placeholder) # add the body to the comprehension comp: ast.SetComp = copy.deepcopy(state.stack[-instr.argval - 1]) - comp.elt = state.stack[-1] + assert any(isinstance(x, Placeholder) for x in ast.walk(comp.elt)) + comp.elt = ReplacePlaceholder(state.stack[-1]).visit(comp.elt) # swap the return value new_stack = state.stack[:-1] @@ -471,13 +545,13 @@ def handle_map_add( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: assert isinstance(state.stack[-instr.argval - 2], ast.DictComp) - assert isinstance(state.stack[-instr.argval - 2].key, Placeholder) - assert isinstance(state.stack[-instr.argval - 2].value, Placeholder) # add the body to the comprehension comp: ast.DictComp = copy.deepcopy(state.stack[-instr.argval - 2]) - comp.key = state.stack[-2] - comp.value = state.stack[-1] + assert any(isinstance(x, Placeholder) for x in ast.walk(comp.key)) + assert any(isinstance(x, Placeholder) for x in ast.walk(comp.value)) + comp.key = ReplacePlaceholder(state.stack[-2]).visit(comp.key) + comp.value = ReplacePlaceholder(state.stack[-1]).visit(comp.value) # swap the return value new_stack = state.stack[:-2] @@ -497,7 +571,8 @@ def handle_return_value( state: ReconstructionState, instr: dis.Instruction ) -> ReconstructionState: assert isinstance(state.result, Placeholder) - new_result = ensure_ast(state.stack[-1]) + assert len(state.stack) == 2 + new_result = ReplacePlaceholder(ensure_ast(state.stack[-1])).visit(state.stack[-2]) new_stack = state.stack[:-1] return replace(state, stack=new_stack, result=new_result) @@ -1643,17 +1718,32 @@ def _handle_pop_jump_if( # Pops a value from the stack and jumps if the condition is met condition = f_condition(ensure_ast(state.stack[-1])) - for pos, item in zip(reversed(range(len(state.stack))), reversed(state.stack)): - if isinstance(item, CompExp) and isinstance( - getattr(item, "elt", getattr(item, "key", None)), Placeholder - ): - new_result = copy.deepcopy(item) - new_result.generators[-1].ifs.append(condition) - new_stack = state.stack[:pos] + [new_result] + state.stack[pos + 1 : -1] - return replace(state, stack=new_stack) - - # Not in a comprehension context - might be boolean logic - raise NotImplementedError("Lazy and+or behavior not implemented yet") + if state.is_filter: + for pos, item in zip(reversed(range(len(state.stack))), reversed(state.stack)): + if isinstance(item, CompExp) and isinstance( + getattr(item, "elt", getattr(item, "key", None)), Placeholder + ): + new_result = copy.deepcopy(item) + new_result.generators[-1].ifs.append(condition) + new_stack = state.stack[:pos] + [new_result] + state.stack[pos + 1 : -1] + return replace(state, stack=new_stack) + raise TypeError("No comprehension context found for filter condition") + else: + for pos, item in zip(reversed(range(len(state.stack))), reversed(state.stack)): + if any(isinstance(x, Placeholder) for x in ast.walk(item)): + body: Skipped | Placeholder + orelse: Skipped | Placeholder + if state.branches.get(instr.offset, 0): + # we don't jump, so we're in the orelse branch + body, orelse = Skipped(id=f".SKIPPED_{instr.offset}"), Placeholder() + else: + # we jump, so we're in the body branch + body, orelse = Placeholder(), Skipped(id=f".SKIPPED_{instr.offset}") + new_ifexp = ast.IfExp(test=condition, body=body, orelse=orelse) + new_result = ReplacePlaceholder(new_ifexp).visit(copy.deepcopy(item)) + new_stack = state.stack[:pos] + [new_result] + state.stack[pos + 1 : -1] + return replace(state, stack=new_stack) + raise TypeError("No placeholder found for conditional expression") @register_handler("POP_JUMP_IF_TRUE", version=PythonVersion.PY_312) @@ -1865,7 +1955,7 @@ def _ensure_ast_codeobj(value: types.CodeType) -> ast.Lambda | CompLambda: "Final return value must not contain statement nodes" ) assert not any( - isinstance(x, Placeholder | Null | CompLambda | ConvertedValue) + isinstance(x, Placeholder | Skipped | Null | CompLambda | ConvertedValue) for x in ast.walk(result) ), "Final return value must not contain temporary nodes" assert not any(x.arg == ".0" for x in ast.walk(result) if isinstance(x, ast.arg)), ( diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index e3ed8eb3..b85bde2b 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -536,7 +536,6 @@ def test_different_comprehension_types(genexpr): # ============================================================================ -@pytest.mark.xfail(reason="Conditional expressions not yet fully supported") @pytest.mark.parametrize( "genexpr", [ @@ -562,6 +561,18 @@ def test_different_comprehension_types(genexpr): (lambda x: (x**3) if not (x < 0 or x > 10) else (x**0.5))(xi) for xi in range(-5, 15) ), + ], +) +def test_conditional_expressions_simple_no_comprehension(genexpr): + """Test reconstruction of simple conditional expressions isolated from comprehension bodies.""" + ast_node = disassemble(genexpr) + assert_ast_equivalent(genexpr, ast_node) + + +@pytest.mark.xfail(reason="Nested conditional expressions not yet fully supported") +@pytest.mark.parametrize( + "genexpr", + [ # nested conditional expressions ( (lambda x: (x + 1) if x < 5 else ((x - 1) if x < 10 else (x * 2)))(xi) @@ -581,13 +592,12 @@ def test_different_comprehension_types(genexpr): ), ], ) -def test_conditional_expressions_no_comprehension(genexpr): - """Test reconstruction of conditional expressions isolated from comprehension bodies.""" +def test_conditional_expressions_nested_no_comprehension(genexpr): + """Test reconstruction of nested conditional expressions isolated from comprehension bodies.""" ast_node = disassemble(genexpr) assert_ast_equivalent(genexpr, ast_node) -@pytest.mark.xfail(reason="Conditional expressions not yet fully supported") @pytest.mark.parametrize( "genexpr", [ From 3ca82eed86873f6e0275518bf45328d1480091fe Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 4 Sep 2025 13:43:07 -0400 Subject: [PATCH 105/106] call_kw handler --- effectful/internals/disassembly.py | 43 ++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/effectful/internals/disassembly.py b/effectful/internals/disassembly.py index 1021c08b..46d0c41c 100644 --- a/effectful/internals/disassembly.py +++ b/effectful/internals/disassembly.py @@ -1263,6 +1263,49 @@ def handle_call( return replace(state, stack=new_stack) +@register_handler("CALL_KW", version=PythonVersion.PY_313) +def handle_call_kw( + state: ReconstructionState, instr: dis.Instruction +) -> ReconstructionState: + # CALL_KW pops function, arguments, and keyword names from stack + assert instr.arg is not None + arg_count: int = instr.arg + + func = ensure_ast(state.stack[-arg_count - 3]) + kw_names = state.stack[-1] + assert isinstance(kw_names, ast.Tuple), "Expected a tuple of keyword names" + + # Pop arguments, function, and keyword names + args = ( + [ensure_ast(arg) for arg in state.stack[-arg_count - 2 : -1]] + if arg_count > 0 + else [] + ) + if not isinstance(state.stack[-arg_count - 3], Null): + args = [ensure_ast(state.stack[-arg_count - 3])] + args + + keywords = [] + for i, kw in enumerate(reversed(kw_names.elts)): + kw_name = ( + kw.s if isinstance(kw, ast.Constant) and isinstance(kw.s, str) else None + ) + if kw_name is None: + raise TypeError("Keyword names must be strings") + kw_value = ensure_ast(state.stack[-1 - i]) + keywords.append(ast.keyword(arg=kw_name, value=kw_value)) + keywords.reverse() + + new_stack = state.stack[: -arg_count - 3] + if isinstance(func, CompLambda): + assert len(args) == 1 and len(keywords) == 0 + return replace(state, stack=new_stack + [func.inline(args[0])]) + else: + # Create function call AST + call_node = ast.Call(func=func, args=args, keywords=keywords) + new_stack = new_stack + [call_node] + return replace(state, stack=new_stack) + + @register_handler("MAKE_FUNCTION", version=PythonVersion.PY_312) def handle_make_function_312( state: ReconstructionState, instr: dis.Instruction From e962796b85b19644ab27e16c66b37f9e7245f203 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 4 Sep 2025 13:56:38 -0400 Subject: [PATCH 106/106] nit --- effectful/internals/disassembly.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/effectful/internals/disassembly.py b/effectful/internals/disassembly.py index 46d0c41c..b4a00ae4 100644 --- a/effectful/internals/disassembly.py +++ b/effectful/internals/disassembly.py @@ -349,10 +349,12 @@ def _symbolic_exec(code: types.CodeType) -> ast.expr: class _MergeBranches(ast.NodeTransformer): def __init__(self, node_with_orelse: ast.expr): - self._orelses = { + self._orelses: dict[str, ast.expr] = { n.body.id: n.orelse for n in ast.walk(node_with_orelse) - if isinstance(n, ast.IfExp) and isinstance(n.body, Skipped) + if isinstance(n, ast.IfExp) + and isinstance(n.body, Skipped) + and not isinstance(n.orelse, Skipped | Placeholder) } assert self._orelses, "No orelse branches to merge" super().__init__()