Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
7 changes: 7 additions & 0 deletions MaxKernel/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
__pycache__/
session_info/
.env
*egg-info
*.txt
.adk
eval_config.yaml
317 changes: 244 additions & 73 deletions MaxKernel/README.md
Original file line number Diff line number Diff line change
@@ -1,115 +1,286 @@
# tpu_kernel_gen
# Human-in-the-Loop (HITL) Kernel Generation Agent

An intelligent, interactive agent system for generating, optimizing, testing, and profiling TPU kernels with JAX/Pallas. This agent orchestrates a multi-stage workflow that keeps you in control at every step, from initial planning through implementation, testing, and performance optimization.

## Installation
## Overview

### As a Package (Recommended)
The HITL Kernel Gen Agent provides a conversational interface for TPU kernel development with:

Install this package locally for use anywhere:
- **Plan-Driven Development**: Creates detailed optimization plans before implementation, allowing you to review and refine the approach
- **Automated Testing**: Generates and executes comprehensive pytest test suites with compilation, correctness, and performance validation
- **Performance Profiling**: Identifies bottlenecks and provides data-driven optimization recommendations
- **GPU-to-JAX Conversion**: Automatically converts CUDA, Triton, and PyTorch GPU code to JAX/Pallas
- **RAG-Enhanced**: Leverages documentation retrieval for accurate, context-aware code generation
- **Safety-First**: Scoped file system access with configurable work directories

```bash
# From the project directory
pip install -e .
```
## Features

This allows you to import and use the modules from anywhere:
### 🎯 Core Capabilities

```python
from tpu_kernel_gen.kernel_parser import parse_kernels
from tpu_kernel_gen.embed import generate_embeddings
from tpu_kernel_gen.kernel_retrieval import search_similar_kernels
```
1. **Interactive Kernel Planning**
- Creates detailed optimization plans for Pallas kernels
- Automatic approval workflow with revision support
- Includes tiling strategies, memory optimization, and performance targets

### Dependencies Only
2. **Kernel Implementation**
- Implements kernels following approved plans
- Supports various optimization techniques (tiling, pipelining, memory management)
- Generates clean, idiomatic JAX/Pallas code

Alternatively, install just the required dependencies:

```bash
pip install -r requirements.txt
```
3. **Comprehensive Testing**
- Automatic pytest test file generation
- Compilation validation
- Numerical correctness testing
- Performance benchmarking
- Full traceback reporting for debugging

## Usage
4. **Performance Profiling**
- DMA and memory transfer analysis
- Compute vs memory ratio profiling
- Bottleneck identification with actionable recommendations

### As Python Package
5. **GPU-to-JAX Conversion**
- Converts CUDA, Triton, PyTorch CUDA code to JAX
- Strips hardware-specific optimizations
- Includes syntax validation and numerical correctness testing
- See [GPU-to-JAX Agent README](gpu_to_jax_agent/README.md) for details

After installing the package, you can use it programmatically:
### 🛡️ Safety & Control

