diff --git a/pipe/command.go b/pipe/command.go index 2c465e9..ae6eebb 100644 --- a/pipe/command.go +++ b/pipe/command.go @@ -18,9 +18,13 @@ import ( // commandStage is a pipeline `Stage` based on running an external // command and piping the data through its stdin and stdout. type commandStage struct { - name string - stdin io.Closer - cmd *exec.Cmd + name string + cmd *exec.Cmd + + // lateClosers is a list of things that have to be closed once the + // command has finished. + lateClosers []io.Closer + done chan struct{} wg errgroup.Group stderr bytes.Buffer @@ -30,6 +34,10 @@ type commandStage struct { ctxErr atomic.Value } +var ( + _ Stage = (*commandStage)(nil) +) + // Command returns a pipeline `Stage` based on the specified external // `command`, run with the given command-line `args`. Its stdin and // stdout are handled as usual, and its stderr is collected and @@ -59,33 +67,80 @@ func (s *commandStage) Name() string { return s.name } +func (s *commandStage) Preferences() StagePreferences { + prefs := StagePreferences{ + StdinPreference: IOPreferenceFile, + StdoutPreference: IOPreferenceFile, + } + if s.cmd.Stdin != nil { + prefs.StdinPreference = IOPreferenceNil + } + if s.cmd.Stdout != nil { + prefs.StdoutPreference = IOPreferenceNil + } + + return prefs +} + func (s *commandStage) Start( - ctx context.Context, env Env, stdin io.ReadCloser, -) (io.ReadCloser, error) { + ctx context.Context, env Env, stdin io.ReadCloser, stdout io.WriteCloser, +) error { if s.cmd.Dir == "" { s.cmd.Dir = env.Dir } s.setupEnv(ctx, env) + // Things that have to be closed as soon as the command has + // started: + var earlyClosers []io.Closer + + // See the type command for `Stage` and the long comment in + // `Pipeline.WithStdin()` for the explanation of this unwrapping + // and closing behavior. + if stdin != nil { - // See the long comment in `Pipeline.Start()` for the - // explanation of this special case. switch stdin := stdin.(type) { - case nopCloser: + case readerNopCloser: + // In this case, we shouldn't close it. But unwrap it for + // efficiency's sake: s.cmd.Stdin = stdin.Reader - case nopCloserWriterTo: + case readerWriterToNopCloser: + // In this case, we shouldn't close it. But unwrap it for + // efficiency's sake: s.cmd.Stdin = stdin.Reader + case *os.File: + // In this case, we can close stdin as soon as the command + // has started: + s.cmd.Stdin = stdin + earlyClosers = append(earlyClosers, stdin) default: + // In this case, we need to close `stdin`, but we should + // only do so after the command has finished: s.cmd.Stdin = stdin + s.lateClosers = append(s.lateClosers, stdin) } - // Also keep a copy so that we can close it when the command exits: - s.stdin = stdin } - stdout, err := s.cmd.StdoutPipe() - if err != nil { - return nil, err + if stdout != nil { + // See the long comment in `Pipeline.Start()` for the + // explanation of this special case. + switch stdout := stdout.(type) { + case writerNopCloser: + // In this case, we shouldn't close it. But unwrap it for + // efficiency's sake: + s.cmd.Stdout = stdout.Writer + case *os.File: + // In this case, we can close stdout as soon as the command + // has started: + s.cmd.Stdout = stdout + earlyClosers = append(earlyClosers, stdout) + default: + // In this case, we need to close `stdout`, but we should + // only do so after the command has finished: + s.cmd.Stdout = stdout + s.lateClosers = append(s.lateClosers, stdout) + } } // If the caller hasn't arranged otherwise, read the command's @@ -97,7 +152,7 @@ func (s *commandStage) Start( // can be sure. p, err := s.cmd.StderrPipe() if err != nil { - return nil, err + return err } s.wg.Go(func() error { _, err := io.Copy(&s.stderr, p) @@ -114,7 +169,11 @@ func (s *commandStage) Start( s.runInOwnProcessGroup() if err := s.cmd.Start(); err != nil { - return nil, err + return err + } + + for _, closer := range earlyClosers { + _ = closer.Close() } // Arrange for the process to be killed (gently) if the context @@ -128,7 +187,7 @@ func (s *commandStage) Start( } }() - return stdout, nil + return nil } // setupEnv sets or modifies the environment that will be passed to @@ -217,19 +276,18 @@ func (s *commandStage) Wait() error { // Make sure that any stderr is copied before `s.cmd.Wait()` // closes the read end of the pipe: - wErr := s.wg.Wait() + wgErr := s.wg.Wait() err := s.cmd.Wait() err = s.filterCmdError(err) - if err == nil && wErr != nil { - err = wErr + if err == nil && wgErr != nil { + err = wgErr } - if s.stdin != nil { - cErr := s.stdin.Close() - if cErr != nil && err == nil { - return cErr + for _, closer := range s.lateClosers { + if closeErr := closer.Close(); closeErr != nil && err == nil { + err = closeErr } } diff --git a/pipe/command_test.go b/pipe/command_test.go index 92fd37a..67cd55e 100644 --- a/pipe/command_test.go +++ b/pipe/command_test.go @@ -79,7 +79,8 @@ func TestCopyEnvWithOverride(t *testing.T) { ex := ex t.Run(ex.label, func(t *testing.T) { assert.ElementsMatch(t, ex.expectedResult, - copyEnvWithOverrides(ex.env, ex.overrides)) + copyEnvWithOverrides(ex.env, ex.overrides), + ) }) } } diff --git a/pipe/export_test.go b/pipe/export_test.go new file mode 100644 index 0000000..2812292 --- /dev/null +++ b/pipe/export_test.go @@ -0,0 +1,4 @@ +package pipe + +// This file exports a functions to be used only for testing. +var UnwrapNopCloser = unwrapNopCloser diff --git a/pipe/function.go b/pipe/function.go index bc5d0bd..fe8abd0 100644 --- a/pipe/function.go +++ b/pipe/function.go @@ -9,7 +9,7 @@ import ( // StageFunc is a function that can be used to power a `goStage`. It // should read its input from `stdin` and write its output to // `stdout`. `stdin` and `stdout` will be closed automatically (if -// necessary) once the function returns. +// non-nil) once the function returns. // // Neither `stdin` nor `stdout` are necessarily buffered. If the // `StageFunc` requires buffering, it needs to arrange that itself. @@ -38,26 +38,53 @@ type goStage struct { err error } +var ( + _ Stage = (*goStage)(nil) +) + func (s *goStage) Name() string { return s.name } -func (s *goStage) Start(ctx context.Context, env Env, stdin io.ReadCloser) (io.ReadCloser, error) { - r, w := io.Pipe() +func (s *goStage) Preferences() StagePreferences { + return StagePreferences{ + StdinPreference: IOPreferenceUndefined, + StdoutPreference: IOPreferenceUndefined, + } +} + +func (s *goStage) Start( + ctx context.Context, env Env, stdin io.ReadCloser, stdout io.WriteCloser, +) error { + var r io.Reader = stdin + if stdin, ok := stdin.(readerNopCloser); ok { + r = stdin.Reader + } + + var w io.Writer = stdout + if stdout, ok := stdout.(writerNopCloser); ok { + w = stdout.Writer + } + go func() { - s.err = s.f(ctx, env, stdin, w) - if err := w.Close(); err != nil && s.err == nil { - s.err = fmt.Errorf("error closing output pipe for stage %q: %w", s.Name(), err) + s.err = s.f(ctx, env, r, w) + + if stdout != nil { + if err := stdout.Close(); err != nil && s.err == nil { + s.err = fmt.Errorf("error closing stdout for stage %q: %w", s.Name(), err) + } } + if stdin != nil { if err := stdin.Close(); err != nil && s.err == nil { s.err = fmt.Errorf("error closing stdin for stage %q: %w", s.Name(), err) } } + close(s.done) }() - return r, nil + return nil } func (s *goStage) Wait() error { diff --git a/pipe/iocopier.go b/pipe/iocopier.go deleted file mode 100644 index 78a9143..0000000 --- a/pipe/iocopier.go +++ /dev/null @@ -1,62 +0,0 @@ -package pipe - -import ( - "context" - "errors" - "io" - "os" -) - -// ioCopier is a stage that copies its stdin to a specified -// `io.Writer`. It generates no stdout itself. -type ioCopier struct { - w io.WriteCloser - done chan struct{} - err error -} - -func newIOCopier(w io.WriteCloser) *ioCopier { - return &ioCopier{ - w: w, - done: make(chan struct{}), - } -} - -func (s *ioCopier) Name() string { - return "ioCopier" -} - -// This method always returns `nil, nil`. -func (s *ioCopier) Start(_ context.Context, _ Env, r io.ReadCloser) (io.ReadCloser, error) { - go func() { - _, err := io.Copy(s.w, r) - // We don't consider `ErrClosed` an error (FIXME: is this - // correct?): - if err != nil && !errors.Is(err, os.ErrClosed) { - s.err = err - } - if err := r.Close(); err != nil && s.err == nil { - s.err = err - } - if err := s.w.Close(); err != nil && s.err == nil { - s.err = err - } - close(s.done) - }() - - // FIXME: if `s.w.Write()` is blocking (e.g., because there is a - // downstream process that is not reading from the other side), - // there's no way to terminate the copy when the context expires. - // This is not too bad, because the `io.Copy()` call will exit by - // itself when its input is closed. - // - // We could, however, be smarter about exiting more quickly if the - // context expires but `s.w.Write()` is not blocking. - - return nil, nil -} - -func (s *ioCopier) Wait() error { - <-s.done - return s.err -} diff --git a/pipe/memorylimit.go b/pipe/memorylimit.go index f21ee15..b2d7b39 100644 --- a/pipe/memorylimit.go +++ b/pipe/memorylimit.go @@ -11,12 +11,12 @@ import ( const memoryPollInterval = time.Second -// ErrMemoryLimitExceeded is the error that will be used to kill a process, if -// necessary, from MemoryLimit. +// ErrMemoryLimitExceeded is the error that will be used to kill a +// process, if necessary, from MemoryLimit. var ErrMemoryLimitExceeded = errors.New("memory limit exceeded") -// LimitableStage is the superset of Stage that must be implemented by stages -// passed to MemoryLimit and MemoryObserver. +// LimitableStage is the superset of `Stage` that must be implemented +// by stages passed to MemoryLimit and MemoryObserver. type LimitableStage interface { Stage @@ -175,12 +175,24 @@ func (m *memoryWatchStage) Name() string { return m.stage.Name() + m.nameSuffix } -func (m *memoryWatchStage) Start(ctx context.Context, env Env, stdin io.ReadCloser) (io.ReadCloser, error) { - io, err := m.stage.Start(ctx, env, stdin) - if err != nil { - return nil, err +func (m *memoryWatchStage) Preferences() StagePreferences { + return m.stage.Preferences() +} + +func (m *memoryWatchStage) Start( + ctx context.Context, env Env, stdin io.ReadCloser, stdout io.WriteCloser, +) error { + if err := m.stage.Start(ctx, env, stdin, stdout); err != nil { + return err } + m.monitor(ctx) + + return nil +} + +// monitor starts up a goroutine that monitors the memory of `m`. +func (m *memoryWatchStage) monitor(ctx context.Context) { ctx, cancel := context.WithCancel(ctx) m.cancel = cancel m.wg.Add(1) @@ -189,8 +201,6 @@ func (m *memoryWatchStage) Start(ctx context.Context, env Env, stdin io.ReadClos m.watch(ctx, m.stage) m.wg.Done() }() - - return io, nil } func (m *memoryWatchStage) Wait() error { diff --git a/pipe/memorylimit_test.go b/pipe/memorylimit_test.go index 7501c80..e84f465 100644 --- a/pipe/memorylimit_test.go +++ b/pipe/memorylimit_test.go @@ -8,6 +8,7 @@ import ( "log" "os" "strings" + "syscall" "testing" "time" @@ -112,54 +113,36 @@ func TestMemoryLimitTreeMem(t *testing.T) { require.ErrorContains(t, err, "memory limit exceeded") } -type closeWrapper struct { - io.Writer - close func() error -} - -func (w closeWrapper) Close() error { - return w.close() -} - func testMemoryLimit(t *testing.T, mbs int, limit uint64, stage pipe.Stage) (string, error) { ctx := context.Background() - stdinReader, stdinWriter := io.Pipe() - devNull, err := os.OpenFile("/dev/null", os.O_WRONLY, 0) require.NoError(t, err) - // io.Pipe doesn't know if anything is listening on the other end, so once - // our process is expectedly killed then we'll end up blocked trying to - // write to it. To workaround this, make sure we close the pipe reader when - // we've detected that the process has exited (i.e. when stdout has been - // closed). This will cause our write to immediately fail with this error. - closedErr := fmt.Errorf("stdout was closed") - stdout := closeWrapper{ - Writer: devNull, - close: func() error { - require.NoError(t, stdinReader.CloseWithError(closedErr)) - return nil - }, - } - buf := &bytes.Buffer{} logger := log.New(buf, "testMemoryObserver", log.Ldate|log.Ltime) - p := pipe.New(pipe.WithDir("/"), pipe.WithStdin(stdinReader), pipe.WithStdoutCloser(stdout)) - p.Add(pipe.MemoryLimit(stage, limit, LogEventHandler(logger))) + p := pipe.New(pipe.WithDir("/"), pipe.WithStdoutCloser(devNull)) + p.Add( + pipe.Function( + "write-to-less", + func(ctx context.Context, _ pipe.Env, _ io.Reader, stdout io.Writer) error { + // Write some nonsense data to less. + var bytes [1_000_000]byte + for i := 0; i < mbs; i++ { + _, err := stdout.Write(bytes[:]) + if err != nil { + require.ErrorIs(t, err, syscall.EPIPE) + } + } + + return nil + }, + ), + pipe.MemoryLimit(stage, limit, LogEventHandler(logger)), + ) require.NoError(t, p.Start(ctx)) - // Write some nonsense data to less. - var bytes [1_000_000]byte - for i := 0; i < mbs; i++ { - _, err := stdinWriter.Write(bytes[:]) - if err != nil { - require.ErrorIs(t, err, closedErr) - } - } - - require.NoError(t, stdinWriter.Close()) err = p.Wait() return buf.String(), err diff --git a/pipe/nop_closer.go b/pipe/nop_closer.go index d435d0a..18cf7a9 100644 --- a/pipe/nop_closer.go +++ b/pipe/nop_closer.go @@ -6,29 +6,64 @@ package pipe import "io" -// newNopCloser returns a ReadCloser with a no-op Close method wrapping -// the provided io.Reader r. -// If r implements io.WriterTo, the returned io.ReadCloser will implement io.WriterTo -// by forwarding calls to r. -func newNopCloser(r io.Reader) io.ReadCloser { +// newReaderNopCloser returns a ReadCloser with a no-op Close method, +// wrapping the provided io.Reader `r`. If `r` implements +// `io.WriterTo`, the returned `io.ReadCloser` will also implement +// `io.WriterTo` by forwarding calls to `r`. +func newReaderNopCloser(r io.Reader) io.ReadCloser { if _, ok := r.(io.WriterTo); ok { - return nopCloserWriterTo{r} + return readerWriterToNopCloser{r} } - return nopCloser{r} + return readerNopCloser{r} } -type nopCloser struct { +// readerNopCloser is a ReadCloser that wraps a provided `io.Reader`, +// but whose `Close()` method does nothing. We don't need to check +// whether the wrapped reader also implements `io.WriterTo`, since +// it's always unwrapped before use. +type readerNopCloser struct { io.Reader } -func (nopCloser) Close() error { return nil } +func (readerNopCloser) Close() error { + return nil +} -type nopCloserWriterTo struct { +// readerWriterToNopCloser is like `readerNopCloser` except that it +// also implements `io.WriterTo` by delegating `WriteTo()` to the +// wrapped `io.Reader` (which must also implement `io.WriterTo`). +type readerWriterToNopCloser struct { io.Reader } -func (nopCloserWriterTo) Close() error { return nil } +func (readerWriterToNopCloser) Close() error { return nil } + +func (r readerWriterToNopCloser) WriteTo(w io.Writer) (n int64, err error) { + return r.Reader.(io.WriterTo).WriteTo(w) +} + +// writerNopCloser is a WriteCloser that wraps a provided `io.Writer`, +// but whose `Close()` method does nothing. +type writerNopCloser struct { + io.Writer +} + +func (w writerNopCloser) Close() error { + return nil +} -func (c nopCloserWriterTo) WriteTo(w io.Writer) (n int64, err error) { - return c.Reader.(io.WriterTo).WriteTo(w) +// unwrapNopCloser unwraps the object if it is some kind of nop +// closer, and returns the underlying object. This function is used +// only for testing. +func unwrapNopCloser(obj any) (any, bool) { + switch obj := obj.(type) { + case readerNopCloser: + return obj.Reader, true + case readerWriterToNopCloser: + return obj.Reader, true + case writerNopCloser: + return obj.Writer, true + default: + return nil, false + } } diff --git a/pipe/pipe_matching_test.go b/pipe/pipe_matching_test.go new file mode 100644 index 0000000..5528ba2 --- /dev/null +++ b/pipe/pipe_matching_test.go @@ -0,0 +1,376 @@ +package pipe_test + +import ( + "context" + "fmt" + "io" + "os" + "testing" + + "github.com/github/go-pipe/pipe" + "github.com/stretchr/testify/assert" +) + +// Tests that `Pipeline.Start()` uses the correct types of pipes in +// various situations. +// +// The type of pipe to use depends on both the source and the consumer +// of the data, including the overall pipeline's stdin and stdout. So +// there are a lot of possibilities to consider. + +// Additional values used for the expected types of stdin/stdout: +const ( + IOPreferenceUndefinedNopCloser pipe.IOPreference = iota + 100 + IOPreferenceFileNopCloser +) + +func file(t *testing.T) *os.File { + f, err := os.Open(os.DevNull) + assert.NoError(t, err) + return f +} + +func readCloser() io.ReadCloser { + r, w := io.Pipe() + w.Close() + return r +} + +func writeCloser() io.WriteCloser { + r, w := io.Pipe() + r.Close() + return w +} + +func newPipeSniffingStage( + stdinPreference, stdinExpectation pipe.IOPreference, + stdoutPreference, stdoutExpectation pipe.IOPreference, +) *pipeSniffingStage { + return &pipeSniffingStage{ + prefs: pipe.StagePreferences{ + StdinPreference: stdinPreference, + StdoutPreference: stdoutPreference, + }, + expect: pipe.StagePreferences{ + StdinPreference: stdinExpectation, + StdoutPreference: stdoutExpectation, + }, + } +} + +func newPipeSniffingFunc( + stdinExpectation, stdoutExpectation pipe.IOPreference, +) *pipeSniffingStage { + return newPipeSniffingStage( + pipe.IOPreferenceUndefined, stdinExpectation, + pipe.IOPreferenceUndefined, stdoutExpectation, + ) +} + +func newPipeSniffingCmd( + stdinExpectation, stdoutExpectation pipe.IOPreference, +) *pipeSniffingStage { + return newPipeSniffingStage( + pipe.IOPreferenceFile, stdinExpectation, + pipe.IOPreferenceFile, stdoutExpectation, + ) +} + +type pipeSniffingStage struct { + prefs pipe.StagePreferences + expect pipe.StagePreferences + stdin io.ReadCloser + stdout io.WriteCloser +} + +func (*pipeSniffingStage) Name() string { + return "pipe-sniffer" +} + +func (s *pipeSniffingStage) Preferences() pipe.StagePreferences { + return s.prefs +} + +func (s *pipeSniffingStage) Start( + _ context.Context, _ pipe.Env, stdin io.ReadCloser, stdout io.WriteCloser, +) error { + s.stdin = stdin + if stdin != nil { + _ = stdin.Close() + } + s.stdout = stdout + if stdout != nil { + _ = stdout.Close() + } + return nil +} + +func (s *pipeSniffingStage) check(t *testing.T, i int) { + t.Helper() + + checkStdinExpectation(t, i, s.expect.StdinPreference, s.stdin) + checkStdoutExpectation(t, i, s.expect.StdoutPreference, s.stdout) +} + +func (s *pipeSniffingStage) Wait() error { + return nil +} + +var _ pipe.Stage = (*pipeSniffingStage)(nil) + +func ioTypeString(f any) string { + if f == nil { + return "nil" + } + if f, ok := pipe.UnwrapNopCloser(f); ok { + return fmt.Sprintf("nopCloser(%s)", ioTypeString(f)) + } + switch f := f.(type) { + case *os.File: + return "*os.File" + case io.Reader: + return "other" + case io.Writer: + return "other" + default: + return fmt.Sprintf("%T", f) + } +} + +func prefString(pref pipe.IOPreference) string { + switch pref { + case pipe.IOPreferenceUndefined: + return "other" + case pipe.IOPreferenceFile: + return "*os.File" + case pipe.IOPreferenceNil: + return "nil" + case IOPreferenceUndefinedNopCloser: + return "nopCloser(other)" + case IOPreferenceFileNopCloser: + return "nopCloser(*os.File)" + default: + panic(fmt.Sprintf("invalid IOPreference: %d", pref)) + } +} + +func checkStdinExpectation(t *testing.T, i int, pref pipe.IOPreference, stdin io.ReadCloser) { + t.Helper() + + ioType := ioTypeString(stdin) + expType := prefString(pref) + assert.Equalf( + t, expType, ioType, + "stage %d stdin: expected %s, got %s (%T)", i, expType, ioType, stdin, + ) +} + +type WriterNopCloser interface { + NopCloserWriter() io.Writer +} + +func checkStdoutExpectation(t *testing.T, i int, pref pipe.IOPreference, stdout io.WriteCloser) { + t.Helper() + + ioType := ioTypeString(stdout) + expType := prefString(pref) + assert.Equalf( + t, expType, ioType, + "stage %d stdout: expected %s, got %s (%T)", i, expType, ioType, stdout, + ) +} + +type checker interface { + check(t *testing.T, i int) +} + +func TestPipeTypes(t *testing.T) { + ctx := context.Background() + + t.Parallel() + + for _, tc := range []struct { + name string + opts []pipe.Option + stages []pipe.Stage + stdin io.Reader + stdout io.Writer + }{ + { + name: "func", + opts: []pipe.Option{}, + stages: []pipe.Stage{ + newPipeSniffingFunc(pipe.IOPreferenceNil, pipe.IOPreferenceNil), + }, + }, + { + name: "func-file-stdin", + opts: []pipe.Option{ + pipe.WithStdin(file(t)), + }, + stages: []pipe.Stage{ + newPipeSniffingFunc(IOPreferenceFileNopCloser, pipe.IOPreferenceNil), + }, + }, + { + name: "func-file-stdout", + opts: []pipe.Option{ + pipe.WithStdout(file(t)), + }, + stages: []pipe.Stage{ + newPipeSniffingFunc(pipe.IOPreferenceNil, IOPreferenceFileNopCloser), + }, + }, + { + name: "func-file-stdout-closer", + opts: []pipe.Option{ + pipe.WithStdoutCloser(file(t)), + }, + stages: []pipe.Stage{ + newPipeSniffingFunc(pipe.IOPreferenceNil, pipe.IOPreferenceFile), + }, + }, + { + name: "func-file-stdin-other-stdout-closer-other", + opts: []pipe.Option{ + pipe.WithStdin(readCloser()), + pipe.WithStdoutCloser(writeCloser()), + }, + stages: []pipe.Stage{ + newPipeSniffingFunc(IOPreferenceUndefinedNopCloser, pipe.IOPreferenceUndefined), + }, + }, + { + name: "cmd", + opts: []pipe.Option{}, + stages: []pipe.Stage{ + newPipeSniffingCmd(pipe.IOPreferenceNil, pipe.IOPreferenceNil), + }, + }, + { + name: "cmd-file-stdin", + opts: []pipe.Option{ + pipe.WithStdin(file(t)), + }, + stages: []pipe.Stage{ + newPipeSniffingCmd(IOPreferenceFileNopCloser, pipe.IOPreferenceNil), + }, + }, + { + name: "cmd-file-stdout", + opts: []pipe.Option{ + pipe.WithStdout(file(t)), + }, + stages: []pipe.Stage{ + newPipeSniffingCmd(pipe.IOPreferenceNil, IOPreferenceFileNopCloser), + }, + }, + { + name: "cmd-file-stdout-closer", + opts: []pipe.Option{ + pipe.WithStdoutCloser(file(t)), + }, + stages: []pipe.Stage{ + newPipeSniffingCmd(pipe.IOPreferenceNil, pipe.IOPreferenceFile), + }, + }, + { + name: "cmd-file-stdin-other-stdout-closer-other", + opts: []pipe.Option{ + pipe.WithStdin(readCloser()), + pipe.WithStdoutCloser(writeCloser()), + }, + stages: []pipe.Stage{ + newPipeSniffingCmd(IOPreferenceUndefinedNopCloser, pipe.IOPreferenceUndefined), + }, + }, + { + name: "func-func", + opts: []pipe.Option{ + pipe.WithStdin(file(t)), + pipe.WithStdoutCloser(writeCloser()), + }, + stages: []pipe.Stage{ + newPipeSniffingFunc(IOPreferenceFileNopCloser, pipe.IOPreferenceUndefined), + newPipeSniffingFunc(pipe.IOPreferenceUndefined, pipe.IOPreferenceUndefined), + }, + }, + { + name: "func-cmd", + opts: []pipe.Option{ + pipe.WithStdout(file(t)), + }, + stages: []pipe.Stage{ + newPipeSniffingFunc(pipe.IOPreferenceNil, pipe.IOPreferenceFile), + newPipeSniffingCmd(pipe.IOPreferenceFile, IOPreferenceFileNopCloser), + }, + }, + { + name: "cmd-func", + opts: []pipe.Option{ + pipe.WithStdin(readCloser()), + }, + stages: []pipe.Stage{ + newPipeSniffingCmd(IOPreferenceUndefinedNopCloser, pipe.IOPreferenceFile), + newPipeSniffingFunc(pipe.IOPreferenceFile, pipe.IOPreferenceNil), + }, + }, + { + name: "cmd-cmd", + opts: []pipe.Option{}, + stages: []pipe.Stage{ + newPipeSniffingCmd(pipe.IOPreferenceNil, pipe.IOPreferenceFile), + newPipeSniffingCmd(pipe.IOPreferenceFile, pipe.IOPreferenceNil), + }, + }, + { + name: "hybrid1", + opts: []pipe.Option{}, + stages: []pipe.Stage{ + newPipeSniffingStage( + pipe.IOPreferenceUndefined, pipe.IOPreferenceNil, + pipe.IOPreferenceUndefined, pipe.IOPreferenceUndefined, + ), + newPipeSniffingStage( + pipe.IOPreferenceUndefined, pipe.IOPreferenceUndefined, + pipe.IOPreferenceFile, pipe.IOPreferenceFile, + ), + newPipeSniffingStage( + pipe.IOPreferenceUndefined, pipe.IOPreferenceFile, + pipe.IOPreferenceUndefined, pipe.IOPreferenceNil, + ), + }, + }, + { + name: "hybrid2", + opts: []pipe.Option{}, + stages: []pipe.Stage{ + newPipeSniffingStage( + pipe.IOPreferenceUndefined, pipe.IOPreferenceNil, + pipe.IOPreferenceUndefined, pipe.IOPreferenceFile, + ), + newPipeSniffingStage( + pipe.IOPreferenceFile, pipe.IOPreferenceFile, + pipe.IOPreferenceUndefined, pipe.IOPreferenceUndefined, + ), + newPipeSniffingStage( + pipe.IOPreferenceUndefined, pipe.IOPreferenceUndefined, + pipe.IOPreferenceUndefined, pipe.IOPreferenceNil, + ), + }, + }, + } { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + p := pipe.New(tc.opts...) + p.Add(tc.stages...) + assert.NoError(t, p.Run(ctx)) + for i, s := range tc.stages { + s.(checker).check(t, i) + } + }) + } +} diff --git a/pipe/pipeline.go b/pipe/pipeline.go index e591c63..4fdd4a5 100644 --- a/pipe/pipeline.go +++ b/pipe/pipeline.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "os" "sync/atomic" ) @@ -53,7 +54,7 @@ type ContextValuesFunc func(context.Context) []EnvVar type Pipeline struct { env Env - stdin io.Reader + stdin io.ReadCloser stdout io.WriteCloser stages []Stage cancel func() @@ -68,14 +69,6 @@ type Pipeline struct { var emptyEventHandler = func(e *Event) {} -type nopWriteCloser struct { - io.Writer -} - -func (w nopWriteCloser) Close() error { - return nil -} - type NewPipeFn func(opts ...Option) *Pipeline // NewPipeline returns a Pipeline struct with all of the `options` @@ -105,14 +98,58 @@ func WithDir(dir string) Option { // WithStdin assigns stdin to the first command in the pipeline. func WithStdin(stdin io.Reader) Option { return func(p *Pipeline) { - p.stdin = stdin + // We don't want the first stage to close `stdin`, and it is + // not even necessarily an `io.ReadCloser`. So wrap it in a + // fake `io.ReadCloser` whose `Close()` method doesn't do + // anything. + // + // We could use `io.NopCloser()` for this purpose, but that + // would have a subtle problem. If the first stage is a + // `Command`, then it wants to set the `exec.Cmd`'s `Stdin` to + // an `io.Reader` corresponding to `p.stdin`. If `Cmd.Stdin` + // is an `*os.File`, then `exec.Cmd` will pass the file + // descriptor to the subcommand directly; there is no need to + // create a pipe and copy the data into the input side of the + // pipe. But if `p.stdin` is not an `*os.File`, then this + // optimization is prevented. And even worse, it also has the + // side effect that the goroutine that copies from `Cmd.Stdin` + // into the pipe doesn't terminate until that fd is closed by + // the writing side. + // + // That isn't always what we want. Consider, for example, the + // following snippet, where the subcommand's stdin is set to + // the stdin of the enclosing Go program, but wrapped with + // `io.NopCloser`: + // + // cmd := exec.Command("ls") + // cmd.Stdin = io.NopCloser(os.Stdin) + // cmd.Stdout = os.Stdout + // cmd.Stderr = os.Stderr + // cmd.Run() + // + // In this case, we don't want the Go program to wait for + // `os.Stdin` to close (because `ls` isn't even trying to read + // from its stdin). But it does: `exec.Cmd` doesn't recognize + // that `Cmd.Stdin` is an `*os.File`, so it sets up a pipe and + // copies the data itself, and this goroutine doesn't + // terminate until `cmd.Stdin` (i.e., the Go program's own + // stdin) is closed. But if, for example, the Go program is + // run from an interactive shell session, that might never + // happen, in which case the program will fail to terminate, + // even after `ls` exits. + // + // So instead, in this special case, we wrap `stdin` in our + // own `nopCloser`, which behaves like `io.NopCloser`, except + // that `pipe.CommandStage` knows how to unwrap it before + // passing it to `exec.Cmd`. + p.stdin = newReaderNopCloser(stdin) } } // WithStdout assigns stdout to the last command in the pipeline. func WithStdout(stdout io.Writer) Option { return func(p *Pipeline) { - p.stdout = nopWriteCloser{stdout} + p.stdout = writerNopCloser{stdout} } } @@ -204,6 +241,12 @@ func (p *Pipeline) AddWithIgnoredError(em ErrorMatcher, stages ...Stage) { } } +type stageStarter struct { + prefs StagePreferences + stdin io.ReadCloser + stdout io.WriteCloser +} + // Start starts the commands in the pipeline. If `Start()` exits // without an error, `Wait()` must also be called, to allow all // resources to be freed. @@ -215,89 +258,94 @@ func (p *Pipeline) Start(ctx context.Context) error { atomic.StoreUint32(&p.started, 1) ctx, p.cancel = context.WithCancel(ctx) - var nextStdin io.ReadCloser + // We need to decide how to start the stages, especially what + // pipes to use to connect adjacent stages (`os.Pipe()` vs. + // `io.Pipe()`) based on the two stages' preferences. + stageStarters := make([]stageStarter, len(p.stages), len(p.stages)+1) + + // Collect information about each stage's type and preferences: + for i, s := range p.stages { + stageStarters[i].prefs = s.Preferences() + } + if p.stdin != nil { - // We don't want the first stage to actually close this, and - // `p.stdin` is not even necessarily an `io.ReadCloser`. So - // wrap it in a fake `io.ReadCloser` whose `Close()` method - // doesn't do anything. - // - // We could use `io.NopCloser()` for this purpose, but it has - // a subtle problem. If the first stage is a `Command`, then - // it wants to set the `exec.Cmd`'s `Stdin` to an `io.Reader` - // corresponding to `p.stdin`. If `Cmd.Stdin` is an - // `*os.File`, then the file descriptor can be passed to the - // subcommand directly; there is no need for this process to - // create a pipe and copy the data into the input side of the - // pipe. But if `p.stdin` is not an `*os.File`, then this - // optimization is prevented. And even worse, it also has the - // side effect that the goroutine that copies from `Cmd.Stdin` - // into the pipe doesn't terminate until that fd is closed by - // the writing side. - // - // That isn't always what we want. Consider, for example, the - // following snippet, where the subcommand's stdin is set to - // the stdin of the enclosing Go program, but wrapped with - // `io.NopCloser`: - // - // cmd := exec.Command("ls") - // cmd.Stdin = io.NopCloser(os.Stdin) - // cmd.Stdout = os.Stdout - // cmd.Stderr = os.Stderr - // cmd.Run() - // - // In this case, we don't want the Go program to wait for - // `os.Stdin` to close (because `ls` isn't even trying to read - // from its stdin). But it does: `exec.Cmd` doesn't recognize - // that `Cmd.Stdin` is an `*os.File`, so it sets up a pipe and - // copies the data itself, and this goroutine doesn't - // terminate until `cmd.Stdin` (i.e., the Go program's own - // stdin) is closed. But if, for example, the Go program is - // run from an interactive shell session, that might never - // happen, in which case the program will fail to terminate, - // even after `ls` exits. - // - // So instead, in this special case, we wrap `p.stdin` in our - // own `nopCloser`, which behaves like `io.NopCloser`, except - // that `pipe.CommandStage` knows how to unwrap it before - // passing it to `exec.Cmd`. - nextStdin = newNopCloser(p.stdin) + // Arrange for the input of the 0th stage to come from + // `p.stdin`: + stageStarters[0].stdin = p.stdin } - for i, s := range p.stages { - var err error - stdout, err := s.Start(ctx, p.env, nextStdin) - if err != nil { - // Close the pipe that the previous stage was writing to. - // That should cause it to exit even if it's not minding - // its context. - if nextStdin != nil { - _ = nextStdin.Close() - } + if p.stdout != nil { + i := len(p.stages) - 1 + ss := &stageStarters[i] + ss.stdout = p.stdout + } - // Kill and wait for any stages that have been started - // already to finish: - p.cancel() - for _, s := range p.stages[:i] { - _ = s.Wait() + // Clean up any processes and pipes that have been created. `i` is + // the index of the stage that failed to start (whose output pipe + // has already been cleaned up if necessary). + abort := func(i int, err error) error { + // Close the pipe that the previous stage was writing to. + // That should cause it to exit even if it's not minding + // its context. + if stageStarters[i].stdin != nil { + _ = stageStarters[i].stdin.Close() + } + + // Kill and wait for any stages that have been started + // already to finish: + p.cancel() + for _, s := range p.stages[:i] { + _ = s.Wait() + } + p.eventHandler(&Event{ + Command: p.stages[i].Name(), + Msg: "failed to start pipeline stage", + Err: err, + }) + return fmt.Errorf( + "starting pipeline stage %q: %w", p.stages[i].Name(), err, + ) + } + + // Loop over all but the last stage, starting them. By the time we + // get to a stage, its stdin will have already been determined, + // but we still need to figure out its stdout and set the stdin + // that will be used for the subsequent stage. + for i, s := range p.stages[:len(p.stages)-1] { + ss := &stageStarters[i] + nextSS := &stageStarters[i+1] + + // We need to generate a pipe pair for this stage to use + // to communicate with its successor: + if ss.prefs.StdoutPreference == IOPreferenceFile || + nextSS.prefs.StdinPreference == IOPreferenceFile { + // Use an OS-level pipe for the communication: + var err error + nextSS.stdin, ss.stdout, err = os.Pipe() + if err != nil { + return abort(i, err) } - p.eventHandler(&Event{ - Command: s.Name(), - Msg: "failed to start pipeline stage", - Err: err, - }) - return fmt.Errorf("starting pipeline stage %q: %w", s.Name(), err) + } else { + nextSS.stdin, ss.stdout = io.Pipe() + } + if err := s.Start(ctx, p.env, ss.stdin, ss.stdout); err != nil { + nextSS.stdin.Close() + ss.stdout.Close() + return abort(i, err) } - nextStdin = stdout } - // If the pipeline was configured with a `stdout`, add a synthetic - // stage to copy the last stage's stdout to that writer: - if p.stdout != nil { - c := newIOCopier(p.stdout) - p.stages = append(p.stages, c) - // `ioCopier.Start()` never fails: - _, _ = c.Start(ctx, p.env, nextStdin) + // The last stage needs special handling, because its stdout + // doesn't need to flow into another stage (it's already set in + // `ss.stdout` if it's needed). + { + i := len(p.stages) - 1 + s := p.stages[i] + ss := &stageStarters[i] + + if err := s.Start(ctx, p.env, ss.stdin, ss.stdout); err != nil { + return abort(i, err) + } } return nil @@ -305,7 +353,7 @@ func (p *Pipeline) Start(ctx context.Context) error { func (p *Pipeline) Output(ctx context.Context) ([]byte, error) { var buf bytes.Buffer - p.stdout = nopWriteCloser{&buf} + p.stdout = writerNopCloser{&buf} err := p.Run(ctx) return buf.Bytes(), err } diff --git a/pipe/pipeline_test.go b/pipe/pipeline_test.go index d925aee..16a88b4 100644 --- a/pipe/pipeline_test.go +++ b/pipe/pipeline_test.go @@ -87,7 +87,7 @@ func TestPipelineSingleCommandWithStdout(t *testing.T) { } } -func TestPipelineStdinFileThatIsNeverClosed(t *testing.T) { +func TestPipelineStdinOSPipeThatIsNeverClosed(t *testing.T) { t.Parallel() // Make sure that the subprocess terminates on its own, as opposed @@ -105,7 +105,10 @@ func TestPipelineStdinFileThatIsNeverClosed(t *testing.T) { var stdout bytes.Buffer - p := pipe.New(pipe.WithStdin(r), pipe.WithStdout(&stdout)) + p := pipe.New( + pipe.WithStdin(r), + pipe.WithStdout(&stdout), + ) // Note that this command doesn't read from its stdin, so it will // terminate regardless of whether `w` gets closed: p.Add(pipe.Command("true")) @@ -115,7 +118,7 @@ func TestPipelineStdinFileThatIsNeverClosed(t *testing.T) { assert.NoError(t, p.Run(ctx)) } -func TestPipelineStdinThatIsNeverClosed(t *testing.T) { +func TestPipelineIOPipeStdinThatIsNeverClosed(t *testing.T) { t.Skip("test not run because it currently deadlocks") t.Parallel() @@ -131,8 +134,7 @@ func TestPipelineStdinThatIsNeverClosed(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) defer cancel() - r, w, err := os.Pipe() - require.NoError(t, err) + r, w := io.Pipe() t.Cleanup(func() { _ = w.Close() _ = r.Close() @@ -140,10 +142,8 @@ func TestPipelineStdinThatIsNeverClosed(t *testing.T) { var stdout bytes.Buffer - // The point here is to wrap `r` so that `exec.Cmd` doesn't - // recognize that it's an `*os.File`: p := pipe.New( - pipe.WithStdin(io.NopCloser(r)), + pipe.WithStdin(r), pipe.WithStdout(&stdout), ) // Note that this command doesn't read from its stdin, so it will @@ -159,9 +159,7 @@ func TestNontrivialPipeline(t *testing.T) { t.Parallel() ctx := context.Background() - dir := t.TempDir() - - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Command("echo", "hello world"), pipe.Command("sed", "s/hello/goodbye/"), @@ -172,12 +170,13 @@ func TestNontrivialPipeline(t *testing.T) { } } -func TestPipelineReadFromSlowly(t *testing.T) { +func TestOSPipePipelineReadFromSlowly(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - r, w := io.Pipe() + r, w, err := os.Pipe() + require.NoError(t, err) var buf []byte readErr := make(chan error, 1) @@ -189,14 +188,34 @@ func TestPipelineReadFromSlowly(t *testing.T) { readErr <- err }() - p := pipe.New(pipe.WithStdout(w)) + p := pipe.New(pipe.WithStdoutCloser(w)) p.Add(pipe.Command("echo", "hello world")) assert.NoError(t, p.Run(ctx)) - time.Sleep(100 * time.Millisecond) - // It's not super-intuitive, but `w` has to be closed here so that - // the `io.ReadAll()` call above knows that it's done: - _ = w.Close() + assert.NoError(t, <-readErr) + assert.Equal(t, "hello world\n", string(buf)) +} + +func TestIOPipePipelineReadFromSlowly(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + r, w := io.Pipe() + + var buf []byte + readErr := make(chan error, 1) + + go func() { + time.Sleep(200 * time.Millisecond) + var err error + buf, err = io.ReadAll(r) + readErr <- err + }() + + p := pipe.New(pipe.WithStdoutCloser(w)) + p.Add(pipe.Command("echo", "hello world")) + assert.NoError(t, p.Run(ctx)) assert.NoError(t, <-readErr) assert.Equal(t, "hello world\n", string(buf)) @@ -211,8 +230,6 @@ func TestPipelineReadFromSlowly2(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - dir := t.TempDir() - r, w := io.Pipe() var buf []byte @@ -236,15 +253,10 @@ func TestPipelineReadFromSlowly2(t *testing.T) { } }() - p := pipe.New(pipe.WithDir(dir), pipe.WithStdout(w)) + p := pipe.New(pipe.WithStdoutCloser(w)) p.Add(pipe.Command("seq", "100")) assert.NoError(t, p.Run(ctx)) - time.Sleep(200 * time.Millisecond) - // It's not super-intuitive, but `w` has to be closed here so that - // the `io.ReadAll()` call above knows that it's done: - _ = w.Close() - assert.NoError(t, <-readErr) assert.Equal(t, 292, len(buf)) } @@ -253,9 +265,7 @@ func TestPipelineTwoCommandsPiping(t *testing.T) { t.Parallel() ctx := context.Background() - dir := t.TempDir() - - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add(pipe.Command("echo", "hello world")) assert.Panics(t, func() { p.Add(pipe.Command("")) }) out, err := p.Output(ctx) @@ -283,9 +293,7 @@ func TestPipelineExit(t *testing.T) { t.Parallel() ctx := context.Background() - dir := t.TempDir() - - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Command("false"), pipe.Command("true"), @@ -316,11 +324,10 @@ func TestPipelineInterrupted(t *testing.T) { } t.Parallel() - dir := t.TempDir() stdout := &bytes.Buffer{} - p := pipe.New(pipe.WithDir(dir), pipe.WithStdout(stdout)) + p := pipe.New(pipe.WithStdout(stdout)) p.Add(pipe.Command("sleep", "10")) ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) @@ -339,11 +346,10 @@ func TestPipelineCanceled(t *testing.T) { } t.Parallel() - dir := t.TempDir() stdout := &bytes.Buffer{} - p := pipe.New(pipe.WithDir(dir), pipe.WithStdout(stdout)) + p := pipe.New(pipe.WithStdout(stdout)) p.Add(pipe.Command("sleep", "10")) ctx, cancel := context.WithCancel(context.Background()) @@ -367,9 +373,8 @@ func TestLittleEPIPE(t *testing.T) { } t.Parallel() - dir := t.TempDir() - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Command("sh", "-c", "sleep 1; echo foo"), pipe.Command("true"), @@ -391,9 +396,8 @@ func TestBigEPIPE(t *testing.T) { } t.Parallel() - dir := t.TempDir() - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Command("seq", "100000"), pipe.Command("true"), @@ -415,9 +419,8 @@ func TestIgnoredSIGPIPE(t *testing.T) { } t.Parallel() - dir := t.TempDir() - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.IgnoreError(pipe.Command("seq", "100000"), pipe.IsSIGPIPE), pipe.Command("echo", "foo"), @@ -434,9 +437,7 @@ func TestFunction(t *testing.T) { t.Parallel() ctx := context.Background() - dir := t.TempDir() - - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Print("hello world"), pipe.Function( @@ -464,9 +465,7 @@ func TestPipelineWithFunction(t *testing.T) { t.Parallel() ctx := context.Background() - dir := t.TempDir() - - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Command("echo", "-n", "hello world"), pipe.Function( @@ -499,10 +498,23 @@ func (s ErrorStartingStage) Name() string { return "errorStartingStage" } +func (s ErrorStartingStage) Preferences() pipe.StagePreferences { + return pipe.StagePreferences{ + StdinPreference: pipe.IOPreferenceUndefined, + StdoutPreference: pipe.IOPreferenceUndefined, + } +} + func (s ErrorStartingStage) Start( - _ context.Context, _ pipe.Env, _ io.ReadCloser, -) (io.ReadCloser, error) { - return io.NopCloser(&bytes.Buffer{}), s.err + _ context.Context, _ pipe.Env, stdin io.ReadCloser, stdout io.WriteCloser, +) error { + if stdin != nil { + _ = stdin.Close() + } + if stdout != nil { + _ = stdout.Close() + } + return s.err } func (s ErrorStartingStage) Wait() error { @@ -528,9 +540,7 @@ func TestPipelineWithLinewiseFunction(t *testing.T) { t.Parallel() ctx := context.Background() - dir := t.TempDir() - - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() // Print the numbers from 1 to 20 (generated from scratch): p.Add( seqFunction(20), @@ -581,7 +591,7 @@ func TestScannerAlwaysFlushes(t *testing.T) { var length int64 - p := pipe.New(pipe.WithDir(".")) + p := pipe.New() // Print the numbers from 1 to 20 (generated from scratch): p.Add( pipe.IgnoreError( @@ -629,7 +639,7 @@ func TestScannerFinishEarly(t *testing.T) { var length int64 - p := pipe.New(pipe.WithDir(".")) + p := pipe.New() // Print the numbers from 1 to 20 (generated from scratch): p.Add( pipe.IgnoreError( @@ -670,9 +680,7 @@ func TestPrintln(t *testing.T) { t.Parallel() ctx := context.Background() - dir := t.TempDir() - - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add(pipe.Println("Look Ma, no hands!")) out, err := p.Output(ctx) if assert.NoError(t, err) { @@ -684,9 +692,7 @@ func TestPrintf(t *testing.T) { t.Parallel() ctx := context.Background() - dir := t.TempDir() - - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add(pipe.Printf("Strangely recursive: %T", p)) out, err := p.Output(ctx) if assert.NoError(t, err) { @@ -880,10 +886,8 @@ func TestErrors(t *testing.T) { func BenchmarkSingleProgram(b *testing.B) { ctx := context.Background() - dir := b.TempDir() - for i := 0; i < b.N; i++ { - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Command("true"), ) @@ -894,10 +898,8 @@ func BenchmarkSingleProgram(b *testing.B) { func BenchmarkTenPrograms(b *testing.B) { ctx := context.Background() - dir := b.TempDir() - for i := 0; i < b.N; i++ { - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Command("echo", "hello world"), pipe.Command("cat"), @@ -920,15 +922,13 @@ func BenchmarkTenPrograms(b *testing.B) { func BenchmarkTenFunctions(b *testing.B) { ctx := context.Background() - dir := b.TempDir() - cp := func(_ context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error { _, err := io.Copy(stdout, stdin) return err } for i := 0; i < b.N; i++ { - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Println("hello world"), pipe.Function("copy1", cp), @@ -951,15 +951,13 @@ func BenchmarkTenFunctions(b *testing.B) { func BenchmarkTenMixedStages(b *testing.B) { ctx := context.Background() - dir := b.TempDir() - cp := func(_ context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error { _, err := io.Copy(stdout, stdin) return err } for i := 0; i < b.N; i++ { - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Command("echo", "hello world"), pipe.Function("copy1", cp), @@ -979,6 +977,97 @@ func BenchmarkTenMixedStages(b *testing.B) { } } +func BenchmarkMoreDataUnbuffered(b *testing.B) { + ctx := context.Background() + + cp := func(_ context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error { + _, err := io.Copy(stdout, stdin) + return err + } + + for i := 0; i < b.N; i++ { + count := 0 + p := pipe.New() + p.Add( + pipe.Function( + "seq", + func(ctx context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error { + for i := 1; i <= 100000; i++ { + fmt.Fprintln(stdout, i) + } + return nil + }, + ), + pipe.Command("cat"), + pipe.Function("copy2", cp), + pipe.Command("cat"), + pipe.Function("copy3", cp), + pipe.Command("cat"), + pipe.Function("copy4", cp), + pipe.Command("cat"), + pipe.Function("copy5", cp), + pipe.Command("cat"), + pipe.LinewiseFunction( + "count", + func(ctx context.Context, _ pipe.Env, line []byte, stdout *bufio.Writer) error { + count++ + return nil + }, + ), + ) + err := p.Run(ctx) + if assert.NoError(b, err) { + assert.EqualValues(b, 100000, count) + } + } +} + +func BenchmarkMoreDataBuffered(b *testing.B) { + ctx := context.Background() + + cp := func(_ context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error { + _, err := io.Copy(stdout, stdin) + return err + } + + for i := 0; i < b.N; i++ { + count := 0 + p := pipe.New() + p.Add( + pipe.Function( + "seq", + func(ctx context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error { + out := bufio.NewWriter(stdout) + for i := 1; i <= 1000000; i++ { + fmt.Fprintln(out, i) + } + return out.Flush() + }, + ), + pipe.Command("cat"), + pipe.Function("copy2", cp), + pipe.Command("cat"), + pipe.Function("copy3", cp), + pipe.Command("cat"), + pipe.Function("copy4", cp), + pipe.Command("cat"), + pipe.Function("copy5", cp), + pipe.Command("cat"), + pipe.LinewiseFunction( + "count", + func(ctx context.Context, _ pipe.Env, line []byte, stdout *bufio.Writer) error { + count++ + return nil + }, + ), + ) + err := p.Run(ctx) + if assert.NoError(b, err) { + assert.EqualValues(b, 1000000, count) + } + } +} + func genErr(err error) pipe.StageFunc { return func(_ context.Context, _ pipe.Env, _ io.Reader, _ io.Writer) error { return err diff --git a/pipe/scanner.go b/pipe/scanner.go index b56b58c..5ec16e8 100644 --- a/pipe/scanner.go +++ b/pipe/scanner.go @@ -56,11 +56,7 @@ func ScannerFunction( return err } } - if err := scanner.Err(); err != nil { - return err - } - - return nil + return scanner.Err() // `p.AddFunction()` arranges for `stdout` to be closed. }, ) diff --git a/pipe/stage.go b/pipe/stage.go index f3d74d9..3e12774 100644 --- a/pipe/stage.go +++ b/pipe/stage.go @@ -5,30 +5,142 @@ import ( "io" ) -// Stage is an element of a `Pipeline`. +// +// From the point of view of the pipeline as a whole, if stdin is +// provided by the user (`WithStdin()`), then we don't want to close +// it at all, whether it's an `*os.File` or not. For this reason, +// stdin has to be wrapped using a `readerNopCloser` before being +// passed into the first stage. For efficiency reasons, it's +// advantageous for the first stage should ideally unwrap its stdin +// argument before actually using it. If the wrapped value is an +// `*os.File` and the stage is a command stage, then unwrapping is +// also important to get the right semantics. +// +// For stdout, it depends on whether the user supplied it using +// `WithStdout()` or `WithStdoutCloser()`. If the former, then the +// considerations are the same as for stdin. +// +// [1] It's theoretically possible for a command to pass the open file +// descriptor to another, longer-lived process, in which case the +// file descriptor wouldn't necessarily get closed when the +// command finishes. But that's ill-behaved in a command that is +// being used in a pipeline, so we'll ignore that possibility. + +// Stage is an element of a `Pipeline`. It reads from standard input +// and writes to standard output. +// +// Who closes stdin and stdout? +// +// A `Stage` as a whole needs to be responsible for closing its end of +// stdin and stdout (assuming that `Start()` returns successfully). +// Its doing so tells the previous/next stage that it is done +// reading/writing data, which can affect their behavior. Therefore, +// it should close each one as soon as it is done with it. If the +// caller wants to suppress the closing of stdin/stdout, it can always +// wrap the corresponding argument in a "nopCloser". +// +// How this should be done depends on whether stdin/stdout are of type +// `*os.File`. +// +// If a stage is an external command, then the subprocess ultimately +// needs its own copies of `*os.File` file descriptors for its stdin +// and stdout. The external command will "always" [1] close those when +// it exits. +// +// If the stage is an external command and one of the arguments is an +// `*os.File`, then it can set the corresponding field of `exec.Cmd` +// to that argument directly. This has the result that `exec.Cmd` +// duplicates that file descriptor and passes the dup to the +// subprocess. Therefore, the stage must close its copy of that +// argument as soon as the external command has started, because the +// external command will keep its own copy open as long as necessary +// (and no longer!). It should use roughly the following sequence: +// +// cmd.Stdin = f // Similarly for stdout +// cmd.Start(…) +// f.Close() // close our copy +// cmd.Wait() +// +// If the stage is an external command and one of its arguments is not +// an `*os.File`, then `exec.Cmd` will take care of creating an +// `os.Pipe()`, copying from the provided argument in/out of the pipe, +// and eventually closing both ends of the pipe. The stage must close +// the argument itself, but only _after_ the external command has +// finished, like so: +// +// cmd.Stdin = r // Similarly for stdout +// cmd.Start(…) +// cmd.Wait() +// r.Close() +// +// If the stage is a Go function, then it holds the only copy of +// stdin/stdout, so it must wait until the function is done before +// closing them (regardless of their underlying type, like so: +// +// go func() { +// f(…, stdin, stdout) +// stdin.Close() +// stdout.Close() +// }() type Stage interface { // Name returns the name of the stage. Name() string + // Preferences() returns this stage's preferences regarding how it + // should be run. + Preferences() StagePreferences + // Start starts the stage in the background, in the environment - // described by `env`, and using `stdin` as input. (`stdin` should - // be set to `nil` if the stage is to receive no input, which - // might be the case for the first stage in a pipeline.) It - // returns an `io.ReadCloser` from which the stage's output can be - // read (or `nil` if it generates no output, which should only be - // the case for the last stage in a pipeline). It is the stages' - // responsibility to close `stdin` (if it is not nil) when it has - // read all of the input that it needs, and to close the write end - // of its output reader when it is done, as that is generally how - // the subsequent stage knows that it has received all of its - // input and can finish its work, too. + // described by `env`, using `stdin` to provide its input and + // `stdout` to collect its output. (`stdin`/`stdout` might be set + // to `nil` if the stage is to receive no input, which might be + // the case for the first/last stage in a pipeline.) See the + // `Stage` type comment for more information about responsibility + // for closing stdin and stdout. // // If `Start()` returns without an error, `Wait()` must also be // called, to allow all resources to be freed. - Start(ctx context.Context, env Env, stdin io.ReadCloser) (io.ReadCloser, error) + Start(ctx context.Context, env Env, stdin io.ReadCloser, stdout io.WriteCloser) error // Wait waits for the stage to be done, either because it has // finished or because it has been killed due to the expiration of // the context passed to `Start()`. Wait() error } + +// StagePreferences is the way that a `Stage` indicates its +// preferences about how it is run. This is used within +// `pipe.Pipeline` to decide when to use `os.Pipe()` vs. `io.Pipe()` +// for creating the pipes between stages. +type StagePreferences struct { + StdinPreference IOPreference + StdoutPreference IOPreference +} + +// IOPreference describes what type of stdin / stdout a stage would +// prefer. +// +// External commands prefer `*os.File`s (such as those produced by +// `os.Pipe()`) as their stdin and stdout, because those can be passed +// directly by the external process without any extra copying and also +// simplify the semantics around process termination. Go function +// stages are typically happy with any `io.ReadCloser` (such as one +// produced by `io.Pipe()`), which can be more efficient because +// traffic through an `io.Pipe()` happens entirely in userspace. +type IOPreference int + +const ( + // IOPreferenceUndefined indicates that the stage doesn't care + // what form the specified stdin / stdout takes (i.e., any old + // `io.ReadCloser` / `io.WriteCloser` is just fine). + IOPreferenceUndefined IOPreference = iota + + // IOPreferenceFile indicates that the stage would prefer for the + // specified stdin / stdout to be an `*os.File`, to avoid copying. + IOPreferenceFile + + // IOPreferenceNil indicates that the stage does not use the + // specified stdin / stdout, so `nil` should be passed in. This + // should only happen at the beginning / end of a pipeline. + IOPreferenceNil +)