Skip to content

Commit ad7b51c

Browse files
committed
tighten Find resolution and SSE error handling
1 parent 4366414 commit ad7b51c

7 files changed

Lines changed: 148 additions & 16 deletions

File tree

cmd/root/run_event_hooks.go

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,46 @@ func withEventHooks(hooks []onEventHook) app.Opt {
6767
}
6868
}
6969

70+
// maxHookOutput caps the diagnostic output we keep for a failed on-event
71+
// hook. Large enough to be useful, small enough that a chatty or runaway
72+
// command can't push unbounded data into the agent's heap.
73+
const maxHookOutput = 4 * 1024
74+
7075
func runEventHook(command string, payload []byte) {
7176
shell, argsPrefix := shellpath.DetectShell()
77+
// Hooks are detached from the app context on purpose: a hook still
78+
// flushing the last event when the user exits the TUI should be allowed
79+
// to finish. Each invocation receives a single event on stdin, processes
80+
// it, and exits; the spawning goroutine ends with the subprocess.
7281
cmd := exec.CommandContext(context.Background(), shell, append(argsPrefix, command)...)
7382
cmd.Stdin = bytes.NewReader(payload)
74-
out, err := cmd.CombinedOutput()
75-
if err != nil {
76-
slog.Warn("on-event hook failed", "command", command, "error", err, "output", strings.TrimSpace(string(out)))
83+
var out boundedBuffer
84+
cmd.Stdout = &out
85+
cmd.Stderr = &out
86+
if err := cmd.Run(); err != nil {
87+
slog.Warn("on-event hook failed", "command", command, "error", err, "output", strings.TrimSpace(out.String()))
7788
}
7889
}
90+
91+
// boundedBuffer captures up to maxHookOutput bytes from a hook subprocess
92+
// and silently discards the rest. It implements only io.Writer so it can be
93+
// assigned to exec.Cmd's Stdout/Stderr without forcing exec to buffer the
94+
// full output internally.
95+
type boundedBuffer struct {
96+
buf bytes.Buffer
97+
}
98+
99+
func (b *boundedBuffer) Write(p []byte) (int, error) {
100+
if remaining := maxHookOutput - b.buf.Len(); remaining > 0 {
101+
if len(p) > remaining {
102+
b.buf.Write(p[:remaining])
103+
} else {
104+
b.buf.Write(p)
105+
}
106+
}
107+
return len(p), nil
108+
}
109+
110+
func (b *boundedBuffer) String() string {
111+
return b.buf.String()
112+
}

cmd/root/run_event_hooks_test.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package root
22

33
import (
4+
"bytes"
5+
"strings"
46
"testing"
57

68
"github.com/stretchr/testify/assert"
@@ -27,3 +29,28 @@ func TestParseOnEventFlags_BadFormat(t *testing.T) {
2729
assert.Error(t, err, "expected error for %q", s)
2830
}
2931
}
32+
33+
func TestBoundedBuffer_CapsAtMaxHookOutput(t *testing.T) {
34+
var b boundedBuffer
35+
36+
n, err := b.Write(bytes.Repeat([]byte("a"), maxHookOutput-3))
37+
require.NoError(t, err)
38+
assert.Equal(t, maxHookOutput-3, n)
39+
40+
// A write that straddles the cap is fully accepted from the caller's
41+
// perspective (so io.Copy doesn't error) but only the bytes up to the
42+
// cap are retained.
43+
n, err = b.Write([]byte("bbbbbb"))
44+
require.NoError(t, err)
45+
assert.Equal(t, 6, n)
46+
47+
// Subsequent writes past the cap are silently discarded.
48+
n, err = b.Write([]byte("ccccc"))
49+
require.NoError(t, err)
50+
assert.Equal(t, 5, n)
51+
52+
got := b.String()
53+
assert.Len(t, got, maxHookOutput)
54+
assert.True(t, strings.HasPrefix(got, strings.Repeat("a", maxHookOutput-3)))
55+
assert.True(t, strings.HasSuffix(got, "bbb"))
56+
}

cmd/root/sse.go

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,14 @@ import (
88
"fmt"
99
"io"
1010
"net/http"
11+
"strings"
1112
)
1213

14+
// maxErrorBodyBytes caps how much of an error response body we read into
15+
// memory. SSE error replies should be tiny; this just defends the client
16+
// against a misbehaving server that streams an unbounded error payload.
17+
const maxErrorBodyBytes = 4 * 1024
18+
1319
// openEventStream connects to the SSE event stream of a session running on
1420
// addr and returns the response body. Callers are responsible for closing
1521
// the body. The body produces standard text/event-stream output with one
@@ -27,16 +33,16 @@ func openEventStream(ctx context.Context, addr, sessionID string) (io.ReadCloser
2733
return nil, fmt.Errorf("connecting to %s: %w", url, err)
2834
}
2935
if resp.StatusCode >= 400 {
30-
body, _ := io.ReadAll(resp.Body)
36+
body, _ := io.ReadAll(io.LimitReader(resp.Body, maxErrorBodyBytes))
3137
_ = resp.Body.Close()
32-
return nil, fmt.Errorf("server returned %s: %s", resp.Status, string(body))
38+
return nil, fmt.Errorf("server returned %s: %s", resp.Status, strings.TrimSpace(string(body)))
3339
}
3440
return resp.Body, nil
3541
}
3642

