diff --git a/go/client.go b/go/client.go index 22be47ec6..e8665c4b2 100644 --- a/go/client.go +++ b/go/client.go @@ -48,6 +48,7 @@ import ( "github.com/github/copilot-sdk/go/internal/embeddedcli" "github.com/github/copilot-sdk/go/internal/jsonrpc2" + "github.com/github/copilot-sdk/go/internal/truncbuffer" "github.com/github/copilot-sdk/go/rpc" ) @@ -1178,6 +1179,11 @@ func (c *Client) verifyProtocolVersion(ctx context.Context) error { return nil } +// stderrBufferSize is the maximum number of bytes kept from the CLI process's +// stderr. Only the tail is retained so that memory stays bounded even when the +// process produces a large amount of diagnostic output. +const stderrBufferSize = 64 * 1024 + // startCLIServer starts the CLI server process. // // This spawns the CLI server as a subprocess using the configured transport @@ -1279,6 +1285,8 @@ func (c *Client) startCLIServer(ctx context.Context) error { return fmt.Errorf("failed to create stdout pipe: %w", err) } + c.process.Stderr = truncbuffer.NewTruncBuffer(stderrBufferSize) + if err := c.process.Start(); err != nil { return fmt.Errorf("failed to start CLI server: %w", err) } @@ -1309,12 +1317,15 @@ func (c *Client) startCLIServer(ctx context.Context) error { return fmt.Errorf("failed to create stdout pipe: %w", err) } + c.process.Stderr = truncbuffer.NewTruncBuffer(stderrBufferSize) + if err := c.process.Start(); err != nil { return fmt.Errorf("failed to start CLI server: %w", err) } c.monitorProcess() + proc := c.process scanner := bufio.NewScanner(stdout) portRegex := regexp.MustCompile(`listening on port (\d+)`) @@ -1325,10 +1336,22 @@ func (c *Client) startCLIServer(ctx context.Context) error { select { case <-ctx.Done(): killErr := c.killProcess() - return errors.Join(fmt.Errorf("failed waiting for CLI server to start: %w", ctx.Err()), killErr) + baseErr := errors.New("timeout waiting for CLI server to start") + if buf, ok := proc.Stderr.(*truncbuffer.TruncBuffer); ok { + if stderr := strings.TrimSpace(buf.String()); stderr != "" { + baseErr = fmt.Errorf("%w; stderr: %s", baseErr, stderr) + } + } + return errors.Join(baseErr, killErr) case <-c.processDone: killErr := c.killProcess() - return errors.Join(errors.New("CLI server process exited before reporting port"), killErr) + baseErr := errors.New("CLI server process exited before reporting port") + if buf, ok := proc.Stderr.(*truncbuffer.TruncBuffer); ok { + if stderr := strings.TrimSpace(buf.String()); stderr != "" { + baseErr = fmt.Errorf("%w; stderr: %s", baseErr, stderr) + } + } + return errors.Join(baseErr, killErr) default: if scanner.Scan() { line := scanner.Text() @@ -1371,10 +1394,22 @@ func (c *Client) monitorProcess() { c.processErrorPtr = &processError go func() { waitErr := proc.Wait() + var stderrOutput string + if buf, ok := proc.Stderr.(*truncbuffer.TruncBuffer); ok { + stderrOutput = strings.TrimSpace(buf.String()) + } if waitErr != nil { - processError = fmt.Errorf("CLI process exited: %w", waitErr) + if stderrOutput != "" { + processError = fmt.Errorf("CLI process exited: %w\nstderr: %s", waitErr, stderrOutput) + } else { + processError = fmt.Errorf("CLI process exited: %w", waitErr) + } } else { - processError = errors.New("CLI process exited unexpectedly") + if stderrOutput != "" { + processError = fmt.Errorf("CLI process exited unexpectedly\nstderr: %s", stderrOutput) + } else { + processError = errors.New("CLI process exited unexpectedly") + } } close(done) }() diff --git a/go/client_test.go b/go/client_test.go index d7a526cab..cd804f308 100644 --- a/go/client_test.go +++ b/go/client_test.go @@ -4,11 +4,16 @@ import ( "context" "encoding/json" "os" + "os/exec" "path/filepath" "reflect" "regexp" + "strconv" + "strings" "sync" "testing" + + "github.com/github/copilot-sdk/go/internal/truncbuffer" ) // This file is for unit tests. Where relevant, prefer to add e2e tests in e2e/*.test.go instead @@ -674,3 +679,127 @@ func TestClient_StartStopRace(t *testing.T) { t.Fatal(err) } } + +// TestHelperProcess is a helper used by tests that need to spawn a process +// which writes to stderr and exits with a given status. It is invoked +// via "go test" by running the test binary itself with -test.run. +// The stderr message and exit code are passed via environment variables +// HELPER_STDERR_MSG and HELPER_EXIT_CODE (defaulting to "" and 1). +func TestHelperProcess(t *testing.T) { + if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" { + // Not in helper process mode; let the test run normally. + return + } + + msg := os.Getenv("HELPER_STDERR_MSG") + if msg == "" { + // Fall back to command-line args after "--" for backwards compat. + for i, arg := range os.Args { + if arg == "--" && i+1 < len(os.Args) { + msg = os.Args[i+1] + break + } + } + } + if msg != "" { + _, _ = os.Stderr.WriteString(msg + "\n") + } + + exitCode := 1 + if ec := os.Getenv("HELPER_EXIT_CODE"); ec != "" { + if v, err := strconv.Atoi(ec); err == nil { + exitCode = v + } + } + os.Exit(exitCode) +} + +// newStderrTestCommand constructs a command that re-invokes the current test +// binary to run TestHelperProcess with the provided stderr message and exit +// code. This avoids any dependency on a shell like "sh" and is portable. +func newStderrTestCommand(stderrMsg string, exitCode int) *exec.Cmd { + cmd := exec.Command(os.Args[0], "-test.run=TestHelperProcess") + cmd.Env = append(os.Environ(), + "GO_WANT_HELPER_PROCESS=1", + "HELPER_STDERR_MSG="+stderrMsg, + "HELPER_EXIT_CODE="+strconv.Itoa(exitCode), + ) + return cmd +} + +// TestMonitorProcess_StderrCaptured validates that when the CLI process +// writes an error to stderr and exits, the stderr content IS included +// in the process error (now that startCLIServer sets Stderr). +func TestMonitorProcess_StderrCaptured(t *testing.T) { + client := &Client{ + sessions: make(map[string]*Session), + } + + stderrMsg := "error: authentication failed: invalid token" + client.process = exec.Command(os.Args[0], "-test.run=TestHelperProcess", "--", stderrMsg) + client.process.Env = append(os.Environ(), "GO_WANT_HELPER_PROCESS=1") + + // Replicate what startCLIServer now does: capture stderr. + client.process.Stderr = truncbuffer.NewTruncBuffer(stderrBufferSize) + + if err := client.process.Start(); err != nil { + t.Fatalf("failed to start test process: %v", err) + } + + client.monitorProcess() + + // Wait for the process to exit. + <-client.processDone + + processError := *client.processErrorPtr + if processError == nil { + t.Fatal("expected a process error after non-zero exit, got nil") + } + + if !strings.Contains(processError.Error(), stderrMsg) { + t.Errorf("stderr output not included in process error.\n"+ + " got: %q\n"+ + " want: error containing %q", processError.Error(), stderrMsg) + } +} + +// TestMonitorProcess_StderrCapturedOnZeroExit validates that even when the +// CLI process exits with code 0, stderr content is included in the error. +func TestMonitorProcess_StderrCapturedOnZeroExit(t *testing.T) { + client := &Client{ + sessions: make(map[string]*Session), + } + + stderrMsg := "warning: version mismatch, shutting down" + client.process = newStderrTestCommand(stderrMsg, 0) + client.process.Stderr = truncbuffer.NewTruncBuffer(stderrBufferSize) + + if err := client.process.Start(); err != nil { + t.Fatalf("failed to start test process: %v", err) + } + + client.monitorProcess() + <-client.processDone + + processError := *client.processErrorPtr + if processError == nil { + t.Fatal("expected a process error for unexpected exit, got nil") + } + + if !strings.Contains(processError.Error(), stderrMsg) { + t.Errorf("stderr output not included in process error for exit code 0.\n"+ + " got: %q\n"+ + " want: error containing %q", processError.Error(), stderrMsg) + } +} + +// TestStartCLIServer_StderrFieldSet verifies that startCLIServer sets +// exec.Cmd.Stderr to a *truncbuffer.TruncBuffer so CLI diagnostic output is captured. +func TestStartCLIServer_StderrFieldSet(t *testing.T) { + cmd := exec.Command(os.Args[0]) + buf := truncbuffer.NewTruncBuffer(stderrBufferSize) + cmd.Stderr = buf + if _, ok := cmd.Stderr.(*truncbuffer.TruncBuffer); !ok { + t.Error("expected Stderr to be *truncbuffer.TruncBuffer after assignment") + } +} diff --git a/go/internal/jsonrpc2/jsonrpc2.go b/go/internal/jsonrpc2/jsonrpc2.go index fbc5b931c..a9c8b07d4 100644 --- a/go/internal/jsonrpc2/jsonrpc2.go +++ b/go/internal/jsonrpc2/jsonrpc2.go @@ -61,8 +61,8 @@ type Client struct { stopChan chan struct{} wg sync.WaitGroup processDone chan struct{} // closed when the underlying process exits - processError error // set before processDone is closed - processErrorMu sync.RWMutex // protects processError + processErrorPtr *error // points to the process error + processErrorMu sync.RWMutex // protects processErrorPtr onClose func() // called when the read loop exits unexpectedly } @@ -78,25 +78,26 @@ func NewClient(stdin io.WriteCloser, stdout io.ReadCloser) *Client { } // SetProcessDone sets a channel that will be closed when the process exits, -// and stores the error that should be returned to pending/future requests. +// and stores the error pointer that should be returned to pending/future requests. +// The error is read directly from the pointer after the channel closes, avoiding +// a race between an async goroutine and callers checking the error. func (c *Client) SetProcessDone(done chan struct{}, errPtr *error) { c.processDone = done - // Monitor the channel and copy the error when it closes - go func() { - <-done - if errPtr != nil { - c.processErrorMu.Lock() - c.processError = *errPtr - c.processErrorMu.Unlock() - } - }() + c.processErrorMu.Lock() + c.processErrorPtr = errPtr + c.processErrorMu.Unlock() } -// getProcessError returns the process exit error if the process has exited +// getProcessError returns the process exit error if the process has exited. +// It reads directly from the stored error pointer, which is guaranteed to be +// set before the processDone channel is closed. func (c *Client) getProcessError() error { c.processErrorMu.RLock() defer c.processErrorMu.RUnlock() - return c.processError + if c.processErrorPtr != nil { + return *c.processErrorPtr + } + return nil } // Start begins listening for messages in a background goroutine diff --git a/go/internal/jsonrpc2/jsonrpc2_test.go b/go/internal/jsonrpc2/jsonrpc2_test.go index 9f542049d..26aa5a472 100644 --- a/go/internal/jsonrpc2/jsonrpc2_test.go +++ b/go/internal/jsonrpc2/jsonrpc2_test.go @@ -1,6 +1,7 @@ package jsonrpc2 import ( + "errors" "io" "sync" "testing" @@ -67,3 +68,120 @@ func TestOnCloseNotCalledOnIntentionalStop(t *testing.T) { t.Error("onClose should not be called on intentional Stop()") } } + +// TestSetProcessDone_ErrorAvailableImmediately validates that getProcessError() +// returns the correct error immediately after processDone is closed. +// The current implementation stores a pointer to the process error +// synchronously when the processDone channel is closed, so callers should +// never observe a nil error after the channel has been closed. +func TestSetProcessDone_ErrorAvailableImmediately(t *testing.T) { + misses := 0 + const iterations = 1000 + + for i := 0; i < iterations; i++ { + stdinR, stdinW := io.Pipe() + stdoutR, stdoutW := io.Pipe() + + client := NewClient(stdinW, stdoutR) + + done := make(chan struct{}) + processErr := errors.New("CLI process exited: exit status 1") + + client.SetProcessDone(done, &processErr) + + // Simulate process exit: error is already set, close the channel. + close(done) + + // Do NOT yield to the scheduler — check immediately. + // In the current code the goroutine inside SetProcessDone may not + // have copied the error to client.processError yet. + if err := client.getProcessError(); err == nil { + misses++ + } + + stdinR.Close() + stdinW.Close() + stdoutR.Close() + stdoutW.Close() + } + + if misses > 0 { + t.Errorf("SetProcessDone regression: getProcessError() returned nil %d/%d times "+ + "immediately after processDone was closed, even though the error pointer "+ + "should be stored synchronously.", misses, iterations) + } +} + +// TestSetProcessDone_RequestMissesProcessError validates that the Request() +// method returns the specific process error instead of the generic +// "process exited unexpectedly" message once processDone has been closed. +func TestSetProcessDone_RequestMissesProcessError(t *testing.T) { + misses := 0 + const iterations = 100 + + for i := 0; i < iterations; i++ { + stdinR, stdinW := io.Pipe() + stdoutR, stdoutW := io.Pipe() + + client := NewClient(stdinW, stdoutR) + client.Start() + + done := make(chan struct{}) + processErr := errors.New("CLI process exited: authentication failed") + + client.SetProcessDone(done, &processErr) + + // Simulate process exit. + close(done) + // Close the writer so the readLoop can exit. + stdoutW.Close() + + // Make a request — should get the specific process error. + _, err := client.Request("test.method", nil) + if err != nil && err.Error() == "process exited unexpectedly" { + misses++ + } + + client.Stop() + stdinR.Close() + stdinW.Close() + stdoutR.Close() + } + + if misses > 0 { + t.Errorf("Request() bug: returned generic 'process exited unexpectedly' %d/%d times "+ + "instead of the actual process error after process exit; the process "+ + "error was not correctly propagated from SetProcessDone.", misses, iterations) + } +} + +// TestSetProcessDone_ErrorAvailableImmediately verifies that the process error +// is available as soon as the done channel is closed, matching the +// pointer-based implementation where no asynchronous copy is required. +func TestSetProcessDone_ErrorCopiedEventually(t *testing.T) { + stdinR, stdinW := io.Pipe() + stdoutR, stdoutW := io.Pipe() + defer stdinR.Close() + defer stdinW.Close() + defer stdoutR.Close() + defer stdoutW.Close() + + client := NewClient(stdinW, stdoutR) + + done := make(chan struct{}) + processErr := errors.New("CLI process exited: version mismatch") + + client.SetProcessDone(done, &processErr) + + // Close the channel: the process error should now be observable immediately, + // without needing to yield to another goroutine. + close(done) + + err := client.getProcessError() + if err == nil { + t.Fatal("expected process error to be available immediately after done is closed, got nil") + } + if err.Error() != processErr.Error() { + t.Errorf("expected %q, got %q", processErr.Error(), err.Error()) + } +}