diff --git a/CHANGELOG.md b/CHANGELOG.md index 365de623a47..d499511c954 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,31 @@ # Changelog +### [Version 1.138.4](https://github.com/lobehub/lobe-chat/compare/v1.138.3...v1.138.4) + +Released on **2025-10-18** + +#### 🐛 Bug Fixes + +- **misc**: Fix response API tools calling issue. + +
+ +
+Improvements and Fixes + +#### What's fixed + +- **misc**: Fix response API tools calling issue, closes [#9760](https://github.com/lobehub/lobe-chat/issues/9760) ([0596692](https://github.com/lobehub/lobe-chat/commit/0596692)) + +
+ +
+ +[![](https://img.shields.io/badge/-BACK_TO_TOP-151515?style=flat-square)](#readme-top) + +
+ ### [Version 1.138.3](https://github.com/lobehub/lobe-chat/compare/v1.138.2...v1.138.3) Released on **2025-10-18** diff --git a/changelog/v1.json b/changelog/v1.json index 76a9629f3d0..33af04172d6 100644 --- a/changelog/v1.json +++ b/changelog/v1.json @@ -1,4 +1,11 @@ [ + { + "children": { + "fixes": ["Fix response API tools calling issue."] + }, + "date": "2025-10-18", + "version": "1.138.4" + }, { "children": { "fixes": ["Fix topic fetch not correct in custom agent."] diff --git a/package.json b/package.json index 9560e8a16f5..5cf7b91caa5 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@lobehub/chat", - "version": "1.138.3", + "version": "1.138.4", "description": "Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.", "keywords": [ "framework", diff --git a/packages/database/src/repositories/aiInfra/index.test.ts b/packages/database/src/repositories/aiInfra/index.test.ts index 2c3a380ac43..3174799380d 100644 --- a/packages/database/src/repositories/aiInfra/index.test.ts +++ b/packages/database/src/repositories/aiInfra/index.test.ts @@ -419,6 +419,318 @@ describe('AiInfraRepos', () => { // For custom provider, when user enables search with no builtin settings, default to 'params' expect(merged?.settings).toEqual({ searchImpl: 'params' }); }); + + // 测试场景:用户模型 abilitie 为空(Empty),而基础模型有搜索能力和设置 + it('should retain builtin abilities and settings when user model has no abilities (empty) and builtin has settings', async () => { + const mockProviders = [ + { enabled: true, id: 'openai', name: 'OpenAI', source: 'builtin' as const }, + ]; + + const userModel: EnabledAiModel = { + id: 'gpt-4', + providerId: 'openai', + enabled: true, + type: 'chat', + abilities: {}, // Empty object, no search + }; + + const builtinModel = { + id: 'gpt-4', + enabled: true, + type: 'chat' as const, + abilities: { search: false }, // 使用 builtin abilities + settings: { searchImpl: 'params', searchProvider: 'google' }, // builtin has settings + }; + + vi.spyOn(repo, 'getAiProviderList').mockResolvedValue(mockProviders); + vi.spyOn(repo.aiModelModel, 'getAllModels').mockResolvedValue([userModel]); + vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue([builtinModel]); + + const result = await repo.getEnabledModels(); + + const merged = result.find((m) => m.id === 'gpt-4'); + expect(merged).toBeDefined(); + // 使用 builtin abilities + expect(merged?.abilities?.search).toEqual(false); + // 删去 builtin settings + expect(merged?.settings).toBeUndefined(); + }); + + it('should retain builtin abilities and settings when user model has no abilities (empty) and builtin has settings', async () => { + const mockProviders = [ + { enabled: true, id: 'openai', name: 'OpenAI', source: 'builtin' as const }, + ]; + + const userModel: EnabledAiModel = { + id: 'gpt-4', + providerId: 'openai', + enabled: true, + type: 'chat', + abilities: {}, // Empty object, no search + }; + + const builtinModel = { + id: 'gpt-4', + enabled: true, + type: 'chat' as const, + abilities: { search: true }, // 使用 builtin abilities + settings: { searchImpl: 'params', searchProvider: 'google' }, // builtin has settings + }; + + vi.spyOn(repo, 'getAiProviderList').mockResolvedValue(mockProviders); + vi.spyOn(repo.aiModelModel, 'getAllModels').mockResolvedValue([userModel]); + vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue([builtinModel]); + + const result = await repo.getEnabledModels(); + + const merged = result.find((m) => m.id === 'gpt-4'); + expect(merged).toBeDefined(); + // 使用 builtin abilities + expect(merged?.abilities?.search).toEqual(true); + // 保留 builtin settings + expect(merged?.settings).toEqual({ searchImpl: 'params', searchProvider: 'google' }); + }); + + // 测试场景:用户模型未启用搜索(abilities.search 为 undefined),而基础模型有搜索能力和设置 + it('should retain builtin settings when user model has no abilities.search (undefined) and builtin has settings', async () => { + const mockProviders = [ + { enabled: true, id: 'openai', name: 'OpenAI', source: 'builtin' as const }, + ]; + + const userModel: EnabledAiModel = { + id: 'gpt-4', + providerId: 'openai', + enabled: true, + type: 'chat', + abilities: { vision: true }, // 启用 vision 能力, no search + }; + + const builtinModel = { + id: 'gpt-4', + enabled: true, + type: 'chat' as const, + abilities: { search: false }, // builtin abilities 不生效 + settings: { searchImpl: 'params', searchProvider: 'google' }, // builtin has settings + }; + + vi.spyOn(repo, 'getAiProviderList').mockResolvedValue(mockProviders); + vi.spyOn(repo.aiModelModel, 'getAllModels').mockResolvedValue([userModel]); + vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue([builtinModel]); + + const result = await repo.getEnabledModels(); + + const merged = result.find((m) => m.id === 'gpt-4'); + expect(merged).toBeDefined(); + // abilities.search 仍 undefined(兼容老版本) + expect(merged?.abilities?.search).toBeUndefined(); + // 保留 builtin settings + expect(merged?.settings).toEqual({ searchImpl: 'params', searchProvider: 'google' }); + }); + + it('should retain builtin settings when user model has no abilities.search (undefined) and builtin has settings', async () => { + const mockProviders = [ + { enabled: true, id: 'openai', name: 'OpenAI', source: 'builtin' as const }, + ]; + + const userModel: EnabledAiModel = { + id: 'gpt-4', + providerId: 'openai', + enabled: true, + type: 'chat', + abilities: { vision: true }, // 启用 vision 能力, no search + }; + + const builtinModel = { + id: 'gpt-4', + enabled: true, + type: 'chat' as const, + abilities: { search: true }, // builtin abilities 不生效 + settings: { searchImpl: 'params', searchProvider: 'google' }, // builtin has settings + }; + + vi.spyOn(repo, 'getAiProviderList').mockResolvedValue(mockProviders); + vi.spyOn(repo.aiModelModel, 'getAllModels').mockResolvedValue([userModel]); + vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue([builtinModel]); + + const result = await repo.getEnabledModels(); + + const merged = result.find((m) => m.id === 'gpt-4'); + expect(merged).toBeDefined(); + // abilities.search 仍 undefined(兼容老版本) + expect(merged?.abilities?.search).toBeUndefined(); + // 保留 builtin settings + expect(merged?.settings).toEqual({ searchImpl: 'params', searchProvider: 'google' }); + }); + + // 测试场景:用户模型未启用搜索(abilities.search 为 undefined),而基础模型也无搜索能力和设置 + it('should retain no settings when user model has no abilities.search (undefined) and builtin has no settings', async () => { + const mockProviders = [ + { enabled: true, id: 'openai', name: 'OpenAI', source: 'builtin' as const }, + ]; + + const userModel: EnabledAiModel = { + id: 'gpt-4', + providerId: 'openai', + enabled: true, + type: 'chat', + abilities: {}, // 无 search + }; + + const builtinModel = { + id: 'gpt-4', + enabled: true, + type: 'chat' as const, + abilities: {}, + // builtin 无 settings + }; + + vi.spyOn(repo, 'getAiProviderList').mockResolvedValue(mockProviders); + vi.spyOn(repo.aiModelModel, 'getAllModels').mockResolvedValue([userModel]); + vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue([builtinModel]); + + const result = await repo.getEnabledModels(); + + const merged = result.find((m) => m.id === 'gpt-4'); + expect(merged).toBeDefined(); + expect(merged?.abilities?.search).toBeUndefined(); + // 无 settings + expect(merged?.settings).toBeUndefined(); + }); + + // 测试:用户模型有 abilities.search: true + it('should inject defaults when user has search: true, no existing settings (builtin none)', async () => { + const mockProviders = [ + { enabled: true, id: 'openai', name: 'OpenAI', source: 'builtin' as const }, + ]; + + const userModel: EnabledAiModel = { + id: 'gpt-4', + providerId: 'openai', + enabled: true, + type: 'chat', + abilities: { search: true }, // 用户启用 search + }; + + const builtinModel = { + id: 'gpt-4', + enabled: true, + type: 'chat' as const, + abilities: {}, + // 无 settings + }; + + vi.spyOn(repo, 'getAiProviderList').mockResolvedValue(mockProviders); + vi.spyOn(repo.aiModelModel, 'getAllModels').mockResolvedValue([userModel]); + vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue([builtinModel]); + + const result = await repo.getEnabledModels(); + + const merged = result.find((m) => m.id === 'gpt-4'); + expect(merged).toBeDefined(); + expect(merged?.abilities).toEqual({ search: true }); + // 注入 defaults (openai: params) + expect(merged?.settings).toEqual({ searchImpl: 'params' }); + }); + + it('should retain existing settings when user has search: true and builtin has settings', async () => { + const mockProviders = [ + { enabled: true, id: 'openai', name: 'OpenAI', source: 'builtin' as const }, + ]; + + const userModel: EnabledAiModel = { + id: 'gpt-4', + providerId: 'openai', + enabled: true, + type: 'chat', + abilities: { search: true }, + }; + + const builtinModel = { + id: 'gpt-4', + enabled: true, + type: 'chat' as const, + settings: { searchImpl: 'tool' }, // builtin 有 settings + }; + + vi.spyOn(repo, 'getAiProviderList').mockResolvedValue(mockProviders); + vi.spyOn(repo.aiModelModel, 'getAllModels').mockResolvedValue([userModel]); + vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue([builtinModel]); + + const result = await repo.getEnabledModels(); + + const merged = result.find((m) => m.id === 'gpt-4'); + expect(merged).toBeDefined(); + expect(merged?.abilities).toEqual({ search: true }); + // 使用 builtin settings + expect(merged?.settings).toEqual({ searchImpl: 'tool' }); + }); + + // 测试:用户模型有 abilities.search: false + it('should remove settings when user has search: false and builtin has settings', async () => { + const mockProviders = [ + { enabled: true, id: 'openai', name: 'OpenAI', source: 'builtin' as const }, + ]; + + const userModel: EnabledAiModel = { + id: 'gpt-4', + providerId: 'openai', + enabled: true, + type: 'chat', + abilities: { search: false }, // 用户禁用 search + }; + + const builtinModel = { + id: 'gpt-4', + enabled: true, + type: 'chat' as const, + settings: { searchImpl: 'tool', extendParams: [] }, // builtin 有 settings + }; + + vi.spyOn(repo, 'getAiProviderList').mockResolvedValue(mockProviders); + vi.spyOn(repo.aiModelModel, 'getAllModels').mockResolvedValue([userModel]); + vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue([builtinModel]); + + const result = await repo.getEnabledModels(); + + const merged = result.find((m) => m.id === 'gpt-4'); + expect(merged).toBeDefined(); + expect(merged?.abilities).toEqual({ search: false }); + // 移除 search 相关,保留其他 + expect(merged?.settings).toEqual({ extendParams: [] }); + }); + + it('should keep no settings when user has search: false and no existing settings', async () => { + const mockProviders = [ + { enabled: true, id: 'openai', name: 'OpenAI', source: 'builtin' as const }, + ]; + + const userModel: EnabledAiModel = { + id: 'gpt-4', + providerId: 'openai', + enabled: true, + type: 'chat', + abilities: { search: false }, + }; + + const builtinModel = { + id: 'gpt-4', + enabled: true, + type: 'chat' as const, + // 无 settings + }; + + vi.spyOn(repo, 'getAiProviderList').mockResolvedValue(mockProviders); + vi.spyOn(repo.aiModelModel, 'getAllModels').mockResolvedValue([userModel]); + vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue([builtinModel]); + + const result = await repo.getEnabledModels(); + + const merged = result.find((m) => m.id === 'gpt-4'); + expect(merged).toBeDefined(); + expect(merged?.abilities).toEqual({ search: false }); + // 无 settings + expect(merged?.settings).toBeUndefined(); + }); }); describe('getAiProviderModelList', () => { @@ -614,6 +926,350 @@ describe('AiInfraRepos', () => { // For custom provider, when user enables search with no builtin settings, default to 'params' expect(merged.settings).toEqual({ searchImpl: 'params' }); }); + + // 测试场景:用户模型 abilitie 为空(Empty),而基础模型有搜索能力和设置 + it('should retain builtin abilities and settings when user model has no abilities (empty) and builtin has settings', async () => { + const providerId = 'openai'; + + const userModels: AiProviderModelListItem[] = [ + { + id: 'gpt-4', + type: 'chat', + enabled: true, + abilities: {}, // Empty object, no search + }, + ]; + + const builtinModels: AiProviderModelListItem[] = [ + { + id: 'gpt-4', + type: 'chat', + enabled: true, + abilities: { search: false }, // 使用 builtin abilities + settings: { searchImpl: 'params', searchProvider: 'google' }, // builtin has settings + }, + ]; + + vi.spyOn(repo.aiModelModel, 'getModelListByProviderId').mockResolvedValue(userModels); + vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue(builtinModels); + + const result = await repo.getAiProviderModelList(providerId); + + const merged = result.find((m) => m.id === 'gpt-4'); + expect(merged).toBeDefined(); + // 使用 builtin abilities + expect(merged?.abilities?.search).toEqual(false); + // 保留 builtin settings + expect(merged?.settings).toBeUndefined(); + }); + + it('should retain builtin abilities and settings when user model has no abilities (empty) and builtin has settings', async () => { + const providerId = 'openai'; + + const userModels: AiProviderModelListItem[] = [ + { + id: 'gpt-4', + type: 'chat', + enabled: true, + abilities: {}, // Empty object, no search + }, + ]; + + const builtinModels: AiProviderModelListItem[] = [ + { + id: 'gpt-4', + type: 'chat', + enabled: true, + abilities: { search: true }, // 使用 builtin abilities + settings: { searchImpl: 'params', searchProvider: 'google' }, // builtin has settings + }, + ]; + + vi.spyOn(repo.aiModelModel, 'getModelListByProviderId').mockResolvedValue(userModels); + vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue(builtinModels); + + const result = await repo.getAiProviderModelList(providerId); + + const merged = result.find((m) => m.id === 'gpt-4'); + expect(merged).toBeDefined(); + // 使用 builtin abilities + expect(merged?.abilities?.search).toEqual(true); + // 保留 builtin settings + expect(merged?.settings).toEqual({ searchImpl: 'params', searchProvider: 'google' }); + }); + + // 测试场景:用户模型未启用搜索(abilities.search 为 undefined),而基础模型有搜索能力和设置 + it('should retain builtin settings when user model has no abilities (empty) and builtin has settings', async () => { + const providerId = 'openai'; + + const userModels: AiProviderModelListItem[] = [ + { + id: 'gpt-4', + type: 'chat', + enabled: true, + abilities: { vision: true }, // 启用 vision 能力, no search + }, + ]; + + const builtinModels: AiProviderModelListItem[] = [ + { + id: 'gpt-4', + type: 'chat', + enabled: true, + abilities: { search: false }, // builtin abilities 会被 merge + settings: { searchImpl: 'params', searchProvider: 'google' }, // builtin has settings + }, + ]; + + vi.spyOn(repo.aiModelModel, 'getModelListByProviderId').mockResolvedValue(userModels); + vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue(builtinModels); + + const result = await repo.getAiProviderModelList(providerId); + + const merged = result.find((m) => m.id === 'gpt-4'); + expect(merged).toBeDefined(); + // abilities.search 会被 merge 为 false,此处和 getEnabledAiModel 不同 + expect(merged?.abilities?.search).toEqual(false); + // 删去 builtin settings + expect(merged?.settings).toBeUndefined(); + }); + + it('should retain builtin settings when user model has no abilities (empty) and builtin has settings', async () => { + const providerId = 'openai'; + + const userModels: AiProviderModelListItem[] = [ + { + id: 'gpt-4', + type: 'chat', + enabled: true, + abilities: { vision: true }, // 启用 vision 能力, no search + }, + ]; + + const builtinModels: AiProviderModelListItem[] = [ + { + id: 'gpt-4', + type: 'chat', + enabled: true, + abilities: { search: true }, // builtin abilities 会被 merge + settings: { searchImpl: 'params', searchProvider: 'google' }, // builtin has settings + }, + ]; + + vi.spyOn(repo.aiModelModel, 'getModelListByProviderId').mockResolvedValue(userModels); + vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue(builtinModels); + + const result = await repo.getAiProviderModelList(providerId); + + const merged = result.find((m) => m.id === 'gpt-4'); + expect(merged).toBeDefined(); + // abilities.search 会被 merge 为 true,此处和 getEnabledAiModel 不同 + expect(merged?.abilities?.search).toEqual(true); + // 保留 builtin settings + expect(merged?.settings).toEqual({ searchImpl: 'params', searchProvider: 'google' }); + }); + + // 测试:用户模型无 abilities.search(undefined),保留 builtin settings(mergeArrayById 优先用户,但用户无则 builtin) + it('should retain builtin settings when user model has no abilities.search (undefined) and builtin has settings', async () => { + const providerId = 'openai'; + + const userModels: AiProviderModelListItem[] = [ + { + id: 'gpt-4', + type: 'chat', + enabled: true, + abilities: {}, // 无 search + }, + ]; + + const builtinModels: AiProviderModelListItem[] = [ + { + id: 'gpt-4', + type: 'chat', + enabled: true, + abilities: {}, + settings: { searchImpl: 'params', searchProvider: 'google' }, // builtin 有 + }, + ]; + + vi.spyOn(repo.aiModelModel, 'getModelListByProviderId').mockResolvedValue(userModels); + vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue(builtinModels); + + const result = await repo.getAiProviderModelList(providerId); + + const merged = result.find((m) => m.id === 'gpt-4'); + expect(merged).toBeDefined(); + expect(merged?.abilities?.search).toBeUndefined(); + // 保留 builtin settings + expect(merged?.settings).toEqual({ searchImpl: 'params', searchProvider: 'google' }); + }); + + it('should retain no settings when user model has no abilities.search (undefined) and builtin has no settings', async () => { + const providerId = 'openai'; + + const userModels: AiProviderModelListItem[] = [ + { + id: 'gpt-4', + type: 'chat', + enabled: true, + abilities: {}, // 无 search + }, + ]; + + const builtinModels: AiProviderModelListItem[] = [ + { + id: 'gpt-4', + type: 'chat', + enabled: true, + // 无 settings + }, + ]; + + vi.spyOn(repo.aiModelModel, 'getModelListByProviderId').mockResolvedValue(userModels); + vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue(builtinModels); + + const result = await repo.getAiProviderModelList(providerId); + + const merged = result.find((m) => m.id === 'gpt-4'); + expect(merged).toBeDefined(); + expect(merged?.abilities?.search).toBeUndefined(); + // 无 settings + expect(merged?.settings).toBeUndefined(); + }); + + // 测试:用户模型有 abilities.search: true + it('should inject defaults when user has search: true, no existing settings (builtin none)', async () => { + const providerId = 'openai'; + + const userModels: AiProviderModelListItem[] = [ + { + id: 'gpt-4', + type: 'chat', + enabled: true, + abilities: { search: true }, // 用户启用 + }, + ]; + + const builtinModels: AiProviderModelListItem[] = [ + { + id: 'gpt-4', + type: 'chat', + enabled: true, + // 无 settings + }, + ]; + + vi.spyOn(repo.aiModelModel, 'getModelListByProviderId').mockResolvedValue(userModels); + vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue(builtinModels); + + const result = await repo.getAiProviderModelList(providerId); + + const merged = result.find((m) => m.id === 'gpt-4'); + expect(merged).toBeDefined(); + expect(merged?.abilities).toEqual({ search: true }); + // 注入 defaults + expect(merged?.settings).toEqual({ searchImpl: 'params' }); + }); + + it('should retain existing settings when user has search: true and builtin has settings', async () => { + const providerId = 'openai'; + + const userModels: AiProviderModelListItem[] = [ + { + id: 'gpt-4', + type: 'chat', + enabled: true, + abilities: { search: true }, + }, + ]; + + const builtinModels: AiProviderModelListItem[] = [ + { + id: 'gpt-4', + type: 'chat', + enabled: true, + settings: { searchImpl: 'tool' }, + }, + ]; + + vi.spyOn(repo.aiModelModel, 'getModelListByProviderId').mockResolvedValue(userModels); + vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue(builtinModels); + + const result = await repo.getAiProviderModelList(providerId); + + const merged = result.find((m) => m.id === 'gpt-4'); + expect(merged).toBeDefined(); + expect(merged?.abilities).toEqual({ search: true }); + // 使用 builtin settings + expect(merged?.settings).toEqual({ searchImpl: 'tool' }); + }); + + // 测试:用户模型有 abilities.search: false + it('should remove settings when user has search: false and builtin has settings', async () => { + const providerId = 'openai'; + + const userModels: AiProviderModelListItem[] = [ + { + id: 'gpt-4', + type: 'chat', + enabled: true, + abilities: { search: false }, // 用户禁用 + }, + ]; + + const builtinModels: AiProviderModelListItem[] = [ + { + id: 'gpt-4', + type: 'chat', + enabled: true, + settings: { searchImpl: 'tool', extendParams: [] }, + }, + ]; + + vi.spyOn(repo.aiModelModel, 'getModelListByProviderId').mockResolvedValue(userModels); + vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue(builtinModels); + + const result = await repo.getAiProviderModelList(providerId); + + const merged = result.find((m) => m.id === 'gpt-4'); + expect(merged).toBeDefined(); + expect(merged?.abilities).toEqual({ search: false }); + // 移除 search 相关,保留其他 + expect(merged?.settings).toEqual({ extendParams: [] }); + }); + + it('should keep no settings when user has search: false and no existing settings', async () => { + const providerId = 'openai'; + + const userModels: AiProviderModelListItem[] = [ + { + id: 'gpt-4', + type: 'chat', + enabled: true, + abilities: { search: false }, + }, + ]; + + const builtinModels: AiProviderModelListItem[] = [ + { + id: 'gpt-4', + type: 'chat', + enabled: true, + // 无 settings + }, + ]; + + vi.spyOn(repo.aiModelModel, 'getModelListByProviderId').mockResolvedValue(userModels); + vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue(builtinModels); + + const result = await repo.getAiProviderModelList(providerId); + + const merged = result.find((m) => m.id === 'gpt-4'); + expect(merged).toBeDefined(); + expect(merged?.abilities).toEqual({ search: false }); + // 无 settings + expect(merged?.settings).toBeUndefined(); + }); }); describe('getAiProviderRuntimeState', () => { diff --git a/packages/model-runtime/src/core/contextBuilders/google.test.ts b/packages/model-runtime/src/core/contextBuilders/google.test.ts new file mode 100644 index 00000000000..4e654220444 --- /dev/null +++ b/packages/model-runtime/src/core/contextBuilders/google.test.ts @@ -0,0 +1,585 @@ +// @vitest-environment node +import { Type as SchemaType } from '@google/genai'; +import { describe, expect, it, vi } from 'vitest'; + +import { ChatCompletionTool, OpenAIChatMessage, UserMessageContentPart } from '../../types'; +import * as imageToBase64Module from '../../utils/imageToBase64'; +import { parseDataUri } from '../../utils/uriParser'; +import { + buildGoogleMessage, + buildGoogleMessages, + buildGooglePart, + buildGoogleTool, + buildGoogleTools, +} from './google'; + +// Mock the utils +vi.mock('../../utils/uriParser', () => ({ + parseDataUri: vi.fn(), +})); + +vi.mock('../../utils/imageToBase64', () => ({ + imageUrlToBase64: vi.fn(), +})); + +describe('google contextBuilders', () => { + describe('buildGooglePart', () => { + it('should handle text type messages', async () => { + const content: UserMessageContentPart = { + text: 'Hello', + type: 'text', + }; + + const result = await buildGooglePart(content); + + expect(result).toEqual({ text: 'Hello' }); + }); + + it('should handle thinking type messages', async () => { + const content: UserMessageContentPart = { + signature: 'abc', + thinking: 'Hello', + type: 'thinking', + }; + + const result = await buildGooglePart(content); + + expect(result).toEqual(undefined); + }); + + it('should handle base64 type images', async () => { + const base64Image = + 'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAUAAAAFCAYAAACNbyblAAAAHElEQVQI12P4//8/w38GIAXDIBKE0DHxgljNBAAO9TXL0Y4OHwAAAABJRU5ErkJggg=='; + + vi.mocked(parseDataUri).mockReturnValueOnce({ + base64: + 'iVBORw0KGgoAAAANSUhEUgAAAAUAAAAFCAYAAACNbyblAAAAHElEQVQI12P4//8/w38GIAXDIBKE0DHxgljNBAAO9TXL0Y4OHwAAAABJRU5ErkJggg==', + mimeType: 'image/png', + type: 'base64', + }); + + const content: UserMessageContentPart = { + image_url: { url: base64Image }, + type: 'image_url', + }; + + const result = await buildGooglePart(content); + + expect(result).toEqual({ + inlineData: { + data: 'iVBORw0KGgoAAAANSUhEUgAAAAUAAAAFCAYAAACNbyblAAAAHElEQVQI12P4//8/w38GIAXDIBKE0DHxgljNBAAO9TXL0Y4OHwAAAABJRU5ErkJggg==', + mimeType: 'image/png', + }, + }); + }); + + it('should handle URL type images', async () => { + const imageUrl = 'http://example.com/image.png'; + const mockBase64 = 'mockBase64Data'; + + vi.mocked(parseDataUri).mockReturnValueOnce({ + base64: null, + mimeType: 'image/png', + type: 'url', + }); + + vi.spyOn(imageToBase64Module, 'imageUrlToBase64').mockResolvedValueOnce({ + base64: mockBase64, + mimeType: 'image/png', + }); + + const content: UserMessageContentPart = { + image_url: { url: imageUrl }, + type: 'image_url', + }; + + const result = await buildGooglePart(content); + + expect(result).toEqual({ + inlineData: { + data: mockBase64, + mimeType: 'image/png', + }, + }); + + expect(imageToBase64Module.imageUrlToBase64).toHaveBeenCalledWith(imageUrl); + }); + + it('should throw TypeError for unsupported image URL types', async () => { + const unsupportedImageUrl = 'unsupported://example.com/image.png'; + + vi.mocked(parseDataUri).mockReturnValueOnce({ + base64: null, + mimeType: null, + type: 'unknown' as any, + }); + + const content: UserMessageContentPart = { + image_url: { url: unsupportedImageUrl }, + type: 'image_url', + }; + + await expect(buildGooglePart(content)).rejects.toThrow(TypeError); + }); + + it('should handle base64 video', async () => { + const base64Video = 'data:video/mp4;base64,mockVideoBase64Data'; + + vi.mocked(parseDataUri).mockReturnValueOnce({ + base64: 'mockVideoBase64Data', + mimeType: 'video/mp4', + type: 'base64', + }); + + const content: UserMessageContentPart = { + type: 'video_url', + video_url: { url: base64Video }, + }; + + const result = await buildGooglePart(content); + + expect(result).toEqual({ + inlineData: { + data: 'mockVideoBase64Data', + mimeType: 'video/mp4', + }, + }); + }); + }); + + describe('buildGoogleMessage', () => { + it('should correctly convert assistant message', async () => { + const message: OpenAIChatMessage = { + content: 'Hello', + role: 'assistant', + }; + + const converted = await buildGoogleMessage(message); + + expect(converted).toEqual({ + parts: [{ text: 'Hello' }], + role: 'model', + }); + }); + + it('should correctly convert user message', async () => { + const message: OpenAIChatMessage = { + content: 'Hi', + role: 'user', + }; + + const converted = await buildGoogleMessage(message); + + expect(converted).toEqual({ + parts: [{ text: 'Hi' }], + role: 'user', + }); + }); + + it('should correctly convert message with inline base64 image parts', async () => { + vi.mocked(parseDataUri).mockReturnValueOnce({ + base64: '...', + mimeType: 'image/png', + type: 'base64', + }); + + const message: OpenAIChatMessage = { + content: [ + { text: 'Check this image:', type: 'text' }, + { image_url: { url: 'data:image/png;base64,...' }, type: 'image_url' }, + ], + role: 'user', + }; + + const converted = await buildGoogleMessage(message); + + expect(converted).toEqual({ + parts: [ + { text: 'Check this image:' }, + { inlineData: { data: '...', mimeType: 'image/png' } }, + ], + role: 'user', + }); + }); + + it('should correctly convert function call message', async () => { + const message = { + role: 'assistant', + tool_calls: [ + { + function: { + arguments: JSON.stringify({ location: 'London', unit: 'celsius' }), + name: 'get_current_weather', + }, + id: 'call_1', + type: 'function', + }, + ], + } as OpenAIChatMessage; + + const converted = await buildGoogleMessage(message); + + expect(converted).toEqual({ + parts: [ + { + functionCall: { + args: { location: 'London', unit: 'celsius' }, + name: 'get_current_weather', + }, + }, + ], + role: 'model', + }); + }); + + it('should correctly handle empty content', async () => { + const message: OpenAIChatMessage = { + content: '' as any, // explicitly set as empty string + role: 'user', + }; + + const converted = await buildGoogleMessage(message); + + expect(converted).toEqual({ + parts: [{ text: '' }], + role: 'user', + }); + }); + + it('should correctly convert tool response message', async () => { + const toolCallNameMap = new Map(); + toolCallNameMap.set('call_1', 'get_current_weather'); + + const message: OpenAIChatMessage = { + content: '{"success":true,"data":{"temperature":"14°C"}}', + name: 'get_current_weather', + role: 'tool', + tool_call_id: 'call_1', + }; + + const converted = await buildGoogleMessage(message, toolCallNameMap); + + expect(converted).toEqual({ + parts: [ + { + functionResponse: { + name: 'get_current_weather', + response: { result: '{"success":true,"data":{"temperature":"14°C"}}' }, + }, + }, + ], + role: 'user', + }); + }); + }); + + describe('buildGoogleMessages', () => { + it('get default result with gemini-pro', async () => { + const messages: OpenAIChatMessage[] = [{ content: 'Hello', role: 'user' }]; + + const contents = await buildGoogleMessages(messages); + + expect(contents).toHaveLength(1); + expect(contents).toEqual([{ parts: [{ text: 'Hello' }], role: 'user' }]); + }); + + it('should not modify the length if model is gemini-1.5-pro', async () => { + const messages: OpenAIChatMessage[] = [ + { content: 'Hello', role: 'user' }, + { content: 'Hi', role: 'assistant' }, + ]; + + const contents = await buildGoogleMessages(messages); + + expect(contents).toHaveLength(2); + expect(contents).toEqual([ + { parts: [{ text: 'Hello' }], role: 'user' }, + { parts: [{ text: 'Hi' }], role: 'model' }, + ]); + }); + + it('should use specified model when images are included in messages', async () => { + vi.mocked(parseDataUri).mockReturnValueOnce({ + base64: '...', + mimeType: 'image/png', + type: 'base64', + }); + + const messages: OpenAIChatMessage[] = [ + { + content: [ + { text: 'Hello', type: 'text' }, + { image_url: { url: 'data:image/png;base64,...' }, type: 'image_url' }, + ], + role: 'user', + }, + ]; + + const contents = await buildGoogleMessages(messages); + + expect(contents).toHaveLength(1); + expect(contents).toEqual([ + { + parts: [{ text: 'Hello' }, { inlineData: { data: '...', mimeType: 'image/png' } }], + role: 'user', + }, + ]); + }); + + it('should correctly convert function response message', async () => { + const messages: OpenAIChatMessage[] = [ + { + content: '', + role: 'assistant', + tool_calls: [ + { + function: { + arguments: JSON.stringify({ location: 'London', unit: 'celsius' }), + name: 'get_current_weather', + }, + id: 'call_1', + type: 'function', + }, + ], + }, + { + content: '{"success":true,"data":{"temperature":"14°C"}}', + name: 'get_current_weather', + role: 'tool', + tool_call_id: 'call_1', + }, + ]; + + const contents = await buildGoogleMessages(messages); + + expect(contents).toHaveLength(2); + expect(contents).toEqual([ + { + parts: [ + { + functionCall: { + args: { location: 'London', unit: 'celsius' }, + name: 'get_current_weather', + }, + }, + ], + role: 'model', + }, + { + parts: [ + { + functionResponse: { + name: 'get_current_weather', + response: { result: '{"success":true,"data":{"temperature":"14°C"}}' }, + }, + }, + ], + role: 'user', + }, + ]); + }); + + it('should filter out function role messages', async () => { + const messages: OpenAIChatMessage[] = [ + { content: 'Hello', role: 'user' }, + { content: 'function result', name: 'test_func', role: 'function' }, + { content: 'Hi', role: 'assistant' }, + ]; + + const contents = await buildGoogleMessages(messages); + + expect(contents).toHaveLength(2); + expect(contents).toEqual([ + { parts: [{ text: 'Hello' }], role: 'user' }, + { parts: [{ text: 'Hi' }], role: 'model' }, + ]); + }); + + it('should filter out empty messages', async () => { + const messages: OpenAIChatMessage[] = [ + { content: 'Hello', role: 'user' }, + { content: [], role: 'user' }, + { content: 'Hi', role: 'assistant' }, + ]; + + const contents = await buildGoogleMessages(messages); + + expect(contents).toHaveLength(2); + expect(contents).toEqual([ + { parts: [{ text: 'Hello' }], role: 'user' }, + { parts: [{ text: 'Hi' }], role: 'model' }, + ]); + }); + }); + + describe('buildGoogleTool', () => { + it('should correctly convert ChatCompletionTool to FunctionDeclaration', () => { + const tool: ChatCompletionTool = { + function: { + description: 'A test tool', + name: 'testTool', + parameters: { + properties: { + param1: { type: 'string' }, + param2: { type: 'number' }, + }, + required: ['param1'], + type: 'object', + }, + }, + type: 'function', + }; + + const result = buildGoogleTool(tool); + + expect(result).toEqual({ + description: 'A test tool', + name: 'testTool', + parameters: { + description: undefined, + properties: { + param1: { type: 'string' }, + param2: { type: 'number' }, + }, + required: ['param1'], + type: SchemaType.OBJECT, + }, + }); + }); + + it('should handle tools with empty parameters', () => { + const tool: ChatCompletionTool = { + function: { + description: 'A simple function with no parameters', + name: 'simple_function', + parameters: { + properties: {}, + type: 'object', + }, + }, + type: 'function', + }; + + const result = buildGoogleTool(tool); + + // Should use dummy property for empty parameters + expect(result).toEqual({ + description: 'A simple function with no parameters', + name: 'simple_function', + parameters: { + description: undefined, + properties: { dummy: { type: 'string' } }, + required: undefined, + type: SchemaType.OBJECT, + }, + }); + }); + + it('should preserve parameter description', () => { + const tool: ChatCompletionTool = { + function: { + description: 'A test tool', + name: 'testTool', + parameters: { + description: 'Test parameters', + properties: { + param1: { type: 'string' }, + }, + type: 'object', + }, + }, + type: 'function', + }; + + const result = buildGoogleTool(tool); + + expect(result.parameters?.description).toBe('Test parameters'); + }); + }); + + describe('buildGoogleTools', () => { + it('should return undefined when tools is undefined or empty', () => { + expect(buildGoogleTools(undefined)).toBeUndefined(); + expect(buildGoogleTools([])).toBeUndefined(); + }); + + it('should correctly convert ChatCompletionTool array to GoogleFunctionCallTool', () => { + const tools: ChatCompletionTool[] = [ + { + function: { + description: 'A test tool', + name: 'testTool', + parameters: { + properties: { + param1: { type: 'string' }, + param2: { type: 'number' }, + }, + required: ['param1'], + type: 'object', + }, + }, + type: 'function', + }, + ]; + + const googleTools = buildGoogleTools(tools); + + expect(googleTools).toHaveLength(1); + expect(googleTools![0].functionDeclarations).toHaveLength(1); + expect(googleTools![0].functionDeclarations![0]).toEqual({ + description: 'A test tool', + name: 'testTool', + parameters: { + description: undefined, + properties: { + param1: { type: 'string' }, + param2: { type: 'number' }, + }, + required: ['param1'], + type: SchemaType.OBJECT, + }, + }); + }); + + it('should handle multiple tools', () => { + const tools: ChatCompletionTool[] = [ + { + function: { + description: 'Get weather information', + name: 'get_weather', + parameters: { + properties: { + city: { type: 'string' }, + unit: { type: 'string' }, + }, + required: ['city'], + type: 'object', + }, + }, + type: 'function', + }, + { + function: { + description: 'Get current time', + name: 'get_time', + parameters: { + properties: { + timezone: { type: 'string' }, + }, + required: ['timezone'], + type: 'object', + }, + }, + type: 'function', + }, + ]; + + const googleTools = buildGoogleTools(tools); + + expect(googleTools).toHaveLength(1); + expect(googleTools![0].functionDeclarations).toHaveLength(2); + expect(googleTools![0].functionDeclarations![0].name).toBe('get_weather'); + expect(googleTools![0].functionDeclarations![1].name).toBe('get_time'); + }); + }); +}); diff --git a/packages/model-runtime/src/core/contextBuilders/google.ts b/packages/model-runtime/src/core/contextBuilders/google.ts new file mode 100644 index 00000000000..768e0814ed2 --- /dev/null +++ b/packages/model-runtime/src/core/contextBuilders/google.ts @@ -0,0 +1,201 @@ +import { + Content, + FunctionDeclaration, + Tool as GoogleFunctionCallTool, + Part, + Type as SchemaType, +} from '@google/genai'; + +import { ChatCompletionTool, OpenAIChatMessage, UserMessageContentPart } from '../../types'; +import { imageUrlToBase64 } from '../../utils/imageToBase64'; +import { safeParseJSON } from '../../utils/safeParseJSON'; +import { parseDataUri } from '../../utils/uriParser'; + +/** + * Convert OpenAI content part to Google Part format + */ +export const buildGooglePart = async ( + content: UserMessageContentPart, +): Promise => { + switch (content.type) { + default: { + return undefined; + } + + case 'text': { + return { text: content.text }; + } + + case 'image_url': { + const { mimeType, base64, type } = parseDataUri(content.image_url.url); + + if (type === 'base64') { + if (!base64) { + throw new TypeError("Image URL doesn't contain base64 data"); + } + + return { + inlineData: { data: base64, mimeType: mimeType || 'image/png' }, + }; + } + + if (type === 'url') { + const { base64, mimeType } = await imageUrlToBase64(content.image_url.url); + + return { + inlineData: { data: base64, mimeType }, + }; + } + + throw new TypeError(`currently we don't support image url: ${content.image_url.url}`); + } + + case 'video_url': { + const { mimeType, base64, type } = parseDataUri(content.video_url.url); + + if (type === 'base64') { + if (!base64) { + throw new TypeError("Video URL doesn't contain base64 data"); + } + + return { + inlineData: { data: base64, mimeType: mimeType || 'video/mp4' }, + }; + } + + if (type === 'url') { + // For video URLs, we need to fetch and convert to base64 + // Note: This might need size/duration limits for practical use + const response = await fetch(content.video_url.url); + const arrayBuffer = await response.arrayBuffer(); + const base64 = Buffer.from(arrayBuffer).toString('base64'); + const mimeType = response.headers.get('content-type') || 'video/mp4'; + + return { + inlineData: { data: base64, mimeType }, + }; + } + + throw new TypeError(`currently we don't support video url: ${content.video_url.url}`); + } + } +}; + +/** + * Convert OpenAI message to Google Content format + */ +export const buildGoogleMessage = async ( + message: OpenAIChatMessage, + toolCallNameMap?: Map, +): Promise => { + const content = message.content as string | UserMessageContentPart[]; + + // Handle assistant messages with tool_calls + if (!!message.tool_calls) { + return { + parts: message.tool_calls.map((tool) => ({ + functionCall: { + args: safeParseJSON(tool.function.arguments)!, + name: tool.function.name, + }, + })), + role: 'model', + }; + } + + // Convert tool_call result to functionResponse part + if (message.role === 'tool' && toolCallNameMap && message.tool_call_id) { + const functionName = toolCallNameMap.get(message.tool_call_id); + if (functionName) { + return { + parts: [ + { + functionResponse: { + name: functionName, + response: { result: message.content }, + }, + }, + ], + role: 'user', + }; + } + } + + const getParts = async () => { + if (typeof content === 'string') return [{ text: content }]; + + const parts = await Promise.all(content.map(async (c) => await buildGooglePart(c))); + return parts.filter(Boolean) as Part[]; + }; + + return { + parts: await getParts(), + role: message.role === 'assistant' ? 'model' : 'user', + }; +}; + +/** + * Convert messages from the OpenAI format to Google GenAI SDK format + */ +export const buildGoogleMessages = async (messages: OpenAIChatMessage[]): Promise => { + const toolCallNameMap = new Map(); + + // Build tool call id to name mapping + messages.forEach((message) => { + if (message.role === 'assistant' && message.tool_calls) { + message.tool_calls.forEach((toolCall) => { + if (toolCall.type === 'function') { + toolCallNameMap.set(toolCall.id, toolCall.function.name); + } + }); + } + }); + + const pools = messages + .filter((message) => message.role !== 'function') + .map(async (msg) => await buildGoogleMessage(msg, toolCallNameMap)); + + const contents = await Promise.all(pools); + + // Filter out empty messages: contents.parts must not be empty. + return contents.filter((content: Content) => content.parts && content.parts.length > 0); +}; + +/** + * Convert ChatCompletionTool to Google FunctionDeclaration + */ +export const buildGoogleTool = (tool: ChatCompletionTool): FunctionDeclaration => { + const functionDeclaration = tool.function; + const parameters = functionDeclaration.parameters; + // refs: https://github.com/lobehub/lobe-chat/pull/5002 + const properties = + parameters?.properties && Object.keys(parameters.properties).length > 0 + ? parameters.properties + : { dummy: { type: 'string' } }; // dummy property to avoid empty object + + return { + description: functionDeclaration.description, + name: functionDeclaration.name, + parameters: { + description: parameters?.description, + properties: properties, + required: parameters?.required, + type: SchemaType.OBJECT, + }, + }; +}; + +/** + * Build Google function declarations from ChatCompletionTool array + */ +export const buildGoogleTools = ( + tools: ChatCompletionTool[] | undefined, +): GoogleFunctionCallTool[] | undefined => { + if (!tools || tools.length === 0) return; + + return [ + { + functionDeclarations: tools.map((tool) => buildGoogleTool(tool)), + }, + ]; +}; diff --git a/packages/model-runtime/src/core/openaiCompatibleFactory/index.test.ts b/packages/model-runtime/src/core/openaiCompatibleFactory/index.test.ts index 57f9633f1bd..246a4c5750e 100644 --- a/packages/model-runtime/src/core/openaiCompatibleFactory/index.test.ts +++ b/packages/model-runtime/src/core/openaiCompatibleFactory/index.test.ts @@ -119,11 +119,11 @@ describe('LobeOpenAICompatibleFactory', () => { max_tokens: 1024, messages: [{ content: 'Hello', role: 'user' }], model: 'mistralai/mistral-7b-instruct:free', - temperature: 0.7, stream: true, stream_options: { include_usage: true, }, + temperature: 0.7, top_p: 1, }, { headers: { Accept: '*/*' } }, @@ -136,14 +136,14 @@ describe('LobeOpenAICompatibleFactory', () => { const mockStream = new ReadableStream({ start(controller) { controller.enqueue({ + choices: [ + { delta: { content: 'hello' }, finish_reason: null, index: 0, logprobs: null }, + ], + created: 1_709_125_675, id: 'a', - object: 'chat.completion.chunk', - created: 1709125675, model: 'mistralai/mistral-7b-instruct:free', + object: 'chat.completion.chunk', system_fingerprint: 'fp_86156a94a0', - choices: [ - { index: 0, delta: { content: 'hello' }, logprobs: null, finish_reason: null }, - ], }); controller.close(); }, @@ -163,6 +163,7 @@ describe('LobeOpenAICompatibleFactory', () => { // Collect all chunks const chunks = []; + // eslint-disable-next-line no-constant-condition while (true) { const { value, done } = await reader.read(); if (done) break; @@ -185,13 +186,13 @@ describe('LobeOpenAICompatibleFactory', () => { object: '', prompt_filter_results: [ { - prompt_index: 0, content_filter_results: { hate: { filtered: false, severity: 'safe' }, self_harm: { filtered: false, severity: 'safe' }, sexual: { filtered: false, severity: 'safe' }, violence: { filtered: false, severity: 'safe' }, }, + prompt_index: 0, }, ], }, @@ -204,7 +205,7 @@ describe('LobeOpenAICompatibleFactory', () => { logprobs: null, }, ], - created: 1717249403, + created: 1_717_249_403, id: 'chatcmpl-9VJIxA3qNM2C2YdAnNYA2KgDYfFnX', model: 'gpt-4o-2024-05-13', object: 'chat.completion.chunk', @@ -212,7 +213,7 @@ describe('LobeOpenAICompatibleFactory', () => { }, { choices: [{ delta: { content: '1' }, finish_reason: null, index: 0, logprobs: null }], - created: 1717249403, + created: 1_717_249_403, id: 'chatcmpl-9VJIxA3qNM2C2YdAnNYA2KgDYfFnX', model: 'gpt-4o-2024-05-13', object: 'chat.completion.chunk', @@ -220,7 +221,7 @@ describe('LobeOpenAICompatibleFactory', () => { }, { choices: [{ delta: {}, finish_reason: 'stop', index: 0, logprobs: null }], - created: 1717249403, + created: 1_717_249_403, id: 'chatcmpl-9VJIxA3qNM2C2YdAnNYA2KgDYfFnX', model: 'gpt-4o-2024-05-13', object: 'chat.completion.chunk', @@ -229,7 +230,7 @@ describe('LobeOpenAICompatibleFactory', () => { { choices: [ { - content_filter_offsets: { check_offset: 35, start_offset: 35, end_offset: 36 }, + content_filter_offsets: { check_offset: 35, end_offset: 36, start_offset: 35 }, content_filter_results: { hate: { filtered: false, severity: 'safe' }, self_harm: { filtered: false, severity: 'safe' }, @@ -268,6 +269,7 @@ describe('LobeOpenAICompatibleFactory', () => { const decoder = new TextDecoder(); const reader = result.body!.getReader(); + // eslint-disable-next-line no-constant-condition while (true) { const { value, done } = await reader.read(); if (done) break; @@ -278,7 +280,7 @@ describe('LobeOpenAICompatibleFactory', () => { [ 'id: ', 'event: data', - 'data: {"choices":[],"created":0,"id":"","model":"","object":"","prompt_filter_results":[{"prompt_index":0,"content_filter_results":{"hate":{"filtered":false,"severity":"safe"},"self_harm":{"filtered":false,"severity":"safe"},"sexual":{"filtered":false,"severity":"safe"},"violence":{"filtered":false,"severity":"safe"}}}]}\n', + 'data: {"choices":[],"created":0,"id":"","model":"","object":"","prompt_filter_results":[{"content_filter_results":{"hate":{"filtered":false,"severity":"safe"},"self_harm":{"filtered":false,"severity":"safe"},"sexual":{"filtered":false,"severity":"safe"},"violence":{"filtered":false,"severity":"safe"}},"prompt_index":0}]}\n', 'id: chatcmpl-9VJIxA3qNM2C2YdAnNYA2KgDYfFnX', 'event: text', 'data: ""\n', @@ -299,21 +301,21 @@ describe('LobeOpenAICompatibleFactory', () => { vi.useFakeTimers(); const mockResponse = { - id: 'a', - object: 'chat.completion', - created: 123, - model: 'mistralai/mistral-7b-instruct:free', choices: [ { - index: 0, - message: { role: 'assistant', content: 'Hello' }, finish_reason: 'stop', + index: 0, logprobs: null, + message: { content: 'Hello', role: 'assistant' }, }, ], + created: 123, + id: 'a', + model: 'mistralai/mistral-7b-instruct:free', + object: 'chat.completion', usage: { - prompt_tokens: 5, completion_tokens: 5, + prompt_tokens: 5, total_tokens: 10, }, } as OpenAI.ChatCompletion; @@ -324,8 +326,8 @@ describe('LobeOpenAICompatibleFactory', () => { const chatPromise = instance.chat({ messages: [{ content: 'Hello', role: 'user' }], model: 'mistralai/mistral-7b-instruct:free', - temperature: 0, stream: false, + temperature: 0, }); // Advance time to simulate processing delay @@ -337,6 +339,7 @@ describe('LobeOpenAICompatibleFactory', () => { const reader = result.body!.getReader(); const stream: string[] = []; + // eslint-disable-next-line no-constant-condition while (true) { const { value, done } = await reader.read(); if (done) break; @@ -352,13 +355,14 @@ describe('LobeOpenAICompatibleFactory', () => { 'data: {"inputTextTokens":5,"outputTextTokens":5,"totalInputTokens":5,"totalOutputTokens":5,"totalTokens":10}\n\n', 'id: output_speed\n', 'event: speed\n', - expect.stringMatching(/^data: \{.*"tps":.*,"ttft":.*}\n\n$/), // tps ttft should be calculated with elapsed time + expect.stringMatching(/^data: {.*"tps":.*,"ttft":.*}\n\n$/), // tps ttft should be calculated with elapsed time 'id: a\n', 'event: stop\n', 'data: "stop"\n\n', ]); - expect((await reader.read()).done).toBe(true); + const finalRead = await reader.read(); + expect(finalRead.done).toBe(true); vi.useRealTimers(); }); @@ -367,25 +371,25 @@ describe('LobeOpenAICompatibleFactory', () => { vi.useFakeTimers(); const mockResponse = { - id: 'a', - object: 'chat.completion', - created: 123, - model: 'deepseek/deepseek-reasoner', choices: [ { + finish_reason: 'stop', index: 0, + logprobs: null, message: { - role: 'assistant', content: 'Hello', reasoning_content: 'Thinking content', + role: 'assistant', }, - finish_reason: 'stop', - logprobs: null, }, ], + created: 123, + id: 'a', + model: 'deepseek/deepseek-reasoner', + object: 'chat.completion', usage: { - prompt_tokens: 5, completion_tokens: 5, + prompt_tokens: 5, total_tokens: 10, }, } as unknown as OpenAI.ChatCompletion; @@ -396,8 +400,8 @@ describe('LobeOpenAICompatibleFactory', () => { const chatPromise = instance.chat({ messages: [{ content: 'Hello', role: 'user' }], model: 'deepseek/deepseek-reasoner', - temperature: 0, stream: false, + temperature: 0, }); // Advance time to simulate processing delay @@ -409,6 +413,7 @@ describe('LobeOpenAICompatibleFactory', () => { const reader = result.body!.getReader(); const stream: string[] = []; + // eslint-disable-next-line no-constant-condition while (true) { const { value, done } = await reader.read(); if (done) break; @@ -427,13 +432,14 @@ describe('LobeOpenAICompatibleFactory', () => { 'data: {"inputTextTokens":5,"outputTextTokens":5,"totalInputTokens":5,"totalOutputTokens":5,"totalTokens":10}\n\n', 'id: output_speed\n', 'event: speed\n', - expect.stringMatching(/^data: \{.*"tps":.*,"ttft":.*}\n\n$/), // tps ttft should be calculated with elapsed time + expect.stringMatching(/^data: {.*"tps":.*,"ttft":.*}\n\n$/), // tps ttft should be calculated with elapsed time 'id: a\n', 'event: stop\n', 'data: "stop"\n\n', ]); - expect((await reader.read()).done).toBe(true); + const finalRead = await reader.read(); + expect(finalRead.done).toBe(true); vi.useRealTimers(); }); @@ -607,10 +613,10 @@ describe('LobeOpenAICompatibleFactory', () => { const apiError = new OpenAI.APIError( 400, { - status: 400, error: { message: 'Bad Request', }, + status: 400, }, 'Error message', {}, @@ -770,13 +776,13 @@ describe('LobeOpenAICompatibleFactory', () => { } catch (e) { expect(e).toEqual({ endpoint: defaultBaseURL, - errorType: 'AgentRuntimeError', - provider, error: { - name: genericError.name, cause: genericError.cause, message: genericError.message, + name: genericError.name, }, + errorType: 'AgentRuntimeError', + provider, }); } }); @@ -791,14 +797,14 @@ describe('LobeOpenAICompatibleFactory', () => { new ReadableStream({ start(controller) { controller.enqueue({ + choices: [ + { delta: { content: 'hello' }, finish_reason: null, index: 0, logprobs: null }, + ], + created: 1_709_125_675, id: 'chatcmpl-8xDx5AETP8mESQN7UB30GxTN2H1SO', - object: 'chat.completion.chunk', - created: 1709125675, model: 'mistralai/mistral-7b-instruct:free', + object: 'chat.completion.chunk', system_fingerprint: 'fp_86156a94a0', - choices: [ - { index: 0, delta: { content: 'hello' }, logprobs: null, finish_reason: null }, - ], }); controller.close(); }, @@ -807,8 +813,8 @@ describe('LobeOpenAICompatibleFactory', () => { // Prepare callback and headers const mockCallback: ChatStreamCallbacks = { - onStart: vi.fn(), onCompletion: vi.fn(), + onStart: vi.fn(), }; const mockHeaders = { 'Custom-Header': 'TestValue' }; @@ -848,6 +854,7 @@ describe('LobeOpenAICompatibleFactory', () => { const reader = readableStream.getReader(); const process = async () => { try { + // eslint-disable-next-line no-constant-condition while (true) { const { done, value } = await reader.read(); if (done) break; @@ -877,9 +884,9 @@ describe('LobeOpenAICompatibleFactory', () => { const mockStream = new ReadableStream({ start(controller) { controller.enqueue({ - id: 'test-id', choices: [{ delta: { content: 'Hello' }, index: 0 }], created: Date.now(), + id: 'test-id', model: 'test-model', object: 'chat.completion.chunk', }); @@ -908,12 +915,12 @@ describe('LobeOpenAICompatibleFactory', () => { start(controller) { // Transform the completion to chunk format controller.enqueue({ - id: data.id, choices: data.choices.map((choice) => ({ delta: { content: choice.message.content }, index: choice.index, })), created: data.created, + id: data.id, model: data.model, object: 'chat.completion.chunk', }); @@ -933,20 +940,20 @@ describe('LobeOpenAICompatibleFactory', () => { const instance = new LobeMockProvider({ apiKey: 'test' }); const mockResponse: OpenAI.ChatCompletion = { - id: 'test-id', choices: [ { + finish_reason: 'stop', index: 0, + logprobs: null, message: { - role: 'assistant', content: 'Test response', refusal: null, + role: 'assistant', }, - logprobs: null, - finish_reason: 'stop', }, ], created: Date.now(), + id: 'test-id', model: 'test-model', object: 'chat.completion', usage: { completion_tokens: 2, prompt_tokens: 1, total_tokens: 3 }, @@ -959,8 +966,8 @@ describe('LobeOpenAICompatibleFactory', () => { const payload: ChatStreamPayload = { messages: [{ content: 'Test', role: 'user' }], model: 'test-model', - temperature: 0.7, stream: false, + temperature: 0.7, }; await instance.chat(payload); @@ -1110,8 +1117,8 @@ describe('LobeOpenAICompatibleFactory', () => { model: 'dall-e-3', params: { prompt: 'A beautiful sunset', - size: '1024x1024', quality: 'standard', + size: '1024x1024', }, }; @@ -1121,9 +1128,9 @@ describe('LobeOpenAICompatibleFactory', () => { model: 'dall-e-3', n: 1, prompt: 'A beautiful sunset', - size: '1024x1024', quality: 'standard', response_format: 'b64_json', + size: '1024x1024', }); expect(result).toEqual({ @@ -1200,9 +1207,9 @@ describe('LobeOpenAICompatibleFactory', () => { const payload = { model: 'dall-e-2', params: { - prompt: 'Add a rainbow to this image', imageUrls: ['https://example.com/image1.jpg'], mask: 'https://example.com/mask.jpg', + prompt: 'Add a rainbow to this image', }, }; @@ -1212,13 +1219,13 @@ describe('LobeOpenAICompatibleFactory', () => { 'https://example.com/image1.jpg', ); expect(instance['client'].images.edit).toHaveBeenCalledWith({ + image: expect.any(File), + input_fidelity: 'high', + mask: 'https://example.com/mask.jpg', model: 'dall-e-2', n: 1, prompt: 'Add a rainbow to this image', - image: expect.any(File), - mask: 'https://example.com/mask.jpg', response_format: 'b64_json', - input_fidelity: 'high', }); expect(result).toEqual({ @@ -1243,8 +1250,8 @@ describe('LobeOpenAICompatibleFactory', () => { const payload = { model: 'dall-e-2', params: { - prompt: 'Merge these images', imageUrls: ['https://example.com/image1.jpg', 'https://example.com/image2.jpg'], + prompt: 'Merge these images', }, }; @@ -1259,12 +1266,12 @@ describe('LobeOpenAICompatibleFactory', () => { ); expect(instance['client'].images.edit).toHaveBeenCalledWith({ + image: [mockFile1, mockFile2], + input_fidelity: 'high', model: 'dall-e-2', n: 1, prompt: 'Merge these images', - image: [mockFile1, mockFile2], response_format: 'b64_json', - input_fidelity: 'high', }); expect(result).toEqual({ @@ -1280,8 +1287,8 @@ describe('LobeOpenAICompatibleFactory', () => { const payload = { model: 'dall-e-2', params: { - prompt: 'Edit this image', imageUrls: ['https://invalid-url.com/image.jpg'], + prompt: 'Edit this image', }, }; @@ -1379,22 +1386,22 @@ describe('LobeOpenAICompatibleFactory', () => { const payload = { model: 'dall-e-2', params: { - prompt: 'Test prompt', - imageUrls: ['https://example.com/image.jpg'], customParam: 'should remain unchanged', + imageUrls: ['https://example.com/image.jpg'], + prompt: 'Test prompt', }, }; await (instance as any).createImage(payload); expect(instance['client'].images.edit).toHaveBeenCalledWith({ + customParam: 'should remain unchanged', + image: expect.any(File), + input_fidelity: 'high', model: 'dall-e-2', n: 1, prompt: 'Test prompt', - image: expect.any(File), - customParam: 'should remain unchanged', response_format: 'b64_json', - input_fidelity: 'high', }); }); @@ -1421,8 +1428,8 @@ describe('LobeOpenAICompatibleFactory', () => { n: 1, prompt: 'Test prompt', quality: 'hd', - style: 'vivid', response_format: 'b64_json', + style: 'vivid', }); }); }); @@ -1438,17 +1445,17 @@ describe('LobeOpenAICompatibleFactory', () => { const payload = { messages: [{ content: 'Generate a person object', role: 'user' as const }], + model: 'gpt-4o', + responseApi: true, schema: { - name: 'person_extractor', description: 'Extract person information', + name: 'person_extractor', schema: { + properties: { age: { type: 'number' }, name: { type: 'string' } }, type: 'object' as const, - properties: { name: { type: 'string' }, age: { type: 'number' } }, }, strict: true, }, - model: 'gpt-4o', - responseApi: true, }; const result = await instance.generateObject(payload); @@ -1464,7 +1471,7 @@ describe('LobeOpenAICompatibleFactory', () => { { headers: undefined, signal: undefined }, ); - expect(result).toEqual({ name: 'John', age: 30 }); + expect(result).toEqual({ age: 30, name: 'John' }); }); it('should handle options correctly', async () => { @@ -1476,18 +1483,18 @@ describe('LobeOpenAICompatibleFactory', () => { const payload = { messages: [{ content: 'Generate status', role: 'user' as const }], + model: 'gpt-4o', + responseApi: true, schema: { name: 'status_extractor', - schema: { type: 'object' as const, properties: { status: { type: 'string' } } }, + schema: { properties: { status: { type: 'string' } }, type: 'object' as const }, }, - model: 'gpt-4o', - responseApi: true, }; const options = { headers: { 'Custom-Header': 'test-value' }, - user: 'test-user', signal: new AbortController().signal, + user: 'test-user', }; const result = await instance.generateObject(payload, options); @@ -1516,12 +1523,12 @@ describe('LobeOpenAICompatibleFactory', () => { const payload = { messages: [{ content: 'Generate data', role: 'user' as const }], + model: 'gpt-4o', + responseApi: true, schema: { name: 'test_tool', - schema: { type: 'object' as const, properties: {} }, + schema: { properties: {}, type: 'object' as const }, }, - model: 'gpt-4o', - responseApi: true, }; const result = await instance.generateObject(payload); @@ -1542,12 +1549,12 @@ describe('LobeOpenAICompatibleFactory', () => { const payload = { messages: [{ content: 'Generate data', role: 'user' as const }], + model: 'gpt-4o', + responseApi: true, schema: { name: 'test_tool', - schema: { type: 'object' as const, properties: {} }, + schema: { properties: {}, type: 'object' as const }, }, - model: 'gpt-4o', - responseApi: true, }; const result = await instance.generateObject(payload); @@ -1568,35 +1575,38 @@ describe('LobeOpenAICompatibleFactory', () => { const payload = { messages: [{ content: 'Generate complex user data', role: 'user' as const }], + model: 'gpt-4o', + responseApi: true, schema: { name: 'user_extractor', schema: { - type: 'object' as const, properties: { + metadata: { type: 'object' }, user: { - type: 'object', properties: { name: { type: 'string' }, profile: { - type: 'object', properties: { age: { type: 'number' }, - preferences: { type: 'array', items: { type: 'string' } }, + preferences: { items: { type: 'string' }, type: 'array' }, }, + type: 'object', }, }, + type: 'object', }, - metadata: { type: 'object' }, }, + type: 'object' as const, }, }, - model: 'gpt-4o', - responseApi: true, }; const result = await instance.generateObject(payload); expect(result).toEqual({ + metadata: { + created: '2024-01-01', + }, user: { name: 'Alice', profile: { @@ -1604,9 +1614,6 @@ describe('LobeOpenAICompatibleFactory', () => { preferences: ['music', 'sports'], }, }, - metadata: { - created: '2024-01-01', - }, }); }); @@ -1617,12 +1624,12 @@ describe('LobeOpenAICompatibleFactory', () => { const payload = { messages: [{ content: 'Generate data', role: 'user' as const }], + model: 'gpt-4o', + responseApi: true, schema: { name: 'test_tool', - schema: { type: 'object' as const, properties: {} }, + schema: { properties: {}, type: 'object' as const }, }, - model: 'gpt-4o', - responseApi: true, }; await expect(instance.generateObject(payload)).rejects.toThrow( @@ -1648,14 +1655,14 @@ describe('LobeOpenAICompatibleFactory', () => { const payload = { messages: [{ content: 'Generate a person object', role: 'user' as const }], + model: 'gpt-4o', schema: { name: 'person_extractor', schema: { + properties: { age: { type: 'number' }, name: { type: 'string' } }, type: 'object' as const, - properties: { name: { type: 'string' }, age: { type: 'number' } }, }, }, - model: 'gpt-4o', // responseApi: false or undefined - uses chat completions API }; @@ -1671,7 +1678,7 @@ describe('LobeOpenAICompatibleFactory', () => { { headers: undefined, signal: undefined }, ); - expect(result).toEqual({ name: 'Bob', age: 25 }); + expect(result).toEqual({ age: 25, name: 'Bob' }); }); it('should handle options correctly with chat completions API', async () => { @@ -1691,18 +1698,18 @@ describe('LobeOpenAICompatibleFactory', () => { const payload = { messages: [{ content: 'Generate status', role: 'user' as const }], + model: 'gpt-4o', + responseApi: false, schema: { name: 'status_extractor', - schema: { type: 'object' as const, properties: { status: { type: 'string' } } }, + schema: { properties: { status: { type: 'string' } }, type: 'object' as const }, }, - model: 'gpt-4o', - responseApi: false, }; const options = { headers: { Authorization: 'Bearer token' }, - user: 'test-user-123', signal: new AbortController().signal, + user: 'test-user-123', }; const result = await instance.generateObject(payload, options); @@ -1738,12 +1745,12 @@ describe('LobeOpenAICompatibleFactory', () => { const payload = { messages: [{ content: 'Generate data', role: 'user' as const }], + model: 'gpt-4o', + responseApi: false, schema: { name: 'test_tool', - schema: { type: 'object' as const, properties: {} }, + schema: { properties: {}, type: 'object' as const }, }, - model: 'gpt-4o', - responseApi: false, }; const result = await instance.generateObject(payload); @@ -1772,12 +1779,12 @@ describe('LobeOpenAICompatibleFactory', () => { const payload = { messages: [{ content: 'Generate data', role: 'user' as const }], + model: 'gpt-4o', + responseApi: false, schema: { name: 'test_tool', - schema: { type: 'object' as const, properties: {} }, + schema: { properties: {}, type: 'object' as const }, }, - model: 'gpt-4o', - responseApi: false, }; const result = await instance.generateObject(payload); @@ -1806,26 +1813,26 @@ describe('LobeOpenAICompatibleFactory', () => { const payload = { messages: [{ content: 'Generate items list', role: 'user' as const }], + model: 'gpt-4o', schema: { name: 'abc', schema: { - type: 'object' as const, properties: { items: { - type: 'array', items: { - type: 'object', properties: { id: { type: 'number' }, name: { type: 'string' }, }, + type: 'object', }, + type: 'array', }, total: { type: 'number' }, }, + type: 'object' as const, }, }, - model: 'gpt-4o', }; const result = await instance.generateObject(payload); @@ -1846,9 +1853,9 @@ describe('LobeOpenAICompatibleFactory', () => { const payload = { messages: [{ content: 'Generate data', role: 'user' as const }], - schema: { name: 'abc', schema: { type: 'object' } as any }, model: 'gpt-4o', responseApi: false, + schema: { name: 'abc', schema: { type: 'object' } as any }, }; await expect(instance.generateObject(payload)).rejects.toThrow( @@ -1865,18 +1872,18 @@ describe('LobeOpenAICompatibleFactory', () => { message: { tool_calls: [ { - type: 'function' as const, function: { - name: 'get_weather', arguments: '{"city":"Tokyo","unit":"celsius"}', + name: 'get_weather', }, + type: 'function' as const, }, { - type: 'function' as const, function: { - name: 'get_time', arguments: '{"timezone":"Asia/Tokyo"}', + name: 'get_time', }, + type: 'function' as const, }, ], }, @@ -1890,32 +1897,38 @@ describe('LobeOpenAICompatibleFactory', () => { const payload = { messages: [{ content: 'What is the weather and time in Tokyo?', role: 'user' as const }], + model: 'gpt-4o', tools: [ { - name: 'get_weather', - description: 'Get weather information', - parameters: { - type: 'object' as const, - properties: { - city: { type: 'string' }, - unit: { type: 'string' }, + function: { + description: 'Get weather information', + name: 'get_weather', + parameters: { + properties: { + city: { type: 'string' }, + unit: { type: 'string' }, + }, + required: ['city'], + type: 'object' as const, }, - required: ['city'], }, + type: 'function' as const, }, { - name: 'get_time', - description: 'Get current time', - parameters: { - type: 'object' as const, - properties: { - timezone: { type: 'string' }, + function: { + description: 'Get current time', + name: 'get_time', + parameters: { + properties: { + timezone: { type: 'string' }, + }, + required: ['timezone'], + type: 'object' as const, }, - required: ['timezone'], }, + type: 'function' as const, }, ], - model: 'gpt-4o', }; const result = await instance.generateObject(payload); @@ -1927,33 +1940,33 @@ describe('LobeOpenAICompatibleFactory', () => { tool_choice: 'required', tools: [ { - type: 'function', function: { - name: 'get_weather', description: 'Get weather information', + name: 'get_weather', parameters: { - type: 'object', properties: { city: { type: 'string' }, unit: { type: 'string' }, }, required: ['city'], + type: 'object', }, }, + type: 'function', }, { - type: 'function', function: { - name: 'get_time', description: 'Get current time', + name: 'get_time', parameters: { - type: 'object', properties: { timezone: { type: 'string' }, }, required: ['timezone'], + type: 'object', }, }, + type: 'function', }, ], user: undefined, @@ -1974,11 +1987,11 @@ describe('LobeOpenAICompatibleFactory', () => { message: { tool_calls: [ { - type: 'function' as const, function: { - name: 'calculate', arguments: '{"result":8}', + name: 'calculate', }, + type: 'function' as const, }, ], }, @@ -1992,31 +2005,30 @@ describe('LobeOpenAICompatibleFactory', () => { const payload = { messages: [{ content: 'Add 5 and 3', role: 'user' as const }], + model: 'gpt-4o', tools: [ { - name: 'calculate', - description: 'Perform calculation', - parameters: { - type: 'object' as const, - properties: { - result: { type: 'number' }, + function: { + description: 'Perform calculation', + name: 'calculate', + parameters: { + properties: { + result: { type: 'number' }, + }, + required: ['result'], + type: 'object' as const, }, - required: ['result'], }, + type: 'function' as const, }, ], - systemRole: 'You are a helpful calculator', - model: 'gpt-4o', }; const result = await instance.generateObject(payload); expect(instance['client'].chat.completions.create).toHaveBeenCalledWith( expect.objectContaining({ - messages: [ - { content: 'Add 5 and 3', role: 'user' }, - { content: 'You are a helpful calculator', role: 'system' }, - ], + messages: [{ content: 'Add 5 and 3', role: 'user' }], }), expect.any(Object), ); @@ -2058,11 +2070,11 @@ describe('LobeOpenAICompatibleFactory', () => { message: { tool_calls: [ { - type: 'function' as const, function: { - name: 'person_extractor', arguments: '{"name":"Alice","age":28}', + name: 'person_extractor', }, + type: 'function' as const, }, ], }, @@ -2076,15 +2088,15 @@ describe('LobeOpenAICompatibleFactory', () => { const payload = { messages: [{ content: 'Extract person info', role: 'user' as const }], + model: 'test-model', schema: { - name: 'person_extractor', description: 'Extract person information', + name: 'person_extractor', schema: { + properties: { age: { type: 'number' }, name: { type: 'string' } }, type: 'object' as const, - properties: { name: { type: 'string' }, age: { type: 'number' } }, }, }, - model: 'test-model', }; const result = await instanceWithToolCalling.generateObject(payload); @@ -2093,24 +2105,24 @@ describe('LobeOpenAICompatibleFactory', () => { { messages: payload.messages, model: payload.model, + tool_choice: { function: { name: 'person_extractor' }, type: 'function' }, tools: [ { - type: 'function', function: { - name: 'person_extractor', description: 'Extract person information', + name: 'person_extractor', parameters: payload.schema.schema, }, + type: 'function', }, ], - tool_choice: { type: 'function', function: { name: 'person_extractor' } }, user: undefined, }, { headers: undefined, signal: undefined }, ); expect(result).toEqual([ - { arguments: { name: 'Alice', age: 28 }, name: 'person_extractor' }, + { arguments: { age: 28, name: 'Alice' }, name: 'person_extractor' }, ]); }); @@ -2132,11 +2144,11 @@ describe('LobeOpenAICompatibleFactory', () => { const payload = { messages: [{ content: 'Generate data', role: 'user' as const }], + model: 'test-model', schema: { name: 'test_tool', - schema: { type: 'object' as const, properties: {} }, + schema: { properties: {}, type: 'object' as const }, }, - model: 'test-model', }; const result = await instanceWithToolCalling.generateObject(payload); @@ -2154,11 +2166,11 @@ describe('LobeOpenAICompatibleFactory', () => { message: { tool_calls: [ { - type: 'function' as const, function: { - name: 'test_tool', arguments: 'invalid json', + name: 'test_tool', }, + type: 'function' as const, }, ], }, @@ -2173,11 +2185,11 @@ describe('LobeOpenAICompatibleFactory', () => { const payload = { messages: [{ content: 'Generate data', role: 'user' as const }], + model: 'test-model', schema: { name: 'test_tool', - schema: { type: 'object' as const, properties: {} }, + schema: { properties: {}, type: 'object' as const }, }, - model: 'test-model', }; const result = await instanceWithToolCalling.generateObject(payload); @@ -2198,11 +2210,11 @@ describe('LobeOpenAICompatibleFactory', () => { message: { tool_calls: [ { - type: 'function' as const, function: { - name: 'data_extractor', arguments: '{"data":"test"}', + name: 'data_extractor', }, + type: 'function' as const, }, ], }, @@ -2216,17 +2228,17 @@ describe('LobeOpenAICompatibleFactory', () => { const payload = { messages: [{ content: 'Extract data', role: 'user' as const }], + model: 'test-model', schema: { name: 'data_extractor', - schema: { type: 'object' as const, properties: { data: { type: 'string' } } }, + schema: { properties: { data: { type: 'string' } }, type: 'object' as const }, }, - model: 'test-model', }; const options = { headers: { 'X-Custom': 'header' }, - user: 'test-user', signal: new AbortController().signal, + user: 'test-user', }; const result = await instanceWithToolCalling.generateObject(payload, options); @@ -2245,10 +2257,10 @@ describe('LobeOpenAICompatibleFactory', () => { it('should get models with third party model list', async () => { vi.spyOn(instance['client'].models, 'list').mockResolvedValue({ data: [ - { id: 'gpt-4o', object: 'model', created: 1698218177 }, + { created: 1_698_218_177, id: 'gpt-4o', object: 'model' }, { id: 'claude-3-haiku-20240307', object: 'model' }, - { id: 'gpt-4o-mini', object: 'model', created: 1698318177 * 1000 }, - { id: 'gemini', object: 'model', created: 1736499509125 }, + { created: 1_698_318_177 * 1000, id: 'gpt-4o-mini', object: 'model' }, + { created: 1_736_499_509_125, id: 'gemini', object: 'model' }, ], } as any); @@ -2263,7 +2275,7 @@ describe('LobeOpenAICompatibleFactory', () => { config: { deploymentName: 'gpt-4o', }, - contextWindowTokens: 128000, + contextWindowTokens: 128_000, description: 'ChatGPT-4o 是一款动态模型,实时更新以保持当前最新版本。它结合了强大的语言理解与生成能力,适合于大规模应用场景,包括客户服务、教育和技术支持。', displayName: 'GPT-4o', @@ -2302,7 +2314,7 @@ describe('LobeOpenAICompatibleFactory', () => { functionCall: true, vision: true, }, - contextWindowTokens: 200000, + contextWindowTokens: 200_000, description: 'Claude 3 Haiku 是 Anthropic 的最快且最紧凑的模型,旨在实现近乎即时的响应。它具有快速且准确的定向性能。', displayName: 'Claude 3 Haiku', @@ -2359,7 +2371,7 @@ describe('LobeOpenAICompatibleFactory', () => { config: { deploymentName: 'gpt-4o-mini', }, - contextWindowTokens: 128000, + contextWindowTokens: 128_000, description: 'GPT-4o Mini,小型高效模型,具备与GPT-4o相似的卓越性能。', displayName: 'GPT 4o Mini', enabled: false, diff --git a/packages/model-runtime/src/core/openaiCompatibleFactory/index.ts b/packages/model-runtime/src/core/openaiCompatibleFactory/index.ts index a6a743b9951..c76f79cbf47 100644 --- a/packages/model-runtime/src/core/openaiCompatibleFactory/index.ts +++ b/packages/model-runtime/src/core/openaiCompatibleFactory/index.ts @@ -1,11 +1,13 @@ import type { ChatModelCard } from '@lobechat/types'; import dayjs from 'dayjs'; import utc from 'dayjs/plugin/utc'; +import debug from 'debug'; import { LOBE_DEFAULT_MODEL_LIST } from 'model-bank'; import type { AiModelType } from 'model-bank'; import OpenAI, { ClientOptions } from 'openai'; import { Stream } from 'openai/streaming'; +import { responsesAPIModels } from '../../const/models'; import { ChatCompletionErrorPayload, ChatCompletionTool, @@ -117,6 +119,15 @@ export interface OpenAICompatibleFactoryOptions = invalidAPIKey: ILobeAgentRuntimeErrorType; }; generateObject?: { + /** + * If true, route generateObject requests to Responses API path directly + */ + useResponse?: boolean; + /** + * Allow only some models to use Responses API by simple matching. + * If any string appears in model id or RegExp matches, Responses API is used. + */ + useResponseModels?: Array; /** * Use tool calling to simulate structured output for providers that don't support native structured output */ @@ -141,7 +152,7 @@ export const createOpenAICompatibleRuntime = = an baseURL: DEFAULT_BASE_URL, apiKey: DEFAULT_API_LEY, errorType, - debug, + debug: debugParams, constructorOptions, chatCompletion, models, @@ -159,6 +170,7 @@ export const createOpenAICompatibleRuntime = = an client!: OpenAI; private id: string; + private logPrefix: string; baseURL!: string; protected _options: ConstructorOptions; @@ -186,12 +198,16 @@ export const createOpenAICompatibleRuntime = = an this.baseURL = baseURL || this.client.baseURL; this.id = options.id || provider; + this.logPrefix = `lobe-model-runtime:${this.id}`; } async chat({ responseMode, ...payload }: ChatStreamPayload, options?: ChatMethodOptions) { try { + const log = debug(`${this.logPrefix}:chat`); const inputStartAt = Date.now(); + log('chat called with model: %s, stream: %s', payload.model, payload.stream ?? true); + // 工厂级 Responses API 路由控制(支持实例覆盖) const modelId = (payload as any).model as string | undefined; const shouldUseResponses = (() => { @@ -214,7 +230,10 @@ export const createOpenAICompatibleRuntime = = an let processedPayload: any = payload; if (shouldUseResponses) { + log('using Responses API mode'); processedPayload = { ...payload, apiMode: 'responses' } as any; + } else { + log('using Chat Completions API mode'); } // 再进行工厂级处理 @@ -244,6 +263,7 @@ export const createOpenAICompatibleRuntime = = an }; if (customClient?.createChatCompletionStream) { + log('using custom client for chat completion stream'); response = customClient.createChatCompletionStream( this.client, processedPayload, @@ -260,7 +280,9 @@ export const createOpenAICompatibleRuntime = = an : undefined, }; - if (debug?.chatCompletion?.()) { + log('sending chat completion request with %d messages', messages.length); + + if (debugParams?.chatCompletion?.()) { console.log('[requestPayload]'); console.log(JSON.stringify(finalPayload), '\n'); } @@ -273,9 +295,10 @@ export const createOpenAICompatibleRuntime = = an } if (postPayload.stream) { + log('processing streaming response'); const [prod, useForDebug] = response.tee(); - if (debug?.chatCompletion?.()) { + if (debugParams?.chatCompletion?.()) { const useForDebugStream = useForDebug instanceof ReadableStream ? useForDebug : useForDebug.toReadableStream(); @@ -298,12 +321,16 @@ export const createOpenAICompatibleRuntime = = an ); } - if (debug?.chatCompletion?.()) { + if (debugParams?.chatCompletion?.()) { debugResponse(response); } - if (responseMode === 'json') return Response.json(response); + if (responseMode === 'json') { + log('returning JSON response mode'); + return Response.json(response); + } + log('transforming non-streaming response to stream'); const transformHandler = chatCompletion?.handleTransformResponseToStream || transformResponseToStream; const stream = transformHandler(response as unknown as OpenAI.ChatCompletion); @@ -325,8 +352,11 @@ export const createOpenAICompatibleRuntime = = an } async createImage(payload: CreateImagePayload) { + const log = debug(`${this.logPrefix}:createImage`); + // If custom createImage implementation is provided, use it if (customCreateImage) { + log('using custom createImage implementation'); return customCreateImage(payload, { ...this._options, apiKey: this._options.apiKey!, @@ -334,15 +364,21 @@ export const createOpenAICompatibleRuntime = = an }); } + log('using default createOpenAICompatibleImage'); // Use the new createOpenAICompatibleImage function return createOpenAICompatibleImage(this.client, payload, this.id); } async models() { + const log = debug(`${this.logPrefix}:models`); + log('fetching available models'); + let resultModels: ChatModelCard[] = []; if (typeof models === 'function') { + log('using custom models function'); resultModels = await models({ client: this.client }); } else { + log('fetching models from client API'); const list = await this.client.models.list(); resultModels = list.data .filter((model) => { @@ -390,49 +426,34 @@ export const createOpenAICompatibleRuntime = = an .filter(Boolean) as ChatModelCard[]; } + log('fetched %d models', resultModels.length); + return await postProcessModelList(resultModels, (modelId) => getModelPropertyWithFallback(modelId, 'type'), ); } async generateObject(payload: GenerateObjectPayload, options?: GenerateObjectOptions) { - const { messages, schema, model, responseApi, tools, systemRole } = payload; + const { messages, schema, model, responseApi, tools } = payload; + + const log = debug(`${this.logPrefix}:generateObject`); + log( + 'generateObject called with model: %s, hasTools: %s, hasSchema: %s', + model, + !!tools, + !!schema, + ); if (tools) { - const msgs = messages; - - if (!!systemRole) { - msgs.push({ content: systemRole, role: 'system' }); - } - - const res = await this.client.chat.completions.create( - { - messages: msgs, - model, - tool_choice: 'required', - tools: tools.map((tool) => ({ function: tool, type: 'function' })), - user: options?.user, - }, - { headers: options?.headers, signal: options?.signal }, - ); - - const toolCalls = res.choices[0].message.tool_calls!; - - try { - return toolCalls.map((item) => ({ - arguments: JSON.parse(item.function.arguments), - name: item.function.name, - })); - } catch { - console.error('parse tool call arguments error:', res); - return undefined; - } + log('using tools-based generation'); + return this.generateObjectWithTools(payload, options); } if (!schema) throw new Error('tools or schema is required'); // Use tool calling fallback if configured if (generateObjectConfig?.useToolsCalling) { + log('using tool calling fallback for structured output'); const tool: ChatCompletionTool = { function: { description: @@ -467,7 +488,51 @@ export const createOpenAICompatibleRuntime = = an } } - if (responseApi) { + // Factory-level Responses API routing control (supports instance override) + const shouldUseResponses = (() => { + const instanceGenerateObject = ((this._options as any).generateObject || {}) as { + useResponse?: boolean; + useResponseModels?: Array; + }; + const flagUseResponse = + instanceGenerateObject.useResponse ?? + (generateObjectConfig ? generateObjectConfig.useResponse : undefined); + const flagUseResponseModels = + instanceGenerateObject.useResponseModels ?? generateObjectConfig?.useResponseModels; + + if (responseApi) { + log('using Responses API due to explicit responseApi flag'); + return true; + } + + if (flagUseResponse) { + log('using Responses API due to useResponse flag'); + return true; + } + + // Use factory-configured model list if provided + if (model && flagUseResponseModels?.length) { + const matches = flagUseResponseModels.some((m: string | RegExp) => + typeof m === 'string' ? model.includes(m) : (m as RegExp).test(model), + ); + if (matches) { + log('using Responses API: model %s matches useResponseModels config', model); + return true; + } + } + + // Default: use built-in responsesAPIModels + if (model && responsesAPIModels.has(model)) { + log('using Responses API: model %s in built-in responsesAPIModels', model); + return true; + } + + log('using Chat Completions API for generateObject'); + return false; + })(); + + if (shouldUseResponses) { + log('calling responses.create for structured output'); const res = await this.client!.responses.create( { input: messages, @@ -479,14 +544,19 @@ export const createOpenAICompatibleRuntime = = an ); const text = res.output_text; + log('received structured output from Responses API, length: %d', text?.length || 0); try { - return JSON.parse(text); - } catch { + const result = JSON.parse(text); + log('successfully parsed JSON output'); + return result; + } catch (error) { + log('failed to parse JSON output: %O', error); console.error('parse json error:', text); return undefined; } } + log('calling chat.completions.create for structured output'); const res = await this.client.chat.completions.create( { messages, @@ -498,9 +568,14 @@ export const createOpenAICompatibleRuntime = = an ); const text = res.choices[0].message.content!; + log('received structured output from Chat Completions API, length: %d', text?.length || 0); + try { - return JSON.parse(text); - } catch { + const result = JSON.parse(text); + log('successfully parsed JSON output'); + return result; + } catch (error) { + log('failed to parse JSON output: %O', error); console.error('parse json error:', text); return undefined; } @@ -510,12 +585,20 @@ export const createOpenAICompatibleRuntime = = an payload: EmbeddingsPayload, options?: EmbeddingsOptions, ): Promise { + const log = debug(`${this.logPrefix}:embeddings`); + log( + 'embeddings called with model: %s, input items: %d', + payload.model, + Array.isArray(payload.input) ? payload.input.length : 1, + ); + try { const res = await this.client.embeddings.create( { ...payload, encoding_format: 'float', user: options?.user }, { headers: options?.headers, signal: options?.signal }, ); + log('received %d embeddings', res.data.length); return res.data.map((item) => item.embedding); } catch (error) { throw this.handleError(error); @@ -523,8 +606,12 @@ export const createOpenAICompatibleRuntime = = an } async textToImage(payload: TextToImagePayload) { + const log = debug(`${this.logPrefix}:textToImage`); + log('textToImage called with prompt length: %d', payload.prompt?.length || 0); + try { const res = await this.client.images.generate(payload); + log('generated %d images', res.data?.length || 0); return (res.data || []).map((o) => o.url) as string[]; } catch (error) { throw this.handleError(error); @@ -532,18 +619,30 @@ export const createOpenAICompatibleRuntime = = an } async textToSpeech(payload: TextToSpeechPayload, options?: TextToSpeechOptions) { + const log = debug(`${this.logPrefix}:textToSpeech`); + log( + 'textToSpeech called with input length: %d, voice: %s', + payload.input?.length || 0, + payload.voice, + ); + try { const mp3 = await this.client.audio.speech.create(payload as any, { headers: options?.headers, signal: options?.signal, }); - return mp3.arrayBuffer(); + const buffer = await mp3.arrayBuffer(); + log('generated audio with size: %d bytes', buffer.byteLength); + return buffer; } catch (error) { throw this.handleError(error); } } protected handleError(error: any): ChatCompletionErrorPayload { + const log = debug(`${this.logPrefix}:error`); + log('handling error: %O', error); + let desensitizedEndpoint = this.baseURL; // refs: https://github.com/lobehub/lobe-chat/issues/842 @@ -552,6 +651,7 @@ export const createOpenAICompatibleRuntime = = an } if (chatCompletion?.handleError) { + log('using custom error handler'); const errorResult = chatCompletion.handleError(error, this._options); if (errorResult) @@ -562,8 +662,12 @@ export const createOpenAICompatibleRuntime = = an } if ('status' in (error as any)) { - switch ((error as Response).status) { + const status = (error as Response).status; + log('HTTP error with status: %d', status); + + switch (status) { case 401: { + log('invalid API key error'); return AgentRuntimeError.chat({ endpoint: desensitizedEndpoint, error: error as any, @@ -580,8 +684,11 @@ export const createOpenAICompatibleRuntime = = an const { errorResult, RuntimeError } = handleOpenAIError(error); + log('error code: %s, message: %s', errorResult.code, errorResult.message); + switch (errorResult.code) { case 'insufficient_quota': { + log('insufficient quota error'); return AgentRuntimeError.chat({ endpoint: desensitizedEndpoint, error: errorResult, @@ -591,6 +698,7 @@ export const createOpenAICompatibleRuntime = = an } case 'model_not_found': { + log('model not found error'); return AgentRuntimeError.chat({ endpoint: desensitizedEndpoint, error: errorResult, @@ -602,6 +710,7 @@ export const createOpenAICompatibleRuntime = = an // content too long case 'context_length_exceeded': case 'string_above_max_length': { + log('context length exceeded error'); return AgentRuntimeError.chat({ endpoint: desensitizedEndpoint, error: errorResult, @@ -611,6 +720,7 @@ export const createOpenAICompatibleRuntime = = an } } + log('returning generic error'); return AgentRuntimeError.chat({ endpoint: desensitizedEndpoint, error: errorResult, @@ -623,6 +733,9 @@ export const createOpenAICompatibleRuntime = = an payload: ChatStreamPayload, options?: ChatMethodOptions, ): Promise { + const log = debug(`${this.logPrefix}:handleResponseAPIMode`); + log('handleResponseAPIMode called with model: %s', payload.model); + const inputStartAt = Date.now(); const { messages, reasoning_effort, tools, reasoning, responseMode, ...res } = @@ -638,6 +751,12 @@ export const createOpenAICompatibleRuntime = = an const input = await convertOpenAIResponseInputs(messages as any); const isStreaming = payload.stream !== false; + log( + 'isStreaming: %s, hasTools: %s, hasReasoning: %s', + isStreaming, + !!tools, + !!(reasoning || reasoning_effort), + ); const postPayload = { ...res, @@ -655,11 +774,13 @@ export const createOpenAICompatibleRuntime = = an tools: tools?.map((tool) => this.convertChatCompletionToolToResponseTool(tool)), } as OpenAI.Responses.ResponseCreateParamsStreaming | OpenAI.Responses.ResponseCreateParams; - if (debug?.responses?.()) { + if (debugParams?.responses?.()) { console.log('[requestPayload]'); console.log(JSON.stringify(postPayload), '\n'); } + log('sending responses.create request'); + const response = await this.client.responses.create(postPayload, { headers: options?.requestHeaders, signal: options?.signal, @@ -676,10 +797,11 @@ export const createOpenAICompatibleRuntime = = an }; if (isStreaming) { + log('processing streaming Responses API response'); const stream = response as Stream; const [prod, useForDebug] = stream.tee(); - if (debug?.responses?.()) { + if (debugParams?.responses?.()) { const useForDebugStream = useForDebug instanceof ReadableStream ? useForDebug : useForDebug.toReadableStream(); @@ -691,13 +813,19 @@ export const createOpenAICompatibleRuntime = = an }); } + log('processing non-streaming Responses API response'); + // Handle non-streaming response - if (debug?.responses?.()) { + if (debugParams?.responses?.()) { debugResponse(response); } - if (responseMode === 'json') return Response.json(response); + if (responseMode === 'json') { + log('returning JSON response mode'); + return Response.json(response); + } + log('transforming non-streaming Responses API response to stream'); const stream = transformResponseAPIToStream(response as OpenAI.Responses.Response); return StreamingResponse( @@ -708,8 +836,138 @@ export const createOpenAICompatibleRuntime = = an ); } - private convertChatCompletionToolToResponseTool = (tool: ChatCompletionTool) => { - return { type: tool.type, ...tool.function }; + private convertChatCompletionToolToResponseTool = ( + tool: ChatCompletionTool, + ): OpenAI.Responses.Tool => { + return { type: tool.type, ...tool.function } as any; }; + + private async generateObjectWithTools( + payload: GenerateObjectPayload, + options?: GenerateObjectOptions, + ) { + const { messages, model, tools, responseApi } = payload; + const log = debug(`${this.logPrefix}:generateObject`); + + log( + 'generateObjectWithTools called with model: %s, toolsCount: %d', + model, + tools?.length || 0, + ); + + // Factory-level Responses API routing control (supports instance override) + const shouldUseResponses = (() => { + const instanceGenerateObject = ((this._options as any).generateObject || {}) as { + useResponse?: boolean; + useResponseModels?: Array; + }; + const flagUseResponse = + instanceGenerateObject.useResponse ?? + (generateObjectConfig ? generateObjectConfig.useResponse : undefined); + const flagUseResponseModels = + instanceGenerateObject.useResponseModels ?? generateObjectConfig?.useResponseModels; + + if (responseApi) { + log('using Responses API due to explicit responseApi flag'); + return true; + } + + if (flagUseResponse) { + log('using Responses API due to useResponse flag'); + return true; + } + + // Use factory-configured model list if provided + if (model && flagUseResponseModels?.length) { + const matches = flagUseResponseModels.some((m: string | RegExp) => + typeof m === 'string' ? model.includes(m) : (m as RegExp).test(model), + ); + if (matches) { + log('using Responses API: model %s matches useResponseModels config', model); + return true; + } + } + + // Default: use built-in responsesAPIModels + if (model && responsesAPIModels.has(model)) { + log('using Responses API: model %s in built-in responsesAPIModels', model); + return true; + } + + log('using Chat Completions API for tool calling'); + return false; + })(); + + if (shouldUseResponses) { + log('calling responses.create for tool calling'); + const input = await convertOpenAIResponseInputs(messages as any); + + const res = await this.client.responses.create( + { + input, + model, + tool_choice: 'required', + tools: tools!.map((tool) => this.convertChatCompletionToolToResponseTool(tool)), + user: options?.user, + }, + { headers: options?.headers, signal: options?.signal }, + ); + + const functionCalls = res.output?.filter((item: any) => item.type === 'function_call'); + + log('received %d function calls from Responses API', functionCalls?.length || 0); + + try { + const result = functionCalls?.map((item: any) => ({ + arguments: + typeof item.arguments === 'string' ? JSON.parse(item.arguments) : item.arguments, + name: item.name, + })); + log( + 'successfully parsed function calls: %O', + result?.map((r) => r.name), + ); + return result; + } catch (error) { + log('failed to parse tool call arguments: %O', error); + console.error('parse tool call arguments error:', res); + return undefined; + } + } + + log('calling chat.completions.create for tool calling'); + const msgs = messages; + + const res = await this.client.chat.completions.create( + { + messages: msgs, + model, + tool_choice: 'required', + tools, + user: options?.user, + }, + { headers: options?.headers, signal: options?.signal }, + ); + + const toolCalls = res.choices[0].message.tool_calls!; + + log('received %d tool calls from Chat Completions API', toolCalls?.length || 0); + + try { + const result = toolCalls.map((item) => ({ + arguments: JSON.parse(item.function.arguments), + name: item.function.name, + })); + log( + 'successfully parsed tool calls: %O', + result.map((r) => r.name), + ); + return result; + } catch (error) { + log('failed to parse tool call arguments: %O', error); + console.error('parse tool call arguments error:', res); + return undefined; + } + } }; }; diff --git a/packages/model-runtime/src/providers/anthropic/generateObject.test.ts b/packages/model-runtime/src/providers/anthropic/generateObject.test.ts index a457068e135..e8ce9c22e8c 100644 --- a/packages/model-runtime/src/providers/anthropic/generateObject.test.ts +++ b/packages/model-runtime/src/providers/anthropic/generateObject.test.ts @@ -28,9 +28,9 @@ describe('Anthropic generateObject', () => { create: vi.fn().mockResolvedValue({ content: [ { - type: 'tool_use', + input: { age: 30, name: 'John' }, name: 'person_extractor', - input: { name: 'John', age: 30 }, + type: 'tool_use', }, ], }), @@ -39,48 +39,48 @@ describe('Anthropic generateObject', () => { const payload = { messages: [{ content: 'Generate a person object', role: 'user' as const }], + model: 'claude-3-5-sonnet-20241022', schema: { - name: 'person_extractor', description: 'Extract person information', + name: 'person_extractor', schema: { - type: 'object' as const, - properties: { name: { type: 'string' }, age: { type: 'number' } }, + properties: { age: { type: 'number' }, name: { type: 'string' } }, required: ['name', 'age'], + type: 'object' as const, }, }, - model: 'claude-3-5-sonnet-20241022', }; const result = await createAnthropicGenerateObject(mockClient as any, payload); expect(mockClient.messages.create).toHaveBeenCalledWith( expect.objectContaining({ - model: 'claude-3-5-sonnet-20241022', max_tokens: 8192, messages: [{ content: 'Generate a person object', role: 'user' }], + model: 'claude-3-5-sonnet-20241022', + tool_choice: { + name: 'person_extractor', + type: 'tool', + }, tools: [ { - name: 'person_extractor', description: 'Extract person information', input_schema: { - type: 'object', properties: { - name: { type: 'string' }, age: { type: 'number' }, + name: { type: 'string' }, }, required: ['name', 'age'], + type: 'object', }, + name: 'person_extractor', }, ], - tool_choice: { - type: 'tool', - name: 'person_extractor', - }, }), expect.objectContaining({}), ); - expect(result).toEqual({ name: 'John', age: 30 }); + expect(result).toEqual({ age: 30, name: 'John' }); }); it('should handle system messages correctly', async () => { @@ -89,9 +89,9 @@ describe('Anthropic generateObject', () => { create: vi.fn().mockResolvedValue({ content: [ { - type: 'tool_use', - name: 'status_extractor', input: { status: 'success' }, + name: 'status_extractor', + type: 'tool_use', }, ], }), @@ -103,19 +103,19 @@ describe('Anthropic generateObject', () => { { content: 'You are a helpful assistant', role: 'system' as const }, { content: 'Generate status', role: 'user' as const }, ], + model: 'claude-3-5-sonnet-20241022', schema: { name: 'status_extractor', - schema: { type: 'object' as const, properties: { status: { type: 'string' } } }, + schema: { properties: { status: { type: 'string' } }, type: 'object' as const }, }, - model: 'claude-3-5-sonnet-20241022', }; const result = await createAnthropicGenerateObject(mockClient as any, payload); expect(mockClient.messages.create).toHaveBeenCalledWith( expect.objectContaining({ - system: [{ text: 'You are a helpful assistant', type: 'text' }], messages: expect.any(Array), + system: [{ text: 'You are a helpful assistant', type: 'text' }], }), expect.any(Object), ); @@ -129,9 +129,9 @@ describe('Anthropic generateObject', () => { create: vi.fn().mockResolvedValue({ content: [ { - type: 'tool_use', - name: 'data_extractor', input: { data: 'test' }, + name: 'data_extractor', + type: 'tool_use', }, ], }), @@ -140,11 +140,11 @@ describe('Anthropic generateObject', () => { const payload = { messages: [{ content: 'Generate data', role: 'user' as const }], + model: 'claude-3-5-sonnet-20241022', schema: { name: 'data_extractor', - schema: { type: 'object' as const, properties: { data: { type: 'string' } } }, + schema: { properties: { data: { type: 'string' } }, type: 'object' as const }, }, - model: 'claude-3-5-sonnet-20241022', }; const options = { @@ -169,8 +169,8 @@ describe('Anthropic generateObject', () => { create: vi.fn().mockResolvedValue({ content: [ { - type: 'text', text: 'Some text response without tool use', + type: 'text', }, ], }), @@ -179,11 +179,11 @@ describe('Anthropic generateObject', () => { const payload = { messages: [{ content: 'Generate data', role: 'user' as const }], + model: 'claude-3-5-sonnet-20241022', schema: { name: 'test_tool', schema: { type: 'object' }, }, - model: 'claude-3-5-sonnet-20241022', }; const result = await createAnthropicGenerateObject(mockClient as any, payload as any); @@ -197,9 +197,10 @@ describe('Anthropic generateObject', () => { create: vi.fn().mockResolvedValue({ content: [ { - type: 'tool_use', - name: 'user_extractor', input: { + metadata: { + created: '2024-01-01', + }, user: { name: 'Alice', profile: { @@ -207,10 +208,9 @@ describe('Anthropic generateObject', () => { preferences: ['music', 'sports'], }, }, - metadata: { - created: '2024-01-01', - }, }, + name: 'user_extractor', + type: 'tool_use', }, ], }), @@ -219,35 +219,38 @@ describe('Anthropic generateObject', () => { const payload = { messages: [{ content: 'Generate complex user data', role: 'user' as const }], + model: 'claude-3-5-sonnet-20241022', schema: { - name: 'user_extractor', description: 'Extract complex user information', + name: 'user_extractor', schema: { - type: 'object' as const, properties: { + metadata: { type: 'object' }, user: { - type: 'object', properties: { name: { type: 'string' }, profile: { - type: 'object', properties: { age: { type: 'number' }, - preferences: { type: 'array', items: { type: 'string' } }, + preferences: { items: { type: 'string' }, type: 'array' }, }, + type: 'object', }, }, + type: 'object', }, - metadata: { type: 'object' }, }, + type: 'object' as const, }, }, - model: 'claude-3-5-sonnet-20241022', }; const result = await createAnthropicGenerateObject(mockClient as any, payload); expect(result).toEqual({ + metadata: { + created: '2024-01-01', + }, user: { name: 'Alice', profile: { @@ -255,9 +258,6 @@ describe('Anthropic generateObject', () => { preferences: ['music', 'sports'], }, }, - metadata: { - created: '2024-01-01', - }, }); }); }); @@ -269,14 +269,14 @@ describe('Anthropic generateObject', () => { create: vi.fn().mockResolvedValue({ content: [ { - type: 'tool_use', - name: 'get_weather', input: { city: 'New York', unit: 'celsius' }, + name: 'get_weather', + type: 'tool_use', }, { - type: 'tool_use', - name: 'get_time', input: { timezone: 'America/New_York' }, + name: 'get_time', + type: 'tool_use', }, ], }), @@ -285,69 +285,75 @@ describe('Anthropic generateObject', () => { const payload = { messages: [{ content: 'What is the weather and time in New York?', role: 'user' as const }], + model: 'claude-3-5-sonnet-20241022', tools: [ { - name: 'get_weather', - description: 'Get weather information', - parameters: { - type: 'object' as const, - properties: { - city: { type: 'string' }, - unit: { type: 'string' }, + function: { + description: 'Get weather information', + name: 'get_weather', + parameters: { + properties: { + city: { type: 'string' }, + unit: { type: 'string' }, + }, + required: ['city'], + type: 'object' as const, }, - required: ['city'], }, + type: 'function' as const, }, { - name: 'get_time', - description: 'Get current time', - parameters: { - type: 'object' as const, - properties: { - timezone: { type: 'string' }, + function: { + description: 'Get current time', + name: 'get_time', + parameters: { + properties: { + timezone: { type: 'string' }, + }, + required: ['timezone'], + type: 'object' as const, }, - required: ['timezone'], }, + type: 'function' as const, }, ], - model: 'claude-3-5-sonnet-20241022', }; const result = await createAnthropicGenerateObject(mockClient as any, payload as any); expect(mockClient.messages.create).toHaveBeenCalledWith( expect.objectContaining({ - model: 'claude-3-5-sonnet-20241022', max_tokens: 8192, messages: [{ content: 'What is the weather and time in New York?', role: 'user' }], + model: 'claude-3-5-sonnet-20241022', + tool_choice: { + type: 'any', + }, tools: [ { - name: 'get_weather', description: 'Get weather information', input_schema: { - type: 'object', properties: { city: { type: 'string' }, unit: { type: 'string' }, }, required: ['city'], + type: 'object', }, + name: 'get_weather', }, { - name: 'get_time', description: 'Get current time', input_schema: { - type: 'object', properties: { timezone: { type: 'string' }, }, required: ['timezone'], + type: 'object', }, + name: 'get_time', }, ], - tool_choice: { - type: 'any', - }, }), expect.objectContaining({}), ); @@ -364,9 +370,9 @@ describe('Anthropic generateObject', () => { create: vi.fn().mockResolvedValue({ content: [ { - type: 'tool_use', + input: { a: 5, b: 3, operation: 'add' }, name: 'calculate', - input: { operation: 'add', a: 5, b: 3 }, + type: 'tool_use', }, ], }), @@ -375,27 +381,30 @@ describe('Anthropic generateObject', () => { const payload = { messages: [{ content: 'Add 5 and 3', role: 'user' as const }], + model: 'claude-3-5-sonnet-20241022', tools: [ { - name: 'calculate', - description: 'Perform mathematical calculation', - parameters: { - type: 'object' as const, - properties: { - operation: { type: 'string' }, - a: { type: 'number' }, - b: { type: 'number' }, + function: { + description: 'Perform mathematical calculation', + name: 'calculate', + parameters: { + properties: { + a: { type: 'number' }, + b: { type: 'number' }, + operation: { type: 'string' }, + }, + required: ['operation', 'a', 'b'], + type: 'object' as const, }, - required: ['operation', 'a', 'b'], }, + type: 'function' as const, }, ], - model: 'claude-3-5-sonnet-20241022', }; const result = await createAnthropicGenerateObject(mockClient as any, payload as any); - expect(result).toEqual([{ arguments: { operation: 'add', a: 5, b: 3 }, name: 'calculate' }]); + expect(result).toEqual([{ arguments: { a: 5, b: 3, operation: 'add' }, name: 'calculate' }]); }); }); @@ -410,11 +419,11 @@ describe('Anthropic generateObject', () => { const payload = { messages: [{ content: 'Generate data', role: 'user' as const }], + model: 'claude-3-5-sonnet-20241022', schema: { name: 'test_tool', schema: { type: 'object' }, }, - model: 'claude-3-5-sonnet-20241022', }; await expect(createAnthropicGenerateObject(mockClient as any, payload as any)).rejects.toThrow( @@ -434,11 +443,11 @@ describe('Anthropic generateObject', () => { const payload = { messages: [{ content: 'Generate data', role: 'user' as const }], + model: 'claude-3-5-sonnet-20241022', schema: { name: 'test_tool', schema: { type: 'object' }, }, - model: 'claude-3-5-sonnet-20241022', }; const options = { diff --git a/packages/model-runtime/src/providers/anthropic/generateObject.ts b/packages/model-runtime/src/providers/anthropic/generateObject.ts index 3e76c52fcc9..c93955f3225 100644 --- a/packages/model-runtime/src/providers/anthropic/generateObject.ts +++ b/packages/model-runtime/src/providers/anthropic/generateObject.ts @@ -14,14 +14,14 @@ export const createAnthropicGenerateObject = async ( payload: GenerateObjectPayload, options?: GenerateObjectOptions, ) => { - const { schema, messages, systemRole, model, tools } = payload; + const { schema, messages, model, tools } = payload; log('generateObject called with model: %s', model); log('schema: %O', schema); log('messages count: %d', messages.length); // Convert messages to Anthropic format - const system_message = systemRole || messages.find((m) => m.role === 'system')?.content; + const system_message = messages.find((m) => m.role === 'system')?.content; const user_messages = messages.filter((m) => m.role !== 'system'); const anthropicMessages = await buildAnthropicMessages(user_messages); @@ -39,7 +39,7 @@ export const createAnthropicGenerateObject = async ( let finalTools; let tool_choice: Anthropic.ToolChoiceAny | Anthropic.ToolChoiceTool; if (tools) { - finalTools = buildAnthropicTools(tools.map((item) => ({ function: item, type: 'function' }))); + finalTools = buildAnthropicTools(tools); tool_choice = { type: 'any' }; } else if (schema) { // Convert OpenAI-style schema to Anthropic tool format diff --git a/packages/model-runtime/src/providers/google/generateObject.test.ts b/packages/model-runtime/src/providers/google/generateObject.test.ts index ee95405b779..1e4bfe209c3 100644 --- a/packages/model-runtime/src/providers/google/generateObject.test.ts +++ b/packages/model-runtime/src/providers/google/generateObject.test.ts @@ -2,145 +2,174 @@ import { Type as SchemaType } from '@google/genai'; import { describe, expect, it, vi } from 'vitest'; -import { convertOpenAISchemaToGoogleSchema, createGoogleGenerateObject } from './generateObject'; +import { + convertOpenAISchemaToGoogleSchema, + createGoogleGenerateObject, + createGoogleGenerateObjectWithTools, +} from './generateObject'; describe('Google generateObject', () => { describe('convertOpenAISchemaToGoogleSchema', () => { it('should convert basic types correctly', () => { const openAISchema = { - type: 'object', - properties: { - name: { type: 'string' }, - age: { type: 'number' }, - isActive: { type: 'boolean' }, - count: { type: 'integer' }, + name: 'person', + schema: { + properties: { + age: { type: 'number' }, + count: { type: 'integer' }, + isActive: { type: 'boolean' }, + name: { type: 'string' }, + }, + type: 'object' as const, }, }; const result = convertOpenAISchemaToGoogleSchema(openAISchema); expect(result).toEqual({ - type: SchemaType.OBJECT, properties: { - name: { type: SchemaType.STRING }, age: { type: SchemaType.NUMBER }, - isActive: { type: SchemaType.BOOLEAN }, count: { type: SchemaType.INTEGER }, + isActive: { type: SchemaType.BOOLEAN }, + name: { type: SchemaType.STRING }, }, + type: SchemaType.OBJECT, }); }); it('should convert array schemas correctly', () => { const openAISchema = { - type: 'array', - items: { - type: 'object', + name: 'recipes', + schema: { properties: { - recipeName: { type: 'string' }, - ingredients: { + recipes: { + items: { + properties: { + ingredients: { + items: { type: 'string' }, + type: 'array', + }, + recipeName: { type: 'string' }, + }, + propertyOrdering: ['recipeName', 'ingredients'], + type: 'object', + }, type: 'array', - items: { type: 'string' }, }, }, - propertyOrdering: ['recipeName', 'ingredients'], + type: 'object' as const, }, }; const result = convertOpenAISchemaToGoogleSchema(openAISchema); expect(result).toEqual({ - type: SchemaType.ARRAY, - items: { - type: SchemaType.OBJECT, - properties: { - recipeName: { type: SchemaType.STRING }, - ingredients: { - type: SchemaType.ARRAY, - items: { type: SchemaType.STRING }, + properties: { + recipes: { + items: { + properties: { + ingredients: { + items: { type: SchemaType.STRING }, + type: SchemaType.ARRAY, + }, + recipeName: { type: SchemaType.STRING }, + }, + propertyOrdering: ['recipeName', 'ingredients'], + type: SchemaType.OBJECT, }, + type: SchemaType.ARRAY, }, - propertyOrdering: ['recipeName', 'ingredients'], }, + type: SchemaType.OBJECT, }); }); it('should handle nested objects', () => { const openAISchema = { - type: 'object', - properties: { - user: { - type: 'object', - properties: { - profile: { - type: 'object', - properties: { - preferences: { - type: 'array', - items: { type: 'string' }, + name: 'user_data', + schema: { + properties: { + user: { + properties: { + profile: { + properties: { + preferences: { + items: { type: 'string' }, + type: 'array', + }, }, + type: 'object', }, }, + type: 'object', }, }, + type: 'object' as const, }, }; const result = convertOpenAISchemaToGoogleSchema(openAISchema); expect(result).toEqual({ - type: SchemaType.OBJECT, properties: { user: { - type: SchemaType.OBJECT, properties: { profile: { - type: SchemaType.OBJECT, properties: { preferences: { - type: SchemaType.ARRAY, items: { type: SchemaType.STRING }, + type: SchemaType.ARRAY, }, }, + type: SchemaType.OBJECT, }, }, + type: SchemaType.OBJECT, }, }, + type: SchemaType.OBJECT, }); }); it('should preserve additional properties like description, enum, required', () => { const openAISchema = { - type: 'object', - description: 'A person object', - properties: { - status: { - type: 'string', - enum: ['active', 'inactive'], - description: 'The status of the person', + name: 'person', + schema: { + description: 'A person object', + properties: { + status: { + description: 'The status of the person', + enum: ['active', 'inactive'], + type: 'string', + }, }, - }, - required: ['status'], + required: ['status'], + type: 'object' as const, + } as any, }; const result = convertOpenAISchemaToGoogleSchema(openAISchema); expect(result).toEqual({ - type: SchemaType.OBJECT, description: 'A person object', properties: { status: { - type: SchemaType.STRING, - enum: ['active', 'inactive'], description: 'The status of the person', + enum: ['active', 'inactive'], + type: SchemaType.STRING, }, }, required: ['status'], + type: SchemaType.OBJECT, }); }); it('should handle unknown types by defaulting to STRING', () => { const openAISchema = { - type: 'unknown-type', + name: 'test', + schema: { + type: 'unknown-type' as any, + } as any, }; const result = convertOpenAISchemaToGoogleSchema(openAISchema); @@ -161,15 +190,18 @@ describe('Google generateObject', () => { }, }; - const contents = [{ role: 'user', parts: [{ text: 'Generate a person object' }] }]; + const contents = [{ parts: [{ text: 'Generate a person object' }], role: 'user' }]; const payload = { contents, + model: 'gemini-2.5-flash', schema: { - type: 'object', - properties: { name: { type: 'string' }, age: { type: 'number' } }, + name: 'person', + schema: { + properties: { age: { type: 'number' }, name: { type: 'string' } }, + type: 'object' as const, + }, }, - model: 'gemini-2.5-flash', }; const result = await createGoogleGenerateObject(mockClient as any, payload); @@ -178,11 +210,11 @@ describe('Google generateObject', () => { config: expect.objectContaining({ responseMimeType: 'application/json', responseSchema: expect.objectContaining({ - type: SchemaType.OBJECT, properties: expect.objectContaining({ - name: { type: SchemaType.STRING }, age: { type: SchemaType.NUMBER }, + name: { type: SchemaType.STRING }, }), + type: SchemaType.OBJECT, }), safetySettings: expect.any(Array), }), @@ -190,7 +222,7 @@ describe('Google generateObject', () => { model: 'gemini-2.5-flash', }); - expect(result).toEqual({ name: 'John', age: 30 }); + expect(result).toEqual({ age: 30, name: 'John' }); }); it('should handle options correctly', async () => { @@ -202,12 +234,18 @@ describe('Google generateObject', () => { }, }; - const contents = [{ role: 'user', parts: [{ text: 'Generate status' }] }]; + const contents = [{ parts: [{ text: 'Generate status' }], role: 'user' }]; const payload = { contents, - schema: { type: 'object', properties: { status: { type: 'string' } } }, model: 'gemini-2.5-flash', + schema: { + name: 'status', + schema: { + properties: { status: { type: 'string' } }, + type: 'object' as const, + }, + }, }; const options = { @@ -221,10 +259,10 @@ describe('Google generateObject', () => { abortSignal: options.signal, responseMimeType: 'application/json', responseSchema: expect.objectContaining({ - type: SchemaType.OBJECT, properties: expect.objectContaining({ status: { type: SchemaType.STRING }, }), + type: SchemaType.OBJECT, }), }), contents, @@ -248,8 +286,14 @@ describe('Google generateObject', () => { const payload = { contents, - schema: { type: 'object' }, model: 'gemini-2.5-flash', + schema: { + name: 'test', + schema: { + properties: {}, + type: 'object' as const, + }, + }, }; const result = await createGoogleGenerateObject(mockClient as any, payload); @@ -273,31 +317,37 @@ describe('Google generateObject', () => { const payload = { contents, + model: 'gemini-2.5-flash', schema: { - type: 'object', - properties: { - user: { - type: 'object', - properties: { - name: { type: 'string' }, - profile: { - type: 'object', - properties: { - age: { type: 'number' }, - preferences: { type: 'array', items: { type: 'string' } }, + name: 'user_data', + schema: { + properties: { + metadata: { type: 'object' }, + user: { + properties: { + name: { type: 'string' }, + profile: { + properties: { + age: { type: 'number' }, + preferences: { items: { type: 'string' }, type: 'array' }, + }, + type: 'object', }, }, + type: 'object', }, }, - metadata: { type: 'object' }, + type: 'object' as const, }, }, - model: 'gemini-2.5-flash', }; const result = await createGoogleGenerateObject(mockClient as any, payload); expect(result).toEqual({ + metadata: { + created: '2024-01-01', + }, user: { name: 'Alice', profile: { @@ -305,9 +355,6 @@ describe('Google generateObject', () => { preferences: ['music', 'sports'], }, }, - metadata: { - created: '2024-01-01', - }, }); }); @@ -324,8 +371,14 @@ describe('Google generateObject', () => { const payload = { contents, - schema: { type: 'object' }, model: 'gemini-2.5-flash', + schema: { + name: 'test', + schema: { + properties: {}, + type: 'object' as const, + }, + }, }; await expect(createGoogleGenerateObject(mockClient as any, payload)).rejects.toThrow(); @@ -345,8 +398,14 @@ describe('Google generateObject', () => { const payload = { contents, - schema: { type: 'object' }, model: 'gemini-2.5-flash', + schema: { + name: 'test', + schema: { + properties: {}, + type: 'object' as const, + }, + }, }; const options = { @@ -358,4 +417,450 @@ describe('Google generateObject', () => { ).rejects.toThrow(); }); }); + + describe('createGoogleGenerateObjectWithTools', () => { + it('should return function calls on successful API call with tools', async () => { + const mockClient = { + models: { + generateContent: vi.fn().mockResolvedValue({ + candidates: [ + { + content: { + parts: [ + { + functionCall: { + args: { city: 'New York', unit: 'celsius' }, + name: 'get_weather', + }, + }, + ], + }, + }, + ], + }), + }, + }; + + const contents = [{ parts: [{ text: 'What is the weather in New York?' }], role: 'user' }]; + + const payload = { + contents, + model: 'gemini-2.5-flash', + tools: [ + { + function: { + description: 'Get weather information', + name: 'get_weather', + parameters: { + properties: { + city: { type: 'string' }, + unit: { type: 'string' }, + }, + required: ['city'], + type: 'object' as const, + }, + }, + type: 'function' as const, + }, + ], + }; + + const result = await createGoogleGenerateObjectWithTools(mockClient as any, payload); + + expect(mockClient.models.generateContent).toHaveBeenCalledWith({ + config: expect.objectContaining({ + safetySettings: expect.any(Array), + toolConfig: { + functionCallingConfig: { + mode: 'ANY', + }, + }, + tools: [ + { + functionDeclarations: [ + { + description: 'Get weather information', + name: 'get_weather', + parameters: { + description: undefined, + properties: { + city: { type: 'string' }, + unit: { type: 'string' }, + }, + required: ['city'], + type: SchemaType.OBJECT, + }, + }, + ], + }, + ], + }), + contents, + model: 'gemini-2.5-flash', + }); + + expect(result).toEqual([ + { arguments: { city: 'New York', unit: 'celsius' }, name: 'get_weather' }, + ]); + }); + + it('should handle multiple function calls', async () => { + const mockClient = { + models: { + generateContent: vi.fn().mockResolvedValue({ + candidates: [ + { + content: { + parts: [ + { + functionCall: { + args: { city: 'New York', unit: 'celsius' }, + name: 'get_weather', + }, + }, + { + functionCall: { + args: { timezone: 'America/New_York' }, + name: 'get_time', + }, + }, + ], + }, + }, + ], + }), + }, + }; + + const contents: any[] = []; + + const payload = { + contents, + model: 'gemini-2.5-flash', + tools: [ + { + function: { + description: 'Get weather information', + name: 'get_weather', + parameters: { + properties: { + city: { type: 'string' }, + unit: { type: 'string' }, + }, + required: ['city'], + type: 'object' as const, + }, + }, + type: 'function' as const, + }, + { + function: { + description: 'Get current time', + name: 'get_time', + parameters: { + properties: { + timezone: { type: 'string' }, + }, + required: ['timezone'], + type: 'object' as const, + }, + }, + type: 'function' as const, + }, + ], + }; + + const result = await createGoogleGenerateObjectWithTools(mockClient as any, payload); + + expect(result).toEqual([ + { arguments: { city: 'New York', unit: 'celsius' }, name: 'get_weather' }, + { arguments: { timezone: 'America/New_York' }, name: 'get_time' }, + ]); + }); + + it('should handle options correctly', async () => { + const mockClient = { + models: { + generateContent: vi.fn().mockResolvedValue({ + candidates: [ + { + content: { + parts: [ + { + functionCall: { + args: { a: 5, b: 3, operation: 'add' }, + name: 'calculate', + }, + }, + ], + }, + }, + ], + }), + }, + }; + + const contents: any[] = []; + + const payload = { + contents, + model: 'gemini-2.5-flash', + tools: [ + { + function: { + description: 'Perform mathematical calculation', + name: 'calculate', + parameters: { + properties: { + a: { type: 'number' }, + b: { type: 'number' }, + operation: { type: 'string' }, + }, + required: ['operation', 'a', 'b'], + type: 'object' as const, + }, + }, + type: 'function' as const, + }, + ], + }; + + const options = { + signal: new AbortController().signal, + }; + + const result = await createGoogleGenerateObjectWithTools(mockClient as any, payload, options); + + expect(mockClient.models.generateContent).toHaveBeenCalledWith({ + config: expect.objectContaining({ + abortSignal: options.signal, + }), + contents, + model: 'gemini-2.5-flash', + }); + + expect(result).toEqual([{ arguments: { a: 5, b: 3, operation: 'add' }, name: 'calculate' }]); + }); + + it('should return undefined when no function calls in response', async () => { + const mockClient = { + models: { + generateContent: vi.fn().mockResolvedValue({ + candidates: [ + { + content: { + parts: [ + { + text: 'Some text response without function call', + }, + ], + }, + }, + ], + }), + }, + }; + + const contents: any[] = []; + + const payload = { + contents, + model: 'gemini-2.5-flash', + tools: [ + { + function: { + description: 'Test function', + name: 'test_function', + parameters: { + properties: {}, + type: 'object' as const, + }, + }, + type: 'function' as const, + }, + ], + }; + + const result = await createGoogleGenerateObjectWithTools(mockClient as any, payload); + + expect(result).toBeUndefined(); + }); + + it('should return undefined when no content parts in response', async () => { + const mockClient = { + models: { + generateContent: vi.fn().mockResolvedValue({ + candidates: [ + { + content: {}, + }, + ], + }), + }, + }; + + const contents: any[] = []; + + const payload = { + contents, + model: 'gemini-2.5-flash', + tools: [ + { + function: { + description: 'Test function', + name: 'test_function', + parameters: { + properties: {}, + type: 'object' as const, + }, + }, + type: 'function' as const, + }, + ], + }; + + const result = await createGoogleGenerateObjectWithTools(mockClient as any, payload); + + expect(result).toBeUndefined(); + }); + + it('should propagate API errors correctly', async () => { + const apiError = new Error('API Error: Model not found'); + + const mockClient = { + models: { + generateContent: vi.fn().mockRejectedValue(apiError), + }, + }; + + const contents: any[] = []; + + const payload = { + contents, + model: 'gemini-2.5-flash', + tools: [ + { + function: { + description: 'Test function', + name: 'test_function', + parameters: { + properties: {}, + type: 'object' as const, + }, + }, + type: 'function' as const, + }, + ], + }; + + await expect(createGoogleGenerateObjectWithTools(mockClient as any, payload)).rejects.toThrow( + 'API Error: Model not found', + ); + }); + + it('should handle abort signals correctly', async () => { + const apiError = new Error('Request was cancelled'); + apiError.name = 'AbortError'; + + const mockClient = { + models: { + generateContent: vi.fn().mockRejectedValue(apiError), + }, + }; + + const contents: any[] = []; + + const payload = { + contents, + model: 'gemini-2.5-flash', + tools: [ + { + function: { + description: 'Test function', + name: 'test_function', + parameters: { + properties: {}, + type: 'object' as const, + }, + }, + type: 'function' as const, + }, + ], + }; + + const options = { + signal: new AbortController().signal, + }; + + await expect( + createGoogleGenerateObjectWithTools(mockClient as any, payload, options), + ).rejects.toThrow(); + }); + + it('should handle tools with empty parameters', async () => { + const mockClient = { + models: { + generateContent: vi.fn().mockResolvedValue({ + candidates: [ + { + content: { + parts: [ + { + functionCall: { + args: {}, + name: 'simple_function', + }, + }, + ], + }, + }, + ], + }), + }, + }; + + const contents: any[] = []; + + const payload = { + contents, + model: 'gemini-2.5-flash', + tools: [ + { + function: { + description: 'A simple function with no parameters', + name: 'simple_function', + parameters: { + properties: {}, + type: 'object' as const, + }, + }, + type: 'function' as const, + }, + ], + }; + + const result = await createGoogleGenerateObjectWithTools(mockClient as any, payload); + + // Should use dummy property for empty parameters + expect(mockClient.models.generateContent).toHaveBeenCalledWith({ + config: expect.objectContaining({ + tools: [ + { + functionDeclarations: [ + expect.objectContaining({ + parameters: expect.objectContaining({ + properties: { dummy: { type: 'string' } }, + }), + }), + ], + }, + ], + }), + contents, + model: 'gemini-2.5-flash', + }); + + expect(result).toEqual([{ arguments: {}, name: 'simple_function' }]); + }); + }); }); diff --git a/packages/model-runtime/src/providers/google/generateObject.ts b/packages/model-runtime/src/providers/google/generateObject.ts index 5726a8734e5..d3b8e07e032 100644 --- a/packages/model-runtime/src/providers/google/generateObject.ts +++ b/packages/model-runtime/src/providers/google/generateObject.ts @@ -1,9 +1,15 @@ -import { GenerateContentConfig, GoogleGenAI, Type as SchemaType } from '@google/genai'; +import { + FunctionCallingConfigMode, + GenerateContentConfig, + GoogleGenAI, + Type as SchemaType, +} from '@google/genai'; import Debug from 'debug'; -import { GenerateObjectOptions } from '../../types'; +import { buildGoogleTool } from '../../core/contextBuilders/google'; +import { ChatCompletionTool, GenerateObjectOptions, GenerateObjectSchema } from '../../types'; -const debug = Debug('mode-runtime:google:generateObject'); +const debug = Debug('lobe-mode-runtime:google:generateObject'); enum HarmCategory { HARM_CATEGORY_DANGEROUS_CONTENT = 'HARM_CATEGORY_DANGEROUS_CONTENT', @@ -54,7 +60,7 @@ const convertType = (type: string): SchemaType => { /** * Convert OpenAI JSON schema to Google Gemini schema format */ -export const convertOpenAISchemaToGoogleSchema = (openAISchema: any): any => { +export const convertOpenAISchemaToGoogleSchema = (openAISchema: GenerateObjectSchema): any => { const convertSchema = (schema: any): any => { if (!schema) return schema; @@ -92,7 +98,7 @@ export const convertOpenAISchemaToGoogleSchema = (openAISchema: any): any => { return converted; }; - return convertSchema(openAISchema); + return convertSchema(openAISchema.schema); }; /** @@ -104,7 +110,7 @@ export const createGoogleGenerateObject = async ( payload: { contents: any[]; model: string; - schema: any; + schema: GenerateObjectSchema; }, options?: GenerateObjectOptions, ) => { @@ -175,3 +181,95 @@ export const createGoogleGenerateObject = async ( return undefined; } }; + +/** + * Generate structured output using Google Gemini API with tools calling + * @see https://ai.google.dev/gemini-api/docs/function-calling + */ +export const createGoogleGenerateObjectWithTools = async ( + client: GoogleGenAI, + payload: { + contents: any[]; + model: string; + tools: ChatCompletionTool[]; + }, + options?: GenerateObjectOptions, +) => { + const { tools, contents, model } = payload; + + debug('createGoogleGenerateObjectWithTools started', { + contentsLength: contents.length, + model, + toolsCount: tools.length, + }); + + // Convert tools to Google FunctionDeclaration format + const functionDeclarations = tools.map(buildGoogleTool); + debug('Tools conversion completed', { functionDeclarations }); + + const config: GenerateContentConfig = { + abortSignal: options?.signal, + // avoid wide sensitive words + safetySettings: [ + { + category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold: getThreshold(model), + }, + { + category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold: getThreshold(model), + }, + { + category: HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold: getThreshold(model), + }, + { + category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold: getThreshold(model), + }, + ], + // Force tool calling with 'any' mode + toolConfig: { + functionCallingConfig: { + mode: FunctionCallingConfigMode.ANY, + }, + }, + tools: [{ functionDeclarations }], + }; + + debug('Config prepared', { + hasAbortSignal: !!config.abortSignal, + hasSafetySettings: !!config.safetySettings, + hasTools: !!config.tools, + model, + }); + + const response = await client.models.generateContent({ + config, + contents, + model, + }); + + debug('API response received', { + candidatesCount: response.candidates?.length, + hasContent: !!response.candidates?.[0]?.content, + }); + + // Extract function calls from response + const candidate = response.candidates?.[0]; + if (!candidate?.content?.parts) { + debug('no content parts in response'); + return undefined; + } + + const functionCalls = candidate.content.parts + .filter((part) => part.functionCall) + .map((part) => ({ + arguments: part.functionCall!.args, + name: part.functionCall!.name, + })); + + debug('extracted function calls', { count: functionCalls.length, functionCalls }); + + return functionCalls.length > 0 ? functionCalls : undefined; +}; diff --git a/packages/model-runtime/src/providers/google/index.test.ts b/packages/model-runtime/src/providers/google/index.test.ts index 775a2693b51..2e8f307281e 100644 --- a/packages/model-runtime/src/providers/google/index.test.ts +++ b/packages/model-runtime/src/providers/google/index.test.ts @@ -432,401 +432,6 @@ describe('LobeGoogleAI', () => { }); describe('private method', () => { - describe('convertContentToGooglePart', () => { - it('should handle text type messages', async () => { - const result = await instance['convertContentToGooglePart']({ - type: 'text', - text: 'Hello', - }); - expect(result).toEqual({ text: 'Hello' }); - }); - it('should handle thinking type messages', async () => { - const result = await instance['convertContentToGooglePart']({ - type: 'thinking', - thinking: 'Hello', - signature: 'abc', - }); - expect(result).toEqual(undefined); - }); - - it('should handle base64 type images', async () => { - const base64Image = - 'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAUAAAAFCAYAAACNbyblAAAAHElEQVQI12P4//8/w38GIAXDIBKE0DHxgljNBAAO9TXL0Y4OHwAAAABJRU5ErkJggg=='; - const result = await instance['convertContentToGooglePart']({ - type: 'image_url', - image_url: { url: base64Image }, - }); - - expect(result).toEqual({ - inlineData: { - data: 'iVBORw0KGgoAAAANSUhEUgAAAAUAAAAFCAYAAACNbyblAAAAHElEQVQI12P4//8/w38GIAXDIBKE0DHxgljNBAAO9TXL0Y4OHwAAAABJRU5ErkJggg==', - mimeType: 'image/png', - }, - }); - }); - - it('should handle URL type images', async () => { - const imageUrl = 'http://example.com/image.png'; - const mockBase64 = 'mockBase64Data'; - - // Mock the imageUrlToBase64 function - vi.spyOn(imageToBase64Module, 'imageUrlToBase64').mockResolvedValueOnce({ - base64: mockBase64, - mimeType: 'image/png', - }); - - const result = await instance['convertContentToGooglePart']({ - type: 'image_url', - image_url: { url: imageUrl }, - }); - - expect(result).toEqual({ - inlineData: { - data: mockBase64, - mimeType: 'image/png', - }, - }); - - expect(imageToBase64Module.imageUrlToBase64).toHaveBeenCalledWith(imageUrl); - }); - - it('should throw TypeError for unsupported image URL types', async () => { - const unsupportedImageUrl = 'unsupported://example.com/image.png'; - - await expect( - instance['convertContentToGooglePart']({ - type: 'image_url', - image_url: { url: unsupportedImageUrl }, - }), - ).rejects.toThrow(TypeError); - }); - }); - - describe('buildGoogleMessages', () => { - it('get default result with gemini-pro', async () => { - const messages: OpenAIChatMessage[] = [{ content: 'Hello', role: 'user' }]; - - const contents = await instance['buildGoogleMessages'](messages); - - expect(contents).toHaveLength(1); - expect(contents).toEqual([{ parts: [{ text: 'Hello' }], role: 'user' }]); - }); - - it('should not modify the length if model is gemini-1.5-pro', async () => { - const messages: OpenAIChatMessage[] = [ - { content: 'Hello', role: 'user' }, - { content: 'Hi', role: 'assistant' }, - ]; - - const contents = await instance['buildGoogleMessages'](messages); - - expect(contents).toHaveLength(2); - expect(contents).toEqual([ - { parts: [{ text: 'Hello' }], role: 'user' }, - { parts: [{ text: 'Hi' }], role: 'model' }, - ]); - }); - - it('should use specified model when images are included in messages', async () => { - const messages: OpenAIChatMessage[] = [ - { - content: [ - { type: 'text', text: 'Hello' }, - { type: 'image_url', image_url: { url: 'data:image/png;base64,...' } }, - ], - role: 'user', - }, - ]; - - // Call the buildGoogleMessages method - const contents = await instance['buildGoogleMessages'](messages); - - expect(contents).toHaveLength(1); - expect(contents).toEqual([ - { - parts: [{ text: 'Hello' }, { inlineData: { data: '...', mimeType: 'image/png' } }], - role: 'user', - }, - ]); - }); - - it('should correctly convert function response message', async () => { - const messages: OpenAIChatMessage[] = [ - { - content: '', - role: 'assistant', - tool_calls: [ - { - id: 'call_1', - function: { - name: 'get_current_weather', - arguments: JSON.stringify({ location: 'London', unit: 'celsius' }), - }, - type: 'function', - }, - ], - }, - { - content: '{"success":true,"data":{"temperature":"14°C"}}', - name: 'get_current_weather', - role: 'tool', - tool_call_id: 'call_1', - }, - ]; - - const contents = await instance['buildGoogleMessages'](messages); - expect(contents).toHaveLength(2); - expect(contents).toEqual([ - { - parts: [ - { - functionCall: { - args: { location: 'London', unit: 'celsius' }, - name: 'get_current_weather', - }, - }, - ], - role: 'model', - }, - { - parts: [ - { - functionResponse: { - name: 'get_current_weather', - response: { result: '{"success":true,"data":{"temperature":"14°C"}}' }, - }, - }, - ], - role: 'user', - }, - ]); - }); - }); - - describe('buildGoogleTools', () => { - it('should return undefined when tools is undefined or empty', () => { - expect(instance['buildGoogleTools'](undefined)).toBeUndefined(); - expect(instance['buildGoogleTools']([])).toBeUndefined(); - }); - - it('should correctly convert ChatCompletionTool to GoogleFunctionCallTool', () => { - const tools: OpenAI.ChatCompletionTool[] = [ - { - function: { - name: 'testTool', - description: 'A test tool', - parameters: { - type: 'object', - properties: { - param1: { type: 'string' }, - param2: { type: 'number' }, - }, - required: ['param1'], - }, - }, - type: 'function', - }, - ]; - - const googleTools = instance['buildGoogleTools'](tools); - - expect(googleTools).toHaveLength(1); - expect((googleTools![0] as Tool).functionDeclarations![0]).toEqual({ - name: 'testTool', - description: 'A test tool', - parameters: { - type: 'OBJECT', - properties: { - param1: { type: 'string' }, - param2: { type: 'number' }, - }, - required: ['param1'], - }, - }); - }); - - it('should also add tools when tool_calls exists', () => { - const tools: OpenAI.ChatCompletionTool[] = [ - { - function: { - name: 'testTool', - description: 'A test tool', - parameters: { - type: 'object', - properties: { - param1: { type: 'string' }, - param2: { type: 'number' }, - }, - required: ['param1'], - }, - }, - type: 'function', - }, - ]; - - const payload: ChatStreamPayload = { - messages: [ - { - role: 'user', - content: '', - tool_calls: [ - { function: { name: 'some_func', arguments: '' }, id: 'func_1', type: 'function' }, - ], - }, - ], - model: 'gemini-2.5-flash-preview-04-17', - temperature: 1, - }; - - const googleTools = instance['buildGoogleTools'](tools, payload); - - expect(googleTools).toHaveLength(1); - expect((googleTools![0] as Tool).functionDeclarations![0]).toEqual({ - name: 'testTool', - description: 'A test tool', - parameters: { - type: 'OBJECT', - properties: { - param1: { type: 'string' }, - param2: { type: 'number' }, - }, - required: ['param1'], - }, - }); - }); - - it('should handle googleSearch', () => { - const payload: ChatStreamPayload = { - messages: [ - { - role: 'user', - content: '', - }, - ], - model: 'gemini-2.5-flash-preview-04-17', - temperature: 1, - enabledSearch: true, - }; - - const googleTools = instance['buildGoogleTools'](undefined, payload); - - expect(googleTools).toHaveLength(1); - expect(googleTools![0] as Tool).toEqual({ googleSearch: {} }); - }); - }); - - describe('convertOAIMessagesToGoogleMessage', () => { - it('should correctly convert assistant message', async () => { - const message: OpenAIChatMessage = { - role: 'assistant', - content: 'Hello', - }; - - const converted = await instance['convertOAIMessagesToGoogleMessage'](message); - - expect(converted).toEqual({ - role: 'model', - parts: [{ text: 'Hello' }], - }); - }); - - it('should correctly convert user message', async () => { - const message: OpenAIChatMessage = { - role: 'user', - content: 'Hi', - }; - - const converted = await instance['convertOAIMessagesToGoogleMessage'](message); - - expect(converted).toEqual({ - role: 'user', - parts: [{ text: 'Hi' }], - }); - }); - - it('should correctly convert message with inline base64 image parts', async () => { - const message: OpenAIChatMessage = { - role: 'user', - content: [ - { type: 'text', text: 'Check this image:' }, - { type: 'image_url', image_url: { url: 'data:image/png;base64,...' } }, - ], - }; - - const converted = await instance['convertOAIMessagesToGoogleMessage'](message); - - expect(converted).toEqual({ - role: 'user', - parts: [ - { text: 'Check this image:' }, - { inlineData: { data: '...', mimeType: 'image/png' } }, - ], - }); - }); - it.skip('should correctly convert message with image url parts', async () => { - const message: OpenAIChatMessage = { - role: 'user', - content: [ - { type: 'text', text: 'Check this image:' }, - { type: 'image_url', image_url: { url: 'https://image-file.com' } }, - ], - }; - - const converted = await instance['convertOAIMessagesToGoogleMessage'](message); - - expect(converted).toEqual({ - role: 'user', - parts: [ - { text: 'Check this image:' }, - { inlineData: { data: '...', mimeType: 'image/png' } }, - ], - }); - }); - - it('should correctly convert function call message', async () => { - const message = { - role: 'assistant', - tool_calls: [ - { - id: 'call_1', - function: { - name: 'get_current_weather', - arguments: JSON.stringify({ location: 'London', unit: 'celsius' }), - }, - type: 'function', - }, - ], - } as OpenAIChatMessage; - - const converted = await instance['convertOAIMessagesToGoogleMessage'](message); - expect(converted).toEqual({ - role: 'model', - parts: [ - { - functionCall: { - name: 'get_current_weather', - args: { location: 'London', unit: 'celsius' }, - }, - }, - ], - }); - }); - - it('should correctly handle empty content', async () => { - const message: OpenAIChatMessage = { - role: 'user', - content: '' as any, // explicitly set as empty string - }; - - const converted = await instance['convertOAIMessagesToGoogleMessage'](message); - - expect(converted).toEqual({ - role: 'user', - parts: [{ text: '' }], - }); - }); - }); - describe('createEnhancedStream', () => { it('should handle stream cancellation with data gracefully', async () => { const mockStream = (async function* () { diff --git a/packages/model-runtime/src/providers/google/index.ts b/packages/model-runtime/src/providers/google/index.ts index 3e74f3e44e3..aeab210440f 100644 --- a/packages/model-runtime/src/providers/google/index.ts +++ b/packages/model-runtime/src/providers/google/index.ts @@ -1,17 +1,14 @@ import { - Content, - FunctionDeclaration, GenerateContentConfig, Tool as GoogleFunctionCallTool, GoogleGenAI, HttpOptions, - Part, - Type as SchemaType, ThinkingConfig, } from '@google/genai'; import debug from 'debug'; import { LobeRuntimeAI } from '../../core/BaseAI'; +import { buildGoogleMessages, buildGoogleTools } from '../../core/contextBuilders/google'; import { GoogleGenerativeAIStream, VertexAIStream } from '../../core/streams'; import { LOBE_ERROR_KEY } from '../../core/streams/google'; import { @@ -20,8 +17,6 @@ import { ChatStreamPayload, GenerateObjectOptions, GenerateObjectPayload, - OpenAIChatMessage, - UserMessageContentPart, } from '../../types'; import { AgentRuntimeErrorType } from '../../types/error'; import { CreateImagePayload, CreateImageResponse } from '../../types/image'; @@ -29,12 +24,9 @@ import { AgentRuntimeError } from '../../utils/createError'; import { debugStream } from '../../utils/debugStream'; import { getModelPricing } from '../../utils/getModelPricing'; import { parseGoogleErrorMessage } from '../../utils/googleErrorParser'; -import { imageUrlToBase64 } from '../../utils/imageToBase64'; import { StreamingResponse } from '../../utils/response'; -import { safeParseJSON } from '../../utils/safeParseJSON'; -import { parseDataUri } from '../../utils/uriParser'; import { createGoogleImage } from './createImage'; -import { createGoogleGenerateObject } from './generateObject'; +import { createGoogleGenerateObject, createGoogleGenerateObjectWithTools } from './generateObject'; const log = debug('model-runtime:google'); @@ -217,7 +209,7 @@ export class LobeGoogleAI implements LobeRuntimeAI { thinkingBudget: resolvedThinkingBudget, }; - const contents = await this.buildGoogleMessages(payload.messages); + const contents = await buildGoogleMessages(payload.messages); const controller = new AbortController(); const originalSignal = options?.signal; @@ -264,7 +256,7 @@ export class LobeGoogleAI implements LobeRuntimeAI { modelsDisableInstuction.has(model) || model.toLowerCase().includes('learnlm') ? undefined : thinkingConfig, - tools: this.buildGoogleTools(payload.tools, payload), + tools: this.buildGoogleToolsWithSearch(payload.tools, payload), topP: payload.top_p, }; @@ -330,16 +322,31 @@ export class LobeGoogleAI implements LobeRuntimeAI { /** * Generate structured output using Google Gemini API * @see https://ai.google.dev/gemini-api/docs/structured-output + * @see https://ai.google.dev/gemini-api/docs/function-calling */ async generateObject(payload: GenerateObjectPayload, options?: GenerateObjectOptions) { // Convert OpenAI messages to Google format - const contents = await this.buildGoogleMessages(payload.messages); + const contents = await buildGoogleMessages(payload.messages); + + // Handle tools-based structured output + if (payload.tools && payload.tools.length > 0) { + return createGoogleGenerateObjectWithTools( + this.client, + { contents, model: payload.model, tools: payload.tools }, + options, + ); + } + + // Handle schema-based structured output + if (payload.schema) { + return createGoogleGenerateObject( + this.client, + { contents, model: payload.model, schema: payload.schema }, + options, + ); + } - return createGoogleGenerateObject( - this.client, - { contents, model: payload.model, schema: payload.schema }, - options, - ); + return undefined; } private createEnhancedStream(originalStream: any, signal: AbortSignal): ReadableStream { @@ -489,147 +496,7 @@ export class LobeGoogleAI implements LobeRuntimeAI { }; } - private convertContentToGooglePart = async ( - content: UserMessageContentPart, - ): Promise => { - switch (content.type) { - default: { - return undefined; - } - - case 'text': { - return { text: content.text }; - } - - case 'image_url': { - const { mimeType, base64, type } = parseDataUri(content.image_url.url); - - if (type === 'base64') { - if (!base64) { - throw new TypeError("Image URL doesn't contain base64 data"); - } - - return { - inlineData: { data: base64, mimeType: mimeType || 'image/png' }, - }; - } - - if (type === 'url') { - const { base64, mimeType } = await imageUrlToBase64(content.image_url.url); - - return { - inlineData: { data: base64, mimeType }, - }; - } - - throw new TypeError(`currently we don't support image url: ${content.image_url.url}`); - } - - case 'video_url': { - const { mimeType, base64, type } = parseDataUri(content.video_url.url); - - if (type === 'base64') { - if (!base64) { - throw new TypeError("Video URL doesn't contain base64 data"); - } - - return { - inlineData: { data: base64, mimeType: mimeType || 'video/mp4' }, - }; - } - - if (type === 'url') { - // For video URLs, we need to fetch and convert to base64 - // Note: This might need size/duration limits for practical use - const response = await fetch(content.video_url.url); - const arrayBuffer = await response.arrayBuffer(); - const base64 = Buffer.from(arrayBuffer).toString('base64'); - const mimeType = response.headers.get('content-type') || 'video/mp4'; - - return { - inlineData: { data: base64, mimeType }, - }; - } - - throw new TypeError(`currently we don't support video url: ${content.video_url.url}`); - } - } - }; - - private convertOAIMessagesToGoogleMessage = async ( - message: OpenAIChatMessage, - toolCallNameMap?: Map, - ): Promise => { - const content = message.content as string | UserMessageContentPart[]; - if (!!message.tool_calls) { - return { - parts: message.tool_calls.map((tool) => ({ - functionCall: { - args: safeParseJSON(tool.function.arguments)!, - name: tool.function.name, - }, - })), - role: 'model', - }; - } - - // 将 tool_call result 转成 functionResponse part - if (message.role === 'tool' && toolCallNameMap && message.tool_call_id) { - const functionName = toolCallNameMap.get(message.tool_call_id); - if (functionName) { - return { - parts: [ - { - functionResponse: { - name: functionName, - response: { result: message.content }, - }, - }, - ], - role: 'user', - }; - } - } - - const getParts = async () => { - if (typeof content === 'string') return [{ text: content }]; - - const parts = await Promise.all( - content.map(async (c) => await this.convertContentToGooglePart(c)), - ); - return parts.filter(Boolean) as Part[]; - }; - - return { - parts: await getParts(), - role: message.role === 'assistant' ? 'model' : 'user', - }; - }; - - // convert messages from the OpenAI format to Google GenAI SDK - private buildGoogleMessages = async (messages: OpenAIChatMessage[]): Promise => { - const toolCallNameMap = new Map(); - messages.forEach((message) => { - if (message.role === 'assistant' && message.tool_calls) { - message.tool_calls.forEach((toolCall) => { - if (toolCall.type === 'function') { - toolCallNameMap.set(toolCall.id, toolCall.function.name); - } - }); - } - }); - - const pools = messages - .filter((message) => message.role !== 'function') - .map(async (msg) => await this.convertOAIMessagesToGoogleMessage(msg, toolCallNameMap)); - - const contents = await Promise.all(pools); - - // 筛除空消息: contents.parts must not be empty. - return contents.filter((content: Content) => content.parts && content.parts.length > 0); - }; - - private buildGoogleTools( + private buildGoogleToolsWithSearch( tools: ChatCompletionTool[] | undefined, payload?: ChatStreamPayload, ): GoogleFunctionCallTool[] | undefined { @@ -640,7 +507,7 @@ export class LobeGoogleAI implements LobeRuntimeAI { // 如果已经有 tool_calls,优先处理 function declarations if (hasToolCalls && hasFunctionTools) { - return this.buildFunctionDeclarations(tools); + return buildGoogleTools(tools); } // 构建并返回搜索相关工具(搜索工具不能与 FunctionCall 同时使用) @@ -655,41 +522,8 @@ export class LobeGoogleAI implements LobeRuntimeAI { } // 最后考虑 function declarations - return this.buildFunctionDeclarations(tools); + return buildGoogleTools(tools); } - - private buildFunctionDeclarations( - tools: ChatCompletionTool[] | undefined, - ): GoogleFunctionCallTool[] | undefined { - if (!tools || tools.length === 0) return; - - return [ - { - functionDeclarations: tools.map((tool) => this.convertToolToGoogleTool(tool)), - }, - ]; - } - - private convertToolToGoogleTool = (tool: ChatCompletionTool): FunctionDeclaration => { - const functionDeclaration = tool.function; - const parameters = functionDeclaration.parameters; - // refs: https://github.com/lobehub/lobe-chat/pull/5002 - const properties = - parameters?.properties && Object.keys(parameters.properties).length > 0 - ? parameters.properties - : { dummy: { type: 'string' } }; // dummy property to avoid empty object - - return { - description: functionDeclaration.description, - name: functionDeclaration.name, - parameters: { - description: parameters?.description, - properties: properties, - required: parameters?.required, - type: SchemaType.OBJECT, - }, - }; - }; } export default LobeGoogleAI; diff --git a/packages/model-runtime/src/providers/openai/index.test.ts b/packages/model-runtime/src/providers/openai/index.test.ts index 9e157f43d69..57f30bb7a4a 100644 --- a/packages/model-runtime/src/providers/openai/index.test.ts +++ b/packages/model-runtime/src/providers/openai/index.test.ts @@ -59,10 +59,10 @@ describe('LobeOpenAI', () => { const apiError = new OpenAI.APIError( 400, { - status: 400, error: { message: 'Bad Request', }, + status: 400, }, 'Error message', {}, @@ -178,13 +178,13 @@ describe('LobeOpenAI', () => { } catch (e) { expect(e).toEqual({ endpoint: 'https://api.openai.com/v1', - errorType: 'AgentRuntimeError', - provider: 'openai', error: { - name: genericError.name, cause: genericError.cause, message: genericError.message, + name: genericError.name, }, + errorType: 'AgentRuntimeError', + provider: 'openai', }); } }); @@ -261,10 +261,10 @@ describe('LobeOpenAI', () => { it('should use responses API when enabledSearch is true', async () => { const payload = { + enabledSearch: true, messages: [{ content: 'Hello', role: 'user' as const }], model: 'gpt-4o', temperature: 0.7, - enabledSearch: true, }; await instance.chat(payload); @@ -275,12 +275,12 @@ describe('LobeOpenAI', () => { it('should handle -search- models with stripped parameters', async () => { const payload = { + frequency_penalty: 0.5, messages: [{ content: 'Hello', role: 'user' as const }], model: 'gpt-4o-search-2024', + presence_penalty: 0.3, temperature: 0.7, top_p: 0.9, - frequency_penalty: 0.5, - presence_penalty: 0.3, }; await instance.chat(payload); @@ -296,12 +296,12 @@ describe('LobeOpenAI', () => { it('should handle regular models with all parameters', async () => { const payload = { + frequency_penalty: 0.5, messages: [{ content: 'Hello', role: 'user' as const }], model: 'gpt-4o', + presence_penalty: 0.3, temperature: 0.7, top_p: 0.9, - frequency_penalty: 0.5, - presence_penalty: 0.3, }; await instance.chat(payload); @@ -319,18 +319,19 @@ describe('LobeOpenAI', () => { describe('responses.handlePayload', () => { it('should add web_search tool when enabledSearch is true', async () => { const payload = { + enabledSearch: true, messages: [{ content: 'Hello', role: 'user' as const }], - model: 'gpt-4o', // 使用常规模型,通过 enabledSearch 触发 responses API + model: 'gpt-4o', + // 使用常规模型,通过 enabledSearch 触发 responses API temperature: 0.7, - enabledSearch: true, - tools: [{ type: 'function' as const, function: { name: 'test', description: 'test' } }], + tools: [{ function: { description: 'test', name: 'test' }, type: 'function' as const }], }; await instance.chat(payload); const createCall = (instance['client'].responses.create as Mock).mock.calls[0][0]; expect(createCall.tools).toEqual([ - { type: 'function', name: 'test', description: 'test' }, + { description: 'test', name: 'test', type: 'function' }, { type: 'web_search' }, ]); }); @@ -339,10 +340,10 @@ describe('LobeOpenAI', () => { // Note: oaiSearchContextSize is read at module load time, not runtime // This test verifies the tool structure is correct when the env var would be set const payload = { + enabledSearch: true, messages: [{ content: 'Hello', role: 'user' as const }], model: 'gpt-4o', temperature: 0.7, - enabledSearch: true, }; await instance.chat(payload); @@ -358,8 +359,8 @@ describe('LobeOpenAI', () => { const payload = { messages: [{ content: 'Hello', role: 'user' as const }], model: 'computer-use-preview', - temperature: 0.7, reasoning: { effort: 'medium' }, + temperature: 0.7, }; await instance.chat(payload); @@ -393,7 +394,7 @@ describe('LobeOpenAI', () => { await instance.chat(payload); const createCall = (instance['client'].responses.create as Mock).mock.calls[0][0]; - expect(createCall.reasoning).toEqual({ summary: 'auto', effort: 'high' }); + expect(createCall.reasoning).toEqual({ effort: 'high', summary: 'auto' }); }); it('should set reasoning.effort to high for gpt-5-pro-2025-10-06 models', async () => { @@ -406,7 +407,7 @@ describe('LobeOpenAI', () => { await instance.chat(payload); const createCall = (instance['client'].responses.create as Mock).mock.calls[0][0]; - expect(createCall.reasoning).toEqual({ summary: 'auto', effort: 'high' }); + expect(createCall.reasoning).toEqual({ effort: 'high', summary: 'auto' }); }); }); diff --git a/packages/model-runtime/src/types/structureOutput.ts b/packages/model-runtime/src/types/structureOutput.ts index 40c7fa46795..f63008c2b69 100644 --- a/packages/model-runtime/src/types/structureOutput.ts +++ b/packages/model-runtime/src/types/structureOutput.ts @@ -1,4 +1,4 @@ -import { ChatCompletionFunctions } from './chat'; +import { ChatCompletionTool } from './chat'; interface GenerateObjectMessage { content: string; @@ -6,7 +6,7 @@ interface GenerateObjectMessage { role: 'user' | 'system' | 'assistant'; } -interface GenerateObjectSchema { +export interface GenerateObjectSchema { description?: string; name: string; schema: { @@ -23,8 +23,7 @@ export interface GenerateObjectPayload { model: string; responseApi?: boolean; schema?: GenerateObjectSchema; - systemRole?: string; - tools?: ChatCompletionFunctions[]; + tools?: ChatCompletionTool[]; } export interface GenerateObjectOptions { diff --git a/packages/types/src/aiChat.ts b/packages/types/src/aiChat.ts index b482c61b2f8..bb3462d67b3 100644 --- a/packages/types/src/aiChat.ts +++ b/packages/types/src/aiChat.ts @@ -75,7 +75,6 @@ export const StructureOutputSchema = z.object({ model: z.string(), provider: z.string(), schema: StructureSchema.optional(), - systemRole: z.string().optional(), tools: z .array(z.object({ function: LobeUniformToolSchema, type: z.literal('function') })) .optional(), diff --git a/src/server/routers/lambda/aiChat.ts b/src/server/routers/lambda/aiChat.ts index 4e194ca9b65..1abbea979c4 100644 --- a/src/server/routers/lambda/aiChat.ts +++ b/src/server/routers/lambda/aiChat.ts @@ -60,8 +60,7 @@ export const aiChatRouter = router({ messages: input.messages, model: input.model, schema: input.schema, - systemRole: input.systemRole, - tools: input.tools?.map((item) => item.function), + tools: input.tools, }); log('generateObject completed, result: %O', result);