3743
// readEventStream reads SSE "data:" lines from r and invokes onEvent with
38-
// each raw JSON payload. The function returns when ctx is cancelled, the
39-
// stream ends, or onEvent returns an error.
44+
// each raw JSON payload. The function returns when ctx is cancelled (with
45+
// ctx.Err), the stream ends (nil), or onEvent returns an error.
4046
//
4147
// Payloads are passed through as json.RawMessage so callers can either
4248
// forward the bytes verbatim or re-decode them into a typed value without
@@ -46,7 +52,7 @@ func readEventStream(ctx context.Context, r io.Reader, onEvent func(json.RawMess
4652
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
4753
for scanner.Scan() {
4854
if err := ctx.Err(); err != nil {
49-
return nil
55+
return err
5056
}
5157
after, ok := bytes.CutPrefix(scanner.Bytes(), []byte("data: "))
5258
if !ok {

cmd/root/sse_test.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package root
22

33
import (
4+
"context"
45
"encoding/json"
56
"net/http"
67
"net/http/httptest"
@@ -67,3 +68,34 @@ func TestOpenEventStream_StreamsSuccess(t *testing.T) {
6768
require.NoError(t, err)
6869
assert.Equal(t, []string{`{"hello":"world"}`}, got)
6970
}
71+
72+
// TestOpenEventStream_CapsErrorBody guards against a misbehaving server
73+
// pushing an unbounded error body into client memory.
74+
func TestOpenEventStream_CapsErrorBody(t *testing.T) {
75+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
76+
w.WriteHeader(http.StatusInternalServerError)
77+
_, _ = w.Write([]byte(strings.Repeat("x", 10*maxErrorBodyBytes)))
78+
}))
79+
defer srv.Close()
80+
81+
_, err := openEventStream(t.Context(), srv.URL, "s1")
82+
require.Error(t, err)
83+
// The error message embeds at most maxErrorBodyBytes of the body.
84+
assert.LessOrEqual(t, len(err.Error()), maxErrorBodyBytes+256)
85+
}
86+
87+
// TestReadEventStream_ReturnsCtxErr verifies the helper surfaces ctx
88+
// cancellation so callers can distinguish it from a clean stream end.
89+
func TestReadEventStream_ReturnsCtxErr(t *testing.T) {
90+
ctx, cancel := context.WithCancel(t.Context())
91+
body := strings.NewReader("data: {}\n\ndata: {}\n\n")
92+
93+
calls := 0
94+
err := readEventStream(ctx, body, func(json.RawMessage) error {
95+
calls++
96+
cancel()
97+
return nil
98+
})
99+
require.ErrorIs(t, err, context.Canceled)
100+
assert.Equal(t, 1, calls)
101+
}

cmd/root/watch.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
package root
22

33
import (
4+
"context"
45
"encoding/json"
6+
"errors"
57
"fmt"
68

79
"github.com/spf13/cobra"
@@ -56,8 +58,13 @@ func (f *watchFlags) run(cmd *cobra.Command, args []string) (commandErr error) {
5658
out.Println("Watching", rec.Addr, "(session", rec.SessionID+")")
5759

5860
stdout := cmd.OutOrStdout()
59-
return readEventStream(ctx, body, func(payload json.RawMessage) error {
61+
err = readEventStream(ctx, body, func(payload json.RawMessage) error {
6062
_, err := fmt.Fprintln(stdout, string(payload))
6163
return err
6264
})
65+
// Ctrl+C is the normal way to stop watching; don't treat it as failure.
66+
if errors.Is(err, context.Canceled) {
67+
return nil
68+
}
69+
return err
6370
}

pkg/runregistry/registry.go

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,10 @@ var ErrNoRun = errors.New("no live docker-agent run found; start one with: docke
145145
// An empty target returns the most recently started run. A numeric target is
146146
// matched by PID; a target starting with "http://" or "https://" is matched
147147
// against record addresses; anything else is matched as a (possibly partial)
148-
// session ID. Matching is exact for PID and addr, prefix-or-substring for
149-
// session ID; ambiguous matches return an error so callers don't act on the
150-
// wrong session.
148+
// session ID. PID and address matches are exact. Session-ID matching prefers
149+
// exact equality and only falls back to substring matching when no record
150+
// matches exactly; ambiguous substring matches return an error so callers
151+
// don't act on the wrong session.
151152
func Find(target string) (Record, error) {
152153
target = strings.TrimSpace(target)
153154
if target == "" {
@@ -175,7 +176,7 @@ func Find(target string) (Record, error) {
175176
return r, nil
176177
}
177178
}
178-
return Record{}, fmt.Errorf("no live run with pid %d", pid)
179+
return Record{}, fmt.Errorf("no live run with pid %d: %w", pid, ErrNoRun)
179180
}
180181

181182
if strings.HasPrefix(target, "http://") || strings.HasPrefix(target, "https://") {
@@ -185,18 +186,25 @@ func Find(target string) (Record, error) {
185186
return r, nil
186187
}
187188
}
188-
return Record{}, fmt.Errorf("no live run at %s", target)
189+
return Record{}, fmt.Errorf("no live run at %s: %w", target, ErrNoRun)
189190
}
190191

192+
// Prefer an exact session-id match: an unambiguous full id must always
193+
// resolve, even when other ids contain it as a substring.
194+
for _, r := range records {
195+
if r.SessionID == target {
196+
return r, nil
197+
}
198+
}
191199
var matches []Record
192200
for _, r := range records {
193-
if r.SessionID == target || strings.HasPrefix(r.SessionID, target) || strings.Contains(r.SessionID, target) {
201+
if strings.Contains(r.SessionID, target) {
194202
matches = append(matches, r)
195203
}
196204
}
197205
switch len(matches) {
198206
case 0:
199-
return Record{}, fmt.Errorf("no live run matches %q (pid, http URL, or session id)", target)
207+
return Record{}, fmt.Errorf("no live run matches %q (pid, http URL, or session id): %w", target, ErrNoRun)
200208
case 1:
201209
return matches[0], nil
202210
default:

pkg/runregistry/registry_test.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,18 +135,21 @@ func TestFind(t *testing.T) {
135135
_, err := Find("999999999")
136136
require.Error(t, err)
137137
assert.Contains(t, err.Error(), "no live run with pid")
138+
assert.ErrorIs(t, err, ErrNoRun)
138139
})
139140

140141
t.Run("unknown addr errors", func(t *testing.T) {
141142
_, err := Find("http://nope")
142143
require.Error(t, err)
143144
assert.Contains(t, err.Error(), "no live run at")
145+
assert.ErrorIs(t, err, ErrNoRun)
144146
})
145147

146148
t.Run("unknown session id errors", func(t *testing.T) {
147149
_, err := Find("zzz")
148150
require.Error(t, err)
149151
assert.Contains(t, err.Error(), "no live run matches")
152+
assert.ErrorIs(t, err, ErrNoRun)
150153
})
151154
}
152155

@@ -162,6 +165,21 @@ func TestFind_AmbiguousSessionID(t *testing.T) {
162165
assert.Contains(t, err.Error(), "ambiguous")
163166
}
164167

168+
// TestFind_ExactMatchBeatsSubstring guards against a regression where an
169+
// exact session-id match was reported as ambiguous because a longer id
170+
// contained it as a substring.
171+
func TestFind_ExactMatchBeatsSubstring(t *testing.T) {
172+
withTempDataDir(t)
173+
174+
pid := os.Getpid()
175+
writeRecord(t, "1.json", Record{PID: pid, Addr: "http://a", SessionID: "abc", StartedAt: time.Now()})
176+
writeRecord(t, "2.json", Record{PID: pid, Addr: "http://b", SessionID: "abcd", StartedAt: time.Now()})
177+
178+
rec, err := Find("abc")
179+
require.NoError(t, err)
180+
assert.Equal(t, "abc", rec.SessionID)
181+
}
182+
165183
func TestFind_EmptyRegistry(t *testing.T) {
166184
withTempDataDir(t)
167185

0 commit comments

Comments
 (0)