diff --git a/database.ts b/database.ts index 5c45a3f..64f5109 100644 --- a/database.ts +++ b/database.ts @@ -19,7 +19,7 @@ import { export class DatabaseManager { private static columnCache: WeakMap = new WeakMap(); - static async initDatabase(config: SourceConfig, parentLogger: Logger): Promise { + static async initDatabase(config: SourceConfig, parentLogger: Logger, embeddingDimension: number = 3072): Promise { const logger = parentLogger.child('database'); const dbConfig = config.database_config; @@ -32,10 +32,10 @@ export class DatabaseManager { const db = new BetterSqlite3(dbPath, { allowExtension: true } as any); sqliteVec.load(db); - logger.debug(`Creating vec_items table if it doesn't exist`); + logger.debug(`Creating vec_items table if it doesn't exist (dimension: ${embeddingDimension})`); db.exec(` CREATE VIRTUAL TABLE IF NOT EXISTS vec_items USING vec0( - embedding FLOAT[3072], + embedding FLOAT[${embeddingDimension}], product_name TEXT, version TEXT, branch TEXT, @@ -61,7 +61,7 @@ export class DatabaseManager { logger.info(`Connecting to Qdrant at ${qdrantUrl}:${qdrantPort}, collection: ${collectionName}`); const qdrantClient = new QdrantClient({ url: qdrantUrl, apiKey: process.env.QDRANT_API_KEY, port: qdrantPort }); - await this.createCollectionQdrant(qdrantClient, collectionName, logger); + await this.createCollectionQdrant(qdrantClient, collectionName, logger, embeddingDimension); logger.info(`Qdrant connection established successfully`); return { client: qdrantClient, collectionName, type: 'qdrant' }; } else { @@ -71,7 +71,7 @@ export class DatabaseManager { } } - static async createCollectionQdrant(qdrantClient: QdrantClient, collectionName: string, logger: Logger) { + static async createCollectionQdrant(qdrantClient: QdrantClient, collectionName: string, logger: Logger, embeddingDimension: number = 3072) { try { logger.debug(`Checking if collection ${collectionName} exists`); const collections = await qdrantClient.getCollections(); @@ -84,10 +84,10 @@ export class DatabaseManager { return; } - logger.info(`Creating new collection ${collectionName}`); + logger.info(`Creating new collection ${collectionName} with dimension ${embeddingDimension}`); await qdrantClient.createCollection(collectionName, { vectors: { - size: 3072, + size: embeddingDimension, distance: "Cosine", }, }); @@ -177,7 +177,8 @@ export class DatabaseManager { dbConnection: DatabaseConnection, key: string, value: string, - logger: Logger + logger: Logger, + embeddingDimension: number = 3072 ): Promise { try { if (dbConnection.type === 'sqlite') { @@ -189,8 +190,7 @@ export class DatabaseManager { logger.debug(`Updated metadata value for ${key}`); } else if (dbConnection.type === 'qdrant') { const metadataUUID = Utils.generateMetadataUUID(key); - const dummyEmbeddingSize = 3072; - const dummyEmbedding = new Array(dummyEmbeddingSize).fill(0); + const dummyEmbedding = new Array(embeddingDimension).fill(0); const metadataPoint = { id: metadataUUID, vector: dummyEmbedding, @@ -278,9 +278,19 @@ export class DatabaseManager { logger.debug(`Using UUID: ${metadataUUID} for metadata`); + // Get the embedding dimension from the collection info + let embeddingDimension = 3072; // Default + try { + const collectionInfo = await dbConnection.client.getCollection(dbConnection.collectionName); + if (collectionInfo.config?.params?.vectors && 'size' in collectionInfo.config.params.vectors) { + embeddingDimension = collectionInfo.config.params.vectors.size; + } + } catch (error) { + logger.warn('Could not get collection info, using default dimension'); + } + // Generate a dummy embedding (all zeros) - const dummyEmbeddingSize = 3072; // Same size as your content embeddings - const dummyEmbedding = new Array(dummyEmbeddingSize).fill(0); + const dummyEmbedding = new Array(embeddingDimension).fill(0); // Create a point with special metadata payload const metadataPoint = { diff --git a/doc2vec.ts b/doc2vec.ts index 8507593..0237725 100644 --- a/doc2vec.ts +++ b/doc2vec.ts @@ -37,6 +37,7 @@ export class Doc2Vec { private config: Config; private openai: OpenAI | AzureOpenAI; private embeddingModel: string; + private embeddingDimension: number; private contentProcessor: ContentProcessor; private logger: Logger; private configDir: string; @@ -77,7 +78,8 @@ export class Doc2Vec { apiVersion: azureApiVersion, }); this.embeddingModel = azureDeploymentName; - this.logger.info(`Using Azure OpenAI with deployment: ${azureDeploymentName}`); + this.embeddingDimension = Utils.getEmbeddingDimension(azureDeploymentName); + this.logger.info(`Using Azure OpenAI with deployment: ${azureDeploymentName} (${this.embeddingDimension} dimensions)`); } else { const openaiApiKey = embeddingConfig.openai?.api_key || process.env.OPENAI_API_KEY; const openaiModel = embeddingConfig.openai?.model || process.env.OPENAI_MODEL || 'text-embedding-3-large'; @@ -89,7 +91,8 @@ export class Doc2Vec { this.openai = new OpenAI({ apiKey: openaiApiKey }); this.embeddingModel = openaiModel; - this.logger.info(`Using OpenAI with model: ${openaiModel}`); + this.embeddingDimension = Utils.getEmbeddingDimension(openaiModel); + this.logger.info(`Using OpenAI with model: ${openaiModel} (${this.embeddingDimension} dimensions)`); } this.contentProcessor = new ContentProcessor(this.logger); @@ -397,7 +400,7 @@ export class Doc2Vec { const logger = parentLogger.child('process'); logger.info(`Starting processing for GitHub repo: ${config.repo}`); - const dbConnection = await DatabaseManager.initDatabase(config, logger); + const dbConnection = await DatabaseManager.initDatabase(config, logger, this.embeddingDimension); // Initialize metadata storage await DatabaseManager.initDatabaseMetadata(dbConnection, logger); @@ -414,7 +417,7 @@ export class Doc2Vec { const logger = parentLogger.child('process'); logger.info(`Starting processing for website: ${config.url}`); - const dbConnection = await DatabaseManager.initDatabase(config, logger); + const dbConnection = await DatabaseManager.initDatabase(config, logger, this.embeddingDimension); const validChunkIds: Set = new Set(); const visitedUrls: Set = new Set(); const urlPrefix = Utils.getUrlPrefix(config.url); @@ -565,7 +568,7 @@ export class Doc2Vec { const logger = parentLogger.child('process'); logger.info(`Starting processing for local directory: ${config.path}`); - const dbConnection = await DatabaseManager.initDatabase(config, logger); + const dbConnection = await DatabaseManager.initDatabase(config, logger, this.embeddingDimension); const validChunkIds: Set = new Set(); const processedFiles: Set = new Set(); @@ -701,7 +704,7 @@ export class Doc2Vec { const logger = parentLogger.child('process'); logger.info(`Starting processing for code source (${config.source})`); - const dbConnection = await DatabaseManager.initDatabase(config, logger); + const dbConnection = await DatabaseManager.initDatabase(config, logger, this.embeddingDimension); const validChunkIds: Set = new Set(); const processedFiles: Set = new Set(); @@ -918,10 +921,10 @@ export class Doc2Vec { } } - await DatabaseManager.setMetadataValue(dbConnection, fileListKey, JSON.stringify(currentList), logger); + await DatabaseManager.setMetadataValue(dbConnection, fileListKey, JSON.stringify(currentList), logger, this.embeddingDimension); if (lastMtimeKey) { const nextMtime = maxObservedMtime > 0 ? maxObservedMtime : Date.now(); - await DatabaseManager.setMetadataValue(dbConnection, lastMtimeKey, `${nextMtime}`, logger); + await DatabaseManager.setMetadataValue(dbConnection, lastMtimeKey, `${nextMtime}`, logger, this.embeddingDimension); } } } else { @@ -938,7 +941,7 @@ export class Doc2Vec { const headSha = await this.getRepoHeadSha(basePath, logger); if (headSha) { const shaKey = this.buildCodeShaMetadataKey(config.repo as string, repoBranch); - await DatabaseManager.setMetadataValue(dbConnection, shaKey, headSha, logger); + await DatabaseManager.setMetadataValue(dbConnection, shaKey, headSha, logger, this.embeddingDimension); } } @@ -1127,7 +1130,7 @@ export class Doc2Vec { const logger = parentLogger.child('process'); logger.info(`Starting processing for Zendesk: ${config.zendesk_subdomain}.zendesk.com`); - const dbConnection = await DatabaseManager.initDatabase(config, logger); + const dbConnection = await DatabaseManager.initDatabase(config, logger, this.embeddingDimension); // Initialize metadata storage await DatabaseManager.initDatabaseMetadata(dbConnection, logger); diff --git a/package-lock.json b/package-lock.json index 38d6ac8..745263e 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "doc2vec", - "version": "2.2.0", + "version": "2.4.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "doc2vec", - "version": "2.2.0", + "version": "2.4.0", "license": "ISC", "dependencies": { "@chonkiejs/core": "^0.0.7", diff --git a/package.json b/package.json index 27d5693..b42711c 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "doc2vec", - "version": "2.3.0", + "version": "2.4.0", "type": "commonjs", "description": "", "main": "dist/doc2vec.js", diff --git a/tests/database.test.ts b/tests/database.test.ts index 36ac539..e72db13 100644 --- a/tests/database.test.ts +++ b/tests/database.test.ts @@ -9,14 +9,17 @@ import * as path from 'path'; const testLogger = new Logger('test', { level: LogLevel.NONE }); +// Default embedding dimension for tests +const TEST_EMBEDDING_DIMENSION = 3072; + // Helper to create an in-memory SQLite database matching the app schema -function createTestDb(): BetterSqlite3.Database { +function createTestDb(embeddingDimension: number = TEST_EMBEDDING_DIMENSION): BetterSqlite3.Database { const db = new BetterSqlite3(':memory:', { allowExtension: true } as any); sqliteVec.load(db); db.exec(` CREATE VIRTUAL TABLE IF NOT EXISTS vec_items USING vec0( - embedding FLOAT[3072], + embedding FLOAT[${embeddingDimension}], product_name TEXT, version TEXT, branch TEXT, @@ -60,8 +63,8 @@ function createTestChunk(overrides: Partial { @@ -115,14 +118,14 @@ describe('DatabaseManager', () => { }); it('should set and get metadata values', async () => { - await DatabaseManager.setMetadataValue(conn, 'mykey', 'myvalue', testLogger); + await DatabaseManager.setMetadataValue(conn, 'mykey', 'myvalue', testLogger, TEST_EMBEDDING_DIMENSION); const value = await DatabaseManager.getMetadataValue(conn, 'mykey', undefined, testLogger); expect(value).toBe('myvalue'); }); it('should upsert metadata values', async () => { - await DatabaseManager.setMetadataValue(conn, 'key1', 'value1', testLogger); - await DatabaseManager.setMetadataValue(conn, 'key1', 'value2', testLogger); + await DatabaseManager.setMetadataValue(conn, 'key1', 'value1', testLogger, TEST_EMBEDDING_DIMENSION); + await DatabaseManager.setMetadataValue(conn, 'key1', 'value2', testLogger, TEST_EMBEDDING_DIMENSION); const value = await DatabaseManager.getMetadataValue(conn, 'key1', undefined, testLogger); expect(value).toBe('value2'); }); @@ -595,12 +598,12 @@ describe('DatabaseManager', () => { createCollection: vi.fn().mockResolvedValue({}), }; - await DatabaseManager.createCollectionQdrant(mockClient as any, 'test_col', testLogger); + await DatabaseManager.createCollectionQdrant(mockClient as any, 'test_col', testLogger, TEST_EMBEDDING_DIMENSION); expect(mockClient.createCollection).toHaveBeenCalledOnce(); expect(mockClient.createCollection).toHaveBeenCalledWith('test_col', expect.objectContaining({ vectors: expect.objectContaining({ - size: 3072, + size: TEST_EMBEDDING_DIMENSION, distance: 'Cosine', }), })); @@ -707,7 +710,7 @@ describe('DatabaseManager', () => { type: 'qdrant', }; - await DatabaseManager.setMetadataValue(qdrantDb, 'test_key', 'test_value', testLogger); + await DatabaseManager.setMetadataValue(qdrantDb, 'test_key', 'test_value', testLogger, TEST_EMBEDDING_DIMENSION); expect(mockClient.upsert).toHaveBeenCalledOnce(); const call = mockClient.upsert.mock.calls[0]; @@ -806,7 +809,7 @@ describe('DatabaseManager', () => { sqliteVec.load(db); db.exec(` CREATE VIRTUAL TABLE IF NOT EXISTS vec_items USING vec0( - embedding FLOAT[3072], + embedding FLOAT[${TEST_EMBEDDING_DIMENSION}], product_name TEXT, version TEXT, branch TEXT, @@ -904,7 +907,7 @@ describe('DatabaseManager', () => { } as SourceConfig; await expect( - DatabaseManager.initDatabase(config, testLogger) + DatabaseManager.initDatabase(config, testLogger, TEST_EMBEDDING_DIMENSION) ).rejects.toThrow('Unsupported database type: mongodb'); }); }); @@ -993,7 +996,7 @@ describe('DatabaseManager', () => { expect(point.payload.is_metadata).toBe(true); expect(point.payload.metadata_key).toBe('last_run_owner_repo'); expect(point.payload.metadata_value).toMatch(/^\d{4}-\d{2}-\d{2}T/); - expect(point.vector).toHaveLength(3072); + expect(point.vector).toHaveLength(TEST_EMBEDDING_DIMENSION); }); it('should handle upsert error gracefully', async () => { @@ -1214,7 +1217,7 @@ describe('DatabaseManager', () => { const conn: SqliteDB = { db: mockDb, type: 'sqlite' }; // Should not throw - error is caught internally - await DatabaseManager.setMetadataValue(conn, 'key', 'value', testLogger); + await DatabaseManager.setMetadataValue(conn, 'key', 'value', testLogger, TEST_EMBEDDING_DIMENSION); }); it('should handle Qdrant upsert error gracefully', async () => { @@ -1228,7 +1231,7 @@ describe('DatabaseManager', () => { }; // Should not throw - error is caught internally - await DatabaseManager.setMetadataValue(qdrantDb, 'key', 'value', testLogger); + await DatabaseManager.setMetadataValue(qdrantDb, 'key', 'value', testLogger, TEST_EMBEDDING_DIMENSION); }); }); diff --git a/tests/doc2vec.test.ts b/tests/doc2vec.test.ts index 9be0425..5a11b3c 100644 --- a/tests/doc2vec.test.ts +++ b/tests/doc2vec.test.ts @@ -132,6 +132,7 @@ vi.mock('../utils', () => ({ isValidUuid: vi.fn().mockReturnValue(false), hashToUuid: vi.fn().mockReturnValue('00000000-0000-0000-0000-000000000000'), getUrlPrefix: vi.fn().mockReturnValue('https://example.com'), + getEmbeddingDimension: vi.fn().mockReturnValue(3072), }, })); @@ -196,6 +197,9 @@ describe('Doc2Vec class', () => { // Provide a dummy API key so the constructor validation doesn't call process.exit process.env.OPENAI_API_KEY = 'test-key-for-tests'; + // Force OpenAI provider for tests (override any system default) + process.env.EMBEDDING_PROVIDER = 'openai'; + process.env.OPENAI_MODEL = 'text-embedding-3-large'; // Ensure test config directory exists if (!fs.existsSync(testConfigDir)) { @@ -215,6 +219,8 @@ describe('Doc2Vec class', () => { delete process.env.TEST_DOC2VEC_URL; delete process.env.TEST_DOC2VEC_API_KEY; delete process.env.OPENAI_API_KEY; + delete process.env.EMBEDDING_PROVIDER; + delete process.env.OPENAI_MODEL; }); // ───────────────────────────────────────────────────────────────────────── diff --git a/tests/mcp-server.test.ts b/tests/mcp-server.test.ts index 0a7f5c5..f5df9d3 100644 --- a/tests/mcp-server.test.ts +++ b/tests/mcp-server.test.ts @@ -18,6 +18,8 @@ import { DatabaseManager } from '../database'; import { Logger, LogLevel } from '../logger'; import type { WebsiteSourceConfig } from '../types'; +const TEST_EMBEDDING_DIMENSION = 3072; + describe('MCP server helpers', () => { it('normalizes extensions to lowercase and dot-prefixed', () => { expect(normalizeExtensions(['ts', '.JS', 'Md'])).toEqual(['.ts', '.js', '.md']); @@ -310,7 +312,7 @@ describe('MCP server end-to-end', () => { sqliteVec.load(db); db.exec(` CREATE VIRTUAL TABLE IF NOT EXISTS vec_items USING vec0( - embedding FLOAT[3072], + embedding FLOAT[${TEST_EMBEDDING_DIMENSION}], product_name TEXT, version TEXT, branch TEXT, @@ -341,7 +343,7 @@ describe('MCP server end-to-end', () => { }; const chunks = await processor.chunkMarkdown(markdown, sourceConfig, baseUrl); - const embedding = new Array(3072).fill(0.1); + const embedding = new Array(TEST_EMBEDDING_DIMENSION).fill(0.1); for (const chunk of chunks) { chunk.metadata.branch = ''; chunk.metadata.repo = ''; diff --git a/tests/utils.test.ts b/tests/utils.test.ts index 49d595e..bb0b0e1 100644 --- a/tests/utils.test.ts +++ b/tests/utils.test.ts @@ -337,6 +337,34 @@ describe('Utils', () => { }); }); + // ─── getEmbeddingDimension ────────────────────────────────────── + describe('getEmbeddingDimension', () => { + it('should return 1536 for text-embedding-3-small', () => { + expect(Utils.getEmbeddingDimension('text-embedding-3-small')).toBe(1536); + }); + + it('should return 3072 for text-embedding-3-large', () => { + expect(Utils.getEmbeddingDimension('text-embedding-3-large')).toBe(3072); + }); + + it('should return 1536 for text-embedding-ada-002', () => { + expect(Utils.getEmbeddingDimension('text-embedding-ada-002')).toBe(1536); + }); + + it('should return 3072 for gemini models', () => { + expect(Utils.getEmbeddingDimension('gemini-embedding-001')).toBe(3072); + }); + + it('should be case-insensitive', () => { + expect(Utils.getEmbeddingDimension('TEXT-EMBEDDING-3-SMALL')).toBe(1536); + expect(Utils.getEmbeddingDimension('Text-Embedding-3-Large')).toBe(3072); + }); + + it('should return 1536 for unknown models', () => { + expect(Utils.getEmbeddingDimension('unknown-model')).toBe(1536); + }); + }); + // ─── shouldProcessUrl - invalid URL ───────────────────────────── describe('shouldProcessUrl - invalid URL', () => { it('should throw on invalid URL', () => { diff --git a/utils.ts b/utils.ts index 7419118..574391e 100644 --- a/utils.ts +++ b/utils.ts @@ -82,4 +82,36 @@ export class Utils { static tokenize(text: string): string[] { return text.split(/(\s+)/).filter(token => token.length > 0); } + + /** + * Get the embedding dimension for a given model name + * @param modelName The embedding model name (e.g., 'text-embedding-3-small', 'text-embedding-3-large') + * @returns The dimension size for the model + */ + static getEmbeddingDimension(modelName: string): number { + const modelLower = modelName.toLowerCase(); + + // OpenAI text-embedding-3-small produces 1536 dimensions + if (modelLower.includes('text-embedding-3-small')) { + return 1536; + } + + // OpenAI text-embedding-3-large and text-embedding-ada-002 produce 3072 and 1536 respectively + if (modelLower.includes('text-embedding-3-large')) { + return 3072; + } + + if (modelLower.includes('text-embedding-ada-002')) { + return 1536; + } + + // Gemini embedding models default to 3072 dimensions + if (modelLower.includes('gemini')) { + return 3072; + } + + // Default to 1536 for unknown models (most common) + console.warn(`Unknown embedding model: ${modelName}, defaulting to 1536 dimensions`); + return 1536; + } } \ No newline at end of file