Skip to content

Commit af9c082

Browse files
committed
chore: move lastUserPrompt and correlatingToolCallID to ResponsesRequestPayload
1 parent 5f14de7 commit af9c082

6 files changed

Lines changed: 297 additions & 328 deletions

File tree

intercept/responses/base.go

Lines changed: 3 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"bytes"
55
"context"
66
"encoding/json"
7-
"errors"
87
"fmt"
98
"io"
109
"net/http"
@@ -20,15 +19,13 @@ import (
2019
"github.com/coder/aibridge/intercept"
2120
"github.com/coder/aibridge/intercept/apidump"
2221
"github.com/coder/aibridge/mcp"
23-
"github.com/coder/aibridge/metrics"
2422
"github.com/coder/aibridge/recorder"
2523
"github.com/coder/aibridge/tracing"
2624
"github.com/coder/quartz"
2725
"github.com/google/uuid"
2826
"github.com/openai/openai-go/v3/option"
2927
"github.com/openai/openai-go/v3/responses"
3028
"github.com/openai/openai-go/v3/shared/constant"
31-
"github.com/tidwall/gjson"
3229
"go.opentelemetry.io/otel/attribute"
3330
"go.opentelemetry.io/otel/trace"
3431
)
@@ -48,9 +45,8 @@ type responsesInterceptionBase struct {
4845
recorder recorder.Recorder
4946
mcpProxy mcp.ServerProxier
5047

51-
logger slog.Logger
52-
metrics metrics.Metrics
53-
tracer trace.Tracer
48+
logger slog.Logger
49+
tracer trace.Tracer
5450
}
5551

