Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
7 changes: 0 additions & 7 deletions MaxKernel/.gitignore

This file was deleted.

317 changes: 73 additions & 244 deletions MaxKernel/README.md
Original file line number Diff line number Diff line change
@@ -1,286 +1,115 @@
# Human-in-the-Loop (HITL) Kernel Generation Agent
# tpu_kernel_gen

An intelligent, interactive agent system for generating, optimizing, testing, and profiling TPU kernels with JAX/Pallas. This agent orchestrates a multi-stage workflow that keeps you in control at every step, from initial planning through implementation, testing, and performance optimization.

## Overview
## Installation

The HITL Kernel Gen Agent provides a conversational interface for TPU kernel development with:
### As a Package (Recommended)

- **Plan-Driven Development**: Creates detailed optimization plans before implementation, allowing you to review and refine the approach
- **Automated Testing**: Generates and executes comprehensive pytest test suites with compilation, correctness, and performance validation
- **Performance Profiling**: Identifies bottlenecks and provides data-driven optimization recommendations
- **GPU-to-JAX Conversion**: Automatically converts CUDA, Triton, and PyTorch GPU code to JAX/Pallas
- **RAG-Enhanced**: Leverages documentation retrieval for accurate, context-aware code generation
- **Safety-First**: Scoped file system access with configurable work directories
Install this package locally for use anywhere:

## Features

### 🎯 Core Capabilities
```bash
# From the project directory
pip install -e .
```

1. **Interactive Kernel Planning**
- Creates detailed optimization plans for Pallas kernels
- Automatic approval workflow with revision support
- Includes tiling strategies, memory optimization, and performance targets
This allows you to import and use the modules from anywhere:

2. **Kernel Implementation**
- Implements kernels following approved plans
- Supports various optimization techniques (tiling, pipelining, memory management)
- Generates clean, idiomatic JAX/Pallas code
```python
from tpu_kernel_gen.kernel_parser import parse_kernels
from tpu_kernel_gen.embed import generate_embeddings
from tpu_kernel_gen.kernel_retrieval import search_similar_kernels
```

3. **Comprehensive Testing**
- Automatic pytest test file generation
- Compilation validation
- Numerical correctness testing
- Performance benchmarking
- Full traceback reporting for debugging
### Dependencies Only

4. **Performance Profiling**
- DMA and memory transfer analysis
- Compute vs memory ratio profiling
- Bottleneck identification with actionable recommendations
Alternatively, install just the required dependencies:

5. **GPU-to-JAX Conversion**
- Converts CUDA, Triton, PyTorch CUDA code to JAX
- Strips hardware-specific optimizations
- Includes syntax validation and numerical correctness testing
- See [GPU-to-JAX Agent README](gpu_to_jax_agent/README.md) for details
```bash
pip install -r requirements.txt
```

### 🛡️ Safety & Control
## Usage

- **Scoped Permissions**: Agent operates only within designated work directory
- **User Approval Required**: All implementations require explicit plan approval
- **Transparent Operations**: All file operations are logged and visible
- **Session Persistence**: Save and resume your work across sessions
### As Python Package

## Architecture
After installing the package, you can use it programmatically:

