diff --git a/packages/tools/src/openai/middleware.ts b/packages/tools/src/openai/middleware.ts index bce986d0f..6580a1b2f 100644 --- a/packages/tools/src/openai/middleware.ts +++ b/packages/tools/src/openai/middleware.ts @@ -1,6 +1,10 @@ import type OpenAI from "openai" import Supermemory from "supermemory" import { addConversation } from "../conversations-client" +import { + findLastUserMessage, + extractTextFromMessageContent, +} from "../shared/memory-client" import { deduplicateMemories } from "../tools-shared" import { createLogger, type Logger } from "../vercel/logger" import { convertProfileToMarkdown } from "../vercel/util" @@ -54,14 +58,9 @@ interface SupermemoryProfileSearch { const getLastUserMessage = ( messages: OpenAI.Chat.Completions.ChatCompletionMessageParam[], ) => { - const lastUserMessage = messages - .slice() - .reverse() - .find((msg) => msg.role === "user") - - return typeof lastUserMessage?.content === "string" - ? lastUserMessage.content - : "" + const userMessage = findLastUserMessage(messages) + if (!userMessage) return "" + return extractTextFromMessageContent(userMessage.content) ?? "" } /** diff --git a/packages/tools/src/shared/memory-client.ts b/packages/tools/src/shared/memory-client.ts index d55926c02..bc1adbba0 100644 --- a/packages/tools/src/shared/memory-client.ts +++ b/packages/tools/src/shared/memory-client.ts @@ -176,6 +176,50 @@ export interface GenericMessage { content: string | Array<{ type: string; text?: string }> } +export const joinTextParts = ( + parts: Array<{ type: string; text?: string }>, +): string => + parts + .filter((p) => p.type === "text") + .map((p) => p.text ?? "") + .join(" ") + +export const findLastUserMessage = ( + messages: T[], +): T | undefined => { + for (let i = messages.length - 1; i >= 0; i--) { + if (messages[i].role === "user") return messages[i] + } + return undefined +} + +export const extractTextFromMessageContent = ( + content: unknown, +): string | undefined => { + if (typeof content === "string") { + return content + } + + if (Array.isArray(content)) { + return joinTextParts(content as Array<{ type: string; text?: string }>) + } + + const objContent = content as { + content?: string + parts?: Array<{ type: string; text?: string }> + } | null + if (typeof objContent === "object" && objContent !== null) { + if (typeof objContent.content === "string") { + return objContent.content + } + if (Array.isArray(objContent.parts)) { + return joinTextParts(objContent.parts) + } + } + + return undefined +} + /** * Extracts the query text from messages based on mode. * For "profile" mode, returns empty string (no query needed). @@ -195,42 +239,9 @@ export const extractQueryText = ( return "" } - const userMessage = messages - .slice() - .reverse() - .find((msg) => msg.role === "user") - - const content = userMessage?.content - if (!content) return "" - - if (typeof content === "string") { - return content - } - - if (Array.isArray(content)) { - return content - .filter((part) => part.type === "text") - .map((part) => part.text || "") - .join(" ") - } - - const objContent = content as unknown as { - content?: string - parts?: Array<{ type: string; text?: string }> - } - if (typeof objContent === "object" && objContent !== null) { - if ("content" in objContent && typeof objContent.content === "string") { - return objContent.content - } - if ("parts" in objContent && Array.isArray(objContent.parts)) { - return objContent.parts - .filter((part) => part.type === "text") - .map((part) => part.text || "") - .join(" ") - } - } - - return "" + const userMessage = findLastUserMessage(messages) + if (!userMessage) return "" + return extractTextFromMessageContent(userMessage.content) ?? "" } /** @@ -242,43 +253,7 @@ export const extractQueryText = ( export const getLastUserMessageText = ( messages: GenericMessage[], ): string | undefined => { - const lastUserMessage = messages - .slice() - .reverse() - .find((msg) => msg.role === "user") - - if (!lastUserMessage) { - return undefined - } - - const content = lastUserMessage.content - - if (typeof content === "string") { - return content - } - - if (Array.isArray(content)) { - return content - .filter((part) => part.type === "text") - .map((part) => part.text || "") - .join(" ") - } - - const objContent = content as unknown as { - content?: string - parts?: Array<{ type: string; text?: string }> - } - if (typeof objContent === "object" && objContent !== null) { - if ("content" in objContent && typeof objContent.content === "string") { - return objContent.content - } - if ("parts" in objContent && Array.isArray(objContent.parts)) { - return objContent.parts - .filter((part) => part.type === "text") - .map((part) => part.text || "") - .join(" ") - } - } - - return undefined + const userMessage = findLastUserMessage(messages) + if (!userMessage) return undefined + return extractTextFromMessageContent(userMessage.content) } diff --git a/packages/tools/src/vercel/memory-prompt.ts b/packages/tools/src/vercel/memory-prompt.ts index 29f2c6cb9..91306a0cb 100644 --- a/packages/tools/src/vercel/memory-prompt.ts +++ b/packages/tools/src/vercel/memory-prompt.ts @@ -9,6 +9,10 @@ export { } from "../shared" import type { Logger } from "../shared" +import { + findLastUserMessage, + extractTextFromMessageContent, +} from "../shared/memory-client" import type { LanguageModelCallOptions } from "./util" /** @@ -28,23 +32,9 @@ export const extractQueryText = ( return "" } - const userMessage = params.prompt - .slice() - .reverse() - .find((prompt: { role: string }) => prompt.role === "user") - - const content = userMessage?.content - if (!content) return "" - - if (typeof content === "string") { - return content - } - - // biome-ignore lint/suspicious/noExplicitAny: Union type compatibility between V2 and V3 - return (content as any[]) - .filter((part) => part.type === "text") - .map((part) => part.text || "") - .join(" ") + const userMessage = findLastUserMessage(params.prompt) + if (!userMessage) return "" + return extractTextFromMessageContent(userMessage.content) ?? "" } /** diff --git a/packages/tools/src/vercel/util.ts b/packages/tools/src/vercel/util.ts index 49ab30c5c..688a047be 100644 --- a/packages/tools/src/vercel/util.ts +++ b/packages/tools/src/vercel/util.ts @@ -47,30 +47,17 @@ export type OutputContentItem = // Re-export convertProfileToMarkdown from shared for backward compatibility export { convertProfileToMarkdown } from "../shared" +import { + findLastUserMessage, + extractTextFromMessageContent, +} from "../shared/memory-client" + export const getLastUserMessage = ( params: LanguageModelCallOptions, ): string | undefined => { - const lastUserMessage = params.prompt - .slice() - .reverse() - .find((prompt: LanguageModelMessage) => prompt.role === "user") - - if (!lastUserMessage) { - return undefined - } - - const content = lastUserMessage.content - - // Handle string content directly - if (typeof content === "string") { - return content - } - - // Handle array content - extract text parts - return content - .filter((part) => part.type === "text") - .map((part) => (part as { type: "text"; text: string }).text) - .join(" ") + const userMessage = findLastUserMessage(params.prompt) + if (!userMessage) return undefined + return extractTextFromMessageContent(userMessage.content) } export const filterOutSupermemories = (content: string) => {