Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Api/Auth/Controllers/WebAuthnController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ private async Task ValidateIfUserCanUsePasskeyLogin(Guid userId)
return;
}

var requireSsoPolicyRequirement = await _policyRequirementQuery.GetAsync<RequireSsoPolicyRequirement>(userId);
var requireSsoPolicyRequirement = await _policyRequirementQuery.GetAsyncVNext<RequireSsoPolicyRequirement>(userId);

if (!requireSsoPolicyRequirement.CanUsePasskeyLogin)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@ public interface IPolicyRequirementQuery
/// <typeparam name="T">The IPolicyRequirement that corresponds to the policy you want to enforce.</typeparam>
Task<T> GetAsync<T>(Guid userId) where T : IPolicyRequirement;

/// <summary>
/// Get a policy requirement for a specific user using the optimized single-user query.
/// The policy requirement represents how one or more policy types should be enforced against the user.
/// It will always return a value even if there are no policies that should be enforced.
/// This is the vNext version that uses the optimized GetPolicyDetailsByUserIdAndPolicyTypeAsync method.
/// </summary>
/// <param name="userId">The user that you need to enforce the policy against.</param>
/// <typeparam name="T">The IPolicyRequirement that corresponds to the policy you want to enforce.</typeparam>
Task<T> GetAsyncVNext<T>(Guid userId) where T : IPolicyRequirement;

/// <summary>
/// Get a policy requirement for a list of users.
/// The policy requirement represents how one or more policy types should be enforced against the users.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,20 @@ public class PolicyRequirementQuery(
public async Task<T> GetAsync<T>(Guid userId) where T : IPolicyRequirement
=> (await GetAsync<T>([userId])).Single().Requirement;

public async Task<T> GetAsyncVNext<T>(Guid userId) where T : IPolicyRequirement
{
var factory = factories.OfType<IPolicyRequirementFactory<T>>().SingleOrDefault();
if (factory is null)
{
throw new NotImplementedException("No Requirement Factory found for " + typeof(T));
}

var policyDetails = await policyRepository.GetPolicyDetailsByUserIdAndPolicyTypeAsync(userId, factory.PolicyType);
var enforcedPolicyDetails = policyDetails.Where(factory.Enforce);

return factory.Create(enforcedPolicyDetails);
}

