Skip to content

Commit ad2a2bc

Browse files
committed
Centralize model-change TUI notification in the main loop
Add a change-detection mechanism (lastEmittedModelID + emitModelInfo closure) in RunStream that automatically emits AgentInfo only when the effective model actually changes. This is checked before and after each LLM call, covering per-tool overrides, fallback, model picker, cooldowns, and any future model-switching feature — without each one having to remember to notify the TUI. - loop.go: replace 3 scattered manual AgentInfo emissions with emitModelInfo calls driven by the closure - model_picker.go: remove AgentInfo emission from the tool handler; rename setModelAndEmitInfo to setCurrentAgentModel (no longer emits) - agent_delegation.go: use getEffectiveModelID instead of getAgentModelID so agent-switch events reflect active fallback cooldowns Assisted-By: docker-agent
1 parent 8ad9626 commit ad2a2bc

4 files changed

Lines changed: 153 additions & 30 deletions

File tree

pkg/runtime/agent_delegation.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,13 +168,13 @@ func (r *LocalRuntime) handleTaskTransfer(ctx context.Context, sess *session.Ses
168168

169169
// Restore original agent info in sidebar
170170
if originalAgent, err := r.team.Agent(ca); err == nil {
171-
evts <- AgentInfo(originalAgent.Name(), getAgentModelID(originalAgent), originalAgent.Description(), originalAgent.WelcomeMessage())
171+
evts <- AgentInfo(originalAgent.Name(), r.getEffectiveModelID(originalAgent), originalAgent.Description(), originalAgent.WelcomeMessage())
172172
}
173173
}()
174174

175175
// Emit agent info for the new agent
176176
if newAgent, err := r.team.Agent(params.Agent); err == nil {
177-
evts <- AgentInfo(newAgent.Name(), getAgentModelID(newAgent), newAgent.Description(), newAgent.WelcomeMessage())
177+
evts <- AgentInfo(newAgent.Name(), r.getEffectiveModelID(newAgent), newAgent.Description(), newAgent.WelcomeMessage())
178178
}
179179

180180
slog.Debug("Creating new session with parent session", "parent_session_id", sess.ID, "tools_approved", sess.ToolsApproved, "thinking", sess.Thinking)

pkg/runtime/loop.go

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,21 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
9191

9292
a := r.resolveSessionAgent(sess)
9393

94+
// lastEmittedModelID tracks what the TUI currently displays.
95+
// emitModelInfo sends an AgentInfo only when the model actually changed,
96+
// so new features (routing, alloy, fallback, model picker, …) never need
97+
// to notify the TUI themselves — the loop handles it.
98+
lastEmittedModelID := r.getEffectiveModelID(a)
99+
emitModelInfo := func(a *agent.Agent, modelID string) {
100+
if modelID == lastEmittedModelID {
101+
return
102+
}
103+
lastEmittedModelID = modelID
104+
events <- AgentInfo(a.Name(), modelID, a.Description(), a.WelcomeMessage())
105+
}
106+
94107
// Emit agent information for sidebar display
95-
// Use getEffectiveModelID to account for active fallback cooldowns
96-
events <- AgentInfo(a.Name(), r.getEffectiveModelID(a), a.Description(), a.WelcomeMessage())
108+
events <- AgentInfo(a.Name(), lastEmittedModelID, a.Description(), a.WelcomeMessage())
97109

98110
// Emit team information
99111
events <- TeamInfo(r.agentDetailsFromTeam(), a.Name())
@@ -241,10 +253,9 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
241253

242254
modelID := model.ID()
243255

244-
// Notify sidebar when this turn uses a different model (per-tool override).
245-
if modelID != defaultModelID {
246-
events <- AgentInfo(a.Name(), modelID, a.Description(), a.WelcomeMessage())
247-
}
256+
// Notify sidebar when this turn uses a different model
257+
// (per-tool override, model picker, fallback cooldown, …).
258+
emitModelInfo(a, modelID)
248259

249260
slog.Debug("Using agent", "agent", a.Name(), "model", modelID)
250261
slog.Debug("Getting model definition", "model_id", modelID)
@@ -319,17 +330,15 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
319330
return
320331
}
321332