5652
func (i *responsesInterceptionBase) newResponsesService() responses.ResponseService {
@@ -96,27 +92,7 @@ func (i *responsesInterceptionBase) Model() string {
9692
}
9793

9894
func (i *responsesInterceptionBase) CorrelatingToolCallID() *string {
99-
items := gjson.GetBytes(i.reqPayload, "input")
100-
if !items.IsArray() {
101-
return nil
102-
}
103-
104-
arr := items.Array()
105-
if len(arr) == 0 {
106-
return nil
107-
}
108-
109-
last := arr[len(arr)-1]
110-
if last.Get(string(constant.ValueOf[constant.Type]())).String() != string(constant.ValueOf[constant.FunctionCallOutput]()) {
111-
return nil
112-
}
113-
114-
callID := last.Get("call_id").String()
115-
if callID == "" {
116-
return nil
117-
}
118-
119-
return &callID
95+
return i.reqPayload.correlatingToolCallID()
12096
}
12197

12298
func (i *responsesInterceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue {
@@ -178,88 +154,6 @@ func (i *responsesInterceptionBase) requestOptions(respCopy *responseCopier) []o
178154
return opts
179155
}
180156

181-
// lastUserPrompt returns input text with "user" role from last input item
182-
// or string input value if it is present + bool indicating if input was found or not.
183-
// If no such input was found empty string + false is returned.
184-
func (i *responsesInterceptionBase) lastUserPrompt() (string, bool, error) {
185-
if i == nil {
186-
return "", false, errors.New("cannot get last user prompt: nil struct")
187-
}
188-
if i.reqPayload == nil {
189-
return "", false, errors.New("cannot get last user prompt: nil request struct")
190-
}
191-
192-
// 'input' can be either a string or an array of input items:
193-
// https://platform.openai.com/docs/api-reference/responses/create#responses_create-input
194-
inputItems := gjson.GetBytes(i.reqPayload, "input")
195-
if !inputItems.Exists() || inputItems.Type == gjson.Null {
196-
return "", false, nil
197-
}
198-
199-
// String variant: treat the whole input as the user prompt.
200-
if inputItems.Type == gjson.String {
201-
return inputItems.String(), true, nil
202-
}
203-
204-
// Array variant: checking only the last input item
205-
if !inputItems.IsArray() {
206-
return "", false, fmt.Errorf("unexpected input type: %s", inputItems.Type)
207-
}
208-
209-
inputItemsArr := inputItems.Array()
210-
if len(inputItemsArr) == 0 {
211-
return "", false, nil
212-
}
213-
214-
lastItem := inputItemsArr[len(inputItemsArr)-1]
215-
if lastItem.Get("role").Str != string(constant.ValueOf[constant.User]()) {
216-
// Request was likely not initiated by a prompt but is an iteration of agentic loop.
217-
return "", false, nil
218-
}
219-
220-
// Message content can be either a string or an array of typed content items:
221-
// https://platform.openai.com/docs/api-reference/responses/create#responses_create-input-input_item_list-input_message-content
222-
content := lastItem.Get(string(constant.ValueOf[constant.Content]()))
223-
if !content.Exists() || content.Type == gjson.Null {
224-
return "", false, nil
225-
}
226-
227-
// String variant: use it directly as the prompt.
228-
if content.Type == gjson.String {
229-
return content.Str, true, nil
230-
}
231-
232-
if !content.IsArray() {
233-
return "", false, fmt.Errorf("unexpected input content type: %s", content.Type)
234-
}
235-
236-
var sb strings.Builder
237-
promptExists := false
238-
for _, c := range content.Array() {
239-
// Ignore non-text content blocks such as images or files.
240-
if c.Get(string(constant.ValueOf[constant.Type]())).Str != string(constant.ValueOf[constant.InputText]()) {
241-
continue
242-
}
243-
244-
text := c.Get(string(constant.ValueOf[constant.Text]()))
245-
if text.Type != gjson.String {
246-
continue
247-
}
248-
249-
if promptExists {
250-
sb.WriteByte('\n')
251-
}
252-
promptExists = true
253-
sb.WriteString(text.Str)
254-
}
255-
256-
if !promptExists {
257-
return "", false, nil
258-
}
259-
260-
return sb.String(), true, nil
261-
}
262-
263157
func (i *responsesInterceptionBase) recordUserPrompt(ctx context.Context, responseID string, prompt string) {
264158
if responseID == "" {
265159
i.logger.Warn(ctx, "got empty response ID, skipping prompt recording")

intercept/responses/base_test.go

Lines changed: 0 additions & 216 deletions
Original file line numberDiff line numberDiff line change
@@ -6,229 +6,13 @@ import (
66
"time"
77

88
"cdr.dev/slog/v3"
9-
"github.com/coder/aibridge/fixtures"
109
"github.com/coder/aibridge/internal/testutil"
1110
"github.com/coder/aibridge/recorder"
12-
"github.com/coder/aibridge/utils"
1311
"github.com/google/uuid"
1412
oairesponses "github.com/openai/openai-go/v3/responses"
15-
"github.com/stretchr/testify/assert"
1613
"github.com/stretchr/testify/require"
1714
)
1815

19-
func TestScanForCorrelatingToolCallID(t *testing.T) {
20-
t.Parallel()
21-
22-
tests := []struct {
23-
name string
24-
payload []byte
25-
wantCall *string
26-
}{
27-
{
28-
name: "no input",
29-
payload: []byte(`{"model":"gpt-4o"}`),
30-
},
31-
{
32-
name: "empty input array",
33-
payload: []byte(`{"model":"gpt-4o","input":[]}`),
34-
},
35-
{
36-
name: "no function_call_output items",
37-
payload: []byte(`{"model":"gpt-4o","input":[{"role":"user","content":"hi"}]}`),
38-
},
39-
{
40-
name: "single function_call_output",
41-
payload: []byte(`{"model":"gpt-4o","input":[{"role":"user","content":"hi"},{"type":"function_call_output","call_id":"call_abc","output":"result"}]}`),
42-
wantCall: utils.PtrTo("call_abc"),
43-
},
44-
{
45-
name: "multiple function_call_outputs returns last",
46-
payload: []byte(`{"model":"gpt-4o","input":[{"type":"function_call_output","call_id":"call_first","output":"r1"},{"role":"user","content":"hi"},{"type":"function_call_output","call_id":"call_second","output":"r2"}]}`),
47-
wantCall: utils.PtrTo("call_second"),
48-
},
49-
{
50-
name: "last input is not a tool result",
51-
payload: []byte(`{"model":"gpt-4o","input":[{"type":"function_call_output","call_id":"call_first","output":"r1"},{"role":"user","content":"hi"}]}`),
52-
},
53-
{
54-
name: "missing call id",
55-
payload: []byte(`{"input":[{"type":"function_call_output","output":"ok"}]}`),
56-
},
57-
}
58-
59-
for _, tc := range tests {
60-
t.Run(tc.name, func(t *testing.T) {
61-
t.Parallel()
62-
63-
rp, err := NewResponsesRequestPayload(tc.payload)
64-
require.NoError(t, err)
65-
base := &responsesInterceptionBase{
66-
reqPayload: rp,
67-
}
68-
69-
callID := base.CorrelatingToolCallID()
70-
assert.Equal(t, tc.wantCall, callID)
71-
})
72-
}
73-
}
74-
75-
func TestLastUserPrompt(t *testing.T) {
76-
t.Parallel()
77-
78-
tests := []struct {
79-
name string
80-
reqPayload []byte
81-
expect string
82-
}{
83-
{
84-
name: "input_empty_string",
85-
reqPayload: []byte(`{"input": ""}`),
86-
expect: "",
87-
},
88-
{
89-
name: "input_array_content_empty_string",
90-
reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": ""}]}`),
91-
expect: "",
92-
},
93-
{
94-
name: "input_array_content_array_empty_string",
95-
reqPayload: []byte(`{"model": "gpt-4o", "input": [ { "role": "user", "content": [{"type": "input_text", "text": ""}] } ] }`),
96-
},
97-
{
98-
name: "input_array_content_array_multiple_inputs",
99-
reqPayload: []byte(`{"model": "gpt-4o", "input": [ { "role": "user", "content": [{"type": "input_text", "text": "a"}, {"type": "input_text", "text": "b"}] } ] }`),
100-
expect: "a\nb",
101-
},
102-
{
103-
name: "simple_string_input",
104-
reqPayload: fixtures.Request(t, fixtures.OaiResponsesBlockingSimple),
105-
expect: "tell me a joke",
106-
},
107-
{
108-
name: "array_single_input_string",
109-
reqPayload: fixtures.Request(t, fixtures.OaiResponsesBlockingSingleBuiltinTool),
110-
expect: "Is 3 + 5 a prime number? Use the add function to calculate the sum.",
111-
},
112-
{
113-
name: "array_multiple_items_content_objects",
114-
reqPayload: fixtures.Request(t, fixtures.OaiResponsesStreamingCodex),
115-
expect: "hello",
116-
},
117-
}
118-
119-
for _, tc := range tests {
120-
t.Run(tc.name, func(t *testing.T) {
121-
t.Parallel()
122-
123-
rp, err := NewResponsesRequestPayload(tc.reqPayload)
124-
require.NoError(t, err)
125-
base := &responsesInterceptionBase{
126-
reqPayload: rp,
127-
}
128-
129-
prompt, promptFound, err := base.lastUserPrompt()
130-
require.NoError(t, err)
131-
require.Equal(t, tc.expect, prompt)
132-
require.True(t, promptFound)
133-
})
134-
}
135-
}
136-
137-
func TestLastUserPromptNotFound(t *testing.T) {
138-
t.Parallel()
139-
140-
t.Run("nil_struct", func(t *testing.T) {
141-
t.Parallel()
142-
143-
var base *responsesInterceptionBase
144-
prompt, promptFound, err := base.lastUserPrompt()
145-
require.Error(t, err)
146-
require.Empty(t, prompt)
147-
require.False(t, promptFound)
148-
require.Contains(t, "cannot get last user prompt: nil struct", err.Error())
149-
})
150-
151-
t.Run("nil_request", func(t *testing.T) {
152-
t.Parallel()
153-
154-
base := responsesInterceptionBase{}
155-
prompt, promptFound, err := base.lastUserPrompt()
156-
require.Error(t, err)
157-
require.Empty(t, prompt)
158-
require.False(t, promptFound)
159-
require.Contains(t, "cannot get last user prompt: nil request struct", err.Error())
160-
})
161-
162-
// Cases where the user prompt is not found / wrong format.
163-
tests := []struct {
164-
name string
165-
reqPayload []byte
166-
expectErr string
167-
}{
168-
{
169-
name: "non_existing_input",
170-
reqPayload: []byte(`{"model": "gpt-4o"}`),
171-
},
172-
{
173-
name: "input_empty_array",
174-
reqPayload: []byte(`{"model": "gpt-4o", "input": []}`),
175-
},
176-
{
177-
name: "input_integer",
178-
reqPayload: []byte(`{"model": "gpt-4o", "input": 123}`),
179-
expectErr: "unexpected input type",
180-
},
181-
{
182-
name: "no_user_role",
183-
reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "assistant", "content": "hello"}]}`),
184-
},
185-
{
186-
name: "user_with_empty_content_array",
187-
reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": []}]}`),
188-
},
189-
{
190-
name: "input_array_integer",
191-
reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": 123}]}`),
192-
expectErr: "unexpected input content type",
193-
},
194-
{
195-
name: "user_with_non_input_text_content",
196-
reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": [{"type": "input_image", "url": "http://example.com/img.png"}]}]}`),
197-
},
198-
{
199-
name: "user_content_not_last",
200-
reqPayload: []byte(`{"model": "gpt-4o", "input": [ {"role": "user", "content":"input"}, {"role": "assistant", "content": "hello"} ]}`),
201-
},
202-
{
203-
name: "input_array_content_array_integer",
204-
reqPayload: []byte(`{"model": "gpt-4o", "input": [ { "role": "user", "content": [{"type": "input_text", "text": 123}] } ] }`),
205-
},
206-
}
207-
208-
for _, tc := range tests {
209-
t.Run(tc.name, func(t *testing.T) {
210-
t.Parallel()
211-
212-
rp, err := NewResponsesRequestPayload(tc.reqPayload)
213-
require.NoError(t, err)
214-
215-
base := &responsesInterceptionBase{
216-
reqPayload: rp,
217-
}
218-
219-
prompt, promptFound, err := base.lastUserPrompt()
220-
if tc.expectErr != "" {
221-
require.Error(t, err)
222-
require.Contains(t, err.Error(), tc.expectErr)
223-
} else {
224-
require.NoError(t, err)
225-
}
226-
require.Empty(t, prompt)
227-
require.False(t, promptFound)
228-
})
229-
}
230-
}
231-
23216
func TestRecordPrompt(t *testing.T) {
23317
t.Parallel()
23418

intercept/responses/blocking.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ func (i *BlockingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r *
7474
firstResponseID string
7575
)
7676

77-
prompt, promptFound, err := i.lastUserPrompt()
77+
prompt, promptFound, err := i.reqPayload.lastUserPrompt()
7878
if err != nil {
7979
i.logger.Warn(ctx, "failed to get user prompt", slog.Error(err))
8080
}

0 commit comments

Comments
 (0)