Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 163 additions & 27 deletions pnc/src/ExprBinder.pn
Original file line number Diff line number Diff line change
Expand Up @@ -1220,7 +1220,27 @@ class ExprBinder(
)
)
case Option.Some(symbol) =>
Result.Success(symbol)
// First check if this symbol itself is a constructor
if (symbol.kind == SymbolKind.Constructor) {
Result.Success(symbol)
} else {
// Look for the constructor inside the class/enum case
val classScope = Scope(symbol, List.Nil)
classScope.lookup(".ctor") match {
case Option.None =>
diagnosticBag.reportSymbolNotFound(
identifier.location,
".ctor"
)
Result.Error(
BoundExpression.Error(
"Constructor not found for: " + identifier.text
)
)
case Option.Some(ctorSymbol) =>
Result.Success(ctorSymbol)
}
}
}
case NameSyntax.QualifiedName(
left,
Expand All @@ -1241,7 +1261,27 @@ class ExprBinder(
)
)
case Option.Some(symbol) =>
Result.Success(symbol)
// First check if this symbol itself is a constructor
if (symbol.kind == SymbolKind.Constructor) {
Result.Success(symbol)
} else {
// Look for the constructor inside the class/enum case
val classScope = Scope(symbol, List.Nil)
classScope.lookup(".ctor") match {
case Option.None =>
diagnosticBag.reportSymbolNotFound(
identifier.location,
".ctor"
)
Result.Error(
BoundExpression.Error(
"Constructor not found for: " + identifier.text
)
)
case Option.Some(ctorSymbol) =>
Result.Success(ctorSymbol)
}
}
}
case _ =>
diagnosticBag.reportInternalError(
Expand All @@ -1256,6 +1296,39 @@ class ExprBinder(
}
}

def bindPatternWithType(
pattern: PatternSyntax,
scope: Scope,
expectedType: Type
): Result[BoundExpression.Error, BoundPattern] = {
pattern match {
case PatternSyntax.Literal(token) =>
bindLiteralFromSyntaxToken(token) match {
case Result.Error(value) => Result.Error(value)
case Result.Success(literal) =>
Result.Success(
BoundPattern.Literal(
literal
)
)
}
case PatternSyntax.Discard(_) =>
Result.Success(BoundPattern.Discard)
case PatternSyntax.Identifier(identifier) =>
bindIdentifierPattern(scope, identifier, expectedType)
case PatternSyntax.Type(_) =>
// For type patterns, we create a wildcard pattern
Result.Success(BoundPattern.Discard)
case PatternSyntax.TypeAssertion(innerPattern, typeAnnotation) =>
// Bind the inner pattern with the annotated type
val annotatedType = binder.bindTypeName(typeAnnotation.typ, scope)
bindPatternWithType(innerPattern, scope, annotatedType)
case PatternSyntax.Extract(constructorName, _, patterns, _) =>
// For nested extract patterns, use the regular bindPattern
bindPattern(pattern, scope)
}
}

def bindPattern(
pattern: PatternSyntax,
scope: Scope
Expand Down Expand Up @@ -1316,37 +1389,100 @@ class ExprBinder(
case Result.Error(error) => Result.Error(error)
case Result.Success(constructor) =>

// TODO: Verify that the constructor is valid
// get its parameter types and set up each pattern with a TypeAssertion
// and set the types for all the pattern variables

// Bind each pattern parameter
val boundPatterns = new Array[BoundPattern](patterns.length)
var i = 0
var hasError = false
var errorResult: Option[BoundExpression.Error] = Option.None

while (i < patterns.length && !hasError) {
bindPattern(patterns(i).pattern, scope) match {
case Result.Error(error) =>
hasError = true
errorResult = Option.Some(error)
case Result.Success(pattern) =>
boundPatterns(i) = pattern
i = i + 1
}
}
// Get constructor parameter types for type assertions
getFunctionParameterTypes(constructor) match {
case Either.Left(error) =>
diagnosticBag.reportNotCallable(
AstUtils.locationOfName(constructorName)
)
Result.Error(
BoundExpression.Error(
"Constructor is not callable: " + constructor.name
)
)
case Either.Right(parameterTypes) =>

errorResult match {
case Option.None =>
Result.Success(BoundPattern.Extract(constructor, boundPatterns))
case Option.Some(value) =>
Result.Error(value)
// Verify parameter count matches
if (patterns.length != parameterTypes.length) {
diagnosticBag.reportInternalError(
AstUtils.locationOfName(constructorName),
"Pattern parameter count mismatch"
)
Result.Error(
BoundExpression.Error("Parameter count mismatch")
)
} else {
// Bind each pattern parameter with its expected type
val boundPatterns = new Array[BoundPattern](patterns.length)

var i = 0
var hasError = false
var errorResult: Option[BoundExpression.Error] = Option.None

while (i < patterns.length && !hasError) {
bindPatternWithType(
patterns(i).pattern,
scope,
parameterTypes(i)
) match {
case Result.Error(error) =>
hasError = true
errorResult = Option.Some(error)
case Result.Success(pattern) =>
boundPatterns(i) = pattern
i = i + 1
}
}

errorResult match {
case Option.None =>
Result.Success(
BoundPattern.Extract(constructor, boundPatterns)
)
case Option.Some(value) =>
Result.Error(value)
}
}
}

}
}
}

def getFunctionParameterTypes(
symbol: Symbol
): Either[Type.Error, Array[Type]] = {
binder.getSymbolType(symbol) match {
case Type.Error(message) => Either.Left(Type.Error(message))
case f: Type.Function =>
val paramTypes = getParameterTypes(f.parameters)
val result = new Array[Type](paramTypes.length)
fillParameterTypes(result, 0, paramTypes)
Either.Right(result)
case gf: Type.GenericFunction =>
val paramTypes = getParameterTypes(gf.parameters)
val result = new Array[Type](paramTypes.length)
fillParameterTypes(result, 0, paramTypes)
Either.Right(result)
case x =>
diagnosticBag.reportNotCallable(symbol.location)
Either.Left(Type.Error("Symbol is not a function: " + symbol.name))
}
}

def fillParameterTypes(
array: Array[Type],
index: int,
list: List[Type]
): unit = {
list match {
case List.Nil => ()
case List.Cons(head, tail) =>
array(index) = head
fillParameterTypes(array, index + 1, tail)
}
}

def bindIdentifierPattern(
scope: Scope,
identifier: SyntaxToken,
Expand Down
Loading