Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions MaxCode/agents/migration/model_conversion_agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Agent for converting a model from PyTorch to JAX."""

import re
from typing import Any

from agents import base
Expand Down Expand Up @@ -35,12 +36,22 @@ def run(self, pytorch_model_code: str) -> str:
Returns:
The converted JAX code.
"""
rag_context_list = self._rag_agent.retrieve_context(pytorch_model_code)
rag_context_list = self._rag_agent.retrieve_context(pytorch_model_code, top_k=7)
rag_context = "\n\n".join([
f"File: {c['file']}\n```python\n{c['text']}\n```"
for c in rag_context_list
])
return self.generate(
generated_code = self.generate(
prompts.MODEL_CONVERSION_PROMPT,
{"pytorch_model_code": pytorch_model_code, "rag_context": rag_context},
)
return self._strip_markdown_formatting(generated_code)

def _strip_markdown_formatting(self, text: str) -> str:
"""Strips markdown and returns only the first python code block."""
code_block_match = re.search(
r"```(?:python)?\n?(.*?)\n?```", text, re.DOTALL
)
if code_block_match:
return code_block_match.group(1).strip()
return text
74 changes: 64 additions & 10 deletions MaxCode/agents/migration/primary_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Primary orchestration agent for repository migration."""

import ast
import os
from typing import Any, Dict
from collections import deque
from typing import Any, Dict, List, Set

import models
from agents import base
Expand All @@ -11,6 +13,54 @@
from rag import rag_agent


def _is_model_file(code: str) -> bool:
"""Detects whether code contains a torch.nn.Module subclass definition."""
try:
tree = ast.parse(code)
except SyntaxError:
return False
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
for base_node in node.bases:
# Match nn.Module, torch.nn.Module, Module
if isinstance(base_node, ast.Attribute):
if base_node.attr == "Module":
return True
elif isinstance(base_node, ast.Name):
if base_node.id == "Module":
return True
return False


def _topological_sort(graph: Dict[str, Set[str]]) -> List[str]:
"""Returns files in dependency order (dependencies first) using Kahn's algorithm."""
in_degree = {node: 0 for node in graph}
for node, deps in graph.items():
for dep in deps:
if dep in in_degree:
in_degree[node] += 1

queue = deque(node for node, deg in in_degree.items() if deg == 0)
result = []

while queue:
node = queue.popleft()
result.append(node)
# Find nodes that depend on this one and decrement their in-degree
for other, deps in graph.items():
if node in deps:
in_degree[other] -= 1
if in_degree[other] == 0:
queue.append(other)

# Append any remaining nodes (cycles) to avoid dropping files
for node in graph:
if node not in result:
result.append(node)

return result


class PrimaryAgent(base.Agent):
"""Primary orchestration agent for repository migration."""

Expand All @@ -23,7 +73,7 @@ def __init__(self, model: Any, api_key: str | None = None):
)
self._rag_agent = rag_agent.RAGAgent(
model,
embedding_model_name=models.EmbeddingModel.TEXT_EMBEDDING_004,
embedding_model_name=models.EmbeddingModel.GEMINI_EMBEDDING_001,
api_key=api_key,
)
self._single_file_agent = single_file_agent.PytorchToJaxSingleFileAgent(
Expand All @@ -33,6 +83,12 @@ def __init__(self, model: Any, api_key: str | None = None):
model, self._rag_agent
)

def _convert_file(self, pytorch_code: str) -> str:
"""Routes a file to the appropriate conversion agent."""
if _is_model_file(pytorch_code):
return self._model_conversion_agent.run(pytorch_code)
return self._single_file_agent.run(pytorch_code)

def run(self, repo_path: str) -> Dict[str, str]:
"""Orchestrates the migration of a repository from PyTorch to JAX.

Expand All @@ -43,26 +99,24 @@ def run(self, repo_path: str) -> Dict[str, str]:
A dictionary mapping original file paths to converted JAX code.
"""
if os.path.isfile(repo_path):
with open(repo_path, "r") as f:
with open(repo_path, "r", encoding="utf-8", errors="replace") as f:
pytorch_code = f.read()
converted_code = self._single_file_agent.run(pytorch_code)
converted_code = self._convert_file(pytorch_code)
return {repo_path: converted_code}
elif not os.path.isdir(repo_path):
return {
repo_path: f"# Error: path {repo_path} is not a file or directory."
}

