diff --git a/go/client.go b/go/client.go index c88a68ac3..d522b3bce 100644 --- a/go/client.go +++ b/go/client.go @@ -34,6 +34,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "net" "os" "os/exec" @@ -91,6 +92,9 @@ type Client struct { processDone chan struct{} processErrorPtr *error osProcess atomic.Pointer[os.Process] + stderrBuf []byte + stderrBufMux sync.Mutex + stderrDone chan struct{} // closed when the current process's stderr drain goroutine finishes // RPC provides typed server-scoped RPC methods. // This field is nil until the client is connected via Start(). @@ -275,15 +279,17 @@ func (c *Client) Start(ctx context.Context) error { // Connect to the server if err := c.connectToServer(ctx); err != nil { killErr := c.killProcess() + stderrErr := c.stderrError() c.state = StateError - return errors.Join(err, killErr) + return errors.Join(err, killErr, stderrErr) } // Verify protocol version compatibility if err := c.verifyProtocolVersion(ctx); err != nil { killErr := c.killProcess() + stderrErr := c.stderrError() c.state = StateError - return errors.Join(err, killErr) + return errors.Join(err, killErr, stderrErr) } c.state = StateConnected @@ -335,6 +341,9 @@ func (c *Client) Stop() error { } } c.process = nil + c.stderrBufMux.Lock() + c.stderrBuf = nil + c.stderrBufMux.Unlock() // Close external TCP connection if exists if c.isExternalServer && c.conn != nil { @@ -408,6 +417,9 @@ func (c *Client) ForceStop() { _ = c.killProcess() // Ignore errors since we're force stopping } c.process = nil + c.stderrBufMux.Lock() + c.stderrBuf = nil + c.stderrBufMux.Unlock() // Close external TCP connection if exists if c.isExternalServer && c.conn != nil { @@ -434,6 +446,42 @@ func (c *Client) ForceStop() { c.RPC = nil } +func (c *Client) getStderrOutput() string { + c.stderrBufMux.Lock() + defer c.stderrBufMux.Unlock() + return string(c.stderrBuf) +} + +func (c *Client) stderrError() error { + if output := c.getStderrOutput(); output != "" { + return errors.New("stderr: " + output) + } + return nil +} + +func (c *Client) startStderrDrain(stderr io.ReadCloser, done chan struct{}) { + go func() { + defer close(done) + buf := make([]byte, 1024) + for { + n, err := stderr.Read(buf) + if n > 0 { + c.stderrBufMux.Lock() + // Append to buffer, keep tail if > 64KB + c.stderrBuf = append(c.stderrBuf, buf[:n]...) + if len(c.stderrBuf) > 64*1024 { + n := copy(c.stderrBuf, c.stderrBuf[len(c.stderrBuf)-64*1024:]) + c.stderrBuf = c.stderrBuf[:n] + } + c.stderrBufMux.Unlock() + } + if err != nil { + return + } + } + }() +} + func (c *Client) ensureConnected() error { if c.client != nil { return nil @@ -1105,6 +1153,11 @@ func (c *Client) startCLIServer(ctx context.Context) error { c.process.Env = append(c.process.Env, "COPILOT_SDK_AUTH_TOKEN="+c.options.GitHubToken) } + // Clear previous stderr buffer + c.stderrBufMux.Lock() + c.stderrBuf = nil + c.stderrBufMux.Unlock() + if c.useStdio { // For stdio mode, we need stdin/stdout pipes stdin, err := c.process.StdinPipe() @@ -1117,11 +1170,17 @@ func (c *Client) startCLIServer(ctx context.Context) error { return fmt.Errorf("failed to create stdout pipe: %w", err) } + stderr, err := c.process.StderrPipe() + if err != nil { + return fmt.Errorf("failed to create stderr pipe: %w", err) + } + if err := c.process.Start(); err != nil { - return fmt.Errorf("failed to start CLI server: %w", err) + closeErr := stderr.Close() + return errors.Join(fmt.Errorf("failed to start CLI server: %w", err), closeErr) } - c.monitorProcess() + c.monitorProcess(stderr) // Create JSON-RPC client immediately c.client = jsonrpc2.NewClient(stdin, stdout) @@ -1138,11 +1197,17 @@ func (c *Client) startCLIServer(ctx context.Context) error { return fmt.Errorf("failed to create stdout pipe: %w", err) } + stderr, err := c.process.StderrPipe() + if err != nil { + return fmt.Errorf("failed to create stderr pipe: %w", err) + } + if err := c.process.Start(); err != nil { - return fmt.Errorf("failed to start CLI server: %w", err) + closeErr := stderr.Close() + return errors.Join(fmt.Errorf("failed to start CLI server: %w", err), closeErr) } - c.monitorProcess() + c.monitorProcess(stderr) scanner := bufio.NewScanner(stdout) timeout := time.After(10 * time.Second) @@ -1152,10 +1217,12 @@ func (c *Client) startCLIServer(ctx context.Context) error { select { case <-timeout: killErr := c.killProcess() - return errors.Join(errors.New("timeout waiting for CLI server to start"), killErr) + stderrErr := c.stderrError() + return errors.Join(errors.New("timeout waiting for CLI server to start"), killErr, stderrErr) case <-c.processDone: killErr := c.killProcess() - return errors.Join(errors.New("CLI server process exited before reporting port"), killErr) + stderrErr := c.stderrError() + return errors.Join(errors.New("CLI server process exited before reporting port"), killErr, stderrErr) default: if scanner.Scan() { line := scanner.Text() @@ -1163,7 +1230,8 @@ func (c *Client) startCLIServer(ctx context.Context) error { port, err := strconv.Atoi(matches[1]) if err != nil { killErr := c.killProcess() - return errors.Join(fmt.Errorf("failed to parse port: %w", err), killErr) + stderrErr := c.stderrError() + return errors.Join(fmt.Errorf("failed to parse port: %w", err), killErr, stderrErr) } c.actualPort = port return nil @@ -1175,13 +1243,17 @@ func (c *Client) startCLIServer(ctx context.Context) error { } func (c *Client) killProcess() error { + var err error if p := c.osProcess.Swap(nil); p != nil { - if err := p.Kill(); err != nil { - return fmt.Errorf("failed to kill CLI process: %w", err) + if err = p.Kill(); err != nil { + err = fmt.Errorf("failed to kill CLI process: %w", err) } } + if c.stderrDone != nil { + <-c.stderrDone + } c.process = nil - return nil + return err } // monitorProcess signals when the CLI process exits and captures any exit error. @@ -1189,7 +1261,10 @@ func (c *Client) killProcess() error { // error value, so goroutines from previous processes can't overwrite the // current one. Closing the channel synchronizes with readers, guaranteeing // they see the final processError value. -func (c *Client) monitorProcess() { +func (c *Client) monitorProcess(stderr io.ReadCloser) { + stderrDone := make(chan struct{}) + c.stderrDone = stderrDone + c.startStderrDrain(stderr, stderrDone) done := make(chan struct{}) c.processDone = done proc := c.process @@ -1198,6 +1273,7 @@ func (c *Client) monitorProcess() { c.processErrorPtr = &processError go func() { waitErr := proc.Wait() + <-stderrDone if waitErr != nil { processError = fmt.Errorf("CLI process exited: %w", waitErr) } else { diff --git a/go/client_test.go b/go/client_test.go index 752bdc758..91e42b147 100644 --- a/go/client_test.go +++ b/go/client_test.go @@ -3,9 +3,11 @@ package copilot import ( "encoding/json" "os" + "os/exec" "path/filepath" "reflect" "regexp" + "strings" "sync" "testing" ) @@ -528,3 +530,96 @@ func TestClient_StartStopRace(t *testing.T) { t.Fatal(err) } } + +func TestClient_StderrCapture(t *testing.T) { + buildStderrFixtureCLI := func(t *testing.T) string { + t.Helper() + + tmpDir := t.TempDir() + mainPath := filepath.Join(tmpDir, "main.go") + binaryName := "stderr-fixture.exe" + binaryPath := filepath.Join(tmpDir, binaryName) + + source := `package main + + import ( + "os" + "strconv" + "strings" + ) + + func main() { + if sizeString := os.Getenv("TEST_STDERR_SIZE"); sizeString != "" { + if size, _ := strconv.Atoi(sizeString); size > 0 { + _, _ = os.Stderr.WriteString(strings.Repeat("x", size)) + } + } + if message := os.Getenv("TEST_STDERR"); message != "" { + _, _ = os.Stderr.WriteString(message) + } + os.Exit(1) + }` + + if err := os.WriteFile(mainPath, []byte(source), 0600); err != nil { + t.Fatal(err) + } + + cmd := exec.Command("go", "build", "-o", binaryPath, mainPath) + if output, err := cmd.CombinedOutput(); err != nil { + t.Fatal(err, "\n", string(output)) + } + + return binaryPath + } + + fixturePath := buildStderrFixtureCLI(t) + + t.Run("captures stderr on startup failure", func(t *testing.T) { + client := NewClient(&ClientOptions{ + CLIPath: fixturePath, + Env: append(os.Environ(), "TEST_STDERR=something went wrong"), + }) + + err := client.Start(t.Context()) + if err == nil { + t.Fatal("Expected error, got nil") + } + + if !strings.Contains(err.Error(), "something went wrong") { + t.Errorf("Expected error to contain stderr output, got: %v", err) + } + }) + + t.Run("caps stderr buffer", func(t *testing.T) { + client := NewClient(&ClientOptions{ + CLIPath: fixturePath, + Env: append(os.Environ(), "TEST_STDERR_SIZE=70000"), + }) + + if err := client.Start(t.Context()); err == nil { + t.Fatal("Expected error, got nil") + } + + output := client.getStderrOutput() + if len(output) > 64*1024+100 { // Allow some slack but it should be close to 64KB + t.Errorf("Expected buffer to be capped around 64KB, got %d bytes", len(output)) + } + }) + + t.Run("clears buffer on stop", func(t *testing.T) { + client := NewClient(&ClientOptions{}) + client.stderrBufMux.Lock() + client.stderrBuf = []byte("dirty buffer") + client.stderrBufMux.Unlock() + + if err := client.Stop(); err != nil { + t.Fatal(err) + } + + client.stderrBufMux.Lock() + if client.stderrBuf != nil { + t.Error("Expected stderrBuf to be nil after Stop") + } + client.stderrBufMux.Unlock() + }) +} diff --git a/go/internal/jsonrpc2/jsonrpc2.go b/go/internal/jsonrpc2/jsonrpc2.go index 09505c06d..997828d51 100644 --- a/go/internal/jsonrpc2/jsonrpc2.go +++ b/go/internal/jsonrpc2/jsonrpc2.go @@ -59,8 +59,7 @@ type Client struct { stopChan chan struct{} wg sync.WaitGroup processDone chan struct{} // closed when the underlying process exits - processError error // set before processDone is closed - processErrorMu sync.RWMutex // protects processError + processErrorPtr *error // points to error set before processDone is closed } // NewClient creates a new JSON-RPC client @@ -78,22 +77,17 @@ func NewClient(stdin io.WriteCloser, stdout io.ReadCloser) *Client { // and stores the error that should be returned to pending/future requests. func (c *Client) SetProcessDone(done chan struct{}, errPtr *error) { c.processDone = done - // Monitor the channel and copy the error when it closes - go func() { - <-done - if errPtr != nil { - c.processErrorMu.Lock() - c.processError = *errPtr - c.processErrorMu.Unlock() - } - }() + c.processErrorPtr = errPtr } -// getProcessError returns the process exit error if the process has exited +// getProcessError returns the process exit error if the process has exited. +// Must only be called after <-c.processDone to ensure visibility of the error +// written before close(done) in the monitor goroutine. func (c *Client) getProcessError() error { - c.processErrorMu.RLock() - defer c.processErrorMu.RUnlock() - return c.processError + if c.processErrorPtr != nil { + return *c.processErrorPtr + } + return nil } // Start begins listening for messages in a background goroutine