Skip to content

Commit 3b3b5b7

Browse files
committed
fix(knowledge): record embedding usage cost for KB document processing
Adds billing tracking to the KB embedding pipeline, which was previously generating OpenAI API calls with no cost recorded. Token counts are now captured from the actual API response and recorded via recordUsage after successful embedding insertion. BYOK workspaces are excluded from billing. Applies to all execution paths: direct, BullMQ, and Trigger.dev.
1 parent f016eb3 commit 3b3b5b7

6 files changed

Lines changed: 80 additions & 13 deletions

File tree

apps/sim/app/api/knowledge/utils.test.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ vi.stubGlobal(
7777
{ embedding: [0.1, 0.2], index: 0 },
7878
{ embedding: [0.3, 0.4], index: 1 },
7979
],
80+
usage: { prompt_tokens: 2, total_tokens: 2 },
8081
}),
8182
})
8283
)
@@ -294,7 +295,7 @@ describe('Knowledge Utils', () => {
294295
it.concurrent('should return same length as input', async () => {
295296
const result = await generateEmbeddings(['a', 'b'])
296297

297-
expect(result.length).toBe(2)
298+
expect(result.embeddings.length).toBe(2)
298299
})
299300

300301
it('should use Azure OpenAI when Azure config is provided', async () => {
@@ -313,6 +314,7 @@ describe('Knowledge Utils', () => {
313314
ok: true,
314315
json: async () => ({
315316
data: [{ embedding: [0.1, 0.2], index: 0 }],
317+
usage: { prompt_tokens: 1, total_tokens: 1 },
316318
}),
317319
} as any)
318320

@@ -342,6 +344,7 @@ describe('Knowledge Utils', () => {
342344
ok: true,
343345
json: async () => ({
344346
data: [{ embedding: [0.1, 0.2], index: 0 }],
347+
usage: { prompt_tokens: 1, total_tokens: 1 },
345348
}),
346349
} as any)
347350

apps/sim/lib/billing/core/usage-log.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ export type UsageLogSource =
2121
| 'workspace-chat'
2222
| 'mcp_copilot'
2323
| 'mothership_block'
24+
| 'knowledge-base'
2425

2526
/**
2627
* Metadata for 'model' category charges

apps/sim/lib/chunkers/docs-chunker.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,10 @@ export class DocsChunker {
8181
const textChunks = await this.splitContent(markdownContent)
8282

8383
logger.info(`Generating embeddings for ${textChunks.length} chunks in ${relativePath}`)
84-
const embeddings = textChunks.length > 0 ? await generateEmbeddings(textChunks) : []
84+
const { embeddings } =
85+
textChunks.length > 0
86+
? await generateEmbeddings(textChunks)
87+
: { embeddings: [] as number[][] }
8588
const embeddingModel = 'text-embedding-3-small'
8689

8790
const chunks: DocChunk[] = []

apps/sim/lib/knowledge/chunks/service.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ export async function createChunk(
110110
workspaceId?: string | null
111111
): Promise<ChunkData> {
112112
logger.info(`[${requestId}] Generating embedding for manual chunk`)
113-
const embeddings = await generateEmbeddings([chunkData.content], undefined, workspaceId)
113+
const { embeddings } = await generateEmbeddings([chunkData.content], undefined, workspaceId)
114114

115115
// Calculate accurate token count
116116
const tokenCount = estimateTokenCount(chunkData.content, 'openai')
@@ -359,7 +359,7 @@ export async function updateChunk(
359359
if (content !== currentChunk[0].content) {
360360
logger.info(`[${requestId}] Content changed, regenerating embedding for chunk ${chunkId}`)
361361

362-
const embeddings = await generateEmbeddings([content], undefined, workspaceId)
362+
const { embeddings } = await generateEmbeddings([content], undefined, workspaceId)
363363

364364
// Calculate accurate token count
365365
const tokenCount = estimateTokenCount(content, 'openai')

apps/sim/lib/knowledge/documents/service.ts

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@ import {
2525
type SQL,
2626
sql,
2727
} from 'drizzle-orm'
28+
import { recordUsage } from '@/lib/billing/core/usage-log'
2829
import { createBullMQJobData, isBullMQEnabled } from '@/lib/core/bullmq'
2930
import { env } from '@/lib/core/config/env'
30-
import { isTriggerDevEnabled } from '@/lib/core/config/feature-flags'
31+
import { getCostMultiplier, isTriggerDevEnabled } from '@/lib/core/config/feature-flags'
3132
import { enqueueWorkspaceDispatch } from '@/lib/core/workspace-dispatch'
3233
import { processDocument } from '@/lib/knowledge/documents/document-processor'
3334
import type { DocumentSortField, SortOrder } from '@/lib/knowledge/documents/types'
@@ -43,6 +44,7 @@ import type { ProcessedDocumentTags } from '@/lib/knowledge/types'
4344
import { deleteFile } from '@/lib/uploads/core/storage-service'
4445
import { extractStorageKey } from '@/lib/uploads/utils/file-utils'
4546
import type { DocumentProcessingPayload } from '@/background/knowledge-processing'
47+
import { getEmbeddingModelPricing } from '@/providers/models'
4648

4749
const logger = createLogger('DocumentService')
4850

@@ -460,6 +462,9 @@ export async function processDocumentAsync(
460462
overlap: rawConfig?.overlap ?? 200,
461463
}
462464

465+
let totalEmbeddingTokens = 0
466+
let embeddingIsBYOK = false
467+
463468
await withTimeout(
464469
(async () => {
465470
const processed = await processDocument(
@@ -500,10 +505,16 @@ export async function processDocumentAsync(
500505
const batchNum = Math.floor(i / batchSize) + 1
501506

502507
logger.info(`[${documentId}] Processing embedding batch ${batchNum}/${totalBatches}`)
503-
const batchEmbeddings = await generateEmbeddings(batch, undefined, kb[0].workspaceId)
508+
const {
509+
embeddings: batchEmbeddings,
510+
totalTokens: batchTokens,
511+
isBYOK,
512+
} = await generateEmbeddings(batch, undefined, kb[0].workspaceId)
504513
for (const emb of batchEmbeddings) {
505514
embeddings.push(emb)
506515
}
516+
totalEmbeddingTokens += batchTokens
517+
embeddingIsBYOK = isBYOK
507518
}
508519
}
509520

@@ -638,6 +649,34 @@ export async function processDocumentAsync(
638649

639650
const processingTime = Date.now() - startTime
640651
logger.info(`[${documentId}] Successfully processed document in ${processingTime}ms`)
652+
653+
if (!embeddingIsBYOK && totalEmbeddingTokens > 0 && kb[0].userId) {
654+
try {
655+
const embeddingModel = 'text-embedding-3-small'
656+
const pricing = getEmbeddingModelPricing(embeddingModel)
657+
if (pricing) {
658+
const cost = (totalEmbeddingTokens / 1_000_000) * pricing.input * getCostMultiplier()
659+
await recordUsage({
660+
userId: kb[0].userId,
661+
workspaceId: kb[0].workspaceId ?? undefined,
662+
entries: [
663+
{
664+
category: 'model',
665+
source: 'knowledge-base',
666+
description: embeddingModel,
667+
cost,
668+
metadata: { inputTokens: totalEmbeddingTokens, outputTokens: 0 },
669+
},
670+
],
671+
additionalStats: {
672+
totalTokensUsed: sql`total_tokens_used + ${totalEmbeddingTokens}`,
673+
},
674+
})
675+
}
676+
} catch (billingError) {
677+
logger.error(`[${documentId}] Failed to record embedding usage`, { error: billingError })
678+
}
679+
}
641680
} catch (error) {
642681
const processingTime = Date.now() - startTime
643682
const errorMessage = error instanceof Error ? error.message : 'Unknown error'

apps/sim/lib/knowledge/embeddings.ts

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ interface EmbeddingConfig {
3535
apiUrl: string
3636
headers: Record<string, string>
3737
modelName: string
38+
isBYOK: boolean
3839
}
3940

4041
interface EmbeddingResponseItem {
@@ -71,16 +72,19 @@ async function getEmbeddingConfig(
7172
'Content-Type': 'application/json',
7273
},
7374
modelName: kbModelName,
75+
isBYOK: false,
7476
}
7577
}
7678

7779
let openaiApiKey = env.OPENAI_API_KEY
80+
let isBYOK = false
7881

7982
if (workspaceId) {
8083
const byokResult = await getBYOKKey(workspaceId, 'openai')
8184
if (byokResult) {
8285
logger.info('Using workspace BYOK key for OpenAI embeddings')
8386
openaiApiKey = byokResult.apiKey
87+
isBYOK = true
8488
}
8589
}
8690

@@ -98,12 +102,16 @@ async function getEmbeddingConfig(
98102
'Content-Type': 'application/json',
99103
},
100104
modelName: embeddingModel,
105+
isBYOK,
101106
}
102107
}
103108

104109
const EMBEDDING_REQUEST_TIMEOUT_MS = 60_000
105110

106-
async function callEmbeddingAPI(inputs: string[], config: EmbeddingConfig): Promise<number[][]> {
111+
async function callEmbeddingAPI(
112+
inputs: string[],
113+
config: EmbeddingConfig
114+
): Promise<{ embeddings: number[][]; totalTokens: number }> {
107115
return retryWithExponentialBackoff(
108116
async () => {
109117
const useDimensions = supportsCustomDimensions(config.modelName)
@@ -140,7 +148,10 @@ async function callEmbeddingAPI(inputs: string[], config: EmbeddingConfig): Prom
140148
}
141149

142150
const data: EmbeddingAPIResponse = await response.json()
143-
return data.data.map((item) => item.embedding)
151+
return {
152+
embeddings: data.data.map((item) => item.embedding),
153+
totalTokens: data.usage.total_tokens,
154+
}
144155
},
145156
{
146157
maxRetries: 3,
@@ -178,14 +189,22 @@ async function processWithConcurrency<T, R>(
178189
return results
179190
}
180191

192+
export interface GenerateEmbeddingsResult {
193+
embeddings: number[][]
194+
totalTokens: number
195+
isBYOK: boolean
196+
}
197+
181198
/**
182-
* Generate embeddings for multiple texts with token-aware batching and parallel processing
199+
* Generate embeddings for multiple texts with token-aware batching and parallel processing.
200+
* Returns embeddings alongside the actual token count from the API and whether a BYOK key was used.
201+
* Callers should use `totalTokens` and `isBYOK` to record billing via `recordUsage`.
183202
*/
184203
export async function generateEmbeddings(
185204
texts: string[],
186205
embeddingModel = 'text-embedding-3-small',
187206
workspaceId?: string | null
188-
): Promise<number[][]> {
207+
): Promise<GenerateEmbeddingsResult> {
189208
const config = await getEmbeddingConfig(embeddingModel, workspaceId)
190209

191210
const batches = batchByTokenLimit(texts, MAX_TOKENS_PER_REQUEST, embeddingModel)
@@ -204,13 +223,15 @@ export async function generateEmbeddings(
204223
)
205224

206225
const allEmbeddings: number[][] = []
226+
let totalTokens = 0
207227
for (const batch of batchResults) {
208-
for (const emb of batch) {
228+
for (const emb of batch.embeddings) {
209229
allEmbeddings.push(emb)
210230
}
231+
totalTokens += batch.totalTokens
211232
}
212233

213-
return allEmbeddings
234+
return { embeddings: allEmbeddings, totalTokens, isBYOK: config.isBYOK }
214235
}
215236

216237
/**
@@ -227,6 +248,6 @@ export async function generateSearchEmbedding(
227248
`Using ${config.useAzure ? 'Azure OpenAI' : 'OpenAI'} for search embedding generation`
228249
)
229250

230-
const embeddings = await callEmbeddingAPI([query], config)
251+
const { embeddings } = await callEmbeddingAPI([query], config)
231252
return embeddings[0]
232253
}

0 commit comments

Comments
 (0)