Skip to content

Commit 33378c9

Browse files
committed
refactor(hooks/builtins): let DMR own its unload conventions
Address review feedback on #2706: - Export pkg/model/provider/dmr.ProviderType and UnloadURL so the unload builtin becomes a dumb dispatcher and DMR owns the provider literal + the /v1 -> /_unload URL convention (sibling to the existing /_configure helper). Move the URL-resolution table test into the dmr package where the helper now lives. - Gate the FromAgentModels snapshot in executeOnAgentSwitchHooks on executor.Has(EventOnAgentSwitch) so audit-free deployments don't pay the team-lookup + per-model allocation on every agent switch. - Add a mixed-providers test pinning the per-element DMR filter: cloud entries (openai, anthropic) in the snapshot must be silently skipped, only DMR endpoints get POSTed. - Add a comment on http.DefaultClient explaining why the SSRF-safe client used by http_post is wrong here (DMR is loopback).
1 parent c090e18 commit 33378c9

5 files changed

Lines changed: 213 additions & 129 deletions

File tree

pkg/hooks/builtins/unload.go

Lines changed: 13 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@ import (
88
"io"
99
"log/slog"
1010
"net/http"
11-
"net/url"
1211
"strings"
1312
"time"
1413

1514
"github.com/docker/docker-agent/pkg/hooks"
15+
"github.com/docker/docker-agent/pkg/model/provider/dmr"
1616
)
1717

