Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 89 additions & 13 deletions go/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net"
"os"
"os/exec"
Expand Down Expand Up @@ -91,6 +92,9 @@ type Client struct {
processDone chan struct{}
processErrorPtr *error
osProcess atomic.Pointer[os.Process]
stderrBuf []byte
stderrBufMux sync.Mutex
stderrDone chan struct{} // closed when the current process's stderr drain goroutine finishes

// RPC provides typed server-scoped RPC methods.
// This field is nil until the client is connected via Start().
Expand Down Expand Up @@ -275,15 +279,17 @@ func (c *Client) Start(ctx context.Context) error {
// Connect to the server
if err := c.connectToServer(ctx); err != nil {
killErr := c.killProcess()
stderrErr := c.stderrError()
c.state = StateError
return errors.Join(err, killErr)
return errors.Join(err, killErr, stderrErr)
}

// Verify protocol version compatibility
if err := c.verifyProtocolVersion(ctx); err != nil {
killErr := c.killProcess()
stderrErr := c.stderrError()
c.state = StateError
return errors.Join(err, killErr)
return errors.Join(err, killErr, stderrErr)
}

c.state = StateConnected
Expand Down Expand Up @@ -335,6 +341,9 @@ func (c *Client) Stop() error {
}
}
c.process = nil
c.stderrBufMux.Lock()
c.stderrBuf = nil
c.stderrBufMux.Unlock()

// Close external TCP connection if exists
if c.isExternalServer && c.conn != nil {
Expand Down Expand Up @@ -408,6 +417,9 @@ func (c *Client) ForceStop() {
_ = c.killProcess() // Ignore errors since we're force stopping
}
c.process = nil
c.stderrBufMux.Lock()
c.stderrBuf = nil
c.stderrBufMux.Unlock()

// Close external TCP connection if exists
if c.isExternalServer && c.conn != nil {
Expand All @@ -434,6 +446,42 @@ func (c *Client) ForceStop() {
c.RPC = nil
}

func (c *Client) getStderrOutput() string {
c.stderrBufMux.Lock()
defer c.stderrBufMux.Unlock()
return string(c.stderrBuf)
}

func (c *Client) stderrError() error {
if output := c.getStderrOutput(); output != "" {
return errors.New("stderr: " + output)
}
return nil
}

func (c *Client) startStderrDrain(stderr io.ReadCloser, done chan struct{}) {
go func() {
defer close(done)
buf := make([]byte, 1024)
for {
n, err := stderr.Read(buf)
if n > 0 {
c.stderrBufMux.Lock()
// Append to buffer, keep tail if > 64KB
c.stderrBuf = append(c.stderrBuf, buf[:n]...)
if len(c.stderrBuf) > 64*1024 {
n := copy(c.stderrBuf, c.stderrBuf[len(c.stderrBuf)-64*1024:])
c.stderrBuf = c.stderrBuf[:n]
}
c.stderrBufMux.Unlock()
}
if err != nil {
return
}
}
}()
}

func (c *Client) ensureConnected() error {
if c.client != nil {
return nil
Expand Down Expand Up @@ -1105,6 +1153,11 @@ func (c *Client) startCLIServer(ctx context.Context) error {
c.process.Env = append(c.process.Env, "COPILOT_SDK_AUTH_TOKEN="+c.options.GitHubToken)
}

// Clear previous stderr buffer
c.stderrBufMux.Lock()
c.stderrBuf = nil
c.stderrBufMux.Unlock()

if c.useStdio {
// For stdio mode, we need stdin/stdout pipes
stdin, err := c.process.StdinPipe()
Expand All @@ -1117,11 +1170,17 @@ func (c *Client) startCLIServer(ctx context.Context) error {
return fmt.Errorf("failed to create stdout pipe: %w", err)
}

stderr, err := c.process.StderrPipe()
if err != nil {
return fmt.Errorf("failed to create stderr pipe: %w", err)
}

if err := c.process.Start(); err != nil {
return fmt.Errorf("failed to start CLI server: %w", err)
closeErr := stderr.Close()
return errors.Join(fmt.Errorf("failed to start CLI server: %w", err), closeErr)
}

c.monitorProcess()
c.monitorProcess(stderr)

// Create JSON-RPC client immediately
c.client = jsonrpc2.NewClient(stdin, stdout)
Expand All @@ -1138,11 +1197,17 @@ func (c *Client) startCLIServer(ctx context.Context) error {
return fmt.Errorf("failed to create stdout pipe: %w", err)
}

stderr, err := c.process.StderrPipe()
if err != nil {
return fmt.Errorf("failed to create stderr pipe: %w", err)
}

if err := c.process.Start(); err != nil {
return fmt.Errorf("failed to start CLI server: %w", err)
closeErr := stderr.Close()
return errors.Join(fmt.Errorf("failed to start CLI server: %w", err), closeErr)
}

c.monitorProcess()
c.monitorProcess(stderr)

scanner := bufio.NewScanner(stdout)
timeout := time.After(10 * time.Second)
Expand All @@ -1152,18 +1217,21 @@ func (c *Client) startCLIServer(ctx context.Context) error {
select {
case <-timeout:
killErr := c.killProcess()
return errors.Join(errors.New("timeout waiting for CLI server to start"), killErr)
stderrErr := c.stderrError()
return errors.Join(errors.New("timeout waiting for CLI server to start"), killErr, stderrErr)
case <-c.processDone:
killErr := c.killProcess()
return errors.Join(errors.New("CLI server process exited before reporting port"), killErr)
stderrErr := c.stderrError()
return errors.Join(errors.New("CLI server process exited before reporting port"), killErr, stderrErr)
default:
if scanner.Scan() {
line := scanner.Text()
if matches := portRegex.FindStringSubmatch(line); len(matches) > 1 {
port, err := strconv.Atoi(matches[1])
if err != nil {
killErr := c.killProcess()
return errors.Join(fmt.Errorf("failed to parse port: %w", err), killErr)
stderrErr := c.stderrError()
return errors.Join(fmt.Errorf("failed to parse port: %w", err), killErr, stderrErr)
}
c.actualPort = port
return nil
Expand All @@ -1175,21 +1243,28 @@ func (c *Client) startCLIServer(ctx context.Context) error {
}

func (c *Client) killProcess() error {
var err error
if p := c.osProcess.Swap(nil); p != nil {
if err := p.Kill(); err != nil {
return fmt.Errorf("failed to kill CLI process: %w", err)
if err = p.Kill(); err != nil {
err = fmt.Errorf("failed to kill CLI process: %w", err)
}
}
if c.stderrDone != nil {
<-c.stderrDone
}
c.process = nil
return nil
return err
}

// monitorProcess signals when the CLI process exits and captures any exit error.
// processError is intentionally a local: each process lifecycle gets its own
// error value, so goroutines from previous processes can't overwrite the
// current one. Closing the channel synchronizes with readers, guaranteeing
// they see the final processError value.
func (c *Client) monitorProcess() {
func (c *Client) monitorProcess(stderr io.ReadCloser) {
stderrDone := make(chan struct{})
c.stderrDone = stderrDone
c.startStderrDrain(stderr, stderrDone)
done := make(chan struct{})
c.processDone = done
proc := c.process
Expand All @@ -1198,6 +1273,7 @@ func (c *Client) monitorProcess() {
c.processErrorPtr = &processError
go func() {
waitErr := proc.Wait()
<-stderrDone
if waitErr != nil {
processError = fmt.Errorf("CLI process exited: %w", waitErr)
} else {
Expand Down
95 changes: 95 additions & 0 deletions go/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ package copilot
import (
"encoding/json"
"os"
"os/exec"
"path/filepath"
"reflect"
"regexp"
"strings"
"sync"
"testing"
)
Expand Down Expand Up @@ -528,3 +530,96 @@ func TestClient_StartStopRace(t *testing.T) {
t.Fatal(err)
}
}

func TestClient_StderrCapture(t *testing.T) {
buildStderrFixtureCLI := func(t *testing.T) string {
t.Helper()

tmpDir := t.TempDir()
mainPath := filepath.Join(tmpDir, "main.go")
binaryName := "stderr-fixture.exe"
binaryPath := filepath.Join(tmpDir, binaryName)

source := `package main

import (
"os"
"strconv"
"strings"
)

func main() {
if sizeString := os.Getenv("TEST_STDERR_SIZE"); sizeString != "" {
if size, _ := strconv.Atoi(sizeString); size > 0 {
_, _ = os.Stderr.WriteString(strings.Repeat("x", size))
}
}
if message := os.Getenv("TEST_STDERR"); message != "" {
_, _ = os.Stderr.WriteString(message)
}
os.Exit(1)
}`

if err := os.WriteFile(mainPath, []byte(source), 0600); err != nil {
t.Fatal(err)
}

cmd := exec.Command("go", "build", "-o", binaryPath, mainPath)
if output, err := cmd.CombinedOutput(); err != nil {
t.Fatal(err, "\n", string(output))
}

return binaryPath
}

fixturePath := buildStderrFixtureCLI(t)

t.Run("captures stderr on startup failure", func(t *testing.T) {
client := NewClient(&ClientOptions{
CLIPath: fixturePath,
Env: append(os.Environ(), "TEST_STDERR=something went wrong"),
})

err := client.Start(t.Context())
if err == nil {
t.Fatal("Expected error, got nil")
}

if !strings.Contains(err.Error(), "something went wrong") {
t.Errorf("Expected error to contain stderr output, got: %v", err)
}
})

t.Run("caps stderr buffer", func(t *testing.T) {
client := NewClient(&ClientOptions{
CLIPath: fixturePath,
Env: append(os.Environ(), "TEST_STDERR_SIZE=70000"),
})

if err := client.Start(t.Context()); err == nil {
t.Fatal("Expected error, got nil")
}

output := client.getStderrOutput()
if len(output) > 64*1024+100 { // Allow some slack but it should be close to 64KB
t.Errorf("Expected buffer to be capped around 64KB, got %d bytes", len(output))
}
})

t.Run("clears buffer on stop", func(t *testing.T) {
client := NewClient(&ClientOptions{})
client.stderrBufMux.Lock()
client.stderrBuf = []byte("dirty buffer")
client.stderrBufMux.Unlock()

if err := client.Stop(); err != nil {
t.Fatal(err)
}

client.stderrBufMux.Lock()
if client.stderrBuf != nil {
t.Error("Expected stderrBuf to be nil after Stop")
}
client.stderrBufMux.Unlock()
})
}
24 changes: 9 additions & 15 deletions go/internal/jsonrpc2/jsonrpc2.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ type Client struct {
stopChan chan struct{}
wg sync.WaitGroup
processDone chan struct{} // closed when the underlying process exits
processError error // set before processDone is closed
processErrorMu sync.RWMutex // protects processError
processErrorPtr *error // points to error set before processDone is closed
}

// NewClient creates a new JSON-RPC client
Expand All @@ -78,22 +77,17 @@ func NewClient(stdin io.WriteCloser, stdout io.ReadCloser) *Client {
// and stores the error that should be returned to pending/future requests.
func (c *Client) SetProcessDone(done chan struct{}, errPtr *error) {
c.processDone = done
// Monitor the channel and copy the error when it closes
go func() {
<-done
if errPtr != nil {
c.processErrorMu.Lock()
c.processError = *errPtr
c.processErrorMu.Unlock()
}
}()
c.processErrorPtr = errPtr
}

// getProcessError returns the process exit error if the process has exited
// getProcessError returns the process exit error if the process has exited.
// Must only be called after <-c.processDone to ensure visibility of the error
// written before close(done) in the monitor goroutine.
func (c *Client) getProcessError() error {
c.processErrorMu.RLock()
defer c.processErrorMu.RUnlock()
return c.processError
if c.processErrorPtr != nil {
return *c.processErrorPtr
}
return nil
}

// Start begins listening for messages in a background goroutine
Expand Down