Skip to content

Commit 3877eb9

Browse files
committed
refactor(hooks/builtins): let DMR own its unload conventions
Address review feedback on #2706: - Export pkg/model/provider/dmr.ProviderType and UnloadURL so the unload builtin becomes a dumb dispatcher and DMR owns the provider literal + the /v1 -> /_unload URL convention (sibling to the existing /_configure helper). Move the URL-resolution table test into the dmr package where the helper now lives. - Gate the FromAgentModels snapshot in executeOnAgentSwitchHooks on executor.Has(EventOnAgentSwitch) so audit-free deployments don't pay the team-lookup + per-model allocation on every agent switch. - Add a mixed-providers test pinning the per-element DMR filter: cloud entries (openai, anthropic) in the snapshot must be silently skipped, only DMR endpoints get POSTed. - Add a comment on http.DefaultClient explaining why the SSRF-safe client used by http_post is wrong here (DMR is loopback).
1 parent 3ccc0c2 commit 3877eb9

5 files changed

Lines changed: 213 additions & 129 deletions

File tree

pkg/hooks/builtins/unload.go

Lines changed: 13 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@ import (
88
"io"
99
"log/slog"
1010
"net/http"
11-
"net/url"
1211
"strings"
1312
"time"
1413

1514
"github.com/docker/docker-agent/pkg/hooks"
15+
"github.com/docker/docker-agent/pkg/model/provider/dmr"
1616
)
1717