1818
// Unload is the registered name of the on_agent_switch builtin that
@@ -31,6 +31,11 @@ import (
3131
// net/http. It carries no runtime-side coupling and silently skips any
3232
// model whose endpoint isn't reachable as plain HTTP (e.g. cloud
3333
// providers that don't expose [hooks.ModelEndpoint.BaseURL]).
34+
//
35+
// Provider dispatch and URL resolution are owned by
36+
// [pkg/model/provider/dmr] (see [dmr.ProviderType] and [dmr.UnloadURL]),
37+
// so this builtin stays a dumb dispatcher and DMR keeps full control
38+
// of its conventions.
3439
const Unload = "unload"
3540

3641
// unloadTimeout caps each per-model Unload call so a stalled engine
@@ -47,7 +52,7 @@ func unload(ctx context.Context, in *hooks.Input, _ []string) (*hooks.Output, er
4752
return nil, nil
4853
}
4954
for _, m := range in.FromAgentModels {
50-
if m.Provider != "dmr" {
55+
if m.Provider != dmr.ProviderType {
5156
continue
5257
}
5358
if err := unloadOne(ctx, m); err != nil {
@@ -63,7 +68,7 @@ func unload(ctx context.Context, in *hooks.Input, _ []string) (*hooks.Output, er
6368
// (no base_url and no unload_api) is a silent no-op so the hook stays
6469
// harmless on test / in-process providers.
6570
func unloadOne(parent context.Context, m hooks.ModelEndpoint) error {
66-
endpoint, err := unloadURL(m)
71+
endpoint, err := dmr.UnloadURL(m.BaseURL, m.UnloadAPI)
6772
if err != nil || endpoint == "" {
6873
return err
6974
}
@@ -79,6 +84,11 @@ func unloadOne(parent context.Context, m hooks.ModelEndpoint) error {
7984

8085
slog.DebugContext(ctx, "Unloading model", "url", endpoint, "model", m.Model)
8186

87+
// Unlike the http_post builtin, the unload target is the
88+
// operator-configured DMR base URL — typically a loopback engine
89+
// (Docker Desktop socket, 127.0.0.1:12434, …). The SSRF-safe
90+
// dialer used by http_post would refuse those addresses by
91+
// design, so we use the default client here.
8292
resp, err := http.DefaultClient.Do(req)
8393
if err != nil {
8494
return fmt.Errorf("calling unload endpoint %s: %w", endpoint, err)
@@ -96,36 +106,3 @@ func unloadOne(parent context.Context, m hooks.ModelEndpoint) error {
96106
_, _ = io.Copy(io.Discard, resp.Body)
97107
return nil
98108
}
99-
100-
// unloadURL picks the unload endpoint for one model, in this order:
101-
//
102-
// 1. unload_api is an absolute URL — used verbatim (lets users point
103-
// at a different host than the model's base_url);
104-
// 2. unload_api is set but relative — rebased onto base_url's
105-
// scheme + host (the model's path is dropped);
106-
// 3. unload_api is unset — the default `_unload` URL is derived from
107-
// base_url by replacing its trailing `/v1` segment.
108-
//
109-
// Returns ("", nil) when neither base_url nor unload_api is set, so
110-
// the caller can skip without erroring.
111-
func unloadURL(m hooks.ModelEndpoint) (string, error) {
112-
if strings.HasPrefix(m.UnloadAPI, "http://") || strings.HasPrefix(m.UnloadAPI, "https://") {
113-
return m.UnloadAPI, nil
114-
}
115-
if m.BaseURL == "" && m.UnloadAPI == "" {
116-
return "", nil
117-
}
118-
u, err := url.Parse(m.BaseURL)
119-
if err != nil || u.Scheme == "" || u.Host == "" {
120-
return "", fmt.Errorf("base_url %q is not absolute; cannot resolve unload endpoint", m.BaseURL)
121-
}
122-
switch {
123-
case m.UnloadAPI == "":
124-
u.Path = strings.TrimSuffix(strings.TrimSuffix(u.Path, "/"), "/v1") + "/_unload"
125-
case strings.HasPrefix(m.UnloadAPI, "/"):
126-
u.Path = m.UnloadAPI
127-
default:
128-
u.Path = "/" + m.UnloadAPI
129-
}
130-
return u.String(), nil
131-
}

pkg/hooks/builtins/unload_test.go

Lines changed: 40 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"github.com/stretchr/testify/require"
1313

1414
"github.com/docker/docker-agent/pkg/hooks"
15+
"github.com/docker/docker-agent/pkg/model/provider/dmr"
1516
)
1617

1718
// dmrInput builds an on_agent_switch [hooks.Input] carrying a single
@@ -26,7 +27,7 @@ func dmrInput(server *httptest.Server, unloadAPI string, opts ...func(*hooks.Inp
2627
FromAgent: "from",
2728
ToAgent: "to",
2829
FromAgentModels: []hooks.ModelEndpoint{{
29-
Provider: "dmr",
30+
Provider: dmr.ProviderType,
3031
Model: "ai/qwen3",
3132
BaseURL: server.URL + "/engines/v1",
3233
UnloadAPI: unloadAPI,
@@ -109,6 +110,43 @@ func TestUnload_HonoursOverrideUnloadAPI(t *testing.T) {
109110
assert.Equal(t, "/custom/unload", gotPath)
110111
}
111112

113+
// TestUnload_FiltersPerElement pins the per-element provider filter:
114+
// when the snapshot mixes DMR and non-DMR endpoints, only the DMR
115+
// ones are POSTed to. The non-DMR entries (cloud providers without a
116+
// reachable unload endpoint) must be silently skipped, not errored
117+
// on, not POSTed to a fabricated URL.
118+
func TestUnload_FiltersPerElement(t *testing.T) {
119+
t.Parallel()
120+
121+
var gotModels []string
122+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
123+
var body struct {
124+
Model string `json:"model"`
125+
}
126+
_ = json.NewDecoder(r.Body).Decode(&body)
127+
gotModels = append(gotModels, body.Model)
128+
w.WriteHeader(http.StatusOK)
129+
}))
130+
defer server.Close()
131+
132+
in := &hooks.Input{
133+
FromAgent: "from",
134+
ToAgent: "to",
135+
FromAgentModels: []hooks.ModelEndpoint{
136+
{Provider: "openai", Model: "gpt-4", BaseURL: "https://api.openai.com/v1"},
137+
{Provider: dmr.ProviderType, Model: "ai/qwen3", BaseURL: server.URL + "/engines/v1"},
138+
{Provider: "anthropic", Model: "claude", BaseURL: "https://api.anthropic.com"},
139+
{Provider: dmr.ProviderType, Model: "ai/llama3.2", BaseURL: server.URL + "/engines/llama.cpp/v1"},
140+
},
141+
}
142+
143+
out, err := unload(t.Context(), in, nil)
144+
require.NoError(t, err)
145+
assert.Nil(t, out)
146+
assert.ElementsMatch(t, []string{"ai/qwen3", "ai/llama3.2"}, gotModels,
147+
"only DMR models must be POSTed; cloud providers must be silently skipped")
148+
}
149+
112150
// TestUnload_NoOpInputs pins the cheap-path properties the agent loop
113151
// relies on: the hook MUST NOT fire any HTTP call when the input
114152
// describes a transition where unloading would be wrong (back to the
@@ -157,7 +195,7 @@ func TestUnload_NoOpInputs(t *testing.T) {
157195
in: func(*httptest.Server) *hooks.Input {
158196
return &hooks.Input{
159197
FromAgent: "from", ToAgent: "to",
160-
FromAgentModels: []hooks.ModelEndpoint{{Provider: "dmr", Model: "ai/qwen3"}},
198+
FromAgentModels: []hooks.ModelEndpoint{{Provider: dmr.ProviderType, Model: "ai/qwen3"}},
161199
}
162200
},
163201
},
@@ -194,94 +232,3 @@ func TestUnload_SwallowsServerErrors(t *testing.T) {
194232
require.NoError(t, err)
195233
assert.Nil(t, out)
196234
}
197-
198-
// TestUnloadURL covers every branch of the URL-resolution algorithm in
199-
// one table. Replaces what used to be two separate tables for the now
200-
// inlined `defaultUnloadURL` and `rebaseURL` helpers.
201-
func TestUnloadURL(t *testing.T) {
202-
t.Parallel()
203-
204-
tests := []struct {
205-
name string
206-
baseURL string
207-
unloadAPI string
208-
want string
209-
errContains string // empty ⇒ expect success
210-
}{
211-
// Default derivation (no unload_api set).
212-
{
213-
name: "default: standard engines path",
214-
baseURL: "http://127.0.0.1:12434/engines/v1/",
215-
want: "http://127.0.0.1:12434/engines/_unload",
216-
},
217-
{
218-
name: "default: no trailing slash",
219-
baseURL: "http://127.0.0.1:12434/engines/v1",
220-
want: "http://127.0.0.1:12434/engines/_unload",
221-
},
222-
{
223-
name: "default: Docker Desktop experimental prefix",
224-
baseURL: "http://_/exp/vDD4.40/engines/v1",
225-
want: "http://_/exp/vDD4.40/engines/_unload",
226-
},
227-
{
228-
name: "default: backend-scoped path",
229-
baseURL: "http://127.0.0.1:12434/engines/llama.cpp/v1/",
230-
want: "http://127.0.0.1:12434/engines/llama.cpp/_unload",
231-
},
232-
233-
// Override paths and absolute URLs.
234-
{
235-
name: "override: absolute https URL is returned verbatim",
236-
baseURL: "http://anything",
237-
unloadAPI: "https://api.example.com/unload",
238-
want: "https://api.example.com/unload",
239-
},
240-
{
241-
name: "override: rooted path drops base path",
242-
baseURL: "http://localhost:12434/engines/v1",
243-
unloadAPI: "/engines/_unload",
244-
want: "http://localhost:12434/engines/_unload",
245-
},
246-
{
247-
name: "override: relative path is rooted",
248-
baseURL: "http://localhost:12434/engines/v1",
249-
unloadAPI: "engines/_unload",
250-
want: "http://localhost:12434/engines/_unload",
251-
},
252-
253-
// Skip / error cases.
254-
{
255-
name: "skip: no base_url and no unload_api",
256-
want: "",
257-
},
258-
{
259-
name: "error: unload_api set but base_url empty",
260-
unloadAPI: "/engines/_unload",
261-
errContains: "is not absolute",
262-
},
263-
{
264-
name: "error: base_url without scheme",
265-
baseURL: "localhost:12434/engines/v1",
266-
unloadAPI: "/engines/_unload",
267-
errContains: "is not absolute",
268-
},
269-
}
270-
271-
for _, tt := range tests {
272-
t.Run(tt.name, func(t *testing.T) {
273-
t.Parallel()
274-
got, err := unloadURL(hooks.ModelEndpoint{
275-
BaseURL: tt.baseURL,
276-
UnloadAPI: tt.unloadAPI,
277-
})
278-
if tt.errContains != "" {
279-
require.Error(t, err)
280-
assert.Contains(t, err.Error(), tt.errContains)
281-
return
282-
}
283-
require.NoError(t, err)
284-
assert.Equal(t, tt.want, got)
285-
})
286-
}
287-
}

pkg/model/provider/dmr/unload.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package dmr
2+
3+
import (
4+
"fmt"
5+
"net/url"
6+
"strings"
7+
)
8+
9+
// ProviderType is the canonical [latest.ModelConfig.Provider] value
10+
// for Docker Model Runner. Exported so callers outside the package
11+
// (e.g. the `unload` hook builtin) can dispatch on provider type
12+
// without hard-coding the literal.
13+
const ProviderType = "dmr"
14+
15+
// UnloadURL resolves the URL of the per-model unload endpoint for a
16+
// DMR-served model, given the resolved provider base URL and the
17+
// per-model `unload_api` override (both as they appear on
18+
// [hooks.ModelEndpoint]).
19+
//
20+
// Resolution order:
21+
//
22+
// 1. unloadAPI is an absolute URL — used verbatim (lets users point
23+
// at a different host than baseURL);
24+
// 2. unloadAPI is set but relative — rebased onto baseURL's
25+
// scheme + host (the model's path is dropped);
26+
// 3. unloadAPI is unset — the default `_unload` URL is derived from
27+
// baseURL by replacing its trailing `/v1` segment, mirroring the
28+
// `/v1` → `/_configure` convention the configure path uses.
29+
//
30+
// Returns ("", nil) when neither baseURL nor unloadAPI is set, so the
31+
// caller can skip without erroring (in-process / test providers).
32+
func UnloadURL(baseURL, unloadAPI string) (string, error) {
33+
if strings.HasPrefix(unloadAPI, "http://") || strings.HasPrefix(unloadAPI, "https://") {
34+
return unloadAPI, nil
35+
}
36+
if baseURL == "" && unloadAPI == "" {
37+
return "", nil
38+
}
39+
u, err := url.Parse(baseURL)
40+
if err != nil || u.Scheme == "" || u.Host == "" {
41+
return "", fmt.Errorf("base_url %q is not absolute; cannot resolve unload endpoint", baseURL)
42+
}
43+
switch {
44+
case unloadAPI == "":
45+
u.Path = strings.TrimSuffix(strings.TrimSuffix(u.Path, "/"), "/v1") + "/_unload"
46+
case strings.HasPrefix(unloadAPI, "/"):
47+
u.Path = unloadAPI
48+
default:
49+
u.Path = "/" + unloadAPI
50+
}
51+
return u.String(), nil
52+
}

0 commit comments

Comments
 (0)