Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
ο»Ώusing Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Subscriptions.Models;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Models.Data;
Expand All @@ -13,10 +13,13 @@
using Bit.Core.Utilities;
using Microsoft.Extensions.Logging;
using Stripe;
using static Bit.Core.Billing.Constants.StripeConstants;
using static Bit.Core.Billing.Utilities;
using CountryAbbreviations = Bit.Core.Constants.CountryAbbreviations;
using TaxExempt = Bit.Core.Billing.Constants.StripeConstants.TaxExempt;
using PremiumPlan = Bit.Core.Billing.Pricing.Premium.Plan;

namespace Bit.Core.Billing.Premium.Commands;

/// <summary>
/// Upgrades a user's Premium subscription to an Organization plan by creating a new Organization
/// and transferring the subscription from the User to the Organization.
Expand Down Expand Up @@ -55,7 +58,8 @@ public class UpgradePremiumToOrganizationCommand(
IOrganizationUserRepository organizationUserRepository,
IOrganizationApiKeyRepository organizationApiKeyRepository,
ICollectionRepository collectionRepository,
IApplicationCacheService applicationCacheService)
IApplicationCacheService applicationCacheService,
IBraintreeService braintreeService)
: BaseBillingCommand<UpgradePremiumToOrganizationCommand>(logger), IUpgradePremiumToOrganizationCommand
{
private readonly ILogger<UpgradePremiumToOrganizationCommand> _logger = logger;
Expand All @@ -76,15 +80,7 @@ public Task<BillingCommandResult<Guid>> Run(
return new BadRequest("User does not have an active Premium subscription.");
}

// Fetch the current Premium subscription from Stripe
var currentSubscription = await stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId);

// Fetch all premium plans to find which specific plan the user is on
var premiumPlans = await pricingClient.ListPremiumPlans();

// Find the password manager subscription item (seat, not storage) and match it to a plan
var passwordManagerItem = currentSubscription.Items.Data.FirstOrDefault(i =>
premiumPlans.Any(p => p.Seat.StripePriceId == i.Price.Id));
var (currentSubscription, premiumPlans, passwordManagerItem) = await GetPremiumPlanAndSubscriptionDetailsAsync(user);

if (passwordManagerItem == null)
{
Expand All @@ -96,66 +92,126 @@ public Task<BillingCommandResult<Guid>> Run(
// Get the target organization plan
var targetPlan = await pricingClient.GetPlanOrThrow(targetPlanType);

var subscriptionItemOptions = BuildSubscriptionItemOptions(
currentSubscription, usersPremiumPlan, targetPlan, passwordManagerItem);

// Generate organization ID early to include in metadata
var organizationId = CoreHelpers.GenerateComb();

// Create the Organization entity
var organization = BuildOrganization(
organizationId, user, organizationName, publicKey, encryptedPrivateKey, targetPlan, currentSubscription.Id);

// Update customer billing address for tax calculation
var customer = await stripeAdapter.UpdateCustomerAsync(user.GatewayCustomerId, new CustomerUpdateOptions
{
Address = new AddressOptions
{
Country = billingAddress.Country,
PostalCode = billingAddress.PostalCode
},
TaxExempt = billingAddress.Country != CountryAbbreviations.UnitedStates ? TaxExempt.Reverse : TaxExempt.None
});


await UpdateSubscriptionAsync(currentSubscription.Id, organizationId, customer, subscriptionItemOptions);

// Add tax ID to the customer for accurate tax calculation if provided
if (billingAddress.TaxId != null)
{
await AddTaxIdToCustomerAsync(user.GatewayCustomerId!, billingAddress.TaxId);
}

var organizationUser = await SaveOrganizationAsync(organization, user, key);

// Create a default collection if a collection name is provided
if (!string.IsNullOrWhiteSpace(collectionName))
{
await CreateDefaultCollectionAsync(organization, organizationUser, collectionName);
}

// Remove subscription from user
user.Premium = false;
user.PremiumExpirationDate = null;
user.GatewaySubscriptionId = null;
user.GatewayCustomerId = null;
user.RevisionDate = DateTime.UtcNow;
await userService.SaveUserAsync(user);

return organization.Id;
});

private async Task<(Subscription currentSubscription, List<PremiumPlan> premiumPlans, SubscriptionItem? passwordManagerItem)> GetPremiumPlanAndSubscriptionDetailsAsync(User user)
{
// Fetch the current Premium subscription from Stripe
var currentSubscription = await stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId);

// Fetch all premium plans to find which specific plan the user is on
var premiumPlans = await pricingClient.ListPremiumPlans();

// Find the password manager subscription item (seat, not storage) and match it to a plan
var passwordManagerItem = currentSubscription.Items.Data.FirstOrDefault(i =>
premiumPlans.Any(p => p.Seat.StripePriceId == i.Price.Id));

return (currentSubscription, premiumPlans, passwordManagerItem);
}

