From 408f5c14189e99821707db15f1dd81ceeec0e1a7 Mon Sep 17 00:00:00 2001 From: Hannes Rudolph Date: Wed, 11 Feb 2026 13:10:03 -0700 Subject: [PATCH] refactor: extract shared prompt cache breakpoint layer from 4 providers Move cache breakpoint logic from individual providers to a shared utility called from Task.ts before createMessage(). Messages arrive at providers pre-annotated with providerOptions, and the AI SDK routes the correct options to the active provider automatically. New files: - src/api/transform/prompt-cache.ts: resolveCacheProviderOptions() + applyCacheBreakpoints() with provider adapter mapping - src/api/transform/__tests__/prompt-cache.spec.ts: 14 test cases Changes per provider: - anthropic.ts: removed targeting block + applyCacheControlToAiSdkMessages() - anthropic-vertex.ts: same - minimax.ts: same - bedrock.ts: removed targeting block + applyCachePointsToAiSdkMessages() Key improvements: - Targets non-assistant batches (user + tool) instead of only role=user. After PR #11409, tool results are separate role=tool messages that now correctly receive cache breakpoints. - Single source of truth: cache strategy defined once in prompt-cache.ts - Provider-specific config preserved: Bedrock gets 3 breakpoints + anchor, Anthropic family gets 2 breakpoints Preserved (untouched): - systemProviderOptions in all providers' streamText() calls - OpenAI Native promptCacheRetention (provider-level, not per-message) - Bedrock usePromptCache opt-in + supportsAwsPromptCache() 5,491 tests pass, 0 regressions. --- src/api/providers/__tests__/minimax.spec.ts | 5 - src/api/providers/anthropic-vertex.ts | 54 ---- src/api/providers/anthropic.ts | 43 ---- src/api/providers/bedrock.ts | 84 +------ src/api/providers/minimax.ts | 32 --- .../transform/__tests__/prompt-cache.spec.ts | 232 ++++++++++++++++++ src/api/transform/prompt-cache.ts | 174 +++++++++++++ src/core/task/Task.ts | 18 ++ 8 files changed, 426 insertions(+), 216 deletions(-) create mode 100644 src/api/transform/__tests__/prompt-cache.spec.ts create mode 100644 src/api/transform/prompt-cache.ts diff --git a/src/api/providers/__tests__/minimax.spec.ts b/src/api/providers/__tests__/minimax.spec.ts index 3538184eee..4d5dfc96da 100644 --- a/src/api/providers/__tests__/minimax.spec.ts +++ b/src/api/providers/__tests__/minimax.spec.ts @@ -343,11 +343,6 @@ describe("MiniMaxHandler", () => { expect.objectContaining({ role: "user", content: [{ type: "text", text: "Merged message" }], - providerOptions: { - anthropic: { - cacheControl: { type: "ephemeral" }, - }, - }, }), ]), ) diff --git a/src/api/providers/anthropic-vertex.ts b/src/api/providers/anthropic-vertex.ts index 131b36992e..1077d1f480 100644 --- a/src/api/providers/anthropic-vertex.ts +++ b/src/api/providers/anthropic-vertex.ts @@ -119,37 +119,6 @@ export class AnthropicVertexHandler extends BaseProvider implements SingleComple anthropicProviderOptions.disableParallelToolUse = true } - /** - * Vertex API has specific limitations for prompt caching: - * 1. Maximum of 4 blocks can have cache_control - * 2. Only text blocks can be cached (images and other content types cannot) - * 3. Cache control can only be applied to user messages, not assistant messages - * - * Our caching strategy: - * - Cache the system prompt (1 block) - * - Cache the last text block of the second-to-last user message (1 block) - * - Cache the last text block of the last user message (1 block) - * This ensures we stay under the 4-block limit while maintaining effective caching - * for the most relevant context. - */ - const cacheProviderOption = { anthropic: { cacheControl: { type: "ephemeral" as const } } } - - const userMsgIndices = messages.reduce( - (acc, msg, index) => ("role" in msg && msg.role === "user" ? [...acc, index] : acc), - [] as number[], - ) - - const targetIndices = new Set() - const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1 - const secondLastUserMsgIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1 - - if (lastUserMsgIndex >= 0) targetIndices.add(lastUserMsgIndex) - if (secondLastUserMsgIndex >= 0) targetIndices.add(secondLastUserMsgIndex) - - if (targetIndices.size > 0) { - this.applyCacheControlToAiSdkMessages(messages as ModelMessage[], targetIndices, cacheProviderOption) - } - // Build streamText request // Cast providerOptions to any to bypass strict JSONObject typing — the AI SDK accepts the correct runtime values const requestOptions: Parameters[0] = { @@ -241,29 +210,6 @@ export class AnthropicVertexHandler extends BaseProvider implements SingleComple } } - /** - * Apply cacheControl providerOptions to the correct AI SDK messages by walking - * the original Anthropic messages and converted AI SDK messages in parallel. - * - * convertToAiSdkMessages() can split a single Anthropic user message (containing - * tool_results + text) into 2 AI SDK messages (tool role + user role). This method - * accounts for that split so cache control lands on the right message. - */ - private applyCacheControlToAiSdkMessages( - aiSdkMessages: { role: string; providerOptions?: Record> }[], - targetIndices: Set, - cacheProviderOption: Record>, - ): void { - for (const idx of targetIndices) { - if (idx >= 0 && idx < aiSdkMessages.length) { - aiSdkMessages[idx].providerOptions = { - ...aiSdkMessages[idx].providerOptions, - ...cacheProviderOption, - } - } - } - } - getModel() { const modelId = this.options.apiModelId let id = modelId && modelId in vertexModels ? (modelId as VertexModelId) : vertexDefaultModelId diff --git a/src/api/providers/anthropic.ts b/src/api/providers/anthropic.ts index 1f519250fa..47ced5c0c7 100644 --- a/src/api/providers/anthropic.ts +++ b/src/api/providers/anthropic.ts @@ -105,26 +105,6 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa anthropicProviderOptions.disableParallelToolUse = true } - // Apply cache control to user messages - // Strategy: cache the last 2 user messages (write-to-cache + read-from-cache) - const cacheProviderOption = { anthropic: { cacheControl: { type: "ephemeral" as const } } } - - const userMsgIndices = messages.reduce( - (acc, msg, index) => ("role" in msg && msg.role === "user" ? [...acc, index] : acc), - [] as number[], - ) - - const targetIndices = new Set() - const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1 - const secondLastUserMsgIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1 - - if (lastUserMsgIndex >= 0) targetIndices.add(lastUserMsgIndex) - if (secondLastUserMsgIndex >= 0) targetIndices.add(secondLastUserMsgIndex) - - if (targetIndices.size > 0) { - this.applyCacheControlToAiSdkMessages(messages as ModelMessage[], targetIndices, cacheProviderOption) - } - // Build streamText request // Cast providerOptions to any to bypass strict JSONObject typing — the AI SDK accepts the correct runtime values const requestOptions: Parameters[0] = { @@ -216,29 +196,6 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa } } - /** - * Apply cacheControl providerOptions to the correct AI SDK messages by walking - * the original Anthropic messages and converted AI SDK messages in parallel. - * - * convertToAiSdkMessages() can split a single Anthropic user message (containing - * tool_results + text) into 2 AI SDK messages (tool role + user role). This method - * accounts for that split so cache control lands on the right message. - */ - private applyCacheControlToAiSdkMessages( - aiSdkMessages: { role: string; providerOptions?: Record> }[], - targetIndices: Set, - cacheProviderOption: Record>, - ): void { - for (const idx of targetIndices) { - if (idx >= 0 && idx < aiSdkMessages.length) { - aiSdkMessages[idx].providerOptions = { - ...aiSdkMessages[idx].providerOptions, - ...cacheProviderOption, - } - } - } - } - getModel() { const modelId = this.options.apiModelId let id = modelId && modelId in anthropicModels ? (modelId as AnthropicModelId) : anthropicDefaultModelId diff --git a/src/api/providers/bedrock.ts b/src/api/providers/bedrock.ts index bf713ea016..171321787b 100644 --- a/src/api/providers/bedrock.ts +++ b/src/api/providers/bedrock.ts @@ -252,67 +252,10 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH } // Prompt caching: use AI SDK's cachePoint mechanism - // The AI SDK's @ai-sdk/amazon-bedrock supports cachePoint in providerOptions per message. - // - // Strategy: Bedrock allows up to 4 cache checkpoints. We use them as: - // 1. System prompt (via systemProviderOptions below) - // 2-4. Up to 3 user messages in the conversation history - // - // For the message cache points, we target the last 2 user messages (matching - // Anthropic's strategy: write-to-cache + read-from-cache) PLUS an earlier "anchor" - // user message near the middle of the conversation. This anchor ensures the 20-block - // lookback window has a stable cache entry to hit, covering all assistant/tool messages - // between the anchor and the recent messages. - // - // We identify targets in the ORIGINAL Anthropic messages (before AI SDK conversion) - // because convertToAiSdkMessages() splits user messages containing tool_results into - // separate "tool" + "user" role messages, which would skew naive counting. + // Determine whether to enable prompt caching for the system prompt. + // Per-message cache breakpoints are applied centrally in Task.ts. const usePromptCache = Boolean(this.options.awsUsePromptCache && this.supportsAwsPromptCache(modelConfig)) - if (usePromptCache) { - const cachePointOption = { bedrock: { cachePoint: { type: "default" as const } } } - - // Find all user message indices in the original (pre-conversion) message array. - const originalUserIndices = filteredMessages.reduce( - (acc, msg, idx) => ("role" in msg && msg.role === "user" ? [...acc, idx] : acc), - [], - ) - - // Select up to 3 user messages for cache points (system prompt uses the 4th): - // - Last user message: write to cache for next request - // - Second-to-last user message: read from cache for current request - // - An "anchor" message earlier in the conversation for 20-block window coverage - const targetOriginalIndices = new Set() - const numUserMsgs = originalUserIndices.length - - if (numUserMsgs >= 1) { - // Always cache the last user message - targetOriginalIndices.add(originalUserIndices[numUserMsgs - 1]) - } - if (numUserMsgs >= 2) { - // Cache the second-to-last user message - targetOriginalIndices.add(originalUserIndices[numUserMsgs - 2]) - } - if (numUserMsgs >= 5) { - // Add an anchor cache point roughly in the first third of user messages. - // This ensures that the 20-block lookback from the second-to-last breakpoint - // can find a stable cache entry, covering all the assistant and tool messages - // in the middle of the conversation. We pick the user message at ~1/3 position. - const anchorIdx = Math.floor(numUserMsgs / 3) - // Only add if it's not already one of the last-2 targets - if (!targetOriginalIndices.has(originalUserIndices[anchorIdx])) { - targetOriginalIndices.add(originalUserIndices[anchorIdx]) - } - } - - // Apply cachePoint to the correct AI SDK messages by walking both arrays in parallel. - // A single original user message with tool_results becomes [tool-role msg, user-role msg] - // in the AI SDK array, while a plain user message becomes [user-role msg]. - if (targetOriginalIndices.size > 0) { - this.applyCachePointsToAiSdkMessages(aiSdkMessages, targetOriginalIndices, cachePointOption) - } - } - // Build streamText request // Cast providerOptions to any to bypass strict JSONObject typing — the AI SDK accepts the correct runtime values const requestOptions: Parameters[0] = { @@ -706,29 +649,6 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH ) } - /** - * Apply cachePoint providerOptions to the correct AI SDK messages by walking - * the original Anthropic messages and converted AI SDK messages in parallel. - * - * convertToAiSdkMessages() can split a single Anthropic user message (containing - * tool_results + text) into 2 AI SDK messages (tool role + user role). This method - * accounts for that split so cache points land on the right message. - */ - private applyCachePointsToAiSdkMessages( - aiSdkMessages: { role: string; providerOptions?: Record> }[], - targetIndices: Set, - cachePointOption: Record>, - ): void { - for (const idx of targetIndices) { - if (idx >= 0 && idx < aiSdkMessages.length) { - aiSdkMessages[idx].providerOptions = { - ...aiSdkMessages[idx].providerOptions, - ...cachePointOption, - } - } - } - } - /************************************************************************************ * * AMAZON REGIONS diff --git a/src/api/providers/minimax.ts b/src/api/providers/minimax.ts index 17b0055e4e..69e4cc8f60 100644 --- a/src/api/providers/minimax.ts +++ b/src/api/providers/minimax.ts @@ -89,23 +89,6 @@ export class MiniMaxHandler extends BaseProvider implements SingleCompletionHand anthropicProviderOptions.disableParallelToolUse = true } - const cacheProviderOption = { anthropic: { cacheControl: { type: "ephemeral" as const } } } - const userMsgIndices = mergedMessages.reduce( - (acc, msg, index) => (msg.role === "user" ? [...acc, index] : acc), - [] as number[], - ) - - const targetIndices = new Set() - const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1 - const secondLastUserMsgIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1 - - if (lastUserMsgIndex >= 0) targetIndices.add(lastUserMsgIndex) - if (secondLastUserMsgIndex >= 0) targetIndices.add(secondLastUserMsgIndex) - - if (targetIndices.size > 0) { - this.applyCacheControlToAiSdkMessages(aiSdkMessages, targetIndices, cacheProviderOption) - } - const requestOptions = { model: this.client(modelConfig.id), system: systemPrompt, @@ -187,21 +170,6 @@ export class MiniMaxHandler extends BaseProvider implements SingleCompletionHand } } - private applyCacheControlToAiSdkMessages( - aiSdkMessages: { role: string; providerOptions?: Record> }[], - targetIndices: Set, - cacheProviderOption: Record>, - ): void { - for (const idx of targetIndices) { - if (idx >= 0 && idx < aiSdkMessages.length) { - aiSdkMessages[idx].providerOptions = { - ...aiSdkMessages[idx].providerOptions, - ...cacheProviderOption, - } - } - } - } - getModel() { const modelId = this.options.apiModelId diff --git a/src/api/transform/__tests__/prompt-cache.spec.ts b/src/api/transform/__tests__/prompt-cache.spec.ts new file mode 100644 index 0000000000..834715e0a4 --- /dev/null +++ b/src/api/transform/__tests__/prompt-cache.spec.ts @@ -0,0 +1,232 @@ +import type { + RooMessage, + RooUserMessage, + RooAssistantMessage, + RooToolMessage, +} from "../../../core/task-persistence/rooMessage" +import type { ModelInfo } from "@roo-code/types" +import { resolveCacheProviderOptions, applyCacheBreakpoints, type PromptCacheConfig } from "../prompt-cache" + +// ──────────────────────────────────────────────────────────────────────────── +// Test Helpers +// ──────────────────────────────────────────────────────────────────────────── + +/** Shorthand to read the runtime-assigned `providerOptions` from a mutated message. */ +function opts(msg: RooMessage): Record> | undefined { + return (msg as unknown as { providerOptions?: Record> }).providerOptions +} + +/** Shorthand to set `providerOptions` on a message (for pre-existing options tests). */ +function setOpts(msg: RooMessage, value: Record): void { + ;(msg as unknown as { providerOptions: Record }).providerOptions = value +} + +function makeUserMsg(text: string = "hi"): RooUserMessage { + return { role: "user", content: text } +} + +function makeAssistantMsg(text: string = "hello"): RooAssistantMessage { + return { role: "assistant", content: text } +} + +function makeToolMsg(toolCallId: string = "tool-1"): RooToolMessage { + return { + role: "tool", + content: [ + { + type: "tool-result", + toolCallId, + toolName: "test", + output: { type: "text", value: "ok" }, + }, + ], + } as RooToolMessage +} + +// ──────────────────────────────────────────────────────────────────────────── +// Shared Fixtures +// ──────────────────────────────────────────────────────────────────────────── + +const baseModelInfo = { + contextWindow: 200_000, + supportsPromptCache: true, +} as ModelInfo + +const noPromptCacheModelInfo = { + contextWindow: 200_000, + supportsPromptCache: false, +} as ModelInfo + +const ANTHROPIC_CACHE = { anthropic: { cacheControl: { type: "ephemeral" } } } +const BEDROCK_CACHE = { bedrock: { cachePoint: { type: "default" } } } + +// ──────────────────────────────────────────────────────────────────────────── +// resolveCacheProviderOptions +// ──────────────────────────────────────────────────────────────────────────── + +describe("resolveCacheProviderOptions", () => { + it("returns null when model does not support prompt cache", () => { + const config: PromptCacheConfig = { + providerName: "anthropic", + modelInfo: noPromptCacheModelInfo, + } + + expect(resolveCacheProviderOptions(config)).toBeNull() + }) + + it("returns null for an unknown provider", () => { + const config: PromptCacheConfig = { + providerName: "some-unknown-provider", + modelInfo: baseModelInfo, + } + + expect(resolveCacheProviderOptions(config)).toBeNull() + }) + + it("returns anthropic cache options for anthropic provider", () => { + const config: PromptCacheConfig = { + providerName: "anthropic", + modelInfo: baseModelInfo, + } + + expect(resolveCacheProviderOptions(config)).toEqual(ANTHROPIC_CACHE) + }) + + it("returns anthropic cache options for vertex provider", () => { + const config: PromptCacheConfig = { + providerName: "vertex", + modelInfo: baseModelInfo, + } + + expect(resolveCacheProviderOptions(config)).toEqual(ANTHROPIC_CACHE) + }) + + it("returns anthropic cache options for minimax provider", () => { + const config: PromptCacheConfig = { + providerName: "minimax", + modelInfo: baseModelInfo, + } + + expect(resolveCacheProviderOptions(config)).toEqual(ANTHROPIC_CACHE) + }) + + it("returns bedrock cache options when awsUsePromptCache is enabled", () => { + const config: PromptCacheConfig = { + providerName: "bedrock", + modelInfo: baseModelInfo, + providerSettings: { awsUsePromptCache: true }, + } + + expect(resolveCacheProviderOptions(config)).toEqual(BEDROCK_CACHE) + }) + + it("returns null for bedrock without awsUsePromptCache", () => { + const config: PromptCacheConfig = { + providerName: "bedrock", + modelInfo: baseModelInfo, + } + + expect(resolveCacheProviderOptions(config)).toBeNull() + }) +}) + +// ──────────────────────────────────────────────────────────────────────────── +// applyCacheBreakpoints +// ──────────────────────────────────────────────────────────────────────────── + +describe("applyCacheBreakpoints", () => { + it("returns empty array unchanged when no messages", () => { + const messages: RooMessage[] = [] + const result = applyCacheBreakpoints(messages, ANTHROPIC_CACHE) + + expect(result).toHaveLength(0) + expect(result).toBe(messages) // same reference — mutates in place + }) + + it("places breakpoint on a single user message", () => { + const messages: RooMessage[] = [makeUserMsg()] + + applyCacheBreakpoints(messages, ANTHROPIC_CACHE) + + expect(opts(messages[0])).toEqual(ANTHROPIC_CACHE) + }) + + it("places breakpoint on a single tool message", () => { + const messages: RooMessage[] = [makeToolMsg()] + + applyCacheBreakpoints(messages, ANTHROPIC_CACHE) + + expect(opts(messages[0])).toEqual(ANTHROPIC_CACHE) + }) + + it("places breakpoints at end of each non-assistant batch (user → assistant → tool → user)", () => { + // Batch 1: [user(0)] Batch 2: [tool(2), user(3)] + const messages: RooMessage[] = [ + makeUserMsg("u0"), // index 0 — batch 1 end + makeAssistantMsg("a1"), // index 1 — assistant (separator) + makeToolMsg("t2"), // index 2 — batch 2 start + makeUserMsg("u3"), // index 3 — batch 2 end + ] + + applyCacheBreakpoints(messages, ANTHROPIC_CACHE, 2) + + expect(opts(messages[0])).toEqual(ANTHROPIC_CACHE) + expect(opts(messages[1])).toBeUndefined() + expect(opts(messages[2])).toBeUndefined() + expect(opts(messages[3])).toEqual(ANTHROPIC_CACHE) + }) + + it("only targets last message in each batch for consecutive non-assistant messages", () => { + // Batch 1: [tool(0), user(1)] Batch 2: [tool(3), user(4)] + const messages: RooMessage[] = [ + makeToolMsg("t0"), // index 0 — batch 1 start + makeUserMsg("u1"), // index 1 — batch 1 end + makeAssistantMsg("a2"), // index 2 — assistant (separator) + makeToolMsg("t3"), // index 3 — batch 2 start + makeUserMsg("u4"), // index 4 — batch 2 end + ] + + applyCacheBreakpoints(messages, ANTHROPIC_CACHE, 2) + + expect(opts(messages[0])).toBeUndefined() + expect(opts(messages[1])).toEqual(ANTHROPIC_CACHE) + expect(opts(messages[2])).toBeUndefined() + expect(opts(messages[3])).toBeUndefined() + expect(opts(messages[4])).toEqual(ANTHROPIC_CACHE) + }) + + it("places anchor breakpoint at ~1/3 for long conversations with useAnchor", () => { + // Build alternating user/assistant: 10 messages, 5 non-assistant batches + // Batches (by end index): [0], [2], [4], [6], [8] + const messages: RooMessage[] = [] + for (let i = 0; i < 5; i++) { + messages.push(makeUserMsg(`u${i * 2}`)) + messages.push(makeAssistantMsg(`a${i * 2 + 1}`)) + } + + // maxBreakpoints=2, useAnchor=true, anchorThreshold=5 + applyCacheBreakpoints(messages, BEDROCK_CACHE, 2, true, 5) + + // targetBatches = last 2 batch-ends: indices 6 and 8 + // anchorIndex = Math.floor(10 / 3) = 3 → first batch with end >= 3 is batch[4] + // batch[4] not in targetBatches → gets anchor breakpoint + // Total breakpoints at indices: 4 (anchor), 6 (trailing), 8 (trailing) + expect(opts(messages[0])).toBeUndefined() + expect(opts(messages[2])).toBeUndefined() + expect(opts(messages[4])).toEqual(BEDROCK_CACHE) + expect(opts(messages[6])).toEqual(BEDROCK_CACHE) + expect(opts(messages[8])).toEqual(BEDROCK_CACHE) + }) + + it("preserves existing providerOptions on message (merged, not replaced)", () => { + const messages: RooMessage[] = [makeUserMsg()] + setOpts(messages[0], { openai: { someOption: true } }) + + applyCacheBreakpoints(messages, ANTHROPIC_CACHE) + + const result = opts(messages[0]) + // Existing provider-specific key should be preserved alongside cache options + expect(result?.openai).toEqual({ someOption: true }) + expect(result?.anthropic).toEqual({ cacheControl: { type: "ephemeral" } }) + }) +}) diff --git a/src/api/transform/prompt-cache.ts b/src/api/transform/prompt-cache.ts new file mode 100644 index 0000000000..e2ae83f366 --- /dev/null +++ b/src/api/transform/prompt-cache.ts @@ -0,0 +1,174 @@ +import type { RooMessage } from "../../core/task-persistence/rooMessage" +import type { ModelInfo } from "@roo-code/types" + +// ──────────────────────────────────────────────────────────────────────────── +// Types +// ──────────────────────────────────────────────────────────────────────────── + +export interface PromptCacheConfig { + providerName: string + modelInfo: ModelInfo + providerSettings?: Record + maxMessageBreakpoints?: number +} + +/** Provider-specific options object attached to individual messages. */ +type CacheProviderOptions = Record> + +/** + * Typed intersection so we can set `providerOptions` on a `RooMessage` + * without widening to `any`. The base `RooMessage` union doesn't declare + * this field, but the AI SDK runtime accepts it. + */ +type MessageWithProviderOptions = RooMessage & { + providerOptions?: CacheProviderOptions +} + +// ──────────────────────────────────────────────────────────────────────────── +// Constants +// ──────────────────────────────────────────────────────────────────────────── + +const ANTHROPIC_CACHE_OPTIONS: CacheProviderOptions = { + anthropic: { cacheControl: { type: "ephemeral" } }, +} + +const BEDROCK_CACHE_OPTIONS: CacheProviderOptions = { + bedrock: { cachePoint: { type: "default" } }, +} + +/** Providers that use the Anthropic-style cache control object. */ +const ANTHROPIC_STYLE_PROVIDERS: readonly string[] = ["anthropic", "vertex", "minimax"] + +// ──────────────────────────────────────────────────────────────────────────── +// resolveCacheProviderOptions +// ──────────────────────────────────────────────────────────────────────────── + +/** + * Returns the provider-specific cache options object for the given + * configuration, or `null` when prompt caching is not supported / enabled. + * + * Decision order: + * 1. Model must declare `supportsPromptCache`. + * 2. Bedrock requires `awsUsePromptCache` in provider settings. + * 3. Anthropic / Vertex / Minimax → Anthropic-style ephemeral cache. + * 4. Bedrock → Bedrock-style cache point. + * 5. All other providers → `null`. + */ +export function resolveCacheProviderOptions(config: PromptCacheConfig): CacheProviderOptions | null { + const { providerName, modelInfo, providerSettings } = config + + if (!modelInfo.supportsPromptCache) { + return null + } + + // Bedrock gate: must have awsUsePromptCache enabled. + if (providerName === "bedrock" && !providerSettings?.awsUsePromptCache) { + return null + } + + if (ANTHROPIC_STYLE_PROVIDERS.includes(providerName)) { + return ANTHROPIC_CACHE_OPTIONS + } + + if (providerName === "bedrock") { + return BEDROCK_CACHE_OPTIONS + } + + return null +} + +// ──────────────────────────────────────────────────────────────────────────── +// applyCacheBreakpoints +// ──────────────────────────────────────────────────────────────────────────── + +/** A contiguous run of non-assistant messages identified by inclusive indices. */ +interface Batch { + start: number + end: number +} + +/** + * Applies cache breakpoints to an array of {@link RooMessage} by setting + * `providerOptions` on strategically chosen messages. + * + * **Strategy** + * 1. Identify "non-assistant batches" — consecutive runs of messages whose + * role is *not* `"assistant"` (user, tool, reasoning, etc.). + * 2. Target the **last** message in each batch (the natural boundary before + * an assistant turn). + * 3. Pick the last `maxMessageBreakpoints` batches. + * 4. Optionally place an "anchor" breakpoint at roughly 1/3 through the + * conversation when `messages.length >= anchorThreshold`, which helps the + * provider reuse cached prefixes across turns in long conversations. + * + * Mutates `messages` in place and returns the same array for chaining. + * + * @param messages The conversation history. + * @param cacheProviderOptions Provider-specific options (from {@link resolveCacheProviderOptions}). + * @param maxMessageBreakpoints Maximum trailing breakpoints to place (default `2`). + * @param useAnchor Whether to add an anchor breakpoint for long conversations. + * @param anchorThreshold Minimum message count before an anchor is considered (default `20`). + */ +export function applyCacheBreakpoints( + messages: RooMessage[], + cacheProviderOptions: CacheProviderOptions, + maxMessageBreakpoints: number = 2, + useAnchor: boolean = false, + anchorThreshold: number = 20, +): RooMessage[] { + if (messages.length === 0 || maxMessageBreakpoints <= 0) { + return messages + } + + // ── 1. Identify non-assistant batches ─────────────────────────────── + const batches: Batch[] = [] + let batchStart: number | null = null + + for (let i = 0; i < messages.length; i++) { + const isAssistant = "role" in messages[i] && (messages[i] as { role: string }).role === "assistant" + + if (!isAssistant) { + if (batchStart === null) { + batchStart = i + } + } else { + if (batchStart !== null) { + batches.push({ start: batchStart, end: i - 1 }) + batchStart = null + } + } + } + + // Close a trailing batch that runs to the end of the array. + if (batchStart !== null) { + batches.push({ start: batchStart, end: messages.length - 1 }) + } + + if (batches.length === 0) { + return messages + } + + // ── 2. Pick the last N batches ────────────────────────────────────── + const targetBatches = batches.slice(-maxMessageBreakpoints) + + // ── 3. Apply breakpoints ──────────────────────────────────────────── + for (const batch of targetBatches) { + const target = messages[batch.end] as MessageWithProviderOptions + target.providerOptions = { ...target.providerOptions, ...cacheProviderOptions } + } + + // ── 4. Optional anchor at ~1/3 of conversation ───────────────────── + if (useAnchor && messages.length >= anchorThreshold && batches.length > maxMessageBreakpoints) { + const anchorIndex = Math.floor(messages.length / 3) + + // Find the first batch whose end is at or past the anchor point. + const anchorBatch = batches.find((b) => b.end >= anchorIndex) + + if (anchorBatch && !targetBatches.includes(anchorBatch)) { + const anchor = messages[anchorBatch.end] as MessageWithProviderOptions + anchor.providerOptions = { ...anchor.providerOptions, ...cacheProviderOptions } + } + } + + return messages +} diff --git a/src/core/task/Task.ts b/src/core/task/Task.ts index ee9f15e4e3..a4ec7eab74 100644 --- a/src/core/task/Task.ts +++ b/src/core/task/Task.ts @@ -63,6 +63,7 @@ import { ApiHandler, ApiHandlerCreateMessageMetadata, buildApiHandler } from ".. import type { AssistantModelMessage } from "ai" import { ApiStream, GroundingSource } from "../../api/transform/stream" import { maybeRemoveImageBlocks } from "../../api/transform/image-cleaning" +import { resolveCacheProviderOptions, applyCacheBreakpoints } from "../../api/transform/prompt-cache" // shared import { findLastIndex } from "../../shared/array" @@ -4472,6 +4473,23 @@ export class Task extends EventEmitter implements TaskLike { // Reset the flag after using it this.skipPrevResponseIdOnce = false + // Apply cache breakpoints if the provider/model supports it + const cacheOptions = resolveCacheProviderOptions({ + providerName: apiConfiguration?.apiProvider ?? "", + modelInfo, + providerSettings: apiConfiguration as Record, + }) + if (cacheOptions) { + const isBedrock = apiConfiguration?.apiProvider === "bedrock" + applyCacheBreakpoints( + cleanConversationHistory, + cacheOptions, + isBedrock ? 3 : 2, + isBedrock, // useAnchor + 5, // anchorThreshold + ) + } + const stream = this.api.createMessage(systemPrompt, cleanConversationHistory, metadata) const iterator = stream[Symbol.asyncIterator]()