Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 185 additions & 1 deletion apps/worker/src/__tests__/test-runner.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { mkdtempSync, writeFileSync, mkdirSync, rmSync } from 'fs'
import { join } from 'path'
import { tmpdir } from 'os'
import { existsSync, readFileSync } from 'fs'
import { detectTestRunner, detectPackageManager, detectMonorepo, runTests, stripLocalUvSources } from '../test-runner.js'
import { detectTestRunner, detectPackageManager, detectMonorepo, runTests, stripLocalUvSources, stripHeavyPyDeps } from '../test-runner.js'

// Helper: create a temp directory with specific files
function createTempDir(): string {
Expand Down Expand Up @@ -351,6 +351,190 @@ pkg-c = { path = "/absolute/path/to/pkg-c" }
})
})

describe('stripHeavyPyDeps', () => {
let tempDir: string

beforeEach(() => {
tempDir = createTempDir()
})

afterEach(() => {
rmSync(tempDir, { recursive: true, force: true })
})

it('strips heavy packages from base dependencies', () => {
const pyprojectContent = `[project]
name = "test-project"
dependencies = [
"requests>=2.28.0",
"torch>=2.8.0",
"torchvision>=0.24.1",
"pillow>=10.0.0",
"open-clip-torch>=2.20.0",
"transformers>=4.57.3",
"bitsandbytes>=0.41.0",
"peft>=0.18.0",
]

[tool.hatch.build.targets.wheel]
packages = ["my_package"]
`
touchFile(tempDir, 'pyproject.toml', pyprojectContent)
touchFile(tempDir, 'uv.lock', 'some lock content')

stripHeavyPyDeps(tempDir)

const result = readFileSync(join(tempDir, 'pyproject.toml'), 'utf-8')
// Heavy packages should be removed
expect(result).not.toContain('torch>=2.8.0')
expect(result).not.toContain('torchvision')
expect(result).not.toContain('open-clip-torch')
expect(result).not.toContain('transformers')
expect(result).not.toContain('bitsandbytes')
expect(result).not.toContain('peft')
// Lightweight packages should be preserved
expect(result).toContain('requests>=2.28.0')
expect(result).toContain('pillow>=10.0.0')
// Other sections should be preserved
expect(result).toContain('[project]')
expect(result).toContain('[tool.hatch.build.targets.wheel]')
// uv.lock should be deleted
expect(existsSync(join(tempDir, 'uv.lock'))).toBe(false)
})

it('strips entire [project.optional-dependencies] section', () => {
const pyprojectContent = `[project]
name = "test-project"
dependencies = [
"requests>=2.28.0",
]

[project.optional-dependencies]
dev = [
"pytest>=8.0.0",
]
training = [
"torch>=2.8.0",
"trl>=0.12.0",
]
azure = [
"azure-ai-ml>=1.12.0",
]

[tool.hatch.build.targets.wheel]
packages = ["my_package"]
`
touchFile(tempDir, 'pyproject.toml', pyprojectContent)

stripHeavyPyDeps(tempDir)

const result = readFileSync(join(tempDir, 'pyproject.toml'), 'utf-8')
// Entire optional-dependencies section should be removed
expect(result).not.toContain('[project.optional-dependencies]')
expect(result).not.toContain('pytest>=8.0.0')
expect(result).not.toContain('trl>=0.12.0')
expect(result).not.toContain('azure-ai-ml')
// Base deps and other sections should be preserved
expect(result).toContain('requests>=2.28.0')
expect(result).toContain('[tool.hatch.build.targets.wheel]')
})

it('handles nvidia-* prefix packages', () => {
const pyprojectContent = `[project]
name = "test-project"
dependencies = [
"requests>=2.28.0",
"nvidia-cublas-cu12>=12.1.0",
"nvidia-cuda-runtime-cu12>=12.0",
]
`
touchFile(tempDir, 'pyproject.toml', pyprojectContent)

stripHeavyPyDeps(tempDir)

const result = readFileSync(join(tempDir, 'pyproject.toml'), 'utf-8')
expect(result).not.toContain('nvidia-cublas')
expect(result).not.toContain('nvidia-cuda-runtime')
expect(result).toContain('requests>=2.28.0')
})

it('preserves pyproject.toml with no heavy deps', () => {
const pyprojectContent = `[project]
name = "test-project"
dependencies = [
"requests>=2.28.0",
"pillow>=10.0.0",
]
`
touchFile(tempDir, 'pyproject.toml', pyprojectContent)
touchFile(tempDir, 'uv.lock', 'some lock content')

stripHeavyPyDeps(tempDir)

const result = readFileSync(join(tempDir, 'pyproject.toml'), 'utf-8')
expect(result).toBe(pyprojectContent)
// uv.lock should NOT be deleted when no changes were made
expect(existsSync(join(tempDir, 'uv.lock'))).toBe(true)
})

it('handles openadapt-evals-like pyproject.toml', () => {
const pyprojectContent = `[project]
name = "openadapt-evals"
version = "0.46.0"
dependencies = [
"open-clip-torch>=2.20.0",
"pillow>=10.0.0",
"pydantic-settings>=2.0.0",
"requests>=2.28.0",
"openai>=1.0.0",
"anthropic>=0.76.0",
"openadapt-ml>=0.11.0",
]

[project.optional-dependencies]
dev = [
"pytest>=8.0.0",
"ruff>=0.1.0",
]
training = [
"imagehash>=4.3.0",
]
verl = [
"verl>=0.3.0",
]

[tool.uv.sources]
openadapt-ml = { path = "../openadapt-ml", editable = true }

[tool.hatch.build.targets.wheel]
packages = ["openadapt_evals"]
`
touchFile(tempDir, 'pyproject.toml', pyprojectContent)

stripHeavyPyDeps(tempDir)

const result = readFileSync(join(tempDir, 'pyproject.toml'), 'utf-8')
// Heavy base dep should be stripped
expect(result).not.toContain('open-clip-torch')
// Optional deps section entirely removed
expect(result).not.toContain('[project.optional-dependencies]')
expect(result).not.toContain('verl')
// Lightweight base deps preserved
expect(result).toContain('pillow>=10.0.0')
expect(result).toContain('requests>=2.28.0')
expect(result).toContain('openai>=1.0.0')
expect(result).toContain('anthropic>=0.76.0')
expect(result).toContain('openadapt-ml>=0.11.0')
// Other sections preserved
expect(result).toContain('[tool.uv.sources]')
expect(result).toContain('[tool.hatch.build.targets.wheel]')
})

it('does nothing when pyproject.toml does not exist', () => {
expect(() => stripHeavyPyDeps(tempDir)).not.toThrow()
})
})