322-
// Update sidebar model info to reflect what was actually used this turn.
323-
// Fallback models are sticky (cooldown system persists them), so we only
324-
// emit once. Per-tool model overrides are temporary (one turn), so we
325-
// emit the override and then revert to the agent's default.
333+
// Update sidebar to reflect the model actually used this turn.
334+
// When no fallback kicked in, revert to the agent's default
335+
// (undoes any temporary per-tool override).
336+
actualModelID := defaultModelID
326337
if usedModel != nil && usedModel.ID() != model.ID() {
327338
slog.Info("Used fallback model", "agent", a.Name(), "primary", model.ID(), "used", usedModel.ID())
328-
events <- AgentInfo(a.Name(), usedModel.ID(), a.Description(), a.WelcomeMessage())
329-
} else if model.ID() != defaultModelID {
330-
// Per-tool override was active: revert sidebar to the agent's default model.
331-
events <- AgentInfo(a.Name(), defaultModelID, a.Description(), a.WelcomeMessage())
339+
actualModelID = usedModel.ID()
332340
}
341+
emitModelInfo(a, actualModelID)
333342
streamSpan.SetAttributes(
334343
attribute.Int("tool.calls", len(res.Calls)),
335344
attribute.Int("content.length", len(res.Content)),
@@ -350,6 +359,11 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
350359

351360
r.processToolCalls(ctx, sess, res.Calls, agentTools, events)
352361

362+
// Tool handlers (e.g. change_model, revert_model) may have
363+
// switched the effective model. Notify the TUI now so the
364+
// sidebar updates even when the model stops after the tool call.
365+
emitModelInfo(a, r.getEffectiveModelID(a))
366+
353367
// Record per-toolset model override for the next LLM turn.
354368
toolModelOverride = resolveToolCallModelOverride(res.Calls, agentTools)
355369

pkg/runtime/model_picker.go

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ func (r *LocalRuntime) findModelPickerTool() *builtin.ModelPickerTool {
3030
}
3131

3232
// handleChangeModel handles the change_model tool call by switching the current agent's model.
33-
func (r *LocalRuntime) handleChangeModel(ctx context.Context, _ *session.Session, toolCall tools.ToolCall, events chan Event) (*tools.ToolCallResult, error) {
33+
func (r *LocalRuntime) handleChangeModel(ctx context.Context, _ *session.Session, toolCall tools.ToolCall, _ chan Event) (*tools.ToolCallResult, error) {
3434
var params builtin.ChangeModelArgs
3535
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &params); err != nil {
3636
return nil, fmt.Errorf("invalid arguments: %w", err)
@@ -53,29 +53,24 @@ func (r *LocalRuntime) handleChangeModel(ctx context.Context, _ *session.Session
5353
)), nil
5454
}
5555

56-
return r.setModelAndEmitInfo(ctx, params.Model, events)
56+
return r.setCurrentAgentModel(ctx, params.Model)
5757
}
5858

