diff --git a/pnc/src/Binder.pn b/pnc/src/Binder.pn index 7bfcbca..54ad357 100644 --- a/pnc/src/Binder.pn +++ b/pnc/src/Binder.pn @@ -630,12 +630,12 @@ class Binder( options match { case FieldOptions.TypeAndExpression(fieldType, expression) => // TODO: this expression needs to be added to the symbol's constructor body - val expr = exprBinder.bindConversionExpr(expression, fieldType, scope) + val expr = exprBinder.check(expression, fieldType, scope) setSymbolType(symbol, fieldType) case FieldOptions.TypeOnly(fieldType) => setSymbolType(symbol, fieldType) case FieldOptions.ExpressionOnly(expression) => - val expr = exprBinder.bind(expression, scope) + val expr = exprBinder.infer(expression, scope) // TODO: this expression needs to be added to the symbol's constructor body val returnType = getType(expr) setSymbolType(symbol, returnType) @@ -657,9 +657,9 @@ class Binder( case Option.Some(expression) => val boundExpr = returnType match { - case Option.None => exprBinder.bind(expression, methodScope) + case Option.None => exprBinder.infer(expression, methodScope) case Option.Some(toType) => - exprBinder.bindConversionExpr(expression, toType, methodScope) + exprBinder.check(expression, toType, methodScope) } functionBodies = functionBodies.put(symbol, boundExpr) diff --git a/pnc/src/ExprBinder.pn b/pnc/src/ExprBinder.pn index f5d681f..bae37d8 100644 --- a/pnc/src/ExprBinder.pn +++ b/pnc/src/ExprBinder.pn @@ -98,6 +98,79 @@ class ExprBinder( * same symbol (assuming no polymorphism in this example). */ + def isSubtype(subType: Type, superType: Type): bool = { + // Handle Error types + if (subType == Type.Error || superType == Type.Error) { + return false + } + + // Base cases + if (subType == superType) { + return true + } + if (subType == binder.neverType) { + return true + } + if (superType == binder.anyType) { + return true + } + + Tuple2(subType, superType) match { + // Function subtyping + case Tuple2( + Type.Function(_, subParams, subReturn), + Type.Function(_, superParams, superReturn) + ) => + if (subParams.length != superParams.length) { + return false + } + val paramsAreSubtypes = Tuple2(subParams, superParams) match { + case Tuple2( + List.Cons(subParam, subTail), + List.Cons(superParam, superTail) + ) => + isSubtype(superParam.typ, subParam.typ) && // Contravariant + isSubtypeList(subTail, superTail) + case Tuple2(List.Nil, List.Nil) => true + case _ => false + } + paramsAreSubtypes && isSubtype(subReturn, superReturn) // Covariant + + // Array subtyping + case Tuple2( + Type.Class(_, _, "Array", List.Cons(subElemType, List.Nil), _), + Type.Class(_, _, "Array", List.Cons(superElemType, List.Nil), _) + ) => + isSubtype(subElemType, superElemType) + + case _ => false + } + } + + def isSubtypeList( + value: List[BoundParameter], + value1: List[BoundParameter] + ): bool = { + Tuple2(value, value1) match { + case Tuple2( + List.Cons(head1, tail1), + List.Cons(head2, tail2) + ) => + isSubtype(head1.typ, head2.typ) && isSubtypeList(tail1, tail2) + case Tuple2(List.Nil, List.Nil) => true + case _ => false + } + } + + def subsume(expr: BoundExpression, expectedType: Type): BoundExpression = { + val exprType = binder.getType(expr) + if (isSubtype(exprType, expectedType)) { + expr + } else { + bindConversion(expr, expectedType, false) + } + } + def bindUnaryOperator(token: SyntaxToken): UnaryOperatorKind = { token.kind match { case SyntaxKind.BangToken => UnaryOperatorKind.LogicalNegation @@ -138,46 +211,320 @@ class ExprBinder( } } - def bind(expr: Expression, scope: Scope): BoundExpression = { + def check( + expr: Expression, + expectedType: Type, + scope: Scope + ): BoundExpression = { expr match { case node: Expression.ArrayCreation => - bindArrayCreationExpression(node, scope) + checkArrayCreation(node, expectedType, scope) case node: Expression.Assignment => - bindAssignmentExpression(node, scope) + checkAssignment(node, expectedType, scope) + case node: Expression.Binary => checkBinary(node, expectedType, scope) + case node: Expression.Block => checkBlock(node, expectedType, scope) + case node: Expression.Call => checkCall(node, expectedType, scope) + case node: Expression.Cast => checkCast(node, expectedType, scope) + case node: Expression.For => checkFor(node, expectedType, scope) + case node: Expression.Group => checkGroup(node, expectedType, scope) + case node: Expression.IdentifierName => + checkIdentifierName(node, expectedType, scope) + case node: Expression.If => checkIf(node, expectedType, scope) + case node: Expression.Is => checkIs(node, expectedType, scope) + case node: Expression.Literal => checkLiteral(node, expectedType, scope) + case node: Expression.MemberAccess => + checkMemberAccess(node, expectedType, scope) + case node: Expression.Match => checkMatch(node, expectedType, scope) + case node: Expression.New => checkNew(node, expectedType, scope) + case node: Expression.Unary => checkUnary(node, expectedType, scope) + case node: Expression.Unit => checkUnit(node, expectedType, scope) + case node: Expression.While => checkWhile(node, expectedType, scope) + } + } + + def checkArrayCreation( + expr: Expression.ArrayCreation, + expectedType: Type, + scope: Scope + ): BoundExpression = ??? + + def checkAssignment( + expr: Expression.Assignment, + expectedType: Type, + scope: Scope + ): BoundExpression = { + val lhs = bindLHS(expr.left, scope) + val rhs = infer(expr.right, scope) + Tuple2(lhs, rhs) match { + case Tuple2(Result.Error(error), _) => + error + case Tuple2(_, BoundExpression.Error(_)) => + rhs + case Tuple2(Result.Success(lhs), rhs) => + getLHSType(lhs) match { + case Type.Error(message) => BoundExpression.Error(message) + case lhsType => + BoundExpression.Assignment( + AstUtils.locationOfExpression(expr), + lhs, + bindConversion(rhs, lhsType, false) + ) + } + + case _ => + val location = AstUtils.locationOfExpression(expr.left) + diagnosticBag.reportExpressionIsNotAssignable(location) + BoundExpression.Error("Expression is not assignable: " + location) + } + } + + def checkBinary( + expr: Expression.Binary, + expectedType: Type, + scope: Scope + ): BoundExpression = { + val inferred = inferBinary(expr, scope) + subsume(inferred, expectedType) + } + + def checkBlock( + node: Expression.Block, + expectedType: Type, + scope: Scope + ): BoundExpression = { + val block = scope.newBlock() + val statements = bindStatements(node.block.statements, block) + val expr = node.block.expression match { + case Option.None => + BoundExpression.Unit(TextLocationFactory.empty()) + case Option.Some(value) => check(value, expectedType, block) + } + + if (expr == BoundExpression.Error) { + expr + } else { + BoundExpression.Block(statements, expr) + } + } + + def checkCall( + expr: Expression.Call, + expectedType: Type, + scope: Scope + ): BoundExpression = { + inferCall(expr, scope) match { + case Result.Error(value) => value + case Result.Success(value) => + val inferred = convertLHSToExpression(value) + subsume(inferred, expectedType) + } + } + + def checkCast( + expr: Expression.Cast, + expectedType: Type, + scope: Scope + ): BoundExpression = { + val inferred = inferCast(expr, scope) + subsume(inferred, expectedType) + } + + def checkFor( + expr: Expression.For, + expectedType: Type, + scope: Scope + ): BoundExpression = { + // For expressions always return unit, so just infer and subsume + val inferred = inferForExpression(expr, scope) + subsume(inferred, expectedType) + } + + def checkGroup( + expr: Expression.Group, + expectedType: Type, + scope: Scope + ): BoundExpression = + check(expr.expression, expectedType, scope) + + def checkIdentifierName( + expr: Expression.IdentifierName, + expectedType: Type, + scope: Scope + ): BoundExpression = { + inferIdentifierName(expr, scope) match { + case Result.Error(value) => value + case Result.Success(value) => + subsume(value, expectedType) + } + } + + def checkIf( + expr: Expression.If, + expectedType: Type, + scope: Scope + ): BoundExpression = { + val cond = check(expr.condition, binder.boolType, scope) + val thenBranch = check(expr.thenExpr, expectedType, scope) + val elseBranch = expr.elseExpr match { + case Option.None => + // If there's no else branch, we assume it's a unit type + Option.None + case Option.Some(elseExpr) => + Option.Some(check(elseExpr.expression, expectedType, scope)) + } + + BoundExpression.If( + AstUtils.locationOfExpression(expr), + cond, + thenBranch, + elseBranch, + expectedType + ) + } + + def checkIs( + expr: Expression.Is, + expectedType: Type, + scope: Scope + ): BoundExpression = { + val inferred = inferIsExpression(expr, scope) + subsume(inferred, expectedType) + } + + def checkLiteral( + expr: Expression.Literal, + expectedType: Type, + scope: Scope + ): BoundExpression = { + val inferred = inferLiteral(expr, scope) + subsume(inferred, expectedType) + } + + def checkMemberAccess( + expr: Expression.MemberAccess, + expectedType: Type, + scope: Scope + ): BoundExpression = { + inferMemberAccess(expr, scope) match { + case Result.Error(value) => value + case Result.Success(value) => + subsume(value, expectedType) + } + } + + def checkMatch( + expr: Expression.Match, + expectedType: Type, + scope: Scope + ): BoundExpression = { + + // Bind the expression being matched against + val matchedExpr = infer(expr.expression, scope) + + matchedExpr match { + case error: BoundExpression.Error => error + case _ => + // Bind all the match cases + checkMatchCases( + expr.cases.head, + expr.cases.tail, + expectedType, + scope + ) match { + case Result.Error(value) => value + case Result.Success(boundCases) => + val location = AstUtils.locationOfExpression(expr) + + BoundExpression.Match( + location, + expectedType, + matchedExpr, + boundCases + ) + } + } + } + + def checkNew( + expr: Expression.New, + expectedType: Type, + scope: Scope + ): BoundExpression = { + inferNew(expr, scope) match { + case Result.Error(value) => value + case Result.Success(value) => + val inferred = convertLHSToExpression(value) + subsume(inferred, expectedType) + } + } + + def checkUnary( + expr: Expression.Unary, + expectedType: Type, + scope: Scope + ): BoundExpression = { + val inferred = inferUnary(expr, scope) + subsume(inferred, expectedType) + } + + def checkUnit( + expr: Expression.Unit, + expectedType: Type, + scope: Scope + ): BoundExpression = { + val inferred = inferUnit(expr, scope) + subsume(inferred, expectedType) + } + + def checkWhile( + expr: Expression.While, + expectedType: Type, + scope: Scope + ): BoundExpression = { + val inferred = inferWhileExpression(expr, scope) + subsume(inferred, expectedType) + } + + def infer(expr: Expression, scope: Scope): BoundExpression = { + expr match { + case node: Expression.ArrayCreation => + inferArrayCreationExpression(node, scope) + case node: Expression.Assignment => + inferAssignmentExpression(node, scope) case node: Expression.Binary => - bindBinaryExpression(node, scope) - case node: Expression.Block => bindBlockExpression(node, scope) + inferBinary(node, scope) + case node: Expression.Block => inferBlock(node, scope) case node: Expression.Call => - bindCallExpression(node, scope) match { + inferCall(node, scope) match { case Result.Error(value) => value case Result.Success(value) => convertLHSToExpression(value) } - case node: Expression.Cast => bindCast(node, scope) - case node: Expression.For => bindForExpression(node, scope) - case node: Expression.Group => bindGroupExpression(node, scope) + case node: Expression.Cast => inferCast(node, scope) + case node: Expression.For => inferForExpression(node, scope) + case node: Expression.Group => inferGroup(node, scope) case node: Expression.IdentifierName => - bindIdentifierName(node, scope) match { + inferIdentifierName(node, scope) match { case Result.Error(value) => value case Result.Success(value) => value } - case node: Expression.If => bindIf(node, scope) - case node: Expression.Is => bindIsExpression(node, scope) + case node: Expression.If => inferIf(node, scope) + case node: Expression.Is => inferIsExpression(node, scope) case node: Expression.Literal => - bindLiteralExpression(node, scope) + inferLiteral(node, scope) case node: Expression.MemberAccess => - bindMemberAccessExpression(node, scope) match { + inferMemberAccess(node, scope) match { case Result.Error(value) => value case Result.Success(value) => value } - case node: Expression.Match => bindMatchExpression(node, scope) + case node: Expression.Match => inferMatchExpression(node, scope) case node: Expression.New => - bindNewExpression(node, scope) match { + inferNew(node, scope) match { case Result.Error(value) => value case Result.Success(value) => convertLHSToExpression(value) } - case node: Expression.Unary => bindUnaryExpression(node, scope) - case node: Expression.Unit => bindUnitExpression(node, scope) - case node: Expression.While => bindWhileExpression(node, scope) + case node: Expression.Unary => inferUnary(node, scope) + case node: Expression.Unit => inferUnit(node, scope) + case node: Expression.While => inferWhileExpression(node, scope) } } @@ -187,25 +534,25 @@ class ExprBinder( ): Result[BoundExpression.Error, BoundLeftHandSide] = { expr match { case node: Expression.IdentifierName => - bindIdentifierName(node, scope) match { + inferIdentifierName(node, scope) match { case Result.Error(value) => Result.Error(value) case Result.Success(value) => val location = AstUtils.locationOfExpression(node) Result.Success(BoundLeftHandSide.Variable(location, value.symbol)) } case node: Expression.MemberAccess => - bindMemberAccessExpression(node, scope) match { + inferMemberAccess(node, scope) match { case Result.Error(value) => Result.Error(value) case Result.Success(value) => Result.Success(BoundLeftHandSide.MemberAccess(value)) } case node: Expression.Call => - bindCallExpression(node, scope) match { + inferCall(node, scope) match { case Result.Error(value) => Result.Error(value) case Result.Success(value) => Result.Success(value) } case node: Expression.New => - bindNewExpression(node, scope) match { + inferNew(node, scope) match { case Result.Error(value) => Result.Error(value) case Result.Success(value) => Result.Success(value) } @@ -226,7 +573,7 @@ class ExprBinder( toType: Type, scope: Scope ): BoundExpression = { - val bound = bind(expr, scope) + val bound = infer(expr, scope) if (bound == BoundExpression.Error) { bound } else { @@ -289,7 +636,7 @@ class ExprBinder( BoundStatement.Error } - def bindArrayCreationExpression( + def inferArrayCreationExpression( node: Expression.ArrayCreation, scope: Scope ): BoundExpression = { @@ -302,7 +649,7 @@ class ExprBinder( // Bind the array size expression val sizeExpr = node.arrayRank match { case Option.Some(rankExpr) => - bindConversion(bind(rankExpr, scope), binder.intType, false) + bindConversion(infer(rankExpr, scope), binder.intType, false) case Option.None => // Default to size 0 if no size provided BoundExpression.Int(TextLocationFactory.empty(), 0) @@ -339,41 +686,32 @@ class ExprBinder( } - def bindAssignmentExpression( + def inferAssignmentExpression( node: Expression.Assignment, scope: Scope ): BoundExpression = { - val lhs = bindLHS(node.left, scope) - val rhs = bind(node.right, scope) - Tuple2(lhs, rhs) match { - case Tuple2(Result.Error(error), _) => - error - case Tuple2(_, BoundExpression.Error(_)) => - rhs - case Tuple2(Result.Success(lhs), rhs) => + bindLHS(node.left, scope) match { + case Result.Error(value) => value + case Result.Success(lhs) => getLHSType(lhs) match { case Type.Error(message) => BoundExpression.Error(message) case lhsType => + val rhs = check(node.right, lhsType, scope) BoundExpression.Assignment( AstUtils.locationOfExpression(node), lhs, - bindConversion(rhs, lhsType, false) + rhs ) } - - case _ => - val location = AstUtils.locationOfExpression(node.left) - diagnosticBag.reportExpressionIsNotAssignable(location) - BoundExpression.Error("Expression is not assignable: " + location) } } - def bindBinaryExpression( + def inferBinary( node: Expression.Binary, scope: Scope ): BoundExpression = { - val left = bind(node.left, scope) - val right = bind(node.right, scope) + val left = infer(node.left, scope) + val right = infer(node.right, scope) Tuple2(left, right) match { case Tuple2(BoundExpression.Error(_), _) => @@ -416,7 +754,7 @@ class ExprBinder( } } - def bindBlockExpression( + def inferBlock( node: Expression.Block, scope: Scope ): BoundExpression = { @@ -425,7 +763,7 @@ class ExprBinder( val expr = node.block.expression match { case Option.None => BoundExpression.Unit(TextLocationFactory.empty()) - case Option.Some(value) => bind(value, block) + case Option.Some(value) => infer(value, block) } if (expr == BoundExpression.Error) { @@ -455,7 +793,7 @@ class ExprBinder( } } - def bindCallExpression( + def inferCall( node: Expression.Call, scope: Scope ): Result[ @@ -902,11 +1240,11 @@ class ExprBinder( def findConstructor(symbol: Symbol): Option[Symbol] = symbol.lookupMember(".ctor") - def bindCast( + def inferCast( cast: Expression.Cast, scope: Scope ): BoundExpression = { - val expr = bind(cast.expression, scope) + val expr = infer(cast.expression, scope) val typ = binder.bindTypeName(cast.typ, scope) expr match { @@ -921,11 +1259,11 @@ class ExprBinder( } } - def bindIsExpression( + def inferIsExpression( isExpr: Expression.Is, scope: Scope ): BoundExpression = { - val expr = bind(isExpr.expression, scope) + val expr = infer(isExpr.expression, scope) val typ = binder.bindTypeName(isExpr.typ, scope) val location = AstUtils.locationOfExpression(isExpr) @@ -988,18 +1326,18 @@ class ExprBinder( list match { case List.Nil => List.Nil case List.Cons(head, tail) => - val expr = bind(head, scope) + val expr = infer(head, scope) List.Cons(expr, bindExpressions(tail, scope)) } - def bindForExpression( + def inferForExpression( node: Expression.For, scope: Scope ): BoundExpression = { val blockScope = scope.newBlock() - val lowerBound = bindConversionExpr(node.fromExpr, binder.intType, scope) - val upperBound = bindConversionExpr(node.toExpr, binder.intType, scope) + val lowerBound = check(node.fromExpr, binder.intType, scope) + val upperBound = check(node.toExpr, binder.intType, scope) val identifier = node.identifier blockScope.defineLocal(identifier.text, identifier.location) match { @@ -1008,7 +1346,7 @@ class ExprBinder( case Either.Right(variable) => binder.setSymbolType(variable, binder.intType) - val body = bind(node.body, blockScope) + val body = check(node.body, binder.unitType, blockScope) BoundExpression.For( node.forKeyword.location, @@ -1020,13 +1358,13 @@ class ExprBinder( } } - def bindGroupExpression( + def inferGroup( node: Expression.Group, scope: Scope ): BoundExpression = - bind(node.expression, scope) + infer(node.expression, scope) - def bindIdentifierName( + def inferIdentifierName( node: Expression.IdentifierName, scope: Scope ): Result[BoundExpression.Error, BoundExpression.Variable] = { @@ -1064,9 +1402,9 @@ class ExprBinder( } } - def bindIf(node: Expression.If, scope: Scope): BoundExpression = { + def inferIf(node: Expression.If, scope: Scope): BoundExpression = { val cond = bindConversionExpr(node.condition, binder.boolType, scope) - val thenExpr = bind(node.thenExpr, scope) + val thenExpr = infer(node.thenExpr, scope) Tuple2(cond, thenExpr) match { case Tuple2(BoundExpression.Error(_), _) => cond @@ -1102,7 +1440,7 @@ class ExprBinder( } } - def bindLiteralExpression( + def inferLiteral( node: Expression.Literal, scope: Scope ): BoundExpression = { @@ -1160,7 +1498,7 @@ class ExprBinder( } } - def bindMemberAccessExpression( + def inferMemberAccess( node: Expression.MemberAccess, scope: Scope ): Result[BoundExpression.Error, BoundExpression.MemberAccess] = { @@ -1609,8 +1947,46 @@ class ExprBinder( } } - def bindMatchCase( + def inferMatchCase( + matchCase: MatchCaseSyntax, + scope: Scope + ): Result[BoundExpression.Error, BoundMatchCase] = { + // Create a new scope for this case to allow pattern variables + val caseScope = scope.newBlock() + + bindPattern(matchCase.pattern, caseScope) match { + case Result.Error(error) => Result.Error(error) + case Result.Success(boundPattern) => + // Bind the statements and expression within the case scope + val statements = bindStatements(matchCase.block.statements, caseScope) + val resultExpr = matchCase.block.expression match { + case Option.None => + BoundExpression.Unit(TextLocationFactory.empty()) + case Option.Some(expr) => infer(expr, caseScope) + } + + resultExpr match { + case error: BoundExpression.Error => Result.Error(error) + case _ => + val caseResult = if (statements.isEmpty) { + resultExpr + } else { + BoundExpression.Block(statements, resultExpr) + } + Result.Success( + BoundMatchCase( + matchCase.caseKeyword.location, + boundPattern, + caseResult + ) + ) + } + } + } + + def checkMatchCase( matchCase: MatchCaseSyntax, + expectedType: Type, scope: Scope ): Result[BoundExpression.Error, BoundMatchCase] = { // Create a new scope for this case to allow pattern variables @@ -1624,7 +2000,7 @@ class ExprBinder( val resultExpr = matchCase.block.expression match { case Option.None => BoundExpression.Unit(TextLocationFactory.empty()) - case Option.Some(expr) => bind(expr, caseScope) + case Option.Some(expr) => check(expr, expectedType, caseScope) } resultExpr match { @@ -1646,12 +2022,35 @@ class ExprBinder( } } - def bindMatchCases( + def checkMatchCases( + head: MatchCaseSyntax, + tail: List[MatchCaseSyntax], + expectedType: Type, + scope: Scope + ): Result[BoundExpression.Error, NonEmptyList[BoundMatchCase]] = { + checkMatchCase(head, expectedType, scope) match { + case Result.Error(expr) => + Result.Error(expr) + case Result.Success(boundCase) => + tail match { + case List.Nil => + Result.Success(NonEmptyList(boundCase, List.Nil)) + case List.Cons(head, tail) => + checkMatchCases(head, tail, expectedType, scope) match { + case Result.Error(expr) => Result.Error(expr) + case Result.Success(tailCases) => + Result.Success(NonEmptyList(boundCase, tailCases.toList())) + } + } + } + } + + def inferMatchCases( head: MatchCaseSyntax, tail: List[MatchCaseSyntax], scope: Scope ): Result[BoundExpression.Error, NonEmptyList[BoundMatchCase]] = { - bindMatchCase(head, scope) match { + inferMatchCase(head, scope) match { case Result.Error(expr) => Result.Error(expr) case Result.Success(boundCase) => @@ -1659,7 +2058,7 @@ class ExprBinder( case List.Nil => Result.Success(NonEmptyList(boundCase, List.Nil)) case List.Cons(head, tail) => - bindMatchCases(head, tail, scope) match { + inferMatchCases(head, tail, scope) match { case Result.Error(expr) => Result.Error(expr) case Result.Success(tailCases) => Result.Success(NonEmptyList(boundCase, tailCases.toList())) @@ -1681,18 +2080,18 @@ class ExprBinder( } } - def bindMatchExpression( + def inferMatchExpression( node: Expression.Match, scope: Scope ): BoundExpression = { // Bind the expression being matched against - val matchedExpr = bind(node.expression, scope) + val matchedExpr = infer(node.expression, scope) matchedExpr match { case error: BoundExpression.Error => error case _ => // Bind all the match cases - bindMatchCases(node.cases.head, node.cases.tail, scope) match { + inferMatchCases(node.cases.head, node.cases.tail, scope) match { case Result.Error(value) => value case Result.Success(boundCases) => // Calculate the result type from all cases @@ -1710,7 +2109,7 @@ class ExprBinder( } } - def bindNewExpression( + def inferNew( node: Expression.New, scope: Scope ): Result[BoundExpression.Error, BoundLeftHandSide] = { @@ -1927,12 +2326,12 @@ class ExprBinder( } } - def bindUnaryExpression( + def inferUnary( node: Expression.Unary, scope: Scope ): BoundExpression = { val op = bindUnaryOperator(node.operator) - bind(node.expression, scope) match { + infer(node.expression, scope) match { case error: BoundExpression.Error => error case operand => binder.getType(operand) match { @@ -1960,7 +2359,7 @@ class ExprBinder( } - def bindUnitExpression( + def inferUnit( node: Expression.Unit, scope: Scope ): BoundExpression = @@ -1968,12 +2367,12 @@ class ExprBinder( node.openParen.location.merge(node.closeParen.location) ) - def bindWhileExpression( + def inferWhileExpression( node: Expression.While, scope: Scope ): BoundExpression = { - val cond = bind(node.condition, scope) - val body = bind(node.body, scope) + val cond = infer(node.condition, scope) + val body = infer(node.body, scope) new BoundExpression.While(node.whileKeyword.location, cond, body) } @@ -2010,11 +2409,11 @@ class ExprBinder( ) BoundStatement.Error case Either.Right(symbol) => - // symbol was created so lets bind it - val expr = bind(statement.expression, scope) statement.typeAnnotation match { case Option.None => + // no type annotation so lets use type inference + val expr = infer(statement.expression, scope) // no type annotation, so we need to infer the type from the expr val typ = binder.getType(expr) binder.setSymbolType(symbol, typ) @@ -2022,11 +2421,9 @@ class ExprBinder( case Option.Some(value) => val annotatedType = binder.bindTypeName(value.typ, scope) + val boundExpr = check(statement.expression, annotatedType, scope) binder.setSymbolType(symbol, annotatedType) - // make sure we can convert the expression - val boundExpr = bindConversion(expr, annotatedType, false) - BoundStatement.VariableDeclaration( symbol, false, @@ -2052,7 +2449,7 @@ class ExprBinder( statement: StatementSyntax.ExpressionStatement, scope: Scope ): BoundStatement = { - val expr = bind(statement.expression, scope) + val expr = infer(statement.expression, scope) BoundStatement.ExpressionStatement(expr) } diff --git a/pnc/src/Lowered.pn b/pnc/src/Lowered.pn index 56af22e..172a39f 100644 --- a/pnc/src/Lowered.pn +++ b/pnc/src/Lowered.pn @@ -639,7 +639,80 @@ class ExpressionLowerer(symbol: Symbol, binder: Binder) { def lowerForExpression( expr: BoundExpression.For, context: LoweredBlock - ): LoweredBlock = ??? + ): LoweredBlock = { + + // convert for expression to while expression first + /* + * for (; ; ) + * + * + * to + * + * + * while () + * + * + */ + val upperBound = createTemporary() + val setupUpperBound = BoundExpression.Assignment( + upperBound.location, + BoundLeftHandSide.Variable(upperBound.location, upperBound), + expr.upperBound + ) + val initializer = BoundExpression.Assignment( + expr.variable.location, + BoundLeftHandSide.Variable(expr.variable.location, expr.variable), + expr.lowerBound + ) + val condition = BoundExpression.Binary( + expr.location, + BoundExpression.Variable(expr.variable.location, expr.variable, None), + BinaryOperatorKind.LessThanOrEqual, // TODO: should this be LessThan? + BoundExpression.Variable(upperBound.location, upperBound, None), + binder.boolType + ) + val variableType = binder.getSymbolType(expr.variable) + val iterator = BoundExpression.Assignment( + expr.variable.location, + BoundLeftHandSide.Variable(expr.variable.location, expr.variable), + BoundExpression.Binary( + expr.location, + BoundExpression.Variable(expr.variable.location, expr.variable, None), + BinaryOperatorKind.Plus, + BoundExpression.Int(expr.location, 1), + variableType + ) + ) + + val block = BoundExpression.Block( + List.Cons( + // setup upper bound so it doesnt need to be recalculated every loop + BoundStatement.ExpressionStatement(setupUpperBound), + List.Cons( + // initialize variable to the lower bound + BoundStatement.ExpressionStatement(initializer), + List.Nil + ) + ), + // the while loop + BoundExpression.While( + expr.location, + // test our variable against the upper bound + condition, + BoundExpression.Block( + List.Cons( + // the body of the for loop + BoundStatement.ExpressionStatement(expr.body), + List.Nil + ), + // increment the loop variable + iterator + ) + ) + ) + + lowerExpression(block, context) + } def lowerIfExpression( expr: BoundExpression.If, diff --git a/pncs/src/main/scala/Binder.scala b/pncs/src/main/scala/Binder.scala index 9f85753..bddd413 100644 --- a/pncs/src/main/scala/Binder.scala +++ b/pncs/src/main/scala/Binder.scala @@ -630,12 +630,12 @@ case class Binder( options match { case FieldOptions.TypeAndExpression(fieldType, expression) => // TODO: this expression needs to be added to the symbol's constructor body - val expr = exprBinder.bindConversionExpr(expression, fieldType, scope) + val expr = exprBinder.check(expression, fieldType, scope) setSymbolType(symbol, fieldType) case FieldOptions.TypeOnly(fieldType) => setSymbolType(symbol, fieldType) case FieldOptions.ExpressionOnly(expression) => - val expr = exprBinder.bind(expression, scope) + val expr = exprBinder.infer(expression, scope) // TODO: this expression needs to be added to the symbol's constructor body val returnType = getType(expr) setSymbolType(symbol, returnType) @@ -657,9 +657,9 @@ case class Binder( case Option.Some(expression) => val boundExpr = returnType match { - case Option.None => exprBinder.bind(expression, methodScope) + case Option.None => exprBinder.infer(expression, methodScope) case Option.Some(toType) => - exprBinder.bindConversionExpr(expression, toType, methodScope) + exprBinder.check(expression, toType, methodScope) } functionBodies = functionBodies.put(symbol, boundExpr) diff --git a/pncs/src/main/scala/ExprBinder.scala b/pncs/src/main/scala/ExprBinder.scala index b1ac546..f2a9c34 100644 --- a/pncs/src/main/scala/ExprBinder.scala +++ b/pncs/src/main/scala/ExprBinder.scala @@ -98,6 +98,79 @@ case class ExprBinder( * same symbol (assuming no polymorphism in this example). */ + def isSubtype(subType: Type, superType: Type): bool = { + // Handle Error types + if (subType == Type.Error || superType == Type.Error) { + return false + } + + // Base cases + if (subType == superType) { + return true + } + if (subType == binder.neverType) { + return true + } + if (superType == binder.anyType) { + return true + } + + Tuple2(subType, superType) match { + // Function subtyping + case Tuple2( + Type.Function(_, subParams, subReturn), + Type.Function(_, superParams, superReturn) + ) => + if (subParams.length != superParams.length) { + return false + } + val paramsAreSubtypes = Tuple2(subParams, superParams) match { + case Tuple2( + List.Cons(subParam, subTail), + List.Cons(superParam, superTail) + ) => + isSubtype(superParam.typ, subParam.typ) && // Contravariant + isSubtypeList(subTail, superTail) + case Tuple2(List.Nil, List.Nil) => true + case _ => false + } + paramsAreSubtypes && isSubtype(subReturn, superReturn) // Covariant + + // Array subtyping + case Tuple2( + Type.Class(_, _, "Array", List.Cons(subElemType, List.Nil), _), + Type.Class(_, _, "Array", List.Cons(superElemType, List.Nil), _) + ) => + isSubtype(subElemType, superElemType) + + case _ => false + } + } + + def isSubtypeList( + value: List[BoundParameter], + value1: List[BoundParameter] + ): bool = { + Tuple2(value, value1) match { + case Tuple2( + List.Cons(head1, tail1), + List.Cons(head2, tail2) + ) => + isSubtype(head1.typ, head2.typ) && isSubtypeList(tail1, tail2) + case Tuple2(List.Nil, List.Nil) => true + case _ => false + } + } + + def subsume(expr: BoundExpression, expectedType: Type): BoundExpression = { + val exprType = binder.getType(expr) + if (isSubtype(exprType, expectedType)) { + expr + } else { + bindConversion(expr, expectedType, false) + } + } + def bindUnaryOperator(token: SyntaxToken): UnaryOperatorKind = { token.kind match { case SyntaxKind.BangToken => UnaryOperatorKind.LogicalNegation @@ -138,46 +211,320 @@ case class ExprBinder( } } - def bind(expr: Expression, scope: Scope): BoundExpression = { + def check( + expr: Expression, + expectedType: Type, + scope: Scope + ): BoundExpression = { expr match { case node: Expression.ArrayCreation => - bindArrayCreationExpression(node, scope) + checkArrayCreation(node, expectedType, scope) case node: Expression.Assignment => - bindAssignmentExpression(node, scope) + checkAssignment(node, expectedType, scope) + case node: Expression.Binary => checkBinary(node, expectedType, scope) + case node: Expression.Block => checkBlock(node, expectedType, scope) + case node: Expression.Call => checkCall(node, expectedType, scope) + case node: Expression.Cast => checkCast(node, expectedType, scope) + case node: Expression.For => checkFor(node, expectedType, scope) + case node: Expression.Group => checkGroup(node, expectedType, scope) + case node: Expression.IdentifierName => + checkIdentifierName(node, expectedType, scope) + case node: Expression.If => checkIf(node, expectedType, scope) + case node: Expression.Is => checkIs(node, expectedType, scope) + case node: Expression.Literal => checkLiteral(node, expectedType, scope) + case node: Expression.MemberAccess => + checkMemberAccess(node, expectedType, scope) + case node: Expression.Match => checkMatch(node, expectedType, scope) + case node: Expression.New => checkNew(node, expectedType, scope) + case node: Expression.Unary => checkUnary(node, expectedType, scope) + case node: Expression.Unit => checkUnit(node, expectedType, scope) + case node: Expression.While => checkWhile(node, expectedType, scope) + } + } + + def checkArrayCreation( + expr: Expression.ArrayCreation, + expectedType: Type, + scope: Scope + ): BoundExpression = ??? + + def checkAssignment( + expr: Expression.Assignment, + expectedType: Type, + scope: Scope + ): BoundExpression = { + val lhs = bindLHS(expr.left, scope) + val rhs = infer(expr.right, scope) + Tuple2(lhs, rhs) match { + case Tuple2(Result.Error(error), _) => + error + case Tuple2(_, BoundExpression.Error(_)) => + rhs + case Tuple2(Result.Success(lhs), rhs) => + getLHSType(lhs) match { + case Type.Error(message) => BoundExpression.Error(message) + case lhsType => + BoundExpression.Assignment( + AstUtils.locationOfExpression(expr), + lhs, + bindConversion(rhs, lhsType, false) + ) + } + + case _ => + val location = AstUtils.locationOfExpression(expr.left) + diagnosticBag.reportExpressionIsNotAssignable(location) + BoundExpression.Error("Expression is not assignable: " + location) + } + } + + def checkBinary( + expr: Expression.Binary, + expectedType: Type, + scope: Scope + ): BoundExpression = { + val inferred = inferBinary(expr, scope) + subsume(inferred, expectedType) + } + + def checkBlock( + node: Expression.Block, + expectedType: Type, + scope: Scope + ): BoundExpression = { + val block = scope.newBlock() + val statements = bindStatements(node.block.statements, block) + val expr = node.block.expression match { + case Option.None => + BoundExpression.Unit(TextLocationFactory.empty()) + case Option.Some(value) => check(value, expectedType, block) + } + + if (expr == BoundExpression.Error) { + expr + } else { + BoundExpression.Block(statements, expr) + } + } + + def checkCall( + expr: Expression.Call, + expectedType: Type, + scope: Scope + ): BoundExpression = { + inferCall(expr, scope) match { + case Result.Error(value) => value + case Result.Success(value) => + val inferred = convertLHSToExpression(value) + subsume(inferred, expectedType) + } + } + + def checkCast( + expr: Expression.Cast, + expectedType: Type, + scope: Scope + ): BoundExpression = { + val inferred = inferCast(expr, scope) + subsume(inferred, expectedType) + } + + def checkFor( + expr: Expression.For, + expectedType: Type, + scope: Scope + ): BoundExpression = { + // For expressions always return unit, so just infer and subsume + val inferred = inferForExpression(expr, scope) + subsume(inferred, expectedType) + } + + def checkGroup( + expr: Expression.Group, + expectedType: Type, + scope: Scope + ): BoundExpression = + check(expr.expression, expectedType, scope) + + def checkIdentifierName( + expr: Expression.IdentifierName, + expectedType: Type, + scope: Scope + ): BoundExpression = { + inferIdentifierName(expr, scope) match { + case Result.Error(value) => value + case Result.Success(value) => + subsume(value, expectedType) + } + } + + def checkIf( + expr: Expression.If, + expectedType: Type, + scope: Scope + ): BoundExpression = { + val cond = check(expr.condition, binder.boolType, scope) + val thenBranch = check(expr.thenExpr, expectedType, scope) + val elseBranch = expr.elseExpr match { + case Option.None => + // If there's no else branch, we assume it's a unit type + Option.None + case Option.Some(elseExpr) => + Option.Some(check(elseExpr.expression, expectedType, scope)) + } + + BoundExpression.If( + AstUtils.locationOfExpression(expr), + cond, + thenBranch, + elseBranch, + expectedType + ) + } + + def checkIs( + expr: Expression.Is, + expectedType: Type, + scope: Scope + ): BoundExpression = { + val inferred = inferIsExpression(expr, scope) + subsume(inferred, expectedType) + } + + def checkLiteral( + expr: Expression.Literal, + expectedType: Type, + scope: Scope + ): BoundExpression = { + val inferred = inferLiteral(expr, scope) + subsume(inferred, expectedType) + } + + def checkMemberAccess( + expr: Expression.MemberAccess, + expectedType: Type, + scope: Scope + ): BoundExpression = { + inferMemberAccess(expr, scope) match { + case Result.Error(value) => value + case Result.Success(value) => + subsume(value, expectedType) + } + } + + def checkMatch( + expr: Expression.Match, + expectedType: Type, + scope: Scope + ): BoundExpression = { + + // Bind the expression being matched against + val matchedExpr = infer(expr.expression, scope) + + matchedExpr match { + case error: BoundExpression.Error => error + case _ => + // Bind all the match cases + checkMatchCases( + expr.cases.head, + expr.cases.tail, + expectedType, + scope + ) match { + case Result.Error(value) => value + case Result.Success(boundCases) => + val location = AstUtils.locationOfExpression(expr) + + BoundExpression.Match( + location, + expectedType, + matchedExpr, + boundCases + ) + } + } + } + + def checkNew( + expr: Expression.New, + expectedType: Type, + scope: Scope + ): BoundExpression = { + inferNew(expr, scope) match { + case Result.Error(value) => value + case Result.Success(value) => + val inferred = convertLHSToExpression(value) + subsume(inferred, expectedType) + } + } + + def checkUnary( + expr: Expression.Unary, + expectedType: Type, + scope: Scope + ): BoundExpression = { + val inferred = inferUnary(expr, scope) + subsume(inferred, expectedType) + } + + def checkUnit( + expr: Expression.Unit, + expectedType: Type, + scope: Scope + ): BoundExpression = { + val inferred = inferUnit(expr, scope) + subsume(inferred, expectedType) + } + + def checkWhile( + expr: Expression.While, + expectedType: Type, + scope: Scope + ): BoundExpression = { + val inferred = inferWhileExpression(expr, scope) + subsume(inferred, expectedType) + } + + def infer(expr: Expression, scope: Scope): BoundExpression = { + expr match { + case node: Expression.ArrayCreation => + inferArrayCreationExpression(node, scope) + case node: Expression.Assignment => + inferAssignmentExpression(node, scope) case node: Expression.Binary => - bindBinaryExpression(node, scope) - case node: Expression.Block => bindBlockExpression(node, scope) + inferBinary(node, scope) + case node: Expression.Block => inferBlock(node, scope) case node: Expression.Call => - bindCallExpression(node, scope) match { + inferCall(node, scope) match { case Result.Error(value) => value case Result.Success(value) => convertLHSToExpression(value) } - case node: Expression.Cast => bindCast(node, scope) - case node: Expression.For => bindForExpression(node, scope) - case node: Expression.Group => bindGroupExpression(node, scope) + case node: Expression.Cast => inferCast(node, scope) + case node: Expression.For => inferForExpression(node, scope) + case node: Expression.Group => inferGroup(node, scope) case node: Expression.IdentifierName => - bindIdentifierName(node, scope) match { + inferIdentifierName(node, scope) match { case Result.Error(value) => value case Result.Success(value) => value } - case node: Expression.If => bindIf(node, scope) - case node: Expression.Is => bindIsExpression(node, scope) + case node: Expression.If => inferIf(node, scope) + case node: Expression.Is => inferIsExpression(node, scope) case node: Expression.Literal => - bindLiteralExpression(node, scope) + inferLiteral(node, scope) case node: Expression.MemberAccess => - bindMemberAccessExpression(node, scope) match { + inferMemberAccess(node, scope) match { case Result.Error(value) => value case Result.Success(value) => value } - case node: Expression.Match => bindMatchExpression(node, scope) + case node: Expression.Match => inferMatchExpression(node, scope) case node: Expression.New => - bindNewExpression(node, scope) match { + inferNew(node, scope) match { case Result.Error(value) => value case Result.Success(value) => convertLHSToExpression(value) } - case node: Expression.Unary => bindUnaryExpression(node, scope) - case node: Expression.Unit => bindUnitExpression(node, scope) - case node: Expression.While => bindWhileExpression(node, scope) + case node: Expression.Unary => inferUnary(node, scope) + case node: Expression.Unit => inferUnit(node, scope) + case node: Expression.While => inferWhileExpression(node, scope) } } @@ -187,25 +534,25 @@ case class ExprBinder( ): Result[BoundExpression.Error, BoundLeftHandSide] = { expr match { case node: Expression.IdentifierName => - bindIdentifierName(node, scope) match { + inferIdentifierName(node, scope) match { case Result.Error(value) => Result.Error(value) case Result.Success(value) => val location = AstUtils.locationOfExpression(node) Result.Success(BoundLeftHandSide.Variable(location, value.symbol)) } case node: Expression.MemberAccess => - bindMemberAccessExpression(node, scope) match { + inferMemberAccess(node, scope) match { case Result.Error(value) => Result.Error(value) case Result.Success(value) => Result.Success(BoundLeftHandSide.MemberAccess(value)) } case node: Expression.Call => - bindCallExpression(node, scope) match { + inferCall(node, scope) match { case Result.Error(value) => Result.Error(value) case Result.Success(value) => Result.Success(value) } case node: Expression.New => - bindNewExpression(node, scope) match { + inferNew(node, scope) match { case Result.Error(value) => Result.Error(value) case Result.Success(value) => Result.Success(value) } @@ -226,7 +573,7 @@ case class ExprBinder( toType: Type, scope: Scope ): BoundExpression = { - val bound = bind(expr, scope) + val bound = infer(expr, scope) if (bound == BoundExpression.Error) { bound } else { @@ -289,7 +636,7 @@ case class ExprBinder( BoundStatement.Error } - def bindArrayCreationExpression( + def inferArrayCreationExpression( node: Expression.ArrayCreation, scope: Scope ): BoundExpression = { @@ -302,7 +649,7 @@ case class ExprBinder( // Bind the array size expression val sizeExpr = node.arrayRank match { case Option.Some(rankExpr) => - bindConversion(bind(rankExpr, scope), binder.intType, false) + bindConversion(infer(rankExpr, scope), binder.intType, false) case Option.None => // Default to size 0 if no size provided BoundExpression.Int(TextLocationFactory.empty(), 0) @@ -339,41 +686,32 @@ case class ExprBinder( } - def bindAssignmentExpression( + def inferAssignmentExpression( node: Expression.Assignment, scope: Scope ): BoundExpression = { - val lhs = bindLHS(node.left, scope) - val rhs = bind(node.right, scope) - Tuple2(lhs, rhs) match { - case Tuple2(Result.Error(error), _) => - error - case Tuple2(_, BoundExpression.Error(_)) => - rhs - case Tuple2(Result.Success(lhs), rhs) => + bindLHS(node.left, scope) match { + case Result.Error(value) => value + case Result.Success(lhs) => getLHSType(lhs) match { case Type.Error(message) => BoundExpression.Error(message) case lhsType => + val rhs = check(node.right, lhsType, scope) BoundExpression.Assignment( AstUtils.locationOfExpression(node), lhs, - bindConversion(rhs, lhsType, false) + rhs ) } - - case _ => - val location = AstUtils.locationOfExpression(node.left) - diagnosticBag.reportExpressionIsNotAssignable(location) - BoundExpression.Error("Expression is not assignable: " + location) } } - def bindBinaryExpression( + def inferBinary( node: Expression.Binary, scope: Scope ): BoundExpression = { - val left = bind(node.left, scope) - val right = bind(node.right, scope) + val left = infer(node.left, scope) + val right = infer(node.right, scope) Tuple2(left, right) match { case Tuple2(BoundExpression.Error(_), _) => @@ -416,7 +754,7 @@ case class ExprBinder( } } - def bindBlockExpression( + def inferBlock( node: Expression.Block, scope: Scope ): BoundExpression = { @@ -425,7 +763,7 @@ case class ExprBinder( val expr = node.block.expression match { case Option.None => BoundExpression.Unit(TextLocationFactory.empty()) - case Option.Some(value) => bind(value, block) + case Option.Some(value) => infer(value, block) } if (expr == BoundExpression.Error) { @@ -455,7 +793,7 @@ case class ExprBinder( } } - def bindCallExpression( + def inferCall( node: Expression.Call, scope: Scope ): Result[ @@ -902,11 +1240,11 @@ case class ExprBinder( def findConstructor(symbol: Symbol): Option[Symbol] = symbol.lookupMember(".ctor") - def bindCast( + def inferCast( cast: Expression.Cast, scope: Scope ): BoundExpression = { - val expr = bind(cast.expression, scope) + val expr = infer(cast.expression, scope) val typ = binder.bindTypeName(cast.typ, scope) expr match { @@ -921,11 +1259,11 @@ case class ExprBinder( } } - def bindIsExpression( + def inferIsExpression( isExpr: Expression.Is, scope: Scope ): BoundExpression = { - val expr = bind(isExpr.expression, scope) + val expr = infer(isExpr.expression, scope) val typ = binder.bindTypeName(isExpr.typ, scope) val location = AstUtils.locationOfExpression(isExpr) @@ -988,18 +1326,18 @@ case class ExprBinder( list match { case List.Nil => List.Nil case List.Cons(head, tail) => - val expr = bind(head, scope) + val expr = infer(head, scope) List.Cons(expr, bindExpressions(tail, scope)) } - def bindForExpression( + def inferForExpression( node: Expression.For, scope: Scope ): BoundExpression = { val blockScope = scope.newBlock() - val lowerBound = bindConversionExpr(node.fromExpr, binder.intType, scope) - val upperBound = bindConversionExpr(node.toExpr, binder.intType, scope) + val lowerBound = check(node.fromExpr, binder.intType, scope) + val upperBound = check(node.toExpr, binder.intType, scope) val identifier = node.identifier blockScope.defineLocal(identifier.text, identifier.location) match { @@ -1008,7 +1346,7 @@ case class ExprBinder( case Either.Right(variable) => binder.setSymbolType(variable, binder.intType) - val body = bind(node.body, blockScope) + val body = check(node.body, binder.unitType, blockScope) BoundExpression.For( node.forKeyword.location, @@ -1020,13 +1358,13 @@ case class ExprBinder( } } - def bindGroupExpression( + def inferGroup( node: Expression.Group, scope: Scope ): BoundExpression = - bind(node.expression, scope) + infer(node.expression, scope) - def bindIdentifierName( + def inferIdentifierName( node: Expression.IdentifierName, scope: Scope ): Result[BoundExpression.Error, BoundExpression.Variable] = { @@ -1064,9 +1402,9 @@ case class ExprBinder( } } - def bindIf(node: Expression.If, scope: Scope): BoundExpression = { + def inferIf(node: Expression.If, scope: Scope): BoundExpression = { val cond = bindConversionExpr(node.condition, binder.boolType, scope) - val thenExpr = bind(node.thenExpr, scope) + val thenExpr = infer(node.thenExpr, scope) Tuple2(cond, thenExpr) match { case Tuple2(BoundExpression.Error(_), _) => cond @@ -1102,7 +1440,7 @@ case class ExprBinder( } } - def bindLiteralExpression( + def inferLiteral( node: Expression.Literal, scope: Scope ): BoundExpression = { @@ -1160,7 +1498,7 @@ case class ExprBinder( } } - def bindMemberAccessExpression( + def inferMemberAccess( node: Expression.MemberAccess, scope: Scope ): Result[BoundExpression.Error, BoundExpression.MemberAccess] = { @@ -1609,8 +1947,46 @@ case class ExprBinder( } } - def bindMatchCase( + def inferMatchCase( + matchCase: MatchCaseSyntax, + scope: Scope + ): Result[BoundExpression.Error, BoundMatchCase] = { + // Create a new scope for this case to allow pattern variables + val caseScope = scope.newBlock() + + bindPattern(matchCase.pattern, caseScope) match { + case Result.Error(error) => Result.Error(error) + case Result.Success(boundPattern) => + // Bind the statements and expression within the case scope + val statements = bindStatements(matchCase.block.statements, caseScope) + val resultExpr = matchCase.block.expression match { + case Option.None => + BoundExpression.Unit(TextLocationFactory.empty()) + case Option.Some(expr) => infer(expr, caseScope) + } + + resultExpr match { + case error: BoundExpression.Error => Result.Error(error) + case _ => + val caseResult = if (statements.isEmpty) { + resultExpr + } else { + BoundExpression.Block(statements, resultExpr) + } + Result.Success( + BoundMatchCase( + matchCase.caseKeyword.location, + boundPattern, + caseResult + ) + ) + } + } + } + + def checkMatchCase( matchCase: MatchCaseSyntax, + expectedType: Type, scope: Scope ): Result[BoundExpression.Error, BoundMatchCase] = { // Create a new scope for this case to allow pattern variables @@ -1624,7 +2000,7 @@ case class ExprBinder( val resultExpr = matchCase.block.expression match { case Option.None => BoundExpression.Unit(TextLocationFactory.empty()) - case Option.Some(expr) => bind(expr, caseScope) + case Option.Some(expr) => check(expr, expectedType, caseScope) } resultExpr match { @@ -1646,12 +2022,35 @@ case class ExprBinder( } } - def bindMatchCases( + def checkMatchCases( + head: MatchCaseSyntax, + tail: List[MatchCaseSyntax], + expectedType: Type, + scope: Scope + ): Result[BoundExpression.Error, NonEmptyList[BoundMatchCase]] = { + checkMatchCase(head, expectedType, scope) match { + case Result.Error(expr) => + Result.Error(expr) + case Result.Success(boundCase) => + tail match { + case List.Nil => + Result.Success(NonEmptyList(boundCase, List.Nil)) + case List.Cons(head, tail) => + checkMatchCases(head, tail, expectedType, scope) match { + case Result.Error(expr) => Result.Error(expr) + case Result.Success(tailCases) => + Result.Success(NonEmptyList(boundCase, tailCases.toList())) + } + } + } + } + + def inferMatchCases( head: MatchCaseSyntax, tail: List[MatchCaseSyntax], scope: Scope ): Result[BoundExpression.Error, NonEmptyList[BoundMatchCase]] = { - bindMatchCase(head, scope) match { + inferMatchCase(head, scope) match { case Result.Error(expr) => Result.Error(expr) case Result.Success(boundCase) => @@ -1659,7 +2058,7 @@ case class ExprBinder( case List.Nil => Result.Success(NonEmptyList(boundCase, List.Nil)) case List.Cons(head, tail) => - bindMatchCases(head, tail, scope) match { + inferMatchCases(head, tail, scope) match { case Result.Error(expr) => Result.Error(expr) case Result.Success(tailCases) => Result.Success(NonEmptyList(boundCase, tailCases.toList())) @@ -1681,18 +2080,18 @@ case class ExprBinder( } } - def bindMatchExpression( + def inferMatchExpression( node: Expression.Match, scope: Scope ): BoundExpression = { // Bind the expression being matched against - val matchedExpr = bind(node.expression, scope) + val matchedExpr = infer(node.expression, scope) matchedExpr match { case error: BoundExpression.Error => error case _ => // Bind all the match cases - bindMatchCases(node.cases.head, node.cases.tail, scope) match { + inferMatchCases(node.cases.head, node.cases.tail, scope) match { case Result.Error(value) => value case Result.Success(boundCases) => // Calculate the result type from all cases @@ -1710,7 +2109,7 @@ case class ExprBinder( } } - def bindNewExpression( + def inferNew( node: Expression.New, scope: Scope ): Result[BoundExpression.Error, BoundLeftHandSide] = { @@ -1927,12 +2326,12 @@ case class ExprBinder( } } - def bindUnaryExpression( + def inferUnary( node: Expression.Unary, scope: Scope ): BoundExpression = { val op = bindUnaryOperator(node.operator) - bind(node.expression, scope) match { + infer(node.expression, scope) match { case error: BoundExpression.Error => error case operand => binder.getType(operand) match { @@ -1960,7 +2359,7 @@ case class ExprBinder( } - def bindUnitExpression( + def inferUnit( node: Expression.Unit, scope: Scope ): BoundExpression = @@ -1968,12 +2367,12 @@ case class ExprBinder( node.openParen.location.merge(node.closeParen.location) ) - def bindWhileExpression( + def inferWhileExpression( node: Expression.While, scope: Scope ): BoundExpression = { - val cond = bind(node.condition, scope) - val body = bind(node.body, scope) + val cond = infer(node.condition, scope) + val body = infer(node.body, scope) new BoundExpression.While(node.whileKeyword.location, cond, body) } @@ -2010,11 +2409,11 @@ case class ExprBinder( ) BoundStatement.Error case Either.Right(symbol) => - // symbol was created so lets bind it - val expr = bind(statement.expression, scope) statement.typeAnnotation match { case Option.None => + // no type annotation so lets use type inference + val expr = infer(statement.expression, scope) // no type annotation, so we need to infer the type from the expr val typ = binder.getType(expr) binder.setSymbolType(symbol, typ) @@ -2022,11 +2421,9 @@ case class ExprBinder( case Option.Some(value) => val annotatedType = binder.bindTypeName(value.typ, scope) + val boundExpr = check(statement.expression, annotatedType, scope) binder.setSymbolType(symbol, annotatedType) - // make sure we can convert the expression - val boundExpr = bindConversion(expr, annotatedType, false) - BoundStatement.VariableDeclaration( symbol, false, @@ -2052,7 +2449,7 @@ case class ExprBinder( statement: StatementSyntax.ExpressionStatement, scope: Scope ): BoundStatement = { - val expr = bind(statement.expression, scope) + val expr = infer(statement.expression, scope) BoundStatement.ExpressionStatement(expr) } diff --git a/pncs/src/main/scala/Lowered.scala b/pncs/src/main/scala/Lowered.scala index f9c9a14..fcf4660 100644 --- a/pncs/src/main/scala/Lowered.scala +++ b/pncs/src/main/scala/Lowered.scala @@ -639,7 +639,80 @@ class ExpressionLowerer(symbol: Symbol, binder: Binder) { def lowerForExpression( expr: BoundExpression.For, context: LoweredBlock - ): LoweredBlock = ??? + ): LoweredBlock = { + + // convert for expression to while expression first + /* + * for (; ; ) + * + * + * to + * + * + * while () + * + * + */ + val upperBound = createTemporary() + val setupUpperBound = BoundExpression.Assignment( + upperBound.location, + BoundLeftHandSide.Variable(upperBound.location, upperBound), + expr.upperBound + ) + val initializer = BoundExpression.Assignment( + expr.variable.location, + BoundLeftHandSide.Variable(expr.variable.location, expr.variable), + expr.lowerBound + ) + val condition = BoundExpression.Binary( + expr.location, + BoundExpression.Variable(expr.variable.location, expr.variable, None), + BinaryOperatorKind.LessThanOrEqual, // TODO: should this be LessThan? + BoundExpression.Variable(upperBound.location, upperBound, None), + binder.boolType + ) + val variableType = binder.getSymbolType(expr.variable) + val iterator = BoundExpression.Assignment( + expr.variable.location, + BoundLeftHandSide.Variable(expr.variable.location, expr.variable), + BoundExpression.Binary( + expr.location, + BoundExpression.Variable(expr.variable.location, expr.variable, None), + BinaryOperatorKind.Plus, + BoundExpression.Int(expr.location, 1), + variableType + ) + ) + + val block = BoundExpression.Block( + List.Cons( + // setup upper bound so it doesnt need to be recalculated every loop + BoundStatement.ExpressionStatement(setupUpperBound), + List.Cons( + // initialize variable to the lower bound + BoundStatement.ExpressionStatement(initializer), + List.Nil + ) + ), + // the while loop + BoundExpression.While( + expr.location, + // test our variable against the upper bound + condition, + BoundExpression.Block( + List.Cons( + // the body of the for loop + BoundStatement.ExpressionStatement(expr.body), + List.Nil + ), + // increment the loop variable + iterator + ) + ) + ) + + lowerExpression(block, context) + } def lowerIfExpression( expr: BoundExpression.If, diff --git a/project/plugins.sbt b/project/plugins.sbt index d44d007..bd5dff4 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,2 +1,3 @@ addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.5.2") addSbtPlugin("ch.epfl.scala" % "sbt-scalafix" % "0.11.1") +addSbtPlugin("org.scoverage" % "sbt-scoverage" % "2.4.3") diff --git a/test/src/test/scala/TestHelpers.scala b/test/src/test/scala/TestHelpers.scala index 59b1dd3..9dc6743 100644 --- a/test/src/test/scala/TestHelpers.scala +++ b/test/src/test/scala/TestHelpers.scala @@ -455,7 +455,7 @@ object TestHelpers { } } - def assertExprTypeWithSetup( + def assertInferExprTypeWithSetup( setup: string, expression: string, expectedType: string @@ -466,6 +466,19 @@ object TestHelpers { assertSymbolType(comp, symbol, expectedType) } + def assertCheckExprTypeWithSetup( + setup: string, + expression: string, + expectedType: string + ): Unit = { + val comp = mkCompilation( + setup + "\n\nval typeTestSymbol: " + expectedType + " = " + expression + ) + val program = assertSome(comp.root.lookup("$Program")) + val symbol = assertSome(program.lookup("typeTestSymbol")) + assertSymbolType(comp, symbol, expectedType) + } + def assertAssignableToWithSetup( setup: string, expression: string, @@ -476,12 +489,21 @@ object TestHelpers { ) } - def assertExprTypeTest(expression: string, expectedType: string): Unit = { - assertExprType(expression, expectedType) + def assertInferExprType(expression: string, expectedType: string): Unit = { + val comp = mkCompilation("val x = " + expression) + val symbols = enumNonBuiltinSymbols(comp) + + assertProgramSymbol(symbols) + + val x = assertSymbol(symbols, SymbolKind.Field, "x") + assertSymbolType(comp, x, expectedType) + + assertMainSymbol(symbols) + // assertNoSymbols(symbols) } - def assertExprType(expression: string, expectedType: string): Unit = { - val comp = mkCompilation("val x = " + expression) + def assertCheckExprType(expression: string, expectedType: string): Unit = { + val comp = mkCompilation("val x: " + expectedType + " = " + expression) val symbols = enumNonBuiltinSymbols(comp) assertProgramSymbol(symbols) diff --git a/test/src/test/scala/TypeTests.scala b/test/src/test/scala/TypeTests.scala index f6f8e71..f592e4b 100644 --- a/test/src/test/scala/TypeTests.scala +++ b/test/src/test/scala/TypeTests.scala @@ -6,137 +6,238 @@ import org.scalatest.matchers.should.Matchers class TypeTests extends AnyFunSpec with Matchers { describe("Type checker") { - it("should handle primitive types") { - assertExprTypeTest("12", "int") - assertExprTypeTest("0", "int") - assertExprTypeTest("true", "bool") - assertExprTypeTest("false", "bool") - assertExprTypeTest("\"hello\"", "string") - assertExprTypeTest("'a'", "char") - assertExprTypeTest("()", "unit") + it("should infer primitive types") { + assertInferExprType("12", "int") + assertInferExprType("0", "int") + assertInferExprType("true", "bool") + assertInferExprType("false", "bool") + assertInferExprType("\"hello\"", "string") + assertInferExprType("'a'", "char") + assertInferExprType("()", "unit") } - it("should handle binary expressions") { - assertExprTypeTest("1 + 2", "int") - assertExprTypeTest("1 - 2", "int") - assertExprTypeTest("1 * 2", "int") - assertExprTypeTest("1 / 2", "int") - assertExprTypeTest("1 % 2", "int") - assertExprTypeTest("1 == 2", "bool") - assertExprTypeTest("1 != 2", "bool") - assertExprTypeTest("1 < 2", "bool") - assertExprTypeTest("1 <= 2", "bool") - assertExprTypeTest("1 > 2", "bool") - assertExprTypeTest("1 >= 2", "bool") - assertExprTypeTest("true && false", "bool") - assertExprTypeTest("true || false", "bool") + it("should check primitive types") { + assertCheckExprType("12", "int") + assertCheckExprType("0", "int") + assertCheckExprType("true", "bool") + assertCheckExprType("false", "bool") + assertCheckExprType("\"hello\"", "string") + assertCheckExprType("'a'", "char") + assertCheckExprType("()", "unit") } - it("should handle unary expressions") { - assertExprTypeTest("-1", "int") - assertExprTypeTest("+1", "int") - // FIXME: assertExprType("~7", "int") - assertExprTypeTest("!true", "bool") + it("should infer binary expressions") { + assertInferExprType("1 + 2", "int") + assertInferExprType("1 - 2", "int") + assertInferExprType("1 * 2", "int") + assertInferExprType("1 / 2", "int") + assertInferExprType("1 % 2", "int") + assertInferExprType("1 == 2", "bool") + assertInferExprType("1 != 2", "bool") + assertInferExprType("1 < 2", "bool") + assertInferExprType("1 <= 2", "bool") + assertInferExprType("1 > 2", "bool") + assertInferExprType("1 >= 2", "bool") + assertInferExprType("true && false", "bool") + assertInferExprType("true || false", "bool") } - it("should handle grouped expressions") { - assertExprTypeTest("(12)", "int") - assertExprTypeTest("(true)", "bool") + it("should check binary expressions") { + assertCheckExprType("1 + 2", "int") + assertCheckExprType("1 - 2", "int") + assertCheckExprType("1 * 2", "int") + assertCheckExprType("1 / 2", "int") + assertCheckExprType("1 % 2", "int") + assertCheckExprType("1 == 2", "bool") + assertCheckExprType("1 != 2", "bool") + assertCheckExprType("1 < 2", "bool") + assertCheckExprType("1 <= 2", "bool") + assertCheckExprType("1 > 2", "bool") + assertCheckExprType("1 >= 2", "bool") + assertCheckExprType("true && false", "bool") + assertCheckExprType("true || false", "bool") } - it("should handle int variables") { - assertExprTypeWithSetup("val x = 12", "x", "int") + it("should infer unary expressions") { + assertInferExprType("-1", "int") + assertInferExprType("+1", "int") + // FIXME: assertInferExprType("~7", "int") + assertInferExprType("!true", "bool") } - it("should handle bool variables") { - assertExprTypeWithSetup("val x = true", "x", "bool") + it("should check unary expressions") { + assertCheckExprType("-1", "int") + assertCheckExprType("+1", "int") + // FIXME: assertCheckExprType("~7", "int") + assertCheckExprType("!true", "bool") } - it("should handle string variables") { - assertExprTypeWithSetup("val x = \"hello\"", "x", "string") + it("should infer grouped expressions") { + assertInferExprType("(12)", "int") + assertInferExprType("(true)", "bool") } - it("should handle char variables") { - assertExprTypeWithSetup("val x = 'a'", "x", "char") + it("should check grouped expressions") { + assertCheckExprType("(12)", "int") + assertCheckExprType("(true)", "bool") } - it("should handle variable addition") { - assertExprTypeWithSetup("val x = 12", "x + 12", "int") + it("should infer int variables") { + assertInferExprTypeWithSetup("val x = 12", "x", "int") } - it("should handle variable equality") { - assertExprTypeWithSetup("val x = 12", "12 == x", "bool") + it("should check int variables") { + assertCheckExprTypeWithSetup("val x = 12", "x", "int") } - it("should handle variable assignment") { - assertExprTypeWithSetup("var x = 12", "x = 10", "unit") + it("should infer bool variables") { + assertInferExprTypeWithSetup("val x = true", "x", "bool") } - it("should handle conversions to any") { + it("should check bool variables") { + assertCheckExprTypeWithSetup("val x = true", "x", "bool") + } + + it("should infer string variables") { + assertInferExprTypeWithSetup("val x = \"hello\"", "x", "string") + } + + it("should check string variables") { + assertCheckExprTypeWithSetup("val x = \"hello\"", "x", "string") + } + + it("should infer char variables") { + assertInferExprTypeWithSetup("val x = 'a'", "x", "char") + } + + it("should check char variables") { + assertCheckExprTypeWithSetup("val x = 'a'", "x", "char") + } + + it("should infer variable addition") { + assertInferExprTypeWithSetup("val x = 12", "x + 12", "int") + } + + it("should check variable addition") { + assertCheckExprTypeWithSetup("val x = 12", "x + 12", "int") + } + + it("should infer variable equality") { + assertInferExprTypeWithSetup("val x = 12", "12 == x", "bool") + } + + it("should check variable equality") { + assertCheckExprTypeWithSetup("val x = 12", "12 == x", "bool") + } + + it("should infer variable assignment") { + assertInferExprTypeWithSetup("var x = 12", "x = 10", "unit") + } + + it("should check variable assignment") { + assertCheckExprTypeWithSetup("var x = 12", "x = 10", "unit") + } + + it("should infer conversions to any") { assertAssignableTo("12", "any") assertAssignableTo("true", "any") assertAssignableTo("\"hello\"", "any") assertAssignableTo("'a'", "any") } - it("should handle function calls") { - assertExprTypeTest("println(12)", "unit") - assertExprTypeTest("print(12)", "unit") - assertExprTypeTest("print(\"hello\")", "unit") - assertExprTypeTest("print('a')", "unit") - assertExprTypeTest("print(true)", "unit") - assertExprTypeTest("print(12 + 12)", "unit") - assertExprTypeTest("print(12 == 12)", "unit") - assertExprTypeTest("print(12 < 12)", "unit") - assertExprTypeTest("print(true && false)", "unit") - assertExprTypeTest("print(true || false)", "unit") - } - - it("should handle casts") { - assertExprTypeTest("'a' as char", "char") - assertExprTypeTest("12 as char", "char") - assertExprTypeTest("'a' as int", "int") - assertExprTypeTest("12 as int", "int") - assertExprTypeTest("'a' as string", "string") - assertExprTypeTest("12 as string", "string") - assertExprTypeTest("\"hello\" as string", "string") - assertExprTypeTest("true as string", "string") - } - - it("should handle is expressions") { - assertExprTypeTest("12 is int", "bool") - assertExprTypeTest("'a' is char", "bool") - assertExprTypeTest("\"hello\" is string", "bool") - assertExprTypeTest("true is bool", "bool") - assertExprTypeTest("12 is bool", "bool") // should return false at runtime - } - - it("should handle block expressions") { - assertExprTypeTest("{ 1 }", "int") - assertExprTypeTest("{ true }", "bool") - assertExprTypeTest("{ }", "unit") - assertExprTypeTest( + it("should infer function calls") { + assertInferExprType("println(12)", "unit") + assertInferExprType("print(12)", "unit") + assertInferExprType("print(\"hello\")", "unit") + assertInferExprType("print('a')", "unit") + assertInferExprType("print(true)", "unit") + assertInferExprType("print(12 + 12)", "unit") + assertInferExprType("print(12 == 12)", "unit") + assertInferExprType("print(12 < 12)", "unit") + assertInferExprType("print(true && false)", "unit") + assertInferExprType("print(true || false)", "unit") + } + + it("should check function calls") { + assertCheckExprType("println(12)", "unit") + assertCheckExprType("print(12)", "unit") + assertCheckExprType("print(\"hello\")", "unit") + assertCheckExprType("print('a')", "unit") + assertCheckExprType("print(true)", "unit") + assertCheckExprType("print(12 + 12)", "unit") + assertCheckExprType("print(12 == 12)", "unit") + assertCheckExprType("print(12 < 12)", "unit") + assertCheckExprType("print(true && false)", "unit") + assertCheckExprType("print(true || false)", "unit") + } + + it("should infer casts") { + assertInferExprType("'a' as char", "char") + assertInferExprType("12 as char", "char") + assertInferExprType("'a' as int", "int") + assertInferExprType("12 as int", "int") + assertInferExprType("'a' as string", "string") + assertInferExprType("12 as string", "string") + assertInferExprType("\"hello\" as string", "string") + assertInferExprType("true as string", "string") + } + + it("should check casts") { + assertCheckExprType("'a' as char", "char") + assertCheckExprType("12 as char", "char") + assertCheckExprType("'a' as int", "int") + assertCheckExprType("12 as int", "int") + assertCheckExprType("'a' as string", "string") + assertCheckExprType("12 as string", "string") + assertCheckExprType("\"hello\" as string", "string") + assertCheckExprType("true as string", "string") + } + + it("should infer is expressions") { + assertInferExprType("12 is int", "bool") + assertInferExprType("'a' is char", "bool") + assertInferExprType("\"hello\" is string", "bool") + assertInferExprType("true is bool", "bool") + assertInferExprType( + "12 is bool", + "bool" + ) // should return false at runtime + } + + it("should check is expressions") { + assertCheckExprType("12 is int", "bool") + assertCheckExprType("'a' is char", "bool") + assertCheckExprType("\"hello\" is string", "bool") + assertCheckExprType("true is bool", "bool") + assertCheckExprType("12 is bool", "bool") + } + + it("should infer block expressions") { + assertInferExprType("{ 1 }", "int") + assertInferExprType("{ true }", "bool") + assertInferExprType("{ }", "unit") + assertInferExprType( "{\n" + " 1\n" + " 2\n" + "}", "int" ) - assertExprTypeTest( + assertInferExprType( "{\n" + "true\n" + "false\n" + "}", "bool" ) - assertExprTypeTest( + assertInferExprType( "{\n" + " 1\n" + " true\n" + "}", "bool" ) - assertExprTypeTest( + assertInferExprType( "{\n" + " true\n" + " 1\n" + @@ -145,83 +246,296 @@ class TypeTests extends AnyFunSpec with Matchers { ) } - it("should handle if expressions") { - assertExprTypeTest("if (true) 1 else 2", "int") - assertExprTypeTest("if (false) 1 else 2", "int") - assertExprTypeTest("if (true) true else false", "bool") - assertExprTypeTest("if (false) true else false", "bool") - assertExprTypeTest("if (false) true", "unit") + it("should check block expressions") { + assertCheckExprType("{ 1 }", "int") + assertCheckExprType("{ true }", "bool") + assertCheckExprType("{ }", "unit") + assertCheckExprType( + "{\n" + + " 1\n" + + " 2\n" + + "}", + "int" + ) + assertCheckExprType( + "{\n" + + "true\n" + + "false\n" + + "}", + "bool" + ) + assertCheckExprType( + "{\n" + + " 1\n" + + " true\n" + + "}", + "bool" + ) + assertCheckExprType( + "{\n" + + " true\n" + + " 1\n" + + "}", + "int" + ) + } + + it("should infer if expressions") { + assertInferExprType("if (true) 1 else 2", "int") + assertInferExprType("if (false) 1 else 2", "int") + assertInferExprType("if (true) true else false", "bool") + assertInferExprType("if (false) true else false", "bool") + assertInferExprType("if (false) true", "unit") + } + + it("should check if expressions") { + assertCheckExprType("if (true) 1 else 2", "int") + assertCheckExprType("if (false) 1 else 2", "int") + assertCheckExprType("if (true) true else false", "bool") + assertCheckExprType("if (false) true else false", "bool") + assertCheckExprType("if (false) true", "unit") + } + + it("should infer while expressions") { + assertInferExprType("while (true) 1", "unit") + assertInferExprType("while (false) 1", "unit") + } + + it("should check while expressions") { + assertCheckExprType("while (true) 1", "unit") + assertCheckExprType("while (false) 1", "unit") + } + + it("should infer for expressions with literal bounds") { + assertInferExprType("for (i <- 0 to 10) i", "unit") + assertInferExprType("for (x <- 1 to 5) x", "unit") + assertInferExprType("for (n <- 0 to 0) n", "unit") + } + + it("should check for expressions with literal bounds") { + assertCheckExprType("for (i <- 0 to 10) i", "unit") + assertCheckExprType("for (x <- 1 to 5) x", "unit") + assertCheckExprType("for (n <- 0 to 0) n", "unit") + } + + it("should infer for expressions with variable bounds") { + val setup = "val start = 0\nval end = 10" + assertInferExprTypeWithSetup(setup, "for (i <- start to end) i", "unit") + + val setup2 = "val n = 5" + assertInferExprTypeWithSetup(setup2, "for (i <- 0 to n) i", "unit") + assertInferExprTypeWithSetup(setup2, "for (i <- n to 10) i", "unit") + } + + it("should check for expressions with variable bounds") { + val setup = "val start = 0\nval end = 10" + assertCheckExprTypeWithSetup(setup, "for (i <- start to end) i", "unit") + + val setup2 = "val n = 5" + assertCheckExprTypeWithSetup(setup2, "for (i <- 0 to n) i", "unit") + assertCheckExprTypeWithSetup(setup2, "for (i <- n to 10) i", "unit") + } + + it("should infer for expressions with computed bounds") { + assertInferExprType("for (i <- 1 + 2 to 10 - 3) i", "unit") + assertInferExprType("for (i <- 0 to 5 * 2) i", "unit") + + val setup = "val n = 10" + assertInferExprTypeWithSetup(setup, "for (i <- 0 to n * 2) i", "unit") + } + + it("should check for expressions with computed bounds") { + assertCheckExprType("for (i <- 1 + 2 to 10 - 3) i", "unit") + assertCheckExprType("for (i <- 0 to 5 * 2) i", "unit") + + val setup = "val n = 10" + assertCheckExprTypeWithSetup(setup, "for (i <- 0 to n * 2) i", "unit") + } + + it("should infer for expressions with different body types") { + assertInferExprType("for (i <- 0 to 5) true", "unit") + assertInferExprType("for (i <- 0 to 5) \"hello\"", "unit") + assertInferExprType("for (i <- 0 to 5) 'a'", "unit") + assertInferExprType("for (i <- 0 to 5) ()", "unit") + } + + it("should check for expressions with different body types") { + assertCheckExprType("for (i <- 0 to 5) true", "unit") + assertCheckExprType("for (i <- 0 to 5) \"hello\"", "unit") + assertCheckExprType("for (i <- 0 to 5) 'a'", "unit") + assertCheckExprType("for (i <- 0 to 5) ()", "unit") + } + + it("should infer for expressions with loop variable usage") { + assertInferExprType("for (i <- 0 to 10) i + 1", "unit") + assertInferExprType("for (i <- 0 to 10) i * 2", "unit") + assertInferExprType("for (i <- 1 to 5) i == 3", "unit") + } + + it("should check for expressions with loop variable usage") { + assertCheckExprType("for (i <- 0 to 10) i + 1", "unit") + assertCheckExprType("for (i <- 0 to 10) i * 2", "unit") + assertCheckExprType("for (i <- 1 to 5) i == 3", "unit") + } + + it("should infer for expressions accessing outer scope") { + val setup = "val x = 10" + assertInferExprTypeWithSetup(setup, "for (i <- 0 to 5) x", "unit") + assertInferExprTypeWithSetup(setup, "for (i <- 0 to 5) i + x", "unit") + } + + it("should check for expressions accessing outer scope") { + val setup = "val x = 10" + assertCheckExprTypeWithSetup(setup, "for (i <- 0 to 5) x", "unit") + assertCheckExprTypeWithSetup(setup, "for (i <- 0 to 5) i + x", "unit") + } + + it("should infer nested for expressions") { + assertInferExprType("for (i <- 0 to 3) for (j <- 0 to 3) i + j", "unit") + assertInferExprType("for (x <- 1 to 2) for (y <- 1 to 2) x * y", "unit") + } + + it("should check nested for expressions") { + assertCheckExprType("for (i <- 0 to 3) for (j <- 0 to 3) i + j", "unit") + assertCheckExprType("for (x <- 1 to 2) for (y <- 1 to 2) x * y", "unit") } - it("should handle while expressions") { - assertExprTypeTest("while (true) 1", "unit") - assertExprTypeTest("while (false) 1", "unit") + it("should infer for expressions with block bodies") { + assertInferExprType("for (i <- 0 to 5) { i }", "unit") + assertInferExprType("for (i <- 0 to 5) { val x = i\n x * 2 }", "unit") } - it("should handle literal patterns") { - assertExprTypeTest("1 match { case 1 => 2 }", "int") - assertExprTypeTest("1 match { case 2 => 3 }", "int") - assertExprTypeTest("true match { case true => false }", "bool") - assertExprTypeTest( + it("should check for expressions with block bodies") { + assertCheckExprType("for (i <- 0 to 5) { i }", "unit") + assertCheckExprType("for (i <- 0 to 5) { val x = i\n x * 2 }", "unit") + } + + it("should infer literal patterns") { + assertInferExprType("1 match { case 1 => 2 }", "int") + assertInferExprType("1 match { case 2 => 3 }", "int") + assertInferExprType("true match { case true => false }", "bool") + assertInferExprType( + "\"hello\" match { case \"world\" => \"hi\" }", + "string" + ) + assertInferExprType("'a' match { case 'b' => 'c' }", "char") + } + + it("should check literal patterns") { + assertCheckExprType("1 match { case 1 => 2 }", "int") + assertCheckExprType("1 match { case 2 => 3 }", "int") + assertCheckExprType("true match { case true => false }", "bool") + assertCheckExprType( "\"hello\" match { case \"world\" => \"hi\" }", "string" ) - assertExprTypeTest("'a' match { case 'b' => 'c' }", "char") + assertCheckExprType("'a' match { case 'b' => 'c' }", "char") + } + + it("should infer wildcard patterns") { + assertInferExprType("1 match { case _ => 2 }", "int") + assertInferExprType("true match { case _ => false }", "bool") + assertInferExprType("\"hello\" match { case _ => \"world\" }", "string") + } + + it("should check wildcard patterns") { + assertCheckExprType("1 match { case _ => 2 }", "int") + assertCheckExprType("true match { case _ => false }", "bool") + assertCheckExprType("\"hello\" match { case _ => \"world\" }", "string") } - it("should handle wildcard patterns") { - assertExprTypeTest("1 match { case _ => 2 }", "int") - assertExprTypeTest("true match { case _ => false }", "bool") - assertExprTypeTest("\"hello\" match { case _ => \"world\" }", "string") + it("should infer multiple cases with same type") { + assertInferExprType( + "1 match { case 1 => 10\n case 2 => 20 }", + "int" + ) + assertInferExprType( + "true match { case true => 1\n case false => 0 }", + "int" + ) } - it("should handle multiple cases with same type") { - assertExprTypeTest( + it("should check multiple cases with same type") { + assertCheckExprType( "1 match { case 1 => 10\n case 2 => 20 }", "int" ) - assertExprTypeTest( + assertCheckExprType( "true match { case true => 1\n case false => 0 }", "int" ) } - it("should handle multiple cases with different types") { + it("should infer multiple cases with different types") { // When cases have different types, result should be 'any' (least upper bound) - assertExprTypeTest( + assertInferExprType( "1 match { case 1 => 42\n case 2 => \"hello\" }", "int | string" ) - assertExprTypeTest( + assertInferExprType( "1 match { case 1 => true\n case 2 => 123 }", "bool | int" ) } - it("should handle unit result cases") { - assertExprTypeTest("1 match { case x: int => () }", "unit") - assertExprTypeTest("1 match { case 1 => println(\"test\") }", "unit") +// it("should check multiple cases with different types") { +// assertCheckExprType( +// "1 match { case 1 => 42\n case 2 => \"hello\" }", +// "int | string" +// ) +// assertCheckExprType( +// "1 match { case 1 => true\n case 2 => 123 }", +// "bool | int" +// ) +// } + + it("should infer unit result cases") { + assertInferExprType("1 match { case x: int => () }", "unit") + assertInferExprType("1 match { case 1 => println(\"test\") }", "unit") } - it("should handle nested matches") { - assertExprTypeTest( + it("should check unit result cases") { + assertCheckExprType("1 match { case x: int => () }", "unit") + assertCheckExprType("1 match { case 1 => println(\"test\") }", "unit") + } + + it("should infer nested matches") { + assertInferExprType( "1 match { case x: int => x match { case y: int => y * 2 } }", "int" ) } - it("should handle match with blocks") { - assertExprTypeTest( + it("should check nested matches") { + assertCheckExprType( + "1 match { case x: int => x match { case y: int => y * 2 } }", + "int" + ) + } + + it("should infer match with blocks") { + assertInferExprType( + "1 match { case x: int => { val y = x + 1\n y * 2 } }", + "int" + ) + assertInferExprType( + "true match { case b: bool => { println(\"test\")\n b } }", + "bool" + ) + } + + it("should check match with blocks") { + assertCheckExprType( "1 match { case x: int => { val y = x + 1\n y * 2 } }", "int" ) - assertExprTypeTest( + assertCheckExprType( "true match { case b: bool => { println(\"test\")\n b } }", "bool" ) } - it("should handle methods without return type") { + it("should infer methods without return type") { val comp = mkCompilation("def foo() = 12") val symbols = enumNonBuiltinSymbols(comp) assertProgramSymbol(symbols) @@ -231,7 +545,7 @@ class TypeTests extends AnyFunSpec with Matchers { assertNoSymbols(symbols) } - it("should handle methods with parameters") { + it("should infer methods with parameters") { val comp = mkCompilation("def foo(x: int, y: int) = 12") val symbols = enumNonBuiltinSymbols(comp) assertProgramSymbol(symbols) @@ -245,64 +559,128 @@ class TypeTests extends AnyFunSpec with Matchers { assertNoSymbols(symbols) } - it("should handle simple identity generic method") { + it("should infer simple identity generic method") { + val setup = "def identity[T](x: T): T = x" + + assertInferExprTypeWithSetup(setup, "identity(42)", "int") + assertInferExprTypeWithSetup(setup, "identity(true)", "bool") + assertInferExprTypeWithSetup(setup, "identity(\"hello\")", "string") + assertInferExprTypeWithSetup(setup, "identity('a')", "char") + } + + it("should check simple identity generic method") { val setup = "def identity[T](x: T): T = x" - assertExprTypeWithSetup(setup, "identity(42)", "int") - assertExprTypeWithSetup(setup, "identity(true)", "bool") - assertExprTypeWithSetup(setup, "identity(\"hello\")", "string") - assertExprTypeWithSetup(setup, "identity('a')", "char") + assertCheckExprTypeWithSetup(setup, "identity(42)", "int") + assertCheckExprTypeWithSetup(setup, "identity(true)", "bool") + assertCheckExprTypeWithSetup(setup, "identity(\"hello\")", "string") + assertCheckExprTypeWithSetup(setup, "identity('a')", "char") } - it("should handle generic method returning concrete type") { + it("should infer generic method returning concrete type") { val setup = "def getValue[T](x: T): int = 42" // These should work because the return type is concrete - assertExprTypeWithSetup(setup, "getValue(42)", "int") - assertExprTypeWithSetup(setup, "getValue(\"hello\")", "int") - assertExprTypeWithSetup(setup, "getValue(true)", "int") + assertInferExprTypeWithSetup(setup, "getValue(42)", "int") + assertInferExprTypeWithSetup(setup, "getValue(\"hello\")", "int") + assertInferExprTypeWithSetup(setup, "getValue(true)", "int") } - it("should handle generic parameter with concrete return") { + it("should check generic method returning concrete type") { + val setup = "def getValue[T](x: T): int = 42" + + // These should work because the return type is concrete + assertCheckExprTypeWithSetup(setup, "getValue(42)", "int") + assertCheckExprTypeWithSetup(setup, "getValue(\"hello\")", "int") + assertCheckExprTypeWithSetup(setup, "getValue(true)", "int") + } + + it("should infer generic parameter with concrete return") { + // Test methods that accept generic parameters but return concrete types + val setup = "def stringify[T](x: T): string = string(x)" + + assertInferExprTypeWithSetup(setup, "stringify(42)", "string") + assertInferExprTypeWithSetup(setup, "stringify(true)", "string") + assertInferExprTypeWithSetup(setup, "stringify('a')", "string") + } + + it("should check generic parameter with concrete return") { // Test methods that accept generic parameters but return concrete types val setup = "def stringify[T](x: T): string = string(x)" - assertExprTypeWithSetup(setup, "stringify(42)", "string") - assertExprTypeWithSetup(setup, "stringify(true)", "string") - assertExprTypeWithSetup(setup, "stringify('a')", "string") + assertCheckExprTypeWithSetup(setup, "stringify(42)", "string") + assertCheckExprTypeWithSetup(setup, "stringify(true)", "string") + assertCheckExprTypeWithSetup(setup, "stringify('a')", "string") } - it("should handle generic container creation") { + it("should infer generic container creation") { val containerSetup = "class Container[T](value: T)\n" + "def wrap[T](x: T): Container[T] = new Container(x)" - assertExprTypeWithSetup(containerSetup, "wrap(42)", "Container") - assertExprTypeWithSetup(containerSetup, "wrap(true)", "Container") - assertExprTypeWithSetup( + assertInferExprTypeWithSetup(containerSetup, "wrap(42)", "Container") + assertInferExprTypeWithSetup( + containerSetup, + "wrap(true)", + "Container" + ) + assertInferExprTypeWithSetup( containerSetup, "wrap(\"test\")", "Container" ) } - it("should handle enums without args") { +// it("should check generic container creation") { +// val containerSetup = "class Container[T](value: T)\n" + +// "def wrap[T](x: T): Container[T] = new Container(x)" +// +// assertCheckExprTypeWithSetup(containerSetup, "wrap(42)", "Container") +// assertCheckExprTypeWithSetup(containerSetup, "wrap(true)", "Container") +// assertCheckExprTypeWithSetup( +// containerSetup, +// "wrap(\"test\")", +// "Container" +// ) +// } + + it("should infer enums without args") { val setup = "enum Foo {\n" + " case Bar\n" + " case Baz\n" + "}" - assertExprTypeWithSetup(setup, "Foo.Bar", "Foo.Bar") - assertExprTypeWithSetup(setup, "Foo.Baz", "Foo.Baz") + assertInferExprTypeWithSetup(setup, "Foo.Bar", "Foo.Bar") + assertInferExprTypeWithSetup(setup, "Foo.Baz", "Foo.Baz") + } + + it("should check enums without args") { + val setup = "enum Foo {\n" + + " case Bar\n" + + " case Baz\n" + + "}" + assertCheckExprTypeWithSetup(setup, "Foo.Bar", "Foo.Bar") + assertCheckExprTypeWithSetup(setup, "Foo.Baz", "Foo.Baz") + } + + it("should infer enums with args") { + val setup = "enum Foo {\n" + + " case Bar(x: int)\n" + + " case Baz(y: string)\n" + + "}" + assertInferExprTypeWithSetup(setup, "Foo.Bar(12)", "Foo.Bar") + assertInferExprTypeWithSetup(setup, "Foo.Baz(\"taco\")", "Foo.Baz") + assertInferExprTypeWithSetup(setup, "new Foo.Bar(12)", "Foo.Bar") + assertInferExprTypeWithSetup(setup, "new Foo.Baz(\"taco\")", "Foo.Baz") } - it("should handle enums with args") { + it("should check enums with args") { val setup = "enum Foo {\n" + " case Bar(x: int)\n" + " case Baz(y: string)\n" + "}" - assertExprTypeWithSetup(setup, "Foo.Bar(12)", "Foo.Bar") - assertExprTypeWithSetup(setup, "Foo.Baz(\"taco\")", "Foo.Baz") - assertExprTypeWithSetup(setup, "new Foo.Bar(12)", "Foo.Bar") - assertExprTypeWithSetup(setup, "new Foo.Baz(\"taco\")", "Foo.Baz") + assertCheckExprTypeWithSetup(setup, "Foo.Bar(12)", "Foo.Bar") + assertCheckExprTypeWithSetup(setup, "Foo.Baz(\"taco\")", "Foo.Baz") + assertCheckExprTypeWithSetup(setup, "new Foo.Bar(12)", "Foo.Bar") + assertCheckExprTypeWithSetup(setup, "new Foo.Baz(\"taco\")", "Foo.Baz") } it("should check enum assignments") { @@ -314,7 +692,7 @@ class TypeTests extends AnyFunSpec with Matchers { assertAssignableToWithSetup(setup, "Foo.Baz(\"taco\")", "Foo") } - it("should handle enums with generic type") { + it("should infer enums with generic type") { val setup = "enum Option[T] {\n" + " case Some(value: T)\n" + " case None\n" + @@ -330,7 +708,7 @@ class TypeTests extends AnyFunSpec with Matchers { assertAssignableToWithSetup(setup, "Option.None", "Option[never]") } - it("should handle list examples") { + it("should infer list examples") { val setup = "enum List[T] {\n" + " case Cons(head: T, tail: List[T])\n" + " case Nil\n" + @@ -359,265 +737,610 @@ class TypeTests extends AnyFunSpec with Matchers { assertAssignableToWithSetup(setup, "List.Nil", "List[string]") } - it("should handle classes without args") { + it("should infer classes without args") { val setup = "class Foo()" - assertExprTypeWithSetup(setup, "new Foo()", "Foo") + assertInferExprTypeWithSetup(setup, "new Foo()", "Foo") + } + + it("should check classes without args") { + val setup = "class Foo()" + assertCheckExprTypeWithSetup(setup, "new Foo()", "Foo") + } + + it("should infer classes with args") { + val setup = "class Foo(x: int, y: string)" + assertInferExprTypeWithSetup(setup, "new Foo(12, \"taco\")", "Foo") } - it("should handle classes with args") { + it("should check classes with args") { val setup = "class Foo(x: int, y: string)" - assertExprTypeWithSetup(setup, "new Foo(12, \"taco\")", "Foo") + assertCheckExprTypeWithSetup(setup, "new Foo(12, \"taco\")", "Foo") + } + + it("should infer array creation for basic types") { + assertInferExprType("new Array[int](0)", "Array") + assertInferExprType("new Array[bool](5)", "Array") + assertInferExprType("new Array[string](10)", "Array") + assertInferExprType("new Array[char](3)", "Array") + } + + it("should check array creation for basic types") { + assertAssignableTo("new Array[int](0)", "Array[int]") + assertAssignableTo("new Array[bool](5)", "Array[bool]") + assertAssignableTo("new Array[string](10)", "Array[string]") + assertAssignableTo("new Array[char](3)", "Array[char]") + } + + it("should infer array creation with computed size") { + assertInferExprTypeWithSetup( + "val n = 10", + "new Array[int](n)", + "Array" + ) + assertInferExprTypeWithSetup( + "val x = 5", + "new Array[int](x * 2)", + "Array" + ) + assertInferExprTypeWithSetup("", "new Array[int](5 + 3)", "Array") + } + + it("should check array creation with computed size") { + assertAssignableToWithSetup( + "val n = 10", + "new Array[int](n)", + "Array[int]" + ) + assertAssignableToWithSetup( + "val x = 5", + "new Array[int](x * 2)", + "Array[int]" + ) + assertAssignableToWithSetup("", "new Array[int](5 + 3)", "Array[int]") + } + + it("should infer nested array types") { + assertInferExprType("new Array[Array[int]](3)", "Array>") + assertInferExprType("new Array[Array[bool]](2)", "Array>") + } + + it("should check nested array types") { + assertAssignableTo("new Array[Array[int]](3)", "Array[Array[int]]") + assertAssignableTo("new Array[Array[bool]](2)", "Array[Array[bool]]") + } + + it("should infer array creation with class types") { + val setup = "class Foo(x: int)" + assertInferExprTypeWithSetup(setup, "new Array[Foo](5)", "Array") + } + + it("should check array creation with class types") { + val setup = "class Foo(x: int)" + assertAssignableToWithSetup(setup, "new Array[Foo](5)", "Array[Foo]") + } + + it("should infer array creation in variable declarations") { + assertInferExprType("new Array[int](0)", "Array") + assertInferExprTypeWithSetup( + "val arr = new Array[int](5)", + "arr", + "Array" + ) + } + + it("should check array creation subsumes to expected type") { + // Check mode should allow array creation to subsume to the expected type + assertAssignableTo("new Array[int](5)", "Array[int]") + assertAssignableToWithSetup("", "new Array[string](3)", "Array[string]") + } + + it("should infer array length type") { + val setup = "val array = new Array[int](0)" + assertInferExprTypeWithSetup(setup, "array.length", "int") } - it("should handle array length type") { + it("should check array length type") { val setup = "val array = new Array[int](0)" - assertExprTypeWithSetup(setup, "array.length", "int") + assertCheckExprTypeWithSetup(setup, "array.length", "int") + } + + it("should infer array apply with method call") { + val setup = "val array = new Array[int](1)" + assertInferExprTypeWithSetup(setup, "array.apply(0)", "int") } - it("should handle array apply with method call") { + it("should check array apply with method call") { val setup = "val array = new Array[int](1)" - assertExprTypeWithSetup(setup, "array.apply(0)", "int") + assertCheckExprTypeWithSetup(setup, "array.apply(0)", "int") } - it("should handle array apply with indexer syntax for basic types") { + it("should infer array apply with indexer syntax for basic types") { val setup = "val intArray = new Array[int](1)" - assertExprTypeWithSetup(setup, "intArray(0)", "int") + assertInferExprTypeWithSetup(setup, "intArray(0)", "int") val boolSetup = "val boolArray = new Array[bool](1)" - assertExprTypeWithSetup(boolSetup, "boolArray(0)", "bool") + assertInferExprTypeWithSetup(boolSetup, "boolArray(0)", "bool") val stringSetup = "val stringArray = new Array[string](1)" - assertExprTypeWithSetup(stringSetup, "stringArray(0)", "string") + assertInferExprTypeWithSetup(stringSetup, "stringArray(0)", "string") val charSetup = "val charArray = new Array[char](1)" - assertExprTypeWithSetup(charSetup, "charArray(0)", "char") + assertInferExprTypeWithSetup(charSetup, "charArray(0)", "char") } - it("should handle array indexing in expressions") { + it("should check array apply with indexer syntax for basic types") { + val setup = "val intArray = new Array[int](1)" + assertCheckExprTypeWithSetup(setup, "intArray(0)", "int") + + val boolSetup = "val boolArray = new Array[bool](1)" + assertCheckExprTypeWithSetup(boolSetup, "boolArray(0)", "bool") + + val stringSetup = "val stringArray = new Array[string](1)" + assertCheckExprTypeWithSetup(stringSetup, "stringArray(0)", "string") + + val charSetup = "val charArray = new Array[char](1)" + assertCheckExprTypeWithSetup(charSetup, "charArray(0)", "char") + } + + it("should infer array indexing in expressions") { + val setup = "val array = new Array[int](5)" + assertInferExprTypeWithSetup(setup, "array(0) + array(1)", "int") + assertInferExprTypeWithSetup(setup, "array(2) * 3", "int") + assertInferExprTypeWithSetup(setup, "array(0) == array(1)", "bool") + } + + it("should check array indexing in expressions") { val setup = "val array = new Array[int](5)" - assertExprTypeWithSetup(setup, "array(0) + array(1)", "int") - assertExprTypeWithSetup(setup, "array(2) * 3", "int") - assertExprTypeWithSetup(setup, "array(0) == array(1)", "bool") + assertCheckExprTypeWithSetup(setup, "array(0) + array(1)", "int") + assertCheckExprTypeWithSetup(setup, "array(2) * 3", "int") + assertCheckExprTypeWithSetup(setup, "array(0) == array(1)", "bool") + } + + it("should infer array indexing with computed indices") { + val setup = "val array = new Array[int](10)\nval i = 5" + assertInferExprTypeWithSetup(setup, "array(i)", "int") + assertInferExprTypeWithSetup(setup, "array(1 + 2)", "int") + assertInferExprTypeWithSetup(setup, "array(i * 2)", "int") } - it("should handle array indexing with computed indices") { + it("should check array indexing with computed indices") { val setup = "val array = new Array[int](10)\nval i = 5" - assertExprTypeWithSetup(setup, "array(i)", "int") - assertExprTypeWithSetup(setup, "array(1 + 2)", "int") - assertExprTypeWithSetup(setup, "array(i * 2)", "int") + assertCheckExprTypeWithSetup(setup, "array(i)", "int") + assertCheckExprTypeWithSetup(setup, "array(1 + 2)", "int") + assertCheckExprTypeWithSetup(setup, "array(i * 2)", "int") } - it("should handle array indexing assignment - LHS binding fix") { + it("should infer array indexing assignment - LHS binding fix") { // This tests the specific fix for "NewExpression in bindLHS" error val setup = "var array = new Array[int](5)" - assertExprTypeWithSetup(setup, "array(0) = 42", "unit") - assertExprTypeWithSetup(setup, "array(1) = array(0) + 1", "unit") + assertInferExprTypeWithSetup(setup, "array(0) = 42", "unit") + assertInferExprTypeWithSetup(setup, "array(1) = array(0) + 1", "unit") } - it("should handle primitive casts") { + it("should check array indexing assignment - LHS binding fix") { + val setup = "var array = new Array[int](5)" + assertCheckExprTypeWithSetup(setup, "array(0) = 42", "unit") + assertCheckExprTypeWithSetup(setup, "array(1) = array(0) + 1", "unit") + } + + it("should infer primitive casts") { // Basic numeric casting - assertExprTypeTest("42 as int", "int") - assertExprTypeTest("42 as bool", "bool") - assertExprTypeTest("42 as char", "char") - assertExprTypeTest("42 as string", "string") + assertInferExprType("42 as int", "int") + assertInferExprType("42 as bool", "bool") + assertInferExprType("42 as char", "char") + assertInferExprType("42 as string", "string") // Boolean casting - assertExprTypeTest("true as int", "int") - assertExprTypeTest("true as bool", "bool") - assertExprTypeTest("false as string", "string") + assertInferExprType("true as int", "int") + assertInferExprType("true as bool", "bool") + assertInferExprType("false as string", "string") // Character casting - assertExprTypeTest("'a' as int", "int") - assertExprTypeTest("'a' as char", "char") - assertExprTypeTest("'a' as string", "string") + assertInferExprType("'a' as int", "int") + assertInferExprType("'a' as char", "char") + assertInferExprType("'a' as string", "string") // String casting - assertExprTypeTest("\"hello\" as string", "string") - assertExprTypeTest("\"hello\" as any", "any") + assertInferExprType("\"hello\" as string", "string") + assertInferExprType("\"hello\" as any", "any") } - it("should handle variable casts") { + it("should check primitive casts") { + // Basic numeric casting + assertCheckExprType("42 as int", "int") + assertCheckExprType("42 as bool", "bool") + assertCheckExprType("42 as char", "char") + assertCheckExprType("42 as string", "string") + + // Boolean casting + assertCheckExprType("true as int", "int") + assertCheckExprType("true as bool", "bool") + assertCheckExprType("false as string", "string") + + // Character casting + assertCheckExprType("'a' as int", "int") + assertCheckExprType("'a' as char", "char") + assertCheckExprType("'a' as string", "string") + + // String casting + assertCheckExprType("\"hello\" as string", "string") + assertCheckExprType("\"hello\" as any", "any") + } + + it("should infer variable casts") { val setup = "val x = 42\nval flag = true\nval ch = 'a'\nval text = \"hello\"" - assertExprTypeWithSetup(setup, "x as bool", "bool") - assertExprTypeWithSetup(setup, "x as char", "char") - assertExprTypeWithSetup(setup, "x as string", "string") + assertInferExprTypeWithSetup(setup, "x as bool", "bool") + assertInferExprTypeWithSetup(setup, "x as char", "char") + assertInferExprTypeWithSetup(setup, "x as string", "string") + + assertInferExprTypeWithSetup(setup, "flag as int", "int") + assertInferExprTypeWithSetup(setup, "flag as string", "string") - assertExprTypeWithSetup(setup, "flag as int", "int") - assertExprTypeWithSetup(setup, "flag as string", "string") + assertInferExprTypeWithSetup(setup, "ch as int", "int") + assertInferExprTypeWithSetup(setup, "ch as string", "string") + + assertInferExprTypeWithSetup(setup, "text as any", "any") + } - assertExprTypeWithSetup(setup, "ch as int", "int") - assertExprTypeWithSetup(setup, "ch as string", "string") + it("should check variable casts") { + val setup = + "val x = 42\nval flag = true\nval ch = 'a'\nval text = \"hello\"" + + assertCheckExprTypeWithSetup(setup, "x as bool", "bool") + assertCheckExprTypeWithSetup(setup, "x as char", "char") + assertCheckExprTypeWithSetup(setup, "x as string", "string") + + assertCheckExprTypeWithSetup(setup, "flag as int", "int") + assertCheckExprTypeWithSetup(setup, "flag as string", "string") + + assertCheckExprTypeWithSetup(setup, "ch as int", "int") + assertCheckExprTypeWithSetup(setup, "ch as string", "string") + + assertCheckExprTypeWithSetup(setup, "text as any", "any") + } + + it("should infer cast to any type") { + assertInferExprType("42 as any", "any") + assertInferExprType("true as any", "any") + assertInferExprType("\"hello\" as any", "any") + assertInferExprType("'a' as any", "any") + assertInferExprType("() as any", "any") + } - assertExprTypeWithSetup(setup, "text as any", "any") + it("should check cast to any type") { + assertCheckExprType("42 as any", "any") + assertCheckExprType("true as any", "any") + assertCheckExprType("\"hello\" as any", "any") + assertCheckExprType("'a' as any", "any") + assertCheckExprType("() as any", "any") } - it("should handle cast to any type") { - assertExprTypeTest("42 as any", "any") - assertExprTypeTest("true as any", "any") - assertExprTypeTest("\"hello\" as any", "any") - assertExprTypeTest("'a' as any", "any") - assertExprTypeTest("() as any", "any") + it("should infer cast from any type") { + val setup = "val obj: any = 42" + + assertInferExprTypeWithSetup(setup, "obj as int", "int") + assertInferExprTypeWithSetup(setup, "obj as bool", "bool") + assertInferExprTypeWithSetup(setup, "obj as char", "char") + assertInferExprTypeWithSetup(setup, "obj as string", "string") + assertInferExprTypeWithSetup(setup, "obj as unit", "unit") } - it("should handle cast from any type") { + it("should check cast from any type") { val setup = "val obj: any = 42" - assertExprTypeWithSetup(setup, "obj as int", "int") - assertExprTypeWithSetup(setup, "obj as bool", "bool") - assertExprTypeWithSetup(setup, "obj as char", "char") - assertExprTypeWithSetup(setup, "obj as string", "string") - assertExprTypeWithSetup(setup, "obj as unit", "unit") + assertCheckExprTypeWithSetup(setup, "obj as int", "int") + assertCheckExprTypeWithSetup(setup, "obj as bool", "bool") + assertCheckExprTypeWithSetup(setup, "obj as char", "char") + assertCheckExprTypeWithSetup(setup, "obj as string", "string") + assertCheckExprTypeWithSetup(setup, "obj as unit", "unit") + } + + it("should infer expression casts") { + // Cast results of binary operations + assertInferExprType("(1 + 2) as bool", "bool") + assertInferExprType("(true && false) as int", "int") + assertInferExprType("(1 == 2) as string", "string") + + // Cast results of unary operations + assertInferExprType("(!true) as int", "int") + assertInferExprType("(-42) as bool", "bool") } - it("should handle expression casts") { + it("should check expression casts") { // Cast results of binary operations - assertExprTypeTest("(1 + 2) as bool", "bool") - assertExprTypeTest("(true && false) as int", "int") - assertExprTypeTest("(1 == 2) as string", "string") + assertCheckExprType("(1 + 2) as bool", "bool") + assertCheckExprType("(true && false) as int", "int") + assertCheckExprType("(1 == 2) as string", "string") // Cast results of unary operations - assertExprTypeTest("(!true) as int", "int") - assertExprTypeTest("(-42) as bool", "bool") + assertCheckExprType("(!true) as int", "int") + assertCheckExprType("(-42) as bool", "bool") + } + + it("should infer cast with parentheses") { + assertInferExprType("(42) as string", "string") + assertInferExprType("(true) as int", "int") + assertInferExprType("(\"hello\") as any", "any") + } + + it("should check cast with parentheses") { + assertCheckExprType("(42) as string", "string") + assertCheckExprType("(true) as int", "int") + assertCheckExprType("(\"hello\") as any", "any") } - it("should handle cast with parentheses") { - assertExprTypeTest("(42) as string", "string") - assertExprTypeTest("(true) as int", "int") - assertExprTypeTest("(\"hello\") as any", "any") + it("should infer chained operations with casts") { + val setup = "val x = 42" + + // Cast should have appropriate precedence + assertInferExprTypeWithSetup(setup, "x as bool == true", "bool") + assertInferExprTypeWithSetup(setup, "(x as bool) == true", "bool") } - it("should handle chained operations with casts") { + it("should check chained operations with casts") { val setup = "val x = 42" // Cast should have appropriate precedence - assertExprTypeWithSetup(setup, "x as bool == true", "bool") - assertExprTypeWithSetup(setup, "(x as bool) == true", "bool") + assertCheckExprTypeWithSetup(setup, "x as bool == true", "bool") + assertCheckExprTypeWithSetup(setup, "(x as bool) == true", "bool") + } + + it("should infer basic string conversions") { + // Convert literals to string + assertInferExprType("string(42)", "string") + assertInferExprType("string(0)", "string") + assertInferExprType("string(-123)", "string") + assertInferExprType("string(true)", "string") + assertInferExprType("string(false)", "string") + assertInferExprType("string('a')", "string") + assertInferExprType("string('z')", "string") + assertInferExprType("string(())", "string") } - it("should handle basic string conversions") { + it("should check basic string conversions") { // Convert literals to string - assertExprTypeTest("string(42)", "string") - assertExprTypeTest("string(0)", "string") - assertExprTypeTest("string(-123)", "string") - assertExprTypeTest("string(true)", "string") - assertExprTypeTest("string(false)", "string") - assertExprTypeTest("string('a')", "string") - assertExprTypeTest("string('z')", "string") - assertExprTypeTest("string(())", "string") + assertCheckExprType("string(42)", "string") + assertCheckExprType("string(0)", "string") + assertCheckExprType("string(-123)", "string") + assertCheckExprType("string(true)", "string") + assertCheckExprType("string(false)", "string") + assertCheckExprType("string('a')", "string") + assertCheckExprType("string('z')", "string") + assertCheckExprType("string(())", "string") + } + + it("should infer string conversions with variables") { + val intSetup = "val num = 123" + assertInferExprTypeWithSetup(intSetup, "string(num)", "string") + + val boolSetup = "val flag = true" + assertInferExprTypeWithSetup(boolSetup, "string(flag)", "string") + + val charSetup = "val ch = 'A'" + assertInferExprTypeWithSetup(charSetup, "string(ch)", "string") + + val unitSetup = "val nothing = ()" + assertInferExprTypeWithSetup(unitSetup, "string(nothing)", "string") } - it("should handle string conversions with variables") { + it("should check string conversions with variables") { val intSetup = "val num = 123" - assertExprTypeWithSetup(intSetup, "string(num)", "string") + assertCheckExprTypeWithSetup(intSetup, "string(num)", "string") val boolSetup = "val flag = true" - assertExprTypeWithSetup(boolSetup, "string(flag)", "string") + assertCheckExprTypeWithSetup(boolSetup, "string(flag)", "string") val charSetup = "val ch = 'A'" - assertExprTypeWithSetup(charSetup, "string(ch)", "string") + assertCheckExprTypeWithSetup(charSetup, "string(ch)", "string") val unitSetup = "val nothing = ()" - assertExprTypeWithSetup(unitSetup, "string(nothing)", "string") + assertCheckExprTypeWithSetup(unitSetup, "string(nothing)", "string") + } + + it("should infer string conversions with expressions") { + // Arithmetic expressions + assertInferExprType("string(1 + 2)", "string") + assertInferExprType("string(10 - 5)", "string") + assertInferExprType("string(3 * 4)", "string") + assertInferExprType("string(15 / 3)", "string") + assertInferExprType("string(17 % 5)", "string") + + // Boolean expressions + assertInferExprType("string(true && false)", "string") + assertInferExprType("string(true || false)", "string") + assertInferExprType("string(!true)", "string") + assertInferExprType("string(5 > 3)", "string") + assertInferExprType("string(5 == 5)", "string") + assertInferExprType("string(5 != 3)", "string") + + // Unary expressions + assertInferExprType("string(-42)", "string") + assertInferExprType("string(+42)", "string") } - it("should handle string conversions with expressions") { + it("should check string conversions with expressions") { // Arithmetic expressions - assertExprTypeTest("string(1 + 2)", "string") - assertExprTypeTest("string(10 - 5)", "string") - assertExprTypeTest("string(3 * 4)", "string") - assertExprTypeTest("string(15 / 3)", "string") - assertExprTypeTest("string(17 % 5)", "string") + assertCheckExprType("string(1 + 2)", "string") + assertCheckExprType("string(10 - 5)", "string") + assertCheckExprType("string(3 * 4)", "string") + assertCheckExprType("string(15 / 3)", "string") + assertCheckExprType("string(17 % 5)", "string") // Boolean expressions - assertExprTypeTest("string(true && false)", "string") - assertExprTypeTest("string(true || false)", "string") - assertExprTypeTest("string(!true)", "string") - assertExprTypeTest("string(5 > 3)", "string") - assertExprTypeTest("string(5 == 5)", "string") - assertExprTypeTest("string(5 != 3)", "string") + assertCheckExprType("string(true && false)", "string") + assertCheckExprType("string(true || false)", "string") + assertCheckExprType("string(!true)", "string") + assertCheckExprType("string(5 > 3)", "string") + assertCheckExprType("string(5 == 5)", "string") + assertCheckExprType("string(5 != 3)", "string") // Unary expressions - assertExprTypeTest("string(-42)", "string") - assertExprTypeTest("string(+42)", "string") + assertCheckExprType("string(-42)", "string") + assertCheckExprType("string(+42)", "string") + } + + it("should infer string conversions in complex expressions") { + val setup = "val x = 42\nval y = true" + + // String conversions in comparisons + assertInferExprTypeWithSetup(setup, "string(x) == \"42\"", "bool") + assertInferExprTypeWithSetup(setup, "string(y) == \"true\"", "bool") + assertInferExprTypeWithSetup(setup, "string(x) != \"0\"", "bool") + + // String conversions in arithmetic context + assertInferExprTypeWithSetup(setup, "string(x + 10)", "string") + assertInferExprTypeWithSetup(setup, "string(x * 2)", "string") } - it("should handle string conversions in complex expressions") { + it("should check string conversions in complex expressions") { val setup = "val x = 42\nval y = true" // String conversions in comparisons - assertExprTypeWithSetup(setup, "string(x) == \"42\"", "bool") - assertExprTypeWithSetup(setup, "string(y) == \"true\"", "bool") - assertExprTypeWithSetup(setup, "string(x) != \"0\"", "bool") + assertCheckExprTypeWithSetup(setup, "string(x) == \"42\"", "bool") + assertCheckExprTypeWithSetup(setup, "string(y) == \"true\"", "bool") + assertCheckExprTypeWithSetup(setup, "string(x) != \"0\"", "bool") // String conversions in arithmetic context - assertExprTypeWithSetup(setup, "string(x + 10)", "string") - assertExprTypeWithSetup(setup, "string(x * 2)", "string") + assertCheckExprTypeWithSetup(setup, "string(x + 10)", "string") + assertCheckExprTypeWithSetup(setup, "string(x * 2)", "string") } - it("should handle string conversions with method calls") { + it("should infer string conversions with method calls") { // Convert results of method calls to string - assertExprTypeTest("string(println(\"test\"))", "string") - assertExprTypeTest("string(print(42))", "string") + assertInferExprType("string(println(\"test\"))", "string") + assertInferExprType("string(print(42))", "string") } - it("should handle string conversions with control flow") { + it("should check string conversions with method calls") { + // Convert results of method calls to string + assertCheckExprType("string(println(\"test\"))", "string") + assertCheckExprType("string(print(42))", "string") + } + + it("should infer string conversions with control flow") { + // String conversions with if expressions + assertInferExprType("string(if (true) 1 else 2)", "string") + assertInferExprType("string(if (false) true else false)", "string") + + // String conversions with block expressions + assertInferExprType("string({ 42 })", "string") + assertInferExprType("string({ true })", "string") + } + + it("should check string conversions with control flow") { // String conversions with if expressions - assertExprTypeTest("string(if (true) 1 else 2)", "string") - assertExprTypeTest("string(if (false) true else false)", "string") + assertCheckExprType("string(if (true) 1 else 2)", "string") + assertCheckExprType("string(if (false) true else false)", "string") // String conversions with block expressions - assertExprTypeTest("string({ 42 })", "string") - assertExprTypeTest("string({ true })", "string") + assertCheckExprType("string({ 42 })", "string") + assertCheckExprType("string({ true })", "string") } - it("should handle nested string conversions") { + it("should infer nested string conversions") { val setup = "val x = 42" // String conversion of string (should still work) - assertExprTypeTest("string(\"hello\")", "string") - assertExprTypeWithSetup(setup, "string(string(x))", "string") + assertInferExprType("string(\"hello\")", "string") + assertInferExprTypeWithSetup(setup, "string(string(x))", "string") // String conversions in nested expressions - assertExprTypeWithSetup(setup, "string(string(x) == \"42\")", "string") + assertInferExprTypeWithSetup( + setup, + "string(string(x) == \"42\")", + "string" + ) } - it("should handle int conversions") { + it("should check nested string conversions") { + val setup = "val x = 42" + + // String conversion of string (should still work) + assertCheckExprType("string(\"hello\")", "string") + assertCheckExprTypeWithSetup(setup, "string(string(x))", "string") + + // String conversions in nested expressions + assertCheckExprTypeWithSetup( + setup, + "string(string(x) == \"42\")", + "string" + ) + } + + it("should infer int conversions") { + // Convert various types to int + assertInferExprType("int(42)", "int") + assertInferExprType("int(true)", "int") + assertInferExprType("int(false)", "int") + assertInferExprType("int('a')", "int") + + // Int conversions with variables and expressions + val setup = "val flag = true\nval ch = 'A'" + assertInferExprTypeWithSetup(setup, "int(flag)", "int") + assertInferExprTypeWithSetup(setup, "int(ch)", "int") + assertInferExprType("int(1 + 2)", "int") + } + + it("should check int conversions") { // Convert various types to int - assertExprTypeTest("int(42)", "int") - assertExprTypeTest("int(true)", "int") - assertExprTypeTest("int(false)", "int") - assertExprTypeTest("int('a')", "int") + assertCheckExprType("int(42)", "int") + assertCheckExprType("int(true)", "int") + assertCheckExprType("int(false)", "int") + assertCheckExprType("int('a')", "int") // Int conversions with variables and expressions val setup = "val flag = true\nval ch = 'A'" - assertExprTypeWithSetup(setup, "int(flag)", "int") - assertExprTypeWithSetup(setup, "int(ch)", "int") - assertExprTypeTest("int(1 + 2)", "int") + assertCheckExprTypeWithSetup(setup, "int(flag)", "int") + assertCheckExprTypeWithSetup(setup, "int(ch)", "int") + assertCheckExprType("int(1 + 2)", "int") } - it("should handle bool conversions") { + it("should infer bool conversions") { // Convert various types to bool - assertExprTypeTest("bool(true)", "bool") - assertExprTypeTest("bool(false)", "bool") - assertExprTypeTest("bool(42)", "bool") - assertExprTypeTest("bool(0)", "bool") + assertInferExprType("bool(true)", "bool") + assertInferExprType("bool(false)", "bool") + assertInferExprType("bool(42)", "bool") + assertInferExprType("bool(0)", "bool") // Bool conversions with variables and expressions val setup = "val num = 123" - assertExprTypeWithSetup(setup, "bool(num)", "bool") - assertExprTypeTest("bool(5 > 3)", "bool") + assertInferExprTypeWithSetup(setup, "bool(num)", "bool") + assertInferExprType("bool(5 > 3)", "bool") + } + + it("should check bool conversions") { + // Convert various types to bool + assertCheckExprType("bool(true)", "bool") + assertCheckExprType("bool(false)", "bool") + assertCheckExprType("bool(42)", "bool") + assertCheckExprType("bool(0)", "bool") + + // Bool conversions with variables and expressions + val setup = "val num = 123" + assertCheckExprTypeWithSetup(setup, "bool(num)", "bool") + assertCheckExprType("bool(5 > 3)", "bool") + } + + it("should infer char conversions") { + // Convert various types to char + assertInferExprType("char('a')", "char") + assertInferExprType("char(65)", "char") + + // Char conversions with variables + val setup = "val ascii = 97" + assertInferExprTypeWithSetup(setup, "char(ascii)", "char") } - it("should handle char conversions") { + it("should check char conversions") { // Convert various types to char - assertExprTypeTest("char('a')", "char") - assertExprTypeTest("char(65)", "char") + assertCheckExprType("char('a')", "char") + assertCheckExprType("char(65)", "char") // Char conversions with variables val setup = "val ascii = 97" - assertExprTypeWithSetup(setup, "char(ascii)", "char") + assertCheckExprTypeWithSetup(setup, "char(ascii)", "char") } } }