Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
11e9efd
refactor(billing): change billing address request type
sbrown-livefront Mar 3, 2026
1c66f0a
feat(billing): add tax id support for international business plans
sbrown-livefront Mar 3, 2026
d0d664d
feat(billing): add billing address tax id handling
sbrown-livefront Mar 3, 2026
8389b48
test: add tests for tax id handling during upgrade
sbrown-livefront Mar 3, 2026
4f799ae
fix(billing): run dotnet format
sbrown-livefront Mar 3, 2026
25738f0
fix(billing): remove extra line
sbrown-livefront Mar 4, 2026
0c206f3
Merge branch 'main' into billing/pm-33061/tax-id-should-update-when-u…
sbrown-livefront Mar 4, 2026
7d34622
Merge branch 'main' into billing/pm-33061/tax-id-should-update-when-u…
sbrown-livefront Mar 4, 2026
b359e4f
fix(billing): modify return type of HandleAsync
sbrown-livefront Mar 4, 2026
590c055
test(billing): update tests to reflect updated command signature
sbrown-livefront Mar 4, 2026
507bf44
fix(billing): run dotnet format
sbrown-livefront Mar 4, 2026
df1f8f0
Merge branch 'main' into billing/pm-33061/tax-id-should-update-when-u…
sbrown-livefront Mar 4, 2026
ef14e35
tests(billing): fix tests
sbrown-livefront Mar 4, 2026
c18ea1b
Merge branch 'main' into billing/pm-33061/tax-id-should-update-when-u…
sbrown-livefront Mar 4, 2026
f3dc91a
test(billing): format
sbrown-livefront Mar 4, 2026
ccd2ea3
Merge branch 'main' into billing/pm-33061/tax-id-should-update-when-u…
sbrown-livefront Mar 5, 2026
69171f2
Merge branch 'main' into billing/pm-33061/tax-id-should-update-when-u…
sbrown-livefront Mar 9, 2026
8227aaf
Merge branch 'main' into billing/pm-33061/tax-id-should-update-when-u…
sbrown-livefront Mar 9, 2026
e13ef58
Merge branch 'main' into billing/pm-33061/tax-id-should-update-when-u…
sbrown-livefront Mar 9, 2026
7b8bd46
Merge branch 'main' into billing/pm-33061/tax-id-should-update-when-u…
sbrown-livefront Mar 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public class UpgradePremiumToOrganizationRequest
public required ProductTierType TargetProductTierType { get; set; }

[Required]
public required MinimalBillingAddressRequest BillingAddress { get; set; }
public required CheckoutBillingAddressRequest BillingAddress { get; set; }

private PlanType PlanType
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services;
using Bit.Core.Entities;
Expand All @@ -12,6 +13,8 @@
using Bit.Core.Utilities;
using Microsoft.Extensions.Logging;
using Stripe;
using CountryAbbreviations = Bit.Core.Constants.CountryAbbreviations;
using TaxExempt = Bit.Core.Billing.Constants.StripeConstants.TaxExempt;

namespace Bit.Core.Billing.Premium.Commands;
/// <summary>
Expand Down Expand Up @@ -40,7 +43,7 @@ Task<BillingCommandResult<Guid>> Run(
string encryptedPrivateKey,
string? collectionName,
PlanType targetPlanType,
Payment.Models.BillingAddress billingAddress);
BillingAddress billingAddress);
}

public class UpgradePremiumToOrganizationCommand(
Expand All @@ -65,7 +68,7 @@ public Task<BillingCommandResult<Guid>> Run(
string encryptedPrivateKey,
string? collectionName,
PlanType targetPlanType,
Payment.Models.BillingAddress billingAddress) => HandleAsync<Guid>(async () =>
BillingAddress billingAddress) => HandleAsync<Guid>(async () =>
{
// Validate that the user has an active Premium subscription
if (user is not { Premium: true, GatewaySubscriptionId: not null and not "" })
Expand Down Expand Up @@ -198,9 +201,16 @@ public Task<BillingCommandResult<Guid>> Run(
{
Country = billingAddress.Country,
PostalCode = billingAddress.PostalCode
}
},
TaxExempt = billingAddress.Country != CountryAbbreviations.UnitedStates ? TaxExempt.Reverse : TaxExempt.None
});

// Add tax ID to customer for accurate tax calculation if provided
if (billingAddress.TaxId != null)
{
await AddTaxIdToCustomerAsync(user, billingAddress.TaxId);
}

// Update the subscription in Stripe
await stripeAdapter.UpdateSubscriptionAsync(currentSubscription.Id, subscriptionUpdateOptions);

Expand Down Expand Up @@ -271,4 +281,26 @@ await organizationApiKeyRepository.CreateAsync(new OrganizationApiKey

return organization.Id;
});