private List<SubscriptionItemOptions> BuildSubscriptionItemOptions(
Subscription currentSubscription,
PremiumPlan usersPremiumPlan,
Core.Models.StaticStore.Plan targetPlan,
SubscriptionItem passwordManagerItem)
{
var isNonSeatBasedPmPlan = targetPlan.HasNonSeatBasedPasswordManagerPlan();

// if the target plan is non-seat-based, set seats to the base seats of the target plan, otherwise set to 1
var initialSeats = isNonSeatBasedPmPlan ? targetPlan.PasswordManager.BaseSeats : 1;

// Build the list of subscription item updates
var subscriptionItemOptions = new List<SubscriptionItemOptions>();
var options = new List<SubscriptionItemOptions>();

// Delete the storage item if it exists for this user's plan
var storageItem = currentSubscription.Items.Data.FirstOrDefault(i =>
i.Price.Id == usersPremiumPlan.Storage.StripePriceId);

if (storageItem != null)
{
subscriptionItemOptions.Add(new SubscriptionItemOptions
{
Id = storageItem.Id,
Deleted = true
});
options.Add(new SubscriptionItemOptions { Id = storageItem.Id, Deleted = true });
}

// Add new organization subscription items
if (isNonSeatBasedPmPlan)
{
subscriptionItemOptions.Add(new SubscriptionItemOptions
options.Add(isNonSeatBasedPmPlan
? new SubscriptionItemOptions
{
Id = passwordManagerItem.Id,
Price = targetPlan.PasswordManager.StripePlanId,
Quantity = 1
});
}
else
{
subscriptionItemOptions.Add(new SubscriptionItemOptions
}
: new SubscriptionItemOptions
{
Id = passwordManagerItem.Id,
Price = targetPlan.PasswordManager.StripeSeatPlanId,
Quantity = initialSeats
});
}

// Generate organization ID early to include in metadata
var organizationId = CoreHelpers.GenerateComb();
return options;
}

// Build the subscription update options
var subscriptionUpdateOptions = new SubscriptionUpdateOptions
{
Items = subscriptionItemOptions,
ProrationBehavior = StripeConstants.ProrationBehavior.AlwaysInvoice,
BillingCycleAnchor = SubscriptionBillingCycleAnchor.Unchanged,
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true },
Metadata = new Dictionary<string, string>
{
[StripeConstants.MetadataKeys.OrganizationId] = organizationId.ToString(),
[StripeConstants.MetadataKeys.UserId] = string.Empty // Remove userId to unlink subscription from User
}
};
private Organization BuildOrganization(
Guid organizationId,
User user,
string organizationName,
string publicKey,
string encryptedPrivateKey,
Core.Models.StaticStore.Plan targetPlan,
string subscriptionId)
{
var isNonSeatBasedPmPlan = targetPlan.HasNonSeatBasedPasswordManagerPlan();

// Create the Organization entity
var organization = new Organization
// if the target plan is non-seat-based, set seats to the base seats of the target plan, otherwise set to 1
var initialSeats = isNonSeatBasedPmPlan ? targetPlan.PasswordManager.BaseSeats : 1;

return new Organization
{
Id = organizationId,
Name = organizationName,
Expand All @@ -165,7 +221,7 @@ public Task<BillingCommandResult<Guid>> Run(
MaxCollections = targetPlan.PasswordManager.MaxCollections,
MaxStorageGb = targetPlan.PasswordManager.BaseStorageGb,
UsePolicies = targetPlan.HasPolicies,
UseMyItems = targetPlan.HasPolicies, // TODO: use the plan property when added (PM-32366)
UseMyItems = targetPlan.HasMyItems,
UseSso = targetPlan.HasSso,
UseGroups = targetPlan.HasGroups,
UseEvents = targetPlan.HasEvents,
Expand All @@ -191,33 +247,52 @@ public Task<BillingCommandResult<Guid>> Run(
UseSecretsManager = false,
UseOrganizationDomains = targetPlan.HasOrganizationDomains,
GatewayCustomerId = user.GatewayCustomerId,
GatewaySubscriptionId = currentSubscription.Id
GatewaySubscriptionId = subscriptionId
};
}

// Update customer billing address for tax calculation
await stripeAdapter.UpdateCustomerAsync(user.GatewayCustomerId, new CustomerUpdateOptions
private async Task UpdateSubscriptionAsync(
string subscriptionId,
Guid organizationId,
Customer customer,
List<SubscriptionItemOptions> subscriptionItemOptions)
{
var usingPayPal = customer.Metadata?.ContainsKey(BraintreeCustomerIdKey) ?? false;

// Build the subscription update options
var subscriptionUpdateOptions = new SubscriptionUpdateOptions
{
Address = new AddressOptions
Items = subscriptionItemOptions,
ProrationBehavior = ProrationBehavior.AlwaysInvoice,
BillingCycleAnchor = SubscriptionBillingCycleAnchor.Unchanged,
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true },
Metadata = new Dictionary<string, string>
{
Country = billingAddress.Country,
PostalCode = billingAddress.PostalCode
[MetadataKeys.OrganizationId] = organizationId.ToString(),
[MetadataKeys.UserId] = string.Empty // Remove userId to unlink the subscription from User
},
TaxExempt = billingAddress.Country != CountryAbbreviations.UnitedStates ? TaxExempt.Reverse : TaxExempt.None
});
PaymentBehavior = usingPayPal ? PaymentBehavior.DefaultIncomplete : null
};

// Add tax ID to customer for accurate tax calculation if provided
if (billingAddress.TaxId != null)
// Update the subscription in Stripe
var subscription = await stripeAdapter.UpdateSubscriptionAsync(subscriptionId, subscriptionUpdateOptions);

// If using PayPal, update the subscription in Braintree
if (usingPayPal)
{
await AddTaxIdToCustomerAsync(user, billingAddress.TaxId);
await PayInvoiceUsingPayPalAsync(subscription, organizationId);
}
}

