Skip to content

Commit a83926a

Browse files
committed
feat: add reasoning_effort chat template kwarg
1 parent 7613aca commit a83926a

4 files changed

Lines changed: 127 additions & 4 deletions

File tree

llama_cpp/llama.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1973,6 +1973,9 @@ def create_chat_completion(
19731973
logit_bias: Optional[Dict[int, float]] = None,
19741974
logprobs: Optional[bool] = None,
19751975
top_logprobs: Optional[int] = None,
1976+
reasoning_effort: Optional[
1977+
Literal["none", "minimal", "low", "medium", "high", "xhigh"]
1978+
] = None,
19761979
) -> Union[
19771980
CreateChatCompletionResponse, Iterator[CreateChatCompletionStreamResponse]
19781981
]:
@@ -2005,6 +2008,8 @@ def create_chat_completion(
20052008
logits_processor: A list of logits processors to use.
20062009
grammar: A grammar to use.
20072010
logit_bias: A logit bias to use.
2011+
reasoning_effort: Optional reasoning hint forwarded to chat handlers as a
2012+
chat-template keyword argument.
20082013
20092014
Returns:
20102015
Generated chat completion or a stream of chat completion chunks.
@@ -2044,6 +2049,7 @@ def create_chat_completion(
20442049
logits_processor=logits_processor,
20452050
grammar=grammar,
20462051
logit_bias=logit_bias,
2052+
reasoning_effort=reasoning_effort,
20472053
)
20482054

20492055
def create_chat_completion_openai_v1(

llama_cpp/llama_chat_format.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ def raise_exception(message: str):
243243
tools=tools,
244244
tool_choice=tool_choice,
245245
strftime_now=self.strftime_now,
246+
**kwargs,
246247
)
247248

248249
stopping_criteria = None
@@ -617,6 +618,7 @@ def chat_completion_handler(
617618
function_call=function_call,
618619
tools=tools,
619620
tool_choice=tool_choice,
621+
**kwargs,
620622
)
621623
prompt = llama.tokenize(
622624
result.prompt.encode("utf-8"),
@@ -734,7 +736,9 @@ def format_autotokenizer(
734736
**kwargs: Any,
735737
) -> ChatFormatterResponse:
736738
tokenizer.use_default_system_prompt = False # type: ignore
737-
prompt: str = tokenizer.apply_chat_template(messages, tokenize=False) # type: ignore
739+
prompt: str = tokenizer.apply_chat_template( # type: ignore
740+
messages, tokenize=False, **kwargs
741+
)
738742
assert isinstance(prompt, str)
739743
# Return formatted prompt and eos token by default
740744
return ChatFormatterResponse(
@@ -791,6 +795,7 @@ def format_tokenizer_config(
791795
messages=messages,
792796
bos_token=bos_token,
793797
eos_token=eos_token,
798+
**kwargs,
794799
)
795800
return ChatFormatterResponse(
796801
prompt=prompt, stop=[eos_token, bos_token], added_special=True

llama_cpp/server/types.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,12 @@ class CreateChatCompletionRequest(BaseModel):
235235
response_format: Optional[llama_cpp.ChatCompletionRequestResponseFormat] = Field(
236236
default=None,
237237
)
238+
reasoning_effort: Optional[
239+
Literal["none", "minimal", "low", "medium", "high", "xhigh"]
240+
] = Field(
241+
default=None,
242+
description="Optional reasoning-effort hint exposed to chat templates as the `reasoning_effort` keyword argument.",
243+
)
238244

239245
# ignored or currently unsupported
240246
model: Optional[str] = model_field

tests/test_llama_chat_format.py

Lines changed: 109 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import json
2+
import inspect
23

34
import jinja2
45

5-
from llama_cpp import (
6-
ChatCompletionRequestUserMessage,
7-
)
6+
import llama_cpp
7+
from llama_cpp import ChatCompletionRequestUserMessage
88
import llama_cpp.llama_types as llama_types
99
import llama_cpp.llama_chat_format as llama_chat_format
10+
import llama_cpp.server.types as server_types
1011

1112
from llama_cpp.llama_chat_format import hf_tokenizer_config_to_chat_formatter
1213

@@ -92,3 +93,108 @@ def test_hf_tokenizer_config_str_to_chat_formatter():
9293
)
9394

9495
assert chat_formatter_respoonse.prompt == ("<s>[INST] Hello, world! [/INST]</s>")
96+
97+
98+
def test_jinja2_chat_formatter_passes_template_kwargs():
99+
chat_formatter = llama_chat_format.Jinja2ChatFormatter(
100+
template="{{ reasoning_effort | default('unset') }} {{ messages[0]['content'] }}",
101+
bos_token="<s>",
102+
eos_token="</s>",
103+
)
104+
response = chat_formatter(
105+
messages=[
106+
ChatCompletionRequestUserMessage(role="user", content="Hello, world!"),
107+
],
108+
reasoning_effort="low",
109+
)
110+
111+
assert response.prompt == "low Hello, world!"
112+
113+
114+
def test_hf_tokenizer_config_chat_formatter_passes_template_kwargs():
115+
tokenizer_config = {
116+
"chat_template": "{{ bos_token }}{{ reasoning_effort | default('unset') }} {{ messages[0]['content'] }}",
117+
"bos_token": "<s>",
118+
"eos_token": "</s>",
119+
}
120+
chat_formatter = hf_tokenizer_config_to_chat_formatter(
121+
tokenizer_config, add_generation_prompt=False
122+
)
123+
response = chat_formatter(
124+
messages=[
125+
ChatCompletionRequestUserMessage(role="user", content="Hello, world!"),
126+
],
127+
reasoning_effort="medium",
128+
)
129+
130+
assert response.prompt == "<s>medium Hello, world!"
131+
132+
133+
def test_chat_completion_handler_passes_template_kwargs():
134+
captured = {}
135+
136+
def chat_formatter(*, messages, **kwargs):
137+
captured["messages"] = messages
138+
captured["kwargs"] = kwargs
139+
return llama_chat_format.ChatFormatterResponse(prompt="Hello")
140+
141+
handler = llama_chat_format.chat_formatter_to_chat_completion_handler(
142+
chat_formatter
143+
)
144+
145+
class DummyLlama:
146+
verbose = False
147+
148+
def tokenize(self, data, add_bos, special):
149+
return [1]
150+
151+
def create_completion(self, **kwargs):
152+
return {
153+
"id": "cmpl-test",
154+
"object": "text_completion",
155+
"created": 0,
156+
"model": "dummy",
157+
"choices": [
158+
{
159+
"text": "world",
160+
"index": 0,
161+
"logprobs": None,
162+
"finish_reason": "stop",
163+
}
164+
],
165+
"usage": {
166+
"prompt_tokens": 1,
167+
"completion_tokens": 1,
168+
"total_tokens": 2,
169+
},
170+
}
171+
172+
response = handler(
173+
llama=DummyLlama(),
174+
messages=[
175+
ChatCompletionRequestUserMessage(role="user", content="Hello, world!"),
176+
],
177+
reasoning_effort="high",
178+
)
179+
180+
assert response["choices"][0]["message"]["content"] == "world"
181+
assert captured["kwargs"]["reasoning_effort"] == "high"
182+
183+
184+
def test_create_chat_completion_exposes_reasoning_effort_parameter():
185+
parameter = inspect.signature(llama_cpp.Llama.create_chat_completion).parameters[
186+
"reasoning_effort"
187+
]
188+
189+
assert parameter.default is None
190+
191+
192+
def test_server_chat_completion_request_accepts_reasoning_effort():
193+
request = server_types.CreateChatCompletionRequest(
194+
messages=[
195+
ChatCompletionRequestUserMessage(role="user", content="Hello, world!")
196+
],
197+
reasoning_effort="minimal",
198+
)
199+
200+
assert request.reasoning_effort == "minimal"

0 commit comments

Comments
 (0)