Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions database.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import {
export class DatabaseManager {
private static columnCache: WeakMap<Database, { hasBranch: boolean; hasRepo: boolean }> = new WeakMap();

static async initDatabase(config: SourceConfig, parentLogger: Logger): Promise<DatabaseConnection> {
static async initDatabase(config: SourceConfig, parentLogger: Logger, embeddingDimension: number = 3072): Promise<DatabaseConnection> {
const logger = parentLogger.child('database');
const dbConfig = config.database_config;

Expand All @@ -32,10 +32,10 @@ export class DatabaseManager {
const db = new BetterSqlite3(dbPath, { allowExtension: true } as any);
sqliteVec.load(db);

logger.debug(`Creating vec_items table if it doesn't exist`);
logger.debug(`Creating vec_items table if it doesn't exist (dimension: ${embeddingDimension})`);
db.exec(`
CREATE VIRTUAL TABLE IF NOT EXISTS vec_items USING vec0(
embedding FLOAT[3072],
embedding FLOAT[${embeddingDimension}],
product_name TEXT,
version TEXT,
branch TEXT,
Expand All @@ -61,7 +61,7 @@ export class DatabaseManager {
logger.info(`Connecting to Qdrant at ${qdrantUrl}:${qdrantPort}, collection: ${collectionName}`);
const qdrantClient = new QdrantClient({ url: qdrantUrl, apiKey: process.env.QDRANT_API_KEY, port: qdrantPort });

await this.createCollectionQdrant(qdrantClient, collectionName, logger);
await this.createCollectionQdrant(qdrantClient, collectionName, logger, embeddingDimension);
logger.info(`Qdrant connection established successfully`);
return { client: qdrantClient, collectionName, type: 'qdrant' };
} else {
Expand All @@ -71,7 +71,7 @@ export class DatabaseManager {
}
}

static async createCollectionQdrant(qdrantClient: QdrantClient, collectionName: string, logger: Logger) {
static async createCollectionQdrant(qdrantClient: QdrantClient, collectionName: string, logger: Logger, embeddingDimension: number = 3072) {
try {
logger.debug(`Checking if collection ${collectionName} exists`);
const collections = await qdrantClient.getCollections();
Expand All @@ -84,10 +84,10 @@ export class DatabaseManager {
return;
}

logger.info(`Creating new collection ${collectionName}`);
logger.info(`Creating new collection ${collectionName} with dimension ${embeddingDimension}`);
await qdrantClient.createCollection(collectionName, {
vectors: {
size: 3072,
size: embeddingDimension,
distance: "Cosine",
},
});
Expand Down Expand Up @@ -177,7 +177,8 @@ export class DatabaseManager {
dbConnection: DatabaseConnection,
key: string,
value: string,
logger: Logger
logger: Logger,
embeddingDimension: number = 3072
): Promise<void> {
try {
if (dbConnection.type === 'sqlite') {
Expand All @@ -189,8 +190,7 @@ export class DatabaseManager {
logger.debug(`Updated metadata value for ${key}`);
} else if (dbConnection.type === 'qdrant') {
const metadataUUID = Utils.generateMetadataUUID(key);
const dummyEmbeddingSize = 3072;
const dummyEmbedding = new Array(dummyEmbeddingSize).fill(0);
const dummyEmbedding = new Array(embeddingDimension).fill(0);
const metadataPoint = {
id: metadataUUID,
vector: dummyEmbedding,
Expand Down Expand Up @@ -278,9 +278,19 @@ export class DatabaseManager {

logger.debug(`Using UUID: ${metadataUUID} for metadata`);

// Get the embedding dimension from the collection info
let embeddingDimension = 3072; // Default
try {
const collectionInfo = await dbConnection.client.getCollection(dbConnection.collectionName);
if (collectionInfo.config?.params?.vectors && 'size' in collectionInfo.config.params.vectors) {
embeddingDimension = collectionInfo.config.params.vectors.size;
}
} catch (error) {
logger.warn('Could not get collection info, using default dimension');
}

// Generate a dummy embedding (all zeros)
const dummyEmbeddingSize = 3072; // Same size as your content embeddings
const dummyEmbedding = new Array(dummyEmbeddingSize).fill(0);
const dummyEmbedding = new Array(embeddingDimension).fill(0);

// Create a point with special metadata payload
const metadataPoint = {
Expand Down
23 changes: 13 additions & 10 deletions doc2vec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ export class Doc2Vec {
private config: Config;
private openai: OpenAI | AzureOpenAI;
private embeddingModel: string;
private embeddingDimension: number;
private contentProcessor: ContentProcessor;
private logger: Logger;
private configDir: string;
Expand Down Expand Up @@ -77,7 +78,8 @@ export class Doc2Vec {
apiVersion: azureApiVersion,
});
this.embeddingModel = azureDeploymentName;
this.logger.info(`Using Azure OpenAI with deployment: ${azureDeploymentName}`);
this.embeddingDimension = Utils.getEmbeddingDimension(azureDeploymentName);
this.logger.info(`Using Azure OpenAI with deployment: ${azureDeploymentName} (${this.embeddingDimension} dimensions)`);
} else {
const openaiApiKey = embeddingConfig.openai?.api_key || process.env.OPENAI_API_KEY;
const openaiModel = embeddingConfig.openai?.model || process.env.OPENAI_MODEL || 'text-embedding-3-large';
Expand All @@ -89,7 +91,8 @@ export class Doc2Vec {

this.openai = new OpenAI({ apiKey: openaiApiKey });
this.embeddingModel = openaiModel;
this.logger.info(`Using OpenAI with model: ${openaiModel}`);
this.embeddingDimension = Utils.getEmbeddingDimension(openaiModel);
this.logger.info(`Using OpenAI with model: ${openaiModel} (${this.embeddingDimension} dimensions)`);
}

this.contentProcessor = new ContentProcessor(this.logger);
Expand Down Expand Up @@ -397,7 +400,7 @@ export class Doc2Vec {
const logger = parentLogger.child('process');
logger.info(`Starting processing for GitHub repo: ${config.repo}`);

const dbConnection = await DatabaseManager.initDatabase(config, logger);
const dbConnection = await DatabaseManager.initDatabase(config, logger, this.embeddingDimension);

// Initialize metadata storage
await DatabaseManager.initDatabaseMetadata(dbConnection, logger);
Expand All @@ -414,7 +417,7 @@ export class Doc2Vec {
const logger = parentLogger.child('process');
logger.info(`Starting processing for website: ${config.url}`);

const dbConnection = await DatabaseManager.initDatabase(config, logger);
const dbConnection = await DatabaseManager.initDatabase(config, logger, this.embeddingDimension);
const validChunkIds: Set<string> = new Set();
const visitedUrls: Set<string> = new Set();
const urlPrefix = Utils.getUrlPrefix(config.url);
Expand Down Expand Up @@ -565,7 +568,7 @@ export class Doc2Vec {
const logger = parentLogger.child('process');
logger.info(`Starting processing for local directory: ${config.path}`);

const dbConnection = await DatabaseManager.initDatabase(config, logger);
const dbConnection = await DatabaseManager.initDatabase(config, logger, this.embeddingDimension);
const validChunkIds: Set<string> = new Set();
const processedFiles: Set<string> = new Set();

Expand Down Expand Up @@ -701,7 +704,7 @@ export class Doc2Vec {
const logger = parentLogger.child('process');
logger.info(`Starting processing for code source (${config.source})`);

const dbConnection = await DatabaseManager.initDatabase(config, logger);
const dbConnection = await DatabaseManager.initDatabase(config, logger, this.embeddingDimension);
const validChunkIds: Set<string> = new Set();
const processedFiles: Set<string> = new Set();

Expand Down Expand Up @@ -918,10 +921,10 @@ export class Doc2Vec {
}
}

await DatabaseManager.setMetadataValue(dbConnection, fileListKey, JSON.stringify(currentList), logger);
await DatabaseManager.setMetadataValue(dbConnection, fileListKey, JSON.stringify(currentList), logger, this.embeddingDimension);
if (lastMtimeKey) {
const nextMtime = maxObservedMtime > 0 ? maxObservedMtime : Date.now();
await DatabaseManager.setMetadataValue(dbConnection, lastMtimeKey, `${nextMtime}`, logger);
await DatabaseManager.setMetadataValue(dbConnection, lastMtimeKey, `${nextMtime}`, logger, this.embeddingDimension);
}
}
} else {
Expand All @@ -938,7 +941,7 @@ export class Doc2Vec {
const headSha = await this.getRepoHeadSha(basePath, logger);
if (headSha) {
const shaKey = this.buildCodeShaMetadataKey(config.repo as string, repoBranch);
await DatabaseManager.setMetadataValue(dbConnection, shaKey, headSha, logger);
await DatabaseManager.setMetadataValue(dbConnection, shaKey, headSha, logger, this.embeddingDimension);
}
}

Expand Down Expand Up @@ -1127,7 +1130,7 @@ export class Doc2Vec {
const logger = parentLogger.child('process');
logger.info(`Starting processing for Zendesk: ${config.zendesk_subdomain}.zendesk.com`);

const dbConnection = await DatabaseManager.initDatabase(config, logger);
const dbConnection = await DatabaseManager.initDatabase(config, logger, this.embeddingDimension);

// Initialize metadata storage
await DatabaseManager.initDatabaseMetadata(dbConnection, logger);
Expand Down
4 changes: 2 additions & 2 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "doc2vec",
"version": "2.3.0",
"version": "2.4.0",
"type": "commonjs",
"description": "",
"main": "dist/doc2vec.js",
Expand Down
33 changes: 18 additions & 15 deletions tests/database.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,17 @@ import * as path from 'path';

const testLogger = new Logger('test', { level: LogLevel.NONE });

// Default embedding dimension for tests
const TEST_EMBEDDING_DIMENSION = 3072;

// Helper to create an in-memory SQLite database matching the app schema
function createTestDb(): BetterSqlite3.Database {
function createTestDb(embeddingDimension: number = TEST_EMBEDDING_DIMENSION): BetterSqlite3.Database {
const db = new BetterSqlite3(':memory:', { allowExtension: true } as any);
sqliteVec.load(db);

db.exec(`
CREATE VIRTUAL TABLE IF NOT EXISTS vec_items USING vec0(
embedding FLOAT[3072],
embedding FLOAT[${embeddingDimension}],
product_name TEXT,
version TEXT,
branch TEXT,
Expand Down Expand Up @@ -60,8 +63,8 @@ function createTestChunk(overrides: Partial<DocumentChunk & { metadata?: Partial
};
}

function createTestEmbedding(): number[] {
return new Array(3072).fill(0.1);
function createTestEmbedding(embeddingDimension: number = TEST_EMBEDDING_DIMENSION): number[] {
return new Array(embeddingDimension).fill(0.1);
}

describe('DatabaseManager', () => {
Expand Down Expand Up @@ -115,14 +118,14 @@ describe('DatabaseManager', () => {
});

it('should set and get metadata values', async () => {
await DatabaseManager.setMetadataValue(conn, 'mykey', 'myvalue', testLogger);
await DatabaseManager.setMetadataValue(conn, 'mykey', 'myvalue', testLogger, TEST_EMBEDDING_DIMENSION);
const value = await DatabaseManager.getMetadataValue(conn, 'mykey', undefined, testLogger);
expect(value).toBe('myvalue');
});

it('should upsert metadata values', async () => {
await DatabaseManager.setMetadataValue(conn, 'key1', 'value1', testLogger);
await DatabaseManager.setMetadataValue(conn, 'key1', 'value2', testLogger);
await DatabaseManager.setMetadataValue(conn, 'key1', 'value1', testLogger, TEST_EMBEDDING_DIMENSION);
await DatabaseManager.setMetadataValue(conn, 'key1', 'value2', testLogger, TEST_EMBEDDING_DIMENSION);
const value = await DatabaseManager.getMetadataValue(conn, 'key1', undefined, testLogger);
expect(value).toBe('value2');
});
Expand Down Expand Up @@ -595,12 +598,12 @@ describe('DatabaseManager', () => {
createCollection: vi.fn().mockResolvedValue({}),
};

await DatabaseManager.createCollectionQdrant(mockClient as any, 'test_col', testLogger);
await DatabaseManager.createCollectionQdrant(mockClient as any, 'test_col', testLogger, TEST_EMBEDDING_DIMENSION);

expect(mockClient.createCollection).toHaveBeenCalledOnce();
expect(mockClient.createCollection).toHaveBeenCalledWith('test_col', expect.objectContaining({
vectors: expect.objectContaining({
size: 3072,
size: TEST_EMBEDDING_DIMENSION,
distance: 'Cosine',
}),
}));
Expand Down Expand Up @@ -707,7 +710,7 @@ describe('DatabaseManager', () => {
type: 'qdrant',
};

await DatabaseManager.setMetadataValue(qdrantDb, 'test_key', 'test_value', testLogger);
await DatabaseManager.setMetadataValue(qdrantDb, 'test_key', 'test_value', testLogger, TEST_EMBEDDING_DIMENSION);

expect(mockClient.upsert).toHaveBeenCalledOnce();
const call = mockClient.upsert.mock.calls[0];
Expand Down Expand Up @@ -806,7 +809,7 @@ describe('DatabaseManager', () => {
sqliteVec.load(db);
db.exec(`
CREATE VIRTUAL TABLE IF NOT EXISTS vec_items USING vec0(
embedding FLOAT[3072],
embedding FLOAT[${TEST_EMBEDDING_DIMENSION}],
product_name TEXT,
version TEXT,
branch TEXT,
Expand Down Expand Up @@ -904,7 +907,7 @@ describe('DatabaseManager', () => {
} as SourceConfig;

await expect(
DatabaseManager.initDatabase(config, testLogger)
DatabaseManager.initDatabase(config, testLogger, TEST_EMBEDDING_DIMENSION)
).rejects.toThrow('Unsupported database type: mongodb');
});
});
Expand Down Expand Up @@ -993,7 +996,7 @@ describe('DatabaseManager', () => {
expect(point.payload.is_metadata).toBe(true);
expect(point.payload.metadata_key).toBe('last_run_owner_repo');
expect(point.payload.metadata_value).toMatch(/^\d{4}-\d{2}-\d{2}T/);
expect(point.vector).toHaveLength(3072);
expect(point.vector).toHaveLength(TEST_EMBEDDING_DIMENSION);
});

it('should handle upsert error gracefully', async () => {
Expand Down Expand Up @@ -1214,7 +1217,7 @@ describe('DatabaseManager', () => {
const conn: SqliteDB = { db: mockDb, type: 'sqlite' };

// Should not throw - error is caught internally
await DatabaseManager.setMetadataValue(conn, 'key', 'value', testLogger);
await DatabaseManager.setMetadataValue(conn, 'key', 'value', testLogger, TEST_EMBEDDING_DIMENSION);
});

it('should handle Qdrant upsert error gracefully', async () => {
Expand All @@ -1228,7 +1231,7 @@ describe('DatabaseManager', () => {
};

// Should not throw - error is caught internally
await DatabaseManager.setMetadataValue(qdrantDb, 'key', 'value', testLogger);
await DatabaseManager.setMetadataValue(qdrantDb, 'key', 'value', testLogger, TEST_EMBEDDING_DIMENSION);
});
});

Expand Down
6 changes: 6 additions & 0 deletions tests/doc2vec.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ vi.mock('../utils', () => ({
isValidUuid: vi.fn().mockReturnValue(false),
hashToUuid: vi.fn().mockReturnValue('00000000-0000-0000-0000-000000000000'),
getUrlPrefix: vi.fn().mockReturnValue('https://example.com'),
getEmbeddingDimension: vi.fn().mockReturnValue(3072),
},
}));

Expand Down Expand Up @@ -196,6 +197,9 @@ describe('Doc2Vec class', () => {

// Provide a dummy API key so the constructor validation doesn't call process.exit
process.env.OPENAI_API_KEY = 'test-key-for-tests';
// Force OpenAI provider for tests (override any system default)
process.env.EMBEDDING_PROVIDER = 'openai';
process.env.OPENAI_MODEL = 'text-embedding-3-large';

// Ensure test config directory exists
if (!fs.existsSync(testConfigDir)) {
Expand All @@ -215,6 +219,8 @@ describe('Doc2Vec class', () => {
delete process.env.TEST_DOC2VEC_URL;
delete process.env.TEST_DOC2VEC_API_KEY;
delete process.env.OPENAI_API_KEY;
delete process.env.EMBEDDING_PROVIDER;
delete process.env.OPENAI_MODEL;
});

// ─────────────────────────────────────────────────────────────────────────
Expand Down
6 changes: 4 additions & 2 deletions tests/mcp-server.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import { DatabaseManager } from '../database';
import { Logger, LogLevel } from '../logger';
import type { WebsiteSourceConfig } from '../types';

const TEST_EMBEDDING_DIMENSION = 3072;

describe('MCP server helpers', () => {
it('normalizes extensions to lowercase and dot-prefixed', () => {
expect(normalizeExtensions(['ts', '.JS', 'Md'])).toEqual(['.ts', '.js', '.md']);
Expand Down Expand Up @@ -310,7 +312,7 @@ describe('MCP server end-to-end', () => {
sqliteVec.load(db);
db.exec(`
CREATE VIRTUAL TABLE IF NOT EXISTS vec_items USING vec0(
embedding FLOAT[3072],
embedding FLOAT[${TEST_EMBEDDING_DIMENSION}],
product_name TEXT,
version TEXT,
branch TEXT,
Expand Down Expand Up @@ -341,7 +343,7 @@ describe('MCP server end-to-end', () => {
};

const chunks = await processor.chunkMarkdown(markdown, sourceConfig, baseUrl);
const embedding = new Array(3072).fill(0.1);
const embedding = new Array(TEST_EMBEDDING_DIMENSION).fill(0.1);
for (const chunk of chunks) {
chunk.metadata.branch = '';
chunk.metadata.repo = '';
Expand Down
Loading