|
1 | 1 | import json |
| 2 | +import inspect |
2 | 3 |
|
3 | 4 | import jinja2 |
4 | 5 |
|
5 | | -from llama_cpp import ( |
6 | | - ChatCompletionRequestUserMessage, |
7 | | -) |
| 6 | +import llama_cpp |
| 7 | +from llama_cpp import ChatCompletionRequestUserMessage |
8 | 8 | import llama_cpp.llama_types as llama_types |
9 | 9 | import llama_cpp.llama_chat_format as llama_chat_format |
| 10 | +import llama_cpp.server.types as server_types |
10 | 11 |
|
11 | 12 | from llama_cpp.llama_chat_format import hf_tokenizer_config_to_chat_formatter |
12 | 13 |
|
@@ -92,3 +93,108 @@ def test_hf_tokenizer_config_str_to_chat_formatter(): |
92 | 93 | ) |
93 | 94 |
|
94 | 95 | assert chat_formatter_respoonse.prompt == ("<s>[INST] Hello, world! [/INST]</s>") |
| 96 | + |
| 97 | + |
| 98 | +def test_jinja2_chat_formatter_passes_template_kwargs(): |
| 99 | + chat_formatter = llama_chat_format.Jinja2ChatFormatter( |
| 100 | + template="{{ reasoning_effort | default('unset') }} {{ messages[0]['content'] }}", |
| 101 | + bos_token="<s>", |
| 102 | + eos_token="</s>", |
| 103 | + ) |
| 104 | + response = chat_formatter( |
| 105 | + messages=[ |
| 106 | + ChatCompletionRequestUserMessage(role="user", content="Hello, world!"), |
| 107 | + ], |
| 108 | + reasoning_effort="low", |
| 109 | + ) |
| 110 | + |
| 111 | + assert response.prompt == "low Hello, world!" |
| 112 | + |
| 113 | + |
| 114 | +def test_hf_tokenizer_config_chat_formatter_passes_template_kwargs(): |
| 115 | + tokenizer_config = { |
| 116 | + "chat_template": "{{ bos_token }}{{ reasoning_effort | default('unset') }} {{ messages[0]['content'] }}", |
| 117 | + "bos_token": "<s>", |
| 118 | + "eos_token": "</s>", |
| 119 | + } |
| 120 | + chat_formatter = hf_tokenizer_config_to_chat_formatter( |
| 121 | + tokenizer_config, add_generation_prompt=False |
| 122 | + ) |
| 123 | + response = chat_formatter( |
| 124 | + messages=[ |
| 125 | + ChatCompletionRequestUserMessage(role="user", content="Hello, world!"), |
| 126 | + ], |
| 127 | + reasoning_effort="medium", |
| 128 | + ) |
| 129 | + |
| 130 | + assert response.prompt == "<s>medium Hello, world!" |
| 131 | + |
| 132 | + |
| 133 | +def test_chat_completion_handler_passes_template_kwargs(): |
| 134 | + captured = {} |
| 135 | + |
| 136 | + def chat_formatter(*, messages, **kwargs): |
| 137 | + captured["messages"] = messages |
| 138 | + captured["kwargs"] = kwargs |
| 139 | + return llama_chat_format.ChatFormatterResponse(prompt="Hello") |
| 140 | + |
| 141 | + handler = llama_chat_format.chat_formatter_to_chat_completion_handler( |
| 142 | + chat_formatter |
| 143 | + ) |
| 144 | + |
| 145 | + class DummyLlama: |
| 146 | + verbose = False |
| 147 | + |
| 148 | + def tokenize(self, data, add_bos, special): |
| 149 | + return [1] |
| 150 | + |
| 151 | + def create_completion(self, **kwargs): |
| 152 | + return { |
| 153 | + "id": "cmpl-test", |
| 154 | + "object": "text_completion", |
| 155 | + "created": 0, |
| 156 | + "model": "dummy", |
| 157 | + "choices": [ |
| 158 | + { |
| 159 | + "text": "world", |
| 160 | + "index": 0, |
| 161 | + "logprobs": None, |
| 162 | + "finish_reason": "stop", |
| 163 | + } |
| 164 | + ], |
| 165 | + "usage": { |
| 166 | + "prompt_tokens": 1, |
| 167 | + "completion_tokens": 1, |
| 168 | + "total_tokens": 2, |
| 169 | + }, |
| 170 | + } |
| 171 | + |
| 172 | + response = handler( |
| 173 | + llama=DummyLlama(), |
| 174 | + messages=[ |
| 175 | + ChatCompletionRequestUserMessage(role="user", content="Hello, world!"), |
| 176 | + ], |
| 177 | + reasoning_effort="high", |
| 178 | + ) |
| 179 | + |
| 180 | + assert response["choices"][0]["message"]["content"] == "world" |
| 181 | + assert captured["kwargs"]["reasoning_effort"] == "high" |
| 182 | + |
| 183 | + |
| 184 | +def test_create_chat_completion_exposes_reasoning_effort_parameter(): |
| 185 | + parameter = inspect.signature(llama_cpp.Llama.create_chat_completion).parameters[ |
| 186 | + "reasoning_effort" |
| 187 | + ] |
| 188 | + |
| 189 | + assert parameter.default is None |
| 190 | + |
| 191 | + |
| 192 | +def test_server_chat_completion_request_accepts_reasoning_effort(): |
| 193 | + request = server_types.CreateChatCompletionRequest( |
| 194 | + messages=[ |
| 195 | + ChatCompletionRequestUserMessage(role="user", content="Hello, world!") |
| 196 | + ], |
| 197 | + reasoning_effort="minimal", |
| 198 | + ) |
| 199 | + |
| 200 | + assert request.reasoning_effort == "minimal" |
0 commit comments