diff --git a/docs/generators/visitor-generator.md b/docs/generators/visitor-generator.md index c749470..c0afabf 100644 --- a/docs/generators/visitor-generator.md +++ b/docs/generators/visitor-generator.md @@ -1,6 +1,6 @@ # Visitor Pattern Generator -The Visitor Pattern Generator automatically generates fluent, type-safe visitor infrastructure for class hierarchies marked with the `[GenerateVisitor]` attribute. This eliminates boilerplate code and provides modern C# ergonomics including async/await support, ValueTask, and generic type inference. +The Visitor Pattern Generator automatically generates fluent, type-safe visitor infrastructure for hierarchies marked with the `[GenerateVisitor]` attribute. It supports class, interface, struct, and record hierarchies, eliminating boilerplate code and providing modern C# ergonomics including async/await support, ValueTask, and generic type inference. ## Overview @@ -118,6 +118,129 @@ var asyncLogger = new AstNodeAsyncActionVisitorBuilder() await myExpression.AcceptAsync(asyncLogger); ``` +## Supported Hierarchy Types + +The visitor generator supports multiple types of hierarchies, providing flexibility in design: + +### Class-Based Hierarchies + +Traditional class inheritance hierarchies are fully supported: + +```csharp +[GenerateVisitor] +public abstract partial class Animal +{ +} + +public partial class Dog : Animal +{ + public string Breed { get; init; } +} + +public partial class Cat : Animal +{ + public bool IsIndoor { get; init; } +} +``` + +### Interface-Based Hierarchies + +Hierarchies based on interfaces work seamlessly: + +```csharp +[GenerateVisitor] +public partial interface IShape +{ +} + +public partial class Circle : IShape +{ + public double Radius { get; init; } +} + +public partial class Rectangle : IShape +{ + public double Width { get; init; } + public double Height { get; init; } +} + +public partial class Triangle : IShape +{ + public double Base { get; init; } + public double Height { get; init; } +} +``` + +**Note:** For interface base types, the generated visitor interface name is intelligently derived. `IShape` generates `IShapeVisitor` (not `IIShapeVisitor`). + +### Struct-Based Hierarchies + +Value types can implement visitable interfaces for allocation-free visitor patterns: + +```csharp +[GenerateVisitor] +public partial interface IValue +{ +} + +public partial struct IntValue : IValue +{ + public int Value { get; init; } +} + +public partial struct DoubleValue : IValue +{ + public double Value { get; init; } +} + +// No boxing occurs during visitation +var visitor = new IValueVisitorBuilder() + .When(i => $"Int:{i.Value}") + .When(d => $"Double:{d.Value:F2}") + .Build(); + +var intVal = new IntValue { Value = 42 }; +var result = intVal.Accept(visitor); // "Int:42" +``` + +### Record Types + +Records are also supported: + +```csharp +[GenerateVisitor] +public abstract partial record Message; + +public partial record TextMessage(string Content) : Message; +public partial record ImageMessage(byte[] Data, string Format) : Message; +``` + +### Mixed Hierarchies + +You can mix interfaces, classes, and structs in complex hierarchies: + +```csharp +[GenerateVisitor] +public partial interface INode +{ +} + +public abstract partial class Expression : INode +{ +} + +public partial class Literal : Expression +{ + public object Value { get; init; } +} + +public partial struct Position : INode +{ + public int Line { get; init; } + public int Column { get; init; } +} +``` + ## Attribute Options The `[GenerateVisitor]` attribute supports several options: @@ -391,6 +514,62 @@ var validator = new DocumentVisitorBuilder() .Build(); ``` +## Diagnostics + +The generator provides helpful diagnostics to catch common issues: + +### PKVIS001: No concrete types found + +**Severity:** Warning + +This warning appears when the generator cannot find any concrete types implementing or deriving from the marked base type. + +```csharp +[GenerateVisitor] +public partial interface IEmptyHierarchy { } + +// Warning PKVIS001: No concrete types implementing or deriving from 'IEmptyHierarchy' were found +``` + +**Solutions:** +- Add concrete types that implement the interface or derive from the class +- Set `AutoDiscoverDerivedTypes = false` if you're building types manually + +### PKVIS002: Type must be partial + +**Severity:** Error + +The base type (class, struct, or interface) must be declared as `partial` to allow Accept method generation. + +```csharp +[GenerateVisitor] +public class NonPartialBase { } // Error! + +// Fix: +[GenerateVisitor] +public partial class PartialBase { } // Correct +``` + +**Solution:** Add the `partial` keyword to the type declaration. + +### PKVIS004: Derived type must be partial + +**Severity:** Error + +All derived types must be `partial` to allow Accept method generation. + +```csharp +[GenerateVisitor] +public partial class Base { } + +public class Derived : Base { } // Error! + +// Fix: +public partial class Derived : Base { } // Correct +``` + +**Solution:** Add the `partial` keyword to all derived types in the hierarchy. + ## Troubleshooting ### "No handler registered for type X" diff --git a/src/PatternKit.Generators.Abstractions/Visitors/VisitorAttributes.cs b/src/PatternKit.Generators.Abstractions/Visitors/VisitorAttributes.cs index 522987a..20bb1b4 100644 --- a/src/PatternKit.Generators.Abstractions/Visitors/VisitorAttributes.cs +++ b/src/PatternKit.Generators.Abstractions/Visitors/VisitorAttributes.cs @@ -12,6 +12,7 @@ namespace PatternKit.Generators.Visitors; /// Fluent builder APIs for composing visitors /// /// +/// Class-based hierarchy: /// /// [GenerateVisitor] /// public partial class AstNode { } @@ -19,9 +20,17 @@ namespace PatternKit.Generators.Visitors; /// public partial class Expression : AstNode { } /// public partial class Statement : AstNode { } /// +/// Interface-based hierarchy: +/// +/// [GenerateVisitor] +/// public partial interface IShape { } +/// +/// public partial class Circle : IShape { } +/// public partial class Rectangle : IShape { } +/// /// /// -[AttributeUsage(AttributeTargets.Class, AllowMultiple = false, Inherited = false)] +[AttributeUsage(AttributeTargets.Class | AttributeTargets.Interface | AttributeTargets.Struct, AllowMultiple = false, Inherited = false)] public sealed class GenerateVisitorAttribute : Attribute { /// diff --git a/src/PatternKit.Generators/AnalyzerReleases.Unshipped.md b/src/PatternKit.Generators/AnalyzerReleases.Unshipped.md index 90f116b..e530359 100644 --- a/src/PatternKit.Generators/AnalyzerReleases.Unshipped.md +++ b/src/PatternKit.Generators/AnalyzerReleases.Unshipped.md @@ -37,3 +37,6 @@ PKMEM003 | PatternKit.Generators.Memento | Warning | Unsafe reference capture PKMEM004 | PatternKit.Generators.Memento | Error | Clone strategy requested but mechanism missing PKMEM005 | PatternKit.Generators.Memento | Error | Record restore generation failed PKMEM006 | PatternKit.Generators.Memento | Info | Init-only or readonly restrictions prevent in-place restore +PKVIS001 | PatternKit.Generators.Visitor | Warning | No concrete types found for visitor generation +PKVIS002 | PatternKit.Generators.Visitor | Error | Type must be partial for Accept method generation +PKVIS004 | PatternKit.Generators.Visitor | Error | Derived type must be partial for Accept method generation diff --git a/src/PatternKit.Generators/VisitorGenerator.cs b/src/PatternKit.Generators/VisitorGenerator.cs index 470b5ca..7971246 100644 --- a/src/PatternKit.Generators/VisitorGenerator.cs +++ b/src/PatternKit.Generators/VisitorGenerator.cs @@ -1,4 +1,5 @@ using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using System.Collections.Immutable; using System.Text; @@ -15,10 +16,13 @@ public sealed class VisitorGenerator : IIncrementalGenerator { public void Initialize(IncrementalGeneratorInitializationContext context) { - // Find all classes marked with [GenerateVisitor] + // Find all types (classes, interfaces, structs, records) marked with [GenerateVisitor] var visitorRoots = context.SyntaxProvider.ForAttributeWithMetadataName( fullyQualifiedMetadataName: "PatternKit.Generators.Visitors.GenerateVisitorAttribute", - predicate: static (node, _) => node is ClassDeclarationSyntax, + predicate: static (node, _) => node is ClassDeclarationSyntax + or InterfaceDeclarationSyntax + or StructDeclarationSyntax + or RecordDeclarationSyntax, transform: static (gasc, ct) => GetVisitorRoot(gasc, ct) ).Where(static x => x is not null); @@ -52,6 +56,9 @@ public void Initialize(IncrementalGeneratorInitializationContext context) var baseName = baseType.Name; var baseFullName = baseType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + + // Generate default visitor interface name + var defaultVisitorName = GetDefaultVisitorInterfaceName(baseType); // Discover derived types in the same assembly var derivedTypes = autoDiscover @@ -63,12 +70,35 @@ public void Initialize(IncrementalGeneratorInitializationContext context) BaseName: baseName, BaseFullName: baseFullName, BaseType: baseType, - VisitorInterfaceName: visitorInterfaceName ?? $"I{baseName}Visitor", + VisitorInterfaceName: visitorInterfaceName ?? defaultVisitorName, GenerateAsync: generateAsync, GenerateActions: generateActions, DerivedTypes: derivedTypes ); } + + /// + /// Generates a default visitor interface name based on the base type. + /// For interfaces with I-prefix (e.g., IShape), generates IShapeVisitor. + /// For other types (e.g., Shape), generates IShapeVisitor. + /// + private static string GetDefaultVisitorInterfaceName(INamedTypeSymbol baseType) + { + var baseName = baseType.Name; + + // If base is an interface with I-prefix (Hungarian notation), don't add another I + if (baseType.TypeKind == TypeKind.Interface && + baseName.StartsWith("I") && + baseName.Length > 1 && + char.IsUpper(baseName[1])) + { + // Interface name like "IShape" -> "IShapeVisitor" + return $"{baseName}Visitor"; + } + + // Class name like "Shape" -> "IShapeVisitor" + return $"I{baseName}Visitor"; + } private static T? GetAttributeProperty(AttributeData attr, string propertyName) { @@ -89,30 +119,108 @@ private static ImmutableArray DiscoverDerivedTypes( var semanticModel = compilation.GetSemanticModel(tree); var root = tree.GetRoot(); + // Discover classes foreach (var classDecl in root.DescendantNodes().OfType()) { var symbol = semanticModel.GetDeclaredSymbol(classDecl) as INamedTypeSymbol; - if (symbol is null) continue; + if (symbol is null || SymbolEqualityComparer.Default.Equals(symbol, baseType)) continue; + + if (IsDerivedFrom(symbol, baseType)) + { + derived.Add(symbol); + } + } + + // Discover structs + foreach (var structDecl in root.DescendantNodes().OfType()) + { + var symbol = semanticModel.GetDeclaredSymbol(structDecl) as INamedTypeSymbol; + if (symbol is null || SymbolEqualityComparer.Default.Equals(symbol, baseType)) continue; - // Check if this type derives from baseType - var current = symbol.BaseType; - while (current is not null) + if (IsDerivedFrom(symbol, baseType)) { - if (SymbolEqualityComparer.Default.Equals(current, baseType)) - { - derived.Add(symbol); - break; - } - current = current.BaseType; + derived.Add(symbol); + } + } + + // Discover records + foreach (var recordDecl in root.DescendantNodes().OfType()) + { + var symbol = semanticModel.GetDeclaredSymbol(recordDecl) as INamedTypeSymbol; + if (symbol is null || SymbolEqualityComparer.Default.Equals(symbol, baseType)) continue; + + if (IsDerivedFrom(symbol, baseType)) + { + derived.Add(symbol); } } } return derived.ToImmutableArray(); } + + private static bool IsDerivedFrom(INamedTypeSymbol type, INamedTypeSymbol baseType) + { + // Check class inheritance + var current = type.BaseType; + while (current is not null) + { + if (SymbolEqualityComparer.Default.Equals(current, baseType)) + return true; + current = current.BaseType; + } + + // Check interface implementation + return ImplementsInterface(type, baseType); + } + + private static bool ImplementsInterface(INamedTypeSymbol type, INamedTypeSymbol interfaceType) + { + if (interfaceType.TypeKind != TypeKind.Interface) + return false; + + foreach (var iface in type.AllInterfaces) + { + if (SymbolEqualityComparer.Default.Equals(iface, interfaceType)) + return true; + } + + return false; + } private static void GenerateVisitorInfrastructure(SourceProductionContext context, VisitorRootInfo root) { + // Check if we have any concrete types to visit + if (root.DerivedTypes.IsEmpty) + { + var location = root.BaseType.Locations.FirstOrDefault(); + context.ReportDiagnostic(Diagnostic.Create( + Diagnostics.NoConcretTypes, + location, + root.BaseName)); + } + + // Check if base type is partial (all types need to be partial for Accept method generation) + if (!IsPartial(root.BaseType)) + { + var location = root.BaseType.Locations.FirstOrDefault(); + context.ReportDiagnostic(Diagnostic.Create( + Diagnostics.TypeMustBePartial, + location, + root.BaseName)); + return; // Can't generate if not partial + } + + // Check if derived types are partial + foreach (var derivedType in root.DerivedTypes.Where(dt => !IsPartial(dt))) + { + var location = derivedType.Locations.FirstOrDefault(); + context.ReportDiagnostic(Diagnostic.Create( + Diagnostics.DerivedTypeNotPartial, + location, + derivedType.Name)); + } + // Generate visitor interfaces GenerateVisitorInterfaces(context, root); @@ -122,6 +230,15 @@ private static void GenerateVisitorInfrastructure(SourceProductionContext contex // Generate fluent builders GenerateFluentBuilders(context, root); } + + private static bool IsPartial(INamedTypeSymbol type) + { + // Check if any of the declarations is marked as partial + return type.DeclaringSyntaxReferences + .Select(syntaxRef => syntaxRef.GetSyntax()) + .OfType() + .Any(typeDecl => typeDecl.Modifiers.Any(m => m.IsKind(SyntaxKind.PartialKeyword))); + } private static void GenerateVisitorInterfaces(SourceProductionContext context, VisitorRootInfo root) { @@ -230,7 +347,25 @@ private static void GenerateAcceptMethods(SourceProductionContext context, Visit sb.AppendLine(); } - sb.AppendLine($"public partial class {type.Name}"); + // Determine the type keyword (class, struct, interface, record, record struct) + string typeKeyword; + if (type.IsRecord) + { + typeKeyword = type.TypeKind == TypeKind.Struct + ? "record struct" + : "record"; + } + else + { + typeKeyword = type.TypeKind switch + { + TypeKind.Interface => "interface", + TypeKind.Struct => "struct", + _ => "class" + }; + } + + sb.AppendLine($"public partial {typeKeyword} {type.Name}"); sb.AppendLine("{"); // Sync Accept with result @@ -686,4 +821,45 @@ private readonly record struct VisitorRootInfo( bool GenerateActions, ImmutableArray DerivedTypes ); + + /// + /// Diagnostic descriptors for visitor pattern generation. + /// + private static class Diagnostics + { + private const string Category = "PatternKit.Generators.Visitor"; + + /// + /// PKVIS001: No concrete types found for a marked base type. + /// + public static readonly DiagnosticDescriptor NoConcretTypes = new( + "PKVIS001", + "No concrete types found", + "No concrete types implementing or deriving from '{0}' were found. Add derived types or set AutoDiscoverDerivedTypes = false", + Category, + DiagnosticSeverity.Warning, + isEnabledByDefault: true); + + /// + /// PKVIS002: Type must be partial for Accept method generation. + /// + public static readonly DiagnosticDescriptor TypeMustBePartial = new( + "PKVIS002", + "Type must be partial", + "Type '{0}' must be declared as partial to allow Accept method generation", + Category, + DiagnosticSeverity.Error, + isEnabledByDefault: true); + + /// + /// PKVIS004: Derived type is not partial. + /// + public static readonly DiagnosticDescriptor DerivedTypeNotPartial = new( + "PKVIS004", + "Derived type must be partial", + "Derived type '{0}' must be declared as partial to allow Accept method generation", + Category, + DiagnosticSeverity.Error, + isEnabledByDefault: true); + } } diff --git a/test/PatternKit.Generators.Tests/VisitorGeneratorTests.cs b/test/PatternKit.Generators.Tests/VisitorGeneratorTests.cs index c1b3990..4e60439 100644 --- a/test/PatternKit.Generators.Tests/VisitorGeneratorTests.cs +++ b/test/PatternKit.Generators.Tests/VisitorGeneratorTests.cs @@ -466,4 +466,503 @@ public partial class ResultLeaf : ResultNode var emit = updated.Emit(Stream.Null); Assert.True(emit.Success, string.Join("\n", emit.Diagnostics)); } + + [Fact] + public void Generates_Visitor_For_Interface_Hierarchy() + { + const string interfaceHierarchy = """ + using PatternKit.Generators.Visitors; + + namespace PatternKit.Examples.Shapes; + + [GenerateVisitor] + public partial interface IShape { } + + public partial class Circle : IShape + { + public double Radius { get; init; } + } + + public partial class Rectangle : IShape + { + public double Width { get; init; } + public double Height { get; init; } + } + + public partial class Triangle : IShape + { + public double Base { get; init; } + public double Height { get; init; } + } + """; + + var comp = RoslynTestHelpers.CreateCompilation( + interfaceHierarchy, + assemblyName: nameof(Generates_Visitor_For_Interface_Hierarchy)); + + var gen = new VisitorGenerator(); + _ = RoslynTestHelpers.Run(comp, gen, out var run, out var updated); + + // No generator diagnostics + Assert.All(run.Results, r => Assert.Empty(r.Diagnostics)); + + // Confirm we generated expected files + var names = run.Results.SelectMany(r => r.GeneratedSources).Select(gs => gs.HintName).ToArray(); + + // Interfaces + Assert.Contains("IShapeVisitor.Interfaces.g.cs", names); + + // Accept methods for each type (interface + concrete classes) + Assert.Contains("IShape.Accept.g.cs", names); + Assert.Contains("Circle.Accept.g.cs", names); + Assert.Contains("Rectangle.Accept.g.cs", names); + Assert.Contains("Triangle.Accept.g.cs", names); + + // Builders + Assert.Contains("IShapeVisitorBuilder.g.cs", names); + + // Verify compilation succeeds + var emit = updated.Emit(Stream.Null); + Assert.True(emit.Success, string.Join("\n", emit.Diagnostics)); + } + + [Fact] + public void Interface_Hierarchy_Visitor_Dispatches_Correctly() + { + const string interfaceHierarchyWithUsage = """ + using PatternKit.Generators.Visitors; + + namespace PatternKit.Examples.Shapes; + + [GenerateVisitor] + public partial interface IShape { } + + public partial class Circle : IShape + { + public double Radius { get; init; } + } + + public partial class Rectangle : IShape + { + public double Width { get; init; } + public double Height { get; init; } + } + + public static class Demo + { + public static double Run() + { + var areaCalculator = new IShapeVisitorBuilder() + .When(c => 3.14159 * c.Radius * c.Radius) + .When(r => r.Width * r.Height) + .Default(_ => 0.0) + .Build(); + + var circle = new Circle { Radius = 5.0 }; + var rectangle = new Rectangle { Width = 4.0, Height = 6.0 }; + + var circleArea = circle.Accept(areaCalculator); + var rectangleArea = rectangle.Accept(areaCalculator); + + return circleArea + rectangleArea; + } + } + """; + + var comp = RoslynTestHelpers.CreateCompilation( + interfaceHierarchyWithUsage, + assemblyName: nameof(Interface_Hierarchy_Visitor_Dispatches_Correctly)); + + var gen = new VisitorGenerator(); + _ = RoslynTestHelpers.Run(comp, gen, out var run, out var updated); + + Assert.All(run.Results, r => Assert.Empty(r.Diagnostics)); + + // Emit and execute + using var pe = new MemoryStream(); + using var pdb = new MemoryStream(); + var res = updated.Emit(pe, pdb); + Assert.True(res.Success, string.Join("\n", res.Diagnostics)); + pe.Position = 0; + + var asm = AssemblyLoadContext.Default.LoadFromStream(pe, pdb); + var demo = asm.GetType("PatternKit.Examples.Shapes.Demo")!; + var runMethod = demo.GetMethod("Run")!; + var result = (double)runMethod.Invoke(null, null)!; + + // Circle area: π * 5^2 ≈ 78.54 + // Rectangle area: 4 * 6 = 24 + // Total ≈ 102.54 + Assert.True(result > 100 && result < 105, $"Expected ~102.54, got {result}"); + } + + [Fact] + public void Generates_Visitor_For_Struct_Hierarchy() + { + const string structHierarchy = """ + using PatternKit.Generators.Visitors; + + namespace PatternKit.Examples.Values; + + [GenerateVisitor] + public partial interface IValue { } + + public partial struct IntValue : IValue + { + public int Value { get; init; } + } + + public partial struct DoubleValue : IValue + { + public double Value { get; init; } + } + + public partial struct StringValue : IValue + { + public string Value { get; init; } + } + """; + + var comp = RoslynTestHelpers.CreateCompilation( + structHierarchy, + assemblyName: nameof(Generates_Visitor_For_Struct_Hierarchy)); + + var gen = new VisitorGenerator(); + _ = RoslynTestHelpers.Run(comp, gen, out var run, out var updated); + + // No generator diagnostics + Assert.All(run.Results, r => Assert.Empty(r.Diagnostics)); + + // Confirm we generated expected files + var names = run.Results.SelectMany(r => r.GeneratedSources).Select(gs => gs.HintName).ToArray(); + + // Interfaces + Assert.Contains("IValueVisitor.Interfaces.g.cs", names); + + // Accept methods for interface and structs + Assert.Contains("IValue.Accept.g.cs", names); + Assert.Contains("IntValue.Accept.g.cs", names); + Assert.Contains("DoubleValue.Accept.g.cs", names); + Assert.Contains("StringValue.Accept.g.cs", names); + + // Verify compilation succeeds + var emit = updated.Emit(Stream.Null); + Assert.True(emit.Success, string.Join("\n", emit.Diagnostics)); + } + + [Fact] + public void Struct_Visitor_Dispatches_Without_Boxing() + { + const string structHierarchyWithUsage = """ + using PatternKit.Generators.Visitors; + + namespace PatternKit.Examples.Values; + + [GenerateVisitor] + public partial interface IValue { } + + public partial struct IntValue : IValue + { + public int Value { get; init; } + } + + public partial struct DoubleValue : IValue + { + public double Value { get; init; } + } + + public static class Demo + { + public static string Run() + { + var formatter = new IValueVisitorBuilder() + .When(i => $"Int:{i.Value}") + .When(d => $"Double:{d.Value:F2}") + .Default(_ => "Unknown") + .Build(); + + var intVal = new IntValue { Value = 42 }; + var doubleVal = new DoubleValue { Value = 3.14159 }; + + var intStr = intVal.Accept(formatter); + var doubleStr = doubleVal.Accept(formatter); + + return $"{intStr},{doubleStr}"; + } + } + """; + + var comp = RoslynTestHelpers.CreateCompilation( + structHierarchyWithUsage, + assemblyName: nameof(Struct_Visitor_Dispatches_Without_Boxing)); + + var gen = new VisitorGenerator(); + _ = RoslynTestHelpers.Run(comp, gen, out var run, out var updated); + + Assert.All(run.Results, r => Assert.Empty(r.Diagnostics)); + + // Emit and execute + using var pe = new MemoryStream(); + using var pdb = new MemoryStream(); + var res = updated.Emit(pe, pdb); + Assert.True(res.Success, string.Join("\n", res.Diagnostics)); + pe.Position = 0; + + var asm = AssemblyLoadContext.Default.LoadFromStream(pe, pdb); + var demo = asm.GetType("PatternKit.Examples.Values.Demo")!; + var runMethod = demo.GetMethod("Run")!; + var result = (string)runMethod.Invoke(null, null)!; + + Assert.Equal("Int:42,Double:3.14", result); + } + + [Fact] + public void Diagnostic_PKVIS001_EmittedWhenNoConcretTypesFound() + { + const string noDerivedTypes = """ + using PatternKit.Generators.Visitors; + + namespace PatternKit.Examples; + + [GenerateVisitor] + public partial interface IEmptyHierarchy { } + """; + + var comp = RoslynTestHelpers.CreateCompilation( + noDerivedTypes, + assemblyName: nameof(Diagnostic_PKVIS001_EmittedWhenNoConcretTypesFound)); + + var gen = new VisitorGenerator(); + _ = RoslynTestHelpers.Run(comp, gen, out var run, out var updated); + + // Should have PKVIS001 warning + var diagnostics = run.Results.SelectMany(r => r.Diagnostics).ToArray(); + Assert.Contains(diagnostics, d => d.Id == "PKVIS001"); + + var pkvis001 = diagnostics.First(d => d.Id == "PKVIS001"); + Assert.Contains("IEmptyHierarchy", pkvis001.GetMessage()); + } + + [Fact] + public void Diagnostic_PKVIS002_EmittedWhenBaseTypeNotPartial() + { + const string nonPartialBase = """ + using PatternKit.Generators.Visitors; + + namespace PatternKit.Examples; + + [GenerateVisitor] + public class NonPartialBase { } + + public partial class DerivedType : NonPartialBase { } + """; + + var comp = RoslynTestHelpers.CreateCompilation( + nonPartialBase, + assemblyName: nameof(Diagnostic_PKVIS002_EmittedWhenBaseTypeNotPartial)); + + var gen = new VisitorGenerator(); + _ = RoslynTestHelpers.Run(comp, gen, out var run, out var updated); + + // Should have PKVIS002 error + var diagnostics = run.Results.SelectMany(r => r.Diagnostics).ToArray(); + Assert.Contains(diagnostics, d => d.Id == "PKVIS002"); + + var pkvis002 = diagnostics.First(d => d.Id == "PKVIS002"); + Assert.Contains("NonPartialBase", pkvis002.GetMessage()); + } + + [Fact] + public void Diagnostic_PKVIS004_EmittedWhenDerivedTypeNotPartial() + { + const string nonPartialDerived = """ + using PatternKit.Generators.Visitors; + + namespace PatternKit.Examples; + + [GenerateVisitor] + public partial class Base { } + + public class NonPartialDerived : Base { } + """; + + var comp = RoslynTestHelpers.CreateCompilation( + nonPartialDerived, + assemblyName: nameof(Diagnostic_PKVIS004_EmittedWhenDerivedTypeNotPartial)); + + var gen = new VisitorGenerator(); + _ = RoslynTestHelpers.Run(comp, gen, out var run, out var updated); + + // Should have PKVIS004 error + var diagnostics = run.Results.SelectMany(r => r.Diagnostics).ToArray(); + Assert.Contains(diagnostics, d => d.Id == "PKVIS004"); + + var pkvis004 = diagnostics.First(d => d.Id == "PKVIS004"); + Assert.Contains("NonPartialDerived", pkvis004.GetMessage()); + } + + [Fact] + public void No_Diagnostics_For_Valid_Hierarchy() + { + const string validHierarchy = """ + using PatternKit.Generators.Visitors; + + namespace PatternKit.Examples; + + [GenerateVisitor] + public partial class ValidBase { } + + public partial class ValidDerived : ValidBase { } + """; + + var comp = RoslynTestHelpers.CreateCompilation( + validHierarchy, + assemblyName: nameof(No_Diagnostics_For_Valid_Hierarchy)); + + var gen = new VisitorGenerator(); + _ = RoslynTestHelpers.Run(comp, gen, out var run, out var updated); + + // Should have no generator diagnostics + var diagnostics = run.Results.SelectMany(r => r.Diagnostics) + .Where(d => d.Id.StartsWith("PKVIS")) + .ToArray(); + Assert.Empty(diagnostics); + } + + [Fact] + public void Generates_Visitor_For_Record_Hierarchy() + { + const string recordHierarchy = """ + using PatternKit.Generators.Visitors; + + namespace PatternKit.Examples.Records; + + [GenerateVisitor] + public abstract partial record Message; + + public partial record TextMessage(string Content) : Message; + + public partial record ImageMessage(byte[] Data, string Format) : Message; + + public partial record AudioMessage(string Url, int DurationSeconds) : Message; + """; + + var comp = RoslynTestHelpers.CreateCompilation( + recordHierarchy, + assemblyName: nameof(Generates_Visitor_For_Record_Hierarchy)); + + var gen = new VisitorGenerator(); + _ = RoslynTestHelpers.Run(comp, gen, out var run, out var updated); + + // No generator diagnostics + Assert.All(run.Results, r => Assert.Empty(r.Diagnostics)); + + // Confirm we generated expected files + var names = run.Results.SelectMany(r => r.GeneratedSources).Select(gs => gs.HintName).ToArray(); + + // Interfaces + Assert.Contains("IMessageVisitor.Interfaces.g.cs", names); + + // Accept methods for each record type + Assert.Contains("Message.Accept.g.cs", names); + Assert.Contains("TextMessage.Accept.g.cs", names); + Assert.Contains("ImageMessage.Accept.g.cs", names); + Assert.Contains("AudioMessage.Accept.g.cs", names); + + // Builders + Assert.Contains("MessageVisitorBuilder.g.cs", names); + + // Verify compilation succeeds + var emit = updated.Emit(Stream.Null); + Assert.True(emit.Success, string.Join("\n", emit.Diagnostics)); + } + + [Fact] + public void Record_Visitor_Dispatches_Correctly() + { + const string recordHierarchyWithUsage = """ + using PatternKit.Generators.Visitors; + + namespace PatternKit.Examples.Records; + + [GenerateVisitor] + public abstract partial record Message; + + public partial record TextMessage(string Content) : Message; + + public partial record ImageMessage(byte[] Data, string Format) : Message; + + public static class Demo + { + public static string Run() + { + var formatter = new MessageVisitorBuilder() + .When(m => $"Text: {m.Content}") + .When(m => $"Image: {m.Format}") + .Default(_ => "Unknown") + .Build(); + + var text = new TextMessage("Hello World"); + var image = new ImageMessage(new byte[] { 1, 2, 3 }, "PNG"); + + var textStr = text.Accept(formatter); + var imageStr = image.Accept(formatter); + + return $"{textStr}|{imageStr}"; + } + } + """; + + var comp = RoslynTestHelpers.CreateCompilation( + recordHierarchyWithUsage, + assemblyName: nameof(Record_Visitor_Dispatches_Correctly)); + + var gen = new VisitorGenerator(); + _ = RoslynTestHelpers.Run(comp, gen, out var run, out var updated); + + Assert.All(run.Results, r => Assert.Empty(r.Diagnostics)); + + // Emit and execute + using var pe = new MemoryStream(); + using var pdb = new MemoryStream(); + var res = updated.Emit(pe, pdb); + Assert.True(res.Success, string.Join("\n", res.Diagnostics)); + pe.Position = 0; + + var asm = AssemblyLoadContext.Default.LoadFromStream(pe, pdb); + var demo = asm.GetType("PatternKit.Examples.Records.Demo")!; + var runMethod = demo.GetMethod("Run")!; + var result = (string)runMethod.Invoke(null, null)!; + + Assert.Equal("Text: Hello World|Image: PNG", result); + } + + [Fact] + public void Diagnostic_PKVIS002_EmittedWhenInterfaceBaseTypeNotPartial() + { + const string nonPartialInterface = """ + using PatternKit.Generators.Visitors; + + namespace PatternKit.Examples; + + [GenerateVisitor] + public interface INotPartial { } + + public partial class DerivedType : INotPartial { } + """; + + var comp = RoslynTestHelpers.CreateCompilation( + nonPartialInterface, + assemblyName: nameof(Diagnostic_PKVIS002_EmittedWhenInterfaceBaseTypeNotPartial)); + + var gen = new VisitorGenerator(); + _ = RoslynTestHelpers.Run(comp, gen, out var run, out var updated); + + // Should have PKVIS002 error + var diagnostics = run.Results.SelectMany(r => r.Diagnostics).ToArray(); + Assert.Contains(diagnostics, d => d.Id == "PKVIS002"); + + var pkvis002 = diagnostics.First(d => d.Id == "PKVIS002"); + Assert.Contains("INotPartial", pkvis002.GetMessage()); + } }