graph = utils.build_dependency_graph(repo_path)
ordered_files = _topological_sort(graph)
converted_files: Dict[str, str] = {}
# conversion order.
# model_conversion_agent for model files, single_file_agent for others).
# For now, convert files individually using single_file_agent.

for file_rel_path in graph:
for file_rel_path in ordered_files:
file_path = os.path.join(repo_path, file_rel_path)
with open(file_path, "r") as f:
with open(file_path, "r", encoding="utf-8", errors="replace") as f:
pytorch_code = f.read()
converted_code = self._single_file_agent.run(pytorch_code)
converted_code = self._convert_file(pytorch_code)
converted_files[file_path] = converted_code

return converted_files
123 changes: 115 additions & 8 deletions MaxCode/agents/migration/prompts/prompts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,66 @@
"""Prompts for code migration."""

JAX_BEST_PRACTICES = """
## JAX/Flax Best Practices (MUST follow)

1. **Use Flax Linen with @nn.compact**: Define all submodules inline inside
`@nn.compact def __call__`. Do NOT use a separate `setup()` method or NNX.
2. **KV Cache**: Use pre-allocated fixed-size caches updated via
`jax.lax.dynamic_update_slice`. NEVER grow the cache with `jnp.concatenate`
or Python list appends -- that breaks XLA compilation.
3. **Causal Conv1d**: Use `padding="VALID"` with explicit left-padding
(`jnp.pad` with `pad_width` on the time axis). Do NOT use `padding="SAME"`
as it is non-causal and leaks future information. Implement as a standalone
function with both prefill (full sequence) and decode (single-step with
conv_state) paths. Use `jax.lax.conv_general_dilated` with
`feature_group_count=channels` for depthwise convolution.
4. **Standalone Imports**: Only import from `jax`, `jax.numpy`, `flax.linen`,
`numpy`, and `math`. Do NOT import from `torch`, `transformers`, or any
PyTorch library in the output.
5. **Static Shapes**: All tensor shapes must be determinable at trace time for
`jax.jit` compatibility. Avoid data-dependent shapes, Python loops over
dynamic lengths, and boolean indexing with data-dependent masks.
CRITICAL: Never use `array[..., :i]` where `i` is dynamic inside
`jax.lax.scan` -- this creates dynamically-sized slices that break JIT.
6. **Variable Ordering**: Define every variable before its first use. No forward
references -- JAX traces sequentially and undefined names cause errors.
7. **Attention Masking**: Use additive masking: add 0.0 where attention is
allowed and a large negative value (e.g., -1e9 or `jnp.finfo(dtype).min`)
where it should be blocked. Do NOT use multiplicative boolean masks.
8. **RMS Norm**: Implement as `x * jax.lax.rsqrt(mean(x^2) + eps) * weight`.
Do NOT call `torch.nn.functional` or leave PyTorch API calls.
9. **Activation Functions**: Use `jax.nn.silu`, `jax.nn.gelu`, etc. Map
`F.silu` -> `jax.nn.silu`, `F.gelu` -> `jax.nn.gelu`.
10. **Rotary Embeddings**: Precompute `cos` and `sin` tables. Apply as
`(x * cos) + (rotate_half(x) * sin)`. Shapes must broadcast correctly.
11. **Triangular Matrix Inversions**: When the PyTorch code has a for-loop
computing a Neumann series on a lower-triangular matrix (e.g., the WY
representation in chunk-parallel delta rules), convert it to
`jax.scipy.linalg.solve_triangular(I - W, I, lower=True)`. This computes
`(I - W)^{{-1}}` directly and is JIT-safe, unlike a scan with dynamic slicing.
12. **Interleaved Weight Ordering**: When the source code has a
`fix_query_key_value_ordering` function or groups projections by key heads
(e.g., when num_key_heads != num_value_heads), you MUST preserve this
ordering exactly. Reshape to [B, T, num_k_heads, per_head_size] and split
within each group. NEVER flatten to a single dimension and do a flat split
-- this produces wrong tensors when num_k_heads != num_v_heads.
13. **Weight Initialization**: Match PyTorch initialization exactly.
MoE router: `nn.initializers.zeros_init()` (NOT normal).
RMSNorm (1+w): `nn.initializers.zeros_init()`.
RMSNorm (w): `nn.initializers.ones_init()`.
Dense projections: `nn.initializers.normal(stddev=config.initializer_range)`.
Check each nn.Parameter in the source and match its init.

## CRITICAL: Faithfulness to Source Code

NEVER simplify complex tensor reshaping, reordering, or algorithmic patterns
from the source code. If the PyTorch code uses a specific interleaved weight
layout, chunk-parallel algorithm, or multi-step computation, convert it
faithfully to JAX. The RAG context shows EXAMPLES of similar patterns -- use
them as guidance for JAX idioms, but always follow the ACTUAL source code's
logic and structure.
"""

