Skip to content
Open
137 changes: 110 additions & 27 deletions pipe/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@ import (
"golang.org/x/sync/errgroup"
)

// commandStage is a pipeline `Stage` based on running an external
// command and piping the data through its stdin and stdout.
// commandStage is a pipeline `Stage2` based on running an external
// command and piping the data through its stdin and stdout. It also
// implements `Stage2`.
type commandStage struct {
name string
stdin io.Closer
cmd *exec.Cmd
name string
cmd *exec.Cmd

// lateClosers is a list of things that have to be closed once the
// command has finished.
lateClosers []io.Closer

done chan struct{}
wg errgroup.Group
stderr bytes.Buffer
Expand All @@ -30,11 +35,15 @@ type commandStage struct {
ctxErr atomic.Value
}

// Command returns a pipeline `Stage` based on the specified external
var (
_ Stage2 = (*commandStage)(nil)
)

// Command returns a pipeline `Stage2` based on the specified external
// `command`, run with the given command-line `args`. Its stdin and
// stdout are handled as usual, and its stderr is collected and
// included in any `*exec.ExitError` that the command might emit.
func Command(command string, args ...string) Stage {
func Command(command string, args ...string) Stage2 {
if len(command) == 0 {
panic("attempt to create command with empty command")
}
Expand All @@ -47,7 +56,7 @@ func Command(command string, args ...string) Stage {
// the specified `cmd`. Its stdin and stdout are handled as usual, and
// its stderr is collected and included in any `*exec.ExitError` that
// the command might emit.
func CommandStage(name string, cmd *exec.Cmd) Stage {
func CommandStage(name string, cmd *exec.Cmd) Stage2 {
return &commandStage{
name: name,
cmd: cmd,
Expand All @@ -62,30 +71,101 @@ func (s *commandStage) Name() string {
func (s *commandStage) Start(
ctx context.Context, env Env, stdin io.ReadCloser,
) (io.ReadCloser, error) {
pr, pw, err := os.Pipe()
if err != nil {
return nil, err
}

if err := s.Start2(ctx, env, stdin, pw); err != nil {
_ = pr.Close()
_ = pw.Close()
return nil, err
}

// Now close our copy of the write end of the pipe (the subprocess
// has its own copy now and will keep it open as long as it is
// running). There's not much we can do now in the case of an
// error, so just ignore them.
_ = pw.Close()

// The caller is responsible for closing `pr`.
return pr, nil
}

func (s *commandStage) Preferences() StagePreferences {
prefs := StagePreferences{
StdinPreference: IOPreferenceFile,
StdoutPreference: IOPreferenceFile,
}
if s.cmd.Stdin != nil {
prefs.StdinPreference = IOPreferenceNil
}
if s.cmd.Stdout != nil {
prefs.StdoutPreference = IOPreferenceNil
}

return prefs
}

func (s *commandStage) Start2(
ctx context.Context, env Env, stdin io.ReadCloser, stdout io.WriteCloser,
) error {
if s.cmd.Dir == "" {
s.cmd.Dir = env.Dir
}

s.setupEnv(ctx, env)

// Things that have to be closed as soon as the command has
// started:
var earlyClosers []io.Closer

// See the type command for `Stage` and the long comment in
// `Pipeline.WithStdin()` for the explanation of this unwrapping
// and closing behavior.

if stdin != nil {
// See the long comment in `Pipeline.Start()` for the
// explanation of this special case.
switch stdin := stdin.(type) {
case nopCloser:
case readerNopCloser:
// In this case, we shouldn't close it. But unwrap it for
// efficiency's sake:
s.cmd.Stdin = stdin.Reader
case nopCloserWriterTo:
case readerWriterToNopCloser:
// In this case, we shouldn't close it. But unwrap it for
// efficiency's sake:
s.cmd.Stdin = stdin.Reader
case *os.File:
// In this case, we can close stdin as soon as the command
// has started:
s.cmd.Stdin = stdin
earlyClosers = append(earlyClosers, stdin)
default:
// In this case, we need to close `stdin`, but we should
// only do so after the command has finished:
s.cmd.Stdin = stdin
s.lateClosers = append(s.lateClosers, stdin)
}
// Also keep a copy so that we can close it when the command exits:
s.stdin = stdin
}

stdout, err := s.cmd.StdoutPipe()
if err != nil {
return nil, err
if stdout != nil {
// See the long comment in `Pipeline.Start()` for the
// explanation of this special case.
switch stdout := stdout.(type) {
case writerNopCloser:
// In this case, we shouldn't close it. But unwrap it for
// efficiency's sake:
s.cmd.Stdout = stdout.Writer
case *os.File:
// In this case, we can close stdout as soon as the command
// has started:
s.cmd.Stdout = stdout
earlyClosers = append(earlyClosers, stdout)
default:
// In this case, we need to close `stdout`, but we should
// only do so after the command has finished:
s.cmd.Stdout = stdout
s.lateClosers = append(s.lateClosers, stdout)
}
}

// If the caller hasn't arranged otherwise, read the command's
Expand All @@ -97,7 +177,7 @@ func (s *commandStage) Start(
// can be sure.
p, err := s.cmd.StderrPipe()
if err != nil {
return nil, err
return err
}
s.wg.Go(func() error {
_, err := io.Copy(&s.stderr, p)
Expand All @@ -114,7 +194,11 @@ func (s *commandStage) Start(
s.runInOwnProcessGroup()

if err := s.cmd.Start(); err != nil {
return nil, err
return err
}

for _, closer := range earlyClosers {
_ = closer.Close()
}

// Arrange for the process to be killed (gently) if the context
Expand All @@ -128,7 +212,7 @@ func (s *commandStage) Start(
}
}()

return stdout, nil
return nil
}

// setupEnv sets or modifies the environment that will be passed to
Expand Down Expand Up @@ -217,19 +301,18 @@ func (s *commandStage) Wait() error {

// Make sure that any stderr is copied before `s.cmd.Wait()`
// closes the read end of the pipe:
wErr := s.wg.Wait()
wgErr := s.wg.Wait()

err := s.cmd.Wait()
err = s.filterCmdError(err)

if err == nil && wErr != nil {
err = wErr
if err == nil && wgErr != nil {
err = wgErr
}

if s.stdin != nil {
cErr := s.stdin.Close()
if cErr != nil && err == nil {
return cErr
for _, closer := range s.lateClosers {
if closeErr := closer.Close(); closeErr != nil && err == nil {
err = closeErr
}
}

Expand Down
2 changes: 1 addition & 1 deletion pipe/command_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
)

// On linux, we can limit or observe memory usage in command stages.
var _ LimitableStage = (*commandStage)(nil)
var _ LimitableStage2 = (*commandStage)(nil)

var (
errProcessInfoMissing = errors.New("cmd.Process is nil")
Expand Down
3 changes: 2 additions & 1 deletion pipe/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ func TestCopyEnvWithOverride(t *testing.T) {
ex := ex
t.Run(ex.label, func(t *testing.T) {
assert.ElementsMatch(t, ex.expectedResult,
copyEnvWithOverrides(ex.env, ex.overrides))
copyEnvWithOverrides(ex.env, ex.overrides),
)
})
}
}
4 changes: 4 additions & 0 deletions pipe/export_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package pipe

// This file exports a functions to be used only for testing.
var UnwrapNopCloser = unwrapNopCloser
12 changes: 12 additions & 0 deletions pipe/filter-error.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ import (
type ErrorFilter func(err error) error

func FilterError(s Stage, filter ErrorFilter) Stage {
if s, ok := s.(Stage2); ok {
return efStage2{Stage2: s, filter: filter}
}
return efStage{Stage: s, filter: filter}
}

Expand All @@ -26,6 +29,15 @@ func (s efStage) Wait() error {
return s.filter(s.Stage.Wait())
}

type efStage2 struct {
Stage2
filter ErrorFilter
}

func (s efStage2) Wait() error {
return s.filter(s.Stage2.Wait())
}

// ErrorMatcher decides whether its argument matches some class of
// errors (e.g., errors that we want to ignore). The function will
// only be invoked for non-nil errors.
Expand Down
51 changes: 45 additions & 6 deletions pipe/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
// StageFunc is a function that can be used to power a `goStage`. It
// should read its input from `stdin` and write its output to
// `stdout`. `stdin` and `stdout` will be closed automatically (if
// necessary) once the function returns.
// non-nil) once the function returns.
//
// Neither `stdin` nor `stdout` are necessarily buffered. If the
// `StageFunc` requires buffering, it needs to arrange that itself.
Expand Down Expand Up @@ -38,26 +38,65 @@ type goStage struct {
err error
}

var (
_ Stage2 = (*goStage)(nil)
)

func (s *goStage) Name() string {
return s.name
}

func (s *goStage) Preferences() StagePreferences {
return StagePreferences{
StdinPreference: IOPreferenceUndefined,
StdoutPreference: IOPreferenceUndefined,
}
}

func (s *goStage) Start(ctx context.Context, env Env, stdin io.ReadCloser) (io.ReadCloser, error) {
r, w := io.Pipe()
pr, pw := io.Pipe()

if err := s.Start2(ctx, env, stdin, pw); err != nil {
_ = pr.Close()
_ = pw.Close()
return nil, err
}

return pr, nil
}

func (s *goStage) Start2(
ctx context.Context, env Env, stdin io.ReadCloser, stdout io.WriteCloser,
) error {
var r io.Reader = stdin
if stdin, ok := stdin.(readerNopCloser); ok {
r = stdin.Reader
}

var w io.Writer = stdout
if stdout, ok := stdout.(writerNopCloser); ok {
w = stdout.Writer
}

go func() {
s.err = s.f(ctx, env, stdin, w)
if err := w.Close(); err != nil && s.err == nil {
s.err = fmt.Errorf("error closing output pipe for stage %q: %w", s.Name(), err)
s.err = s.f(ctx, env, r, w)

if stdout != nil {
if err := stdout.Close(); err != nil && s.err == nil {
s.err = fmt.Errorf("error closing stdout for stage %q: %w", s.Name(), err)
}
}

if stdin != nil {
if err := stdin.Close(); err != nil && s.err == nil {
s.err = fmt.Errorf("error closing stdin for stage %q: %w", s.Name(), err)
}
}

close(s.done)
}()

return r, nil
return nil
}

func (s *goStage) Wait() error {
Expand Down
Loading