public async Task<IEnumerable<(Guid UserId, T Requirement)>> GetAsync<T>(IEnumerable<Guid> userIds) where T : IPolicyRequirement
{
var factory = factories.OfType<IPolicyRequirementFactory<T>>().SingleOrDefault();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ public RequireSsoPolicyRequirementFactory(GlobalSettings globalSettings)

public override RequireSsoPolicyRequirement Create(IEnumerable<PolicyDetails> policyDetails)
{
policyDetails = policyDetails.ToList();
var result = new RequireSsoPolicyRequirement
{
CanUsePasskeyLogin = policyDetails.All(p =>
p.OrganizationUserStatus == OrganizationUserStatusType.Revoked ||
p.OrganizationUserStatus == OrganizationUserStatusType.Invited),
CanUsePasskeyLogin = policyDetails.Any(p =>
p.OrganizationUserStatus is OrganizationUserStatusType.Accepted or OrganizationUserStatusType.Confirmed),

SsoRequired = policyDetails.Any(p =>
p.OrganizationUserStatus == OrganizationUserStatusType.Confirmed)
Expand Down
16 changes: 16 additions & 0 deletions src/Core/AdminConsole/Repositories/IPolicyRepository.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,20 @@ public interface IPolicyRepository : IRepository<Policy, Guid>
/// associated with the specified users and policy type.
/// </returns>
Task<IEnumerable<OrganizationPolicyDetails>> GetPolicyDetailsByUserIdsAndPolicyType(IEnumerable<Guid> userIds, PolicyType policyType);

/// <summary>
/// Retrieves policy details for a single user filtered by the specified policy type.
/// </summary>
/// <remarks>
/// Returns policy details only for enabled policies from enabled organizations that support policies.
/// This includes both confirmed users (matched by UserId) and invited users (matched by email).
/// Provider users are identified via the IsProvider flag.
/// </remarks>
/// <param name="userId">The user identifier for which policy details are to be fetched.</param>
/// <param name="policyType">The type of policy for which the details are required.</param>
/// <returns>
/// An asynchronous task that returns a collection of <see cref="PolicyDetails"/> objects containing
/// the policy information associated with the specified user and policy type.
/// </returns>
Task<IEnumerable<PolicyDetails>> GetPolicyDetailsByUserIdAndPolicyTypeAsync(Guid userId, PolicyType policyType);
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ private async Task<bool> RequireSsoAuthenticationAsync(User user, string grantTy

// Check if user belongs to any organization with an active SSO policy
var ssoRequired = _featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements)
? (await _policyRequirementQuery.GetAsync<RequireSsoPolicyRequirement>(user.Id))
? (await _policyRequirementQuery.GetAsyncVNext<RequireSsoPolicyRequirement>(user.Id))
.SsoRequired
: await _policyService.AnyPoliciesApplicableToUserAsync(
user.Id, PolicyType.RequireSso, OrganizationUserStatusType.Confirmed);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,19 @@ public async Task<IEnumerable<OrganizationPolicyDetails>> GetPolicyDetailsByOrga
return results.ToList();
}
}

public async Task<IEnumerable<PolicyDetails>> GetPolicyDetailsByUserIdAndPolicyTypeAsync(Guid userId, PolicyType policyType)
{
await using var connection = new SqlConnection(ConnectionString);
var results = await connection.QueryAsync<PolicyDetails>(
$"[{Schema}].[PolicyDetails_ReadByUserIdPolicyType]",
new
{
UserId = userId,
PolicyType = (byte)policyType
},
commandType: CommandType.StoredProcedure);

return results.ToList();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -234,4 +234,53 @@ where p.Enabled

return allResults.ToList();
}

public async Task<IEnumerable<PolicyDetails>> GetPolicyDetailsByUserIdAndPolicyTypeAsync(Guid userId, PolicyType policyType)
{
using var scope = ServiceScopeFactory.CreateScope();
var dbContext = GetDatabaseContext(scope);

// Get user email for invited user matching
var userEmail = await dbContext.Users
.Where(u => u.Id == userId)
.Select(u => u.Email)
.FirstOrDefaultAsync();

// Get provider relationships
var providerOrganizationIds = await (from pu in dbContext.ProviderUsers
join po in dbContext.ProviderOrganizations on pu.ProviderId equals po.ProviderId
where pu.UserId == userId
select po.OrganizationId)
.Distinct()
.ToListAsync();

var providerSet = new HashSet<Guid>(providerOrganizationIds);

// Get organization users (both confirmed/accepted and invited)
var orgUsersQuery = dbContext.OrganizationUsers
.Where(ou => (ou.Status != OrganizationUserStatusType.Invited && ou.UserId == userId) ||
(ou.Status == OrganizationUserStatusType.Invited && ou.Email == userEmail));

// Join with policies and organizations
var query = from policy in dbContext.Policies
join orgUser in orgUsersQuery on policy.OrganizationId equals orgUser.OrganizationId
join org in dbContext.Organizations on policy.OrganizationId equals org.Id
where policy.Type == policyType
&& policy.Enabled
&& org.Enabled
&& org.UsePolicies
select new PolicyDetails
{
OrganizationUserId = orgUser.Id,
OrganizationId = policy.OrganizationId,
PolicyType = policy.Type,
PolicyData = policy.Data,
OrganizationUserType = orgUser.Type,
OrganizationUserStatus = orgUser.Status,
OrganizationUserPermissionsData = orgUser.Permissions,
IsProvider = providerSet.Contains(policy.OrganizationId)
};

return await query.ToListAsync();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
CREATE PROCEDURE [dbo].[PolicyDetails_ReadByUserIdPolicyType]
@UserId UNIQUEIDENTIFIER,
@PolicyType TINYINT
AS
BEGIN
SET NOCOUNT ON

DECLARE @UserEmail NVARCHAR(256)
SELECT @UserEmail = Email
FROM
[dbo].[UserView]
WHERE
Id = @UserId

;WITH OrgUsers AS
(
-- Non-invited users (Status != 0): direct UserId match
SELECT
OU.[Id],
OU.[OrganizationId],
OU.[Type],
OU.[Status],
OU.[Permissions]
FROM
[dbo].[OrganizationUserView] OU
WHERE
OU.[Status] != 0
AND OU.[UserId] = @UserId

UNION ALL

-- Invited users (Status = 0): email match
SELECT
OU.[Id],
OU.[OrganizationId],
OU.[Type],
OU.[Status],
OU.[Permissions]
FROM
[dbo].[OrganizationUserView] OU
WHERE
OU.[Status] = 0
AND OU.[Email] = @UserEmail
AND @UserEmail IS NOT NULL
),
Providers AS
(
SELECT
OrganizationId
FROM
[dbo].[UserProviderAccessView]
WHERE
UserId = @UserId
)
SELECT
OU.[Id] AS OrganizationUserId,
P.[OrganizationId],
P.[Type] AS PolicyType,
P.[Data] AS PolicyData,
OU.[Type] AS OrganizationUserType,
OU.[Status] AS OrganizationUserStatus,
OU.[Permissions] AS OrganizationUserPermissionsData,
CASE WHEN PR.[OrganizationId] IS NULL THEN 0 ELSE 1 END AS IsProvider
FROM
[dbo].[PolicyView] P
INNER JOIN
OrgUsers OU ON P.[OrganizationId] = OU.[OrganizationId]
INNER JOIN
[dbo].[OrganizationView] O ON P.[OrganizationId] = O.[Id]
LEFT JOIN
Providers PR ON PR.[OrganizationId] = OU.[OrganizationId]
WHERE
P.[Type] = @PolicyType
AND P.[Enabled] = 1
AND O.[Enabled] = 1
AND O.[UsePolicies] = 1
END
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,60 @@ public async Task GetAsync_HandlesNoPolicies(Guid userId)
Assert.Empty(requirement.Policies);
}

[Theory, BitAutoData]
public async Task GetAsyncVNext_CallsEnforceCallback(Guid userId)
{
// Arrange policies
var policyRepository = Substitute.For<IPolicyRepository>();
var thisPolicy = new PolicyDetails { PolicyType = PolicyType.SingleOrg };
var otherPolicy = new PolicyDetails { PolicyType = PolicyType.SingleOrg };
policyRepository.GetPolicyDetailsByUserIdAndPolicyTypeAsync(userId, PolicyType.SingleOrg)
.Returns([thisPolicy, otherPolicy]);

// Arrange a substitute Enforce function so that we can inspect the received calls
var callback = Substitute.For<Func<PolicyDetails, bool>>();
callback(Arg.Any<PolicyDetails>()).Returns(x => x.Arg<PolicyDetails>() == thisPolicy);

// Arrange the sut
var factory = new TestPolicyRequirementFactory(callback);
var sut = new PolicyRequirementQuery(policyRepository, [factory]);

// Act
var requirement = await sut.GetAsyncVNext<TestPolicyRequirement>(userId);

// Assert
Assert.Contains(thisPolicy, requirement.Policies);
Assert.DoesNotContain(otherPolicy, requirement.Policies);
callback.Received()(Arg.Is(thisPolicy));
callback.Received()(Arg.Is(otherPolicy));
}

[Theory, BitAutoData]
public async Task GetAsyncVNext_ThrowsIfNoFactoryRegistered(Guid userId)
{
var policyRepository = Substitute.For<IPolicyRepository>();
var sut = new PolicyRequirementQuery(policyRepository, []);

var exception = await Assert.ThrowsAsync<NotImplementedException>(()
=> sut.GetAsyncVNext<TestPolicyRequirement>(userId));
Assert.Contains("No Requirement Factory found", exception.Message);
}

[Theory, BitAutoData]
public async Task GetAsyncVNext_HandlesNoPolicies(Guid userId)
{
var policyRepository = Substitute.For<IPolicyRepository>();
policyRepository.GetPolicyDetailsByUserIdAndPolicyTypeAsync(userId, PolicyType.SingleOrg)
.Returns([]);

var factory = new TestPolicyRequirementFactory(x => x.IsProvider);
var sut = new PolicyRequirementQuery(policyRepository, [factory]);

var requirement = await sut.GetAsyncVNext<TestPolicyRequirement>(userId);

Assert.Empty(requirement.Policies);
}

[Theory, BitAutoData]
public async Task GetAsync_WithMultipleUserIds_ReturnsRequirementPerUser(Guid userIdA, Guid userIdB)
{
Expand Down
Loading
Loading