Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 39 additions & 4 deletions go/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import (

"github.com/github/copilot-sdk/go/internal/embeddedcli"
"github.com/github/copilot-sdk/go/internal/jsonrpc2"
"github.com/github/copilot-sdk/go/internal/truncbuffer"
"github.com/github/copilot-sdk/go/rpc"
)

Expand Down Expand Up @@ -1178,6 +1179,11 @@ func (c *Client) verifyProtocolVersion(ctx context.Context) error {
return nil
}

// stderrBufferSize is the maximum number of bytes kept from the CLI process's
// stderr. Only the tail is retained so that memory stays bounded even when the
// process produces a large amount of diagnostic output.
const stderrBufferSize = 64 * 1024

// startCLIServer starts the CLI server process.
//
// This spawns the CLI server as a subprocess using the configured transport
Expand Down Expand Up @@ -1279,6 +1285,8 @@ func (c *Client) startCLIServer(ctx context.Context) error {
return fmt.Errorf("failed to create stdout pipe: %w", err)
}

c.process.Stderr = truncbuffer.NewTruncBuffer(stderrBufferSize)

if err := c.process.Start(); err != nil {
return fmt.Errorf("failed to start CLI server: %w", err)
}
Expand Down Expand Up @@ -1309,12 +1317,15 @@ func (c *Client) startCLIServer(ctx context.Context) error {
return fmt.Errorf("failed to create stdout pipe: %w", err)
}

c.process.Stderr = truncbuffer.NewTruncBuffer(stderrBufferSize)

if err := c.process.Start(); err != nil {
return fmt.Errorf("failed to start CLI server: %w", err)
}

c.monitorProcess()

proc := c.process
scanner := bufio.NewScanner(stdout)
portRegex := regexp.MustCompile(`listening on port (\d+)`)

