diff --git a/src/Api/Auth/Controllers/WebAuthnController.cs b/src/Api/Auth/Controllers/WebAuthnController.cs index 821d9e9d9c3b..8872d598de2b 100644 --- a/src/Api/Auth/Controllers/WebAuthnController.cs +++ b/src/Api/Auth/Controllers/WebAuthnController.cs @@ -146,7 +146,7 @@ private async Task ValidateIfUserCanUsePasskeyLogin(Guid userId) return; } - var requireSsoPolicyRequirement = await _policyRequirementQuery.GetAsync(userId); + var requireSsoPolicyRequirement = await _policyRequirementQuery.GetAsyncVNext(userId); if (!requireSsoPolicyRequirement.CanUsePasskeyLogin) { diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyRequirementQuery.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyRequirementQuery.cs index 02d2dedfc114..466a914dcd9c 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyRequirementQuery.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyRequirementQuery.cs @@ -16,6 +16,16 @@ public interface IPolicyRequirementQuery /// The IPolicyRequirement that corresponds to the policy you want to enforce. Task GetAsync(Guid userId) where T : IPolicyRequirement; + /// + /// Get a policy requirement for a specific user using the optimized single-user query. + /// The policy requirement represents how one or more policy types should be enforced against the user. + /// It will always return a value even if there are no policies that should be enforced. + /// This is the vNext version that uses the optimized GetPolicyDetailsByUserIdAndPolicyTypeAsync method. + /// + /// The user that you need to enforce the policy against. + /// The IPolicyRequirement that corresponds to the policy you want to enforce. + Task GetAsyncVNext(Guid userId) where T : IPolicyRequirement; + /// /// Get a policy requirement for a list of users. /// The policy requirement represents how one or more policy types should be enforced against the users. diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs index c38693fdfd9d..9fb7a2081df7 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs @@ -13,6 +13,20 @@ public class PolicyRequirementQuery( public async Task GetAsync(Guid userId) where T : IPolicyRequirement => (await GetAsync([userId])).Single().Requirement; + public async Task GetAsyncVNext(Guid userId) where T : IPolicyRequirement + { + var factory = factories.OfType>().SingleOrDefault(); + if (factory is null) + { + throw new NotImplementedException("No Requirement Factory found for " + typeof(T)); + } + + var policyDetails = await policyRepository.GetPolicyDetailsByUserIdAndPolicyTypeAsync(userId, factory.PolicyType); + var enforcedPolicyDetails = policyDetails.Where(factory.Enforce); + + return factory.Create(enforcedPolicyDetails); + } + public async Task> GetAsync(IEnumerable userIds) where T : IPolicyRequirement { var factory = factories.OfType>().SingleOrDefault(); diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/RequireSsoPolicyRequirement.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/RequireSsoPolicyRequirement.cs index f01ab1983ad2..6a6ce5fc79b3 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/RequireSsoPolicyRequirement.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/RequireSsoPolicyRequirement.cs @@ -47,11 +47,11 @@ public RequireSsoPolicyRequirementFactory(GlobalSettings globalSettings) public override RequireSsoPolicyRequirement Create(IEnumerable policyDetails) { + policyDetails = policyDetails.ToList(); var result = new RequireSsoPolicyRequirement { - CanUsePasskeyLogin = policyDetails.All(p => - p.OrganizationUserStatus == OrganizationUserStatusType.Revoked || - p.OrganizationUserStatus == OrganizationUserStatusType.Invited), + CanUsePasskeyLogin = policyDetails.Any(p => + p.OrganizationUserStatus is OrganizationUserStatusType.Accepted or OrganizationUserStatusType.Confirmed), SsoRequired = policyDetails.Any(p => p.OrganizationUserStatus == OrganizationUserStatusType.Confirmed) diff --git a/src/Core/AdminConsole/Repositories/IPolicyRepository.cs b/src/Core/AdminConsole/Repositories/IPolicyRepository.cs index d479809b890f..3abff7f4de4d 100644 --- a/src/Core/AdminConsole/Repositories/IPolicyRepository.cs +++ b/src/Core/AdminConsole/Repositories/IPolicyRepository.cs @@ -44,4 +44,20 @@ public interface IPolicyRepository : IRepository /// associated with the specified users and policy type. /// Task> GetPolicyDetailsByUserIdsAndPolicyType(IEnumerable userIds, PolicyType policyType); + + /// + /// Retrieves policy details for a single user filtered by the specified policy type. + /// + /// + /// Returns policy details only for enabled policies from enabled organizations that support policies. + /// This includes both confirmed users (matched by UserId) and invited users (matched by email). + /// Provider users are identified via the IsProvider flag. + /// + /// The user identifier for which policy details are to be fetched. + /// The type of policy for which the details are required. + /// + /// An asynchronous task that returns a collection of objects containing + /// the policy information associated with the specified user and policy type. + /// + Task> GetPolicyDetailsByUserIdAndPolicyTypeAsync(Guid userId, PolicyType policyType); } diff --git a/src/Identity/IdentityServer/RequestValidators/SsoRequestValidator.cs b/src/Identity/IdentityServer/RequestValidators/SsoRequestValidator.cs index 145ecc873730..2b64616bdbde 100644 --- a/src/Identity/IdentityServer/RequestValidators/SsoRequestValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/SsoRequestValidator.cs @@ -80,7 +80,7 @@ private async Task RequireSsoAuthenticationAsync(User user, string grantTy // Check if user belongs to any organization with an active SSO policy var ssoRequired = _featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements) - ? (await _policyRequirementQuery.GetAsync(user.Id)) + ? (await _policyRequirementQuery.GetAsyncVNext(user.Id)) .SsoRequired : await _policyService.AnyPoliciesApplicableToUserAsync( user.Id, PolicyType.RequireSso, OrganizationUserStatusType.Confirmed); diff --git a/src/Infrastructure.Dapper/AdminConsole/Repositories/PolicyRepository.cs b/src/Infrastructure.Dapper/AdminConsole/Repositories/PolicyRepository.cs index 865c4f8e5cd5..f64e5fa94408 100644 --- a/src/Infrastructure.Dapper/AdminConsole/Repositories/PolicyRepository.cs +++ b/src/Infrastructure.Dapper/AdminConsole/Repositories/PolicyRepository.cs @@ -88,4 +88,19 @@ public async Task> GetPolicyDetailsByOrga return results.ToList(); } } + + public async Task> GetPolicyDetailsByUserIdAndPolicyTypeAsync(Guid userId, PolicyType policyType) + { + await using var connection = new SqlConnection(ConnectionString); + var results = await connection.QueryAsync( + $"[{Schema}].[PolicyDetails_ReadByUserIdPolicyType]", + new + { + UserId = userId, + PolicyType = (byte)policyType + }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } } diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/PolicyRepository.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/PolicyRepository.cs index 894fb255beb3..ff2e350f459a 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/PolicyRepository.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/PolicyRepository.cs @@ -234,4 +234,53 @@ where p.Enabled return allResults.ToList(); } + + public async Task> GetPolicyDetailsByUserIdAndPolicyTypeAsync(Guid userId, PolicyType policyType) + { + using var scope = ServiceScopeFactory.CreateScope(); + var dbContext = GetDatabaseContext(scope); + + // Get user email for invited user matching + var userEmail = await dbContext.Users + .Where(u => u.Id == userId) + .Select(u => u.Email) + .FirstOrDefaultAsync(); + + // Get provider relationships + var providerOrganizationIds = await (from pu in dbContext.ProviderUsers + join po in dbContext.ProviderOrganizations on pu.ProviderId equals po.ProviderId + where pu.UserId == userId + select po.OrganizationId) + .Distinct() + .ToListAsync(); + + var providerSet = new HashSet(providerOrganizationIds); + + // Get organization users (both confirmed/accepted and invited) + var orgUsersQuery = dbContext.OrganizationUsers + .Where(ou => (ou.Status != OrganizationUserStatusType.Invited && ou.UserId == userId) || + (ou.Status == OrganizationUserStatusType.Invited && ou.Email == userEmail)); + + // Join with policies and organizations + var query = from policy in dbContext.Policies + join orgUser in orgUsersQuery on policy.OrganizationId equals orgUser.OrganizationId + join org in dbContext.Organizations on policy.OrganizationId equals org.Id + where policy.Type == policyType + && policy.Enabled + && org.Enabled + && org.UsePolicies + select new PolicyDetails + { + OrganizationUserId = orgUser.Id, + OrganizationId = policy.OrganizationId, + PolicyType = policy.Type, + PolicyData = policy.Data, + OrganizationUserType = orgUser.Type, + OrganizationUserStatus = orgUser.Status, + OrganizationUserPermissionsData = orgUser.Permissions, + IsProvider = providerSet.Contains(policy.OrganizationId) + }; + + return await query.ToListAsync(); + } } diff --git a/src/Sql/dbo/Stored Procedures/PolicyDetails_ReadByUserIdPolicyType.sql b/src/Sql/dbo/Stored Procedures/PolicyDetails_ReadByUserIdPolicyType.sql new file mode 100644 index 000000000000..464da7613c14 --- /dev/null +++ b/src/Sql/dbo/Stored Procedures/PolicyDetails_ReadByUserIdPolicyType.sql @@ -0,0 +1,77 @@ +CREATE PROCEDURE [dbo].[PolicyDetails_ReadByUserIdPolicyType] + @UserId UNIQUEIDENTIFIER, + @PolicyType TINYINT +AS +BEGIN + SET NOCOUNT ON + + DECLARE @UserEmail NVARCHAR(256) + SELECT @UserEmail = Email + FROM + [dbo].[UserView] + WHERE + Id = @UserId + + ;WITH OrgUsers AS + ( + -- Non-invited users (Status != 0): direct UserId match + SELECT + OU.[Id], + OU.[OrganizationId], + OU.[Type], + OU.[Status], + OU.[Permissions] + FROM + [dbo].[OrganizationUserView] OU + WHERE + OU.[Status] != 0 + AND OU.[UserId] = @UserId + + UNION ALL + + -- Invited users (Status = 0): email match + SELECT + OU.[Id], + OU.[OrganizationId], + OU.[Type], + OU.[Status], + OU.[Permissions] + FROM + [dbo].[OrganizationUserView] OU + WHERE + OU.[Status] = 0 + AND OU.[Email] = @UserEmail + AND @UserEmail IS NOT NULL + ), + Providers AS + ( + SELECT + OrganizationId + FROM + [dbo].[UserProviderAccessView] + WHERE + UserId = @UserId + ) + SELECT + OU.[Id] AS OrganizationUserId, + P.[OrganizationId], + P.[Type] AS PolicyType, + P.[Data] AS PolicyData, + OU.[Type] AS OrganizationUserType, + OU.[Status] AS OrganizationUserStatus, + OU.[Permissions] AS OrganizationUserPermissionsData, + CASE WHEN PR.[OrganizationId] IS NULL THEN 0 ELSE 1 END AS IsProvider + FROM + [dbo].[PolicyView] P + INNER JOIN + OrgUsers OU ON P.[OrganizationId] = OU.[OrganizationId] + INNER JOIN + [dbo].[OrganizationView] O ON P.[OrganizationId] = O.[Id] + LEFT JOIN + Providers PR ON PR.[OrganizationId] = OU.[OrganizationId] + WHERE + P.[Type] = @PolicyType + AND P.[Enabled] = 1 + AND O.[Enabled] = 1 + AND O.[UsePolicies] = 1 +END diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyRequirementQueryTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyRequirementQueryTests.cs index e652181a461d..28b06833b732 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyRequirementQueryTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyRequirementQueryTests.cs @@ -67,6 +67,60 @@ public async Task GetAsync_HandlesNoPolicies(Guid userId) Assert.Empty(requirement.Policies); } + [Theory, BitAutoData] + public async Task GetAsyncVNext_CallsEnforceCallback(Guid userId) + { + // Arrange policies + var policyRepository = Substitute.For(); + var thisPolicy = new PolicyDetails { PolicyType = PolicyType.SingleOrg }; + var otherPolicy = new PolicyDetails { PolicyType = PolicyType.SingleOrg }; + policyRepository.GetPolicyDetailsByUserIdAndPolicyTypeAsync(userId, PolicyType.SingleOrg) + .Returns([thisPolicy, otherPolicy]); + + // Arrange a substitute Enforce function so that we can inspect the received calls + var callback = Substitute.For>(); + callback(Arg.Any()).Returns(x => x.Arg() == thisPolicy); + + // Arrange the sut + var factory = new TestPolicyRequirementFactory(callback); + var sut = new PolicyRequirementQuery(policyRepository, [factory]); + + // Act + var requirement = await sut.GetAsyncVNext(userId); + + // Assert + Assert.Contains(thisPolicy, requirement.Policies); + Assert.DoesNotContain(otherPolicy, requirement.Policies); + callback.Received()(Arg.Is(thisPolicy)); + callback.Received()(Arg.Is(otherPolicy)); + } + + [Theory, BitAutoData] + public async Task GetAsyncVNext_ThrowsIfNoFactoryRegistered(Guid userId) + { + var policyRepository = Substitute.For(); + var sut = new PolicyRequirementQuery(policyRepository, []); + + var exception = await Assert.ThrowsAsync(() + => sut.GetAsyncVNext(userId)); + Assert.Contains("No Requirement Factory found", exception.Message); + } + + [Theory, BitAutoData] + public async Task GetAsyncVNext_HandlesNoPolicies(Guid userId) + { + var policyRepository = Substitute.For(); + policyRepository.GetPolicyDetailsByUserIdAndPolicyTypeAsync(userId, PolicyType.SingleOrg) + .Returns([]); + + var factory = new TestPolicyRequirementFactory(x => x.IsProvider); + var sut = new PolicyRequirementQuery(policyRepository, [factory]); + + var requirement = await sut.GetAsyncVNext(userId); + + Assert.Empty(requirement.Policies); + } + [Theory, BitAutoData] public async Task GetAsync_WithMultipleUserIds_ReturnsRequirementPerUser(Guid userIdA, Guid userIdB) { diff --git a/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/PolicyRepository/GetPolicyDetailsByUserIdAndPolicyTypeTests.cs b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/PolicyRepository/GetPolicyDetailsByUserIdAndPolicyTypeTests.cs new file mode 100644 index 000000000000..f6c6b696ec21 --- /dev/null +++ b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/PolicyRepository/GetPolicyDetailsByUserIdAndPolicyTypeTests.cs @@ -0,0 +1,495 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Enums; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Repositories; +using Xunit; + +namespace Bit.Infrastructure.IntegrationTest.AdminConsole.Repositories.PolicyRepository; + +public class GetPolicyDetailsByUserIdAndPolicyTypeTests +{ + [Theory] + [DatabaseData] + public async Task GetPolicyDetailsByUserIdAndPolicyTypeAsync_WithConfirmedUser_ReturnsPolicyDetails( + IUserRepository userRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IPolicyRepository policyRepository) + { + // Arrange + var user = await userRepository.CreateAsync(GetDefaultUser()); + var org = await CreateEnterpriseOrgAsync(organizationRepository); + var policy = await policyRepository.CreateAsync(new Policy + { + OrganizationId = org.Id, + Type = PolicyType.TwoFactorAuthentication, + Enabled = true + }); + + var customPermissions = "{\"accessReports\":true,\"manageGroups\":false}"; + var orgUser = await organizationUserRepository.CreateAsync(new OrganizationUser + { + OrganizationId = org.Id, + UserId = user.Id, + Status = OrganizationUserStatusType.Confirmed, + Type = OrganizationUserType.Custom, + Permissions = customPermissions + }); + + // Act + var results = await policyRepository.GetPolicyDetailsByUserIdAndPolicyTypeAsync( + user.Id, + PolicyType.TwoFactorAuthentication); + + // Assert + var resultsList = results.ToList(); + var result = Assert.Single(resultsList); + Assert.Equal(orgUser.Id, result.OrganizationUserId); + Assert.Equal(org.Id, result.OrganizationId); + Assert.Equal(PolicyType.TwoFactorAuthentication, result.PolicyType); + Assert.Equal(policy.Data, result.PolicyData); + Assert.Equal(OrganizationUserStatusType.Confirmed, result.OrganizationUserStatus); + Assert.Equal(OrganizationUserType.Custom, result.OrganizationUserType); + Assert.Equal(customPermissions, result.OrganizationUserPermissionsData); + Assert.False(result.IsProvider); + + // Cleanup + await organizationRepository.DeleteAsync(org); + await userRepository.DeleteAsync(user); + } + + [Theory] + [DatabaseData] + public async Task GetPolicyDetailsByUserIdAndPolicyTypeAsync_WithAcceptedUser_ReturnsPolicyDetails( + IUserRepository userRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IPolicyRepository policyRepository) + { + // Arrange + var user = await userRepository.CreateAsync(GetDefaultUser()); + var org = await CreateEnterpriseOrgAsync(organizationRepository); + var policy = await policyRepository.CreateAsync(new Policy + { + OrganizationId = org.Id, + Type = PolicyType.MasterPassword, + Data = "{\"minComplexity\":4}", + Enabled = true + }); + var orgUser = await organizationUserRepository.CreateAsync(GetAcceptedOrganizationUser(org, user)); + + // Act + var results = await policyRepository.GetPolicyDetailsByUserIdAndPolicyTypeAsync( + user.Id, + PolicyType.MasterPassword); + + // Assert + var resultsList = results.ToList(); + var result = Assert.Single(resultsList); + Assert.Equal(orgUser.Id, result.OrganizationUserId); + Assert.Equal(org.Id, result.OrganizationId); + Assert.Equal(PolicyType.MasterPassword, result.PolicyType); + Assert.Equal(policy.Data, result.PolicyData); + Assert.Equal(OrganizationUserStatusType.Accepted, result.OrganizationUserStatus); + Assert.False(result.IsProvider); + + // Cleanup + await organizationRepository.DeleteAsync(org); + await userRepository.DeleteAsync(user); + } + + [Theory] + [DatabaseData] + public async Task GetPolicyDetailsByUserIdAndPolicyTypeAsync_WithInvitedUser_ReturnsPolicyDetails( + IUserRepository userRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IPolicyRepository policyRepository) + { + // Arrange + var user = await userRepository.CreateAsync(GetDefaultUser()); + var org = await CreateEnterpriseOrgAsync(organizationRepository); + var policy = await policyRepository.CreateAsync(new Policy + { + OrganizationId = org.Id, + Type = PolicyType.RequireSso, + Enabled = true + }); + var orgUser = await organizationUserRepository.CreateAsync(GetInvitedOrganizationUser(org, user)); + + // Act + var results = await policyRepository.GetPolicyDetailsByUserIdAndPolicyTypeAsync( + user.Id, + PolicyType.RequireSso); + + // Assert + var resultsList = results.ToList(); + var result = Assert.Single(resultsList); + Assert.Equal(orgUser.Id, result.OrganizationUserId); + Assert.Equal(org.Id, result.OrganizationId); + Assert.Equal(PolicyType.RequireSso, result.PolicyType); + Assert.Equal(OrganizationUserStatusType.Invited, result.OrganizationUserStatus); + Assert.False(result.IsProvider); + + // Cleanup + await organizationRepository.DeleteAsync(org); + await userRepository.DeleteAsync(user); + } + + [Theory] + [DatabaseData] + public async Task GetPolicyDetailsByUserIdAndPolicyTypeAsync_WithMultipleOrganizations_ReturnsAllPolicyDetails( + IUserRepository userRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IPolicyRepository policyRepository) + { + // Arrange + var user = await userRepository.CreateAsync(GetDefaultUser()); + var org1 = await CreateEnterpriseOrgAsync(organizationRepository); + var org2 = await CreateEnterpriseOrgAsync(organizationRepository); + + await policyRepository.CreateAsync(new Policy + { + OrganizationId = org1.Id, + Type = PolicyType.SingleOrg, + Enabled = true + }); + await policyRepository.CreateAsync(new Policy + { + OrganizationId = org2.Id, + Type = PolicyType.SingleOrg, + Enabled = true + }); + + var orgUser1 = await organizationUserRepository.CreateAsync(GetConfirmedOrganizationUser(org1, user)); + var orgUser2 = await organizationUserRepository.CreateAsync(GetConfirmedOrganizationUser(org2, user)); + + // Act + var results = await policyRepository.GetPolicyDetailsByUserIdAndPolicyTypeAsync( + user.Id, + PolicyType.SingleOrg); + + // Assert + var resultsList = results.ToList(); + Assert.Equal(2, resultsList.Count); + + var result1 = resultsList.First(r => r.OrganizationId == org1.Id); + Assert.Equal(orgUser1.Id, result1.OrganizationUserId); + Assert.Equal(PolicyType.SingleOrg, result1.PolicyType); + + var result2 = resultsList.First(r => r.OrganizationId == org2.Id); + Assert.Equal(orgUser2.Id, result2.OrganizationUserId); + Assert.Equal(PolicyType.SingleOrg, result2.PolicyType); + + // Cleanup + await organizationRepository.DeleteAsync(org1); + await organizationRepository.DeleteAsync(org2); + await userRepository.DeleteAsync(user); + } + + [Theory] + [DatabaseData] + public async Task GetPolicyDetailsByUserIdAndPolicyTypeAsync_WithPolicyTypeFiltering_ReturnsOnlySpecifiedType( + IUserRepository userRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IPolicyRepository policyRepository) + { + // Arrange + var user = await userRepository.CreateAsync(GetDefaultUser()); + var org = await CreateEnterpriseOrgAsync(organizationRepository); + + // Create multiple enabled policies + await policyRepository.CreateAsync(new Policy + { + OrganizationId = org.Id, + Type = PolicyType.TwoFactorAuthentication, + Enabled = true + }); + await policyRepository.CreateAsync(new Policy + { + OrganizationId = org.Id, + Type = PolicyType.MasterPassword, + Enabled = true + }); + await policyRepository.CreateAsync(new Policy + { + OrganizationId = org.Id, + Type = PolicyType.SingleOrg, + Enabled = true + }); + + await organizationUserRepository.CreateAsync(GetConfirmedOrganizationUser(org, user)); + + // Act - Request only TwoFactorAuthentication policy + var results = await policyRepository.GetPolicyDetailsByUserIdAndPolicyTypeAsync( + user.Id, + PolicyType.TwoFactorAuthentication); + + // Assert + var resultsList = results.ToList(); + var result = Assert.Single(resultsList); + Assert.Equal(PolicyType.TwoFactorAuthentication, result.PolicyType); + + // Cleanup + await organizationRepository.DeleteAsync(org); + await userRepository.DeleteAsync(user); + } + + [Theory] + [DatabaseData] + public async Task GetPolicyDetailsByUserIdAndPolicyTypeAsync_WithDisabledPolicy_ReturnsEmpty( + IUserRepository userRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IPolicyRepository policyRepository) + { + // Arrange + var user = await userRepository.CreateAsync(GetDefaultUser()); + var org = await CreateEnterpriseOrgAsync(organizationRepository); + + await policyRepository.CreateAsync(new Policy + { + OrganizationId = org.Id, + Type = PolicyType.DisableSend, + Enabled = false // Disabled policy + }); + + await organizationUserRepository.CreateAsync(GetConfirmedOrganizationUser(org, user)); + + // Act + var results = await policyRepository.GetPolicyDetailsByUserIdAndPolicyTypeAsync( + user.Id, + PolicyType.DisableSend); + + // Assert + Assert.Empty(results); + + // Cleanup + await organizationRepository.DeleteAsync(org); + await userRepository.DeleteAsync(user); + } + + [Theory] + [DatabaseData] + public async Task GetPolicyDetailsByUserIdAndPolicyTypeAsync_WithDisabledOrganization_ReturnsEmpty( + IUserRepository userRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IPolicyRepository policyRepository) + { + // Arrange + var user = await userRepository.CreateAsync(GetDefaultUser()); + var org = await organizationRepository.CreateAsync(new Organization + { + Name = "Test Organization", + BillingEmail = $"billing+{Guid.NewGuid()}@example.com", + Plan = "EnterpriseAnnually", + PlanType = PlanType.EnterpriseAnnually, + Seats = 10, + MaxCollections = 10, + UsePolicies = true, + UseDirectory = true, + UseTotp = true, + Use2fa = true, + UseApi = true, + SelfHost = true, + Enabled = false // Disabled organization + }); + + await policyRepository.CreateAsync(new Policy + { + OrganizationId = org.Id, + Type = PolicyType.PasswordGenerator, + Enabled = true + }); + + await organizationUserRepository.CreateAsync(GetConfirmedOrganizationUser(org, user)); + + // Act + var results = await policyRepository.GetPolicyDetailsByUserIdAndPolicyTypeAsync( + user.Id, + PolicyType.PasswordGenerator); + + // Assert + Assert.Empty(results); + + // Cleanup + await organizationRepository.DeleteAsync(org); + await userRepository.DeleteAsync(user); + } + + [Theory] + [DatabaseData] + public async Task GetPolicyDetailsByUserIdAndPolicyTypeAsync_WithOrganizationNotUsingPolicies_ReturnsEmpty( + IUserRepository userRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IPolicyRepository policyRepository) + { + // Arrange + var user = await userRepository.CreateAsync(GetDefaultUser()); + var org = await organizationRepository.CreateAsync(new Organization + { + Name = "Test Organization", + BillingEmail = $"billing+{Guid.NewGuid()}@example.com", + Plan = "EnterpriseAnnually", + PlanType = PlanType.EnterpriseAnnually, + Seats = 10, + MaxCollections = 10, + UsePolicies = false, // Not using policies + UseDirectory = true, + UseTotp = true, + Use2fa = true, + UseApi = true, + SelfHost = true, + Enabled = true + }); + + await policyRepository.CreateAsync(new Policy + { + OrganizationId = org.Id, + Type = PolicyType.MaximumVaultTimeout, + Enabled = true + }); + + await organizationUserRepository.CreateAsync(GetConfirmedOrganizationUser(org, user)); + + // Act + var results = await policyRepository.GetPolicyDetailsByUserIdAndPolicyTypeAsync( + user.Id, + PolicyType.MaximumVaultTimeout); + + // Assert + Assert.Empty(results); + + // Cleanup + await organizationRepository.DeleteAsync(org); + await userRepository.DeleteAsync(user); + } + + [Theory] + [DatabaseData] + public async Task GetPolicyDetailsByUserIdAndPolicyTypeAsync_WithProviderUser_SetsIsProviderFlag( + IUserRepository userRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IProviderRepository providerRepository, + IProviderUserRepository providerUserRepository, + IProviderOrganizationRepository providerOrganizationRepository, + IPolicyRepository policyRepository) + { + // Arrange + var user = await userRepository.CreateAsync(GetDefaultUser()); + var org = await CreateEnterpriseOrgAsync(organizationRepository); + + await policyRepository.CreateAsync(new Policy + { + OrganizationId = org.Id, + Type = PolicyType.SingleOrg, + Enabled = true + }); + + await organizationUserRepository.CreateAsync(GetConfirmedOrganizationUser(org, user)); + + var provider = await providerRepository.CreateAsync(new Provider + { + Name = "Test Provider", + BusinessName = "Test Provider Business", + BusinessAddress1 = "123 Test St", + BusinessAddress2 = "Suite 456", + BusinessAddress3 = "Floor 7", + BusinessCountry = "US", + BusinessTaxNumber = "123456789", + BillingEmail = $"billing+{Guid.NewGuid()}@example.com" + }); + + await providerUserRepository.CreateAsync(new ProviderUser + { + ProviderId = provider.Id, + UserId = user.Id, + Status = ProviderUserStatusType.Confirmed, + Type = ProviderUserType.ProviderAdmin + }); + + await providerOrganizationRepository.CreateAsync(new ProviderOrganization + { + ProviderId = provider.Id, + OrganizationId = org.Id + }); + + // Act + var results = await policyRepository.GetPolicyDetailsByUserIdAndPolicyTypeAsync( + user.Id, + PolicyType.SingleOrg); + + // Assert + var resultsList = results.ToList(); + var result = Assert.Single(resultsList); + Assert.True(result.IsProvider); + Assert.Equal(org.Id, result.OrganizationId); + + // Cleanup + await organizationRepository.DeleteAsync(org); + await userRepository.DeleteAsync(user); + } + + private static async Task CreateEnterpriseOrgAsync(IOrganizationRepository orgRepo) + { + return await orgRepo.CreateAsync(new Organization + { + Name = "Test Organization", + BillingEmail = $"billing+{Guid.NewGuid()}@example.com", + Plan = "EnterpriseAnnually", + PlanType = PlanType.EnterpriseAnnually, + Seats = 10, + MaxCollections = 10, + UsePolicies = true, + UseDirectory = true, + UseTotp = true, + Use2fa = true, + UseApi = true, + SelfHost = true, + Enabled = true + }); + } + + private static User GetDefaultUser() => new() + { + Name = $"Test User {Guid.NewGuid()}", + Email = $"test+{Guid.NewGuid()}@example.com", + ApiKey = $"test.api.key.{Guid.NewGuid()}"[..30], + SecurityStamp = Guid.NewGuid().ToString() + }; + + private static OrganizationUser GetConfirmedOrganizationUser(Organization organization, User user) => new() + { + OrganizationId = organization.Id, + UserId = user.Id, + Status = OrganizationUserStatusType.Confirmed, + Type = OrganizationUserType.User + }; + + private static OrganizationUser GetAcceptedOrganizationUser(Organization organization, User user) => new() + { + OrganizationId = organization.Id, + UserId = user.Id, + Status = OrganizationUserStatusType.Accepted, + Type = OrganizationUserType.User + }; + + private static OrganizationUser GetInvitedOrganizationUser(Organization organization, User user) => new() + { + OrganizationId = organization.Id, + UserId = null, // Invited users don't have UserId + Email = user.Email, + Status = OrganizationUserStatusType.Invited, + Type = OrganizationUserType.User + }; +} diff --git a/util/Migrator/DbScripts/2026-03-07_00_PolicyDetails_ReadByUserIdPolicyType.sql b/util/Migrator/DbScripts/2026-03-07_00_PolicyDetails_ReadByUserIdPolicyType.sql new file mode 100644 index 000000000000..403a61f09a91 --- /dev/null +++ b/util/Migrator/DbScripts/2026-03-07_00_PolicyDetails_ReadByUserIdPolicyType.sql @@ -0,0 +1,77 @@ +CREATE OR ALTER PROCEDURE [dbo].[PolicyDetails_ReadByUserIdPolicyType] + @UserId UNIQUEIDENTIFIER, + @PolicyType TINYINT +AS +BEGIN + SET NOCOUNT ON + + DECLARE @UserEmail NVARCHAR(256) + SELECT @UserEmail = Email + FROM + [dbo].[UserView] + WHERE + Id = @UserId + + ;WITH OrgUsers AS + ( + -- Non-invited users (Status != 0): direct UserId match + SELECT + OU.[Id], + OU.[OrganizationId], + OU.[Type], + OU.[Status], + OU.[Permissions] + FROM + [dbo].[OrganizationUserView] OU + WHERE + OU.[Status] != 0 + AND OU.[UserId] = @UserId + + UNION ALL + + -- Invited users (Status = 0): email match + SELECT + OU.[Id], + OU.[OrganizationId], + OU.[Type], + OU.[Status], + OU.[Permissions] + FROM + [dbo].[OrganizationUserView] OU + WHERE + OU.[Status] = 0 + AND OU.[Email] = @UserEmail + AND @UserEmail IS NOT NULL + ), + Providers AS + ( + SELECT + OrganizationId + FROM + [dbo].[UserProviderAccessView] + WHERE + UserId = @UserId + ) + SELECT + OU.[Id] AS OrganizationUserId, + P.[OrganizationId], + P.[Type] AS PolicyType, + P.[Data] AS PolicyData, + OU.[Type] AS OrganizationUserType, + OU.[Status] AS OrganizationUserStatus, + OU.[Permissions] AS OrganizationUserPermissionsData, + CASE WHEN PR.[OrganizationId] IS NULL THEN 0 ELSE 1 END AS IsProvider + FROM + [dbo].[PolicyView] P + INNER JOIN + OrgUsers OU ON P.[OrganizationId] = OU.[OrganizationId] + INNER JOIN + [dbo].[OrganizationView] O ON P.[OrganizationId] = O.[Id] + LEFT JOIN + Providers PR ON PR.[OrganizationId] = OU.[OrganizationId] + WHERE + P.[Type] = @PolicyType + AND P.[Enabled] = 1 + AND O.[Enabled] = 1 + AND O.[UsePolicies] = 1 +END