/// <summary>
/// Adds a tax ID to the Stripe customer for accurate tax calculation.
/// If the tax ID is a Spanish NIF, also adds the corresponding EU VAT ID.
/// </summary>
/// <param name="user"> The user whose Stripe customer will be updated with the tax ID.</param>
/// <param name="taxId"> The tax ID to add, including the type and value.</param>
private async Task AddTaxIdToCustomerAsync(User user, TaxID taxId)
{
await stripeAdapter.CreateTaxIdAsync(user.GatewayCustomerId,
new TaxIdCreateOptions { Type = taxId.Code, Value = taxId.Value });

if (taxId.Code == StripeConstants.TaxIdType.SpanishNIF)
{
await stripeAdapter.CreateTaxIdAsync(user.GatewayCustomerId,
new TaxIdCreateOptions
{
Type = StripeConstants.TaxIdType.EUVAT,
Value = $"ES{taxId.Value}"
});
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public void ToDomain_ValidTierTypes_ReturnsPlanType(ProductTierType tierType, Pl
EncryptedPrivateKey = "encrypted-private-key",
CollectionName = "Default Collection",
TargetProductTierType = tierType,
BillingAddress = new MinimalBillingAddressRequest
BillingAddress = new CheckoutBillingAddressRequest
{
Country = "US",
PostalCode = "12345"
Expand Down Expand Up @@ -56,7 +56,7 @@ public void ToDomain_InvalidTierTypes_ThrowsInvalidOperationException(ProductTie
PublicKey = "public-key",
EncryptedPrivateKey = "encrypted-private-key",
TargetProductTierType = tierType,
BillingAddress = new MinimalBillingAddressRequest
BillingAddress = new CheckoutBillingAddressRequest
{
Country = "US",
PostalCode = "12345"
Expand All @@ -67,4 +67,53 @@ public void ToDomain_InvalidTierTypes_ThrowsInvalidOperationException(ProductTie
var exception = Assert.Throws<InvalidOperationException>(() => sut.ToDomain());
Assert.Contains($"Cannot upgrade Premium subscription to {tierType} plan", exception.Message);
}

[Theory]
[InlineData(ProductTierType.Teams, PlanType.TeamsAnnually, "DE", "10115", "eu_vat", "DE123456789")]
[InlineData(ProductTierType.Enterprise, PlanType.EnterpriseAnnually, "FR", "75001", "eu_vat", "FR12345678901")]
public void ToDomain_BusinessPlansWithNonUsTaxId_IncludesTaxIdInBillingAddress(
ProductTierType tierType,
PlanType expectedPlanType,
string country,
string postalCode,
string taxIdCode,
string taxIdValue)
{
// Arrange
var sut = new UpgradePremiumToOrganizationRequest
{
OrganizationName = "International Business",
Key = "encrypted-key",
TargetProductTierType = tierType,
PublicKey = "public-key",
EncryptedPrivateKey = "encrypted-private-key",
CollectionName = "Default Collection",
BillingAddress = new CheckoutBillingAddressRequest
{
Country = country,
PostalCode = postalCode,
TaxId = new CheckoutBillingAddressRequest.TaxIdRequest
{
Code = taxIdCode,
Value = taxIdValue
}
}
};

// Act
var (organizationName, key, publicKey, encryptedPrivateKey, collectionName, planType, billingAddress) = sut.ToDomain();

// Assert
Assert.Equal("International Business", organizationName);
Assert.Equal("encrypted-key", key);
Assert.Equal("public-key", publicKey);
Assert.Equal("encrypted-private-key", encryptedPrivateKey);
Assert.Equal("Default Collection", collectionName);
Assert.Equal(expectedPlanType, planType);
Assert.Equal(country, billingAddress.Country);
Assert.Equal(postalCode, billingAddress.PostalCode);
Assert.NotNull(billingAddress.TaxId);
Assert.Equal(taxIdCode, billingAddress.TaxId.Code);
Assert.Equal(taxIdValue, billingAddress.TaxId.Value);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1202,4 +1202,200 @@ await _collectionRepository.Received(1).CreateAsync(
Arg.Any<IEnumerable<CollectionAccessSelection>>(),
Arg.Any<IEnumerable<CollectionAccessSelection>>());
}

[Theory, BitAutoData]
public async Task Run_WithNoTaxId_SetsTaxExemptToNone_DoesNotCreateTaxId(User user)
{
// Arrange
user.Premium = true;
user.GatewaySubscriptionId = "sub_123";
user.GatewayCustomerId = "cus_123";

var mockSubscription = new Subscription
{
Id = "sub_123",
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new SubscriptionItem
{
Id = "si_premium",
Price = new Price { Id = "premium-annually" }
}
}
},
Metadata = new Dictionary<string, string>()
};

var mockPremiumPlans = CreateTestPremiumPlansList();
var mockPlan = CreateTestPlan(PlanType.TeamsAnnually, stripeSeatPlanId: "teams-seat-annually");

_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(mockSubscription);
_pricingClient.ListPremiumPlans().Returns(mockPremiumPlans);
_pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(mockPlan);
_stripeAdapter.UpdateSubscriptionAsync(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>()).Returns(mockSubscription);
_stripeAdapter.UpdateCustomerAsync(Arg.Any<string>(), Arg.Any<CustomerUpdateOptions>()).Returns(Task.FromResult(new Customer()));
_organizationRepository.CreateAsync(Arg.Any<Organization>()).Returns(callInfo => Task.FromResult(callInfo.Arg<Organization>()));
_organizationApiKeyRepository.CreateAsync(Arg.Any<OrganizationApiKey>()).Returns(callInfo => Task.FromResult(callInfo.Arg<OrganizationApiKey>()));
_organizationUserRepository.CreateAsync(Arg.Any<OrganizationUser>()).Returns(callInfo => Task.FromResult(callInfo.Arg<OrganizationUser>()));
_applicationCacheService.UpsertOrganizationAbilityAsync(Arg.Any<Organization>()).Returns(Task.CompletedTask);
_userService.SaveUserAsync(user).Returns(Task.CompletedTask);

var billingAddress = new Core.Billing.Payment.Models.BillingAddress
{
Country = "US",
PostalCode = "12345",
TaxId = null
};

// Act
var result = await _command.Run(user, "My Organization", "encrypted-key", "public-key", "encrypted-private-key", "Default Collection", PlanType.TeamsAnnually, billingAddress);

// Assert
Assert.True(result.IsT0);
await _stripeAdapter.Received(1).UpdateCustomerAsync(
"cus_123",
Arg.Is<CustomerUpdateOptions>(options =>
options.TaxExempt == StripeConstants.TaxExempt.None));
await _stripeAdapter.DidNotReceive().CreateTaxIdAsync(Arg.Any<string>(), Arg.Any<TaxIdCreateOptions>());
}

[Theory, BitAutoData]
public async Task Run_WithTaxId_SetsTaxExemptToReverse_CreatesOneTaxId(User user)
{
// Arrange
user.Premium = true;
user.GatewaySubscriptionId = "sub_123";
user.GatewayCustomerId = "cus_123";

var mockSubscription = new Subscription
{
Id = "sub_123",
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new SubscriptionItem
{
Id = "si_premium",
Price = new Price { Id = "premium-annually" }
}
}
},
Metadata = new Dictionary<string, string>()
};

var mockPremiumPlans = CreateTestPremiumPlansList();
var mockPlan = CreateTestPlan(PlanType.TeamsAnnually, stripeSeatPlanId: "teams-seat-annually");

_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(mockSubscription);
_pricingClient.ListPremiumPlans().Returns(mockPremiumPlans);
_pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(mockPlan);
_stripeAdapter.UpdateSubscriptionAsync(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>()).Returns(mockSubscription);
_stripeAdapter.UpdateCustomerAsync(Arg.Any<string>(), Arg.Any<CustomerUpdateOptions>()).Returns(Task.FromResult(new Customer()));
_stripeAdapter.CreateTaxIdAsync(Arg.Any<string>(), Arg.Any<TaxIdCreateOptions>()).Returns(new TaxId());
_organizationRepository.CreateAsync(Arg.Any<Organization>()).Returns(callInfo => Task.FromResult(callInfo.Arg<Organization>()));
_organizationApiKeyRepository.CreateAsync(Arg.Any<OrganizationApiKey>()).Returns(callInfo => Task.FromResult(callInfo.Arg<OrganizationApiKey>()));
_organizationUserRepository.CreateAsync(Arg.Any<OrganizationUser>()).Returns(callInfo => Task.FromResult(callInfo.Arg<OrganizationUser>()));
_applicationCacheService.UpsertOrganizationAbilityAsync(Arg.Any<Organization>()).Returns(Task.CompletedTask);
_userService.SaveUserAsync(user).Returns(Task.CompletedTask);

var billingAddress = new Core.Billing.Payment.Models.BillingAddress
{
Country = "DE",
PostalCode = "10115",
TaxId = new Core.Billing.Payment.Models.TaxID("eu_vat", "DE123456789")
};

// Act
var result = await _command.Run(user, "My Organization", "encrypted-key", "public-key", "encrypted-private-key", "Default Collection", PlanType.TeamsAnnually, billingAddress);

// Assert
Assert.True(result.IsT0);
await _stripeAdapter.Received(1).UpdateCustomerAsync(
"cus_123",
Arg.Is<CustomerUpdateOptions>(options =>
options.TaxExempt == StripeConstants.TaxExempt.Reverse));
await _stripeAdapter.Received(1).CreateTaxIdAsync(
"cus_123",
Arg.Is<TaxIdCreateOptions>(options =>
options.Type == "eu_vat" &&
options.Value == "DE123456789"));
}

[Theory, BitAutoData]
public async Task Run_WithSpanishNIF_SetsTaxExemptToReverse_CreatesBothSpanishNIFAndEUVAT(User user)
{
// Arrange
user.Premium = true;
user.GatewaySubscriptionId = "sub_123";
user.GatewayCustomerId = "cus_123";

var mockSubscription = new Subscription
{
Id = "sub_123",
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new SubscriptionItem
{
Id = "si_premium",
Price = new Price { Id = "premium-annually" }
}
}
},
Metadata = new Dictionary<string, string>()
};

var mockPremiumPlans = CreateTestPremiumPlansList();
var mockPlan = CreateTestPlan(PlanType.TeamsAnnually, stripeSeatPlanId: "teams-seat-annually");

_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(mockSubscription);
_pricingClient.ListPremiumPlans().Returns(mockPremiumPlans);
_pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(mockPlan);
_stripeAdapter.UpdateSubscriptionAsync(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>()).Returns(mockSubscription);
_stripeAdapter.UpdateCustomerAsync(Arg.Any<string>(), Arg.Any<CustomerUpdateOptions>()).Returns(Task.FromResult(new Customer()));
_stripeAdapter.CreateTaxIdAsync(Arg.Any<string>(), Arg.Any<TaxIdCreateOptions>()).Returns(new TaxId());
_organizationRepository.CreateAsync(Arg.Any<Organization>()).Returns(callInfo => Task.FromResult(callInfo.Arg<Organization>()));
_organizationApiKeyRepository.CreateAsync(Arg.Any<OrganizationApiKey>()).Returns(callInfo => Task.FromResult(callInfo.Arg<OrganizationApiKey>()));
_organizationUserRepository.CreateAsync(Arg.Any<OrganizationUser>()).Returns(callInfo => Task.FromResult(callInfo.Arg<OrganizationUser>()));
_applicationCacheService.UpsertOrganizationAbilityAsync(Arg.Any<Organization>()).Returns(Task.CompletedTask);
_userService.SaveUserAsync(user).Returns(Task.CompletedTask);

var billingAddress = new Core.Billing.Payment.Models.BillingAddress
{
Country = "ES",
PostalCode = "28001",
TaxId = new Core.Billing.Payment.Models.TaxID(StripeConstants.TaxIdType.SpanishNIF, "A12345678")
};

// Act
var result = await _command.Run(user, "My Organization", "encrypted-key", "public-key", "encrypted-private-key", "Default Collection", PlanType.TeamsAnnually, billingAddress);

// Assert
Assert.True(result.IsT0);

await _stripeAdapter.Received(1).UpdateCustomerAsync(
"cus_123",
Arg.Is<CustomerUpdateOptions>(options =>
options.TaxExempt == StripeConstants.TaxExempt.Reverse));

// Verify Spanish NIF was created
await _stripeAdapter.Received(1).CreateTaxIdAsync(
"cus_123",
Arg.Is<TaxIdCreateOptions>(options =>
options.Type == StripeConstants.TaxIdType.SpanishNIF &&
options.Value == "A12345678"));

// Verify EU VAT was created with ES prefix
await _stripeAdapter.Received(1).CreateTaxIdAsync(
"cus_123",
Arg.Is<TaxIdCreateOptions>(options =>
options.Type == StripeConstants.TaxIdType.EUVAT &&
options.Value == "ESA12345678"));


}
}
Loading