Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ dependencies = [
"fastembed>=0.7.4",
"sqlite-vec>=0.1.6",
"openai>=1.100.2",
"litellm>=1.60.0,<2.0.0",
"logfire>=4.19.0",
"psutil>=5.9.0",
]
Expand Down
10 changes: 10 additions & 0 deletions src/basic_memory/repository/embedding_provider_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,16 @@ def create_embedding_provider(app_config: BasicMemoryConfig) -> EmbeddingProvide
request_concurrency=app_config.semantic_embedding_request_concurrency,
**extra_kwargs,
)
elif provider_name == "litellm":
from basic_memory.repository.litellm_provider import LiteLLMEmbeddingProvider

model_name = app_config.semantic_embedding_model or "openai/text-embedding-3-small"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Map the built-in default model for LiteLLM

When users switch only semantic_embedding_provider to litellm, BasicMemoryConfig still supplies the non-empty default model bge-small-en-v1.5, so this or never selects the LiteLLM provider default. The factory then instantiates LiteLLMEmbeddingProvider(model_name="bge-small-en-v1.5") instead of a LiteLLM-routable model such as openai/text-embedding-3-small, making the new provider fail for the documented minimal configuration; mirror the OpenAI branch's remapping of the FastEmbed default or otherwise treat it as unset.

Useful? React with 👍 / 👎.

provider = LiteLLMEmbeddingProvider(
model_name=model_name,
batch_size=app_config.semantic_embedding_batch_size,
request_concurrency=app_config.semantic_embedding_request_concurrency,
**extra_kwargs,
)
else:
raise ValueError(f"Unsupported semantic embedding provider: {provider_name}")

Expand Down
116 changes: 116 additions & 0 deletions src/basic_memory/repository/litellm_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""LiteLLM-based embedding provider for semantic indexing.

Routes embedding requests to 100+ providers (OpenAI, Anthropic, Google, Azure,
Bedrock, Cohere, etc.) via the litellm SDK. No proxy server needed.

Model strings use the ``provider/model`` format, e.g.
``openai/text-embedding-3-small``, ``cohere/embed-english-v3.0``,
``azure/my-embedding-deployment``.

See https://docs.litellm.ai/docs/embedding/supported_embedding for all
supported embedding models.
"""

from __future__ import annotations

import asyncio
from typing import Any

from basic_memory.repository.embedding_provider import EmbeddingProvider
from basic_memory.repository.semantic_errors import SemanticDependenciesMissingError


class LiteLLMEmbeddingProvider(EmbeddingProvider):
"""Embedding provider backed by the litellm SDK."""

def __init__(
self,
model_name: str = "openai/text-embedding-3-small",
*,
batch_size: int = 64,
request_concurrency: int = 4,
dimensions: int = 1536,
api_key: str | None = None,
timeout: float = 30.0,
) -> None:
self.model_name = model_name
self.dimensions = dimensions
self.batch_size = batch_size
self.request_concurrency = request_concurrency
self._api_key = api_key
self._timeout = timeout

def runtime_log_attrs(self) -> dict[str, int]:
"""Return provider-specific runtime settings suitable for startup logs."""
return {
"provider_batch_size": self.batch_size,
"request_concurrency": self.request_concurrency,
}

async def embed_documents(self, texts: list[str]) -> list[list[float]]:
if not texts:
return []

try:
import litellm
except ImportError as exc:
raise SemanticDependenciesMissingError(
"litellm dependency is missing. Install with: pip install litellm"
) from exc

batches = [
texts[start : start + self.batch_size]
for start in range(0, len(texts), self.batch_size)
]
batch_vectors: list[list[list[float]] | None] = [None] * len(batches)
semaphore = asyncio.Semaphore(self.request_concurrency)

async def embed_batch(batch_index: int, batch: list[str]) -> None:
async with semaphore:
params: dict[str, Any] = {
"model": self.model_name,
"input": batch,
"drop_params": True,
"timeout": self._timeout,
}
if self._api_key:
params["api_key"] = self._api_key

response = await litellm.aembedding(**params)

vectors_by_index: dict[int, list[float]] = {}
for item in response.data:
response_index = int(item["index"])
vectors_by_index[response_index] = [float(v) for v in item["embedding"]]

ordered_vectors: list[list[float]] = []
for index in range(len(batch)):
vector = vectors_by_index.get(index)
if vector is None:
raise RuntimeError(
"LiteLLM embedding response is missing expected vector index."
)
ordered_vectors.append(vector)

batch_vectors[batch_index] = ordered_vectors

await asyncio.gather(
*(embed_batch(batch_index, batch) for batch_index, batch in enumerate(batches))
)

all_vectors: list[list[float]] = []
for vectors in batch_vectors:
if vectors is None:
raise RuntimeError("LiteLLM embedding batch did not produce vectors.")
all_vectors.extend(vectors)

if all_vectors and len(all_vectors[0]) != self.dimensions:
raise RuntimeError(
f"Embedding model returned {len(all_vectors[0])}-dimensional vectors "
f"but provider was configured for {self.dimensions} dimensions."
)
return all_vectors

async def embed_query(self, text: str) -> list[float]:
vectors = await self.embed_documents([text])
return vectors[0] if vectors else [0.0] * self.dimensions
181 changes: 181 additions & 0 deletions tests/repository/test_litellm_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
"""Tests for LiteLLMEmbeddingProvider.