PYTORCH_TO_JAX_SINGLE_FILE_PROMPT = """You are an expert in JAX and PyTorch.
Your task is to convert the following PyTorch code to JAX.
If it is helpful, you can use the following JAX code snippets as context for
Expand All @@ -17,7 +78,7 @@
Ensure that the JAX code is idiomatic and follows best practices, such as using
pure functions and handling random number generation correctly with JAX's PRNG
keys. Only return the Python code block for the JAX implementation.
"""
""" + JAX_BEST_PRACTICES

PYTORCH_TO_JAX_REPO_PROMPT = """You are an expert in JAX and PyTorch. Your task
is to convert a repository from PyTorch to JAX. You will be given a file path
Expand All @@ -41,7 +102,7 @@
keys. The conversion should maintain compatibility with other files in the
repository, assuming they will also be converted to JAX.
Only return the Python code block for the JAX implementation.
"""
""" + JAX_BEST_PRACTICES

HF_TO_JAX_SINGLE_FILE_PROMPT = """You are an expert in JAX and PyTorch, with
special expertise in HuggingFace Transformers. Your task is to convert the
Expand All @@ -62,25 +123,71 @@
idiomatic and follows best practices, such as using pure functions and handling
random number generation correctly with JAX's PRNG keys.
Only return the Python code block for the JAX implementation.
"""
""" + JAX_BEST_PRACTICES

MODEL_CONVERSION_PROMPT = """You are an expert in JAX and PyTorch model
architectures. Your task is to convert the following PyTorch model definition
to a JAX-based equivalent, using libraries such as Flax.
architectures. Your task is to convert the ENTIRE PyTorch file below to a
single JAX/Flax file. You MUST convert ALL classes, helper functions,
constants, and configuration dataclasses -- not just one class.

If it is helpful, you can use the following JAX code snippets as context for
functionality that might be similar to your conversion task:
---
{rag_context}
---
PyTorch model:
PyTorch model file:
```python
{pytorch_model_code}
```

IMPORTANT CONVERSION RULES:
1. Convert EVERY class and function in the file above. The output must include
JAX equivalents for all nn.Module subclasses, all helper functions (rotary
embeddings, attention masking, loss functions, etc.), and all supporting code.
2. If the source has a `fix_query_key_value_ordering` method or groups QKVZ
projections by key heads, convert it FAITHFULLY. Reshape to
[B, T, num_k_heads, ...] and split within each key-head group. Do NOT
replace it with a flat split -- that produces wrong tensors when
num_k_heads != num_v_heads.
3. If the source has a chunk-parallel delta rule with a for-loop computing a
Neumann series (WY representation), convert it using
`jax.scipy.linalg.solve_triangular(I - W, I, lower=True)` instead of
jax.lax.scan with dynamic slicing. See the RAG context for the pattern.
4. If the source has both a chunk (prefill) and recurrent (decode) mode for
linear attention, implement BOTH modes and dispatch based on sequence length.
5. Implement causal_conv1d as a standalone function with both prefill and
single-step decode paths.
6. For causal operations with decode-time state (causal conv1d, linear
attention), implement SEPARATE prefill and decode functions. Do NOT use
a single unified function with conditional branching.
7. ALWAYS include a `@dataclasses.dataclass` Config class at the top of the
output file. Mirror ALL fields from the PyTorch configuration class with
their types and default values. Use `dataclasses.field(default_factory=...)`
for mutable defaults. Use the Config type (not `Any`) in module annotations.
8. The `load_balancing_loss` function MUST accept an optional `attention_mask`
parameter. When the mask is provided, broadcast it to match the concatenated
router logits shape and use it to exclude padding tokens from mean/sum
statistics. See the RAG context for the full pattern.
9. **MoE Experts: Capacity-Based Dispatch (MANDATORY)**. The Experts class MUST
use capacity-based dispatch with dispatch/combine tensors -- NOT per-token
gather of expert weights. The correct pattern is:
a) Compute per-expert capacity: `capacity = ceil(T * K / E) * 1.5`
b) Build dispatch tensor via `one_hot(selected_experts) -> cumsum -> positions
-> one_hot(positions, capacity)` to get `dispatch: [T, E, C]`
c) Build combine tensor: `combine = dispatch * routing_weights`
d) Route tokens to expert buffers: `expert_in = einsum('tec,th->ech', dispatch, x)`
e) Batched expert matmul: `expert_out = einsum('ech,ehi->eci', expert_in, W)`
f) Scatter back: `output = einsum('tec,ech->th', combine, expert_out)`
Do NOT use `weight[flat_indices]` gather or `jax.vmap` over individual experts.
Do NOT use `jnp.einsum('td,edh->teh')` computing all experts for all tokens.
The capacity-based approach is 10-50x more efficient for large E (e.g. E=64).
See the RAG context file `targeted_moe_capacity_routing_jax.py` for the full
implementation with WRONG/CORRECT examples.