```python
import tpu_kernel_gen.kernel_parser as parser
import tpu_kernel_gen.embed as embedder
import tpu_kernel_gen.kernel_retrieval as retriever
- **Scoped Permissions**: Agent operates only within designated work directory
- **User Approval Required**: All implementations require explicit plan approval
- **Transparent Operations**: All file operations are logged and visible
- **Session Persistence**: Save and resume your work across sessions

# Parse kernels
kernels = parser.parse_kernels("/path/to/source")
## Architecture

# Generate embeddings
embedder.add_embeddings("kernels.csv")
### Agent Hierarchy

# Search for similar kernels
results = retriever.search_kernels("matrix multiplication", k=10)
```
KernelGenerationOrchestrationAgent (root_agent)
├── ExplanationAgent - Explains TPU/Pallas concepts
├── PlanKernelAgent - Creates/revises optimization plans
├── ImplementKernelAgent - Implements approved plans
├── ValidatedTestGenerationAgent
│ ├── GenerateTestFileAgent - Creates pytest test files
│ ├── TestValidationLoopAgent - Validates test syntax/structure
│ └── ValidationSummaryAgent - Reports validation results
├── UnifiedTestAgent
│ ├── ReadFileForTestingAgent - Locates test files
│ ├── RunTestsAgent - Executes pytest with server management
│ └── SummarizeTestResultsAgent - Analyzes and reports results
├── ProfileAgentOrchestrator
│ ├── ReadFileForProfilingAgent - Locates kernel files
│ ├── GenerateProfilingScriptAgent - Creates profiling scripts
│ ├── EvalProfileAgent - Executes profiling
│ └── SummarizeProfileAgent - Analyzes bottlenecks
└── GpuToJaxAgent - GPU-to-JAX conversion pipeline
└── (10-step conversion pipeline - see gpu_to_jax_agent/README.md)
```

### Command Line Usage

## How to populate kernel DB
### Step 1: Parse kernels from source code
### Directory Structure

Use the kernel parser to extract Pallas kernels from Python source files:
```
hitl_agent/
├── hitl_agent
│ ├── agent.py # Main orchestration logic
│ ├── callbacks.py # Agent callbacks
│ ├── config.py # Configuration management
│ ├── constants.py # Agent constants
│ ├── custom_types.py # Custom types
│ ├── dependency # Agent dependencies
│ │ ├── adk_cli_patch.py
│ │ ├── agent_requirements.txt
│ │ └── main_requirements.txt
│ ├── isolate_object.py
│ ├── knowledge_base. # Knowledge base for pallas docs
│ │ ├── pallas_docs.py
│ │ └── pallas_profiling_docs.py
│ ├── prompts # Main interactive prompt
│ │ ├── __init__.py
│ │ └── interactive_prompt.py
│ ├── server_utils # Server management utilities
│ │ ├── __init__.py
│ │ ├── cpu_server.py
│ │ ├── eval_config.yaml
│ │ ├── eval_server.py
│ │ ├── server_manager_mixin.py
│ │ ├── setup.sh
│ │ └── tpu_server.py
│ ├── subagents # Specialized subagents
│ │ ├── __init__.py
│ │ ├── explanation # Explanation agent
│ │ │ ├── __init__.py
│ │ │ ├── agent.py
│ │ │ └── prompts
│ │ ├── gpu_to_jax_agent # GPU-to-JAX conversion subagent
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── agent.py
│ │ │ ├── constants.py
│ │ │ ├── evaluators
│ │ │ │ ├── __init__.py
│ │ │ │ ├── compilation_checker.py
│ │ │ │ ├── correctness_checker.py
│ │ │ │ ├── jax_syntax_checker.py
│ │ │ │ └── shape_validator.py
│ │ │ ├── prompts
│ │ │ └── test_agent.py
│ │ ├── kernel_writing # Kernel planning & implementation
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ │ ├── __init__.cpython-310.pyc
│ │ │ │ ├── agent.cpython-310.pyc
│ │ │ │ └── kernel_compilation.cpython-310.pyc
│ │ │ ├── agent.py
│ │ │ ├── kernel_compilation.py
│ │ │ └── prompts
│ │ ├── profiling # Performance profiling
│ │ │ ├── __init__.py
│ │ │ ├── agent.py
│ │ │ ├── kernel_profile.py
│ │ │ ├── offline_tools.py
│ │ │ └── prompts
│ │ └── testing # Test generation & execution
│ │ ├── __init__.py
│ │ ├── agent.py
│ │ └── prompts
│ ├── tests
│ │ ├── conftest.py
│ │ ├── test_compilation_validation_loop.py
│ │ └── test_validate_kernel_compilation_agent.py
│ ├── tools # Agent tools
│ │ ├── __init__.py
│ │ ├── analyze_profile.py
│ │ ├── api_rag
│ │ │ ├── __init__.py
│ │ │ └── get_apis.py
│ │ ├── search_api_tool.py
│ │ └── tools.py
│ └── tpu_specs.json
├── prepare_hitl_agent.sh # Setup script
├── run_hitl_agent.sh # Launch script (CLI or UI mode)
└── setup.py
```
## Getting Started

### Prerequisites

1. **Python Environment**: Python 3.9+ with JAX and dependencies installed
2. **Google Cloud**: Vertex AI access for the agent and RAG retrieval
3. **TPU Access**: For actual kernel execution and testing

### Installation

1. **Navigate to the directory of this README file**:

2. **Run the setup script**:
```bash
bash prepare_hitl_agent.sh
```

This script will:
- Prompt you to choose Python environment setup (Miniconda, venv, or your own)
- Install required dependencies
- Set up environment variables
- Configure your work directory
- Create the `.env` file with your settings

3. **Configure Environment Variables**:

The setup script creates a `.env` file. You can edit it manually to customize:

```bash
# Required
GOOGLE_CLOUD_PROJECT=your-project-id
GOOGLE_GENAI_API_KEY=your-api-key

