|
10 | 10 | from llama_cpp import CreateChatCompletionStreamResponse, Llama |
11 | 11 |
|
12 | 12 | from slm_server.config import Settings, get_settings |
| 13 | +from slm_server.embedding import OnnxEmbeddingModel |
13 | 14 | from slm_server.logging import setup_logging |
14 | 15 | from slm_server.metrics import setup_metrics |
15 | 16 | from slm_server.model import ( |
16 | 17 | ChatCompletionRequest, |
| 18 | + EmbeddingData, |
17 | 19 | EmbeddingRequest, |
| 20 | + EmbeddingResponse, |
18 | 21 | ModelInfo, |
19 | 22 | ModelListResponse, |
20 | 23 | ) |
@@ -62,13 +65,21 @@ def get_llm(settings: Annotated[Settings, Depends(get_settings)]) -> Llama: |
62 | 65 | seed=settings.seed, |
63 | 66 | chat_format=CHAT_FORMAT, |
64 | 67 | logits_all=False, |
65 | | - embedding=True, |
66 | | - use_mlock=True, # Use mlock to prevent memory swapping |
67 | | - use_mmap=True, # Use memory-mapped files for faster access |
| 68 | + embedding=False, |
| 69 | + use_mlock=True, |
| 70 | + use_mmap=True, |
68 | 71 | ) |
69 | 72 | return get_llm._instance |
70 | 73 |
|
71 | 74 |
|
| 75 | +def get_embedding_model( |
| 76 | + settings: Annotated[Settings, Depends(get_settings)], |
| 77 | +) -> OnnxEmbeddingModel: |
| 78 | + if not hasattr(get_embedding_model, "_instance"): |
| 79 | + get_embedding_model._instance = OnnxEmbeddingModel(settings.embedding) |
| 80 | + return get_embedding_model._instance |
| 81 | + |
| 82 | + |
72 | 83 | def get_app() -> FastAPI: |
73 | 84 | # Get settings when creating app. |
74 | 85 | settings = get_settings() |
@@ -176,41 +187,53 @@ async def create_chat_completion( |
176 | 187 | @app.post("/api/v1/embeddings") |
177 | 188 | async def create_embeddings( |
178 | 189 | req: EmbeddingRequest, |
179 | | - llm: Annotated[Llama, Depends(get_llm)], |
| 190 | + emb_model: Annotated[OnnxEmbeddingModel, Depends(get_embedding_model)], |
180 | 191 | _: Annotated[None, Depends(lock_llm_semaphor)], |
181 | 192 | __: Annotated[None, Depends(raise_as_http_exception)], |
182 | 193 | ): |
183 | | - """Create embeddings for the given input text(s).""" |
| 194 | + """Create embeddings using the dedicated ONNX embedding model.""" |
184 | 195 | with slm_embedding_span(req) as span: |
185 | | - # Use llama-cpp-python's create_embedding method directly |
186 | | - embedding_result = await asyncio.to_thread( |
187 | | - llm.create_embedding, |
188 | | - **req.model_dump(), |
| 196 | + inputs = req.input if isinstance(req.input, list) else [req.input] |
| 197 | + vectors = await asyncio.to_thread(emb_model.encode, inputs, True) |
| 198 | + result = EmbeddingResponse( |
| 199 | + data=[ |
| 200 | + EmbeddingData(embedding=vec.tolist(), index=i) |
| 201 | + for i, vec in enumerate(vectors) |
| 202 | + ], |
| 203 | + model=emb_model.model_id, |
189 | 204 | ) |
190 | | - # Convert llama-cpp response using model_validate like chat completion |
191 | | - set_attribute_response_embedding(span, embedding_result) |
192 | | - return embedding_result |
| 205 | + set_attribute_response_embedding(span, result) |
| 206 | + return result |
193 | 207 |
|
194 | 208 |
|
195 | 209 | @app.get("/api/v1/models", response_model=ModelListResponse) |
196 | 210 | async def list_models( |
197 | 211 | settings: Annotated[Settings, Depends(get_settings)], |
198 | 212 | ) -> ModelListResponse: |
199 | | - """List available models (OpenAI-compatible). Returns the single loaded model.""" |
200 | | - model_id = Path(settings.model_path).stem |
| 213 | + """List available models (OpenAI-compatible).""" |
| 214 | + chat_model_id = Path(settings.model_path).stem |
201 | 215 | try: |
202 | | - created = int(Path(settings.model_path).stat().st_mtime) |
| 216 | + chat_created = int(Path(settings.model_path).stat().st_mtime) |
203 | 217 | except (OSError, ValueError): |
204 | | - created = 0 |
| 218 | + chat_created = 0 |
| 219 | + |
| 220 | + try: |
| 221 | + emb_created = int(Path(settings.embedding.onnx_path).stat().st_mtime) |
| 222 | + except (OSError, ValueError): |
| 223 | + emb_created = 0 |
| 224 | + |
205 | 225 | return ModelListResponse( |
206 | | - object="list", |
207 | 226 | data=[ |
208 | 227 | ModelInfo( |
209 | | - id=model_id, |
210 | | - object="model", |
211 | | - created=created, |
| 228 | + id=chat_model_id, |
| 229 | + created=chat_created, |
212 | 230 | owned_by=settings.model_owner, |
213 | | - ) |
| 231 | + ), |
| 232 | + ModelInfo( |
| 233 | + id=settings.embedding.model_id, |
| 234 | + created=emb_created, |
| 235 | + owned_by="sentence-transformers", |
| 236 | + ), |
214 | 237 | ], |
215 | 238 | ) |
216 | 239 |
|
|
0 commit comments