Please think step by step about the conversion process before generating the code.
Then, provide the JAX equivalent of the model definition above.
Then, provide the complete JAX equivalent of the entire file above.
Ensure that the JAX code is idiomatic and follows best practices for defining
models in JAX, such as using pure functions and handling random number
generation correctly with JAX's PRNG keys.
Only return the Python code block for the JAX implementation.
"""
""" + JAX_BEST_PRACTICES
2 changes: 1 addition & 1 deletion MaxCode/agents/migration/single_file_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def run(self, pytorch_code: str) -> str:
Returns:
The converted JAX code.
"""
rag_context_list = self._rag_agent.retrieve_context(pytorch_code)
rag_context_list = self._rag_agent.retrieve_context(pytorch_code, top_k=7)
rag_context = "\n\n".join([
f"File: {c['file']}\n```python\n{c['text']}\n```"
for c in rag_context_list
Expand Down
38 changes: 26 additions & 12 deletions MaxCode/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import enum
import logging
import random
import time
import requests

Expand All @@ -13,12 +14,14 @@ class GeminiModel(enum.Enum):
GEMINI_2_5_FLASH = "gemini-2.5-flash"
GEMINI_3_0_PRO = "gemini-3.0-pro"
GEMINI_3_0_FLASH = "gemini-3.0-flash"
GEMINI_3_1_PRO_PREVIEW = "gemini-3.1-pro-preview"


class EmbeddingModel(enum.Enum):
"""Enum for Embedding model names."""

EMBEDDING_001 = "models/embedding-001"
GEMINI_EMBEDDING_001 = "models/gemini-embedding-001"
TEXT_EMBEDDING_004 = "models/text-embedding-004"


Expand Down Expand Up @@ -72,18 +75,29 @@ def __call__(self, user_prompt):
"parts": [{"text": self.system_instruction}]
}

try:
time.sleep(2)
response = requests.post(self.endpoint, headers=headers, json=payload)
response.raise_for_status() # Raise HTTPError for bad responses
json_response = response.json()
return json_response["candidates"][0]["content"]["parts"][0]["text"]
except requests.exceptions.RequestException as e:
logging.error("Error calling Gemini API: %s", e)
raise
except (KeyError, IndexError) as e:
logging.error("Error parsing Gemini API response: %s", e)
raise ValueError("Could not parse response from Gemini API.") from e
max_retries = 5
for attempt in range(max_retries):
try:
time.sleep(2 if attempt == 0 else min(30, 2 ** (attempt + 1)))
response = requests.post(
self.endpoint, headers=headers, json=payload, timeout=600
)
response.raise_for_status() # Raise HTTPError for bad responses
json_response = response.json()
return json_response["candidates"][0]["content"]["parts"][0]["text"]
except requests.exceptions.RequestException as e:
status = getattr(getattr(e, "response", None), "status_code", None)
if status in (429, 500, 503) and attempt < max_retries - 1:
wait = min(60, 2 ** (attempt + 2)) + random.uniform(0, 2)
logging.warning("Gemini API %s, retrying in %.1fs (attempt %d/%d)...",
status, wait, attempt + 1, max_retries)
time.sleep(wait)
continue
logging.error("Error calling Gemini API: %s", e)
raise
except (KeyError, IndexError) as e:
logging.error("Error parsing Gemini API response: %s", e)
raise ValueError("Could not parse response from Gemini API.") from e

def generate(self, user_prompt: str) -> str:
"""Alias for __call__ to support agents expecting a generate method."""
Expand Down
Loading