# Optional - defaults provided
WORKDIR=/path/to/your/work/directory # Default: example_workdir
TPU_VERSION=v5e # Default: v5e
SESSION_ID=hitl_session # Default: hitl_session

# RAG Configuration (optional)
VERTEX_AI_RAG_CORPUS=your-corpus-name
GOOGLE_CLOUD_REGION=us-central1
```

### Running the Agent

#### Option 1: CLI Mode (Recommended for Development)

```bash
python kernel_parser.py /path/to/source/directory --output kernels.csv
bash run_hitl_agent.sh
```

This will:
- Recursively scan Python files for JAX Pallas kernels
- Extract kernel definitions and call sites
- Save results to `kernels.csv`
**CLI Features**:
- Interactive command-line interface
- Session stored as JSON files (`*.session.json`)
- Easy to debug and inspect
- Lower overhead

### Step 2: Generate embeddings
**CLI Options**:
```bash
# Start with default session
bash run_hitl_agent.sh

# Start with specific session ID
bash run_hitl_agent.sh --session my_session

Add code embeddings to the kernel data using UniXcoder:
# Reset and start fresh
bash run_hitl_agent.sh --reset
```

#### Option 2: Web UI Mode

```bash
python embed.py kernels.csv --code_column code
bash run_hitl_agent.sh --ui
```

This will:
- Load the UniXcoder model for code embeddings
- Process each kernel's code to generate vector embeddings
- Add embedding columns to the CSV file in-place
- Create a backup of the original file
**UI Features**:
- Web interface on port 1430
- Session stored in SQLite database
- Visual conversation history
- Better for demos and non-technical users

### Step 3: Upload to BigQuery
**Important**: CLI and UI modes use different session storage mechanisms and **cannot share sessions**.

Upload the enriched kernel data to BigQuery:
### Resuming Sessions

**CLI Mode**:
```bash
python bq_upload.py --csv-file kernels.csv --table-name your_dataset.kernels --project-id your-project-id
# Sessions are saved as JSON files
bash run_hitl_agent.sh --session my_session
```

This will:
- Upload the CSV data to the specified BigQuery table
- Auto-generate incremental UUIDs for new entries
- Apply the proper schema for the kernel database

**UI Mode**:
```bash
# Sessions managed through web interface
bash run_hitl_agent.sh --ui
# Browse to http://localhost:1430 and select session
```

## How to retrieve from kernel DB
## Work Directory Configuration

### What is a Work Directory?

Use the kernel retrieval tool to search for similar kernels in the BigQuery vector database:
The work directory is where the agent:
- Reads your input files (kernels, GPU code, etc.)
- Writes generated files (plans, implementations, tests, profiles)
- Executes operations (testing, profiling)

```bash
python kernel_retrieval.py --project-id your-project-id --dataset-name your_dataset --table-name kernels --query "matrix multiplication kernel" --k 10
```
**Security**: The agent's file access is **scoped** to this directory - it cannot access files outside it.

This will:
- Connect to your BigQuery vector store using UniXcoder embeddings
- Search for kernels similar to your query using cosine similarity
- Return the top k most similar results with metadata and similarity scores
- Display operation names, frameworks, hardware targets, and file locations
### Setting Up Your Work Directory

Optional flags:
- `--verbose`: Enable detailed output during the search process
- `--k`: Number of similar kernels to retrieve (default: 5)
1. **During Initial Setup**:
```bash
bash prepare_hitl_agent.sh
# Follow prompts to set WORKDIR
```

The results will show ranked kernels with their similarity scores, operation metadata, and source file information to help you find relevant kernel implementations.
2. **Manual Configuration**:
Edit the generated `.env` file:
```bash
WORKDIR=/absolute/path/to/your/work/directory
```
1 change: 1 addition & 0 deletions MaxKernel/hitl_agent/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

Loading