diff --git a/src/lib/balances.ts b/src/lib/balances.ts index 77f4622..0bcd915 100644 --- a/src/lib/balances.ts +++ b/src/lib/balances.ts @@ -1,5 +1,6 @@ import { getGroupExpenses } from '@/lib/api' import { Participant } from '@prisma/client' +import { match } from 'ts-pattern' export type Balances = Record< Participant['id'], @@ -19,34 +20,42 @@ export function getBalances( for (const expense of expenses) { const paidBy = expense.paidById - const paidFors = expense.paidFor.map((p) => p.participantId) + const paidFors = expense.paidFor if (!balances[paidBy]) balances[paidBy] = { paid: 0, paidFor: 0, total: 0 } balances[paidBy].paid += expense.amount balances[paidBy].total += expense.amount - paidFors.forEach((paidFor, index) => { - if (!balances[paidFor]) - balances[paidFor] = { paid: 0, paidFor: 0, total: 0 } - const dividedAmount = divide( - expense.amount, - paidFors.length, - index === paidFors.length - 1, - ) - balances[paidFor].paidFor += dividedAmount - balances[paidFor].total -= dividedAmount + const totalPaidForShares = paidFors.reduce( + (sum, paidFor) => sum + paidFor.shares, + 0, + ) + let remaining = expense.amount + paidFors.forEach((paidFor, index) => { + if (!balances[paidFor.participantId]) + balances[paidFor.participantId] = { paid: 0, paidFor: 0, total: 0 } + + const isLast = index === paidFors.length - 1 + + const [shares, totalShares] = match(expense.splitMode) + .with('EVENLY', () => [1, paidFors.length]) + .with('BY_SHARES', () => [paidFor.shares, totalPaidForShares]) + .with('BY_PERCENTAGE', () => [paidFor.shares, totalPaidForShares]) + .with('BY_AMOUNT', () => [paidFor.shares, totalPaidForShares]) + .exhaustive() + + const dividedAmount = isLast + ? remaining + : Math.floor((expense.amount * shares) / totalShares) + remaining -= dividedAmount + balances[paidFor.participantId].paidFor += dividedAmount + balances[paidFor.participantId].total -= dividedAmount }) } return balances } -function divide(total: number, count: number, isLast: boolean): number { - if (!isLast) return Math.floor(total / count) - - return total - divide(total, count, false) * (count - 1) -} - export function getSuggestedReimbursements( balances: Balances, ): Reimbursement[] {