diff --git a/pnc/src/ExprBinder.pn b/pnc/src/ExprBinder.pn index b029884..44dc4e5 100644 --- a/pnc/src/ExprBinder.pn +++ b/pnc/src/ExprBinder.pn @@ -1125,6 +1125,14 @@ class ExprBinder( } } + def getArgumentTypes(arguments: List[BoundExpression]): List[Type] = { + arguments match { + case List.Nil => List.Nil + case List.Cons(head, tail) => + List.Cons(binder.getType(head), getArgumentTypes(tail)) + } + } + def bindArguments( parameters: List[BoundParameter], arguments: List[BoundExpression], @@ -2006,9 +2014,39 @@ class ExprBinder( case Option.Some( Type.GenericFunction(loc, generics, traits, params, _) ) => - // Instantiate generic function with inferred type arguments + // For generic constructors, infer type arguments from constructor arguments + val parameterTypes = getParameterTypes(params) + val argumentTypes = getArgumentTypes(args) + val inferredTypeArgs = inferTypeArgumentsFromCall( + generics, + parameterTypes, + argumentTypes + ) + + // Create the instantiated return type using inferred type arguments + val inferredInstantiationType = instantiationType match { + case Type.Class(clsLoc, clsNs, clsName, _, clsSymbol) => + // Replace generic type arguments with inferred ones + Type.Class( + clsLoc, + clsNs, + clsName, + inferredTypeArgs, + clsSymbol + ) + case other => other + } + + // Substitute type variables in parameter types + val substitutedParams = + substituteParameterTypes(params, inferredTypeArgs) + val instantiatedFunction: Type.Function = - Type.Function(loc, params, instantiationType) + Type.Function( + loc, + substitutedParams, + inferredInstantiationType + ) bindNewExpressionForSymbol( location, ctor, @@ -2048,8 +2086,20 @@ class ExprBinder( val location = AstUtils.locationOfExpression(node) // Try to infer type arguments from constructor arguments - val inferredTypeArgs = - inferTypeArgumentsFromConstructor(genericParams, ctor, args) + val inferredTypeArgs = binder.tryGetSymbolType(ctor) match { + case Option.Some( + Type.GenericFunction(_, ctorGenerics, _, ctorParams, _) + ) => + val parameterTypes = getParameterTypes(ctorParams) + val argumentTypes = getArgumentTypes(args) + inferTypeArgumentsFromCall( + ctorGenerics, + parameterTypes, + argumentTypes + ) + case _ => + inferTypeArgumentsFromConstructor(genericParams, ctor, args) + } val instantiatedType = Type.Class( location, ns, diff --git a/pncs/src/main/scala/ExprBinder.scala b/pncs/src/main/scala/ExprBinder.scala index 89043dc..5180ca3 100644 --- a/pncs/src/main/scala/ExprBinder.scala +++ b/pncs/src/main/scala/ExprBinder.scala @@ -1125,6 +1125,14 @@ case class ExprBinder( } } + def getArgumentTypes(arguments: List[BoundExpression]): List[Type] = { + arguments match { + case List.Nil => List.Nil + case List.Cons(head, tail) => + List.Cons(binder.getType(head), getArgumentTypes(tail)) + } + } + def bindArguments( parameters: List[BoundParameter], arguments: List[BoundExpression], @@ -2006,9 +2014,39 @@ case class ExprBinder( case Option.Some( Type.GenericFunction(loc, generics, traits, params, _) ) => - // Instantiate generic function with inferred type arguments + // For generic constructors, infer type arguments from constructor arguments + val parameterTypes = getParameterTypes(params) + val argumentTypes = getArgumentTypes(args) + val inferredTypeArgs = inferTypeArgumentsFromCall( + generics, + parameterTypes, + argumentTypes + ) + + // Create the instantiated return type using inferred type arguments + val inferredInstantiationType = instantiationType match { + case Type.Class(clsLoc, clsNs, clsName, _, clsSymbol) => + // Replace generic type arguments with inferred ones + Type.Class( + clsLoc, + clsNs, + clsName, + inferredTypeArgs, + clsSymbol + ) + case other => other + } + + // Substitute type variables in parameter types + val substitutedParams = + substituteParameterTypes(params, inferredTypeArgs) + val instantiatedFunction: Type.Function = - Type.Function(loc, params, instantiationType) + Type.Function( + loc, + substitutedParams, + inferredInstantiationType + ) bindNewExpressionForSymbol( location, ctor, @@ -2048,8 +2086,20 @@ case class ExprBinder( val location = AstUtils.locationOfExpression(node) // Try to infer type arguments from constructor arguments - val inferredTypeArgs = - inferTypeArgumentsFromConstructor(genericParams, ctor, args) + val inferredTypeArgs = binder.tryGetSymbolType(ctor) match { + case Option.Some( + Type.GenericFunction(_, ctorGenerics, _, ctorParams, _) + ) => + val parameterTypes = getParameterTypes(ctorParams) + val argumentTypes = getArgumentTypes(args) + inferTypeArgumentsFromCall( + ctorGenerics, + parameterTypes, + argumentTypes + ) + case _ => + inferTypeArgumentsFromConstructor(genericParams, ctor, args) + } val instantiatedType = Type.Class( location, ns, diff --git a/test/src/test/scala/TypeTests.scala b/test/src/test/scala/TypeTests.scala index 7141c5d..269a25d 100644 --- a/test/src/test/scala/TypeTests.scala +++ b/test/src/test/scala/TypeTests.scala @@ -317,14 +317,18 @@ object TypeTests extends TestSuite { // assertExprTypeWithSetup(setup, "first(new Array[char](1))", "char") // } // -// test("generic container creation (commented until generics work)") { -// val containerSetup = "class Container[T](value: T)\n" + -// "def wrap[T](x: T): Container[T] = new Container(x)" -// -// assertExprTypeWithSetup(containerSetup, "wrap(42)", "Container[int]") -// assertExprTypeWithSetup(containerSetup, "wrap(true)", "Container[bool]") -// assertExprTypeWithSetup(containerSetup, "wrap(\"test\")", "Container[string]") -// } + test("generic container creation (commented until generics work)") { + val containerSetup = "class Container[T](value: T)\n" + + "def wrap[T](x: T): Container[T] = new Container(x)" + + assertExprTypeWithSetup(containerSetup, "wrap(42)", "Container") + assertExprTypeWithSetup(containerSetup, "wrap(true)", "Container") + assertExprTypeWithSetup( + containerSetup, + "wrap(\"test\")", + "Container" + ) + } // // test("option-like generic methods (commented until generics work)") { // val optionSetup = "enum Option[T] {\n" +