1818
// Unload is the registered name of the on_agent_switch builtin that
@@ -31,6 +31,11 @@ import (
3131
// net/http. It carries no runtime-side coupling and silently skips any
3232
// model whose endpoint isn't reachable as plain HTTP (e.g. cloud
3333
// providers that don't expose [hooks.ModelEndpoint.BaseURL]).
34+
//
35+
// Provider dispatch and URL resolution are owned by
36+
// [pkg/model/provider/dmr] (see [dmr.ProviderType] and [dmr.UnloadURL]),
37+
// so this builtin stays a dumb dispatcher and DMR keeps full control
38+
// of its conventions.
3439
const Unload = "unload"
3540

3641
// unloadTimeout caps each per-model Unload call so a stalled engine
@@ -47,7 +52,7 @@ func unload(ctx context.Context, in *hooks.Input, _ []string) (*hooks.Output, er
4752
return nil, nil
4853
}
4954
for _, m := range in.FromAgentModels {
50-
if m.Provider != "dmr" {
55+
if m.Provider != dmr.ProviderType {
5156
continue
5257
}
5358
if err := unloadOne(ctx, m); err != nil {
@@ -63,7 +68,7 @@ func unload(ctx context.Context, in *hooks.Input, _ []string) (*hooks.Output, er
6368
// (no base_url and no unload_api) is a silent no-op so the hook stays
6469
// harmless on test / in-process providers.
6570
func unloadOne(parent context.Context, m hooks.ModelEndpoint) error {
66-
endpoint, err := unloadURL(m)
71+
endpoint, err := dmr.UnloadURL(m.BaseURL, m.UnloadAPI)
6772
if err != nil || endpoint == "" {
6873
return err
6974
}
@@ -79,6 +84,11 @@ func unloadOne(parent context.Context, m hooks.ModelEndpoint) error {
7984

8085
slog.DebugContext(ctx, "Unloading model", "url", endpoint, "model", m.Model)
8186

87+
// Unlike the http_post builtin, the unload target is the
88+
// operator-configured DMR base URL — typically a loopback engine
89+
// (Docker Desktop socket, 127.0.0.1:12434, …). The SSRF-safe
90+
// dialer used by http_post would refuse those addresses by
91+
// design, so we use the default client here.
8292
resp, err := http.DefaultClient.Do(req)
8393
if err != nil {
8494
return fmt.Errorf("calling unload endpoint %s: %w", endpoint, err)
@@ -96,36 +106,3 @@ func unloadOne(parent context.Context, m hooks.ModelEndpoint) error {
96106
_, _ = io.Copy(io.Discard, resp.Body)
97107
return nil
98108
}
99-
100-
// unloadURL picks the unload endpoint for one model, in this order:
101-
//
102-
// 1. unload_api is an absolute URL — used verbatim (lets users point
103-
// at a different host than the model's base_url);
104-
// 2. unload_api is set but relative — rebased onto base_url's
105-
// scheme + host (the model's path is dropped);
106-
// 3. unload_api is unset — the default `_unload` URL is derived from
107-
// base_url by replacing its trailing `/v1` segment.
108-
//
109-
// Returns ("", nil) when neither base_url nor unload_api is set, so
110-
// the caller can skip without erroring.
111-
func unloadURL(m hooks.ModelEndpoint) (string, error) {
112-
if strings.HasPrefix(m.UnloadAPI, "http://") || strings.HasPrefix(m.UnloadAPI, "https://") {
113-
return m.UnloadAPI, nil
114-
}
115-
if m.BaseURL == "" && m.UnloadAPI == "" {
116-
return "", nil
117-
}
118-
u, err := url.Parse(m.BaseURL)
119-
if err != nil || u.Scheme == "" || u.Host == "" {
120-
return "", fmt.Errorf("base_url %q is not absolute; cannot resolve unload endpoint", m.BaseURL)
121-
}
122-
switch {
123-
case m.UnloadAPI == "":
124-
u.Path = strings.TrimSuffix(strings.TrimSuffix(u.Path, "/"), "/v1") + "/_unload"
125-
case strings.HasPrefix(m.UnloadAPI, "/"):
126-
u.Path = m.UnloadAPI
127-
default:
128-
u.Path = "/" + m.UnloadAPI
129-
}
130-
return u.String(), nil
131-
}

pkg/hooks/builtins/unload_test.go

Lines changed: 40 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"github.com/stretchr/testify/require"
1313

1414
"github.com/docker/docker-agent/pkg/hooks"
15+
"github.com/docker/docker-agent/pkg/model/provider/dmr"
1516
)
1617

1718
// dmrInput builds an on_agent_switch [hooks.Input] carrying a single
@@ -26,7 +27,7 @@ func dmrInput(server *httptest.Server, unloadAPI string, opts ...func(*hooks.Inp
2627
FromAgent: "from",
2728
ToAgent: "to",
2829
FromAgentModels: []hooks.ModelEndpoint{{
29-
Provider: "dmr",
30+
Provider: dmr.ProviderType,
3031
Model: "ai/qwen3",
3132
BaseURL: server.URL + "/engines/v1",
3233
UnloadAPI: unloadAPI,
@@ -108,6 +109,43 @@ func TestUnload_HonoursOverrideUnloadAPI(t *testing.T) {
108109
assert.Equal(t, "/custom/unload", gotPath)
109110
}
110111

112+
// TestUnload_FiltersPerElement pins the per-element provider filter:
113+
// when the snapshot mixes DMR and non-DMR endpoints, only the DMR
114+
// ones are POSTed to. The non-DMR entries (cloud providers without a
115+
// reachable unload endpoint) must be silently skipped, not errored
116+
// on, not POSTed to a fabricated URL.
117+
func TestUnload_FiltersPerElement(t *testing.T) {
118+
t.Parallel()
119+
120+
var gotModels []string
121+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
122+
var body struct {
123+
Model string `json:"model"`
124+
}
125+
_ = json.NewDecoder(r.Body).Decode(&body)
126+
gotModels = append(gotModels, body.Model)
127+
w.WriteHeader(http.StatusOK)
128+
}))
129+
defer server.Close()
130+
131+
in := &hooks.Input{
132+
FromAgent: "from",
133+
ToAgent: "to",
134+
FromAgentModels: []hooks.ModelEndpoint{
135+
{Provider: "openai", Model: "gpt-4", BaseURL: "https://api.openai.com/v1"},
136+
{Provider: dmr.ProviderType, Model: "ai/qwen3", BaseURL: server.URL + "/engines/v1"},
137+
{Provider: "anthropic", Model: "claude", BaseURL: "https://api.anthropic.com"},
138+
{Provider: dmr.ProviderType, Model: "ai/llama3.2", BaseURL: server.URL + "/engines/llama.cpp/v1"},
139+
},
140+
}
141+
142+
out, err := unload(t.Context(), in, nil)
143+
require.NoError(t, err)
144+
assert.Nil(t, out)
145+
assert.ElementsMatch(t, []string{"ai/qwen3", "ai/llama3.2"}, gotModels,
146+
"only DMR models must be POSTed; cloud providers must be silently skipped")
147+
}
148+
111149
// TestUnload_NoOpInputs pins the cheap-path properties the agent loop
112150
// relies on: the hook MUST NOT fire any HTTP call when the input
113151
// describes a transition where unloading would be wrong (back to the
@@ -156,7 +194,7 @@ func TestUnload_NoOpInputs(t *testing.T) {
156194
in: func(*httptest.Server) *hooks.Input {
157195
return &hooks.Input{
158196
FromAgent: "from", ToAgent: "to",
159-
FromAgentModels: []hooks.ModelEndpoint{{Provider: "dmr", Model: "ai/qwen3"}},
197+
FromAgentModels: []hooks.ModelEndpoint{{Provider: dmr.ProviderType, Model: "ai/qwen3"}},
160198
}
161199
},
162200
},
@@ -193,94 +231,3 @@ func TestUnload_SwallowsServerErrors(t *testing.T) {
193231
require.NoError(t, err)
194232
assert.Nil(t, out)
195233
}
196-
197-
// TestUnloadURL covers every branch of the URL-resolution algorithm in
198-
// one table. Replaces what used to be two separate tables for the now
199-
// inlined `defaultUnloadURL` and `rebaseURL` helpers.
200-
func TestUnloadURL(t *testing.T) {
201-
t.Parallel()
202-
203-
tests := []struct {
204-
name string
205-
baseURL string
206-
unloadAPI string
207-
want string
208-
errContains string // empty ⇒ expect success
209-
}{
210-
// Default derivation (no unload_api set).
211-
{
212-
name: "default: standard engines path",
213-
baseURL: "http://127.0.0.1:12434/engines/v1/",
214-
want: "http://127.0.0.1:12434/engines/_unload",
215-
},
216-
{
217-
name: "default: no trailing slash",
218-
baseURL: "http://127.0.0.1:12434/engines/v1",
219-
want: "http://127.0.0.1:12434/engines/_unload",
220-
},
221-
{
222-
name: "default: Docker Desktop experimental prefix",
223-
baseURL: "http://_/exp/vDD4.40/engines/v1",
224-
want: "http://_/exp/vDD4.40/engines/_unload",
225-
},
226-
{
227-
name: "default: backend-scoped path",
228-
baseURL: "http://127.0.0.1:12434/engines/llama.cpp/v1/",
229-
want: "http://127.0.0.1:12434/engines/llama.cpp/_unload",
230-
},
231-
232-
// Override paths and absolute URLs.
233-
{
234-
name: "override: absolute https URL is returned verbatim",
235-
baseURL: "http://anything",
236-
unloadAPI: "https://api.example.com/unload",
237-
want: "https://api.example.com/unload",
238-
},
239-
{
240-
name: "override: rooted path drops base path",
241-
baseURL: "http://localhost:12434/engines/v1",
242-
unloadAPI: "/engines/_unload",
243-
want: "http://localhost:12434/engines/_unload",
244-
},
245-
{
246-
name: "override: relative path is rooted",
247-
baseURL: "http://localhost:12434/engines/v1",
248-
unloadAPI: "engines/_unload",
249-
want: "http://localhost:12434/engines/_unload",
250-
},
251-
252-
// Skip / error cases.
253-
{
254-
name: "skip: no base_url and no unload_api",
255-
want: "",
256-
},
257-
{
258-
name: "error: unload_api set but base_url empty",
259-
unloadAPI: "/engines/_unload",
260-
errContains: "is not absolute",
261-
},
262-
{
263-
name: "error: base_url without scheme",
264-
baseURL: "localhost:12434/engines/v1",
265-
unloadAPI: "/engines/_unload",
266-
errContains: "is not absolute",
267-
},
268-
}
269-
270-
for _, tt := range tests {
271-
t.Run(tt.name, func(t *testing.T) {
272-
t.Parallel()
273-
got, err := unloadURL(hooks.ModelEndpoint{
274-
BaseURL: tt.baseURL,
275-
UnloadAPI: tt.unloadAPI,
276-
})
277-
if tt.errContains != "" {
278-
require.Error(t, err)
279-
assert.Contains(t, err.Error(), tt.errContains)
280-
return
281-
}
282-
require.NoError(t, err)
283-
assert.Equal(t, tt.want, got)
284-
})
285-
}
286-
}

pkg/model/provider/dmr/unload.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package dmr
2+
3+
import (
4+
"fmt"
5+
"net/url"
6+
"strings"
7+
)
8+
9+
// ProviderType is the canonical [latest.ModelConfig.Provider] value
10+
// for Docker Model Runner. Exported so callers outside the package
11+
// (e.g. the `unload` hook builtin) can dispatch on provider type
12+
// without hard-coding the literal.
13+
const ProviderType = "dmr"
14+
15+
// UnloadURL resolves the URL of the per-model unload endpoint for a
16+
// DMR-served model, given the resolved provider base URL and the
17+
// per-model `unload_api` override (both as they appear on
18+
// [hooks.ModelEndpoint]).
19+
//
20+
// Resolution order:
21+
//
22+
// 1. unloadAPI is an absolute URL — used verbatim (lets users point
23+
// at a different host than baseURL);
24+
// 2. unloadAPI is set but relative — rebased onto baseURL's
25+
// scheme + host (the model's path is dropped);
26+
// 3. unloadAPI is unset — the default `_unload` URL is derived from
27+
// baseURL by replacing its trailing `/v1` segment, mirroring the
28+
// `/v1` → `/_configure` convention the configure path uses.
29+
//
30+
// Returns ("", nil) when neither baseURL nor unloadAPI is set, so the
31+
// caller can skip without erroring (in-process / test providers).
32+
func UnloadURL(baseURL, unloadAPI string) (string, error) {
33+
if strings.HasPrefix(unloadAPI, "http://") || strings.HasPrefix(unloadAPI, "https://") {
34+
return unloadAPI, nil
35+
}
36+
if baseURL == "" && unloadAPI == "" {
37+
return "", nil
38+
}
39+
u, err := url.Parse(baseURL)
40+
if err != nil || u.Scheme == "" || u.Host == "" {
41+
return "", fmt.Errorf("base_url %q is not absolute; cannot resolve unload endpoint", baseURL)
42+
}
43+
switch {
44+
case unloadAPI == "":
45+
u.Path = strings.TrimSuffix(strings.TrimSuffix(u.Path, "/"), "/v1") + "/_unload"
46+
case strings.HasPrefix(unloadAPI, "/"):
47+
u.Path = unloadAPI
48+
default:
49+
u.Path = "/" + unloadAPI
50+
}
51+
return u.String(), nil
52+
}

0 commit comments

Comments
 (0)