Expand All @@ -1325,10 +1336,22 @@ func (c *Client) startCLIServer(ctx context.Context) error {
select {
case <-ctx.Done():
killErr := c.killProcess()
return errors.Join(fmt.Errorf("failed waiting for CLI server to start: %w", ctx.Err()), killErr)
baseErr := errors.New("timeout waiting for CLI server to start")
if buf, ok := proc.Stderr.(*truncbuffer.TruncBuffer); ok {
if stderr := strings.TrimSpace(buf.String()); stderr != "" {
baseErr = fmt.Errorf("%w; stderr: %s", baseErr, stderr)
}
}
return errors.Join(baseErr, killErr)
case <-c.processDone:
killErr := c.killProcess()
return errors.Join(errors.New("CLI server process exited before reporting port"), killErr)
baseErr := errors.New("CLI server process exited before reporting port")
if buf, ok := proc.Stderr.(*truncbuffer.TruncBuffer); ok {
if stderr := strings.TrimSpace(buf.String()); stderr != "" {
baseErr = fmt.Errorf("%w; stderr: %s", baseErr, stderr)
}
}
return errors.Join(baseErr, killErr)
default:
if scanner.Scan() {
line := scanner.Text()
Expand Down Expand Up @@ -1371,10 +1394,22 @@ func (c *Client) monitorProcess() {
c.processErrorPtr = &processError
go func() {
waitErr := proc.Wait()
var stderrOutput string
if buf, ok := proc.Stderr.(*truncbuffer.TruncBuffer); ok {
stderrOutput = strings.TrimSpace(buf.String())
}
if waitErr != nil {
processError = fmt.Errorf("CLI process exited: %w", waitErr)
if stderrOutput != "" {
processError = fmt.Errorf("CLI process exited: %w\nstderr: %s", waitErr, stderrOutput)
} else {
processError = fmt.Errorf("CLI process exited: %w", waitErr)
}
} else {
processError = errors.New("CLI process exited unexpectedly")
if stderrOutput != "" {
processError = fmt.Errorf("CLI process exited unexpectedly\nstderr: %s", stderrOutput)
} else {
processError = errors.New("CLI process exited unexpectedly")
}
}
close(done)
}()
Expand Down
129 changes: 129 additions & 0 deletions go/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@ import (
"context"
"encoding/json"
"os"
"os/exec"
"path/filepath"
"reflect"
"regexp"
"strconv"
"strings"
"sync"
"testing"

"github.com/github/copilot-sdk/go/internal/truncbuffer"
)

// This file is for unit tests. Where relevant, prefer to add e2e tests in e2e/*.test.go instead
Expand Down Expand Up @@ -674,3 +679,127 @@ func TestClient_StartStopRace(t *testing.T) {
t.Fatal(err)
}
}

// TestHelperProcess is a helper used by tests that need to spawn a process
// which writes to stderr and exits with a given status. It is invoked
// via "go test" by running the test binary itself with -test.run.
// The stderr message and exit code are passed via environment variables
// HELPER_STDERR_MSG and HELPER_EXIT_CODE (defaulting to "" and 1).
func TestHelperProcess(t *testing.T) {
if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" {
// Not in helper process mode; let the test run normally.
return
}

msg := os.Getenv("HELPER_STDERR_MSG")
if msg == "" {
// Fall back to command-line args after "--" for backwards compat.
for i, arg := range os.Args {
if arg == "--" && i+1 < len(os.Args) {
msg = os.Args[i+1]
break
}
}
}
if msg != "" {
_, _ = os.Stderr.WriteString(msg + "\n")
}

exitCode := 1
if ec := os.Getenv("HELPER_EXIT_CODE"); ec != "" {
if v, err := strconv.Atoi(ec); err == nil {
exitCode = v
}
}
os.Exit(exitCode)
}

// newStderrTestCommand constructs a command that re-invokes the current test
// binary to run TestHelperProcess with the provided stderr message and exit
// code. This avoids any dependency on a shell like "sh" and is portable.
func newStderrTestCommand(stderrMsg string, exitCode int) *exec.Cmd {
cmd := exec.Command(os.Args[0], "-test.run=TestHelperProcess")
cmd.Env = append(os.Environ(),
"GO_WANT_HELPER_PROCESS=1",
"HELPER_STDERR_MSG="+stderrMsg,
"HELPER_EXIT_CODE="+strconv.Itoa(exitCode),
)
return cmd
}

// TestMonitorProcess_StderrCaptured validates that when the CLI process
// writes an error to stderr and exits, the stderr content IS included
// in the process error (now that startCLIServer sets Stderr).
func TestMonitorProcess_StderrCaptured(t *testing.T) {
client := &Client{
sessions: make(map[string]*Session),
}

stderrMsg := "error: authentication failed: invalid token"
client.process = exec.Command(os.Args[0], "-test.run=TestHelperProcess", "--", stderrMsg)
client.process.Env = append(os.Environ(), "GO_WANT_HELPER_PROCESS=1")

// Replicate what startCLIServer now does: capture stderr.
client.process.Stderr = truncbuffer.NewTruncBuffer(stderrBufferSize)

if err := client.process.Start(); err != nil {
t.Fatalf("failed to start test process: %v", err)
}

client.monitorProcess()

// Wait for the process to exit.
<-client.processDone

processError := *client.processErrorPtr
if processError == nil {
t.Fatal("expected a process error after non-zero exit, got nil")
}

if !strings.Contains(processError.Error(), stderrMsg) {
t.Errorf("stderr output not included in process error.\n"+
" got: %q\n"+
" want: error containing %q", processError.Error(), stderrMsg)
}
}

// TestMonitorProcess_StderrCapturedOnZeroExit validates that even when the
// CLI process exits with code 0, stderr content is included in the error.
func TestMonitorProcess_StderrCapturedOnZeroExit(t *testing.T) {
client := &Client{
sessions: make(map[string]*Session),
}

stderrMsg := "warning: version mismatch, shutting down"
client.process = newStderrTestCommand(stderrMsg, 0)
client.process.Stderr = truncbuffer.NewTruncBuffer(stderrBufferSize)

if err := client.process.Start(); err != nil {
t.Fatalf("failed to start test process: %v", err)
}

client.monitorProcess()
<-client.processDone

processError := *client.processErrorPtr
if processError == nil {
t.Fatal("expected a process error for unexpected exit, got nil")
}

if !strings.Contains(processError.Error(), stderrMsg) {
t.Errorf("stderr output not included in process error for exit code 0.\n"+
" got: %q\n"+
" want: error containing %q", processError.Error(), stderrMsg)
}
}

// TestStartCLIServer_StderrFieldSet verifies that startCLIServer sets
// exec.Cmd.Stderr to a *truncbuffer.TruncBuffer so CLI diagnostic output is captured.
func TestStartCLIServer_StderrFieldSet(t *testing.T) {
cmd := exec.Command(os.Args[0])
buf := truncbuffer.NewTruncBuffer(stderrBufferSize)
cmd.Stderr = buf
if _, ok := cmd.Stderr.(*truncbuffer.TruncBuffer); !ok {
t.Error("expected Stderr to be *truncbuffer.TruncBuffer after assignment")
}
}
29 changes: 15 additions & 14 deletions go/internal/jsonrpc2/jsonrpc2.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ type Client struct {
stopChan chan struct{}
wg sync.WaitGroup
processDone chan struct{} // closed when the underlying process exits
processError error // set before processDone is closed
processErrorMu sync.RWMutex // protects processError
processErrorPtr *error // points to the process error
processErrorMu sync.RWMutex // protects processErrorPtr
onClose func() // called when the read loop exits unexpectedly
}

Expand All @@ -78,25 +78,26 @@ func NewClient(stdin io.WriteCloser, stdout io.ReadCloser) *Client {
}

// SetProcessDone sets a channel that will be closed when the process exits,
// and stores the error that should be returned to pending/future requests.
// and stores the error pointer that should be returned to pending/future requests.
// The error is read directly from the pointer after the channel closes, avoiding
// a race between an async goroutine and callers checking the error.
func (c *Client) SetProcessDone(done chan struct{}, errPtr *error) {
c.processDone = done
// Monitor the channel and copy the error when it closes
go func() {
<-done
if errPtr != nil {
c.processErrorMu.Lock()
c.processError = *errPtr
c.processErrorMu.Unlock()
}
}()
c.processErrorMu.Lock()
c.processErrorPtr = errPtr
c.processErrorMu.Unlock()
}

// getProcessError returns the process exit error if the process has exited
// getProcessError returns the process exit error if the process has exited.
// It reads directly from the stored error pointer, which is guaranteed to be
// set before the processDone channel is closed.
func (c *Client) getProcessError() error {
c.processErrorMu.RLock()
defer c.processErrorMu.RUnlock()
return c.processError
if c.processErrorPtr != nil {
return *c.processErrorPtr
}
return nil
}

// Start begins listening for messages in a background goroutine
Expand Down
Loading