diff --git a/src/Api/Billing/Models/Requests/Premium/UpgradePremiumToOrganizationRequest.cs b/src/Api/Billing/Models/Requests/Premium/UpgradePremiumToOrganizationRequest.cs index 62a5c2adff56..112fa69010b2 100644 --- a/src/Api/Billing/Models/Requests/Premium/UpgradePremiumToOrganizationRequest.cs +++ b/src/Api/Billing/Models/Requests/Premium/UpgradePremiumToOrganizationRequest.cs @@ -26,7 +26,7 @@ public class UpgradePremiumToOrganizationRequest public required ProductTierType TargetProductTierType { get; set; } [Required] - public required MinimalBillingAddressRequest BillingAddress { get; set; } + public required CheckoutBillingAddressRequest BillingAddress { get; set; } private PlanType PlanType { diff --git a/src/Core/Billing/Premium/Commands/UpgradePremiumToOrganizationCommand.cs b/src/Core/Billing/Premium/Commands/UpgradePremiumToOrganizationCommand.cs index ffb7993c75e3..a7bf3e20cea5 100644 --- a/src/Core/Billing/Premium/Commands/UpgradePremiumToOrganizationCommand.cs +++ b/src/Core/Billing/Premium/Commands/UpgradePremiumToOrganizationCommand.cs @@ -2,6 +2,7 @@ using Bit.Core.Billing.Commands; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; using Bit.Core.Entities; @@ -12,6 +13,8 @@ using Bit.Core.Utilities; using Microsoft.Extensions.Logging; using Stripe; +using CountryAbbreviations = Bit.Core.Constants.CountryAbbreviations; +using TaxExempt = Bit.Core.Billing.Constants.StripeConstants.TaxExempt; namespace Bit.Core.Billing.Premium.Commands; /// @@ -40,7 +43,7 @@ Task> Run( string encryptedPrivateKey, string? collectionName, PlanType targetPlanType, - Payment.Models.BillingAddress billingAddress); + BillingAddress billingAddress); } public class UpgradePremiumToOrganizationCommand( @@ -65,7 +68,7 @@ public Task> Run( string encryptedPrivateKey, string? collectionName, PlanType targetPlanType, - Payment.Models.BillingAddress billingAddress) => HandleAsync(async () => + BillingAddress billingAddress) => HandleAsync(async () => { // Validate that the user has an active Premium subscription if (user is not { Premium: true, GatewaySubscriptionId: not null and not "" }) @@ -198,9 +201,16 @@ public Task> Run( { Country = billingAddress.Country, PostalCode = billingAddress.PostalCode - } + }, + TaxExempt = billingAddress.Country != CountryAbbreviations.UnitedStates ? TaxExempt.Reverse : TaxExempt.None }); + // Add tax ID to customer for accurate tax calculation if provided + if (billingAddress.TaxId != null) + { + await AddTaxIdToCustomerAsync(user, billingAddress.TaxId); + } + // Update the subscription in Stripe await stripeAdapter.UpdateSubscriptionAsync(currentSubscription.Id, subscriptionUpdateOptions); @@ -271,4 +281,26 @@ await organizationApiKeyRepository.CreateAsync(new OrganizationApiKey return organization.Id; }); + + /// + /// Adds a tax ID to the Stripe customer for accurate tax calculation. + /// If the tax ID is a Spanish NIF, also adds the corresponding EU VAT ID. + /// + /// The user whose Stripe customer will be updated with the tax ID. + /// The tax ID to add, including the type and value. + private async Task AddTaxIdToCustomerAsync(User user, TaxID taxId) + { + await stripeAdapter.CreateTaxIdAsync(user.GatewayCustomerId, + new TaxIdCreateOptions { Type = taxId.Code, Value = taxId.Value }); + + if (taxId.Code == StripeConstants.TaxIdType.SpanishNIF) + { + await stripeAdapter.CreateTaxIdAsync(user.GatewayCustomerId, + new TaxIdCreateOptions + { + Type = StripeConstants.TaxIdType.EUVAT, + Value = $"ES{taxId.Value}" + }); + } + } } diff --git a/test/Api.Test/Billing/Models/Requests/UpgradePremiumToOrganizationRequestTests.cs b/test/Api.Test/Billing/Models/Requests/UpgradePremiumToOrganizationRequestTests.cs index b9cb7754d8aa..d43084cdf710 100644 --- a/test/Api.Test/Billing/Models/Requests/UpgradePremiumToOrganizationRequestTests.cs +++ b/test/Api.Test/Billing/Models/Requests/UpgradePremiumToOrganizationRequestTests.cs @@ -22,7 +22,7 @@ public void ToDomain_ValidTierTypes_ReturnsPlanType(ProductTierType tierType, Pl EncryptedPrivateKey = "encrypted-private-key", CollectionName = "Default Collection", TargetProductTierType = tierType, - BillingAddress = new MinimalBillingAddressRequest + BillingAddress = new CheckoutBillingAddressRequest { Country = "US", PostalCode = "12345" @@ -56,7 +56,7 @@ public void ToDomain_InvalidTierTypes_ThrowsInvalidOperationException(ProductTie PublicKey = "public-key", EncryptedPrivateKey = "encrypted-private-key", TargetProductTierType = tierType, - BillingAddress = new MinimalBillingAddressRequest + BillingAddress = new CheckoutBillingAddressRequest { Country = "US", PostalCode = "12345" @@ -67,4 +67,53 @@ public void ToDomain_InvalidTierTypes_ThrowsInvalidOperationException(ProductTie var exception = Assert.Throws(() => sut.ToDomain()); Assert.Contains($"Cannot upgrade Premium subscription to {tierType} plan", exception.Message); } + + [Theory] + [InlineData(ProductTierType.Teams, PlanType.TeamsAnnually, "DE", "10115", "eu_vat", "DE123456789")] + [InlineData(ProductTierType.Enterprise, PlanType.EnterpriseAnnually, "FR", "75001", "eu_vat", "FR12345678901")] + public void ToDomain_BusinessPlansWithNonUsTaxId_IncludesTaxIdInBillingAddress( + ProductTierType tierType, + PlanType expectedPlanType, + string country, + string postalCode, + string taxIdCode, + string taxIdValue) + { + // Arrange + var sut = new UpgradePremiumToOrganizationRequest + { + OrganizationName = "International Business", + Key = "encrypted-key", + TargetProductTierType = tierType, + PublicKey = "public-key", + EncryptedPrivateKey = "encrypted-private-key", + CollectionName = "Default Collection", + BillingAddress = new CheckoutBillingAddressRequest + { + Country = country, + PostalCode = postalCode, + TaxId = new CheckoutBillingAddressRequest.TaxIdRequest + { + Code = taxIdCode, + Value = taxIdValue + } + } + }; + + // Act + var (organizationName, key, publicKey, encryptedPrivateKey, collectionName, planType, billingAddress) = sut.ToDomain(); + + // Assert + Assert.Equal("International Business", organizationName); + Assert.Equal("encrypted-key", key); + Assert.Equal("public-key", publicKey); + Assert.Equal("encrypted-private-key", encryptedPrivateKey); + Assert.Equal("Default Collection", collectionName); + Assert.Equal(expectedPlanType, planType); + Assert.Equal(country, billingAddress.Country); + Assert.Equal(postalCode, billingAddress.PostalCode); + Assert.NotNull(billingAddress.TaxId); + Assert.Equal(taxIdCode, billingAddress.TaxId.Code); + Assert.Equal(taxIdValue, billingAddress.TaxId.Value); + } } diff --git a/test/Core.Test/Billing/Premium/Commands/UpgradePremiumToOrganizationCommandTests.cs b/test/Core.Test/Billing/Premium/Commands/UpgradePremiumToOrganizationCommandTests.cs index 158f08f048c7..181a5e6d33e5 100644 --- a/test/Core.Test/Billing/Premium/Commands/UpgradePremiumToOrganizationCommandTests.cs +++ b/test/Core.Test/Billing/Premium/Commands/UpgradePremiumToOrganizationCommandTests.cs @@ -1202,4 +1202,200 @@ await _collectionRepository.Received(1).CreateAsync( Arg.Any>(), Arg.Any>()); } + + [Theory, BitAutoData] + public async Task Run_WithNoTaxId_SetsTaxExemptToNone_DoesNotCreateTaxId(User user) + { + // Arrange + user.Premium = true; + user.GatewaySubscriptionId = "sub_123"; + user.GatewayCustomerId = "cus_123"; + + var mockSubscription = new Subscription + { + Id = "sub_123", + Items = new StripeList + { + Data = new List + { + new SubscriptionItem + { + Id = "si_premium", + Price = new Price { Id = "premium-annually" } + } + } + }, + Metadata = new Dictionary() + }; + + var mockPremiumPlans = CreateTestPremiumPlansList(); + var mockPlan = CreateTestPlan(PlanType.TeamsAnnually, stripeSeatPlanId: "teams-seat-annually"); + + _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(mockSubscription); + _pricingClient.ListPremiumPlans().Returns(mockPremiumPlans); + _pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(mockPlan); + _stripeAdapter.UpdateSubscriptionAsync(Arg.Any(), Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(Task.FromResult(new Customer())); + _organizationRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _organizationApiKeyRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _organizationUserRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _applicationCacheService.UpsertOrganizationAbilityAsync(Arg.Any()).Returns(Task.CompletedTask); + _userService.SaveUserAsync(user).Returns(Task.CompletedTask); + + var billingAddress = new Core.Billing.Payment.Models.BillingAddress + { + Country = "US", + PostalCode = "12345", + TaxId = null + }; + + // Act + var result = await _command.Run(user, "My Organization", "encrypted-key", "public-key", "encrypted-private-key", "Default Collection", PlanType.TeamsAnnually, billingAddress); + + // Assert + Assert.True(result.IsT0); + await _stripeAdapter.Received(1).UpdateCustomerAsync( + "cus_123", + Arg.Is(options => + options.TaxExempt == StripeConstants.TaxExempt.None)); + await _stripeAdapter.DidNotReceive().CreateTaxIdAsync(Arg.Any(), Arg.Any()); + } + + [Theory, BitAutoData] + public async Task Run_WithTaxId_SetsTaxExemptToReverse_CreatesOneTaxId(User user) + { + // Arrange + user.Premium = true; + user.GatewaySubscriptionId = "sub_123"; + user.GatewayCustomerId = "cus_123"; + + var mockSubscription = new Subscription + { + Id = "sub_123", + Items = new StripeList + { + Data = new List + { + new SubscriptionItem + { + Id = "si_premium", + Price = new Price { Id = "premium-annually" } + } + } + }, + Metadata = new Dictionary() + }; + + var mockPremiumPlans = CreateTestPremiumPlansList(); + var mockPlan = CreateTestPlan(PlanType.TeamsAnnually, stripeSeatPlanId: "teams-seat-annually"); + + _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(mockSubscription); + _pricingClient.ListPremiumPlans().Returns(mockPremiumPlans); + _pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(mockPlan); + _stripeAdapter.UpdateSubscriptionAsync(Arg.Any(), Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(Task.FromResult(new Customer())); + _stripeAdapter.CreateTaxIdAsync(Arg.Any(), Arg.Any()).Returns(new TaxId()); + _organizationRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _organizationApiKeyRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _organizationUserRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _applicationCacheService.UpsertOrganizationAbilityAsync(Arg.Any()).Returns(Task.CompletedTask); + _userService.SaveUserAsync(user).Returns(Task.CompletedTask); + + var billingAddress = new Core.Billing.Payment.Models.BillingAddress + { + Country = "DE", + PostalCode = "10115", + TaxId = new Core.Billing.Payment.Models.TaxID("eu_vat", "DE123456789") + }; + + // Act + var result = await _command.Run(user, "My Organization", "encrypted-key", "public-key", "encrypted-private-key", "Default Collection", PlanType.TeamsAnnually, billingAddress); + + // Assert + Assert.True(result.IsT0); + await _stripeAdapter.Received(1).UpdateCustomerAsync( + "cus_123", + Arg.Is(options => + options.TaxExempt == StripeConstants.TaxExempt.Reverse)); + await _stripeAdapter.Received(1).CreateTaxIdAsync( + "cus_123", + Arg.Is(options => + options.Type == "eu_vat" && + options.Value == "DE123456789")); + } + + [Theory, BitAutoData] + public async Task Run_WithSpanishNIF_SetsTaxExemptToReverse_CreatesBothSpanishNIFAndEUVAT(User user) + { + // Arrange + user.Premium = true; + user.GatewaySubscriptionId = "sub_123"; + user.GatewayCustomerId = "cus_123"; + + var mockSubscription = new Subscription + { + Id = "sub_123", + Items = new StripeList + { + Data = new List + { + new SubscriptionItem + { + Id = "si_premium", + Price = new Price { Id = "premium-annually" } + } + } + }, + Metadata = new Dictionary() + }; + + var mockPremiumPlans = CreateTestPremiumPlansList(); + var mockPlan = CreateTestPlan(PlanType.TeamsAnnually, stripeSeatPlanId: "teams-seat-annually"); + + _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(mockSubscription); + _pricingClient.ListPremiumPlans().Returns(mockPremiumPlans); + _pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(mockPlan); + _stripeAdapter.UpdateSubscriptionAsync(Arg.Any(), Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(Task.FromResult(new Customer())); + _stripeAdapter.CreateTaxIdAsync(Arg.Any(), Arg.Any()).Returns(new TaxId()); + _organizationRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _organizationApiKeyRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _organizationUserRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _applicationCacheService.UpsertOrganizationAbilityAsync(Arg.Any()).Returns(Task.CompletedTask); + _userService.SaveUserAsync(user).Returns(Task.CompletedTask); + + var billingAddress = new Core.Billing.Payment.Models.BillingAddress + { + Country = "ES", + PostalCode = "28001", + TaxId = new Core.Billing.Payment.Models.TaxID(StripeConstants.TaxIdType.SpanishNIF, "A12345678") + }; + + // Act + var result = await _command.Run(user, "My Organization", "encrypted-key", "public-key", "encrypted-private-key", "Default Collection", PlanType.TeamsAnnually, billingAddress); + + // Assert + Assert.True(result.IsT0); + + await _stripeAdapter.Received(1).UpdateCustomerAsync( + "cus_123", + Arg.Is(options => + options.TaxExempt == StripeConstants.TaxExempt.Reverse)); + + // Verify Spanish NIF was created + await _stripeAdapter.Received(1).CreateTaxIdAsync( + "cus_123", + Arg.Is(options => + options.Type == StripeConstants.TaxIdType.SpanishNIF && + options.Value == "A12345678")); + + // Verify EU VAT was created with ES prefix + await _stripeAdapter.Received(1).CreateTaxIdAsync( + "cus_123", + Arg.Is(options => + options.Type == StripeConstants.TaxIdType.EUVAT && + options.Value == "ESA12345678")); + + + } }