### Agent Hierarchy
```python
import tpu_kernel_gen.kernel_parser as parser
import tpu_kernel_gen.embed as embedder
import tpu_kernel_gen.kernel_retrieval as retriever

```
KernelGenerationOrchestrationAgent (root_agent)
├── ExplanationAgent - Explains TPU/Pallas concepts
├── PlanKernelAgent - Creates/revises optimization plans
├── ImplementKernelAgent - Implements approved plans
├── ValidatedTestGenerationAgent
│ ├── GenerateTestFileAgent - Creates pytest test files
│ ├── TestValidationLoopAgent - Validates test syntax/structure
│ └── ValidationSummaryAgent - Reports validation results
├── UnifiedTestAgent
│ ├── ReadFileForTestingAgent - Locates test files
│ ├── RunTestsAgent - Executes pytest with server management
│ └── SummarizeTestResultsAgent - Analyzes and reports results
├── ProfileAgentOrchestrator
│ ├── ReadFileForProfilingAgent - Locates kernel files
│ ├── GenerateProfilingScriptAgent - Creates profiling scripts
│ ├── EvalProfileAgent - Executes profiling
│ └── SummarizeProfileAgent - Analyzes bottlenecks
└── GpuToJaxAgent - GPU-to-JAX conversion pipeline
└── (10-step conversion pipeline - see gpu_to_jax_agent/README.md)
```
# Parse kernels
kernels = parser.parse_kernels("/path/to/source")

### Directory Structure
# Generate embeddings
embedder.add_embeddings("kernels.csv")

# Search for similar kernels
results = retriever.search_kernels("matrix multiplication", k=10)
```
hitl_agent/
├── hitl_agent
│ ├── agent.py # Main orchestration logic
│ ├── callbacks.py # Agent callbacks
│ ├── config.py # Configuration management
│ ├── constants.py # Agent constants
│ ├── custom_types.py # Custom types
│ ├── dependency # Agent dependencies
│ │ ├── adk_cli_patch.py
│ │ ├── agent_requirements.txt
│ │ └── main_requirements.txt
│ ├── isolate_object.py
│ ├── knowledge_base. # Knowledge base for pallas docs
│ │ ├── pallas_docs.py
│ │ └── pallas_profiling_docs.py
│ ├── prompts # Main interactive prompt
│ │ ├── __init__.py
│ │ └── interactive_prompt.py
│ ├── server_utils # Server management utilities
│ │ ├── __init__.py
│ │ ├── cpu_server.py
│ │ ├── eval_config.yaml
│ │ ├── eval_server.py
│ │ ├── server_manager_mixin.py
│ │ ├── setup.sh
│ │ └── tpu_server.py
│ ├── subagents # Specialized subagents
│ │ ├── __init__.py
│ │ ├── explanation # Explanation agent
│ │ │ ├── __init__.py
│ │ │ ├── agent.py
│ │ │ └── prompts
│ │ ├── gpu_to_jax_agent # GPU-to-JAX conversion subagent
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── agent.py
│ │ │ ├── constants.py
│ │ │ ├── evaluators
│ │ │ │ ├── __init__.py
│ │ │ │ ├── compilation_checker.py
│ │ │ │ ├── correctness_checker.py
│ │ │ │ ├── jax_syntax_checker.py
│ │ │ │ └── shape_validator.py
│ │ │ ├── prompts
│ │ │ └── test_agent.py
│ │ ├── kernel_writing # Kernel planning & implementation
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ │ ├── __init__.cpython-310.pyc
│ │ │ │ ├── agent.cpython-310.pyc
│ │ │ │ └── kernel_compilation.cpython-310.pyc
│ │ │ ├── agent.py
│ │ │ ├── kernel_compilation.py
│ │ │ └── prompts
│ │ ├── profiling # Performance profiling
│ │ │ ├── __init__.py
│ │ │ ├── agent.py
│ │ │ ├── kernel_profile.py
│ │ │ ├── offline_tools.py
│ │ │ └── prompts
│ │ └── testing # Test generation & execution
│ │ ├── __init__.py
│ │ ├── agent.py
│ │ └── prompts
│ ├── tests
│ │ ├── conftest.py
│ │ ├── test_compilation_validation_loop.py
│ │ └── test_validate_kernel_compilation_agent.py
│ ├── tools # Agent tools
│ │ ├── __init__.py
│ │ ├── analyze_profile.py
│ │ ├── api_rag
│ │ │ ├── __init__.py
│ │ │ └── get_apis.py
│ │ ├── search_api_tool.py
│ │ └── tools.py
│ └── tpu_specs.json
├── prepare_hitl_agent.sh # Setup script
├── run_hitl_agent.sh # Launch script (CLI or UI mode)
└── setup.py
```
## Getting Started

### Prerequisites

1. **Python Environment**: Python 3.9+ with JAX and dependencies installed
2. **Google Cloud**: Vertex AI access for the agent and RAG retrieval
3. **TPU Access**: For actual kernel execution and testing

### Installation

1. **Navigate to the directory of this README file**:

2. **Run the setup script**:
```bash
bash prepare_hitl_agent.sh
```

This script will:
- Prompt you to choose Python environment setup (Miniconda, venv, or your own)
- Install required dependencies
- Set up environment variables
- Configure your work directory
- Create the `.env` file with your settings

3. **Configure Environment Variables**:

The setup script creates a `.env` file. You can edit it manually to customize:

```bash
# Required
GOOGLE_CLOUD_PROJECT=your-project-id
GOOGLE_GENAI_API_KEY=your-api-key

