Skip to content

Commit ba4e9df

Browse files
🎨 reformatted by ruff
1 parent c54e79a commit ba4e9df

2 files changed

Lines changed: 4 additions & 9 deletions

File tree

slm_server/app.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,9 +194,7 @@ async def create_embeddings(
194194
"""Create embeddings using the dedicated ONNX embedding model."""
195195
with slm_embedding_span(req) as span:
196196
inputs = req.input if isinstance(req.input, list) else [req.input]
197-
vectors = await asyncio.to_thread(
198-
emb_model.encode, inputs, True
199-
)
197+
vectors = await asyncio.to_thread(emb_model.encode, inputs, True)
200198
result = EmbeddingResponse(
201199
data=[
202200
EmbeddingData(embedding=vec.tolist(), index=i)

slm_server/embedding.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
KEY_ATTENTION_MASK: str = "attention_mask"
1919
KEY_TOKEN_TYPE_IDS: str = "token_type_ids"
2020

21+
2122
class OnnxEmbeddingModel:
2223
"""Lightweight ONNX-based sentence embedding model.
2324
@@ -35,9 +36,7 @@ def __init__(self, settings: EmbeddingSettings):
3536
self.tokenizer.enable_truncation(max_length=settings.max_length)
3637
self.tokenizer.enable_padding(length=None)
3738

38-
self.session = InferenceSession(
39-
settings.onnx_path, providers=ONNX_PROVIDERS
40-
)
39+
self.session = InferenceSession(settings.onnx_path, providers=ONNX_PROVIDERS)
4140

4241
elapsed_ms = (time.monotonic() - start) * 1000
4342
logger.info(
@@ -52,9 +51,7 @@ def encode(self, texts: list[str], normalize: bool = True) -> np.ndarray:
5251
encodings = self.tokenizer.encode_batch(texts)
5352

5453
input_ids = np.array([e.ids for e in encodings], dtype=np.int64)
55-
attention_mask = np.array(
56-
[e.attention_mask for e in encodings], dtype=np.int64
57-
)
54+
attention_mask = np.array([e.attention_mask for e in encodings], dtype=np.int64)
5855
token_type_ids = np.zeros_like(input_ids)
5956

6057
outputs = self.session.run(

0 commit comments

Comments
 (0)