Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 55 additions & 51 deletions apps/sim/lib/billing/core/billing.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { getHighestPrioritySubscription } from '@/lib/billing/core/subscription'
import { getUserUsageData } from '@/lib/billing/core/usage'
import { getCreditBalance } from '@/lib/billing/credits/balance'
import { getFreeTierLimit, getPlanPricing } from '@/lib/billing/subscriptions/utils'
import { Decimal, toDecimal, toNumber } from '@/lib/billing/utils/decimal'

export { getPlanPricing }

Expand Down Expand Up @@ -99,18 +100,18 @@ export async function calculateSubscriptionOverage(sub: {
return 0
}

let totalOverage = 0
let totalOverageDecimal = new Decimal(0)

if (sub.plan === 'team') {
const members = await db
.select({ userId: member.userId })
.from(member)
.where(eq(member.organizationId, sub.referenceId))

let totalTeamUsage = 0
let totalTeamUsageDecimal = new Decimal(0)
for (const m of members) {
const usage = await getUserUsageData(m.userId)
totalTeamUsage += usage.currentUsage
totalTeamUsageDecimal = totalTeamUsageDecimal.plus(toDecimal(usage.currentUsage))
}

const orgData = await db
Expand All @@ -119,28 +120,29 @@ export async function calculateSubscriptionOverage(sub: {
.where(eq(organization.id, sub.referenceId))
.limit(1)

const departedUsage =
orgData.length > 0 && orgData[0].departedMemberUsage
? Number.parseFloat(orgData[0].departedMemberUsage)
: 0
const departedUsageDecimal =
orgData.length > 0 ? toDecimal(orgData[0].departedMemberUsage) : new Decimal(0)

const totalUsageWithDeparted = totalTeamUsage + departedUsage
const totalUsageWithDepartedDecimal = totalTeamUsageDecimal.plus(departedUsageDecimal)
const { basePrice } = getPlanPricing(sub.plan)
const baseSubscriptionAmount = (sub.seats ?? 0) * basePrice
totalOverage = Math.max(0, totalUsageWithDeparted - baseSubscriptionAmount)
totalOverageDecimal = Decimal.max(
0,
totalUsageWithDepartedDecimal.minus(baseSubscriptionAmount)
)

logger.info('Calculated team overage', {
subscriptionId: sub.id,
currentMemberUsage: totalTeamUsage,
departedMemberUsage: departedUsage,
totalUsage: totalUsageWithDeparted,
currentMemberUsage: toNumber(totalTeamUsageDecimal),
departedMemberUsage: toNumber(departedUsageDecimal),
totalUsage: toNumber(totalUsageWithDepartedDecimal),
baseSubscriptionAmount,
totalOverage,
totalOverage: toNumber(totalOverageDecimal),
})
} else if (sub.plan === 'pro') {
// Pro plan: include snapshot if user joined a team
const usage = await getUserUsageData(sub.referenceId)
let totalProUsage = usage.currentUsage
let totalProUsageDecimal = toDecimal(usage.currentUsage)

// Add any snapshotted Pro usage (from when they joined a team)
const userStatsRows = await db
Expand All @@ -150,41 +152,41 @@ export async function calculateSubscriptionOverage(sub: {
.limit(1)

if (userStatsRows.length > 0 && userStatsRows[0].proPeriodCostSnapshot) {
const snapshotUsage = Number.parseFloat(userStatsRows[0].proPeriodCostSnapshot.toString())
totalProUsage += snapshotUsage
const snapshotUsageDecimal = toDecimal(userStatsRows[0].proPeriodCostSnapshot)
totalProUsageDecimal = totalProUsageDecimal.plus(snapshotUsageDecimal)
logger.info('Including snapshotted Pro usage in overage calculation', {
userId: sub.referenceId,
currentUsage: usage.currentUsage,
snapshotUsage,
totalProUsage,
snapshotUsage: toNumber(snapshotUsageDecimal),
totalProUsage: toNumber(totalProUsageDecimal),
})
}

const { basePrice } = getPlanPricing(sub.plan)
totalOverage = Math.max(0, totalProUsage - basePrice)
totalOverageDecimal = Decimal.max(0, totalProUsageDecimal.minus(basePrice))

logger.info('Calculated pro overage', {
subscriptionId: sub.id,
totalProUsage,
totalProUsage: toNumber(totalProUsageDecimal),
basePrice,
totalOverage,
totalOverage: toNumber(totalOverageDecimal),
})
} else {
// Free plan or unknown plan type
const usage = await getUserUsageData(sub.referenceId)
const { basePrice } = getPlanPricing(sub.plan || 'free')
totalOverage = Math.max(0, usage.currentUsage - basePrice)
totalOverageDecimal = Decimal.max(0, toDecimal(usage.currentUsage).minus(basePrice))

logger.info('Calculated overage for plan', {
subscriptionId: sub.id,
plan: sub.plan || 'free',
usage: usage.currentUsage,
basePrice,
totalOverage,
totalOverage: toNumber(totalOverageDecimal),
})
}

return totalOverage
return toNumber(totalOverageDecimal)
}

/**
Expand Down Expand Up @@ -272,14 +274,16 @@ export async function getSimplifiedBillingSummary(
const licensedSeats = subscription.seats ?? 0
const totalBasePrice = basePricePerSeat * licensedSeats // Based on Stripe subscription

let totalCurrentUsage = 0
let totalCopilotCost = 0
let totalLastPeriodCopilotCost = 0
let totalCurrentUsageDecimal = new Decimal(0)
let totalCopilotCostDecimal = new Decimal(0)
let totalLastPeriodCopilotCostDecimal = new Decimal(0)

// Calculate total team usage across all members
for (const memberInfo of members) {
const memberUsageData = await getUserUsageData(memberInfo.userId)
totalCurrentUsage += memberUsageData.currentUsage
totalCurrentUsageDecimal = totalCurrentUsageDecimal.plus(
toDecimal(memberUsageData.currentUsage)
)

// Fetch copilot cost for this member
const memberStats = await db
Expand All @@ -292,17 +296,21 @@ export async function getSimplifiedBillingSummary(
.limit(1)

if (memberStats.length > 0) {
totalCopilotCost += Number.parseFloat(
memberStats[0].currentPeriodCopilotCost?.toString() || '0'
totalCopilotCostDecimal = totalCopilotCostDecimal.plus(
toDecimal(memberStats[0].currentPeriodCopilotCost)
)
totalLastPeriodCopilotCost += Number.parseFloat(
memberStats[0].lastPeriodCopilotCost?.toString() || '0'
totalLastPeriodCopilotCostDecimal = totalLastPeriodCopilotCostDecimal.plus(
toDecimal(memberStats[0].lastPeriodCopilotCost)
)
}
}

const totalCurrentUsage = toNumber(totalCurrentUsageDecimal)
const totalCopilotCost = toNumber(totalCopilotCostDecimal)
const totalLastPeriodCopilotCost = toNumber(totalLastPeriodCopilotCostDecimal)

// Calculate team-level overage: total usage beyond what was already paid to Stripe
const totalOverage = Math.max(0, totalCurrentUsage - totalBasePrice)
const totalOverage = toNumber(Decimal.max(0, totalCurrentUsageDecimal.minus(totalBasePrice)))

// Get user's personal limits for warnings
const percentUsed =
Expand Down Expand Up @@ -380,14 +388,10 @@ export async function getSimplifiedBillingSummary(
.limit(1)

const copilotCost =
userStatsRows.length > 0
? Number.parseFloat(userStatsRows[0].currentPeriodCopilotCost?.toString() || '0')
: 0
userStatsRows.length > 0 ? toNumber(toDecimal(userStatsRows[0].currentPeriodCopilotCost)) : 0

const lastPeriodCopilotCost =
userStatsRows.length > 0
? Number.parseFloat(userStatsRows[0].lastPeriodCopilotCost?.toString() || '0')
: 0
userStatsRows.length > 0 ? toNumber(toDecimal(userStatsRows[0].lastPeriodCopilotCost)) : 0

// For team and enterprise plans, calculate total team usage instead of individual usage
let currentUsage = usageData.currentUsage
Expand All @@ -400,12 +404,12 @@ export async function getSimplifiedBillingSummary(
.from(member)
.where(eq(member.organizationId, subscription.referenceId))

let totalTeamUsage = 0
let totalTeamCopilotCost = 0
let totalTeamLastPeriodCopilotCost = 0
let totalTeamUsageDecimal = new Decimal(0)
let totalTeamCopilotCostDecimal = new Decimal(0)
let totalTeamLastPeriodCopilotCostDecimal = new Decimal(0)
for (const teamMember of teamMembers) {
const memberUsageData = await getUserUsageData(teamMember.userId)
totalTeamUsage += memberUsageData.currentUsage
totalTeamUsageDecimal = totalTeamUsageDecimal.plus(toDecimal(memberUsageData.currentUsage))

// Fetch copilot cost for this team member
const memberStats = await db
Expand All @@ -418,20 +422,20 @@ export async function getSimplifiedBillingSummary(
.limit(1)

if (memberStats.length > 0) {
totalTeamCopilotCost += Number.parseFloat(
memberStats[0].currentPeriodCopilotCost?.toString() || '0'
totalTeamCopilotCostDecimal = totalTeamCopilotCostDecimal.plus(
toDecimal(memberStats[0].currentPeriodCopilotCost)
)
totalTeamLastPeriodCopilotCost += Number.parseFloat(
memberStats[0].lastPeriodCopilotCost?.toString() || '0'
totalTeamLastPeriodCopilotCostDecimal = totalTeamLastPeriodCopilotCostDecimal.plus(
toDecimal(memberStats[0].lastPeriodCopilotCost)
)
}
}
currentUsage = totalTeamUsage
totalCopilotCost = totalTeamCopilotCost
totalLastPeriodCopilotCost = totalTeamLastPeriodCopilotCost
currentUsage = toNumber(totalTeamUsageDecimal)
totalCopilotCost = toNumber(totalTeamCopilotCostDecimal)
totalLastPeriodCopilotCost = toNumber(totalTeamLastPeriodCopilotCostDecimal)
}

const overageAmount = Math.max(0, currentUsage - basePrice)
const overageAmount = toNumber(Decimal.max(0, toDecimal(currentUsage).minus(basePrice)))
const percentUsed = usageData.limit > 0 ? (currentUsage / usageData.limit) * 100 : 0

// Calculate days remaining in billing period
Expand Down
40 changes: 21 additions & 19 deletions apps/sim/lib/billing/core/usage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
getPlanPricing,
} from '@/lib/billing/subscriptions/utils'
import type { BillingData, UsageData, UsageLimitInfo } from '@/lib/billing/types'
import { Decimal, toDecimal, toNumber } from '@/lib/billing/utils/decimal'
import { isBillingEnabled } from '@/lib/core/config/feature-flags'
import { getBaseUrl } from '@/lib/core/utils/urls'
import { sendEmail } from '@/lib/messaging/email/mailer'
Expand Down Expand Up @@ -45,7 +46,7 @@ export async function getOrgUsageLimit(

const configured =
orgData.length > 0 && orgData[0].orgUsageLimit
? Number.parseFloat(orgData[0].orgUsageLimit)
? toNumber(toDecimal(orgData[0].orgUsageLimit))
: null

if (plan === 'enterprise') {
Expand Down Expand Up @@ -111,30 +112,31 @@ export async function getUserUsageData(userId: string): Promise<UsageData> {
}

const stats = userStatsData[0]
let currentUsage = Number.parseFloat(stats.currentPeriodCost?.toString() ?? '0')
let currentUsageDecimal = toDecimal(stats.currentPeriodCost)

// For Pro users, include any snapshotted usage (from when they joined a team)
// This ensures they see their total Pro usage in the UI
if (subscription && subscription.plan === 'pro' && subscription.referenceId === userId) {
const snapshotUsage = Number.parseFloat(stats.proPeriodCostSnapshot?.toString() ?? '0')
if (snapshotUsage > 0) {
currentUsage += snapshotUsage
const snapshotUsageDecimal = toDecimal(stats.proPeriodCostSnapshot)
if (snapshotUsageDecimal.greaterThan(0)) {
currentUsageDecimal = currentUsageDecimal.plus(snapshotUsageDecimal)
logger.info('Including Pro snapshot in usage display', {
userId,
currentPeriodCost: stats.currentPeriodCost,
proPeriodCostSnapshot: snapshotUsage,
totalUsage: currentUsage,
proPeriodCostSnapshot: toNumber(snapshotUsageDecimal),
totalUsage: toNumber(currentUsageDecimal),
})
}
}
const currentUsage = toNumber(currentUsageDecimal)

// Determine usage limit based on plan type
let limit: number

if (!subscription || subscription.plan === 'free' || subscription.plan === 'pro') {
// Free/Pro: Use individual user limit from userStats
limit = stats.currentUsageLimit
? Number.parseFloat(stats.currentUsageLimit)
? toNumber(toDecimal(stats.currentUsageLimit))
: getFreeTierLimit()
} else {
// Team/Enterprise: Use organization limit
Expand Down Expand Up @@ -163,7 +165,7 @@ export async function getUserUsageData(userId: string): Promise<UsageData> {
isExceeded,
billingPeriodStart,
billingPeriodEnd,
lastPeriodCost: Number.parseFloat(stats.lastPeriodCost?.toString() || '0'),
lastPeriodCost: toNumber(toDecimal(stats.lastPeriodCost)),
}
} catch (error) {
logger.error('Failed to get user usage data', { userId, error })
Expand Down Expand Up @@ -195,7 +197,7 @@ export async function getUserUsageLimitInfo(userId: string): Promise<UsageLimitI
if (!subscription || subscription.plan === 'free' || subscription.plan === 'pro') {
// Free/Pro: Use individual limits
currentLimit = stats.currentUsageLimit
? Number.parseFloat(stats.currentUsageLimit)
? toNumber(toDecimal(stats.currentUsageLimit))
: getFreeTierLimit()
minimumLimit = getPerUserMinimumLimit(subscription)
canEdit = canEditUsageLimit(subscription)
Expand Down Expand Up @@ -353,7 +355,7 @@ export async function getUserUsageLimit(userId: string): Promise<number> {
)
}

return Number.parseFloat(userStatsQuery[0].currentUsageLimit)
return toNumber(toDecimal(userStatsQuery[0].currentUsageLimit))
}
// Team/Enterprise: Verify org exists then use organization limit
const orgExists = await db
Expand Down Expand Up @@ -438,7 +440,7 @@ export async function syncUsageLimitsFromSubscription(userId: string): Promise<v
// Free/Pro: Handle individual limits
const defaultLimit = getPerUserMinimumLimit(subscription)
const currentLimit = currentStats.currentUsageLimit
? Number.parseFloat(currentStats.currentUsageLimit)
? toNumber(toDecimal(currentStats.currentUsageLimit))
: 0

if (!subscription || subscription.status !== 'active') {
Expand Down Expand Up @@ -503,9 +505,9 @@ export async function getTeamUsageLimits(organizationId: string): Promise<
userId: memberData.userId,
userName: memberData.userName,
userEmail: memberData.userEmail,
currentLimit: Number.parseFloat(memberData.currentLimit || getFreeTierLimit().toString()),
currentUsage: Number.parseFloat(memberData.currentPeriodCost || '0'),
totalCost: Number.parseFloat(memberData.totalCost || '0'),
currentLimit: toNumber(toDecimal(memberData.currentLimit || getFreeTierLimit().toString())),
currentUsage: toNumber(toDecimal(memberData.currentPeriodCost)),
totalCost: toNumber(toDecimal(memberData.totalCost)),
lastActive: memberData.lastActive,
}))
} catch (error) {
Expand All @@ -531,7 +533,7 @@ export async function getEffectiveCurrentPeriodCost(userId: string): Promise<num
.limit(1)

if (rows.length === 0) return 0
return rows[0].current ? Number.parseFloat(rows[0].current.toString()) : 0
return toNumber(toDecimal(rows[0].current))
}

// Team/Enterprise: pooled usage across org members
Expand All @@ -548,11 +550,11 @@ export async function getEffectiveCurrentPeriodCost(userId: string): Promise<num
.from(userStats)
.where(inArray(userStats.userId, memberIds))

let pooled = 0
let pooled = new Decimal(0)
for (const r of rows) {
pooled += r.current ? Number.parseFloat(r.current.toString()) : 0
pooled = pooled.plus(toDecimal(r.current))
}
return pooled
return toNumber(pooled)
}

/**
Expand Down
Loading
Loading