Skip to content

Commit 773489e

Browse files
authored
fix(chat): do not retry if we had chatdeltas or tooldeltas from backend (#9244)
* fix(chat): do not retry if we had chatdeltas or tooldeltas from backend Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix: use oai compat for llama.cpp Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix: apply to non-streaming path too Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * map also other fields Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
1 parent 06fbe48 commit 773489e

File tree

6 files changed

+478
-40
lines changed

6 files changed

+478
-40
lines changed

backend/cpp/llama-cpp/grpc-server.cpp

Lines changed: 96 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1309,6 +1309,7 @@ class BackendServiceImpl final : public backend::Backend::Service {
13091309

13101310
body_json["messages"] = messages_json;
13111311
body_json["stream"] = true; // PredictStream is always streaming
1312+
body_json["stream_options"] = {{"include_usage", true}}; // Ensure token counts in final chunk
13121313

13131314
// Check if grammar is provided from Go layer (NoGrammar=false)
13141315
// If grammar is provided, we must use it and NOT let template generate grammar from tools
@@ -1616,8 +1617,11 @@ class BackendServiceImpl final : public backend::Backend::Service {
16161617
data);
16171618
task.id_slot = json_value(data, "id_slot", -1);
16181619

1619-
// OAI-compat
1620-
task.params.res_type = TASK_RESPONSE_TYPE_NONE;
1620+
// OAI-compat: enable autoparser (PEG-based chat parsing) so that
1621+
// reasoning, tool calls, and content are classified into ChatDeltas.
1622+
// Without this, the PEG parser never produces diffs and the Go side
1623+
// cannot detect tool calls or separate reasoning from content.
1624+
task.params.res_type = TASK_RESPONSE_TYPE_OAI_CHAT;
16211625
task.params.oaicompat_cmpl_id = completion_id;
16221626
// oaicompat_model is already populated by params_from_json_cmpl
16231627

@@ -1642,19 +1646,47 @@ class BackendServiceImpl final : public backend::Backend::Service {
16421646
return grpc::Status(grpc::StatusCode::INTERNAL, error_json.value("message", "Error occurred"));
16431647
}
16441648

1645-
// Lambda to build a Reply from JSON + attach chat deltas from a result
1649+
// Lambda to build a Reply from JSON + attach chat deltas from a result.
1650+
// Handles both native format ({"content": "..."}) and OAI chat format
1651+
// ({"choices": [{"delta": {"content": "...", "reasoning": "..."}}]}).
16461652
auto build_reply_from_json = [](const json & res_json, server_task_result * raw_result) -> backend::Reply {
16471653
backend::Reply reply;
1648-
std::string completion_text = res_json.value("content", "");
1654+
std::string completion_text;
1655+
1656+
if (res_json.contains("choices")) {
1657+
// OAI chat format — extract content from choices[0].delta
1658+
const auto & choices = res_json.at("choices");
1659+
if (!choices.empty()) {
1660+
const auto & delta = choices[0].value("delta", json::object());
1661+
if (delta.contains("content") && !delta.at("content").is_null()) {
1662+
completion_text = delta.at("content").get<std::string>();
1663+
}
1664+
}
1665+
} else {
1666+
// Native llama.cpp format
1667+
completion_text = res_json.value("content", "");
1668+
}
1669+
16491670
reply.set_message(completion_text);
1650-
reply.set_tokens(res_json.value("tokens_predicted", 0));
1651-
reply.set_prompt_tokens(res_json.value("tokens_evaluated", 0));
16521671

1672+
// Token counts: native format has top-level fields,
1673+
// OAI format has them in "usage" (final chunk only)
1674+
if (res_json.contains("usage")) {
1675+
const auto & usage = res_json.at("usage");
1676+
reply.set_tokens(usage.value("completion_tokens", 0));
1677+
reply.set_prompt_tokens(usage.value("prompt_tokens", 0));
1678+
} else {
1679+
reply.set_tokens(res_json.value("tokens_predicted", 0));
1680+
reply.set_prompt_tokens(res_json.value("tokens_evaluated", 0));
1681+
}
1682+
1683+
// Timings: present as top-level "timings" in both formats
16531684
if (res_json.contains("timings")) {
16541685
reply.set_timing_prompt_processing(res_json.at("timings").value("prompt_ms", 0.0));
16551686
reply.set_timing_token_generation(res_json.at("timings").value("predicted_ms", 0.0));
16561687
}
16571688

1689+
// Logprobs: extract_logprobs_from_json handles both formats
16581690
json logprobs_json = extract_logprobs_from_json(res_json);
16591691
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
16601692
reply.set_logprobs(logprobs_json.dump());
@@ -1663,21 +1695,17 @@ class BackendServiceImpl final : public backend::Backend::Service {
16631695
return reply;
16641696
};
16651697

1698+
// Attach chat deltas from the autoparser to a Reply.
1699+
// When diffs are available, populate ChatDeltas on the reply.
1700+
// The raw message is always preserved so the Go side can use it
1701+
// for reasoning extraction and tool call parsing as a fallback
1702+
// (important in distributed mode where ChatDeltas may not be
1703+
// the primary parsing path).
16661704
auto attach_chat_deltas = [](backend::Reply & reply, server_task_result * raw_result) {
16671705
// Try streaming partial result first
16681706
auto* partial = dynamic_cast<server_task_result_cmpl_partial*>(raw_result);
1669-
if (partial) {
1670-
if (!partial->oaicompat_msg_diffs.empty()) {
1671-
populate_chat_deltas_from_diffs(reply, partial->oaicompat_msg_diffs);
1672-
} else if (partial->is_updated) {
1673-
// Autoparser is active but hasn't classified this chunk yet
1674-
// (PEG parser warming up). Clear the raw message so the Go
1675-
// side doesn't try to parse partial tag tokens (e.g. "<|channel>"
1676-
// before the full "<|channel>thought\n" is received).
1677-
// This matches llama.cpp server behavior which only emits SSE
1678-
// chunks when the parser produces diffs.
1679-
reply.set_message("");
1680-
}
1707+
if (partial && !partial->oaicompat_msg_diffs.empty()) {
1708+
populate_chat_deltas_from_diffs(reply, partial->oaicompat_msg_diffs);
16811709
return;
16821710
}
16831711
// Try final result
@@ -2357,8 +2385,9 @@ class BackendServiceImpl final : public backend::Backend::Service {
23572385
data);
23582386
task.id_slot = json_value(data, "id_slot", -1);
23592387

2360-
// OAI-compat
2361-
task.params.res_type = TASK_RESPONSE_TYPE_NONE;
2388+
// OAI-compat: enable autoparser (PEG-based chat parsing) so that
2389+
// reasoning, tool calls, and content are classified into ChatDeltas.
2390+
task.params.res_type = TASK_RESPONSE_TYPE_OAI_CHAT;
23622391
task.params.oaicompat_cmpl_id = completion_id;
23632392
// oaicompat_model is already populated by params_from_json_cmpl
23642393

@@ -2389,25 +2418,48 @@ class BackendServiceImpl final : public backend::Backend::Service {
23892418
auto* final_res = dynamic_cast<server_task_result_cmpl_final*>(all_results.results[0].get());
23902419
GGML_ASSERT(final_res != nullptr);
23912420
json result_json = all_results.results[0]->to_json();
2392-
reply->set_message(result_json.value("content", ""));
23932421

2394-
int32_t tokens_predicted = result_json.value("tokens_predicted", 0);
2422+
// Handle both native format ({"content": "...", "tokens_predicted": N})
2423+
// and OAI chat format ({"choices": [{"message": {"content": "..."}}],
2424+
// "usage": {"completion_tokens": N, "prompt_tokens": N}}).
2425+
std::string completion_text;
2426+
int32_t tokens_predicted = 0;
2427+
int32_t tokens_evaluated = 0;
2428+
2429+
if (result_json.contains("choices")) {
2430+
// OAI chat format
2431+
const auto & choices = result_json.at("choices");
2432+
if (!choices.empty()) {
2433+
const auto & msg = choices[0].value("message", json::object());
2434+
if (msg.contains("content") && !msg.at("content").is_null()) {
2435+
completion_text = msg.at("content").get<std::string>();
2436+
}
2437+
}
2438+
if (result_json.contains("usage")) {
2439+
const auto & usage = result_json.at("usage");
2440+
tokens_predicted = usage.value("completion_tokens", 0);
2441+
tokens_evaluated = usage.value("prompt_tokens", 0);
2442+
}
2443+
} else {
2444+
// Native llama.cpp format
2445+
completion_text = result_json.value("content", "");
2446+
tokens_predicted = result_json.value("tokens_predicted", 0);
2447+
tokens_evaluated = result_json.value("tokens_evaluated", 0);
2448+
}
2449+
reply->set_message(completion_text);
23952450
reply->set_tokens(tokens_predicted);
2396-
int32_t tokens_evaluated = result_json.value("tokens_evaluated", 0);
23972451
reply->set_prompt_tokens(tokens_evaluated);
23982452

2453+
// Timings: present in both formats as a top-level "timings" object
23992454
if (result_json.contains("timings")) {
2400-
double timing_prompt_processing = result_json.at("timings").value("prompt_ms", 0.0);
2401-
reply->set_timing_prompt_processing(timing_prompt_processing);
2402-
double timing_token_generation = result_json.at("timings").value("predicted_ms", 0.0);
2403-
reply->set_timing_token_generation(timing_token_generation);
2455+
reply->set_timing_prompt_processing(result_json.at("timings").value("prompt_ms", 0.0));
2456+
reply->set_timing_token_generation(result_json.at("timings").value("predicted_ms", 0.0));
24042457
}
24052458

2406-
// Extract and set logprobs if present
2459+
// Logprobs: extract_logprobs_from_json handles both formats
24072460
json logprobs_json = extract_logprobs_from_json(result_json);
24082461
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
2409-
std::string logprobs_str = logprobs_json.dump();
2410-
reply->set_logprobs(logprobs_str);
2462+
reply->set_logprobs(logprobs_json.dump());
24112463
}
24122464

24132465
// Populate chat deltas from the autoparser's final parsed message
@@ -2423,7 +2475,20 @@ class BackendServiceImpl final : public backend::Backend::Service {
24232475
for (auto & res : all_results.results) {
24242476
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
24252477
json res_json = res->to_json();
2426-
arr.push_back(res_json.value("content", ""));
2478+
// Handle both native and OAI chat formats
2479+
std::string result_content;
2480+
if (res_json.contains("choices")) {
2481+
const auto & choices = res_json.at("choices");
2482+
if (!choices.empty()) {
2483+
const auto & msg = choices[0].value("message", json::object());
2484+
if (msg.contains("content") && !msg.at("content").is_null()) {
2485+
result_content = msg.at("content").get<std::string>();
2486+
}
2487+
}
2488+
} else {
2489+
result_content = res_json.value("content", "");
2490+
}
2491+
arr.push_back(result_content);
24272492

24282493
// Extract logprobs for each result
24292494
json logprobs_json = extract_logprobs_from_json(res_json);

core/http/endpoints/openai/chat.go

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,23 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
147147
result := ""
148148
lastEmittedCount := 0
149149
sentInitialRole := false
150+
hasChatDeltaToolCalls := false
151+
hasChatDeltaContent := false
150152

151153
_, tokenUsage, chatDeltas, err := ComputeChoices(req, prompt, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
152154
result += s
153155

156+
// Track whether ChatDeltas from the C++ autoparser contain
157+
// tool calls or content, so the retry decision can account for them.
158+
for _, d := range usage.ChatDeltas {
159+
if len(d.ToolCalls) > 0 {
160+
hasChatDeltaToolCalls = true
161+
}
162+
if d.Content != "" {
163+
hasChatDeltaContent = true
164+
}
165+
}
166+
154167
var reasoningDelta, contentDelta string
155168

156169
goReasoning, goContent := extractor.ProcessToken(s)
@@ -309,15 +322,22 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
309322
// After streaming completes: check if we got actionable content
310323
cleaned := extractor.CleanedContent()
311324
// Check for tool calls from chat deltas (will be re-checked after ComputeChoices,
312-
// but we need to know here whether to retry)
313-
hasToolCalls := lastEmittedCount > 0
314-
if cleaned == "" && !hasToolCalls {
325+
// but we need to know here whether to retry).
326+
// Also check ChatDelta flags — when the C++ autoparser is active,
327+
// tool calls and content are delivered via ChatDeltas while the
328+
// raw message is cleared. Without this check, we'd retry
329+
// unnecessarily, losing valid results and concatenating output.
330+
hasToolCalls := lastEmittedCount > 0 || hasChatDeltaToolCalls
331+
hasContent := cleaned != "" || hasChatDeltaContent
332+
if !hasContent && !hasToolCalls {
315333
xlog.Warn("Streaming: backend produced only reasoning, retrying",
316334
"reasoning_len", len(extractor.Reasoning()), "attempt", attempt+1)
317335
extractor.ResetAndSuppressReasoning()
318336
result = ""
319337
lastEmittedCount = 0
320338
sentInitialRole = false
339+
hasChatDeltaToolCalls = false
340+
hasChatDeltaContent = false
321341
return true
322342
}
323343
return false

core/http/endpoints/openai/inference.go

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,23 @@ func ComputeChoices(
113113
}
114114
prediction = p
115115

116-
// Built-in: retry on truly empty response (no tokens at all)
116+
// Built-in: retry on truly empty response (no tokens at all).
117+
// However, when the C++ autoparser is active, it clears the raw
118+
// message and delivers content via ChatDeltas instead. Do NOT
119+
// retry if ChatDeltas contain tool calls or content.
117120
if strings.TrimSpace(prediction.Response) == "" && attempt < maxRetries {
118-
xlog.Warn("Backend returned empty response, retrying",
119-
"attempt", attempt+1, "maxRetries", maxRetries)
120-
continue
121+
hasChatDeltaData := false
122+
for _, d := range prediction.ChatDeltas {
123+
if d.Content != "" || len(d.ToolCalls) > 0 {
124+
hasChatDeltaData = true
125+
break
126+
}
127+
}
128+
if !hasChatDeltaData {
129+
xlog.Warn("Backend returned empty response, retrying",
130+
"attempt", attempt+1, "maxRetries", maxRetries)
131+
continue
132+
}
121133
}
122134

123135
tokenUsage.Prompt = prediction.Usage.Prompt
@@ -130,8 +142,21 @@ func ComputeChoices(
130142
finetunedResponse := backend.Finetune(*config, predInput, prediction.Response)
131143
cb(finetunedResponse, &result)
132144

133-
// Caller-driven retry (tool parsing, reasoning-only, etc.)
134-
if shouldRetryFn != nil && shouldRetryFn(attempt) && attempt < maxRetries {
145+
// Caller-driven retry (tool parsing, reasoning-only, etc.).
146+
// When the C++ autoparser is active, it clears the raw response
147+
// and delivers data via ChatDeltas. If the response is empty but
148+
// ChatDeltas contain actionable data, skip the caller retry —
149+
// the autoparser already parsed the response successfully.
150+
skipCallerRetry := false
151+
if strings.TrimSpace(prediction.Response) == "" && len(prediction.ChatDeltas) > 0 {
152+
for _, d := range prediction.ChatDeltas {
153+
if d.Content != "" || len(d.ToolCalls) > 0 {
154+
skipCallerRetry = true
155+
break
156+
}
157+
}
158+
}
159+
if shouldRetryFn != nil && !skipCallerRetry && shouldRetryFn(attempt) && attempt < maxRetries {
135160
// Caller has already reset its state inside shouldRetry
136161
result = result[:0]
137162
allChatDeltas = nil

tests/e2e/e2e_suite_test.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,25 @@ var _ = BeforeSuite(func() {
101101
Expect(err).ToNot(HaveOccurred())
102102
Expect(os.WriteFile(configPath, configYAML, 0644)).To(Succeed())
103103

104+
// Create model config for autoparser tests (NoGrammar so tool calls
105+
// are driven entirely by the backend's ChatDeltas, not grammar enforcement)
106+
autoparserConfig := map[string]any{
107+
"name": "mock-model-autoparser",
108+
"backend": "mock-backend",
109+
"parameters": map[string]any{
110+
"model": "mock-model.bin",
111+
},
112+
"function": map[string]any{
113+
"grammar": map[string]any{
114+
"disable": true,
115+
},
116+
},
117+
}
118+
autoparserPath := filepath.Join(modelsPath, "mock-model-autoparser.yaml")
119+
autoparserYAML, err := yaml.Marshal(autoparserConfig)
120+
Expect(err).ToNot(HaveOccurred())
121+
Expect(os.WriteFile(autoparserPath, autoparserYAML, 0644)).To(Succeed())
122+
104123
// Start mock MCP server and create MCP-enabled model config
105124
mcpServerURL, mcpServerShutdown = startMockMCPServer()
106125
mcpConfig := mcpModelConfig(mcpServerURL)

0 commit comments

Comments
 (0)