Uses AST parsing and direct SDK mocking to avoid importing the full
basic_memory dependency chain (logfire, alembic, etc.).
"""

import ast
import sys
import types
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock

import pytest

PROVIDER_PATH = (
Path(__file__).resolve().parents[2]
/ "src"
/ "basic_memory"
/ "repository"
/ "litellm_provider.py"
)
FACTORY_PATH = (
Path(__file__).resolve().parents[2]
/ "src"
/ "basic_memory"
/ "repository"
/ "embedding_provider_factory.py"
)


class TestLiteLLMProviderStructure:
"""Verify the provider file has the correct structure."""

def _parse(self):
return ast.parse(PROVIDER_PATH.read_text())

def test_file_exists(self):
assert PROVIDER_PATH.exists()

def test_has_litellm_embedding_provider_class(self):
tree = self._parse()
classes = [n.name for n in ast.walk(tree) if isinstance(n, ast.ClassDef)]
assert "LiteLLMEmbeddingProvider" in classes

def test_has_embed_documents_method(self):
tree = self._parse()
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef) and node.name == "LiteLLMEmbeddingProvider":
methods = [
n.name
for n in node.body
if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef))
]
assert "embed_documents" in methods
assert "embed_query" in methods
return
pytest.fail("LiteLLMEmbeddingProvider class not found")

def test_embed_documents_is_async(self):
tree = self._parse()
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef) and node.name == "LiteLLMEmbeddingProvider":
for item in node.body:
if isinstance(item, ast.AsyncFunctionDef) and item.name == "embed_documents":
return
pytest.fail("embed_documents is not async")

def test_uses_drop_params_true(self):
src = PROVIDER_PATH.read_text()
assert "drop_params" in src

def test_uses_litellm_aembedding(self):
src = PROVIDER_PATH.read_text()
assert "aembedding" in src

def test_has_runtime_log_attrs(self):
tree = self._parse()
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef) and node.name == "LiteLLMEmbeddingProvider":
methods = [
n.name
for n in node.body
if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef))
]
assert "runtime_log_attrs" in methods
return

def test_default_model_in_source(self):
src = PROVIDER_PATH.read_text()
assert "openai/text-embedding-3-small" in src


class TestFactoryRegistration:
"""Verify the factory recognizes litellm as a provider."""

def test_litellm_branch_in_factory(self):
src = FACTORY_PATH.read_text()
assert 'provider_name == "litellm"' in src

def test_imports_litellm_provider(self):
src = FACTORY_PATH.read_text()
assert "LiteLLMEmbeddingProvider" in src


class TestLiteLLMSDKInteraction:
"""Test litellm SDK calls directly (no basic_memory deps needed)."""

def test_aembedding_called_with_drop_params(self):
fake = types.ModuleType("litellm")
response = MagicMock()
response.data = [{"index": 0, "embedding": [0.1, 0.2]}]
fake.aembedding = AsyncMock(return_value=response)
sys.modules["litellm"] = fake

try:
import asyncio

async def run():
await fake.aembedding(
model="openai/text-embedding-3-small",
input=["hello"],
drop_params=True,
)

asyncio.run(run())
kwargs = fake.aembedding.call_args.kwargs
assert kwargs["drop_params"] is True
assert kwargs["model"] == "openai/text-embedding-3-small"
finally:
del sys.modules["litellm"]

def test_aembedding_forwards_api_key(self):
fake = types.ModuleType("litellm")
response = MagicMock()
response.data = [{"index": 0, "embedding": [0.1]}]
fake.aembedding = AsyncMock(return_value=response)
sys.modules["litellm"] = fake

try:
import asyncio

async def run():
await fake.aembedding(
model="openai/text-embedding-3-small",
input=["hello"],
api_key="sk-test",
drop_params=True,
)

asyncio.run(run())
assert fake.aembedding.call_args.kwargs["api_key"] == "sk-test"
finally:
del sys.modules["litellm"]

def test_aembedding_response_has_vectors(self):
fake = types.ModuleType("litellm")
response = MagicMock()
response.data = [
{"index": 0, "embedding": [0.1, 0.2, 0.3]},
{"index": 1, "embedding": [0.4, 0.5, 0.6]},
]
fake.aembedding = AsyncMock(return_value=response)
sys.modules["litellm"] = fake

try:
import asyncio

async def run():
resp = await fake.aembedding(
model="openai/text-embedding-3-small",
input=["hello", "world"],
drop_params=True,
)
return resp

resp = asyncio.run(run())
assert len(resp.data) == 2
assert resp.data[0]["embedding"] == [0.1, 0.2, 0.3]
assert resp.data[1]["embedding"] == [0.4, 0.5, 0.6]
finally:
del sys.modules["litellm"]