From 67aaf007d17b603827a6b1d6ac94d9ccd855f2ff Mon Sep 17 00:00:00 2001 From: simon Date: Fri, 13 Mar 2026 00:25:31 +0100 Subject: [PATCH 1/3] Add limit iterator and context plumbing for --limit flag --- libs/cmdio/limit.go | 50 +++++++++++++ libs/cmdio/limit_test.go | 146 ++++++++++++++++++++++++++++++++++++++ libs/cmdio/render.go | 3 + libs/cmdio/render_test.go | 30 ++++++++ 4 files changed, 229 insertions(+) create mode 100644 libs/cmdio/limit.go create mode 100644 libs/cmdio/limit_test.go diff --git a/libs/cmdio/limit.go b/libs/cmdio/limit.go new file mode 100644 index 0000000000..d6ef653a45 --- /dev/null +++ b/libs/cmdio/limit.go @@ -0,0 +1,50 @@ +package cmdio + +import ( + "context" + + "github.com/databricks/databricks-sdk-go/listing" +) + +type limitKey struct{} + +// WithLimit stores the limit in the context. +func WithLimit(ctx context.Context, n int) context.Context { + return context.WithValue(ctx, limitKey{}, n) +} + +// GetLimit retrieves the limit from context. Returns 0 if not set. +func GetLimit(ctx context.Context) int { + v, ok := ctx.Value(limitKey{}).(int) + if !ok { + return 0 + } + return v +} + +type limitIterator[T any] struct { + inner listing.Iterator[T] + remaining int +} + +func (l *limitIterator[T]) HasNext(ctx context.Context) bool { + return l.remaining > 0 && l.inner.HasNext(ctx) +} + +func (l *limitIterator[T]) Next(ctx context.Context) (T, error) { + v, err := l.inner.Next(ctx) + if err != nil { + return v, err + } + l.remaining-- + return v, nil +} + +// ApplyLimit wraps a listing.Iterator to yield at most the limit from context. +// It returns the iterator unchanged if the limit is not positive. +func ApplyLimit[T any](ctx context.Context, i listing.Iterator[T]) listing.Iterator[T] { + if limit := GetLimit(ctx); limit > 0 { + return &limitIterator[T]{inner: i, remaining: limit} + } + return i +} diff --git a/libs/cmdio/limit_test.go b/libs/cmdio/limit_test.go new file mode 100644 index 0000000000..33ec7e6401 --- /dev/null +++ b/libs/cmdio/limit_test.go @@ -0,0 +1,146 @@ +package cmdio_test + +import ( + "context" + "errors" + "testing" + + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/databricks-sdk-go/listing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type sliceIterator[T any] struct { + items []T +} + +func (s *sliceIterator[T]) HasNext(_ context.Context) bool { + return len(s.items) > 0 +} + +func (s *sliceIterator[T]) Next(_ context.Context) (T, error) { + if len(s.items) == 0 { + var zero T + return zero, errors.New("no more items") + } + item := s.items[0] + s.items = s.items[1:] + return item, nil +} + +func drain[T any](ctx context.Context, iter listing.Iterator[T]) ([]T, error) { + var result []T + for iter.HasNext(ctx) { + v, err := iter.Next(ctx) + if err != nil { + return result, err + } + result = append(result, v) + } + return result, nil +} + +type errorIterator[T any] struct { + items []T + failAt int + callCount int +} + +func (e *errorIterator[T]) HasNext(_ context.Context) bool { + return e.callCount <= e.failAt && e.callCount < len(e.items) +} + +func (e *errorIterator[T]) Next(_ context.Context) (T, error) { + idx := e.callCount + e.callCount++ + if idx == e.failAt { + var zero T + return zero, errors.New("fetch error") + } + return e.items[idx], nil +} + +func TestWithLimitRoundTrip(t *testing.T) { + ctx := cmdio.WithLimit(t.Context(), 42) + assert.Equal(t, 42, cmdio.GetLimit(ctx)) +} + +func TestGetLimitReturnsZeroWhenNotSet(t *testing.T) { + assert.Equal(t, 0, cmdio.GetLimit(t.Context())) +} + +func TestApplyLimit(t *testing.T) { + tests := []struct { + name string + limit int + setLimit bool + items []int + want []int + }{ + { + name: "caps results", + limit: 5, + setLimit: true, + items: []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + want: []int{1, 2, 3, 4, 5}, + }, + { + name: "no-op when unset", + items: []int{1, 2, 3}, + want: []int{1, 2, 3}, + }, + { + name: "greater than total", + limit: 10, + setLimit: true, + items: []int{1, 2, 3}, + want: []int{1, 2, 3}, + }, + { + name: "one", + limit: 1, + setLimit: true, + items: []int{1, 2, 3}, + want: []int{1}, + }, + { + name: "zero", + limit: 0, + setLimit: true, + items: []int{1, 2, 3}, + want: []int{1, 2, 3}, + }, + { + name: "negative", + limit: -1, + setLimit: true, + items: []int{1, 2, 3}, + want: []int{1, 2, 3}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := t.Context() + if tt.setLimit { + ctx = cmdio.WithLimit(ctx, tt.limit) + } + + iter := cmdio.ApplyLimit(ctx, &sliceIterator[int]{items: tt.items}) + + result, err := drain(t.Context(), iter) + require.NoError(t, err) + assert.Equal(t, tt.want, result) + }) + } +} + +func TestApplyLimitPreservesErrors(t *testing.T) { + ctx := cmdio.WithLimit(t.Context(), 5) + iter := cmdio.ApplyLimit(ctx, &errorIterator[int]{items: []int{1, 2, 3}, failAt: 2}) + + result, err := drain(t.Context(), iter) + assert.ErrorContains(t, err, "fetch error") + assert.Equal(t, []int{1, 2}, result) +} diff --git a/libs/cmdio/render.go b/libs/cmdio/render.go index c344c3d028..2bd52cc64d 100644 --- a/libs/cmdio/render.go +++ b/libs/cmdio/render.go @@ -264,6 +264,7 @@ func Render(ctx context.Context, v any) error { } func RenderIterator[T any](ctx context.Context, i listing.Iterator[T]) error { + i = ApplyLimit(ctx, i) c := fromContext(ctx) return renderWithTemplate(ctx, newIteratorRenderer(i), c.outputFormat, c.out, c.headerTemplate, c.template) } @@ -277,11 +278,13 @@ func RenderWithTemplate(ctx context.Context, v any, headerTemplate, template str } func RenderIteratorWithTemplate[T any](ctx context.Context, i listing.Iterator[T], headerTemplate, template string) error { + i = ApplyLimit(ctx, i) c := fromContext(ctx) return renderWithTemplate(ctx, newIteratorRenderer(i), c.outputFormat, c.out, headerTemplate, template) } func RenderIteratorJson[T any](ctx context.Context, i listing.Iterator[T]) error { + i = ApplyLimit(ctx, i) c := fromContext(ctx) return renderWithTemplate(ctx, newIteratorRenderer(i), c.outputFormat, c.out, c.headerTemplate, c.template) } diff --git a/libs/cmdio/render_test.go b/libs/cmdio/render_test.go index be41f80c38..67440b6b2d 100644 --- a/libs/cmdio/render_test.go +++ b/libs/cmdio/render_test.go @@ -167,6 +167,36 @@ var testCases = []testCase{ }, } +func TestRenderIteratorWithLimit(t *testing.T) { + output := &bytes.Buffer{} + ctx := t.Context() + cmdIO := NewIO(ctx, flags.OutputText, nil, output, output, + "id\tname", + "{{range .}}{{.WorkspaceId}}\t{{.WorkspaceName}}\n{{end}}") + ctx = InContext(ctx, cmdIO) + ctx = WithLimit(ctx, 3) + + err := RenderIterator(ctx, makeIterator(10)) + assert.NoError(t, err) + assert.Equal(t, "id name\n"+makeBigOutput(3), output.String()) +} + +func TestRenderIteratorWithLimitJSON(t *testing.T) { + output := &bytes.Buffer{} + ctx := t.Context() + cmdIO := NewIO(ctx, flags.OutputJSON, nil, output, output, "", "") + ctx = InContext(ctx, cmdIO) + ctx = WithLimit(ctx, 2) + + err := RenderIterator(ctx, makeIterator(10)) + assert.NoError(t, err) + + var items []provisioning.Workspace + err = json.Unmarshal(output.Bytes(), &items) + assert.NoError(t, err) + assert.Len(t, items, 2) +} + func TestRender(t *testing.T) { for _, c := range testCases { t.Run(c.name, func(t *testing.T) { From 8d90102f94775f6e79d13e23f5aa0f6b35204555 Mon Sep 17 00:00:00 2001 From: simon Date: Fri, 13 Mar 2026 15:39:54 +0100 Subject: [PATCH 2/3] Add defensive guard in limitIterator.Next() for exhausted limit Return listing.ErrNoMoreItems when Next() is called with remaining <= 0, so the limit is enforced even if the caller skips HasNext(). --- libs/cmdio/limit.go | 4 ++++ libs/cmdio/limit_test.go | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/libs/cmdio/limit.go b/libs/cmdio/limit.go index d6ef653a45..8bcf84151b 100644 --- a/libs/cmdio/limit.go +++ b/libs/cmdio/limit.go @@ -32,6 +32,10 @@ func (l *limitIterator[T]) HasNext(ctx context.Context) bool { } func (l *limitIterator[T]) Next(ctx context.Context) (T, error) { + if l.remaining <= 0 { + var zero T + return zero, listing.ErrNoMoreItems + } v, err := l.inner.Next(ctx) if err != nil { return v, err diff --git a/libs/cmdio/limit_test.go b/libs/cmdio/limit_test.go index 33ec7e6401..2165a17da2 100644 --- a/libs/cmdio/limit_test.go +++ b/libs/cmdio/limit_test.go @@ -144,3 +144,21 @@ func TestApplyLimitPreservesErrors(t *testing.T) { assert.ErrorContains(t, err, "fetch error") assert.Equal(t, []int{1, 2}, result) } + +func TestLimitIteratorNextWithoutHasNextReturnsError(t *testing.T) { + ctx := cmdio.WithLimit(t.Context(), 2) + iter := cmdio.ApplyLimit(ctx, &sliceIterator[int]{items: []int{1, 2, 3, 4, 5}}) + + // Drain the allowed items. + v1, err := iter.Next(t.Context()) + require.NoError(t, err) + assert.Equal(t, 1, v1) + + v2, err := iter.Next(t.Context()) + require.NoError(t, err) + assert.Equal(t, 2, v2) + + // Calling Next() again without HasNext() must return ErrNoMoreItems. + _, err = iter.Next(t.Context()) + assert.ErrorIs(t, err, listing.ErrNoMoreItems) +} From 2798fb3f50c74dcdee1a4ca80639b8e37d96b892 Mon Sep 17 00:00:00 2001 From: simon Date: Wed, 18 Mar 2026 23:15:13 +0100 Subject: [PATCH 3/3] Refactor --limit to use SDK's NewLimitIterator instead of context-based plumbing Remove the context-based limit mechanism (WithLimit, GetLimit, ApplyLimit and limitIterator) from libs/cmdio. The limit will now be applied at the callsite by wrapping the iterator with listing.NewLimitIterator from the SDK, keeping the render functions clean and decoupled from limit logic. Co-authored-by: Isaac --- go.mod | 2 + go.sum | 4 +- libs/cmdio/limit.go | 54 ------------- libs/cmdio/limit_test.go | 164 -------------------------------------- libs/cmdio/render.go | 3 - libs/cmdio/render_test.go | 30 ------- 6 files changed, 4 insertions(+), 253 deletions(-) delete mode 100644 libs/cmdio/limit.go delete mode 100644 libs/cmdio/limit_test.go diff --git a/go.mod b/go.mod index 97cab5b6f4..d5d77db55b 100644 --- a/go.mod +++ b/go.mod @@ -107,3 +107,5 @@ require ( google.golang.org/grpc v1.78.0 // indirect google.golang.org/protobuf v1.36.11 // indirect ) + +replace github.com/databricks/databricks-sdk-go => github.com/databricks/databricks-sdk-go v0.124.1-0.20260318213920-cda6ea5b4323 diff --git a/go.sum b/go.sum index f72c2e937d..270db59481 100644 --- a/go.sum +++ b/go.sum @@ -75,8 +75,8 @@ github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s= github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= github.com/cyphar/filepath-securejoin v0.4.1 h1:JyxxyPEaktOD+GAnqIqTf9A8tHyAG22rowi7HkoSU1s= github.com/cyphar/filepath-securejoin v0.4.1/go.mod h1:Sdj7gXlvMcPZsbhwhQ33GguGLDGQL7h7bg04C/+u9jI= -github.com/databricks/databricks-sdk-go v0.120.0 h1:XLEoLeVUB/MFygyklLiB2HtQTeaULnfr1RyGtYcl2gQ= -github.com/databricks/databricks-sdk-go v0.120.0/go.mod h1:hWoHnHbNLjPKiTm5K/7bcIv3J3Pkgo5x9pPzh8K3RVE= +github.com/databricks/databricks-sdk-go v0.124.1-0.20260318213920-cda6ea5b4323 h1:z71kyh8dyFeQIXIj09B0eF28WhrGIpkQTUz0iURmpkI= +github.com/databricks/databricks-sdk-go v0.124.1-0.20260318213920-cda6ea5b4323/go.mod h1:hWoHnHbNLjPKiTm5K/7bcIv3J3Pkgo5x9pPzh8K3RVE= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/libs/cmdio/limit.go b/libs/cmdio/limit.go deleted file mode 100644 index 8bcf84151b..0000000000 --- a/libs/cmdio/limit.go +++ /dev/null @@ -1,54 +0,0 @@ -package cmdio - -import ( - "context" - - "github.com/databricks/databricks-sdk-go/listing" -) - -type limitKey struct{} - -// WithLimit stores the limit in the context. -func WithLimit(ctx context.Context, n int) context.Context { - return context.WithValue(ctx, limitKey{}, n) -} - -// GetLimit retrieves the limit from context. Returns 0 if not set. -func GetLimit(ctx context.Context) int { - v, ok := ctx.Value(limitKey{}).(int) - if !ok { - return 0 - } - return v -} - -type limitIterator[T any] struct { - inner listing.Iterator[T] - remaining int -} - -func (l *limitIterator[T]) HasNext(ctx context.Context) bool { - return l.remaining > 0 && l.inner.HasNext(ctx) -} - -func (l *limitIterator[T]) Next(ctx context.Context) (T, error) { - if l.remaining <= 0 { - var zero T - return zero, listing.ErrNoMoreItems - } - v, err := l.inner.Next(ctx) - if err != nil { - return v, err - } - l.remaining-- - return v, nil -} - -// ApplyLimit wraps a listing.Iterator to yield at most the limit from context. -// It returns the iterator unchanged if the limit is not positive. -func ApplyLimit[T any](ctx context.Context, i listing.Iterator[T]) listing.Iterator[T] { - if limit := GetLimit(ctx); limit > 0 { - return &limitIterator[T]{inner: i, remaining: limit} - } - return i -} diff --git a/libs/cmdio/limit_test.go b/libs/cmdio/limit_test.go deleted file mode 100644 index 2165a17da2..0000000000 --- a/libs/cmdio/limit_test.go +++ /dev/null @@ -1,164 +0,0 @@ -package cmdio_test - -import ( - "context" - "errors" - "testing" - - "github.com/databricks/cli/libs/cmdio" - "github.com/databricks/databricks-sdk-go/listing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -type sliceIterator[T any] struct { - items []T -} - -func (s *sliceIterator[T]) HasNext(_ context.Context) bool { - return len(s.items) > 0 -} - -func (s *sliceIterator[T]) Next(_ context.Context) (T, error) { - if len(s.items) == 0 { - var zero T - return zero, errors.New("no more items") - } - item := s.items[0] - s.items = s.items[1:] - return item, nil -} - -func drain[T any](ctx context.Context, iter listing.Iterator[T]) ([]T, error) { - var result []T - for iter.HasNext(ctx) { - v, err := iter.Next(ctx) - if err != nil { - return result, err - } - result = append(result, v) - } - return result, nil -} - -type errorIterator[T any] struct { - items []T - failAt int - callCount int -} - -func (e *errorIterator[T]) HasNext(_ context.Context) bool { - return e.callCount <= e.failAt && e.callCount < len(e.items) -} - -func (e *errorIterator[T]) Next(_ context.Context) (T, error) { - idx := e.callCount - e.callCount++ - if idx == e.failAt { - var zero T - return zero, errors.New("fetch error") - } - return e.items[idx], nil -} - -func TestWithLimitRoundTrip(t *testing.T) { - ctx := cmdio.WithLimit(t.Context(), 42) - assert.Equal(t, 42, cmdio.GetLimit(ctx)) -} - -func TestGetLimitReturnsZeroWhenNotSet(t *testing.T) { - assert.Equal(t, 0, cmdio.GetLimit(t.Context())) -} - -func TestApplyLimit(t *testing.T) { - tests := []struct { - name string - limit int - setLimit bool - items []int - want []int - }{ - { - name: "caps results", - limit: 5, - setLimit: true, - items: []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - want: []int{1, 2, 3, 4, 5}, - }, - { - name: "no-op when unset", - items: []int{1, 2, 3}, - want: []int{1, 2, 3}, - }, - { - name: "greater than total", - limit: 10, - setLimit: true, - items: []int{1, 2, 3}, - want: []int{1, 2, 3}, - }, - { - name: "one", - limit: 1, - setLimit: true, - items: []int{1, 2, 3}, - want: []int{1}, - }, - { - name: "zero", - limit: 0, - setLimit: true, - items: []int{1, 2, 3}, - want: []int{1, 2, 3}, - }, - { - name: "negative", - limit: -1, - setLimit: true, - items: []int{1, 2, 3}, - want: []int{1, 2, 3}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ctx := t.Context() - if tt.setLimit { - ctx = cmdio.WithLimit(ctx, tt.limit) - } - - iter := cmdio.ApplyLimit(ctx, &sliceIterator[int]{items: tt.items}) - - result, err := drain(t.Context(), iter) - require.NoError(t, err) - assert.Equal(t, tt.want, result) - }) - } -} - -func TestApplyLimitPreservesErrors(t *testing.T) { - ctx := cmdio.WithLimit(t.Context(), 5) - iter := cmdio.ApplyLimit(ctx, &errorIterator[int]{items: []int{1, 2, 3}, failAt: 2}) - - result, err := drain(t.Context(), iter) - assert.ErrorContains(t, err, "fetch error") - assert.Equal(t, []int{1, 2}, result) -} - -func TestLimitIteratorNextWithoutHasNextReturnsError(t *testing.T) { - ctx := cmdio.WithLimit(t.Context(), 2) - iter := cmdio.ApplyLimit(ctx, &sliceIterator[int]{items: []int{1, 2, 3, 4, 5}}) - - // Drain the allowed items. - v1, err := iter.Next(t.Context()) - require.NoError(t, err) - assert.Equal(t, 1, v1) - - v2, err := iter.Next(t.Context()) - require.NoError(t, err) - assert.Equal(t, 2, v2) - - // Calling Next() again without HasNext() must return ErrNoMoreItems. - _, err = iter.Next(t.Context()) - assert.ErrorIs(t, err, listing.ErrNoMoreItems) -} diff --git a/libs/cmdio/render.go b/libs/cmdio/render.go index 2bd52cc64d..c344c3d028 100644 --- a/libs/cmdio/render.go +++ b/libs/cmdio/render.go @@ -264,7 +264,6 @@ func Render(ctx context.Context, v any) error { } func RenderIterator[T any](ctx context.Context, i listing.Iterator[T]) error { - i = ApplyLimit(ctx, i) c := fromContext(ctx) return renderWithTemplate(ctx, newIteratorRenderer(i), c.outputFormat, c.out, c.headerTemplate, c.template) } @@ -278,13 +277,11 @@ func RenderWithTemplate(ctx context.Context, v any, headerTemplate, template str } func RenderIteratorWithTemplate[T any](ctx context.Context, i listing.Iterator[T], headerTemplate, template string) error { - i = ApplyLimit(ctx, i) c := fromContext(ctx) return renderWithTemplate(ctx, newIteratorRenderer(i), c.outputFormat, c.out, headerTemplate, template) } func RenderIteratorJson[T any](ctx context.Context, i listing.Iterator[T]) error { - i = ApplyLimit(ctx, i) c := fromContext(ctx) return renderWithTemplate(ctx, newIteratorRenderer(i), c.outputFormat, c.out, c.headerTemplate, c.template) } diff --git a/libs/cmdio/render_test.go b/libs/cmdio/render_test.go index 67440b6b2d..be41f80c38 100644 --- a/libs/cmdio/render_test.go +++ b/libs/cmdio/render_test.go @@ -167,36 +167,6 @@ var testCases = []testCase{ }, } -func TestRenderIteratorWithLimit(t *testing.T) { - output := &bytes.Buffer{} - ctx := t.Context() - cmdIO := NewIO(ctx, flags.OutputText, nil, output, output, - "id\tname", - "{{range .}}{{.WorkspaceId}}\t{{.WorkspaceName}}\n{{end}}") - ctx = InContext(ctx, cmdIO) - ctx = WithLimit(ctx, 3) - - err := RenderIterator(ctx, makeIterator(10)) - assert.NoError(t, err) - assert.Equal(t, "id name\n"+makeBigOutput(3), output.String()) -} - -func TestRenderIteratorWithLimitJSON(t *testing.T) { - output := &bytes.Buffer{} - ctx := t.Context() - cmdIO := NewIO(ctx, flags.OutputJSON, nil, output, output, "", "") - ctx = InContext(ctx, cmdIO) - ctx = WithLimit(ctx, 2) - - err := RenderIterator(ctx, makeIterator(10)) - assert.NoError(t, err) - - var items []provisioning.Workspace - err = json.Unmarshal(output.Bytes(), &items) - assert.NoError(t, err) - assert.Len(t, items, 2) -} - func TestRender(t *testing.T) { for _, c := range testCases { t.Run(c.name, func(t *testing.T) {