Skip to content

Commit c617756

Browse files
Merge pull request #13 from XyLearningProgramming/feature/embed-allmini
✨ added a new all-minilm for embedding task
2 parents cafc6b5 + ba4e9df commit c617756

14 files changed

Lines changed: 739 additions & 462 deletions

File tree

Makefile

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
.PHONY: dev run download install lint format check test smoke clean help
2+
3+
help: ## Show this help
4+
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-15s\033[0m %s\n", $$1, $$2}'
5+
6+
install: ## Install dependencies (including dev)
7+
uv sync
8+
9+
download: ## Download model files
10+
bash scripts/download.sh
11+
12+
dev: ## Start dev server with auto-reload
13+
uv run uvicorn slm_server.app:app --reload --host 0.0.0.0 --port 8000
14+
15+
run: ## Start server via start.sh
16+
bash scripts/start.sh
17+
18+
lint: ## Run ruff linter
19+
uv run ruff check slm_server/
20+
21+
format: ## Run ruff formatter
22+
uv run ruff format slm_server/
23+
24+
check: lint ## Run linter + formatter check
25+
uv run ruff format --check slm_server/
26+
27+
smoke: ## Smoke-test the running server APIs with curl
28+
bash scripts/smoke.sh
29+
30+
test: ## Run tests with coverage
31+
uv run pytest tests/ -v --cov=slm_server --cov-report=term-missing
32+
33+
clean: ## Remove caches and build artifacts
34+
rm -rf __pycache__ .pytest_cache .ruff_cache .coverage htmlcov build dist *.egg-info
35+
find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true

deploy/helm/values.yaml

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,35 @@ env: {}
7979

8080
# Resource requests and limits for the container.
8181
# See https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/
82-
# Tuned for Qwen3-0.6B-Q4_K_M (484 MB) + n_ctx=8192 KV cache (~448 MB) on 1-CPU / 1 GB VPS nodes.
82+
#
83+
# Memory budget breakdown (target node: 1-CPU / 2 GB VPS):
84+
# Chat LLM – Qwen3-0.6B-Q4_K_M.gguf ~484 MB (4-bit quantised)
85+
# Embedding – all-MiniLM-L6-v2 quint8 ONNX ~23 MB (uint8 AVX2 quantised)
86+
# KV cache – n_ctx=2048 ~50-80 MB
87+
# Runtime – Python, FastAPI, onnxruntime ~50-100 MB
88+
# -------------------------------------------------------
89+
# Total request: 550 Mi Hard limit: 1 Gi
90+
#
91+
# Why these models:
92+
# - Qwen3-0.6B-Q4_K_M is the smallest instruction-tuned LLM that still
93+
# supports function calling (chatml format) at usable quality.
94+
# - all-MiniLM-L6-v2 (384-dim, 6-layer) is purpose-trained for sentence
95+
# embeddings via mean pooling, ranking well on STS benchmarks for its
96+
# size. The quint8 AVX2 variant keeps the file at 23 MB vs 90 MB fp32.
97+
#
98+
# Why the limit is reasonable:
99+
# - The worker node (active-nerd-2) has 2 GiB total RAM shared with the
100+
# OS and other pods. 550 Mi request leaves headroom; the 1 Gi hard
101+
# limit prevents OOM-kill from bursty KV-cache growth.
102+
# - MAX_CONCURRENCY=1 ensures only one inference runs at a time, so peak
103+
# memory is predictable (no concurrent KV-cache allocations).
83104
resources:
84105
limits:
85-
cpu: 1
106+
cpu: 900m
86107
memory: 1Gi
87108
requests:
88-
cpu: 200m
89-
memory: 600Mi
109+
cpu: 50m
110+
memory: 550Mi
90111

91112
# Readiness and liveness probes configuration
92113
probes:

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ dependencies = [
1818
"prometheus-client>=0.22.1",
1919
"prometheus-fastapi-instrumentator>=7.1.0",
2020
"psutil>=6.1.0",
21+
"onnxruntime>=1.17.0",
22+
"tokenizers>=0.21.0",
2123
]
2224

2325
[tool.ruff.lint]

scripts/download.sh

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,30 @@ for file in "${FILES_TO_DOWNLOAD[@]}"; do
3636
fi
3737
done
3838

39+
# --- Embedding model: all-MiniLM-L6-v2 (ONNX, quantized UINT8 for AVX2) ---
40+
EMBEDDING_REPO_URL="https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2"
41+
EMBEDDING_MODEL_DIR="$MODEL_DIR/all-MiniLM-L6-v2"
42+
43+
mkdir -p "$EMBEDDING_MODEL_DIR/onnx"
44+
45+
EMBEDDING_FILES=(
46+
"onnx/model_quint8_avx2.onnx"
47+
"tokenizer.json"
48+
)
49+
50+
echo "Downloading all-MiniLM-L6-v2 ONNX embedding model..."
51+
52+
for file in "${EMBEDDING_FILES[@]}"; do
53+
dest="$EMBEDDING_MODEL_DIR/$file"
54+
if [ -f "$dest" ]; then
55+
echo "$file already exists, skipping download."
56+
else
57+
echo "Downloading $file..."
58+
wget -O "$dest" "$EMBEDDING_REPO_URL/resolve/main/$file" || {
59+
echo "Failed to download $file with wget, trying curl..."
60+
curl -L -o "$dest" "$EMBEDDING_REPO_URL/resolve/main/$file"
61+
}
62+
fi
63+
done
64+
3965
echo "Download process complete! Files are in $MODEL_DIR"

scripts/smoke.sh

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#!/bin/bash
2+
3+
set -e
4+
5+
BASE_URL="${BASE_URL:-http://localhost:8000}"
6+
7+
echo "=== Health check ==="
8+
curl -sf "$BASE_URL/health"
9+
echo
10+
11+
echo "=== List models ==="
12+
curl -sf "$BASE_URL/api/v1/models" | python3 -m json.tool
13+
echo
14+
15+
echo "=== Chat completion ==="
16+
curl -sf "$BASE_URL/api/v1/chat/completions" \
17+
-H "Content-Type: application/json" \
18+
-d '{
19+
"messages": [{"role": "user", "content": "Say hello in one sentence."}],
20+
"max_tokens": 64
21+
}' | python3 -m json.tool
22+
echo
23+
24+
echo "=== Chat completion (streaming) ==="
25+
curl -sf "$BASE_URL/api/v1/chat/completions" \
26+
-H "Content-Type: application/json" \
27+
-d '{
28+
"messages": [{"role": "user", "content": "What is 2+2?"}],
29+
"max_tokens": 32,
30+
"stream": true
31+
}'
32+
echo
33+
34+
echo "=== Embeddings (single) ==="
35+
curl -sf "$BASE_URL/api/v1/embeddings" \
36+
-H "Content-Type: application/json" \
37+
-d '{
38+
"input": "Hello world"
39+
}' | python3 -m json.tool
40+
echo
41+
42+
echo "=== Embeddings (batch) ==="
43+
curl -sf "$BASE_URL/api/v1/embeddings" \
44+
-H "Content-Type: application/json" \
45+
-d '{
46+
"input": ["The cat sat on the mat.", "A dog played in the park."]
47+
}' | python3 -m json.tool
48+
echo
49+
50+
echo "All smoke tests passed."

