22import json
33import traceback
44from http import HTTPStatus
5- from typing import Annotated , AsyncGenerator
5+ from typing import Annotated , AsyncGenerator , Generator , Literal
66
77from fastapi import Depends , FastAPI , HTTPException
88from fastapi .responses import StreamingResponse
1919from slm_server .utils import (
2020 set_atrribute_response ,
2121 set_atrribute_response_stream ,
22- set_attribute_cancelled ,
2322 set_attribute_response_embedding ,
2423 slm_embedding_span ,
2524 slm_span ,
3029# for single thread. Meanwhile, value larger than 1 allows
3130# threads to compete for same resources.
3231MAX_CONCURRENCY = 1
32+ # Keeps function calling and also compatible with ReAct agents.
33+ CHAT_FORMAT = "chatml-function-calling"
3334# Default timeout message in detail field.
3435DETAIL_SEM_TIMEOUT = "Server is busy, please try again later."
3536# Status code for semaphore timeout.
3637STATUS_CODE_SEM_TIMEOUT = HTTPStatus .REQUEST_TIMEOUT
3738# Status code for unexpected errors.
3839# This is used when the server encounters an error that is not handled
3940STATUS_CODE_EXCEPTION = HTTPStatus .INTERNAL_SERVER_ERROR
41+ # Media type for streaming responses.
42+ STREAM_RESPONSE_MEDIA_TYPE = "text/event-stream"
4043
4144
4245def get_llm_semaphor () -> asyncio .Semaphore :
@@ -54,11 +57,11 @@ def get_llm(settings: Annotated[Settings, Depends(get_settings)]) -> Llama:
5457 n_batch = settings .n_batch ,
5558 verbose = settings .logging .verbose ,
5659 seed = settings .seed ,
60+ chat_format = CHAT_FORMAT ,
5761 logits_all = False ,
5862 embedding = True ,
5963 use_mlock = True , # Use mlock to prevent memory swapping
6064 use_mmap = True , # Use memory-mapped files for faster access
61- chat_format = "chatml-function-calling" ,
6265 )
6366 return get_llm ._instance
6467
@@ -89,11 +92,11 @@ def get_app() -> FastAPI:
8992async def lock_llm_semaphor (
9093 sem : Annotated [asyncio .Semaphore , Depends (get_llm_semaphor )],
9194 settings : Annotated [Settings , Depends (get_settings )],
92- ) -> AsyncGenerator [None , None ]:
95+ ) -> AsyncGenerator [Literal [ True ] , None ]:
9396 """Context manager to acquire and release the LLM semaphore with a timeout."""
9497 try :
9598 await asyncio .wait_for (sem .acquire (), settings .s_timeout )
96- yield None
99+ yield True
97100 except asyncio .TimeoutError :
98101 raise HTTPException (
99102 status_code = STATUS_CODE_SEM_TIMEOUT , detail = DETAIL_SEM_TIMEOUT
@@ -103,28 +106,37 @@ async def lock_llm_semaphor(
103106 sem .release ()
104107
105108
109+ def raise_as_http_exception () -> Generator [Literal [True ], None , None ]:
110+ """Capture exception with stack trace in details."""
111+ try :
112+ yield True
113+ except Exception :
114+ error_str = traceback .format_exc ()
115+ raise HTTPException (status_code = STATUS_CODE_EXCEPTION , detail = error_str )
116+
117+
106118async def run_llm_streaming (
107119 llm : Llama , req : ChatCompletionRequest
108120) -> AsyncGenerator [str , None ]:
109121 """Generator that runs the LLM and yields SSE chunks under lock."""
110122 with slm_span (req , is_streaming = True ) as span :
111- try :
112- completion_stream = await asyncio .to_thread (
113- llm .create_chat_completion ,
114- ** req .model_dump (),
115- )
123+ completion_stream = await asyncio .to_thread (
124+ llm .create_chat_completion ,
125+ ** req .model_dump (),
126+ )
116127
117- # Use traced iterator that automatically handles chunk spans
118- # and parent span updates
119- chunk : CreateChatCompletionStreamResponse
120- for chunk in completion_stream :
121- set_atrribute_response_stream (span , chunk )
122- yield f"data: { json .dumps (chunk )} \n \n "
128+ # Use traced iterator that automatically handles chunk spans
129+ # and parent span updates
130+ chunk : CreateChatCompletionStreamResponse
131+ for chunk in completion_stream :
132+ set_atrribute_response_stream (span , chunk )
133+ yield f"data: { json .dumps (chunk )} \n \n "
134+ # NOTE: This is a workaround to yield control back to the event loop
135+ # to allow checking for socket after yield and pop in CancelledError.
136+ # Ref: https://github.com/encode/starlette/discussions/1776#discussioncomment-3207518
137+ await asyncio .sleep (0 )
123138
124- yield "data: [DONE]\n \n "
125- except asyncio .CancelledError :
126- # Handle cancellation gracefully during sse.
127- set_attribute_cancelled (span )
139+ yield "data: [DONE]\n \n "
128140
129141
130142async def run_llm_non_streaming (llm : Llama , req : ChatCompletionRequest ):
@@ -144,44 +156,37 @@ async def create_chat_completion(
144156 req : ChatCompletionRequest ,
145157 llm : Annotated [Llama , Depends (get_llm )],
146158 _ : Annotated [None , Depends (lock_llm_semaphor )],
159+ __ : Annotated [None , Depends (raise_as_http_exception )],
147160):
148161 """
149162 Generates a chat completion, handling both streaming and non-streaming cases.
150163 Concurrency is managed by the `locked_llm_session` context manager.
151164 """
152- try :
153- if req .stream :
154- return StreamingResponse (
155- run_llm_streaming (llm , req ), media_type = "text/event-stream"
156- )
157- else :
158- return await run_llm_non_streaming (llm , req )
159- except Exception :
160- # Catch any other unexpected errors
161- error_str = traceback .format_exc ()
162- raise HTTPException (status_code = STATUS_CODE_EXCEPTION , detail = error_str )
165+ if req .stream :
166+ return StreamingResponse (
167+ run_llm_streaming (llm , req ), media_type = STREAM_RESPONSE_MEDIA_TYPE
168+ )
169+ else :
170+ return await run_llm_non_streaming (llm , req )
163171
164172
165173@app .post ("/api/v1/embeddings" )
166174async def create_embeddings (
167175 req : EmbeddingRequest ,
168176 llm : Annotated [Llama , Depends (get_llm )],
169177 _ : Annotated [None , Depends (lock_llm_semaphor )],
178+ __ : Annotated [None , Depends (raise_as_http_exception )],
170179):
171180 """Create embeddings for the given input text(s)."""
172- try :
173- with slm_embedding_span (req ) as span :
174- # Use llama-cpp-python's create_embedding method directly
175- embedding_result = await asyncio .to_thread (
176- llm .create_embedding ,
177- ** req .model_dump (),
178- )
179- # Convert llama-cpp response using model_validate like chat completion
180- set_attribute_response_embedding (span , embedding_result )
181- return embedding_result
182- except Exception :
183- error_str = traceback .format_exc ()
184- raise HTTPException (status_code = STATUS_CODE_EXCEPTION , detail = error_str )
181+ with slm_embedding_span (req ) as span :
182+ # Use llama-cpp-python's create_embedding method directly
183+ embedding_result = await asyncio .to_thread (
184+ llm .create_embedding ,
185+ ** req .model_dump (),
186+ )
187+ # Convert llama-cpp response using model_validate like chat completion
188+ set_attribute_response_embedding (span , embedding_result )
189+ return embedding_result
185190
186191
187192@app .get ("/health" )
0 commit comments