-
Notifications
You must be signed in to change notification settings - Fork 197
Expand file tree
/
Copy pathlitellm_provider.py
More file actions
116 lines (94 loc) · 4.17 KB
/
litellm_provider.py
File metadata and controls
116 lines (94 loc) · 4.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""LiteLLM-based embedding provider for semantic indexing.
Routes embedding requests to 100+ providers (OpenAI, Anthropic, Google, Azure,
Bedrock, Cohere, etc.) via the litellm SDK. No proxy server needed.
Model strings use the ``provider/model`` format, e.g.
``openai/text-embedding-3-small``, ``cohere/embed-english-v3.0``,
``azure/my-embedding-deployment``.
See https://docs.litellm.ai/docs/embedding/supported_embedding for all
supported embedding models.
"""
from __future__ import annotations
import asyncio
from typing import Any
from basic_memory.repository.embedding_provider import EmbeddingProvider
from basic_memory.repository.semantic_errors import SemanticDependenciesMissingError
class LiteLLMEmbeddingProvider(EmbeddingProvider):
"""Embedding provider backed by the litellm SDK."""
def __init__(
self,
model_name: str = "openai/text-embedding-3-small",
*,
batch_size: int = 64,
request_concurrency: int = 4,
dimensions: int = 1536,
api_key: str | None = None,
timeout: float = 30.0,
) -> None:
self.model_name = model_name
self.dimensions = dimensions
self.batch_size = batch_size
self.request_concurrency = request_concurrency
self._api_key = api_key
self._timeout = timeout
def runtime_log_attrs(self) -> dict[str, int]:
"""Return provider-specific runtime settings suitable for startup logs."""
return {
"provider_batch_size": self.batch_size,
"request_concurrency": self.request_concurrency,
}
async def embed_documents(self, texts: list[str]) -> list[list[float]]:
if not texts:
return []
try:
import litellm
except ImportError as exc:
raise SemanticDependenciesMissingError(
"litellm dependency is missing. Install with: pip install litellm"
) from exc
batches = [
texts[start : start + self.batch_size]
for start in range(0, len(texts), self.batch_size)
]
batch_vectors: list[list[list[float]] | None] = [None] * len(batches)
semaphore = asyncio.Semaphore(self.request_concurrency)
async def embed_batch(batch_index: int, batch: list[str]) -> None:
async with semaphore:
params: dict[str, Any] = {
"model": self.model_name,
"input": batch,
"drop_params": True,
"timeout": self._timeout,
}
if self._api_key:
params["api_key"] = self._api_key
response = await litellm.aembedding(**params)
vectors_by_index: dict[int, list[float]] = {}
for item in response.data:
response_index = int(item["index"])
vectors_by_index[response_index] = [float(v) for v in item["embedding"]]
ordered_vectors: list[list[float]] = []
for index in range(len(batch)):
vector = vectors_by_index.get(index)
if vector is None:
raise RuntimeError(
"LiteLLM embedding response is missing expected vector index."
)
ordered_vectors.append(vector)
batch_vectors[batch_index] = ordered_vectors
await asyncio.gather(
*(embed_batch(batch_index, batch) for batch_index, batch in enumerate(batches))
)
all_vectors: list[list[float]] = []
for vectors in batch_vectors:
if vectors is None:
raise RuntimeError("LiteLLM embedding batch did not produce vectors.")
all_vectors.extend(vectors)
if all_vectors and len(all_vectors[0]) != self.dimensions:
raise RuntimeError(
f"Embedding model returned {len(all_vectors[0])}-dimensional vectors "
f"but provider was configured for {self.dimensions} dimensions."
)
return all_vectors
async def embed_query(self, text: str) -> list[float]:
vectors = await self.embed_documents([text])
return vectors[0] if vectors else [0.0] * self.dimensions