diff --git a/go.mod b/go.mod index e5084c8..11b6473 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,9 @@ go 1.25.6 require ( cdr.dev/slog/v3 v3.0.0-rc1 - github.com/coder/coder/v2 v2.31.0 + github.com/coder/coder/v2 v2.30.1 + github.com/coder/retry v1.5.1 + github.com/coder/websocket v1.8.14 github.com/docker/docker v28.5.2+incompatible github.com/docker/go-connections v0.6.0 github.com/google/uuid v1.6.0 @@ -63,7 +65,6 @@ require ( github.com/coder/pretty v0.0.0-20230908205945-e89ba86370e0 // indirect github.com/coder/serpent v0.13.0 // indirect github.com/coder/terraform-provider-coder/v2 v2.13.1 // indirect - github.com/coder/websocket v1.8.14 // indirect github.com/containerd/errdefs v1.0.0 // indirect github.com/containerd/errdefs/pkg v0.3.0 // indirect github.com/coreos/go-oidc/v3 v3.17.0 // indirect diff --git a/go.sum b/go.sum index c747d55..519e9ee 100644 --- a/go.sum +++ b/go.sum @@ -118,10 +118,12 @@ github.com/clipperhouse/uax29/v2 v2.3.0 h1:SNdx9DVUqMoBuBoW3iLOj4FQv3dN5mDtuqwuh github.com/clipperhouse/uax29/v2 v2.3.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g= github.com/cloudflare/circl v1.6.3 h1:9GPOhQGF9MCYUeXyMYlqTR6a5gTrgR/fBLXvUgtVcg8= github.com/cloudflare/circl v1.6.3/go.mod h1:2eXP6Qfat4O/Yhh8BznvKnJ+uzEoTQ6jVKJRn81BiS4= -github.com/coder/coder/v2 v2.31.0 h1:Qx6yLXyaaxZ5k9UOTmUQ9457iz7togEnHtc05xylIks= -github.com/coder/coder/v2 v2.31.0/go.mod h1:2DE8CPEuDj07THeNcbClDlv0bxR9ThA3A4TOHxBHuZY= +github.com/coder/coder/v2 v2.30.1 h1:5dxGKxWx80xb6lNd958y8Y4h3fBbQubDgIooHTTv/RQ= +github.com/coder/coder/v2 v2.30.1/go.mod h1:w40ThqnpVr727SVnu3wwUrK2woxNx1MrV1zVxxABimk= github.com/coder/pretty v0.0.0-20230908205945-e89ba86370e0 h1:3A0ES21Ke+FxEM8CXx9n47SZOKOpgSE1bbJzlE4qPVs= github.com/coder/pretty v0.0.0-20230908205945-e89ba86370e0/go.mod h1:5UuS2Ts+nTToAMeOjNlnHFkPahrtDkmpydBen/3wgZc= +github.com/coder/retry v1.5.1 h1:iWu8YnD8YqHs3XwqrqsjoBTAVqT9ml6z9ViJ2wlMiqc= +github.com/coder/retry v1.5.1/go.mod h1:blHMk9vs6LkoRT9ZHyuZo360cufXEhrxqvEzeMtRGoY= github.com/coder/serpent v0.13.0 h1:6EoWjpEypkb8cS6i0eCF4qoAv9vrEVaX26RW+3FMMvo= github.com/coder/serpent v0.13.0/go.mod h1:7OIvFBYMd+OqarMy5einBl8AtRr8LliopVU7pyrwucY= github.com/coder/terraform-provider-coder/v2 v2.13.1 h1:dtPaJUvueFm+XwBPUMWQCc5Z1QUQBW4B4RNyzX4h4y8= diff --git a/internal/provider/template_resource.go b/internal/provider/template_resource.go index 84de0f1..f053773 100644 --- a/internal/provider/template_resource.go +++ b/internal/provider/template_resource.go @@ -8,11 +8,13 @@ import ( "io" "slices" "strings" + "time" "cdr.dev/slog/v3" "github.com/coder/coder/v2/coderd/util/ptr" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/provisionersdk" + "github.com/coder/retry" "github.com/coder/terraform-provider-coderd/internal/codersdkvalidator" "github.com/google/uuid" "github.com/hashicorp/terraform-plugin-framework-validators/listvalidator" @@ -1104,49 +1106,70 @@ func uploadDirectory(ctx context.Context, client *codersdk.Client, logger slog.L func waitForJob(ctx context.Context, client *codersdk.Client, version *codersdk.TemplateVersion) ([]codersdk.ProvisionerJobLog, error) { const maxRetries = 3 - var jobLogs []codersdk.ProvisionerJobLog - for retries := 0; retries < maxRetries; retries++ { - logs, closer, err := client.TemplateVersionLogsAfter(ctx, version.ID, 0) + var allLogs []codersdk.ProvisionerJobLog + var lastLogID int64 + + for attempts, retrier := 0, retry.New(500*time.Millisecond, 5*time.Second); attempts < maxRetries && retrier.Wait(ctx); attempts++ { + logs, done, err := waitForJobOnce(ctx, client, version, lastLogID) + allLogs = append(allLogs, logs...) + if len(logs) > 0 { + lastLogID = logs[len(logs)-1].ID + } if err != nil { - return jobLogs, fmt.Errorf("begin streaming logs: %w", err) + return allLogs, err } - defer func() { - if err := closer.Close(); err != nil { - tflog.Warn(ctx, "error closing template version log stream", map[string]any{ - "error": err, - }) - } - }() - for { - logs, ok := <-logs - if !ok { - break - } - tflog.Info(ctx, logs.Output, map[string]interface{}{ - "job_id": logs.ID, - "job_stage": logs.Stage, - "log_source": logs.Source, - "level": logs.Level, - "created_at": logs.CreatedAt, - }) - if logs.Output != "" { - jobLogs = append(jobLogs, logs) - } + if done { + return allLogs, nil } - latestResp, err := client.TemplateVersion(ctx, version.ID) - if err != nil { - return jobLogs, err + tflog.Warn(ctx, fmt.Sprintf("provisioner job still active, retrying (attempt %d/%d)", attempts+1, maxRetries)) + } + + if err := ctx.Err(); err != nil { + return allLogs, err + } + return allLogs, fmt.Errorf("provisioner job did not complete after %d retries", maxRetries) +} + +func waitForJobOnce(ctx context.Context, client *codersdk.Client, version *codersdk.TemplateVersion, after int64) ([]codersdk.ProvisionerJobLog, bool, error) { + logCh, closer, err := client.TemplateVersionLogsAfter(ctx, version.ID, after) + if err != nil { + return nil, false, fmt.Errorf("begin streaming logs: %w", err) + } + defer func() { + if err := closer.Close(); err != nil { + tflog.Warn(ctx, "error closing template version log stream", map[string]any{ + "error": err, + }) } - if latestResp.Job.Status.Active() { - tflog.Warn(ctx, fmt.Sprintf("provisioner job still active, continuing to wait...: %s", latestResp.Job.Status)) - continue + }() + var jobLogs []codersdk.ProvisionerJobLog + for { + logMsg, ok := <-logCh + if !ok { + break } - if latestResp.Job.Status != codersdk.ProvisionerJobSucceeded { - return jobLogs, fmt.Errorf("provisioner job did not succeed: %s (%s)", latestResp.Job.Status, latestResp.Job.Error) + tflog.Info(ctx, logMsg.Output, map[string]interface{}{ + "job_id": logMsg.ID, + "job_stage": logMsg.Stage, + "log_source": logMsg.Source, + "level": logMsg.Level, + "created_at": logMsg.CreatedAt, + }) + if logMsg.Output != "" { + jobLogs = append(jobLogs, logMsg) } - return jobLogs, nil } - return jobLogs, fmt.Errorf("provisioner job did not complete after %d retries", maxRetries) + latestResp, err := client.TemplateVersion(ctx, version.ID) + if err != nil { + return jobLogs, false, err + } + if latestResp.Job.Status.Active() { + return jobLogs, false, nil + } + if latestResp.Job.Status != codersdk.ProvisionerJobSucceeded { + return jobLogs, false, fmt.Errorf("provisioner job did not succeed: %s (%s)", latestResp.Job.Status, latestResp.Job.Error) + } + return jobLogs, true, nil } type newVersionRequest struct { diff --git a/internal/provider/wait_for_job_test.go b/internal/provider/wait_for_job_test.go new file mode 100644 index 0000000..2dd2dba --- /dev/null +++ b/internal/provider/wait_for_job_test.go @@ -0,0 +1,249 @@ +package provider + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync/atomic" + "testing" + + "github.com/coder/coder/v2/codersdk" + "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +func TestWaitForJobOnce_Success(t *testing.T) { + t.Parallel() + versionID := uuid.New() + + handler := http.NewServeMux() + handler.HandleFunc("/api/v2/templateversions/", func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.RawQuery, "follow") { + conn, err := websocket.Accept(w, r, nil) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + ctx := r.Context() + _ = wsjson.Write(ctx, conn, codersdk.ProvisionerJobLog{ + ID: 1, + Output: "test log line", + }) + _ = conn.Close(websocket.StatusNormalClosure, "done") + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(codersdk.TemplateVersion{ + ID: versionID, + Job: codersdk.ProvisionerJob{ + Status: codersdk.ProvisionerJobSucceeded, + }, + }) + }) + + srv := httptest.NewServer(handler) + t.Cleanup(srv.Close) + srvURL, err := url.Parse(srv.URL) + require.NoError(t, err) + client := codersdk.New(srvURL) + + version := &codersdk.TemplateVersion{ID: versionID} + logs, done, err := waitForJobOnce(context.Background(), client, version, 0) + require.NoError(t, err) + require.True(t, done) + require.Len(t, logs, 1) + require.Equal(t, "test log line", logs[0].Output) +} + +func TestWaitForJobOnce_JobFailed(t *testing.T) { + t.Parallel() + versionID := uuid.New() + + handler := http.NewServeMux() + handler.HandleFunc("/api/v2/templateversions/", func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.RawQuery, "follow") { + conn, err := websocket.Accept(w, r, nil) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + _ = conn.Close(websocket.StatusNormalClosure, "done") + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(codersdk.TemplateVersion{ + ID: versionID, + Job: codersdk.ProvisionerJob{ + Status: codersdk.ProvisionerJobFailed, + Error: "something went wrong", + }, + }) + }) + + srv := httptest.NewServer(handler) + t.Cleanup(srv.Close) + srvURL, err := url.Parse(srv.URL) + require.NoError(t, err) + client := codersdk.New(srvURL) + + version := &codersdk.TemplateVersion{ID: versionID} + _, done, err := waitForJobOnce(context.Background(), client, version, 0) + require.Error(t, err) + require.False(t, done) + require.Contains(t, err.Error(), "provisioner job did not succeed") + require.Contains(t, err.Error(), "something went wrong") +} + +func TestWaitForJobOnce_StillActive(t *testing.T) { + t.Parallel() + versionID := uuid.New() + + handler := http.NewServeMux() + handler.HandleFunc("/api/v2/templateversions/", func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.RawQuery, "follow") { + conn, err := websocket.Accept(w, r, nil) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + _ = conn.Close(websocket.StatusNormalClosure, "done") + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(codersdk.TemplateVersion{ + ID: versionID, + Job: codersdk.ProvisionerJob{ + Status: codersdk.ProvisionerJobRunning, + }, + }) + }) + + srv := httptest.NewServer(handler) + t.Cleanup(srv.Close) + srvURL, err := url.Parse(srv.URL) + require.NoError(t, err) + client := codersdk.New(srvURL) + + version := &codersdk.TemplateVersion{ID: versionID} + _, done, err := waitForJobOnce(context.Background(), client, version, 0) + require.NoError(t, err) + require.False(t, done) +} + +func TestWaitForJob_UsesAfterCursorAcrossRetries(t *testing.T) { + t.Parallel() + versionID := uuid.New() + var versionCallCount atomic.Int32 + var wsCallCount atomic.Int32 + secondAfterCh := make(chan string, 1) + + handler := http.NewServeMux() + handler.HandleFunc("/api/v2/templateversions/", func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.RawQuery, "follow") { + call := wsCallCount.Add(1) + if call == 2 { + secondAfterCh <- r.URL.Query().Get("after") + } + + conn, err := websocket.Accept(w, r, nil) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + ctx := r.Context() + if call == 1 { + _ = wsjson.Write(ctx, conn, codersdk.ProvisionerJobLog{ID: 1, Output: "log 1"}) + _ = wsjson.Write(ctx, conn, codersdk.ProvisionerJobLog{ID: 2, Output: "log 2"}) + _ = wsjson.Write(ctx, conn, codersdk.ProvisionerJobLog{ID: 3, Output: "log 3"}) + } else { + _ = wsjson.Write(ctx, conn, codersdk.ProvisionerJobLog{ID: 4, Output: "log 4"}) + _ = wsjson.Write(ctx, conn, codersdk.ProvisionerJobLog{ID: 5, Output: "log 5"}) + } + _ = conn.Close(websocket.StatusNormalClosure, "done") + return + } + + count := versionCallCount.Add(1) + status := codersdk.ProvisionerJobRunning + if count >= 2 { + status = codersdk.ProvisionerJobSucceeded + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(codersdk.TemplateVersion{ + ID: versionID, + Job: codersdk.ProvisionerJob{Status: status}, + }) + }) + + srv := httptest.NewServer(handler) + t.Cleanup(srv.Close) + srvURL, err := url.Parse(srv.URL) + require.NoError(t, err) + client := codersdk.New(srvURL) + + version := &codersdk.TemplateVersion{ID: versionID} + logs, err := waitForJob(context.Background(), client, version) + require.NoError(t, err) + require.Len(t, logs, 5) + for i, log := range logs { + require.Equal(t, int64(i+1), log.ID) + } + require.Equal(t, int32(2), wsCallCount.Load()) + select { + case got := <-secondAfterCh: + require.Equal(t, "3", got) + default: + t.Fatal("missing second after cursor") + } +} + +func TestWaitForJob_ContextCanceledDuringBackoff(t *testing.T) { + t.Parallel() + versionID := uuid.New() + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + var statusCallCount atomic.Int32 + firstStatusSeen := make(chan struct{}, 1) + go func() { + <-firstStatusSeen + cancel() + }() + + handler := http.NewServeMux() + handler.HandleFunc("/api/v2/templateversions/", func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.RawQuery, "follow") { + conn, err := websocket.Accept(w, r, nil) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + _ = conn.Close(websocket.StatusNormalClosure, "done") + return + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(codersdk.TemplateVersion{ + ID: versionID, + Job: codersdk.ProvisionerJob{Status: codersdk.ProvisionerJobRunning}, + }) + // Cancel after the first status response so waitForJob hits cancellation while waiting to retry. + if statusCallCount.Add(1) == 1 { + firstStatusSeen <- struct{}{} + } + }) + + srv := httptest.NewServer(handler) + t.Cleanup(srv.Close) + srvURL, err := url.Parse(srv.URL) + require.NoError(t, err) + client := codersdk.New(srvURL) + + version := &codersdk.TemplateVersion{ID: versionID} + _, err = waitForJob(ctx, client, version) + require.ErrorIs(t, err, context.Canceled) +}