diff --git a/src/PatternKit.Examples/MediatorComprehensiveDemo/README.md b/src/PatternKit.Examples/MediatorComprehensiveDemo/README.md index cff6bac..9f342eb 100644 --- a/src/PatternKit.Examples/MediatorComprehensiveDemo/README.md +++ b/src/PatternKit.Examples/MediatorComprehensiveDemo/README.md @@ -12,6 +12,10 @@ This 780+ line demo implements a complete e-commerce domain with: - ✅ **Event-Driven Architecture** with notification fan-out - ✅ **Async Streaming** with `IAsyncEnumerable` - ✅ **Pipeline Behaviors** for cross-cutting concerns +- ✅ **Around Middleware** for wrapping handlers +- ✅ **OnError Handling** for exception management +- ✅ **Module System** for modular registration +- ✅ **Object Overloads** for dynamic dispatch (optional) - ✅ **Repository Pattern** for data access - ✅ **Real-World Domain** (e-commerce scenario) @@ -43,6 +47,111 @@ var customer = await dispatcher.Send( default); ``` +## New Features + +### Around Middleware + +Wrap handler execution with full control over the pipeline: + +```csharp +var dispatcher = AppDispatcher.Create() + .Command(handler) + .Around(async (req, ct, next) => + { + // Before handler + Console.WriteLine("Before"); + + var result = await next(); + + // After handler + Console.WriteLine("After"); + + return result; + }, order: 1) + .Build(); +``` + +### OnError Handling + +Handle exceptions gracefully with error handlers: + +```csharp +var dispatcher = AppDispatcher.Create() + .Command(handler) + .OnError((req, ex, ct) => + { + Console.WriteLine($"Error: {ex.Message}"); + return ValueTask.CompletedTask; + }) + .Build(); +``` + +### Module System + +Organize handlers into reusable modules: + +```csharp +public class OrderModule : IModule +{ + public void Register(IDispatcherBuilder builder) + { + builder.Command(PlaceOrderHandler); + builder.Notification(NotifyInventory); + builder.Notification(SendConfirmation); + } +} + +var dispatcher = AppDispatcher.Create() + .AddModule(new OrderModule()) + .Build(); +``` + +### Object Overloads + +Enable dynamic dispatch for runtime scenarios: + +```csharp +// Generate with object overloads enabled +[assembly: GenerateDispatcher( + Namespace = "MyApp", + Name = "AppDispatcher", + IncludeObjectOverloads = true)] + +// Use dynamically +object command = new GetCustomer(id); +var result = await dispatcher.Send(command, ct); +``` + +### Stream Pipelines + +Add pipeline hooks for stream requests: + +```csharp +var dispatcher = AppDispatcher.Create() + .Stream(SearchHandler) + .PreStream((req, ct) => + { + Console.WriteLine($"Searching for: {req.Query}"); + return ValueTask.CompletedTask; + }) + .Build(); +``` + +### Pipeline Ordering + +Control execution order with explicit ordering: + +```csharp +var dispatcher = AppDispatcher.Create() + .Around(OuterMiddleware, order: 1) + .Around(InnerMiddleware, order: 2) + .Pre(PreHook, order: 0) + .Post(PostHook, order: 0) + .Build(); + +// Execution order: Pre(0) -> Around(1) -> Around(2) -> Handler -> Post(0) +``` + ## Architecture ### Domain Model @@ -69,6 +178,46 @@ var customer = await dispatcher.Send( **Streams:** - `SearchProductsQuery` → `IAsyncEnumerable` +## Pipeline Execution Flow + +### Command with Full Pipeline + +``` +Request → Pre Hooks (ordered) + → Around Middleware (outer to inner) + → Handler + → Around Middleware (inner to outer) + → Post Hooks (ordered) + → Response + +On Exception: + → OnError Hooks (ordered) + → Exception propagated +``` + +### Example with Multiple Behaviors + +```csharp +var dispatcher = AppDispatcher.Create() + .Pre(ValidateRequest, order: 0) + .Around(LoggingMiddleware, order: 1) + .Around(TransactionMiddleware, order: 2) + .Command(Handler) + .Post(CacheResult, order: 0) + .OnError(LogError, order: 0) + .Build(); + +// Execution flow: +// 1. Pre: ValidateRequest +// 2. Around(1) Begin: LoggingMiddleware +// 3. Around(2) Begin: TransactionMiddleware +// 4. Handler +// 5. Around(2) End: TransactionMiddleware +// 6. Around(1) End: LoggingMiddleware +// 7. Post: CacheResult +// On error: LogError +``` + ### Extension Methods #### `AddSourceGeneratedMediator()` diff --git a/src/PatternKit.Generators/Messaging/DispatcherGenerator.cs b/src/PatternKit.Generators/Messaging/DispatcherGenerator.cs index 71bba36..5a3a74b 100644 --- a/src/PatternKit.Generators/Messaging/DispatcherGenerator.cs +++ b/src/PatternKit.Generators/Messaging/DispatcherGenerator.cs @@ -90,6 +90,7 @@ private static string GenerateMainDispatcherFile(DispatcherConfig config, string { "System", "System.Collections.Generic", + "System.Linq", "System.Threading", "System.Threading.Tasks" }; @@ -99,9 +100,17 @@ private static string GenerateMainDispatcherFile(DispatcherConfig config, string usings.Add("System.Runtime.CompilerServices"); } + if (config.IncludeObjectOverloads) + { + usings.Add("System.Reflection"); + } + AppendUsings(sb, usings.ToArray()); AppendNamespaceAndClassHeader(sb, config.Namespace, visibility, config.Name); + // PipelineEntry class + GeneratePipelineEntry(sb); + // Internal state sb.AppendLine(" private readonly Dictionary _commandHandlers = new();"); sb.AppendLine(" private readonly Dictionary> _notificationHandlers = new();"); @@ -111,11 +120,11 @@ private static string GenerateMainDispatcherFile(DispatcherConfig config, string sb.AppendLine(" private readonly Dictionary _streamHandlers = new();"); } - sb.AppendLine(" private readonly Dictionary> _commandPipelines = new();"); + sb.AppendLine(" private readonly Dictionary> _commandPipelines = new();"); if (config.IncludeStreaming) { - sb.AppendLine(" private readonly Dictionary> _streamPipelines = new();"); + sb.AppendLine(" private readonly Dictionary> _streamPipelines = new();"); } sb.AppendLine(); @@ -150,6 +159,12 @@ private static string GenerateMainDispatcherFile(DispatcherConfig config, string sb.AppendLine(" }"); sb.AppendLine(); + // Object overload for Send + if (config.IncludeObjectOverloads) + { + GenerateObjectSendMethod(sb); + } + // Publish method (notifications) sb.AppendLine(" /// "); sb.AppendLine(" /// Publishes a notification to all registered handlers."); @@ -170,6 +185,12 @@ private static string GenerateMainDispatcherFile(DispatcherConfig config, string sb.AppendLine(" }"); sb.AppendLine(); + // Object overload for Publish + if (config.IncludeObjectOverloads) + { + GenerateObjectPublishMethod(sb); + } + // Stream method if (config.IncludeStreaming) { @@ -179,6 +200,18 @@ private static string GenerateMainDispatcherFile(DispatcherConfig config, string sb.AppendLine(" public async IAsyncEnumerable Stream(TRequest request, [EnumeratorCancellation] CancellationToken ct = default)"); sb.AppendLine(" {"); sb.AppendLine(" var requestType = typeof(TRequest);"); + sb.AppendLine(); + sb.AppendLine(" // Execute Pre hooks if registered"); + sb.AppendLine(" if (_streamPipelines.TryGetValue(requestType, out var pipelines))"); + sb.AppendLine(" {"); + sb.AppendLine(" var orderedPipelines = pipelines.OrderBy(p => p.Order).ToList();"); + sb.AppendLine(" foreach (var entry in orderedPipelines.Where(e => e.Type == PipelineType.Pre))"); + sb.AppendLine(" {"); + sb.AppendLine(" var pre = (Func)entry.Delegate;"); + sb.AppendLine(" await pre(request, ct);"); + sb.AppendLine(" }"); + sb.AppendLine(" }"); + sb.AppendLine(); sb.AppendLine(" if (!_streamHandlers.TryGetValue(requestType, out var handlerDelegate))"); sb.AppendLine(" {"); sb.AppendLine(" throw new InvalidOperationException($\"No stream handler registered for request type {requestType.Name}\");"); @@ -193,51 +226,295 @@ private static string GenerateMainDispatcherFile(DispatcherConfig config, string sb.AppendLine(" }"); sb.AppendLine(" }"); sb.AppendLine(); + + // Object overload for Stream + if (config.IncludeObjectOverloads) + { + GenerateObjectStreamMethod(sb); + } } // Helper method for pipeline execution + GenerateExecuteWithPipelineMethod(sb); + + sb.AppendLine("}"); + + return sb.ToString(); + } + + private static void GeneratePipelineEntry(StringBuilder sb) + { + sb.AppendLine(" private enum PipelineType { Pre, Around, Post, OnError }"); + sb.AppendLine(); + sb.AppendLine(" private sealed class PipelineEntry"); + sb.AppendLine(" {"); + sb.AppendLine(" public PipelineType Type { get; set; }"); + sb.AppendLine(" public int Order { get; set; }"); + sb.AppendLine(" public Delegate Delegate { get; set; } = null!;"); + sb.AppendLine(" }"); + sb.AppendLine(); + } + + private static void GenerateExecuteWithPipelineMethod(StringBuilder sb) + { sb.AppendLine(" private async ValueTask ExecuteWithPipeline("); sb.AppendLine(" TRequest request,"); sb.AppendLine(" Func> handler,"); - sb.AppendLine(" List pipelines,"); + sb.AppendLine(" List pipelines,"); sb.AppendLine(" CancellationToken ct)"); sb.AppendLine(" {"); + sb.AppendLine(" var orderedPipelines = pipelines.OrderBy(p => p.Order).ToList();"); + sb.AppendLine(); sb.AppendLine(" // Execute Pre hooks"); - sb.AppendLine(" foreach (var pipeline in pipelines)"); + sb.AppendLine(" foreach (var entry in orderedPipelines.Where(e => e.Type == PipelineType.Pre))"); + sb.AppendLine(" {"); + sb.AppendLine(" var pre = (Func)entry.Delegate;"); + sb.AppendLine(" await pre(request, ct);"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine(" // Build Around chain (compose from innermost to outermost)"); + sb.AppendLine(" var arounds = orderedPipelines.Where(e => e.Type == PipelineType.Around).ToList();"); + sb.AppendLine(); + sb.AppendLine(" Func> next = () => handler(request, ct);"); + sb.AppendLine(); + sb.AppendLine(" for (int i = arounds.Count - 1; i >= 0; i--)"); sb.AppendLine(" {"); - sb.AppendLine(" if (pipeline is Func pre)"); + sb.AppendLine(" var around = (Func>, ValueTask>)arounds[i].Delegate;"); + sb.AppendLine(" var currentNext = next;"); + sb.AppendLine(" next = () => around(request, ct, currentNext);"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine(" TResponse response;"); + sb.AppendLine(" try"); + sb.AppendLine(" {"); + sb.AppendLine(" response = await next();"); + sb.AppendLine(" }"); + sb.AppendLine(" catch (Exception ex)"); + sb.AppendLine(" {"); + sb.AppendLine(" // Execute OnError hooks"); + sb.AppendLine(" foreach (var entry in orderedPipelines.Where(e => e.Type == PipelineType.OnError))"); sb.AppendLine(" {"); - sb.AppendLine(" await pre(request, ct);"); + sb.AppendLine(" var onError = (Func)entry.Delegate;"); + sb.AppendLine(" await onError(request, ex, ct);"); sb.AppendLine(" }"); + sb.AppendLine(" throw;"); sb.AppendLine(" }"); sb.AppendLine(); - sb.AppendLine(" // Execute handler"); - sb.AppendLine(" var response = await handler(request, ct);"); - sb.AppendLine(); sb.AppendLine(" // Execute Post hooks"); - sb.AppendLine(" foreach (var pipeline in pipelines)"); + sb.AppendLine(" foreach (var entry in orderedPipelines.Where(e => e.Type == PipelineType.Post))"); + sb.AppendLine(" {"); + sb.AppendLine(" var post = (Func)entry.Delegate;"); + sb.AppendLine(" await post(request, response, ct);"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine(" return response;"); + sb.AppendLine(" }"); + } + + private static void GenerateObjectSendMethod(StringBuilder sb) + { + sb.AppendLine(" /// "); + sb.AppendLine(" /// Sends a command and returns a response (object-based overload)."); + sb.AppendLine(" /// Note: Uses reflection. For best performance, use generic Send."); + sb.AppendLine(" /// "); + sb.AppendLine(" public async ValueTask Send(object request, CancellationToken ct = default)"); + sb.AppendLine(" {"); + sb.AppendLine(" var requestType = request.GetType();"); + sb.AppendLine(" if (!_commandHandlers.TryGetValue(requestType, out var handlerDelegate))"); + sb.AppendLine(" {"); + sb.AppendLine(" throw new InvalidOperationException($\"No handler registered for command type {requestType.Name}\");"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine(" // Invoke handler via reflection"); + sb.AppendLine(" var delegateType = handlerDelegate.GetType();"); + sb.AppendLine(" var invokeMethod = delegateType.GetMethod(\"Invoke\");"); + sb.AppendLine(" if (invokeMethod == null)"); + sb.AppendLine(" {"); + sb.AppendLine(" throw new InvalidOperationException(\"Could not find Invoke method on handler delegate\");"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine(" var result = invokeMethod.Invoke(handlerDelegate, new object?[] { request, ct });"); + sb.AppendLine(" if (result is ValueTask vtObj)"); sb.AppendLine(" {"); - sb.AppendLine(" if (pipeline is Func post)"); + sb.AppendLine(" return await vtObj;"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine(" // Handle generic ValueTask"); + sb.AppendLine(" // Note: This reflection-based path is only used for object overloads (opt-in)"); + sb.AppendLine(" // Regular generic Send is zero-reflection"); + sb.AppendLine(" var resultType = result?.GetType();"); + sb.AppendLine(" if (resultType != null && resultType.IsGenericType && resultType.GetGenericTypeDefinition() == typeof(ValueTask<>))"); + sb.AppendLine(" {"); + sb.AppendLine(" var asTaskMethod = resultType.GetMethod(\"AsTask\");"); + sb.AppendLine(" if (asTaskMethod != null)"); sb.AppendLine(" {"); - sb.AppendLine(" await post(request, response, ct);"); + sb.AppendLine(" var task = asTaskMethod.Invoke(result, null) as Task;"); + sb.AppendLine(" if (task != null)"); + sb.AppendLine(" {"); + sb.AppendLine(" await task;"); + sb.AppendLine(" return task.GetType().GetProperty(\"Result\")?.GetValue(task);"); + sb.AppendLine(" }"); sb.AppendLine(" }"); sb.AppendLine(" }"); sb.AppendLine(); - sb.AppendLine(" return response;"); + sb.AppendLine(" return result;"); sb.AppendLine(" }"); - - sb.AppendLine("}"); - - return sb.ToString(); + sb.AppendLine(); + } + + private static void GenerateObjectPublishMethod(StringBuilder sb) + { + sb.AppendLine(" /// "); + sb.AppendLine(" /// Publishes a notification to all registered handlers (object-based overload)."); + sb.AppendLine(" /// Note: Uses reflection. For best performance, use generic Publish."); + sb.AppendLine(" /// "); + sb.AppendLine(" public async ValueTask Publish(object notification, CancellationToken ct = default)"); + sb.AppendLine(" {"); + sb.AppendLine(" var notificationType = notification.GetType();"); + sb.AppendLine(" if (!_notificationHandlers.TryGetValue(notificationType, out var handlers))"); + sb.AppendLine(" {"); + sb.AppendLine(" return; // No-op if no handlers"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine(" foreach (var handlerDelegate in handlers)"); + sb.AppendLine(" {"); + sb.AppendLine(" var invokeMethod = handlerDelegate.GetType().GetMethod(\"Invoke\");"); + sb.AppendLine(" if (invokeMethod != null)"); + sb.AppendLine(" {"); + sb.AppendLine(" var result = invokeMethod.Invoke(handlerDelegate, new object?[] { notification, ct });"); + sb.AppendLine(" if (result is ValueTask vt)"); + sb.AppendLine(" {"); + sb.AppendLine(" await vt;"); + sb.AppendLine(" }"); + sb.AppendLine(" }"); + sb.AppendLine(" }"); + sb.AppendLine(" }"); + sb.AppendLine(); + } + + private static void GenerateObjectStreamMethod(StringBuilder sb) + { + sb.AppendLine(" /// "); + sb.AppendLine(" /// Streams items from a stream request (object-based overload)."); + sb.AppendLine(" /// Note: Uses reflection. For best performance, use generic Stream."); + sb.AppendLine(" /// "); + sb.AppendLine(" public async IAsyncEnumerable Stream(object request, [EnumeratorCancellation] CancellationToken ct = default)"); + sb.AppendLine(" {"); + sb.AppendLine(" var requestType = request.GetType();"); + sb.AppendLine(" if (!_streamHandlers.TryGetValue(requestType, out var handlerDelegate))"); + sb.AppendLine(" {"); + sb.AppendLine(" throw new InvalidOperationException($\"No stream handler registered for request type {requestType.Name}\");"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine(" // Invoke handler to get IAsyncEnumerable"); + sb.AppendLine(" var invokeMethod = handlerDelegate.GetType().GetMethod(\"Invoke\");"); + sb.AppendLine(" if (invokeMethod == null)"); + sb.AppendLine(" {"); + sb.AppendLine(" throw new InvalidOperationException(\"Could not find Invoke method on stream handler delegate\");"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine(" var result = invokeMethod.Invoke(handlerDelegate, new object?[] { request, ct });"); + sb.AppendLine(" if (result == null)"); + sb.AppendLine(" {"); + sb.AppendLine(" yield break;"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine(" // Find IAsyncEnumerable interface on result"); + sb.AppendLine(" var asyncEnumerableInterface = result.GetType().GetInterfaces()"); + sb.AppendLine(" .FirstOrDefault(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IAsyncEnumerable<>));"); + sb.AppendLine(); + sb.AppendLine(" if (asyncEnumerableInterface == null)"); + sb.AppendLine(" {"); + sb.AppendLine(" throw new InvalidOperationException(\"Handler result does not implement IAsyncEnumerable\");"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine(" // Get GetAsyncEnumerator method from the interface"); + sb.AppendLine(" var getEnumeratorMethod = asyncEnumerableInterface.GetMethod(\"GetAsyncEnumerator\");"); + sb.AppendLine(" if (getEnumeratorMethod == null)"); + sb.AppendLine(" {"); + sb.AppendLine(" throw new InvalidOperationException(\"Could not find GetAsyncEnumerator method\");"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine(" // Get enumerator"); + sb.AppendLine(" var enumerator = getEnumeratorMethod.Invoke(result, new object[] { ct });"); + sb.AppendLine(" if (enumerator == null)"); + sb.AppendLine(" {"); + sb.AppendLine(" yield break;"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine(" // Get IAsyncEnumerator interface"); + sb.AppendLine(" var asyncEnumeratorInterface = enumerator.GetType().GetInterfaces()"); + sb.AppendLine(" .FirstOrDefault(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IAsyncEnumerator<>));"); + sb.AppendLine(); + sb.AppendLine(" if (asyncEnumeratorInterface == null)"); + sb.AppendLine(" {"); + sb.AppendLine(" yield break;"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine(" var moveNextAsyncMethod = asyncEnumeratorInterface.GetMethod(\"MoveNextAsync\");"); + sb.AppendLine(" var currentProperty = asyncEnumeratorInterface.GetProperty(\"Current\");"); + sb.AppendLine(); + sb.AppendLine(" if (moveNextAsyncMethod == null || currentProperty == null)"); + sb.AppendLine(" {"); + sb.AppendLine(" yield break;"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine(" try"); + sb.AppendLine(" {"); + sb.AppendLine(" while (true)"); + sb.AppendLine(" {"); + sb.AppendLine(" var moveNextResult = moveNextAsyncMethod.Invoke(enumerator, null);"); + sb.AppendLine(" if (moveNextResult is ValueTask vtBool)"); + sb.AppendLine(" {"); + sb.AppendLine(" if (!await vtBool)"); + sb.AppendLine(" {"); + sb.AppendLine(" break;"); + sb.AppendLine(" }"); + sb.AppendLine(" yield return currentProperty.GetValue(enumerator);"); + sb.AppendLine(" }"); + sb.AppendLine(" else"); + sb.AppendLine(" {"); + sb.AppendLine(" break;"); + sb.AppendLine(" }"); + sb.AppendLine(" }"); + sb.AppendLine(" }"); + sb.AppendLine(" finally"); + sb.AppendLine(" {"); + sb.AppendLine(" if (enumerator is IAsyncDisposable asyncDisposable)"); + sb.AppendLine(" {"); + sb.AppendLine(" await asyncDisposable.DisposeAsync();"); + sb.AppendLine(" }"); + sb.AppendLine(" else"); + sb.AppendLine(" {"); + sb.AppendLine(" var disposeAsyncMethod = enumerator.GetType().GetMethod(\"DisposeAsync\", Type.EmptyTypes);"); + sb.AppendLine(" if (disposeAsyncMethod != null)"); + sb.AppendLine(" {"); + sb.AppendLine(" var disposeResult = disposeAsyncMethod.Invoke(enumerator, null);"); + sb.AppendLine(" if (disposeResult is ValueTask vt)"); + sb.AppendLine(" {"); + sb.AppendLine(" await vt;"); + sb.AppendLine(" }"); + sb.AppendLine(" }"); + sb.AppendLine(" }"); + sb.AppendLine(" }"); + sb.AppendLine(" }"); + sb.AppendLine(); } private static string GenerateBuilderFile(DispatcherConfig config, string visibility) { var sb = CreateFileHeader(); AppendUsings(sb, "System", "System.Collections.Generic", "System.Threading", "System.Threading.Tasks"); + + if (config.IncludeStreaming) + { + sb.AppendLine("using System.Runtime.CompilerServices;"); + } + + sb.AppendLine(); AppendNamespaceAndClassHeader(sb, config.Namespace, visibility, config.Name); - sb.AppendLine($" {visibility} sealed class Builder"); + sb.AppendLine($" {visibility} sealed class Builder : IDispatcherBuilder"); sb.AppendLine(" {"); sb.AppendLine($" private readonly {config.Name} _dispatcher = new();"); sb.AppendLine(); @@ -286,35 +563,105 @@ private static string GenerateBuilderFile(DispatcherConfig config, string visibi } // Pipeline registration - Pre - sb.AppendLine(" public Builder Pre(Func pre)"); + sb.AppendLine(" public Builder Pre(Func pre, int order = 0)"); + sb.AppendLine(" {"); + sb.AppendLine(" var requestType = typeof(TRequest);"); + sb.AppendLine(" if (!_dispatcher._commandPipelines.TryGetValue(requestType, out var pipelines))"); + sb.AppendLine(" {"); + sb.AppendLine(" pipelines = new List();"); + sb.AppendLine(" _dispatcher._commandPipelines[requestType] = pipelines;"); + sb.AppendLine(" }"); + sb.AppendLine(" pipelines.Add(new PipelineEntry { Type = PipelineType.Pre, Order = order, Delegate = pre });"); + sb.AppendLine(" return this;"); + sb.AppendLine(" }"); + sb.AppendLine(); + + // Pipeline registration - Around + sb.AppendLine(" public Builder Around(Func>, ValueTask> around, int order = 0)"); sb.AppendLine(" {"); sb.AppendLine(" var requestType = typeof(TRequest);"); sb.AppendLine(" if (!_dispatcher._commandPipelines.TryGetValue(requestType, out var pipelines))"); sb.AppendLine(" {"); - sb.AppendLine(" pipelines = new List();"); + sb.AppendLine(" pipelines = new List();"); sb.AppendLine(" _dispatcher._commandPipelines[requestType] = pipelines;"); sb.AppendLine(" }"); - sb.AppendLine(" pipelines.Add(pre);"); + sb.AppendLine(" pipelines.Add(new PipelineEntry { Type = PipelineType.Around, Order = order, Delegate = around });"); sb.AppendLine(" return this;"); sb.AppendLine(" }"); sb.AppendLine(); // Pipeline registration - Post - sb.AppendLine(" public Builder Post(Func post)"); + sb.AppendLine(" public Builder Post(Func post, int order = 0)"); + sb.AppendLine(" {"); + sb.AppendLine(" var requestType = typeof(TRequest);"); + sb.AppendLine(" if (!_dispatcher._commandPipelines.TryGetValue(requestType, out var pipelines))"); + sb.AppendLine(" {"); + sb.AppendLine(" pipelines = new List();"); + sb.AppendLine(" _dispatcher._commandPipelines[requestType] = pipelines;"); + sb.AppendLine(" }"); + sb.AppendLine(" pipelines.Add(new PipelineEntry { Type = PipelineType.Post, Order = order, Delegate = post });"); + sb.AppendLine(" return this;"); + sb.AppendLine(" }"); + sb.AppendLine(); + + // Pipeline registration - OnError + sb.AppendLine(" public Builder OnError(Func onError, int order = 0)"); sb.AppendLine(" {"); sb.AppendLine(" var requestType = typeof(TRequest);"); sb.AppendLine(" if (!_dispatcher._commandPipelines.TryGetValue(requestType, out var pipelines))"); sb.AppendLine(" {"); - sb.AppendLine(" pipelines = new List();"); + sb.AppendLine(" pipelines = new List();"); sb.AppendLine(" _dispatcher._commandPipelines[requestType] = pipelines;"); sb.AppendLine(" }"); - sb.AppendLine(" pipelines.Add(post);"); + sb.AppendLine(" pipelines.Add(new PipelineEntry { Type = PipelineType.OnError, Order = order, Delegate = onError });"); + sb.AppendLine(" return this;"); + sb.AppendLine(" }"); + sb.AppendLine(); + + // Stream pipeline registration - PreStream + if (config.IncludeStreaming) + { + sb.AppendLine(" public Builder PreStream(Func pre, int order = 0)"); + sb.AppendLine(" {"); + sb.AppendLine(" var requestType = typeof(TRequest);"); + sb.AppendLine(" if (!_dispatcher._streamPipelines.TryGetValue(requestType, out var pipelines))"); + sb.AppendLine(" {"); + sb.AppendLine(" pipelines = new List();"); + sb.AppendLine(" _dispatcher._streamPipelines[requestType] = pipelines;"); + sb.AppendLine(" }"); + sb.AppendLine(" pipelines.Add(new PipelineEntry { Type = PipelineType.Pre, Order = order, Delegate = pre });"); + sb.AppendLine(" return this;"); + sb.AppendLine(" }"); + sb.AppendLine(); + } + + // Module registration + sb.AppendLine(" public Builder AddModule(IModule module)"); + sb.AppendLine(" {"); + sb.AppendLine(" module.Register(this);"); sb.AppendLine(" return this;"); sb.AppendLine(" }"); sb.AppendLine(); // Build method sb.AppendLine($" public {config.Name} Build() => _dispatcher;"); + sb.AppendLine(); + + // IDispatcherBuilder implementation + sb.AppendLine(" // Explicit interface implementations"); + sb.AppendLine(" IDispatcherBuilder IDispatcherBuilder.Command(Func> handler)"); + sb.AppendLine(" => Command(handler);"); + sb.AppendLine(); + sb.AppendLine(" IDispatcherBuilder IDispatcherBuilder.Notification(Func handler)"); + sb.AppendLine(" => Notification(handler);"); + sb.AppendLine(); + + if (config.IncludeStreaming) + { + sb.AppendLine(" IDispatcherBuilder IDispatcherBuilder.Stream(Func> handler)"); + sb.AppendLine(" => Stream(handler);"); + sb.AppendLine(); + } sb.AppendLine(" }"); sb.AppendLine("}"); @@ -330,6 +677,33 @@ private static string GenerateContractsFile(DispatcherConfig config, string visi sb.AppendLine($"namespace {config.Namespace};"); sb.AppendLine(); + // IModule interface + sb.AppendLine("/// "); + sb.AppendLine("/// Interface for modular registration of handlers."); + sb.AppendLine("/// "); + sb.AppendLine($"{visibility} interface IModule"); + sb.AppendLine("{"); + sb.AppendLine(" void Register(IDispatcherBuilder builder);"); + sb.AppendLine("}"); + sb.AppendLine(); + + // IDispatcherBuilder interface + sb.AppendLine("/// "); + sb.AppendLine("/// Builder interface for registering handlers."); + sb.AppendLine("/// "); + sb.AppendLine($"{visibility} interface IDispatcherBuilder"); + sb.AppendLine("{"); + sb.AppendLine(" IDispatcherBuilder Command(System.Func> handler);"); + sb.AppendLine(" IDispatcherBuilder Notification(System.Func handler);"); + + if (config.IncludeStreaming) + { + sb.AppendLine(" IDispatcherBuilder Stream(System.Func> handler);"); + } + + sb.AppendLine("}"); + sb.AppendLine(); + // Command handler interface sb.AppendLine("/// "); sb.AppendLine("/// Handler for a command that returns a response."); diff --git a/test/PatternKit.Generators.Tests/DispatcherGeneratorTests.cs b/test/PatternKit.Generators.Tests/DispatcherGeneratorTests.cs index d47c4e0..811a468 100644 --- a/test/PatternKit.Generators.Tests/DispatcherGeneratorTests.cs +++ b/test/PatternKit.Generators.Tests/DispatcherGeneratorTests.cs @@ -453,4 +453,821 @@ namespace MyApp; Assert.Contains("interface IStreamHandler", text); Assert.Contains("delegate ValueTask CommandNext", text); } + + #region Around Middleware Tests + + [Fact] + public void AroundMiddleware_SingleBehavior_WrapsHandler() + { + var source = """ + using PatternKit.Generators.Messaging; + using System.Collections.Generic; + using System.Threading; + using System.Threading.Tasks; + + [assembly: GenerateDispatcher(Namespace = "MyApp.Messaging", Name = "AppDispatcher")] + + namespace MyApp; + + using MyApp.Messaging; + + public record Multiply(int Value); + public record Result(int Value); + + public static class Demo + { + private static List log = new(); + + public static async Task Run() + { + log.Clear(); + var dispatcher = AppDispatcher.Create() + .Command((req, ct) => + { + log.Add($"Handler:{req.Value}"); + return new ValueTask(new Result(req.Value * 2)); + }) + .Around(async (req, ct, next) => + { + log.Add("Around:Before"); + var result = await next(); + log.Add($"Around:After:{result.Value}"); + return result; + }) + .Build(); + + var response = await dispatcher.Send(new Multiply(5), default); + log.Add($"Final:{response.Value}"); + return string.Join("|", log); + } + } + """; + + var comp = RoslynTestHelpers.CreateCompilation( + source, + assemblyName: nameof(AroundMiddleware_SingleBehavior_WrapsHandler)); + + var gen = new PatternKit.Generators.Messaging.DispatcherGenerator(); + _ = RoslynTestHelpers.Run(comp, gen, out _, out var updated); + + var emit = updated.Emit(Stream.Null); + Assert.True(emit.Success, string.Join("\n", emit.Diagnostics)); + + using var pe = new MemoryStream(); + var emitResult = updated.Emit(pe); + Assert.True(emitResult.Success); + + pe.Seek(0, SeekOrigin.Begin); + var asm = System.Reflection.Assembly.Load(pe.ToArray()); + var demo = asm.GetType("MyApp.Demo"); + var run = demo!.GetMethod("Run"); + var task = (Task)run!.Invoke(null, null)!; + var result = task.Result; + + Assert.Equal("Around:Before|Handler:5|Around:After:10|Final:10", result); + } + + [Fact] + public void AroundMiddleware_MultipleBehaviors_ComposesInOrder() + { + var source = """ + using PatternKit.Generators.Messaging; + using System.Collections.Generic; + using System.Threading; + using System.Threading.Tasks; + + [assembly: GenerateDispatcher(Namespace = "MyApp.Messaging", Name = "AppDispatcher")] + + namespace MyApp; + + using MyApp.Messaging; + + public record Add(int Value); + public record Result(int Value); + + public static class Demo + { + private static List log = new(); + + public static async Task Run() + { + log.Clear(); + var dispatcher = AppDispatcher.Create() + .Command((req, ct) => + { + log.Add("Handler"); + return new ValueTask(new Result(req.Value + 10)); + }) + .Around(async (req, ct, next) => + { + log.Add("Around1:Before"); + var result = await next(); + log.Add("Around1:After"); + return result; + }, order: 1) + .Around(async (req, ct, next) => + { + log.Add("Around2:Before"); + var result = await next(); + log.Add("Around2:After"); + return result; + }, order: 2) + .Build(); + + await dispatcher.Send(new Add(5), default); + return string.Join("|", log); + } + } + """; + + var comp = RoslynTestHelpers.CreateCompilation( + source, + assemblyName: nameof(AroundMiddleware_MultipleBehaviors_ComposesInOrder)); + + var gen = new PatternKit.Generators.Messaging.DispatcherGenerator(); + _ = RoslynTestHelpers.Run(comp, gen, out _, out var updated); + + var emit = updated.Emit(Stream.Null); + Assert.True(emit.Success, string.Join("\n", emit.Diagnostics)); + + using var pe = new MemoryStream(); + var emitResult = updated.Emit(pe); + Assert.True(emitResult.Success); + + pe.Seek(0, SeekOrigin.Begin); + var asm = System.Reflection.Assembly.Load(pe.ToArray()); + var demo = asm.GetType("MyApp.Demo"); + var run = demo!.GetMethod("Run"); + var task = (Task)run!.Invoke(null, null)!; + var result = task.Result; + + // Order: 1 (outer) wraps 2 (inner) + // Execution: Around1:Before -> Around2:Before -> Handler -> Around2:After -> Around1:After + Assert.Equal("Around1:Before|Around2:Before|Handler|Around2:After|Around1:After", result); + } + + [Fact] + public void AroundMiddleware_ModifiesRequestAndResponse_VerifiesNesting() + { + var source = """ + using PatternKit.Generators.Messaging; + using System.Threading; + using System.Threading.Tasks; + + [assembly: GenerateDispatcher(Namespace = "MyApp.Messaging", Name = "AppDispatcher")] + + namespace MyApp; + + using MyApp.Messaging; + + public record Request(int Value); + public record Response(int Value); + + public static class Demo + { + public static async Task Run() + { + var dispatcher = AppDispatcher.Create() + .Command((req, ct) => + new ValueTask(new Response(req.Value))) + .Around(async (req, ct, next) => + { + // Outer Around adds 10 after handler + var result = await next(); + return new Response(result.Value + 10); + }, order: 1) + .Around(async (req, ct, next) => + { + // Inner Around multiplies by 2 after handler + var result = await next(); + return new Response(result.Value * 2); + }, order: 2) + .Build(); + + var response = await dispatcher.Send(new Request(5), default); + return response.Value; + } + } + """; + + var comp = RoslynTestHelpers.CreateCompilation( + source, + assemblyName: nameof(AroundMiddleware_ModifiesRequestAndResponse_VerifiesNesting)); + + var gen = new PatternKit.Generators.Messaging.DispatcherGenerator(); + _ = RoslynTestHelpers.Run(comp, gen, out _, out var updated); + + var emit = updated.Emit(Stream.Null); + Assert.True(emit.Success, string.Join("\n", emit.Diagnostics)); + + using var pe = new MemoryStream(); + var emitResult = updated.Emit(pe); + Assert.True(emitResult.Success); + + pe.Seek(0, SeekOrigin.Begin); + var asm = System.Reflection.Assembly.Load(pe.ToArray()); + var demo = asm.GetType("MyApp.Demo"); + var run = demo!.GetMethod("Run"); + var task = (Task)run!.Invoke(null, null)!; + var result = task.Result; + + // Flow: 5 -> handler(5) -> inner(*2=10) -> outer(+10=20) + Assert.Equal(20, result); + } + + [Fact] + public void AroundMiddleware_WithPreAndPost_ExecutesInCorrectOrder() + { + var source = """ + using PatternKit.Generators.Messaging; + using System.Collections.Generic; + using System.Threading; + using System.Threading.Tasks; + + [assembly: GenerateDispatcher(Namespace = "MyApp.Messaging", Name = "AppDispatcher")] + + namespace MyApp; + + using MyApp.Messaging; + + public record DoWork(int Value); + public record Result(int Value); + + public static class Demo + { + private static List log = new(); + + public static async Task Run() + { + log.Clear(); + var dispatcher = AppDispatcher.Create() + .Pre((req, ct) => { log.Add("Pre"); return ValueTask.CompletedTask; }) + .Command((req, ct) => + { + log.Add("Handler"); + return new ValueTask(new Result(req.Value)); + }) + .Around(async (req, ct, next) => + { + log.Add("Around:Before"); + var result = await next(); + log.Add("Around:After"); + return result; + }) + .Post((req, res, ct) => { log.Add("Post"); return ValueTask.CompletedTask; }) + .Build(); + + await dispatcher.Send(new DoWork(1), default); + return string.Join("|", log); + } + } + """; + + var comp = RoslynTestHelpers.CreateCompilation( + source, + assemblyName: nameof(AroundMiddleware_WithPreAndPost_ExecutesInCorrectOrder)); + + var gen = new PatternKit.Generators.Messaging.DispatcherGenerator(); + _ = RoslynTestHelpers.Run(comp, gen, out _, out var updated); + + var emit = updated.Emit(Stream.Null); + Assert.True(emit.Success, string.Join("\n", emit.Diagnostics)); + + using var pe = new MemoryStream(); + var emitResult = updated.Emit(pe); + Assert.True(emitResult.Success); + + pe.Seek(0, SeekOrigin.Begin); + var asm = System.Reflection.Assembly.Load(pe.ToArray()); + var demo = asm.GetType("MyApp.Demo"); + var run = demo!.GetMethod("Run"); + var task = (Task)run!.Invoke(null, null)!; + var result = task.Result; + + // Expected: Pre -> Around Before -> Handler -> Around After -> Post + Assert.Equal("Pre|Around:Before|Handler|Around:After|Post", result); + } + + #endregion + + #region OnError Handling Tests + + [Fact] + public void OnError_HandlerThrows_ExecutesErrorHandler() + { + var source = """ + using PatternKit.Generators.Messaging; + using System; + using System.Collections.Generic; + using System.Threading; + using System.Threading.Tasks; + + [assembly: GenerateDispatcher(Namespace = "MyApp.Messaging", Name = "AppDispatcher")] + + namespace MyApp; + + using MyApp.Messaging; + + public record FailingCommand(string Message); + public record Result(string Data); + + public static class Demo + { + private static List log = new(); + + public static async Task Run() + { + log.Clear(); + var dispatcher = AppDispatcher.Create() + .Command((req, ct) => + { + log.Add("Handler:Throwing"); + throw new InvalidOperationException(req.Message); + }) + .OnError((req, ex, ct) => + { + log.Add($"OnError:{ex.Message}"); + return ValueTask.CompletedTask; + }) + .Build(); + + try + { + await dispatcher.Send(new FailingCommand("TestError"), default); + } + catch (InvalidOperationException) + { + log.Add("Caught"); + } + + return string.Join("|", log); + } + } + """; + + var comp = RoslynTestHelpers.CreateCompilation( + source, + assemblyName: nameof(OnError_HandlerThrows_ExecutesErrorHandler)); + + var gen = new PatternKit.Generators.Messaging.DispatcherGenerator(); + _ = RoslynTestHelpers.Run(comp, gen, out _, out var updated); + + var emit = updated.Emit(Stream.Null); + Assert.True(emit.Success, string.Join("\n", emit.Diagnostics)); + + using var pe = new MemoryStream(); + var emitResult = updated.Emit(pe); + Assert.True(emitResult.Success); + + pe.Seek(0, SeekOrigin.Begin); + var asm = System.Reflection.Assembly.Load(pe.ToArray()); + var demo = asm.GetType("MyApp.Demo"); + var run = demo!.GetMethod("Run"); + var task = (Task)run!.Invoke(null, null)!; + var result = task.Result; + + Assert.Equal("Handler:Throwing|OnError:TestError|Caught", result); + } + + [Fact] + public void OnError_PrePostAndOnError_ExecutesCorrectly() + { + var source = """ + using PatternKit.Generators.Messaging; + using System; + using System.Collections.Generic; + using System.Threading; + using System.Threading.Tasks; + + [assembly: GenerateDispatcher(Namespace = "MyApp.Messaging", Name = "AppDispatcher")] + + namespace MyApp; + + using MyApp.Messaging; + + public record FailCommand(string Message); + public record Result(string Data); + + public static class Demo + { + private static List log = new(); + + public static async Task Run() + { + log.Clear(); + var dispatcher = AppDispatcher.Create() + .Pre((req, ct) => { log.Add("Pre"); return ValueTask.CompletedTask; }) + .Command((req, ct) => throw new Exception("Fail")) + .Post((req, res, ct) => { log.Add("Post:ShouldNotRun"); return ValueTask.CompletedTask; }) + .OnError((req, ex, ct) => { log.Add("OnError"); return ValueTask.CompletedTask; }) + .Build(); + + try + { + await dispatcher.Send(new FailCommand("Test"), default); + } + catch + { + log.Add("Caught"); + } + + return string.Join("|", log); + } + } + """; + + var comp = RoslynTestHelpers.CreateCompilation( + source, + assemblyName: nameof(OnError_PrePostAndOnError_ExecutesCorrectly)); + + var gen = new PatternKit.Generators.Messaging.DispatcherGenerator(); + _ = RoslynTestHelpers.Run(comp, gen, out _, out var updated); + + var emit = updated.Emit(Stream.Null); + Assert.True(emit.Success, string.Join("\n", emit.Diagnostics)); + + using var pe = new MemoryStream(); + var emitResult = updated.Emit(pe); + Assert.True(emitResult.Success); + + pe.Seek(0, SeekOrigin.Begin); + var asm = System.Reflection.Assembly.Load(pe.ToArray()); + var demo = asm.GetType("MyApp.Demo"); + var run = demo!.GetMethod("Run"); + var task = (Task)run!.Invoke(null, null)!; + var result = task.Result; + + // Pre runs, handler throws, OnError runs, Post does NOT run + Assert.Equal("Pre|OnError|Caught", result); + } + + #endregion + + #region Stream Pipeline Tests + + [Fact] + public void StreamPipeline_PreHook_ExecutesBeforeStream() + { + var source = """ + using PatternKit.Generators.Messaging; + using System.Collections.Generic; + using System.Runtime.CompilerServices; + using System.Threading; + using System.Threading.Tasks; + + [assembly: GenerateDispatcher(Namespace = "MyApp.Messaging", Name = "AppDispatcher")] + + namespace MyApp; + + using MyApp.Messaging; + + public record GetNumbers(int Count); + + public static class Demo + { + private static List log = new(); + + private static async IAsyncEnumerable GenerateNumbers(GetNumbers req, [EnumeratorCancellation] CancellationToken ct) + { + for (int i = 1; i <= req.Count; i++) + { + log.Add($"Item:{i}"); + yield return i; + } + } + + public static async Task Run() + { + log.Clear(); + var dispatcher = AppDispatcher.Create() + .PreStream((req, ct) => { log.Add("PreStream"); return ValueTask.CompletedTask; }) + .Stream(GenerateNumbers) + .Build(); + + await foreach (var num in dispatcher.Stream(new GetNumbers(3), default)) + { + // Consume + } + + return string.Join("|", log); + } + } + """; + + var comp = RoslynTestHelpers.CreateCompilation( + source, + assemblyName: nameof(StreamPipeline_PreHook_ExecutesBeforeStream)); + + var gen = new PatternKit.Generators.Messaging.DispatcherGenerator(); + _ = RoslynTestHelpers.Run(comp, gen, out _, out var updated); + + var emit = updated.Emit(Stream.Null); + Assert.True(emit.Success, string.Join("\n", emit.Diagnostics)); + + using var pe = new MemoryStream(); + var emitResult = updated.Emit(pe); + Assert.True(emitResult.Success); + + pe.Seek(0, SeekOrigin.Begin); + var asm = System.Reflection.Assembly.Load(pe.ToArray()); + var demo = asm.GetType("MyApp.Demo"); + var run = demo!.GetMethod("Run"); + var task = (Task)run!.Invoke(null, null)!; + var result = task.Result; + + Assert.Equal("PreStream|Item:1|Item:2|Item:3", result); + } + + #endregion + + #region Object Overload Tests + + [Fact] + public void ObjectOverloads_Send_DispatchesCorrectly() + { + var source = """ + using PatternKit.Generators.Messaging; + using System.Threading; + using System.Threading.Tasks; + + [assembly: GenerateDispatcher( + Namespace = "MyApp.Messaging", + Name = "AppDispatcher", + IncludeObjectOverloads = true)] + + namespace MyApp; + + using MyApp.Messaging; + + public record GetValue(int Id); + public record ValueResult(int Value); + + public static class Demo + { + public static async Task Run() + { + var dispatcher = AppDispatcher.Create() + .Command((req, ct) => + new ValueTask(new ValueResult(req.Id * 10))) + .Build(); + + object request = new GetValue(5); + var response = await dispatcher.Send(request, default); + + if (response is ValueResult vr) + return $"Result:{vr.Value}"; + + return "Failed"; + } + } + """; + + var comp = RoslynTestHelpers.CreateCompilation( + source, + assemblyName: nameof(ObjectOverloads_Send_DispatchesCorrectly)); + + var gen = new PatternKit.Generators.Messaging.DispatcherGenerator(); + _ = RoslynTestHelpers.Run(comp, gen, out _, out var updated); + + var emit = updated.Emit(Stream.Null); + Assert.True(emit.Success, string.Join("\n", emit.Diagnostics)); + + using var pe = new MemoryStream(); + var emitResult = updated.Emit(pe); + Assert.True(emitResult.Success); + + pe.Seek(0, SeekOrigin.Begin); + var asm = System.Reflection.Assembly.Load(pe.ToArray()); + var demo = asm.GetType("MyApp.Demo"); + var run = demo!.GetMethod("Run"); + var task = (Task)run!.Invoke(null, null)!; + var result = task.Result; + + Assert.Equal("Result:50", result); + } + + [Fact] + public void ObjectOverloads_Publish_DispatchesCorrectly() + { + var source = """ + using PatternKit.Generators.Messaging; + using System.Collections.Generic; + using System.Threading; + using System.Threading.Tasks; + + [assembly: GenerateDispatcher( + Namespace = "MyApp.Messaging", + Name = "AppDispatcher", + IncludeObjectOverloads = true)] + + namespace MyApp; + + using MyApp.Messaging; + + public record SomethingHappened(string Message); + + public static class Demo + { + private static List log = new(); + + public static async Task Run() + { + log.Clear(); + var dispatcher = AppDispatcher.Create() + .Notification((evt, ct) => + { + log.Add($"Handler:{evt.Message}"); + return ValueTask.CompletedTask; + }) + .Build(); + + object notification = new SomethingHappened("Test"); + await dispatcher.Publish(notification, default); + + return string.Join("|", log); + } + } + """; + + var comp = RoslynTestHelpers.CreateCompilation( + source, + assemblyName: nameof(ObjectOverloads_Publish_DispatchesCorrectly)); + + var gen = new PatternKit.Generators.Messaging.DispatcherGenerator(); + _ = RoslynTestHelpers.Run(comp, gen, out _, out var updated); + + var emit = updated.Emit(Stream.Null); + Assert.True(emit.Success, string.Join("\n", emit.Diagnostics)); + + using var pe = new MemoryStream(); + var emitResult = updated.Emit(pe); + Assert.True(emitResult.Success); + + pe.Seek(0, SeekOrigin.Begin); + var asm = System.Reflection.Assembly.Load(pe.ToArray()); + var demo = asm.GetType("MyApp.Demo"); + var run = demo!.GetMethod("Run"); + var task = (Task)run!.Invoke(null, null)!; + var result = task.Result; + + Assert.Equal("Handler:Test", result); + } + + [Fact] + public void ObjectOverloads_Stream_DispatchesCorrectly() + { + var source = """ + using PatternKit.Generators.Messaging; + using System; + using System.Collections.Generic; + using System.Runtime.CompilerServices; + using System.Threading; + using System.Threading.Tasks; + + [assembly: GenerateDispatcher( + Namespace = "MyApp.Messaging", + Name = "AppDispatcher", + IncludeStreaming = true, + IncludeObjectOverloads = true)] + + namespace MyApp; + + using MyApp.Messaging; + + public record RangeRequest(int Start, int Count); + + public static class Demo + { + private static async IAsyncEnumerable GenerateRange(RangeRequest req, [EnumeratorCancellation] CancellationToken ct) + { + for (int i = req.Start; i < req.Start + req.Count; i++) + { + yield return i; + } + } + + public static async Task Run() + { + try + { + var dispatcher = AppDispatcher.Create() + .Stream(GenerateRange) + .Build(); + + var items = new List(); + object request = new RangeRequest(10, 5); + + await foreach (var item in dispatcher.Stream(request, default)) + { + items.Add((int)item!); + } + + return string.Join(",", items); + } + catch (Exception ex) + { + return $"ERROR:{ex.Message}"; + } + } + } + """; + + var comp = RoslynTestHelpers.CreateCompilation( + source, + assemblyName: nameof(ObjectOverloads_Stream_DispatchesCorrectly)); + + var gen = new PatternKit.Generators.Messaging.DispatcherGenerator(); + _ = RoslynTestHelpers.Run(comp, gen, out _, out var updated); + + var emit = updated.Emit(Stream.Null); + Assert.True(emit.Success, string.Join("\n", emit.Diagnostics)); + + using var pe = new MemoryStream(); + var emitResult = updated.Emit(pe); + Assert.True(emitResult.Success); + + pe.Seek(0, SeekOrigin.Begin); + var asm = System.Reflection.Assembly.Load(pe.ToArray()); + var demo = asm.GetType("MyApp.Demo"); + var run = demo!.GetMethod("Run"); + var task = (Task)run!.Invoke(null, null)!; + var result = task.Result; + + // Should either be the expected result or an error message + if (result.StartsWith("ERROR:")) + { + Assert.Fail($"Test threw exception: {result}"); + } + Assert.Equal("10,11,12,13,14", result); + } + + #endregion + + #region Module System Tests + + [Fact] + public void ModuleSystem_AddModule_RegistersHandlers() + { + var source = """ + using PatternKit.Generators.Messaging; + using System.Threading; + using System.Threading.Tasks; + + [assembly: GenerateDispatcher(Namespace = "MyApp.Messaging", Name = "AppDispatcher")] + + namespace MyApp; + + using MyApp.Messaging; + + public record Ping(string Message); + public record Pong(string Reply); + + public class TestModule : IModule + { + public void Register(IDispatcherBuilder builder) + { + builder.Command((req, ct) => + new ValueTask(new Pong($"Module:{req.Message}"))); + } + } + + public static class Demo + { + public static async Task Run() + { + var dispatcher = AppDispatcher.Create() + .AddModule(new TestModule()) + .Build(); + + var response = await dispatcher.Send(new Ping("Hello"), default); + return response.Reply; + } + } + """; + + var comp = RoslynTestHelpers.CreateCompilation( + source, + assemblyName: nameof(ModuleSystem_AddModule_RegistersHandlers)); + + var gen = new PatternKit.Generators.Messaging.DispatcherGenerator(); + _ = RoslynTestHelpers.Run(comp, gen, out _, out var updated); + + var emit = updated.Emit(Stream.Null); + Assert.True(emit.Success, string.Join("\n", emit.Diagnostics)); + + using var pe = new MemoryStream(); + var emitResult = updated.Emit(pe); + Assert.True(emitResult.Success); + + pe.Seek(0, SeekOrigin.Begin); + var asm = System.Reflection.Assembly.Load(pe.ToArray()); + var demo = asm.GetType("MyApp.Demo"); + var run = demo!.GetMethod("Run"); + var task = (Task)run!.Invoke(null, null)!; + var result = task.Result; + + Assert.Equal("Module:Hello", result); + } + + #endregion }