diff --git a/Dappi.HeadlessCms.Tests/Controllers/ModelsControllerTests.cs b/Dappi.HeadlessCms.Tests/Controllers/ModelsControllerTests.cs index 93c2cb1..e9448f0 100644 --- a/Dappi.HeadlessCms.Tests/Controllers/ModelsControllerTests.cs +++ b/Dappi.HeadlessCms.Tests/Controllers/ModelsControllerTests.cs @@ -1,4 +1,3 @@ -using System.Net.Http; using System.Net.Http.Json; using System.Text.RegularExpressions; using Dappi.HeadlessCms.Models; diff --git a/Dappi.HeadlessCms.Tests/TestData/PropertyAndClassNames.cs b/Dappi.HeadlessCms.Tests/TestData/PropertyAndClassNames.cs index d3f9a99..92c1f3e 100644 --- a/Dappi.HeadlessCms.Tests/TestData/PropertyAndClassNames.cs +++ b/Dappi.HeadlessCms.Tests/TestData/PropertyAndClassNames.cs @@ -1,5 +1,4 @@ using Dappi.HeadlessCms.Models; -using Xunit; namespace Dappi.HeadlessCms.Tests.TestData { diff --git a/Dappi.HeadlessCms/ActionFilters/IncludeQueryFilter.cs b/Dappi.HeadlessCms/ActionFilters/IncludeQueryFilter.cs new file mode 100644 index 0000000..e813be0 --- /dev/null +++ b/Dappi.HeadlessCms/ActionFilters/IncludeQueryFilter.cs @@ -0,0 +1,64 @@ +using Dappi.HeadlessCms.Models; +using Microsoft.AspNetCore.Mvc.Filters; + +namespace Dappi.HeadlessCms.ActionFilters +{ + public class IncludeQueryFilter : ActionFilterAttribute + { + public const string IncludeParamsKey = "Includes"; + + public override void OnActionExecuting(ActionExecutingContext context) + { + if (!context.HttpContext.Request.Query.TryGetValue("include", out var includeValues)) + return; + + var includeTree = new Dictionary(StringComparer.OrdinalIgnoreCase); + + foreach (var segments in from includeValue in includeValues.OfType() select includeValue + .Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries) into includePaths from includePath in includePaths select includePath + .Split('.', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries) + .ToArray() into segments where segments.Length != 0 select segments) + { + AddSegments(includeTree, segments, 0); + } + + if (includeTree.Count == 0) + { + return; + } + + context.HttpContext.Items[IncludeParamsKey] = includeTree; + } + + private static void AddSegments(IDictionary nodes, IReadOnlyList segments, int index) + { + while (index != segments.Count) + { + var segment = CapitalizeSegment(segments[index]); + if (!nodes.TryGetValue(segment, out var current)) + { + current = new IncludeNode(segment); + nodes[segment] = current; + } + + nodes = current.Children; + index++; + } + } + + private static string CapitalizeSegment(string segment) + { + if (string.IsNullOrEmpty(segment)) + { + return segment; + } + + if (segment.Length == 1) + { + return segment.ToUpperInvariant(); + } + + return char.ToUpperInvariant(segment[0]) + segment.Substring(1); + } + } +} diff --git a/Dappi.HeadlessCms/Models/IncludeNode.cs b/Dappi.HeadlessCms/Models/IncludeNode.cs new file mode 100644 index 0000000..be82f2a --- /dev/null +++ b/Dappi.HeadlessCms/Models/IncludeNode.cs @@ -0,0 +1,14 @@ +namespace Dappi.HeadlessCms.Models +{ + public class IncludeNode + { + public IncludeNode(string name) + { + Name = name; + } + + public string Name { get; } + + public IDictionary Children { get; } = new Dictionary(StringComparer.OrdinalIgnoreCase); + } +} \ No newline at end of file diff --git a/Dappi.SourceGenerator/CrudGenerator.cs b/Dappi.SourceGenerator/CrudGenerator.cs index 8d9e711..da07da4 100644 --- a/Dappi.SourceGenerator/CrudGenerator.cs +++ b/Dappi.SourceGenerator/CrudGenerator.cs @@ -65,6 +65,7 @@ protected override void Execute(SourceProductionContext context, using System.IO; using System.Reflection; using System.Collections; +using System.Collections.Generic; using Dappi.Core.Constants; using System.Globalization; using System.Linq; @@ -136,6 +137,35 @@ private dynamic GetDbSetForType(string typeName) return dbSetProperty?.GetValue(dbContext); }} + + private IQueryable<{item.ClassName}> ApplyDynamicIncludes(IQueryable<{item.ClassName}> query) + {{ + var includeTree = HttpContext.Items[IncludeQueryFilter.IncludeParamsKey] as IDictionary; + if (includeTree is null || includeTree.Count == 0) + {{ + return query; + }} + + foreach (var include in includeTree) + {{ + query = ApplyIncludeRecursively(query, include.Key, include.Value); + }} + + return query; + }} + + private static IQueryable<{item.ClassName}> ApplyIncludeRecursively(IQueryable<{item.ClassName}> query, string path, IncludeNode node) + {{ + query = query.Include(path); + + foreach (var child in node.Children) + {{ + var childPath = string.Concat(path, ""."", child.Key); + query = ApplyIncludeRecursively(query, childPath, child.Value); + }} + + return query; + }} }}"; context.AddSource($"{item.ClassName}Controller.cs", generatedCode); diff --git a/Dappi.SourceGenerator/Generators/ActionsGenerator.cs b/Dappi.SourceGenerator/Generators/ActionsGenerator.cs index da70be7..02ff4bb 100644 --- a/Dappi.SourceGenerator/Generators/ActionsGenerator.cs +++ b/Dappi.SourceGenerator/Generators/ActionsGenerator.cs @@ -19,6 +19,7 @@ public static string GenerateGetByIdAction(List crudActions, Source [HttpGet("{id}")] {{PropagateDappiAuthorizationTags(item.AuthorizeAttributes, AuthorizeMethods.Get)}} + [IncludeQueryFilter] public async Task Get{{item.ClassName}}(Guid id, [FromQuery] string? fields = null) { try @@ -26,9 +27,9 @@ public static string GenerateGetByIdAction(List crudActions, Source if (id == Guid.Empty) return BadRequest(); - var query = dbContext.{{item.ClassName.Pluralize()}}.AsNoTracking().AsQueryable(); + var query = dbContext.{{item.ClassName.Pluralize()}}.AsNoTracking().AsQueryable(); - query = query{{includesCode}}; + query = ApplyDynamicIncludes(query); var result = await query .FirstOrDefaultAsync(p => p.Id == id); @@ -59,13 +60,14 @@ public static string GenerateGetAction(List crudActions, SourceMode [HttpGet] {{PropagateDappiAuthorizationTags(item.AuthorizeAttributes, AuthorizeMethods.Get)}} [CollectionFilter] + [IncludeQueryFilter] public async Task Get{{item.ClassName.Pluralize()}}([FromQuery] {{item.ClassName}}Filter? filter, [FromQuery] string? fields = null) { try { var query = dbContext.{{item.ClassName.Pluralize()}}.AsNoTracking().AsQueryable(); - query = query{{includesCode}}; + query = ApplyDynamicIncludes(query); var filters = HttpContext.Items[CollectionFilter.FilterParamsKey] as List; if (filters is not null && filters.Count > 0) @@ -115,11 +117,12 @@ public static string GenerateGetAllAction(List crudActions, SourceM [HttpGet("get-all")] {{PropagateDappiAuthorizationTags(item.AuthorizeAttributes, AuthorizeMethods.Get)}} + [IncludeQueryFilter] public async Task GetAll{{item.ClassName.Pluralize()}}() { var query = dbContext.{{item.ClassName.Pluralize()}}.AsNoTracking(); - query = query{{includesCode}}; + query = ApplyDynamicIncludes(query); return Ok(new {items = await query.ToListAsync()}); }