describe('runTests with real commands', () => {
let tempDir: string

Expand Down
162 changes: 159 additions & 3 deletions apps/worker/src/test-runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,159 @@ export function stripLocalUvSources(workDir: string): void {
}
}

/**
* Known heavy Python packages that should be stripped from dependencies
* on the worker. These packages (CUDA, PyTorch, large ML frameworks) are
* 100MB-2GB each and are not needed to run tests.
*
* Patterns are matched against the package name portion of dependency lines
* (before any version specifier). Case-insensitive.
*/
const HEAVY_PY_PACKAGES = [
'torch',
'torchvision',
'torchaudio',
'open-clip-torch',
'bitsandbytes',
'triton',
'nvidia-', // nvidia-cublas-cu12, nvidia-cuda-runtime-cu12, etc.
'cu12', // standalone CUDA 12 packages
'cu11', // standalone CUDA 11 packages
'xformers',
'flash-attn',
'deepspeed',
'apex',
'vllm',
'transformers',
'accelerate',
'peft',
'safetensors',
'sentencepiece',
'tokenizers',
]

/**
* Check if a dependency line references a heavy package.
*
* Dependency lines look like:
* "torch>=2.8.0"
* "open-clip-torch>=2.20.0"
* "nvidia-cublas-cu12>=12.1.0"
*
* We match the package name (everything before `>=`, `==`, `~=`, `<`, `>`, `[`, etc.)
* against the HEAVY_PY_PACKAGES patterns.
*/
function isHeavyDep(depLine: string): boolean {
const trimmed = depLine.trim().replace(/^["']|["'],?\s*$/g, '')
if (!trimmed || trimmed.startsWith('#')) return false

// Extract package name (before version specifier)
const pkgName = trimmed.split(/[>=<!~\[;]/)[0].trim().toLowerCase()
if (!pkgName) return false

return HEAVY_PY_PACKAGES.some(pattern => {
const p = pattern.toLowerCase()
// If pattern ends with '-', match as prefix
if (p.endsWith('-')) {
return pkgName.startsWith(p)
}
return pkgName === p
})
}

/**
* Strip heavy ML/CUDA dependencies from pyproject.toml to keep installs lightweight.
*
* The Wright worker runs tests, not training. Heavy packages like PyTorch (2GB+),
* CUDA libraries, and large ML frameworks slow down installs, eat disk space, and
* may fail entirely on non-GPU containers.
*
* This function:
* 1. Removes known heavy packages from `[project] dependencies = [...]`
* 2. Removes the entire `[project.optional-dependencies]` section (all groups).
* Optional deps are already skipped by `uv sync --no-dev`, but some groups
* may be pulled in transitively or via `all = [...]` meta-groups.
*
* Combined with `--inexact`, missing transitive deps from stripped packages are
* tolerated — uv will install what it can and skip the rest.
*/
export function stripHeavyPyDeps(workDir: string): void {
const pyprojectPath = join(workDir, 'pyproject.toml')
if (!existsSync(pyprojectPath)) return

const content = readFileSync(pyprojectPath, 'utf-8')
const lines = content.split('\n')
const outputLines: string[] = []
let inDeps = false
let inOptDeps = false
let inOptGroup = false
let strippedAny = false
let bracketDepth = 0

for (let i = 0; i < lines.length; i++) {
const line = lines[i]
const trimmed = line.trim()

// Track [project.optional-dependencies] section — skip it entirely
if (trimmed === '[project.optional-dependencies]') {
inOptDeps = true
inOptGroup = false
strippedAny = true
console.log('[test-runner] Stripping [project.optional-dependencies] section')
continue
}

// If we're in optional-dependencies, skip until we hit a new top-level section
if (inOptDeps) {
if (trimmed.startsWith('[') && trimmed.endsWith(']') && trimmed !== '[project.optional-dependencies]') {
inOptDeps = false
// Fall through to process this line normally
} else {
continue
}
}

// Track `dependencies = [` in [project] section
if (/^dependencies\s*=\s*\[/.test(trimmed)) {
inDeps = true
bracketDepth = (line.match(/\[/g) || []).length - (line.match(/\]/g) || []).length
outputLines.push(line)
continue
}

if (inDeps) {
// Track bracket depth (handles multi-line arrays)
bracketDepth += (line.match(/\[/g) || []).length - (line.match(/\]/g) || []).length
if (bracketDepth <= 0) {
inDeps = false
outputLines.push(line)
continue
}

// Check if this dependency line is heavy
if (isHeavyDep(trimmed)) {
strippedAny = true
console.log(`[test-runner] Stripped heavy dependency: ${trimmed}`)
continue
}
}

outputLines.push(line)
}

if (strippedAny) {
writeFileSync(pyprojectPath, outputLines.join('\n'))
console.log('[test-runner] Removed heavy dependencies from pyproject.toml')

// Delete uv.lock so uv regenerates it without the heavy deps
const lockPath = join(workDir, 'uv.lock')
if (existsSync(lockPath)) {
unlinkSync(lockPath)
console.log('[test-runner] Removed stale uv.lock (will regenerate)')
}
}
}

/**
* Install dependencies using the detected package manager.
*/
Expand All @@ -215,7 +368,7 @@ export function installDependencies(workDir: string, pm: PackageManager): void {
pnpm: 'pnpm install',
yarn: 'yarn install',
pip: 'pip install -e .',
uv: 'uv sync',
uv: 'uv sync --no-dev --inexact',
poetry: 'poetry install',
cargo: 'cargo build',
go: 'go mod download',
Expand All @@ -225,10 +378,13 @@ export function installDependencies(workDir: string, pm: PackageManager): void {
const cmd = commands[pm]
if (!cmd) return

// For uv projects, strip local path sources from pyproject.toml that reference
// sibling directories (e.g. "../openadapt-ml") which won't exist on the worker.
// For uv projects, strip local path sources and heavy ML dependencies from
// pyproject.toml. Local paths reference sibling directories (e.g. "../openadapt-ml")
// that won't exist on the worker. Heavy deps (PyTorch, CUDA, etc.) are 100MB-2GB
// each and not needed to run tests.
if (pm === 'uv') {
stripLocalUvSources(workDir)
stripHeavyPyDeps(workDir)
}

console.log(`[test-runner] Installing dependencies with ${pm}: ${cmd}`)
Expand Down
Loading