# Optional - defaults provided
WORKDIR=/path/to/your/work/directory # Default: example_workdir
TPU_VERSION=v5e # Default: v5e
SESSION_ID=hitl_session # Default: hitl_session

# RAG Configuration (optional)
VERTEX_AI_RAG_CORPUS=your-corpus-name
GOOGLE_CLOUD_REGION=us-central1
```

### Running the Agent

#### Option 1: CLI Mode (Recommended for Development)

```bash
bash run_hitl_agent.sh
```
### Command Line Usage

## How to populate kernel DB
### Step 1: Parse kernels from source code

**CLI Features**:
- Interactive command-line interface
- Session stored as JSON files (`*.session.json`)
- Easy to debug and inspect
- Lower overhead
Use the kernel parser to extract Pallas kernels from Python source files:

**CLI Options**:
```bash
# Start with default session
bash run_hitl_agent.sh
python kernel_parser.py /path/to/source/directory --output kernels.csv
```

# Start with specific session ID
bash run_hitl_agent.sh --session my_session
This will:
- Recursively scan Python files for JAX Pallas kernels
- Extract kernel definitions and call sites
- Save results to `kernels.csv`

# Reset and start fresh
bash run_hitl_agent.sh --reset
```
### Step 2: Generate embeddings

#### Option 2: Web UI Mode
Add code embeddings to the kernel data using UniXcoder:

```bash
bash run_hitl_agent.sh --ui
python embed.py kernels.csv --code_column code
```

**UI Features**:
- Web interface on port 1430
- Session stored in SQLite database
- Visual conversation history
- Better for demos and non-technical users
This will:
- Load the UniXcoder model for code embeddings
- Process each kernel's code to generate vector embeddings
- Add embedding columns to the CSV file in-place
- Create a backup of the original file

**Important**: CLI and UI modes use different session storage mechanisms and **cannot share sessions**.
### Step 3: Upload to BigQuery

### Resuming Sessions
Upload the enriched kernel data to BigQuery:

**CLI Mode**:
```bash
# Sessions are saved as JSON files
bash run_hitl_agent.sh --session my_session
python bq_upload.py --csv-file kernels.csv --table-name your_dataset.kernels --project-id your-project-id
```

**UI Mode**:
```bash
# Sessions managed through web interface
bash run_hitl_agent.sh --ui
# Browse to http://localhost:1430 and select session
```
This will:
- Upload the CSV data to the specified BigQuery table
- Auto-generate incremental UUIDs for new entries
- Apply the proper schema for the kernel database


## Work Directory Configuration
## How to retrieve from kernel DB

### What is a Work Directory?

The work directory is where the agent:
- Reads your input files (kernels, GPU code, etc.)
- Writes generated files (plans, implementations, tests, profiles)
- Executes operations (testing, profiling)
Use the kernel retrieval tool to search for similar kernels in the BigQuery vector database:

**Security**: The agent's file access is **scoped** to this directory - it cannot access files outside it.
```bash
python kernel_retrieval.py --project-id your-project-id --dataset-name your_dataset --table-name kernels --query "matrix multiplication kernel" --k 10
```

### Setting Up Your Work Directory
This will:
- Connect to your BigQuery vector store using UniXcoder embeddings
- Search for kernels similar to your query using cosine similarity
- Return the top k most similar results with metadata and similarity scores
- Display operation names, frameworks, hardware targets, and file locations

1. **During Initial Setup**:
```bash
bash prepare_hitl_agent.sh
# Follow prompts to set WORKDIR
```
Optional flags:
- `--verbose`: Enable detailed output during the search process
- `--k`: Number of similar kernels to retrieve (default: 5)

2. **Manual Configuration**:
Edit the generated `.env` file:
```bash
WORKDIR=/absolute/path/to/your/work/directory
```
The results will show ranked kernels with their similarity scores, operation metadata, and source file information to help you find relevant kernel implementations.
1 change: 0 additions & 1 deletion MaxKernel/hitl_agent/__init__.py

This file was deleted.

Loading