5959
// handleRevertModel handles the revert_model tool call by reverting the current agent to its default model.
60-
func (r *LocalRuntime) handleRevertModel(ctx context.Context, _ *session.Session, _ tools.ToolCall, events chan Event) (*tools.ToolCallResult, error) {
61-
return r.setModelAndEmitInfo(ctx, "", events)
60+
func (r *LocalRuntime) handleRevertModel(ctx context.Context, _ *session.Session, _ tools.ToolCall, _ chan Event) (*tools.ToolCallResult, error) {
61+
return r.setCurrentAgentModel(ctx, "")
6262
}
6363

64-
// setModelAndEmitInfo sets the model for the current agent and emits an updated
65-
// AgentInfo event so the UI reflects the change. An empty modelRef reverts to
66-
// the agent's default model.
67-
func (r *LocalRuntime) setModelAndEmitInfo(ctx context.Context, modelRef string, events chan Event) (*tools.ToolCallResult, error) {
64+
// setCurrentAgentModel sets the model for the current agent. An empty modelRef
65+
// reverts to the agent's default model. The main loop detects the resulting
66+
// model change and automatically notifies the TUI, so no AgentInfo event is
67+
// emitted here.
68+
func (r *LocalRuntime) setCurrentAgentModel(ctx context.Context, modelRef string) (*tools.ToolCallResult, error) {
6869
currentName := r.CurrentAgentName()
6970
if err := r.SetAgentModel(ctx, currentName, modelRef); err != nil {
7071
return tools.ResultError(fmt.Sprintf("failed to set model: %v", err)), nil
7172
}
7273

73-
if a, err := r.team.Agent(currentName); err == nil {
74-
events <- AgentInfo(a.Name(), r.getEffectiveModelID(a), a.Description(), a.WelcomeMessage())
75-
} else {
76-
slog.Warn("Failed to retrieve agent after model change; UI may not reflect the update", "agent", currentName, "error", err)
77-
}
78-
7974
if modelRef == "" {
8075
slog.Info("Model reverted via model_picker tool", "agent", currentName)
8176
return tools.ResultSuccess("Model reverted to the agent's default model"), nil

pkg/runtime/model_picker_test.go

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
package runtime
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/require"
9+
10+
"github.com/docker/docker-agent/pkg/agent"
11+
"github.com/docker/docker-agent/pkg/chat"
12+
"github.com/docker/docker-agent/pkg/session"
13+
"github.com/docker/docker-agent/pkg/team"
14+
"github.com/docker/docker-agent/pkg/tools"
15+
)
16+
17+
// staticToolSet is a simple ToolSet that returns a fixed list of tools.
18+
type staticToolSet struct {
19+
tools []tools.Tool
20+
}
21+
22+
func (s *staticToolSet) Tools(context.Context) ([]tools.Tool, error) {
23+
return s.tools, nil
24+
}
25+
26+
// TestModelChangeEmitsAgentInfo verifies that when a tool call changes the
27+
// agent's model (like change_model does), an AgentInfoEvent with the new
28+
// model ID is emitted even when the model stops in the same turn.
29+
// This is the scenario where the TUI sidebar must be updated.
30+
func TestModelChangeEmitsAgentInfo(t *testing.T) {
31+
newModel := &mockProvider{id: "openai/gpt-4o-mini"}
32+
33+
// Stream 1: model calls the custom "switch_model" tool and stops.
34+
stream1 := newStreamBuilder().
35+
AddToolCallName("call_1", "switch_model").
36+
AddToolCallArguments("call_1", `{}`).
37+
AddStopWithUsage(5, 5).
38+
Build()
39+
40+
// Stream 2: after the tool result is returned, model says "Done" and stops.
41+
stream2 := newStreamBuilder().
42+
AddContent("Model switched.").
43+
AddStopWithUsage(5, 5).
44+
Build()
45+
46+
prov := &queueProvider{id: "test/original-model", streams: []chat.MessageStream{
47+
stream1,
48+
stream2,
49+
}}
50+
51+
// Create a toolset that exposes the "switch_model" tool.
52+
switchToolSet := &staticToolSet{tools: []tools.Tool{
53+
{
54+
Name: "switch_model",
55+
Description: "switch the model",
56+
Annotations: tools.ToolAnnotations{ReadOnlyHint: true},
57+
},
58+
}}
59+
60+
root := agent.New("root", "test agent",
61+
agent.WithModel(prov),
62+
agent.WithToolSets(switchToolSet),
63+
)
64+
tm := team.New(team.WithAgents(root))
65+
66+
rt, err := NewLocalRuntime(tm,
67+
WithSessionCompaction(false),
68+
WithModelStore(mockModelStore{}),
69+
)
70+
require.NoError(t, err)
71+
72+
// Register a custom handler that switches the agent's model override,
73+
// mimicking what handleChangeModel does internally.
74+
rt.toolMap["switch_model"] = func(_ context.Context, _ *session.Session, _ tools.ToolCall, _ chan Event) (*tools.ToolCallResult, error) {
75+
a2, _ := rt.team.Agent("root")
76+
a2.SetModelOverride(newModel)
77+
return tools.ResultSuccess("Model changed to openai/gpt-4o-mini"), nil
78+
}
79+
80+
sess := session.New(session.WithUserMessage("Switch the model"), session.WithToolsApproved(true))
81+
sess.Title = "Test"
82+
83+
evCh := rt.RunStream(t.Context(), sess)
84+
var events []Event
85+
for ev := range evCh {
86+
events = append(events, ev)
87+
}
88+
89+
// Collect all AgentInfoEvents.
90+
var agentInfoEvents []*AgentInfoEvent
91+
for _, ev := range events {
92+
if ai, ok := ev.(*AgentInfoEvent); ok {
93+
agentInfoEvents = append(agentInfoEvents, ai)
94+
}
95+
}
96+
97+
// There should be at least two AgentInfoEvents:
98+
// 1. The initial one with "test/original-model"
99+
// 2. One after the tool call with "openai/gpt-4o-mini"
100+
require.GreaterOrEqual(t, len(agentInfoEvents), 2, "expected at least 2 AgentInfoEvents, got %d", len(agentInfoEvents))
101+
102+
// The first should show the original model.
103+
assert.Equal(t, "test/original-model", agentInfoEvents[0].Model)
104+
105+
// At least one AgentInfoEvent should show the new model.
106+
foundNewModel := false
107+
for _, ai := range agentInfoEvents {
108+
if ai.Model == "openai/gpt-4o-mini" {
109+
foundNewModel = true
110+
break
111+
}
112+
}
113+
assert.True(t, foundNewModel, "expected an AgentInfoEvent with model 'openai/gpt-4o-mini'")
114+
}

0 commit comments

Comments
 (0)