diff --git a/pnc/src/ExprBinder.pn b/pnc/src/ExprBinder.pn index 8d8dbd1..a33258a 100644 --- a/pnc/src/ExprBinder.pn +++ b/pnc/src/ExprBinder.pn @@ -1220,7 +1220,27 @@ class ExprBinder( ) ) case Option.Some(symbol) => - Result.Success(symbol) + // First check if this symbol itself is a constructor + if (symbol.kind == SymbolKind.Constructor) { + Result.Success(symbol) + } else { + // Look for the constructor inside the class/enum case + val classScope = Scope(symbol, List.Nil) + classScope.lookup(".ctor") match { + case Option.None => + diagnosticBag.reportSymbolNotFound( + identifier.location, + ".ctor" + ) + Result.Error( + BoundExpression.Error( + "Constructor not found for: " + identifier.text + ) + ) + case Option.Some(ctorSymbol) => + Result.Success(ctorSymbol) + } + } } case NameSyntax.QualifiedName( left, @@ -1241,7 +1261,27 @@ class ExprBinder( ) ) case Option.Some(symbol) => - Result.Success(symbol) + // First check if this symbol itself is a constructor + if (symbol.kind == SymbolKind.Constructor) { + Result.Success(symbol) + } else { + // Look for the constructor inside the class/enum case + val classScope = Scope(symbol, List.Nil) + classScope.lookup(".ctor") match { + case Option.None => + diagnosticBag.reportSymbolNotFound( + identifier.location, + ".ctor" + ) + Result.Error( + BoundExpression.Error( + "Constructor not found for: " + identifier.text + ) + ) + case Option.Some(ctorSymbol) => + Result.Success(ctorSymbol) + } + } } case _ => diagnosticBag.reportInternalError( @@ -1256,6 +1296,39 @@ class ExprBinder( } } + def bindPatternWithType( + pattern: PatternSyntax, + scope: Scope, + expectedType: Type + ): Result[BoundExpression.Error, BoundPattern] = { + pattern match { + case PatternSyntax.Literal(token) => + bindLiteralFromSyntaxToken(token) match { + case Result.Error(value) => Result.Error(value) + case Result.Success(literal) => + Result.Success( + BoundPattern.Literal( + literal + ) + ) + } + case PatternSyntax.Discard(_) => + Result.Success(BoundPattern.Discard) + case PatternSyntax.Identifier(identifier) => + bindIdentifierPattern(scope, identifier, expectedType) + case PatternSyntax.Type(_) => + // For type patterns, we create a wildcard pattern + Result.Success(BoundPattern.Discard) + case PatternSyntax.TypeAssertion(innerPattern, typeAnnotation) => + // Bind the inner pattern with the annotated type + val annotatedType = binder.bindTypeName(typeAnnotation.typ, scope) + bindPatternWithType(innerPattern, scope, annotatedType) + case PatternSyntax.Extract(constructorName, _, patterns, _) => + // For nested extract patterns, use the regular bindPattern + bindPattern(pattern, scope) + } + } + def bindPattern( pattern: PatternSyntax, scope: Scope @@ -1316,37 +1389,100 @@ class ExprBinder( case Result.Error(error) => Result.Error(error) case Result.Success(constructor) => - // TODO: Verify that the constructor is valid - // get its parameter types and set up each pattern with a TypeAssertion - // and set the types for all the pattern variables - - // Bind each pattern parameter - val boundPatterns = new Array[BoundPattern](patterns.length) - var i = 0 - var hasError = false - var errorResult: Option[BoundExpression.Error] = Option.None - - while (i < patterns.length && !hasError) { - bindPattern(patterns(i).pattern, scope) match { - case Result.Error(error) => - hasError = true - errorResult = Option.Some(error) - case Result.Success(pattern) => - boundPatterns(i) = pattern - i = i + 1 - } - } + // Get constructor parameter types for type assertions + getFunctionParameterTypes(constructor) match { + case Either.Left(error) => + diagnosticBag.reportNotCallable( + AstUtils.locationOfName(constructorName) + ) + Result.Error( + BoundExpression.Error( + "Constructor is not callable: " + constructor.name + ) + ) + case Either.Right(parameterTypes) => - errorResult match { - case Option.None => - Result.Success(BoundPattern.Extract(constructor, boundPatterns)) - case Option.Some(value) => - Result.Error(value) + // Verify parameter count matches + if (patterns.length != parameterTypes.length) { + diagnosticBag.reportInternalError( + AstUtils.locationOfName(constructorName), + "Pattern parameter count mismatch" + ) + Result.Error( + BoundExpression.Error("Parameter count mismatch") + ) + } else { + // Bind each pattern parameter with its expected type + val boundPatterns = new Array[BoundPattern](patterns.length) + + var i = 0 + var hasError = false + var errorResult: Option[BoundExpression.Error] = Option.None + + while (i < patterns.length && !hasError) { + bindPatternWithType( + patterns(i).pattern, + scope, + parameterTypes(i) + ) match { + case Result.Error(error) => + hasError = true + errorResult = Option.Some(error) + case Result.Success(pattern) => + boundPatterns(i) = pattern + i = i + 1 + } + } + + errorResult match { + case Option.None => + Result.Success( + BoundPattern.Extract(constructor, boundPatterns) + ) + case Option.Some(value) => + Result.Error(value) + } + } } + } } } + def getFunctionParameterTypes( + symbol: Symbol + ): Either[Type.Error, Array[Type]] = { + binder.getSymbolType(symbol) match { + case Type.Error(message) => Either.Left(Type.Error(message)) + case f: Type.Function => + val paramTypes = getParameterTypes(f.parameters) + val result = new Array[Type](paramTypes.length) + fillParameterTypes(result, 0, paramTypes) + Either.Right(result) + case gf: Type.GenericFunction => + val paramTypes = getParameterTypes(gf.parameters) + val result = new Array[Type](paramTypes.length) + fillParameterTypes(result, 0, paramTypes) + Either.Right(result) + case x => + diagnosticBag.reportNotCallable(symbol.location) + Either.Left(Type.Error("Symbol is not a function: " + symbol.name)) + } + } + + def fillParameterTypes( + array: Array[Type], + index: int, + list: List[Type] + ): unit = { + list match { + case List.Nil => () + case List.Cons(head, tail) => + array(index) = head + fillParameterTypes(array, index + 1, tail) + } + } + def bindIdentifierPattern( scope: Scope, identifier: SyntaxToken, diff --git a/pncs/src/main/scala/ExprBinder.scala b/pncs/src/main/scala/ExprBinder.scala index 2c7809f..ccd3405 100644 --- a/pncs/src/main/scala/ExprBinder.scala +++ b/pncs/src/main/scala/ExprBinder.scala @@ -1220,7 +1220,27 @@ case class ExprBinder( ) ) case Option.Some(symbol) => - Result.Success(symbol) + // First check if this symbol itself is a constructor + if (symbol.kind == SymbolKind.Constructor) { + Result.Success(symbol) + } else { + // Look for the constructor inside the class/enum case + val classScope = Scope(symbol, List.Nil) + classScope.lookup(".ctor") match { + case Option.None => + diagnosticBag.reportSymbolNotFound( + identifier.location, + ".ctor" + ) + Result.Error( + BoundExpression.Error( + "Constructor not found for: " + identifier.text + ) + ) + case Option.Some(ctorSymbol) => + Result.Success(ctorSymbol) + } + } } case NameSyntax.QualifiedName( left, @@ -1241,7 +1261,27 @@ case class ExprBinder( ) ) case Option.Some(symbol) => - Result.Success(symbol) + // First check if this symbol itself is a constructor + if (symbol.kind == SymbolKind.Constructor) { + Result.Success(symbol) + } else { + // Look for the constructor inside the class/enum case + val classScope = Scope(symbol, List.Nil) + classScope.lookup(".ctor") match { + case Option.None => + diagnosticBag.reportSymbolNotFound( + identifier.location, + ".ctor" + ) + Result.Error( + BoundExpression.Error( + "Constructor not found for: " + identifier.text + ) + ) + case Option.Some(ctorSymbol) => + Result.Success(ctorSymbol) + } + } } case _ => diagnosticBag.reportInternalError( @@ -1256,6 +1296,39 @@ case class ExprBinder( } } + def bindPatternWithType( + pattern: PatternSyntax, + scope: Scope, + expectedType: Type + ): Result[BoundExpression.Error, BoundPattern] = { + pattern match { + case PatternSyntax.Literal(token) => + bindLiteralFromSyntaxToken(token) match { + case Result.Error(value) => Result.Error(value) + case Result.Success(literal) => + Result.Success( + BoundPattern.Literal( + literal + ) + ) + } + case PatternSyntax.Discard(_) => + Result.Success(BoundPattern.Discard) + case PatternSyntax.Identifier(identifier) => + bindIdentifierPattern(scope, identifier, expectedType) + case PatternSyntax.Type(_) => + // For type patterns, we create a wildcard pattern + Result.Success(BoundPattern.Discard) + case PatternSyntax.TypeAssertion(innerPattern, typeAnnotation) => + // Bind the inner pattern with the annotated type + val annotatedType = binder.bindTypeName(typeAnnotation.typ, scope) + bindPatternWithType(innerPattern, scope, annotatedType) + case PatternSyntax.Extract(constructorName, _, patterns, _) => + // For nested extract patterns, use the regular bindPattern + bindPattern(pattern, scope) + } + } + def bindPattern( pattern: PatternSyntax, scope: Scope @@ -1316,37 +1389,100 @@ case class ExprBinder( case Result.Error(error) => Result.Error(error) case Result.Success(constructor) => - // TODO: Verify that the constructor is valid - // get its parameter types and set up each pattern with a TypeAssertion - // and set the types for all the pattern variables - - // Bind each pattern parameter - val boundPatterns = new Array[BoundPattern](patterns.length) - var i = 0 - var hasError = false - var errorResult: Option[BoundExpression.Error] = Option.None - - while (i < patterns.length && !hasError) { - bindPattern(patterns(i).pattern, scope) match { - case Result.Error(error) => - hasError = true - errorResult = Option.Some(error) - case Result.Success(pattern) => - boundPatterns(i) = pattern - i = i + 1 - } - } + // Get constructor parameter types for type assertions + getFunctionParameterTypes(constructor) match { + case Either.Left(error) => + diagnosticBag.reportNotCallable( + AstUtils.locationOfName(constructorName) + ) + Result.Error( + BoundExpression.Error( + "Constructor is not callable: " + constructor.name + ) + ) + case Either.Right(parameterTypes) => - errorResult match { - case Option.None => - Result.Success(BoundPattern.Extract(constructor, boundPatterns)) - case Option.Some(value) => - Result.Error(value) + // Verify parameter count matches + if (patterns.length != parameterTypes.length) { + diagnosticBag.reportInternalError( + AstUtils.locationOfName(constructorName), + "Pattern parameter count mismatch" + ) + Result.Error( + BoundExpression.Error("Parameter count mismatch") + ) + } else { + // Bind each pattern parameter with its expected type + val boundPatterns = new Array[BoundPattern](patterns.length) + + var i = 0 + var hasError = false + var errorResult: Option[BoundExpression.Error] = Option.None + + while (i < patterns.length && !hasError) { + bindPatternWithType( + patterns(i).pattern, + scope, + parameterTypes(i) + ) match { + case Result.Error(error) => + hasError = true + errorResult = Option.Some(error) + case Result.Success(pattern) => + boundPatterns(i) = pattern + i = i + 1 + } + } + + errorResult match { + case Option.None => + Result.Success( + BoundPattern.Extract(constructor, boundPatterns) + ) + case Option.Some(value) => + Result.Error(value) + } + } } + } } } + def getFunctionParameterTypes( + symbol: Symbol + ): Either[Type.Error, Array[Type]] = { + binder.getSymbolType(symbol) match { + case Type.Error(message) => Either.Left(Type.Error(message)) + case f: Type.Function => + val paramTypes = getParameterTypes(f.parameters) + val result = new Array[Type](paramTypes.length) + fillParameterTypes(result, 0, paramTypes) + Either.Right(result) + case gf: Type.GenericFunction => + val paramTypes = getParameterTypes(gf.parameters) + val result = new Array[Type](paramTypes.length) + fillParameterTypes(result, 0, paramTypes) + Either.Right(result) + case x => + diagnosticBag.reportNotCallable(symbol.location) + Either.Left(Type.Error("Symbol is not a function: " + symbol.name)) + } + } + + def fillParameterTypes( + array: Array[Type], + index: int, + list: List[Type] + ): unit = { + list match { + case List.Nil => () + case List.Cons(head, tail) => + array(index) = head + fillParameterTypes(array, index + 1, tail) + } + } + def bindIdentifierPattern( scope: Scope, identifier: SyntaxToken,