// Update the subscription in Stripe
await stripeAdapter.UpdateSubscriptionAsync(currentSubscription.Id, subscriptionUpdateOptions);

private async Task<OrganizationUser> SaveOrganizationAsync(
Organization organization,
User user,
string key)
{
// Save the organization
await organizationRepository.CreateAsync(organization);

// Create organization API key
// Create the organization API key
await organizationApiKeyRepository.CreateAsync(new OrganizationApiKey
{
OrganizationId = organization.Id,
Expand All @@ -244,61 +319,66 @@ await organizationApiKeyRepository.CreateAsync(new OrganizationApiKey
organizationUser.SetNewId();
await organizationUserRepository.CreateAsync(organizationUser);

// Create default collection if collection name is provided
if (!string.IsNullOrWhiteSpace(collectionName))
return organizationUser;
}

private async Task CreateDefaultCollectionAsync(
Organization organization,
OrganizationUser organizationUser,
string collectionName)
{
try
{
try
{
// Give the owner Can Manage access over the default collection
List<CollectionAccessSelection> defaultOwnerAccess =
[new CollectionAccessSelection { Id = organizationUser.Id, HidePasswords = false, ReadOnly = false, Manage = true }];
// Give the owner Can Manage access over the default collection
List<CollectionAccessSelection> defaultOwnerAccess =
[new() { Id = organizationUser.Id, HidePasswords = false, ReadOnly = false, Manage = true }];

var defaultCollection = new Collection
{
Name = collectionName,
OrganizationId = organization.Id,
CreationDate = organization.CreationDate,
RevisionDate = organization.CreationDate
};
await collectionRepository.CreateAsync(defaultCollection, null, defaultOwnerAccess);
}
catch (Exception ex)
var defaultCollection = new Collection
{
_logger.LogWarning(ex,
"{Command}: Failed to create default collection for organization {OrganizationId}. Organization upgrade will continue.",
CommandName, organization.Id);
// Continue - organization is fully functional without default collection
}
Name = collectionName,
OrganizationId = organization.Id,
CreationDate = organization.CreationDate,
RevisionDate = organization.CreationDate
};
await collectionRepository.CreateAsync(defaultCollection, null, defaultOwnerAccess);
}
catch (Exception ex)
{
_logger.LogWarning(ex,
"{Command}: Failed to create default collection for organization {OrganizationId}. Organization upgrade will continue.",
CommandName, organization.Id);
// Continue - organization is fully functional without default collection
}
}

// Remove subscription from user
user.Premium = false;
user.PremiumExpirationDate = null;
user.GatewaySubscriptionId = null;
user.GatewayCustomerId = null;
user.RevisionDate = DateTime.UtcNow;
await userService.SaveUserAsync(user);
private async Task PayInvoiceUsingPayPalAsync(Subscription subscription, Guid organizationId)
{
var invoice = await stripeAdapter.UpdateInvoiceAsync(subscription.LatestInvoiceId, new InvoiceUpdateOptions
{
AutoAdvance = false,
Expand = ["customer"]
});

return organization.Id;
});
await braintreeService.PayInvoice(new UserId(organizationId), invoice);
}

/// <summary>
/// Adds a tax ID to the Stripe customer for accurate tax calculation.
/// If the tax ID is a Spanish NIF, also adds the corresponding EU VAT ID.
/// </summary>
/// <param name="user"> The user whose Stripe customer will be updated with the tax ID.</param>
/// <param name="taxId"> The tax ID to add, including the type and value.</param>
private async Task AddTaxIdToCustomerAsync(User user, TaxID taxId)
/// <param name="customerId">The Stripe customer ID to add the tax ID to.</param>
/// <param name="taxId">The tax ID to add, including the type and value.</param>
private async Task AddTaxIdToCustomerAsync(string customerId, TaxID taxId)
{
await stripeAdapter.CreateTaxIdAsync(user.GatewayCustomerId,
await stripeAdapter.CreateTaxIdAsync(customerId,
new TaxIdCreateOptions { Type = taxId.Code, Value = taxId.Value });

if (taxId.Code == StripeConstants.TaxIdType.SpanishNIF)
if (taxId.Code == TaxIdType.SpanishNIF)
{
await stripeAdapter.CreateTaxIdAsync(user.GatewayCustomerId,
await stripeAdapter.CreateTaxIdAsync(customerId,
new TaxIdCreateOptions
{
Type = StripeConstants.TaxIdType.EUVAT,
Type = TaxIdType.EUVAT,
Value = $"ES{taxId.Value}"
});
}
Expand Down
Loading
Loading