diff --git a/bitwarden_license/src/Commercial.Core/Billing/Providers/Queries/GetProviderWarningsQuery.cs b/bitwarden_license/src/Commercial.Core/Billing/Providers/Queries/GetProviderWarningsQuery.cs index e140a1384158..8d9d72a9c151 100644 --- a/bitwarden_license/src/Commercial.Core/Billing/Providers/Queries/GetProviderWarningsQuery.cs +++ b/bitwarden_license/src/Commercial.Core/Billing/Providers/Queries/GetProviderWarningsQuery.cs @@ -3,13 +3,13 @@ using Bit.Core.Billing.Providers.Models; using Bit.Core.Billing.Providers.Queries; using Bit.Core.Billing.Services; +using Bit.Core.Billing.Tax.Utilities; using Bit.Core.Context; using Stripe; using Stripe.Tax; namespace Bit.Commercial.Core.Billing.Providers.Queries; -using static Bit.Core.Constants; using static StripeConstants; using SuspensionWarning = ProviderWarnings.SuspensionWarning; using TaxIdWarning = ProviderWarnings.TaxIdWarning; @@ -61,7 +61,7 @@ await subscriberService.GetSubscription(provider, Provider provider, Customer customer) { - if (customer.Address?.Country == CountryAbbreviations.UnitedStates) + if (TaxHelpers.IsDirectTaxCountry(customer.Address?.Country)) { return null; } diff --git a/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/ProviderBillingService.cs b/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/ProviderBillingService.cs index 19ee33c68260..82c39410b8e7 100644 --- a/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/ProviderBillingService.cs +++ b/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/ProviderBillingService.cs @@ -3,7 +3,6 @@ using System.Globalization; using Bit.Commercial.Core.Billing.Providers.Models; -using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Enums.Provider; @@ -19,6 +18,7 @@ using Bit.Core.Billing.Providers.Repositories; using Bit.Core.Billing.Providers.Services; using Bit.Core.Billing.Services; +using Bit.Core.Billing.Tax.Utilities; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; @@ -34,7 +34,6 @@ namespace Bit.Commercial.Core.Billing.Providers.Services; -using static Constants; using static StripeConstants; public class ProviderBillingService( @@ -267,10 +266,13 @@ await subscriberService.GetCustomerOrThrow(provider, ] }; - if (providerCustomer.Address is not { Country: CountryAbbreviations.UnitedStates }) + var determinedTaxExemptStatus = TaxHelpers.DetermineTaxExemptStatus(providerCustomer.Address?.Country, providerCustomer.TaxExempt); + customerCreateOptions.TaxExempt = providerCustomer switch { - customerCreateOptions.TaxExempt = TaxExempt.Reverse; - } + { Address.Country: not null and not "", TaxExempt: var customerTaxExemptStatus } when + determinedTaxExemptStatus != customerTaxExemptStatus => determinedTaxExemptStatus, + _ => providerCustomer.TaxExempt + }; var customer = await stripeAdapter.CreateCustomerAsync(customerCreateOptions); @@ -467,6 +469,7 @@ public async Task SetupCustomer( TokenizedPaymentMethod paymentMethod, BillingAddress billingAddress) { + var determinedTaxExemptStatus = TaxHelpers.DetermineTaxExemptStatus(billingAddress.Country); var options = new CustomerCreateOptions { Address = new AddressOptions @@ -494,7 +497,7 @@ public async Task SetupCustomer( ] }, Metadata = new Dictionary { { "region", globalSettings.BaseServiceUri.CloudRegion } }, - TaxExempt = billingAddress.Country != CountryAbbreviations.UnitedStates ? TaxExempt.Reverse : TaxExempt.None + TaxExempt = determinedTaxExemptStatus }; if (billingAddress.TaxId != null) diff --git a/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Queries/GetProviderWarningsQueryTests.cs b/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Queries/GetProviderWarningsQueryTests.cs index 96dbacfa92c1..5c7b34a4ee76 100644 --- a/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Queries/GetProviderWarningsQueryTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Queries/GetProviderWarningsQueryTests.cs @@ -520,6 +520,39 @@ public async Task Run_CombinesBothWarningTypes( Assert.Equal(cancelAt, response.Suspension.SubscriptionCancelsAt); } + [Theory, BitAutoData] + public async Task Run_SwissCustomer_NoTaxIdWarning( + Provider provider, + SutProvider sutProvider) + { + provider.Enabled = true; + + sutProvider.GetDependency() + .GetSubscription(provider, Arg.Is(options => + options.Expand.SequenceEqual(_requiredExpansions) + )) + .Returns(new Subscription + { + Status = SubscriptionStatus.Active, + Customer = new Customer + { + TaxIds = new StripeList { Data = [] }, + Address = new Address { Country = "CH" } + } + }); + + sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) + .Returns(new StripeList + { + Data = [new Registration { Country = "CH" }] + }); + + var response = await sutProvider.Sut.Run(provider); + + Assert.Null(response!.TaxId); + } + [Theory, BitAutoData] public async Task Run_USCustomer_NoTaxIdWarning( Provider provider, diff --git a/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/ProviderBillingServiceTests.cs b/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/ProviderBillingServiceTests.cs index 89bb4d189962..2ab194c5b7aa 100644 --- a/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/ProviderBillingServiceTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/ProviderBillingServiceTests.cs @@ -389,6 +389,55 @@ await sutProvider.GetDependency().Received(1).ReplaceAs org => org.GatewayCustomerId == "customer_id")); } + [Theory, BitAutoData] + public async Task CreateCustomer_ForClientOrg_USCustomer_SetsTaxExemptToNone( + Provider provider, + Organization organization, + SutProvider sutProvider) + { + organization.GatewayCustomerId = null; + organization.Name = "Name"; + + var providerCustomer = new Customer + { + Address = new Address + { + Country = "US", + PostalCode = "12345", + Line1 = "123 Main St.", + Line2 = "Unit 4", + City = "Fake Town", + State = "Fake State" + }, + TaxIds = new StripeList + { + Data = + [ + new TaxId { Type = "TYPE", Value = "VALUE" } + ] + }, + TaxExempt = null + }; + + sutProvider.GetDependency().GetCustomerOrThrow(provider, Arg.Is( + options => options.Expand.Contains("tax") && options.Expand.Contains("tax_ids"))) + .Returns(providerCustomer); + + sutProvider.GetDependency().BaseServiceUri + .Returns(new Bit.Core.Settings.GlobalSettings.BaseServiceUriSettings(new Bit.Core.Settings.GlobalSettings()) + { + CloudRegion = "US" + }); + + sutProvider.GetDependency().CreateCustomerAsync(Arg.Any()) + .Returns(new Customer { Id = "customer_id" }); + + await sutProvider.Sut.CreateCustomerForClientOrganization(provider, organization); + + await sutProvider.GetDependency().Received(1).CreateCustomerAsync( + Arg.Is(options => options.TaxExempt == StripeConstants.TaxExempt.None)); + } + #endregion #region GenerateClientInvoiceReport diff --git a/src/Api/AdminConsole/Models/Request/Organizations/OrganizationCreateRequestModel.cs b/src/Api/AdminConsole/Models/Request/Organizations/OrganizationCreateRequestModel.cs index 464ba0c2fdee..5a48aa2fe2f0 100644 --- a/src/Api/AdminConsole/Models/Request/Organizations/OrganizationCreateRequestModel.cs +++ b/src/Api/AdminConsole/Models/Request/Organizations/OrganizationCreateRequestModel.cs @@ -3,8 +3,8 @@ using System.ComponentModel.DataAnnotations; using System.Text.Json.Serialization; -using Bit.Core; using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Tax.Utilities; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Models.Business; @@ -139,7 +139,7 @@ public IEnumerable Validate(ValidationContext validationContex new string[] { nameof(BillingAddressCountry) }); } - if (PlanType != PlanType.Free && BillingAddressCountry == Constants.CountryAbbreviations.UnitedStates && + if (PlanType != PlanType.Free && TaxHelpers.IsDirectTaxCountry(BillingAddressCountry) && string.IsNullOrWhiteSpace(BillingAddressPostalCode)) { yield return new ValidationResult("Zip / postal code is required.", diff --git a/src/Api/Models/Request/Accounts/PremiumRequestModel.cs b/src/Api/Models/Request/Accounts/PremiumRequestModel.cs index 8e9aac8cc272..e70e4b626e8b 100644 --- a/src/Api/Models/Request/Accounts/PremiumRequestModel.cs +++ b/src/Api/Models/Request/Accounts/PremiumRequestModel.cs @@ -2,7 +2,7 @@ #nullable disable using System.ComponentModel.DataAnnotations; -using Bit.Core; +using Bit.Core.Billing.Tax.Utilities; using Bit.Core.Settings; using Enums = Bit.Core.Enums; @@ -36,7 +36,7 @@ public IEnumerable Validate(ValidationContext validationContex { yield return new ValidationResult("Payment token or license is required."); } - if (Country == Constants.CountryAbbreviations.UnitedStates && string.IsNullOrWhiteSpace(PostalCode)) + if (TaxHelpers.IsDirectTaxCountry(Country) && string.IsNullOrWhiteSpace(PostalCode)) { yield return new ValidationResult("Zip / postal code is required.", new string[] { nameof(PostalCode) }); diff --git a/src/Api/Models/Request/Accounts/TaxInfoUpdateRequestModel.cs b/src/Api/Models/Request/Accounts/TaxInfoUpdateRequestModel.cs index d3e3f5ec554d..5c4c4de018db 100644 --- a/src/Api/Models/Request/Accounts/TaxInfoUpdateRequestModel.cs +++ b/src/Api/Models/Request/Accounts/TaxInfoUpdateRequestModel.cs @@ -2,7 +2,7 @@ #nullable disable using System.ComponentModel.DataAnnotations; -using Bit.Core; +using Bit.Core.Billing.Tax.Utilities; namespace Bit.Api.Models.Request.Accounts; @@ -14,7 +14,7 @@ public class TaxInfoUpdateRequestModel : IValidatableObject public virtual IEnumerable Validate(ValidationContext validationContext) { - if (Country == Constants.CountryAbbreviations.UnitedStates && string.IsNullOrWhiteSpace(PostalCode)) + if (TaxHelpers.IsDirectTaxCountry(Country) && string.IsNullOrWhiteSpace(PostalCode)) { yield return new ValidationResult("Zip / postal code is required.", new string[] { nameof(PostalCode) }); diff --git a/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs b/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs index ae2a76a7ce80..72aeb601ce80 100644 --- a/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs +++ b/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs @@ -8,6 +8,7 @@ using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Payment.Queries; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Tax.Utilities; using Bit.Core.Entities; using Bit.Core.Models.Mail.Billing.Renewal.Families2019Renewal; using Bit.Core.Models.Mail.Billing.Renewal.Families2020Renewal; @@ -157,24 +158,29 @@ private async Task AlignOrganizationTaxConcernsAsync( Customer customer, string eventId) { - var nonUSBusinessUse = - organization.PlanType.GetProductTier() != ProductTierType.Families && - customer.Address.Country != Core.Constants.CountryAbbreviations.UnitedStates; + var isBusinessUse = organization.PlanType.GetProductTier() != ProductTierType.Families; - if (nonUSBusinessUse && customer.TaxExempt != TaxExempt.Reverse) + if (isBusinessUse) { - try - { - await stripeFacade.UpdateCustomer(subscription.CustomerId, - new CustomerUpdateOptions { TaxExempt = TaxExempt.Reverse }); - } - catch (Exception exception) + var determinedTaxExemptStatus = TaxHelpers.DetermineTaxExemptStatus(customer.Address.Country, customer.TaxExempt); + switch (customer) { - logger.LogError( - exception, - "Failed to set organization's ({OrganizationID}) to reverse tax exemption while processing event with ID {EventID}", - organization.Id, - eventId); + case { Address.Country: not null and not "", TaxExempt: var customerTaxExemptStatus } + when determinedTaxExemptStatus != customerTaxExemptStatus: + try + { + await stripeFacade.UpdateCustomer(subscription.CustomerId, + new CustomerUpdateOptions { TaxExempt = determinedTaxExemptStatus }); + } + catch (Exception exception) + { + logger.LogError( + exception, + "Failed to set organization's ({OrganizationID}) to the required tax exemption while processing event with ID {EventID}", + organization.Id, + eventId); + } + break; } } @@ -449,22 +455,25 @@ private async Task AlignProviderTaxConcernsAsync( Customer customer, string eventId) { - if (customer.Address.Country != Core.Constants.CountryAbbreviations.UnitedStates && - customer.TaxExempt != TaxExempt.Reverse) + var determinedTaxExemptStatus = TaxHelpers.DetermineTaxExemptStatus(customer.Address.Country, customer.TaxExempt); + switch (customer) { - try - { - await stripeFacade.UpdateCustomer(subscription.CustomerId, - new CustomerUpdateOptions { TaxExempt = TaxExempt.Reverse }); - } - catch (Exception exception) - { - logger.LogError( - exception, - "Failed to set provider's ({ProviderID}) to reverse tax exemption while processing event with ID {EventID}", - provider.Id, - eventId); - } + case { Address.Country: not null and not "", TaxExempt: var customerTaxExemptStatus } + when determinedTaxExemptStatus != customerTaxExemptStatus: + try + { + await stripeFacade.UpdateCustomer(subscription.CustomerId, + new CustomerUpdateOptions { TaxExempt = determinedTaxExemptStatus }); + } + catch (Exception exception) + { + logger.LogError( + exception, + "Failed to set provider's ({ProviderID}) to the required tax exemption while processing event with ID {EventID}", + provider.Id, + eventId); + } + break; } if (!subscription.AutomaticTax.Enabled) diff --git a/src/Core/Billing/Organizations/Commands/PreviewOrganizationTaxCommand.cs b/src/Core/Billing/Organizations/Commands/PreviewOrganizationTaxCommand.cs index e06aab7b390c..2f7149c93974 100644 --- a/src/Core/Billing/Organizations/Commands/PreviewOrganizationTaxCommand.cs +++ b/src/Core/Billing/Organizations/Commands/PreviewOrganizationTaxCommand.cs @@ -8,6 +8,7 @@ using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; +using Bit.Core.Billing.Tax.Utilities; using Bit.Core.Entities; using Bit.Core.Enums; using Microsoft.Extensions.Logging; @@ -16,7 +17,6 @@ namespace Bit.Core.Billing.Organizations.Commands; -using static Core.Constants; using static StripeConstants; public interface IPreviewOrganizationTaxCommand @@ -385,12 +385,24 @@ private static InvoiceCreatePreviewOptions GetBaseOptions( CustomerDetails = new InvoiceCustomerDetailsOptions { Address = new AddressOptions { Country = country, PostalCode = postalCode }, - TaxExempt = businessUse && country != CountryAbbreviations.UnitedStates - ? TaxExempt.Reverse - : TaxExempt.None } }; + switch (businessUse) + { + case true: + var existingTaxExemptStatus = addressChoice.Match( + customer => customer.TaxExempt, + _ => null!); + + var determinedTaxExemptStatus = TaxHelpers.DetermineTaxExemptStatus(country, existingTaxExemptStatus); + options.CustomerDetails.TaxExempt = determinedTaxExemptStatus; + break; + default: + options.CustomerDetails.TaxExempt = TaxExempt.None; + break; + } + var taxId = addressChoice.Match( customer => { diff --git a/src/Core/Billing/Organizations/Queries/GetOrganizationWarningsQuery.cs b/src/Core/Billing/Organizations/Queries/GetOrganizationWarningsQuery.cs index af8dfa7aec0c..c876eb912f10 100644 --- a/src/Core/Billing/Organizations/Queries/GetOrganizationWarningsQuery.cs +++ b/src/Core/Billing/Organizations/Queries/GetOrganizationWarningsQuery.cs @@ -8,13 +8,13 @@ using Bit.Core.Billing.Organizations.Models; using Bit.Core.Billing.Payment.Queries; using Bit.Core.Billing.Services; +using Bit.Core.Billing.Tax.Utilities; using Bit.Core.Context; using Stripe; using Stripe.Tax; namespace Bit.Core.Billing.Organizations.Queries; -using static Core.Constants; using static StripeConstants; using FreeTrialWarning = OrganizationWarnings.FreeTrialWarning; using InactiveSubscriptionWarning = OrganizationWarnings.InactiveSubscriptionWarning; @@ -230,7 +230,7 @@ on the subscription status. */ Customer customer, Provider? provider) { - if (customer.Address?.Country == CountryAbbreviations.UnitedStates) + if (TaxHelpers.IsDirectTaxCountry(customer.Address?.Country)) { return null; } diff --git a/src/Core/Billing/Organizations/Services/OrganizationBillingService.cs b/src/Core/Billing/Organizations/Services/OrganizationBillingService.cs index 69d4cc1d9f60..1a1aee25b15f 100644 --- a/src/Core/Billing/Organizations/Services/OrganizationBillingService.cs +++ b/src/Core/Billing/Organizations/Services/OrganizationBillingService.cs @@ -1,5 +1,4 @@ using Bit.Core.AdminConsole.Entities; -using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Models.Sales; @@ -8,6 +7,7 @@ using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; using Bit.Core.Billing.Tax.Services; +using Bit.Core.Billing.Tax.Utilities; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; @@ -17,8 +17,10 @@ using Stripe; using static Bit.Core.Billing.Utilities; using Customer = Stripe.Customer; +using StripeConstants = Bit.Core.Billing.Constants.StripeConstants; using Subscription = Stripe.Subscription; + namespace Bit.Core.Billing.Organizations.Services; public class OrganizationBillingService( @@ -238,7 +240,7 @@ private async Task CreateCustomerAsync( }; if (planType.GetProductTier() is not ProductTierType.Free and not ProductTierType.Families && - customerSetup.TaxInformation.Country != Core.Constants.CountryAbbreviations.UnitedStates) + !TaxHelpers.IsDirectTaxCountry(customerSetup.TaxInformation.Country)) { customerCreateOptions.TaxExempt = StripeConstants.TaxExempt.Reverse; } @@ -491,23 +493,13 @@ ProductTierType.TeamsStarter or } List expansions = ["tax", "tax_ids"]; - + var determinedTaxExemptStatus = TaxHelpers.DetermineTaxExemptStatus(customer.Address?.Country, customer.TaxExempt); customer = customer switch { - { Address.Country: not Core.Constants.CountryAbbreviations.UnitedStates, TaxExempt: not StripeConstants.TaxExempt.Reverse } => await - stripeAdapter.UpdateCustomerAsync(customer.Id, - new CustomerUpdateOptions - { - Expand = expansions, - TaxExempt = StripeConstants.TaxExempt.Reverse - }), - { Address.Country: Core.Constants.CountryAbbreviations.UnitedStates, TaxExempt: StripeConstants.TaxExempt.Reverse } => await - stripeAdapter.UpdateCustomerAsync(customer.Id, - new CustomerUpdateOptions - { - Expand = expansions, - TaxExempt = StripeConstants.TaxExempt.None - }), + { Address.Country: not null and not "", TaxExempt: var customerTaxExemptStatus } + when determinedTaxExemptStatus != customerTaxExemptStatus => + await stripeAdapter.UpdateCustomerAsync(customer.Id, + new CustomerUpdateOptions { Expand = expansions, TaxExempt = determinedTaxExemptStatus }), _ => customer }; diff --git a/src/Core/Billing/Payment/Commands/UpdateBillingAddressCommand.cs b/src/Core/Billing/Payment/Commands/UpdateBillingAddressCommand.cs index daf39fb98191..37ccea97088a 100644 --- a/src/Core/Billing/Payment/Commands/UpdateBillingAddressCommand.cs +++ b/src/Core/Billing/Payment/Commands/UpdateBillingAddressCommand.cs @@ -3,6 +3,7 @@ using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Services; +using Bit.Core.Billing.Tax.Utilities; using Bit.Core.Entities; using Microsoft.Extensions.Logging; using Stripe; @@ -69,24 +70,23 @@ private async Task> UpdateBusinessBillingAd ISubscriber subscriber, BillingAddress billingAddress) { - var customer = - await stripeAdapter.UpdateCustomerAsync(subscriber.GatewayCustomerId, - new CustomerUpdateOptions + var determinedTaxExemptStatus = await GetDeterminedTaxExemptStatusAsync(subscriber.GatewayCustomerId!, billingAddress.Country); + + var customer = await stripeAdapter.UpdateCustomerAsync(subscriber.GatewayCustomerId, + new CustomerUpdateOptions + { + Address = new AddressOptions { - Address = new AddressOptions - { - Country = billingAddress.Country, - PostalCode = billingAddress.PostalCode, - Line1 = billingAddress.Line1, - Line2 = billingAddress.Line2, - City = billingAddress.City, - State = billingAddress.State - }, - Expand = ["subscriptions", "tax_ids"], - TaxExempt = billingAddress.Country != Core.Constants.CountryAbbreviations.UnitedStates - ? StripeConstants.TaxExempt.Reverse - : StripeConstants.TaxExempt.None - }); + Country = billingAddress.Country, + PostalCode = billingAddress.PostalCode, + Line1 = billingAddress.Line1, + Line2 = billingAddress.Line2, + City = billingAddress.City, + State = billingAddress.State + }, + Expand = ["subscriptions", "tax_ids"], + TaxExempt = determinedTaxExemptStatus + }); await EnableAutomaticTaxAsync(subscriber, customer); @@ -118,6 +118,13 @@ await stripeAdapter.UpdateCustomerAsync(subscriber.GatewayCustomerId, return BillingAddress.From(customer.Address, updatedTaxId); } + + private async Task GetDeterminedTaxExemptStatusAsync(string customerId, string? billingCountry) + { + var existingCustomer = await stripeAdapter.GetCustomerAsync(customerId); + return TaxHelpers.DetermineTaxExemptStatus(billingCountry, existingCustomer.TaxExempt); + } + private async Task EnableAutomaticTaxAsync( ISubscriber subscriber, Customer customer) diff --git a/src/Core/Billing/Services/Implementations/StripePaymentService.cs b/src/Core/Billing/Services/Implementations/StripePaymentService.cs index 30eac78698c2..38ed075ec73f 100644 --- a/src/Core/Billing/Services/Implementations/StripePaymentService.cs +++ b/src/Core/Billing/Services/Implementations/StripePaymentService.cs @@ -9,6 +9,7 @@ using Bit.Core.Billing.Models; using Bit.Core.Billing.Organizations.Models; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Tax.Utilities; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -122,14 +123,14 @@ private async Task FinalizeSubscriptionChangeAsync(ISubscriber subscribe if (subscriptionUpdate is CompleteSubscriptionUpdate) { - if (sub.Customer is - { - Address.Country: not Core.Constants.CountryAbbreviations.UnitedStates, - TaxExempt: not StripeConstants.TaxExempt.Reverse - }) + var determinedTaxExemptStatus = TaxHelpers.DetermineTaxExemptStatus(sub.Customer.Address.Country, sub.Customer.TaxExempt); + switch (sub.Customer) { - await _stripeAdapter.UpdateCustomerAsync(sub.CustomerId, - new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.Reverse }); + case { Address.Country: not null and not "", TaxExempt: var customerTaxExemptStatus } + when determinedTaxExemptStatus != customerTaxExemptStatus: + await _stripeAdapter.UpdateCustomerAsync(sub.Customer.Id, + new CustomerUpdateOptions { TaxExempt = determinedTaxExemptStatus }); + break; } subUpdateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }; diff --git a/src/Core/Billing/Services/Implementations/SubscriberService.cs b/src/Core/Billing/Services/Implementations/SubscriberService.cs index 00b72ab196f2..d80e7c84fdb4 100644 --- a/src/Core/Billing/Services/Implementations/SubscriberService.cs +++ b/src/Core/Billing/Services/Implementations/SubscriberService.cs @@ -10,6 +10,7 @@ using Bit.Core.Billing.Models; using Bit.Core.Billing.Tax.Models; using Bit.Core.Billing.Tax.Services; +using Bit.Core.Billing.Tax.Utilities; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -19,7 +20,6 @@ using Braintree; using Microsoft.Extensions.Logging; using Stripe; - using static Bit.Core.Billing.Utilities; using Customer = Stripe.Customer; using Subscription = Stripe.Subscription; @@ -602,23 +602,13 @@ await stripeAdapter.CreateTaxIdAsync(customer.Id, if (isBusinessUseSubscriber) { + var determinedTaxExemptStatus = TaxHelpers.DetermineTaxExemptStatus(customer.Address.Country, customer.TaxExempt); switch (customer) { - case - { - Address.Country: not Core.Constants.CountryAbbreviations.UnitedStates, - TaxExempt: not TaxExempt.Reverse - }: - await stripeAdapter.UpdateCustomerAsync(customer.Id, - new CustomerUpdateOptions { TaxExempt = TaxExempt.Reverse }); - break; - case - { - Address.Country: Core.Constants.CountryAbbreviations.UnitedStates, - TaxExempt: TaxExempt.Reverse - }: + case { Address.Country: not null and not "", TaxExempt: var customerTaxExemptStatus } + when determinedTaxExemptStatus != customerTaxExemptStatus: await stripeAdapter.UpdateCustomerAsync(customer.Id, - new CustomerUpdateOptions { TaxExempt = TaxExempt.None }); + new CustomerUpdateOptions { TaxExempt = determinedTaxExemptStatus }); break; } @@ -637,8 +627,8 @@ await stripeAdapter.UpdateSubscriptionAsync(subscription.Id, { User => true, Organization organization => organization.PlanType.GetProductTier() == ProductTierType.Families || - customer.Address.Country == Core.Constants.CountryAbbreviations.UnitedStates || (customer.TaxIds?.Any() ?? false), - Provider => customer.Address.Country == Core.Constants.CountryAbbreviations.UnitedStates || (customer.TaxIds?.Any() ?? false), + TaxHelpers.IsDirectTaxCountry(customer.Address.Country) || (customer.TaxIds?.Any() ?? false), + Provider => TaxHelpers.IsDirectTaxCountry(customer.Address.Country) || (customer.TaxIds?.Any() ?? false), _ => false }; diff --git a/src/Core/Billing/Tax/Utilities/TaxHelpers.cs b/src/Core/Billing/Tax/Utilities/TaxHelpers.cs new file mode 100644 index 000000000000..407693182498 --- /dev/null +++ b/src/Core/Billing/Tax/Utilities/TaxHelpers.cs @@ -0,0 +1,56 @@ +using CountryAbbreviations = Bit.Core.Constants.CountryAbbreviations; +using TaxExempt = Bit.Core.Billing.Constants.StripeConstants.TaxExempt; +namespace Bit.Core.Billing.Tax.Utilities; + +public static class TaxHelpers +{ + /// + /// Countries where tax is collected directly from customers, rather than through VAT ID reverse charge. + /// To add a new country, add its ISO 3166 code to + /// and then add it to this set. + /// + private static readonly HashSet DirectTaxCountries = + [ + CountryAbbreviations.UnitedStates, + CountryAbbreviations.Switzerland + ]; + + /// + /// For countries where tax is collected directly, we generally want to default Stripe's tax_exempt to "none". + /// However, some customer may have been manually set up with "reverse" tax_exempt status, so we want to preserve that manual override for those customers. + /// This set defines the countries for which we should preserve that manual override. + /// to add a new country, add its ISO 3166 code to + /// + private static readonly HashSet PreserveReverseChargeCountries = + [CountryAbbreviations.Switzerland]; + + /// + /// Returns if is in , + /// meaning tax is collected directly and Stripe's tax_exempt should default to "none". + /// Returns for all other countries, where VAT reverse charge applies. + /// + public static bool IsDirectTaxCountry(string? country) => + country is not null and not "" && DirectTaxCountries.Contains(country); + + /// + /// Returns the Stripe tax_exempt value appropriate for .
+ /// For non-direct-tax countries, always returns "reverse".
+ /// For direct-tax countries, returns "none" — unless the country is in and + /// is already "reverse" + ///
+ public static string DetermineTaxExemptStatus(string? country, string? currentTaxExempt = null) => + !IsDirectTaxCountry(country) + ? TaxExempt.Reverse + : IsManualReverseChargeOverridden(country, currentTaxExempt) + ? TaxExempt.Reverse + : TaxExempt.None; + + /// + /// Returns if the current tax exempt status should be retained for the given country. + /// + private static bool IsManualReverseChargeOverridden(string? country, string? taxExemptStatus) => + country is not null and not "" + && taxExemptStatus is not null and not "" and TaxExempt.Reverse + && PreserveReverseChargeCountries.Contains(country); + +} diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 2578249df1a8..0bad0f18a199 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -69,8 +69,13 @@ public static class CountryAbbreviations /// This value must match what Stripe uses for the `Country` field value for the United States. /// public const string UnitedStates = "US"; - } + /// + /// Abbreviation for Switzerland. + /// This value must match what Stripe uses for the `Country` field value for Switzerland. + /// + public const string Switzerland = "CH"; + } /// /// Constants for our browser extensions IDs diff --git a/test/Billing.Test/Services/UpcomingInvoiceHandlerTests.cs b/test/Billing.Test/Services/UpcomingInvoiceHandlerTests.cs index 82d6c8acfd49..67cdfb63869e 100644 --- a/test/Billing.Test/Services/UpcomingInvoiceHandlerTests.cs +++ b/test/Billing.Test/Services/UpcomingInvoiceHandlerTests.cs @@ -553,6 +553,146 @@ await _mailService.Received(1).SendInvoiceUpcoming( Arg.Is(b => b == true)); } + [Fact] + public async Task HandleAsync_WhenNonDirectTaxCountryOrganization_SetsReverseCharge() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123" }; + var invoice = new Invoice { CustomerId = "cus_123", AmountDue = 0, Lines = new StripeList { Data = [] } }; + var subscription = new Subscription + { + Id = "sub_123", + CustomerId = "cus_123", + Items = new StripeList(), + AutomaticTax = new SubscriptionAutomaticTax { Enabled = true }, + Customer = new Customer { Id = "cus_123" }, + Metadata = new Dictionary() + }; + var customer = new Customer + { + Id = "cus_123", + Subscriptions = new StripeList { Data = [subscription] }, + Address = new Address { Country = "DE" }, + TaxExempt = TaxExempt.None + }; + var organization = new Organization + { + Id = _organizationId, + BillingEmail = "org@example.com", + PlanType = PlanType.EnterpriseAnnually + }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(invoice.CustomerId, Arg.Any()).Returns(customer); + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(_organizationId, null, null)); + _organizationRepository.GetByIdAsync(_organizationId).Returns(organization); + _pricingClient.GetPlanOrThrow(organization.PlanType).Returns(new EnterprisePlan(isAnnual: true)); + _stripeEventUtilityService.IsSponsoredSubscription(subscription).Returns(false); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _stripeFacade.Received(1).UpdateCustomer( + Arg.Is("cus_123"), + Arg.Is(o => o.TaxExempt == TaxExempt.Reverse)); + } + + [Fact] + public async Task HandleAsync_WhenUSOrganizationWithManualReverseCharge_CorrectsTaxExemptToNone() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123" }; + var invoice = new Invoice { CustomerId = "cus_123", AmountDue = 0, Lines = new StripeList { Data = [] } }; + var subscription = new Subscription + { + Id = "sub_123", + CustomerId = "cus_123", + Items = new StripeList(), + AutomaticTax = new SubscriptionAutomaticTax { Enabled = true }, + Customer = new Customer { Id = "cus_123" }, + Metadata = new Dictionary() + }; + var customer = new Customer + { + Id = "cus_123", + Subscriptions = new StripeList { Data = [subscription] }, + Address = new Address { Country = "US" }, + TaxExempt = TaxExempt.Reverse + }; + var organization = new Organization + { + Id = _organizationId, + BillingEmail = "org@example.com", + PlanType = PlanType.EnterpriseAnnually + }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(invoice.CustomerId, Arg.Any()).Returns(customer); + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(_organizationId, null, null)); + _organizationRepository.GetByIdAsync(_organizationId).Returns(organization); + _pricingClient.GetPlanOrThrow(organization.PlanType).Returns(new EnterprisePlan(isAnnual: true)); + _stripeEventUtilityService.IsSponsoredSubscription(subscription).Returns(false); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _stripeFacade.Received(1).UpdateCustomer( + Arg.Is("cus_123"), + Arg.Is(o => o.TaxExempt == TaxExempt.None)); + } + + [Fact] + public async Task HandleAsync_WhenSwissOrganizationWithManualReverseCharge_PreservesReverseCharge() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123" }; + var invoice = new Invoice { CustomerId = "cus_123", AmountDue = 0, Lines = new StripeList { Data = [] } }; + var subscription = new Subscription + { + Id = "sub_123", + CustomerId = "cus_123", + Items = new StripeList(), + AutomaticTax = new SubscriptionAutomaticTax { Enabled = true }, + Customer = new Customer { Id = "cus_123" }, + Metadata = new Dictionary() + }; + var customer = new Customer + { + Id = "cus_123", + Subscriptions = new StripeList { Data = [subscription] }, + Address = new Address { Country = "CH" }, + TaxExempt = TaxExempt.Reverse + }; + var organization = new Organization + { + Id = _organizationId, + BillingEmail = "org@example.com", + PlanType = PlanType.EnterpriseAnnually + }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(invoice.CustomerId, Arg.Any()).Returns(customer); + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(_organizationId, null, null)); + _organizationRepository.GetByIdAsync(_organizationId).Returns(organization); + _pricingClient.GetPlanOrThrow(organization.PlanType).Returns(new EnterprisePlan(isAnnual: true)); + _stripeEventUtilityService.IsSponsoredSubscription(subscription).Returns(false); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert — no customer update needed; required status matches current + await _stripeFacade.DidNotReceive().UpdateCustomer( + Arg.Any(), + Arg.Any()); + } [Fact] public async Task HandleAsync_WhenValidProviderSubscription_SendsEmail() @@ -606,7 +746,7 @@ public async Task HandleAsync_WhenValidProviderSubscription_SendsEmail() // Assert await _providerRepository.Received(2).GetByIdAsync(_providerId); - // Verify tax exempt was set to reverse for non-US providers + // Verify tax exempt was set to reverse for non-direct-tax-country providers await _stripeFacade.Received(1).UpdateCustomer( Arg.Is("cus_123"), Arg.Is(o => o.TaxExempt == TaxExempt.Reverse)); @@ -627,6 +767,146 @@ await _mailService.Received(1).SendProviderInvoiceUpcoming( Arg.Is(s => s == $"{paymentMethod.Brand} ending in {paymentMethod.Last4}")); } + [Fact] + public async Task HandleAsync_WhenSwissProviderWithManualReverseCharge_PreservesReverseCharge() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123" }; + var invoice = new Invoice + { + CustomerId = "cus_123", + AmountDue = 10000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + var subscription = new Subscription + { + Id = "sub_123", + CustomerId = "cus_123", + Items = new StripeList(), + AutomaticTax = new SubscriptionAutomaticTax { Enabled = true }, + Customer = new Customer { Id = "cus_123" }, + Metadata = new Dictionary(), + CollectionMethod = "charge_automatically" + }; + var customer = new Customer + { + Id = "cus_123", + Subscriptions = new StripeList { Data = [subscription] }, + Address = new Address { Country = "CH" }, + TaxExempt = TaxExempt.Reverse + }; + var provider = new Provider { Id = _providerId, BillingEmail = "provider@example.com" }; + + var paymentMethod = new Card { Last4 = "4242", Brand = "visa" }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(invoice.CustomerId, Arg.Any()).Returns(customer); + + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(null, null, _providerId)); + + _providerRepository.GetByIdAsync(_providerId).Returns(provider); + _getPaymentMethodQuery.Run(provider).Returns(MaskedPaymentMethod.From(paymentMethod)); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _providerRepository.Received(2).GetByIdAsync(_providerId); + + // Manual reverse charge is preserved for Switzerland — no customer update + await _stripeFacade.DidNotReceive().UpdateCustomer( + Arg.Any(), + Arg.Any()); + } + + [Fact] + public async Task HandleAsync_WhenNonDirectTaxCountryProvider_SetsReverseCharge() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123" }; + var invoice = new Invoice { CustomerId = "cus_123", AmountDue = 0, Lines = new StripeList { Data = [] } }; + var subscription = new Subscription + { + Id = "sub_123", + CustomerId = "cus_123", + Items = new StripeList(), + AutomaticTax = new SubscriptionAutomaticTax { Enabled = true }, + Customer = new Customer { Id = "cus_123" }, + Metadata = new Dictionary(), + CollectionMethod = "charge_automatically" + }; + var customer = new Customer + { + Id = "cus_123", + Subscriptions = new StripeList { Data = [subscription] }, + Address = new Address { Country = "DE" }, + TaxExempt = TaxExempt.None + }; + var provider = new Provider { Id = _providerId, BillingEmail = "provider@example.com" }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(invoice.CustomerId, Arg.Any()).Returns(customer); + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(null, null, _providerId)); + _providerRepository.GetByIdAsync(_providerId).Returns(provider); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _stripeFacade.Received(1).UpdateCustomer( + Arg.Is("cus_123"), + Arg.Is(o => o.TaxExempt == TaxExempt.Reverse)); + } + + [Fact] + public async Task HandleAsync_WhenUSProviderWithManualReverseCharge_CorrectsTaxExemptToNone() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123" }; + var invoice = new Invoice { CustomerId = "cus_123", AmountDue = 0, Lines = new StripeList { Data = [] } }; + var subscription = new Subscription + { + Id = "sub_123", + CustomerId = "cus_123", + Items = new StripeList(), + AutomaticTax = new SubscriptionAutomaticTax { Enabled = true }, + Customer = new Customer { Id = "cus_123" }, + Metadata = new Dictionary(), + CollectionMethod = "charge_automatically" + }; + var customer = new Customer + { + Id = "cus_123", + Subscriptions = new StripeList { Data = [subscription] }, + Address = new Address { Country = "US" }, + TaxExempt = TaxExempt.Reverse + }; + var provider = new Provider { Id = _providerId, BillingEmail = "provider@example.com" }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(invoice.CustomerId, Arg.Any()).Returns(customer); + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(null, null, _providerId)); + _providerRepository.GetByIdAsync(_providerId).Returns(provider); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _stripeFacade.Received(1).UpdateCustomer( + Arg.Is("cus_123"), + Arg.Is(o => o.TaxExempt == TaxExempt.None)); + } + [Fact] public async Task HandleAsync_WhenUpdateSubscriptionItemPriceIdFails_LogsErrorAndSendsTraditionalEmail() { @@ -1064,6 +1344,11 @@ await _mailer.Received(1).SendEmail( email.View.BaseAnnualRenewalPrice == familiesPlan.PasswordManager.BasePrice.ToString("C", new CultureInfo("en-US")) && email.View.DiscountAmount == $"{coupon.PercentOff}%" )); + + // Families plan is excluded from tax exempt alignment + await _stripeFacade.DidNotReceive().UpdateCustomer( + Arg.Any(), + Arg.Any()); } [Fact] @@ -1154,6 +1439,11 @@ await _organizationRepository.Received(1).ReplaceAsync( org.Plan == familiesPlan.Name && org.UsersGetPremium == familiesPlan.UsersGetPremium && org.Seats == familiesPlan.PasswordManager.BaseSeats)); + + // Families plan is excluded from tax exempt alignment + await _stripeFacade.DidNotReceive().UpdateCustomer( + Arg.Any(), + Arg.Any()); } [Fact] @@ -1231,6 +1521,11 @@ await _stripeFacade.DidNotReceive().UpdateSubscription( await _organizationRepository.DidNotReceive().ReplaceAsync( Arg.Is(org => org.PlanType == PlanType.FamiliesAnnually)); + + // Families plan is excluded from tax exempt alignment + await _stripeFacade.DidNotReceive().UpdateCustomer( + Arg.Any(), + Arg.Any()); } [Fact] @@ -1302,6 +1597,10 @@ await _stripeFacade.DidNotReceive().UpdateSubscription( Arg.Is(o => o.Discounts != null)); await _organizationRepository.DidNotReceive().ReplaceAsync(Arg.Any()); + // Families plan is excluded from tax exempt alignment + await _stripeFacade.DidNotReceive().UpdateCustomer( + Arg.Any(), + Arg.Any()); } [Fact] diff --git a/test/Core.Test/Billing/Organizations/Commands/PreviewOrganizationTaxCommandTests.cs b/test/Core.Test/Billing/Organizations/Commands/PreviewOrganizationTaxCommandTests.cs index 83f072fa3657..f30e6fb1d972 100644 --- a/test/Core.Test/Billing/Organizations/Commands/PreviewOrganizationTaxCommandTests.cs +++ b/test/Core.Test/Billing/Organizations/Commands/PreviewOrganizationTaxCommandTests.cs @@ -318,6 +318,57 @@ await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is()).Returns(invoice); + + var result = await _command.Run(_user, purchase, billingAddress); + + Assert.True(result.IsT0); + var (tax, total) = result.AsT0; + Assert.Equal(2.20m, tax); + Assert.Equal(29.20m, total); + + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => + options.AutomaticTax.Enabled == true && + options.Currency == "usd" && + options.CustomerDetails.Address.Country == "CH" && + options.CustomerDetails.Address.PostalCode == "3001" && + options.CustomerDetails.TaxExempt == TaxExempt.None && + options.SubscriptionDetails.Items.Count == 1 && + options.SubscriptionDetails.Items[0].Price == "2023-teams-org-seat-monthly" && + options.SubscriptionDetails.Items[0].Quantity == 3 && + options.Discounts == null)); + } + [Fact] public async Task Run_OrganizationSubscriptionPurchase_SpanishNIFTaxId_AddsEUVATTaxId() { diff --git a/test/Core.Test/Billing/Organizations/Queries/GetOrganizationWarningsQueryTests.cs b/test/Core.Test/Billing/Organizations/Queries/GetOrganizationWarningsQueryTests.cs index a7284410fe52..2f7f5051b3f3 100644 --- a/test/Core.Test/Billing/Organizations/Queries/GetOrganizationWarningsQueryTests.cs +++ b/test/Core.Test/Billing/Organizations/Queries/GetOrganizationWarningsQueryTests.cs @@ -421,6 +421,31 @@ public async Task Run_USCustomer_NoTaxIdWarning( Assert.Null(response.TaxId); } + [Theory, BitAutoData] + public async Task Run_CHCustomer_NoTaxIdWarning( + Organization organization, + SutProvider sutProvider) + { + var subscription = new Subscription + { + Customer = new Customer + { + Address = new Address { Country = "CH" }, + TaxIds = new StripeList { Data = new List() }, + InvoiceSettings = new CustomerInvoiceSettings(), + Metadata = new Dictionary() + } + }; + + sutProvider.GetDependency() + .GetSubscription(organization, Arg.Any()) + .Returns(subscription); + + var response = await sutProvider.Sut.Run(organization); + + Assert.Null(response.TaxId); + } + [Theory, BitAutoData] public async Task Run_FreeCustomer_NoTaxIdWarning( Organization organization, diff --git a/test/Core.Test/Billing/Payment/Commands/UpdateBillingAddressCommandTests.cs b/test/Core.Test/Billing/Payment/Commands/UpdateBillingAddressCommandTests.cs index 88728d4839b9..ac3d53d301fa 100644 --- a/test/Core.Test/Billing/Payment/Commands/UpdateBillingAddressCommandTests.cs +++ b/test/Core.Test/Billing/Payment/Commands/UpdateBillingAddressCommandTests.cs @@ -191,6 +191,9 @@ public async Task Run_BusinessOrganization_MakesCorrectInvocations_ReturnsBillin } }; + _stripeAdapter.GetCustomerAsync(organization.GatewayCustomerId) + .Returns(new Customer { TaxExempt = TaxExempt.None }); + _stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Is(options => options.Address.Matches(input) && options.HasExpansions("subscriptions", "tax_ids") && @@ -259,6 +262,9 @@ public async Task Run_BusinessOrganization_RemovingTaxId_MakesCorrectInvocations } }; + _stripeAdapter.GetCustomerAsync(organization.GatewayCustomerId) + .Returns(new Customer { TaxExempt = TaxExempt.None }); + _stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Is(options => options.Address.Matches(input) && options.HasExpansions("subscriptions", "tax_ids") && @@ -321,6 +327,9 @@ public async Task Run_NonUSBusinessOrganization_MakesCorrectInvocations_ReturnsB } }; + _stripeAdapter.GetCustomerAsync(organization.GatewayCustomerId) + .Returns(new Customer { TaxExempt = TaxExempt.None }); + _stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Is(options => options.Address.Matches(input) && options.HasExpansions("subscriptions", "tax_ids") && @@ -337,6 +346,69 @@ await _stripeAdapter.Received(1).UpdateSubscriptionAsync(organization.GatewaySub Arg.Is(options => options.AutomaticTax.Enabled == true)); } + [Fact] + public async Task Run_SwissBusinessOrganization_MakesCorrectInvocations_ReturnsBillingAddress() + { + var organization = new Organization + { + PlanType = PlanType.EnterpriseAnnually, + GatewayCustomerId = "cus_123", + GatewaySubscriptionId = "sub_123" + }; + + var input = new BillingAddress + { + Country = "CH", + PostalCode = "3001", + Line1 = "Bundesgasse 1", + Line2 = string.Empty, + City = "Bern", + State = "BE" + }; + + var customer = new Customer + { + Address = new Address + { + Country = "CH", + PostalCode = "3001", + Line1 = "Bundesgasse 1", + Line2 = string.Empty, + City = "Bern", + State = "BE" + }, + Subscriptions = new StripeList + { + Data = + [ + new Subscription + { + Id = organization.GatewaySubscriptionId, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = false } + } + ] + } + }; + + _stripeAdapter.GetCustomerAsync(organization.GatewayCustomerId) + .Returns(new Customer { TaxExempt = TaxExempt.None }); + + _stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Is(options => + options.Address.Matches(input) && + options.HasExpansions("subscriptions", "tax_ids") && + options.TaxExempt == TaxExempt.None + )).Returns(customer); + + var result = await _command.Run(organization, input); + + Assert.True(result.IsT0); + var output = result.AsT0; + Assert.Equivalent(input, output); + + await _stripeAdapter.Received(1).UpdateSubscriptionAsync(organization.GatewaySubscriptionId, + Arg.Is(options => options.AutomaticTax.Enabled == true)); + } + [Fact] public async Task Run_BusinessOrganizationWithSpanishCIF_MakesCorrectInvocations_ReturnsBillingAddress() { @@ -383,6 +455,9 @@ public async Task Run_BusinessOrganizationWithSpanishCIF_MakesCorrectInvocations } }; + _stripeAdapter.GetCustomerAsync(organization.GatewayCustomerId) + .Returns(new Customer { TaxExempt = TaxExempt.None }); + _stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Is(options => options.Address.Matches(input) && options.HasExpansions("subscriptions", "tax_ids") && @@ -460,6 +535,9 @@ public async Task Run_BusinessOrganization_UpdatingWithSameTaxId_DeletesBeforeCr } }; + _stripeAdapter.GetCustomerAsync(organization.GatewayCustomerId) + .Returns(new Customer { TaxExempt = TaxExempt.None }); + _stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Is(options => options.Address.Matches(input) && options.HasExpansions("subscriptions", "tax_ids") && @@ -488,4 +566,67 @@ public async Task Run_BusinessOrganization_UpdatingWithSameTaxId_DeletesBeforeCr await _stripeAdapter.Received(1).CreateTaxIdAsync(customer.Id, Arg.Is( options => options.Type == "us_ein" && options.Value == "987654321")); } + + [Fact] + public async Task Run_SwissBusinessOrganization_ManuallySetReverse_PreservesReverse() + { + var organization = new Organization + { + PlanType = PlanType.EnterpriseAnnually, + GatewayCustomerId = "cus_123", + GatewaySubscriptionId = "sub_123" + }; + + var input = new BillingAddress + { + Country = "CH", + PostalCode = "3001", + Line1 = "Bundesgasse 1", + Line2 = string.Empty, + City = "Bern", + State = "BE" + }; + + var customer = new Customer + { + Address = new Address + { + Country = "CH", + PostalCode = "3001", + Line1 = "Bundesgasse 1", + Line2 = string.Empty, + City = "Bern", + State = "BE" + }, + Subscriptions = new StripeList + { + Data = + [ + new Subscription + { + Id = organization.GatewaySubscriptionId, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = true } + } + ] + } + }; + + _stripeAdapter.GetCustomerAsync(organization.GatewayCustomerId) + .Returns(new Customer { TaxExempt = TaxExempt.Reverse }); + + _stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Is(options => + options.Address.Matches(input) && + options.HasExpansions("subscriptions", "tax_ids") && + options.TaxExempt == TaxExempt.Reverse + )).Returns(customer); + + var result = await _command.Run(organization, input); + + Assert.True(result.IsT0); + var output = result.AsT0; + Assert.Equivalent(input, output); + + await _stripeAdapter.Received(1).UpdateCustomerAsync(organization.GatewayCustomerId, + Arg.Is(options => options.TaxExempt == TaxExempt.Reverse)); + } } diff --git a/test/Core.Test/Billing/Services/OrganizationBillingServiceTests.cs b/test/Core.Test/Billing/Services/OrganizationBillingServiceTests.cs index b44278acc47e..95201c1bdc55 100644 --- a/test/Core.Test/Billing/Services/OrganizationBillingServiceTests.cs +++ b/test/Core.Test/Billing/Services/OrganizationBillingServiceTests.cs @@ -597,6 +597,154 @@ await sutProvider.GetDependency() #endregion + [Theory, BitAutoData] + public async Task Finalize_SwissBusinessWithManualSetReverseExempt_DoesNotOverwriteReverse( + Organization organization, + SutProvider sutProvider) + { + // Arrange + var plan = MockPlans.Get(PlanType.TeamsAnnually); + organization.PlanType = PlanType.TeamsAnnually; + organization.GatewayCustomerId = "cus_test123"; + organization.GatewaySubscriptionId = null; + + var subscriptionSetup = new SubscriptionSetup + { + PlanType = PlanType.TeamsAnnually, + PasswordManagerOptions = new SubscriptionSetup.PasswordManager + { + Seats = 5, + Storage = null, + PremiumAccess = false + }, + SecretsManagerOptions = null, + SkipTrial = false + }; + + var sale = new OrganizationSale + { + Organization = organization, + SubscriptionSetup = subscriptionSetup + }; + + var customer = new Customer + { + Id = "cus_test123", + Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }, + Address = new Address { Country = "CH" }, + TaxExempt = StripeConstants.TaxExempt.Reverse + }; + + sutProvider.GetDependency() + .GetPlanOrThrow(PlanType.TeamsAnnually) + .Returns(plan); + + sutProvider.GetDependency() + .GetCustomerOrThrow(organization, Arg.Any()) + .Returns(customer); + + sutProvider.GetDependency() + .CreateSubscriptionAsync(Arg.Any()) + .Returns(new Subscription + { + Id = "sub_test123", + Status = StripeConstants.SubscriptionStatus.Active + }); + + sutProvider.GetDependency() + .ReplaceAsync(organization) + .Returns(Task.CompletedTask); + + // Act + await sutProvider.Sut.Finalize(sale); + + // Assert: UpdateCustomerAsync should NOT be called since the "reverse" tax exempt status for Switzerland should be preserved + await sutProvider.GetDependency() + .DidNotReceive() + .UpdateCustomerAsync(Arg.Any(), Arg.Any()); + } + + [Theory, BitAutoData] + public async Task Finalize_USBusinessWithReverseExempt_CorrectsTaxExemptToNone( + Organization organization, + SutProvider sutProvider) + { + // Arrange + var plan = MockPlans.Get(PlanType.TeamsAnnually); + organization.PlanType = PlanType.TeamsAnnually; + organization.GatewayCustomerId = "cus_test123"; + organization.GatewaySubscriptionId = null; + + var subscriptionSetup = new SubscriptionSetup + { + PlanType = PlanType.TeamsAnnually, + PasswordManagerOptions = new SubscriptionSetup.PasswordManager + { + Seats = 5, + Storage = null, + PremiumAccess = false + }, + SecretsManagerOptions = null, + SkipTrial = false + }; + + var sale = new OrganizationSale + { + Organization = organization, + SubscriptionSetup = subscriptionSetup + }; + + var customer = new Customer + { + Id = "cus_test123", + Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }, + Address = new Address { Country = "US" }, + TaxExempt = StripeConstants.TaxExempt.Reverse + }; + + var correctedCustomer = new Customer + { + Id = "cus_test123", + Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }, + Address = new Address { Country = "US" }, + TaxExempt = StripeConstants.TaxExempt.None + }; + + sutProvider.GetDependency() + .GetPlanOrThrow(PlanType.TeamsAnnually) + .Returns(plan); + + sutProvider.GetDependency() + .GetCustomerOrThrow(organization, Arg.Any()) + .Returns(customer); + + sutProvider.GetDependency() + .UpdateCustomerAsync(customer.Id, Arg.Is(options => + options.TaxExempt == StripeConstants.TaxExempt.None)) + .Returns(correctedCustomer); + + sutProvider.GetDependency() + .CreateSubscriptionAsync(Arg.Any()) + .Returns(new Subscription + { + Id = "sub_test123", + Status = StripeConstants.SubscriptionStatus.Active + }); + + sutProvider.GetDependency() + .ReplaceAsync(organization) + .Returns(Task.CompletedTask); + + // Act + await sutProvider.Sut.Finalize(sale); + + // Assert: UpdateCustomerAsync called with TaxExempt = "none" to correct the erroneous "reverse" + await sutProvider.GetDependency() + .Received(1) + .UpdateCustomerAsync(customer.Id, Arg.Is(options => + options.TaxExempt == StripeConstants.TaxExempt.None)); + } + [Theory, BitAutoData] public async Task UpdateOrganizationNameAndEmail_UpdatesStripeCustomer( Organization organization, diff --git a/test/Core.Test/Billing/Services/StripePaymentServiceTests.cs b/test/Core.Test/Billing/Services/StripePaymentServiceTests.cs index 73f28113ca0a..e0a04a7030a5 100644 --- a/test/Core.Test/Billing/Services/StripePaymentServiceTests.cs +++ b/test/Core.Test/Billing/Services/StripePaymentServiceTests.cs @@ -1,8 +1,12 @@ -using Bit.Core.Billing.Constants; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; using Bit.Core.Billing.Services.Implementations; using Bit.Core.Entities; using Bit.Core.Enums; +using Bit.Core.Test.Billing.Mocks.Plans; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -11,6 +15,8 @@ namespace Bit.Core.Test.Services; +using static StripeConstants; + [SutProviderCustomize] public class StripePaymentServiceTests { @@ -408,4 +414,198 @@ await sutProvider.GetDependency() .DidNotReceive() .GetSubscriptionAsync(Arg.Any(), Arg.Any()); } + + #region AdjustSubscription — CompleteSubscriptionUpdate tax exempt alignment + + [Theory, BitAutoData] + public async Task AdjustSubscription_WhenNonDirectTaxCountry_SetsReverseCharge( + SutProvider sutProvider, + Organization organization) + { + var plan = new EnterprisePlan(isAnnual: true); + organization.PlanType = PlanType.EnterpriseAnnually; + organization.GatewaySubscriptionId = "sub_123"; + organization.Seats = 0; + organization.UseSecretsManager = false; + organization.MaxStorageGb = null; + + var subscription = new Subscription + { + Id = "sub_123", + Status = "active", + Customer = new Customer + { + Id = "cus_123", + Address = new Address { Country = "DE" }, + TaxExempt = TaxExempt.None + }, + Items = new StripeList + { + Data = + [ + new SubscriptionItem + { + Price = new Price { Id = plan.PasswordManager.StripeSeatPlanId }, + Plan = new Stripe.Plan { Id = plan.PasswordManager.StripeSeatPlanId }, + Quantity = 0 + } + ] + } + }; + + sutProvider.GetDependency() + .GetPlanOrThrow(PlanType.EnterpriseAnnually) + .Returns(plan); + + sutProvider.GetDependency() + .GetSubscriptionAsync(organization.GatewaySubscriptionId, Arg.Any()) + .Returns(subscription); + + sutProvider.GetDependency() + .UpdateSubscriptionAsync(Arg.Any(), Arg.Any()) + .Returns(new Subscription { Id = "sub_123", LatestInvoiceId = "inv_123" }); + + sutProvider.GetDependency() + .GetInvoiceAsync("inv_123", Arg.Any()) + .Returns(new Invoice { Id = "inv_123", AmountDue = 0, Status = InvoiceStatus.Paid }); + + sutProvider.GetDependency() + .GetCustomerAsync("cus_123") + .Returns(new Customer { Id = "cus_123" }); + + await sutProvider.Sut.AdjustSubscription(organization, plan, 0, false, null, null, 0); + + await sutProvider.GetDependency().Received(1).UpdateCustomerAsync( + "cus_123", + Arg.Is(o => o.TaxExempt == TaxExempt.Reverse)); + } + + [Theory, BitAutoData] + public async Task AdjustSubscription_WhenUSWithManualReverse_CorrectsTaxExemptToNone( + SutProvider sutProvider, + Organization organization) + { + var plan = new EnterprisePlan(isAnnual: true); + organization.PlanType = PlanType.EnterpriseAnnually; + organization.GatewaySubscriptionId = "sub_123"; + organization.Seats = 0; + organization.UseSecretsManager = false; + organization.MaxStorageGb = null; + + var subscription = new Subscription + { + Id = "sub_123", + Status = "active", + Customer = new Customer + { + Id = "cus_123", + Address = new Address { Country = "US" }, + TaxExempt = TaxExempt.Reverse + }, + Items = new StripeList + { + Data = + [ + new SubscriptionItem + { + Price = new Price { Id = plan.PasswordManager.StripeSeatPlanId }, + Plan = new Stripe.Plan { Id = plan.PasswordManager.StripeSeatPlanId }, + Quantity = 0 + } + ] + } + }; + + sutProvider.GetDependency() + .GetPlanOrThrow(PlanType.EnterpriseAnnually) + .Returns(plan); + + sutProvider.GetDependency() + .GetSubscriptionAsync(organization.GatewaySubscriptionId, Arg.Any()) + .Returns(subscription); + + sutProvider.GetDependency() + .UpdateSubscriptionAsync(Arg.Any(), Arg.Any()) + .Returns(new Subscription { Id = "sub_123", LatestInvoiceId = "inv_123" }); + + sutProvider.GetDependency() + .GetInvoiceAsync("inv_123", Arg.Any()) + .Returns(new Invoice { Id = "inv_123", AmountDue = 0, Status = InvoiceStatus.Paid }); + + sutProvider.GetDependency() + .GetCustomerAsync("cus_123") + .Returns(new Customer { Id = "cus_123" }); + + await sutProvider.Sut.AdjustSubscription(organization, plan, 0, false, null, null, 0); + + await sutProvider.GetDependency().Received(1).UpdateCustomerAsync( + "cus_123", + Arg.Is(o => o.TaxExempt == TaxExempt.None)); + } + + [Theory, BitAutoData] + public async Task AdjustSubscription_WhenSwissWithManualReverse_PreservesReverse( + SutProvider sutProvider, + Organization organization) + { + var plan = new EnterprisePlan(isAnnual: true); + organization.PlanType = PlanType.EnterpriseAnnually; + organization.GatewaySubscriptionId = "sub_123"; + organization.Seats = 0; + organization.UseSecretsManager = false; + organization.MaxStorageGb = null; + + var subscription = new Subscription + { + Id = "sub_123", + Status = "active", + Customer = new Customer + { + Id = "cus_123", + Address = new Address { Country = "CH" }, + TaxExempt = TaxExempt.Reverse + }, + Items = new StripeList + { + Data = + [ + new SubscriptionItem + { + Price = new Price { Id = plan.PasswordManager.StripeSeatPlanId }, + Plan = new Stripe.Plan { Id = plan.PasswordManager.StripeSeatPlanId }, + Quantity = 0 + } + ] + } + }; + + sutProvider.GetDependency() + .GetPlanOrThrow(PlanType.EnterpriseAnnually) + .Returns(plan); + + sutProvider.GetDependency() + .GetSubscriptionAsync(organization.GatewaySubscriptionId, Arg.Any()) + .Returns(subscription); + + sutProvider.GetDependency() + .UpdateSubscriptionAsync(Arg.Any(), Arg.Any()) + .Returns(new Subscription { Id = "sub_123", LatestInvoiceId = "inv_123" }); + + sutProvider.GetDependency() + .GetInvoiceAsync("inv_123", Arg.Any()) + .Returns(new Invoice { Id = "inv_123", AmountDue = 0, Status = InvoiceStatus.Paid }); + + sutProvider.GetDependency() + .GetCustomerAsync("cus_123") + .Returns(new Customer { Id = "cus_123" }); + + await sutProvider.Sut.AdjustSubscription(organization, plan, 0, false, null, null, 0); + + // Manual reverse charge override is preserved for Switzerland — no customer update + await sutProvider.GetDependency().DidNotReceive().UpdateCustomerAsync( + Arg.Any(), + Arg.Any()); + } + + #endregion } diff --git a/test/Core.Test/Billing/Tax/TaxHelpersTests.cs b/test/Core.Test/Billing/Tax/TaxHelpersTests.cs new file mode 100644 index 000000000000..c56765434b66 --- /dev/null +++ b/test/Core.Test/Billing/Tax/TaxHelpersTests.cs @@ -0,0 +1,39 @@ +using Bit.Core.Billing.Tax.Utilities; +using Xunit; +using CountryAbbreviations = Bit.Core.Constants.CountryAbbreviations; +using TaxExempt = Bit.Core.Billing.Constants.StripeConstants.TaxExempt; + +namespace Bit.Core.Test.Billing.Tax; + +public class TaxHelpersTests +{ + + [Theory] + [InlineData(CountryAbbreviations.UnitedStates, true)] + [InlineData(CountryAbbreviations.Switzerland, true)] + [InlineData("DE", false)] + [InlineData(null, false)] + [InlineData("", false)] + public void IsDirectTaxCountry_ReturnsExpectedResult(string? country, bool expected) + { + var result = TaxHelpers.IsDirectTaxCountry(country); + Assert.Equal(expected, result); + } + + [Theory] + [InlineData("DE", TaxExempt.None, TaxExempt.Reverse)] // non-direct-tax → Reverse + [InlineData(CountryAbbreviations.UnitedStates, TaxExempt.Reverse, TaxExempt.None)] // US manual Reverse → None + [InlineData(CountryAbbreviations.Switzerland, null, TaxExempt.None)] // CH no existing status → None + [InlineData(CountryAbbreviations.UnitedStates, TaxExempt.None, TaxExempt.None)] // US already None → None + [InlineData(CountryAbbreviations.Switzerland, TaxExempt.Reverse, TaxExempt.Reverse)] // CH manual Reverse → preserved + [InlineData("DE", TaxExempt.Reverse, TaxExempt.Reverse)] // non-direct-tax already Reverse → Reverse + [InlineData(null, TaxExempt.None, TaxExempt.Reverse)] // unknown country → Reverse + public void DetermineTaxExemptStatus_ReturnsExpectedResult( + string? country, + string? currentTaxExempt, + string expected) + { + var result = TaxHelpers.DetermineTaxExemptStatus(country, currentTaxExempt); + Assert.Equal(expected, result); + } +}