Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 54 additions & 4 deletions pnc/src/ExprBinder.pn
Original file line number Diff line number Diff line change
Expand Up @@ -1125,6 +1125,14 @@ class ExprBinder(
}
}

def getArgumentTypes(arguments: List[BoundExpression]): List[Type] = {
arguments match {
case List.Nil => List.Nil
case List.Cons(head, tail) =>
List.Cons(binder.getType(head), getArgumentTypes(tail))
}
}

def bindArguments(
parameters: List[BoundParameter],
arguments: List[BoundExpression],
Expand Down Expand Up @@ -2006,9 +2014,39 @@ class ExprBinder(
case Option.Some(
Type.GenericFunction(loc, generics, traits, params, _)
) =>
// Instantiate generic function with inferred type arguments
// For generic constructors, infer type arguments from constructor arguments
val parameterTypes = getParameterTypes(params)
val argumentTypes = getArgumentTypes(args)
val inferredTypeArgs = inferTypeArgumentsFromCall(
generics,
parameterTypes,
argumentTypes
)

// Create the instantiated return type using inferred type arguments
val inferredInstantiationType = instantiationType match {
case Type.Class(clsLoc, clsNs, clsName, _, clsSymbol) =>
// Replace generic type arguments with inferred ones
Type.Class(
clsLoc,
clsNs,
clsName,
inferredTypeArgs,
clsSymbol
)
case other => other
}

// Substitute type variables in parameter types
val substitutedParams =
substituteParameterTypes(params, inferredTypeArgs)

val instantiatedFunction: Type.Function =
Type.Function(loc, params, instantiationType)
Type.Function(
loc,
substitutedParams,
inferredInstantiationType
)
bindNewExpressionForSymbol(
location,
ctor,
Expand Down Expand Up @@ -2048,8 +2086,20 @@ class ExprBinder(
val location = AstUtils.locationOfExpression(node)

// Try to infer type arguments from constructor arguments
val inferredTypeArgs =
inferTypeArgumentsFromConstructor(genericParams, ctor, args)
val inferredTypeArgs = binder.tryGetSymbolType(ctor) match {
case Option.Some(
Type.GenericFunction(_, ctorGenerics, _, ctorParams, _)
) =>
val parameterTypes = getParameterTypes(ctorParams)
val argumentTypes = getArgumentTypes(args)
inferTypeArgumentsFromCall(
ctorGenerics,
parameterTypes,
argumentTypes
)
case _ =>
inferTypeArgumentsFromConstructor(genericParams, ctor, args)
}
val instantiatedType = Type.Class(
location,
ns,
Expand Down
58 changes: 54 additions & 4 deletions pncs/src/main/scala/ExprBinder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1125,6 +1125,14 @@ case class ExprBinder(
}
}

def getArgumentTypes(arguments: List[BoundExpression]): List[Type] = {
arguments match {
case List.Nil => List.Nil
case List.Cons(head, tail) =>
List.Cons(binder.getType(head), getArgumentTypes(tail))
}
}

def bindArguments(
parameters: List[BoundParameter],
arguments: List[BoundExpression],
Expand Down Expand Up @@ -2006,9 +2014,39 @@ case class ExprBinder(
case Option.Some(
Type.GenericFunction(loc, generics, traits, params, _)
) =>
// Instantiate generic function with inferred type arguments
// For generic constructors, infer type arguments from constructor arguments
val parameterTypes = getParameterTypes(params)
val argumentTypes = getArgumentTypes(args)
val inferredTypeArgs = inferTypeArgumentsFromCall(
generics,
parameterTypes,
argumentTypes
)

// Create the instantiated return type using inferred type arguments
val inferredInstantiationType = instantiationType match {
case Type.Class(clsLoc, clsNs, clsName, _, clsSymbol) =>
// Replace generic type arguments with inferred ones
Type.Class(
clsLoc,
clsNs,
clsName,
inferredTypeArgs,
clsSymbol
)
case other => other
}

// Substitute type variables in parameter types
val substitutedParams =
substituteParameterTypes(params, inferredTypeArgs)

val instantiatedFunction: Type.Function =
Type.Function(loc, params, instantiationType)
Type.Function(
loc,
substitutedParams,
inferredInstantiationType
)
bindNewExpressionForSymbol(
location,
ctor,
Expand Down Expand Up @@ -2048,8 +2086,20 @@ case class ExprBinder(
val location = AstUtils.locationOfExpression(node)

// Try to infer type arguments from constructor arguments
val inferredTypeArgs =
inferTypeArgumentsFromConstructor(genericParams, ctor, args)
val inferredTypeArgs = binder.tryGetSymbolType(ctor) match {
case Option.Some(
Type.GenericFunction(_, ctorGenerics, _, ctorParams, _)
) =>
val parameterTypes = getParameterTypes(ctorParams)
val argumentTypes = getArgumentTypes(args)
inferTypeArgumentsFromCall(
ctorGenerics,
parameterTypes,
argumentTypes
)
case _ =>
inferTypeArgumentsFromConstructor(genericParams, ctor, args)
}
val instantiatedType = Type.Class(
location,
ns,
Expand Down
20 changes: 12 additions & 8 deletions test/src/test/scala/TypeTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -317,14 +317,18 @@ object TypeTests extends TestSuite {
// assertExprTypeWithSetup(setup, "first(new Array[char](1))", "char")
// }
//
// test("generic container creation (commented until generics work)") {
// val containerSetup = "class Container[T](value: T)\n" +
// "def wrap[T](x: T): Container[T] = new Container(x)"
//
// assertExprTypeWithSetup(containerSetup, "wrap(42)", "Container[int]")
// assertExprTypeWithSetup(containerSetup, "wrap(true)", "Container[bool]")
// assertExprTypeWithSetup(containerSetup, "wrap(\"test\")", "Container[string]")
// }
test("generic container creation (commented until generics work)") {
val containerSetup = "class Container[T](value: T)\n" +
"def wrap[T](x: T): Container[T] = new Container(x)"

assertExprTypeWithSetup(containerSetup, "wrap(42)", "Container<int>")
assertExprTypeWithSetup(containerSetup, "wrap(true)", "Container<bool>")
assertExprTypeWithSetup(
containerSetup,
"wrap(\"test\")",
"Container<string>"
)
}
//
// test("option-like generic methods (commented until generics work)") {
// val optionSetup = "enum Option[T] {\n" +
Expand Down