diff --git a/packages/types/src/provider-settings.ts b/packages/types/src/provider-settings.ts index bf3364d38d7..2ac8edb1055 100644 --- a/packages/types/src/provider-settings.ts +++ b/packages/types/src/provider-settings.ts @@ -208,6 +208,7 @@ const openRouterSchema = baseProviderSettingsSchema.extend({ openRouterModelId: z.string().optional(), openRouterBaseUrl: z.string().optional(), openRouterSpecificProvider: z.string().optional(), + openRouterExcludeLowQuantization: z.boolean().optional(), }) const bedrockSchema = apiModelIdProviderModelSchema.extend({ diff --git a/src/api/providers/__tests__/openrouter.spec.ts b/src/api/providers/__tests__/openrouter.spec.ts index ba039459202..4db1da9d2d4 100644 --- a/src/api/providers/__tests__/openrouter.spec.ts +++ b/src/api/providers/__tests__/openrouter.spec.ts @@ -1166,4 +1166,136 @@ describe("OpenRouterHandler", () => { ) }) }) + + describe("quantization filter", () => { + it("includes quantizations in providerOptions when openRouterExcludeLowQuantization is enabled", async () => { + const handler = new OpenRouterHandler({ + openRouterApiKey: "test-key", + openRouterModelId: "openai/gpt-4o", + openRouterExcludeLowQuantization: true, + }) + + const mockFullStream = (async function* () { + yield { type: "text-delta", text: "test", id: "1" } + })() + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream, + usage: Promise.resolve({ inputTokens: 10, outputTokens: 20, totalTokens: 30 }), + totalUsage: Promise.resolve({ inputTokens: 10, outputTokens: 20, totalTokens: 30 }), + }) + + const generator = handler.createMessage("test", [{ role: "user", content: "test" }]) + + for await (const _ of generator) { + // consume + } + + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + providerOptions: { + openrouter: { + provider: { + quantizations: ["fp16", "bf16", "fp8", "int8"], + }, + }, + }, + }), + ) + }) + + it("does not include quantizations in providerOptions when openRouterExcludeLowQuantization is disabled", async () => { + const handler = new OpenRouterHandler({ + openRouterApiKey: "test-key", + openRouterModelId: "openai/gpt-4o", + openRouterExcludeLowQuantization: false, + }) + + const mockFullStream = (async function* () { + yield { type: "text-delta", text: "test", id: "1" } + })() + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream, + usage: Promise.resolve({ inputTokens: 10, outputTokens: 20, totalTokens: 30 }), + totalUsage: Promise.resolve({ inputTokens: 10, outputTokens: 20, totalTokens: 30 }), + }) + + const generator = handler.createMessage("test", [{ role: "user", content: "test" }]) + + for await (const _ of generator) { + // consume + } + + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + providerOptions: undefined, + }), + ) + }) + + it("combines quantizations with specific provider routing", async () => { + const handler = new OpenRouterHandler({ + openRouterApiKey: "test-key", + openRouterModelId: "openai/gpt-4o", + openRouterExcludeLowQuantization: true, + openRouterSpecificProvider: "DeepInfra", + }) + + const mockFullStream = (async function* () { + yield { type: "text-delta", text: "test", id: "1" } + })() + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream, + usage: Promise.resolve({ inputTokens: 10, outputTokens: 20, totalTokens: 30 }), + totalUsage: Promise.resolve({ inputTokens: 10, outputTokens: 20, totalTokens: 30 }), + }) + + const generator = handler.createMessage("test", [{ role: "user", content: "test" }]) + + for await (const _ of generator) { + // consume + } + + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + providerOptions: { + openrouter: { + provider: { + order: ["DeepInfra"], + only: ["DeepInfra"], + allow_fallbacks: false, + quantizations: ["fp16", "bf16", "fp8", "int8"], + }, + }, + }, + }), + ) + }) + + it("includes quantizations in completePrompt when openRouterExcludeLowQuantization is enabled", async () => { + const handler = new OpenRouterHandler({ + openRouterApiKey: "test-key", + openRouterModelId: "openai/gpt-4o", + openRouterExcludeLowQuantization: true, + }) + + mockGenerateText.mockResolvedValue({ text: "test" }) + + await handler.completePrompt("test prompt") + + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + providerOptions: { + openrouter: { + provider: { + quantizations: ["fp16", "bf16", "fp8", "int8"], + }, + }, + }, + }), + ) + }) + }) }) diff --git a/src/api/providers/openrouter.ts b/src/api/providers/openrouter.ts index d48fc4bb430..98bec153105 100644 --- a/src/api/providers/openrouter.ts +++ b/src/api/providers/openrouter.ts @@ -155,25 +155,7 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH const tools = convertToolsForAiSdk(metadata?.tools) - const providerOptions: - | { - openrouter?: { - provider?: { order: string[]; only: string[]; allow_fallbacks: boolean } - } - } - | undefined = - this.options.openRouterSpecificProvider && - this.options.openRouterSpecificProvider !== OPENROUTER_DEFAULT_PROVIDER_NAME - ? { - openrouter: { - provider: { - order: [this.options.openRouterSpecificProvider], - only: [this.options.openRouterSpecificProvider], - allow_fallbacks: false, - }, - }, - } - : undefined + const providerOptions = this.buildProviderOptions() let accumulatedReasoningText = "" @@ -281,6 +263,47 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH return { id, info, topP: isDeepSeekR1 ? 0.95 : undefined, ...params } } + private buildProviderOptions(): + | { + openrouter?: { + provider?: { + order?: string[] + only?: string[] + allow_fallbacks?: boolean + quantizations?: string[] + } + } + } + | undefined { + const hasSpecificProvider = + this.options.openRouterSpecificProvider && + this.options.openRouterSpecificProvider !== OPENROUTER_DEFAULT_PROVIDER_NAME + const excludeLowQuantization = this.options.openRouterExcludeLowQuantization + + if (!hasSpecificProvider && !excludeLowQuantization) { + return undefined + } + + const provider: { + order?: string[] + only?: string[] + allow_fallbacks?: boolean + quantizations?: string[] + } = {} + + if (hasSpecificProvider) { + provider.order = [this.options.openRouterSpecificProvider!] + provider.only = [this.options.openRouterSpecificProvider!] + provider.allow_fallbacks = false + } + + if (excludeLowQuantization) { + provider.quantizations = ["fp16", "bf16", "fp8", "int8"] + } + + return { openrouter: { provider } } + } + async completePrompt(prompt: string): Promise { let { id: modelId, maxTokens, temperature, topP, reasoning } = await this.fetchModel() @@ -298,25 +321,7 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH const openrouter = this.createOpenRouterProvider({ reasoning, headers }) - const providerOptions: - | { - openrouter?: { - provider?: { order: string[]; only: string[]; allow_fallbacks: boolean } - } - } - | undefined = - this.options.openRouterSpecificProvider && - this.options.openRouterSpecificProvider !== OPENROUTER_DEFAULT_PROVIDER_NAME - ? { - openrouter: { - provider: { - order: [this.options.openRouterSpecificProvider], - only: [this.options.openRouterSpecificProvider], - allow_fallbacks: false, - }, - }, - } - : undefined + const providerOptions = this.buildProviderOptions() try { const result = await generateText({ diff --git a/webview-ui/src/components/settings/providers/OpenRouter.tsx b/webview-ui/src/components/settings/providers/OpenRouter.tsx index 2dba8c8459f..8ada9cc6bb4 100644 --- a/webview-ui/src/components/settings/providers/OpenRouter.tsx +++ b/webview-ui/src/components/settings/providers/OpenRouter.tsx @@ -103,6 +103,18 @@ export const OpenRouter = ({ )} )} +
+ { + setApiConfigurationField("openRouterExcludeLowQuantization", checked) + }}> + {t("settings:providers.openRouter.excludeLowQuantization.label")} + +
+ {t("settings:providers.openRouter.excludeLowQuantization.description")} +
+