slm_server/app.py

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@
1010
from llama_cpp import CreateChatCompletionStreamResponse, Llama
1111

1212
from slm_server.config import Settings, get_settings
13+
from slm_server.embedding import OnnxEmbeddingModel
1314
from slm_server.logging import setup_logging
1415
from slm_server.metrics import setup_metrics
1516
from slm_server.model import (
1617
ChatCompletionRequest,
18+
EmbeddingData,
1719
EmbeddingRequest,
20+
EmbeddingResponse,
1821
ModelInfo,
1922
ModelListResponse,
2023
)
@@ -62,13 +65,21 @@ def get_llm(settings: Annotated[Settings, Depends(get_settings)]) -> Llama:
6265
seed=settings.seed,
6366
chat_format=CHAT_FORMAT,
6467
logits_all=False,
65-
embedding=True,
66-
use_mlock=True, # Use mlock to prevent memory swapping
67-
use_mmap=True, # Use memory-mapped files for faster access
68+
embedding=False,
69+
use_mlock=True,
70+
use_mmap=True,
6871
)
6972
return get_llm._instance
7073

7174

75+
def get_embedding_model(
76+
settings: Annotated[Settings, Depends(get_settings)],
77+
) -> OnnxEmbeddingModel:
78+
if not hasattr(get_embedding_model, "_instance"):
79+
get_embedding_model._instance = OnnxEmbeddingModel(settings.embedding)
80+
return get_embedding_model._instance
81+
82+
7283
def get_app() -> FastAPI:
7384
# Get settings when creating app.
7485
settings = get_settings()
@@ -176,41 +187,53 @@ async def create_chat_completion(
176187
@app.post("/api/v1/embeddings")
177188
async def create_embeddings(
178189
req: EmbeddingRequest,
179-
llm: Annotated[Llama, Depends(get_llm)],
190+
emb_model: Annotated[OnnxEmbeddingModel, Depends(get_embedding_model)],
180191
_: Annotated[None, Depends(lock_llm_semaphor)],
181192
__: Annotated[None, Depends(raise_as_http_exception)],
182193
):
183-
"""Create embeddings for the given input text(s)."""
194+
"""Create embeddings using the dedicated ONNX embedding model."""
184195
with slm_embedding_span(req) as span:
185-
# Use llama-cpp-python's create_embedding method directly
186-
embedding_result = await asyncio.to_thread(
187-
llm.create_embedding,
188-
**req.model_dump(),
196+
inputs = req.input if isinstance(req.input, list) else [req.input]
197+
vectors = await asyncio.to_thread(emb_model.encode, inputs, True)
198+
result = EmbeddingResponse(
199+
data=[
200+
EmbeddingData(embedding=vec.tolist(), index=i)
201+
for i, vec in enumerate(vectors)
202+
],
203+
model=emb_model.model_id,
189204
)
190-
# Convert llama-cpp response using model_validate like chat completion
191-
set_attribute_response_embedding(span, embedding_result)
192-
return embedding_result
205+
set_attribute_response_embedding(span, result)
206+
return result
193207

194208

195209
@app.get("/api/v1/models", response_model=ModelListResponse)
196210
async def list_models(
197211
settings: Annotated[Settings, Depends(get_settings)],
198212
) -> ModelListResponse:
199-
"""List available models (OpenAI-compatible). Returns the single loaded model."""
200-
model_id = Path(settings.model_path).stem
213+
"""List available models (OpenAI-compatible)."""
214+
chat_model_id = Path(settings.model_path).stem
201215
try:
202-
created = int(Path(settings.model_path).stat().st_mtime)
216+
chat_created = int(Path(settings.model_path).stat().st_mtime)
203217
except (OSError, ValueError):
204-
created = 0
218+
chat_created = 0
219+
220+
try:
221+
emb_created = int(Path(settings.embedding.onnx_path).stat().st_mtime)
222+
except (OSError, ValueError):
223+
emb_created = 0
224+
205225
return ModelListResponse(
206-
object="list",
207226
data=[
208227
ModelInfo(
209-
id=model_id,
210-
object="model",
211-
created=created,
228+
id=chat_model_id,
229+
created=chat_created,
212230
owned_by=settings.model_owner,
213-
)
231+
),
232+
ModelInfo(
233+
id=settings.embedding.model_id,
234+
created=emb_created,
235+
owned_by="sentence-transformers",
236+
),
214237
],
215238
)
216239

slm_server/config.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,34 @@
1616
MODEL_PATH_DEFAULT = str(MODELS_DIR / "Qwen3-0.6B-Q4_K_M.gguf")
1717
MODEL_OWNER_DEFAULT = "second-state"
1818

19+
EMBEDDING_TOKENIZER_PATH_DEFAULT = str(
20+
MODELS_DIR / "all-MiniLM-L6-v2" / "tokenizer.json"
21+
)
22+
EMBEDDING_ONNX_PATH_DEFAULT = str(
23+
MODELS_DIR / "all-MiniLM-L6-v2" / "onnx" / "model_quint8_avx2.onnx"
24+
)
25+
26+
27+
class EmbeddingSettings(BaseModel):
28+
model_id: str = Field(
29+
"all-MiniLM-L6-v2",
30+
description="Model identifier returned in API responses.",
31+
)
32+
tokenizer_path: str = Field(
33+
EMBEDDING_TOKENIZER_PATH_DEFAULT,
34+
description="Full path to the tokenizer.json file.",
35+
)
36+
onnx_path: str = Field(
37+
EMBEDDING_ONNX_PATH_DEFAULT,
38+
description="Full path to the ONNX model file.",
39+
)
40+
max_length: int = Field(
41+
256,
42+
description="Maximum token sequence length for the tokenizer. "
43+
"all-MiniLM-L6-v2 was trained with 256; increase only if "
44+
"swapping to a model that supports longer sequences.",
45+
)
46+
1947

2048
class LoggingSettings(BaseModel):
2149
verbose: bool = Field(True, description="If logging to stdout by cpp llama")
@@ -75,6 +103,7 @@ class Settings(BaseSettings):
75103
1, description="Seconds to wait if undergoing another inference."
76104
)
77105

106+
embedding: EmbeddingSettings = Field(default_factory=EmbeddingSettings)
78107
logging: LoggingSettings = Field(default_factory=LoggingSettings)
79108
metrics: MetricsSettings = Field(default_factory=MetricsSettings)
80109
tracing: TraceSettings = Field(default_factory=TraceSettings)

0 commit comments

Comments
 (0)