Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,8 @@
"typescript": "5.9.3",
"vite": "^7.2.7",
"vitest": "^4.0.14"
},
"dependencies": {
"@ag-ui/core": "0.0.49"
}
}
61 changes: 51 additions & 10 deletions packages/typescript/ai-client/src/chat-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ export class ChatClient {
// Create StreamProcessor with event handlers
this.processor = new StreamProcessor({
chunkStrategy: options.streamProcessor?.chunkStrategy,
initialMessages: options.initialMessages,
initialMessages: options.initialMessages as unknown as Array<UIMessage> | undefined,
events: {
onMessagesChange: (messages: Array<UIMessage>) => {
this.callbacksRef.current.onMessagesChange(messages)
Expand Down Expand Up @@ -390,17 +390,21 @@ export class ChatClient {
}
this.callbacksRef.current.onChunk(chunk)
this.processor.processChunk(chunk)
if (chunk.type === 'RUN_STARTED') {
this.activeRunIds.add(chunk.runId)
const chunkType = (chunk as unknown as { type: string }).type
if (chunkType === 'RUN_STARTED') {
this.activeRunIds.add((chunk as unknown as { runId: string }).runId)
this.setSessionGenerating(true)
}
// RUN_FINISHED / RUN_ERROR signal run completion — resolve processing
// (redundant if onStreamEnd already resolved it, harmless)
if (chunk.type === 'RUN_FINISHED' || chunk.type === 'RUN_ERROR') {
const runId = chunk.type === 'RUN_FINISHED' ? chunk.runId : undefined
if (chunkType === 'RUN_FINISHED' || chunkType === 'RUN_ERROR') {
const runId =
chunkType === 'RUN_FINISHED'
? (chunk as unknown as { runId: string }).runId
: undefined
if (runId) {
this.activeRunIds.delete(runId)
} else if (chunk.type === 'RUN_ERROR') {
} else if (chunkType === 'RUN_ERROR') {
// RUN_ERROR without runId is a session-level error; clear all runs
this.activeRunIds.clear()
}
Expand Down Expand Up @@ -522,22 +526,25 @@ export class ChatClient {
*/
async append(message: UIMessage | ModelMessage): Promise<void> {
// Normalize the message to ensure it has id and createdAt
const normalizedMessage = normalizeToUIMessage(message, generateMessageId)
const normalizedMessage = normalizeToUIMessage(
message as Parameters<typeof normalizeToUIMessage>[0],
generateMessageId,
)

// Skip system messages - they're handled via systemPrompts, not UIMessages
if (normalizedMessage.role === 'system') {
return
}

// Type assertion: after checking for system, we know it's user or assistant
const uiMessage = normalizedMessage as UIMessage
const uiMessage = normalizedMessage as unknown as UIMessage

// Emit message appended event
this.events.messageAppended(uiMessage)

// Add to messages
const messages = this.processor.getMessages()
this.processor.setMessages([...messages, uiMessage])
this.processor.setMessages([...messages, uiMessage] as unknown as Array<UIMessage>)

// If stream is in progress, queue the response for after it ends
if (this.isLoading) {
Expand Down Expand Up @@ -805,6 +812,8 @@ export class ChatClient {
// Find the tool call ID from the approval ID
const messages = this.processor.getMessages()
let foundToolCallId: string | undefined
let foundToolName: string | undefined
let foundToolInput: any | undefined

for (const msg of messages) {
const toolCallPart = msg.parts.find(
Expand All @@ -813,6 +822,12 @@ export class ChatClient {
)
if (toolCallPart) {
foundToolCallId = toolCallPart.id
foundToolName = toolCallPart.name
try {
foundToolInput = JSON.parse(toolCallPart.arguments)
} catch {
// Ignore parse errors
}
break
}
}
Expand All @@ -828,6 +843,32 @@ export class ChatClient {
// Add response via processor
this.processor.addToolApprovalResponse(response.id, response.approved)

// Execute client-side tool if approved
if (response.approved && foundToolCallId && foundToolName) {
const clientTool = this.clientToolsRef.current.get(foundToolName)
if (clientTool?.execute) {
try {
const output = await clientTool.execute(foundToolInput)
await this.addToolResult({
toolCallId: foundToolCallId,
tool: foundToolName,
output,
state: 'output-available',
})
return
} catch (error: any) {
await this.addToolResult({
toolCallId: foundToolCallId,
tool: foundToolName,
output: null,
state: 'output-error',
errorText: error.message,
})
return
}
}
}

// If stream is in progress, queue continuation check for after it ends
if (this.isLoading) {
this.queuePostStreamAction(() => this.checkForContinuation())
Expand Down Expand Up @@ -961,7 +1002,7 @@ export class ChatClient {
* Manually set messages
*/
setMessagesManually(messages: Array<UIMessage>): void {
this.processor.setMessages(messages)
this.processor.setMessages(messages as unknown as Array<UIMessage>)
}

/**
Expand Down
9 changes: 8 additions & 1 deletion packages/typescript/ai-client/src/connection-adapters.ts
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ export function normalizeConnectionAdapter(
try {
const stream = connection.connect(messages, data, abortSignal)
for await (const chunk of stream) {
if (chunk.type === 'RUN_FINISHED' || chunk.type === 'RUN_ERROR') {
const chunkType = (chunk as unknown as { type: string }).type
if (chunkType === 'RUN_FINISHED' || chunkType === 'RUN_ERROR') {
hasTerminalEvent = true
}
push(chunk)
Expand Down Expand Up @@ -225,6 +226,12 @@ export interface FetchConnectionOptions {
signal?: AbortSignal
body?: Record<string, any>
fetchClient?: typeof globalThis.fetch
/**
* Send full UIMessage objects (including `parts`) instead of ModelMessages.
* Required for advanced server features that depend on UIMessage metadata
* (e.g. tool approvals and client tool results tracked in parts).
*/
sendFullMessages?: boolean
}

/**
Expand Down
149 changes: 149 additions & 0 deletions packages/typescript/ai-client/tests/chat-client-approval.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import { describe, expect, it, vi } from 'vitest'
import { ChatClient } from '../src/chat-client'
import { stream } from '../src/connection-adapters'
import type { StreamChunk } from '@tanstack/ai'

function createMockConnectionAdapter(options: { chunks: StreamChunk[] }) {
return stream(async function* () {
for (const chunk of options.chunks) {
yield chunk
}
})
}

/** Cast an event object to StreamChunk for type compatibility with EventType enum. */
const asChunk = (chunk: Record<string, unknown>) =>
chunk as unknown as StreamChunk

function createApprovalToolCallChunks(
toolCalls: Array<{
id: string
name: string
arguments: string
approvalId: string
}>,
): StreamChunk[] {
const chunks: StreamChunk[] = []
const timestamp = Date.now()

// Start assistant message
chunks.push(
asChunk({
type: 'TEXT_MESSAGE_START',
messageId: 'msg-1',
role: 'assistant',
timestamp,
}),
)

for (const toolCall of toolCalls) {
// 1. Tool Call Start
chunks.push(
asChunk({
type: 'TOOL_CALL_START',
toolCallId: toolCall.id,
toolName: toolCall.name,
model: 'test-model',
timestamp,
}),
)

// 2. Tool Call Args
chunks.push(
asChunk({
type: 'TOOL_CALL_ARGS',
toolCallId: toolCall.id,
delta: toolCall.arguments,
args: toolCall.arguments,
model: 'test-model',
timestamp,
}),
)

// 3. Approval Requested (custom event)
chunks.push(
asChunk({
type: 'CUSTOM',
name: 'approval-requested',
model: 'test-model',
timestamp,
value: {
toolCallId: toolCall.id,
toolName: toolCall.name,
input: JSON.parse(toolCall.arguments),
approval: {
id: toolCall.approvalId,
needsApproval: true,
},
},
}),
)
}

// Run Finished
chunks.push(
asChunk({
type: 'RUN_FINISHED',
runId: 'run-1',
threadId: 'thread-1',
model: 'test-model',
timestamp,
finishReason: 'tool_calls',
}),
)

return chunks
}

describe('ChatClient Approval Flow', () => {
it('should execute client tool when approved', async () => {
const toolName = 'delete_local_data'
const toolCallId = 'call_123'
const approvalId = 'approval_123'
const input = { key: 'test-key' }

const chunks = createApprovalToolCallChunks([
{
id: toolCallId,
name: toolName,
arguments: JSON.stringify(input),
approvalId,
},
])

const adapter = createMockConnectionAdapter({ chunks })

const execute = vi.fn().mockResolvedValue({ deleted: true })
const clientTool = {
name: toolName,
description: 'Delete data',
execute,
}

const client = new ChatClient({
connection: adapter,
tools: [clientTool],
})

// Start the flow
await client.sendMessage('Delete data')

// Wait for stream to finish (approval request should be pending)
await new Promise((resolve) => setTimeout(resolve, 100))

// Verify tool execution hasn't happened yet
expect(execute).not.toHaveBeenCalled()

// Approve the tool
await client.addToolApprovalResponse({
id: approvalId,
approved: true,
})

// Wait for execution (this is where it currently hangs/fails)
await new Promise((resolve) => setTimeout(resolve, 100))

// Expect execute to have been called
expect(execute).toHaveBeenCalledWith(input)
})
})
Loading