From de83ec8514e40db072760a00b60a0f4e491ae193 Mon Sep 17 00:00:00 2001 From: DMarkovkin Date: Tue, 25 Mar 2025 09:52:12 +0100 Subject: [PATCH] use instance method or prop as ServiceProviderMemberData source --- docs/xunit.md | 14 +++- .../Sdk/ServiceProviderMemberDataAttribute.cs | 79 +++++++++++-------- .../ServiceProviderMemberDataDiscoverer.cs | 8 +- xunit/src/Directory.Build.props | 2 - 4 files changed, 65 insertions(+), 38 deletions(-) diff --git a/docs/xunit.md b/docs/xunit.md index c2a9e53..32f523a 100644 --- a/docs/xunit.md +++ b/docs/xunit.md @@ -38,13 +38,21 @@ public class EnvironmentInitializer : IFrameworkInitializer ``` **Parameterized tests with data from the IServiceProvider:** -Create a static method that accepts IServiceProvider and returns IEnumerable. Pass that method to ServiceProviderMemberDataAttribute to use it as data source for Theory. +Create a instance method or property that returns IEnumerable. Pass that method to ServiceProviderMemberDataAttribute to use it as data source for Theory. ``` [Bss.Testing.Xunit.Sdk.Theory] [ServiceProviderMemberData(nameof(GetMemberData))] public void GetDataFromServiceProvider(FullSecurityRole role) => Assert.NotEmpty(role.Name); -protected static IEnumerable GetMemberData(IServiceProvider serviceProvider) => - serviceProvider.GetRequiredService().SecurityRoles.Select(x => new [] { x }); +protected IEnumerable GetMemberData() => + this.ServiceProvider.GetRequiredService().SecurityRoles.Select(x => new [] { x }); +``` + +``` +[Bss.Testing.Xunit.Sdk.Theory] +[ServiceProviderMemberData(nameof(GetMemberData))] +public void GetDataFromServiceProvider(FullSecurityRole role) => Assert.NotEmpty(role.Name); + +protected IEnumerable GetMemberData => this.ServiceProvider.GetRequiredService().SecurityRoles.Select(x => new [] { x }); ``` \ No newline at end of file diff --git a/xunit/src/Bss.Testing.Xunit/Sdk/ServiceProviderMemberDataAttribute.cs b/xunit/src/Bss.Testing.Xunit/Sdk/ServiceProviderMemberDataAttribute.cs index 5a32650..d5d477b 100644 --- a/xunit/src/Bss.Testing.Xunit/Sdk/ServiceProviderMemberDataAttribute.cs +++ b/xunit/src/Bss.Testing.Xunit/Sdk/ServiceProviderMemberDataAttribute.cs @@ -5,86 +5,101 @@ using System.Reflection; +using Microsoft.Extensions.DependencyInjection; + namespace Bss.Testing.Xunit.Sdk; [DataDiscoverer("Bss.Testing.Xunit.Sdk.ServiceProviderMemberDataDiscoverer", "Bss.Testing.Xunit")] [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] -public class ServiceProviderMemberDataAttribute(string methodName) : DataAttribute +public class ServiceProviderMemberDataAttribute(string methodOrPropertyName) : DataAttribute { - public Type? MemberType { get; set; } + public Type MemberType { get; set; } - string MemberName { get; set; } = methodName; + string MemberName { get; set; } = methodOrPropertyName; - public override IEnumerable? GetData(MethodInfo testMethod) => null; + public override IEnumerable GetData(MethodInfo testMethod) => null; - public IEnumerable? GetData(MethodInfo testMethod, IServiceProvider? serviceProvider) + public IEnumerable GetData(MethodInfo testMethod, IServiceProvider serviceProvider) { var type = this.MemberType ?? testMethod.DeclaringType; - if (type == null) - { - throw new ArgumentException( - string.Format( - CultureInfo.CurrentCulture, - "Could not find type {0}", - type?.FullName) - ); - } + var accessor = this.GetMethodAccessor(type, serviceProvider) + ?? this.GetPropertyAccessor(type, serviceProvider); - var accessor = this.GetMethodAccessor(type, serviceProvider); if (accessor == null) { throw new ArgumentException( string.Format( CultureInfo.CurrentCulture, - "Could not find public static method named '{0}' on {1}{2}", + "Could not find parameterless method or property named '{0}' on {1} provided in ServiceProviderMemberDataAttribute", this.MemberName, - type?.FullName, - " with parameter types: IServiceProvider") + type?.FullName) ); } var obj = accessor(); if (obj == null) { - return (IEnumerable) Array.Empty(); + return null; } if (obj is not IEnumerable dataItems) { - throw new ArgumentException(string.Format(CultureInfo.CurrentCulture, "Method {0} on {1} did not return IEnumerable", this.MemberName, type?.FullName)); + throw new ArgumentException(string.Format(CultureInfo.CurrentCulture, "Method/property {0} on {1} did not return IEnumerable", this.MemberName, type?.FullName)); } - return dataItems.Cast().Select(item => this.ConvertDataItem(testMethod, item))!; + return dataItems.Cast().Select(item => this.ConvertDataItem(testMethod, item)); } - protected Func? GetMethodAccessor(Type type, IServiceProvider? serviceProvider) + private Func? GetMethodAccessor(Type type, IServiceProvider serviceProvider) { MethodInfo? methodInfo = null; for (var reflectionType = type; reflectionType != null; reflectionType = reflectionType.GetTypeInfo().BaseType) { - var runtimeMethodsWithGivenName = reflectionType.GetRuntimeMethods() - .Where(m => m.Name == this.MemberName) - .ToArray(); + methodInfo = reflectionType + .GetRuntimeMethods() + .FirstOrDefault(m => m.Name == this.MemberName); + if (methodInfo != null) + { + break; + } + } + + if (methodInfo == null) + { + return null; + } - methodInfo = runtimeMethodsWithGivenName - .FirstOrDefault(m => m.GetParameters() - .Count(x => x.ParameterType.IsAssignableTo(typeof(IServiceProvider))) == 1); + var @object = ActivatorUtilities.CreateInstance(serviceProvider, type); - if (methodInfo != null) + return () => methodInfo.Invoke(@object, null); + } + + private Func? GetPropertyAccessor(Type type, IServiceProvider serviceProvider) + { + PropertyInfo? propertyInfo = null; + for (var reflectionType = type; reflectionType != null; reflectionType = reflectionType.GetTypeInfo().BaseType) + { + propertyInfo = reflectionType + .GetProperties() + .FirstOrDefault(m => m.Name == this.MemberName); + + if (propertyInfo != null) { break; } } - if (methodInfo == null || !methodInfo.IsStatic) + if (propertyInfo == null) { return null; } - return () => methodInfo.Invoke(null, [serviceProvider])!; + var @object = ActivatorUtilities.CreateInstance(serviceProvider, type); + + return () => propertyInfo.GetValue(@object); } - protected object[]? ConvertDataItem(MethodInfo testMethod, object? item) + private object[]? ConvertDataItem(MethodInfo testMethod, object? item) { if (item == null) { diff --git a/xunit/src/Bss.Testing.Xunit/Sdk/ServiceProviderMemberDataDiscoverer.cs b/xunit/src/Bss.Testing.Xunit/Sdk/ServiceProviderMemberDataDiscoverer.cs index 7c8d2ad..1e7d25f 100644 --- a/xunit/src/Bss.Testing.Xunit/Sdk/ServiceProviderMemberDataDiscoverer.cs +++ b/xunit/src/Bss.Testing.Xunit/Sdk/ServiceProviderMemberDataDiscoverer.cs @@ -6,10 +6,16 @@ namespace Bss.Testing.Xunit.Sdk; public class ServiceProviderMemberDataDiscoverer : IDataDiscoverer { public IEnumerable GetData(IAttributeInfo dataAttribute, IMethodInfo testMethod) => - throw new ArgumentException($"ServiceProviderMemberDataDiscoverer cannot be used as discoverer for any *DataAttribute other than ServiceProviderMemberDataAttribute."); + throw new ArgumentException("ServiceProviderMemberDataDiscoverer cannot be used as discoverer for any *DataAttribute other than ServiceProviderMemberDataAttribute."); public IEnumerable? GetData(IAttributeInfo dataAttribute, IMethodInfo testMethod, IServiceProvider? serviceProvider) { + if (serviceProvider == null) + { + throw new ArgumentException($"ServiceProvider cannot be null for {nameof(dataAttribute)}"); + } + + if (dataAttribute is not IReflectionAttributeInfo reflectionDataAttribute || testMethod is not IReflectionMethodInfo reflectionTestMethod) { diff --git a/xunit/src/Directory.Build.props b/xunit/src/Directory.Build.props index f08551c..11d4507 100644 --- a/xunit/src/Directory.Build.props +++ b/xunit/src/Directory.Build.props @@ -10,8 +10,6 @@ false false - - true