Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions MaxCode/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,25 @@
This extension provides development tools for the MaxCode project,
including tools for AI-powered code migration between ML frameworks.

## Quick Start

Want to try MaxCode without the full Gemini CLI setup? The standalone demo
converts a PyTorch repo to JAX in three commands:

```bash
cd MaxCode/examples/demo
pip install -r requirements.txt
export GOOGLE_API_KEY=<your-key> # Windows CMD: set GOOGLE_API_KEY=<your-key>

python step1_clone_repo.py # Clone a PyTorch repo from GitHub
python step2_populate_rag.py # Build the RAG reference database
python step3_merge.py # Auto-detect and merge model files
python step4_convert.py # Convert to JAX with validation + repair
```

See [examples/demo/README.md](examples/demo/README.md) for full setup
instructions and details on what each step does.

## Prerequisites

This extension uses the Google AI API, which requires an API key. You can get
Expand Down
72 changes: 70 additions & 2 deletions MaxCode/agents/migration/primary_agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Primary orchestration agent for repository migration."""
import logging
import os
from typing import Any

Expand All @@ -7,19 +8,26 @@
from agents import utils
from agents.migration import model_conversion_agent
from agents.migration import single_file_agent
from agents.migration import validation_agent
from rag import rag_agent

logger = logging.getLogger(__name__)


class PrimaryAgent(base.Agent):
"""Primary orchestration agent for repository migration."""

def __init__(self, model: Any, api_key: str | None = None):
def __init__(self, model: Any, api_key: str | None = None,
validate: bool = True):
"""Initializes the agent."""
super().__init__(
model=model,
agent_domain=utils.AgentDomain.MIGRATION,
agent_type=utils.AgentType.PRIMARY,
)
self._model_ref = model
self._validate = validate
self._validation_results: dict[str, dict] = {}
self._rag_agent = rag_agent.RAGAgent(
model,
embedding_model_name=models.EmbeddingModel.GEMINI_EMBEDDING_001,
Expand All @@ -38,6 +46,55 @@ def _convert_file(self, pytorch_code: str, file_path: str) -> str:
return self._model_conversion_agent.run(pytorch_code)
return self._single_file_agent.run(pytorch_code)

def _validate_and_repair(self, pytorch_code: str, converted_code: str,
file_path: str) -> str:
"""Validates converted code and repairs deviations if found.

Args:
pytorch_code: The original PyTorch source code.
converted_code: The converted JAX/Flax code.
file_path: The file path (used as key for storing results).

Returns:
The final code (repaired if deviations were found, original otherwise).
"""
validator = validation_agent.ValidationAgent(self._model_ref)
deviations = validator.validate(pytorch_code, converted_code)
logger.info("Validation of %s: found %d deviations",
file_path, len(deviations))

result = {
"deviations_found": len(deviations),
"deviations": deviations,
"remaining_deviations_count": 0,
"remaining_deviations": [],
}

if deviations:
repaired_code = validator.repair(
converted_code, deviations, pytorch_code=pytorch_code
)
remaining = validator.validate(pytorch_code, repaired_code)
logger.info("Re-validation of %s: %d remaining deviations",
file_path, len(remaining))
result["remaining_deviations_count"] = len(remaining)
result["remaining_deviations"] = remaining
self._validation_results[file_path] = result
return repaired_code

self._validation_results[file_path] = result
return converted_code

def get_validation_results(self) -> dict[str, dict]:
"""Returns validation results for all processed files.

Returns:
A dictionary mapping file paths to their validation results, each
containing deviations_found, deviations, remaining_deviations_count,
and remaining_deviations.
"""
return self._validation_results

def run(self, repo_path: str) -> dict[str, str]:
"""Orchestrates the migration of a repository from PyTorch to JAX.

Expand All @@ -50,7 +107,12 @@ def run(self, repo_path: str) -> dict[str, str]:
try:
with open(repo_path, "r", encoding="utf-8", errors="replace") as f:
pytorch_code = f.read()
logger.info("Converting %s ...", repo_path)
converted_code = self._convert_file(pytorch_code, repo_path)
if self._validate:
converted_code = self._validate_and_repair(
pytorch_code, converted_code, repo_path
)
return {repo_path: converted_code}
except OSError:
# If opening as a file fails, check if it's a directory.
Expand All @@ -68,11 +130,17 @@ def run(self, repo_path: str) -> dict[str, str]:
ordered_files = utils.topological_sort(graph)
converted_files: dict[str, str] = {}

for file_rel_path in ordered_files:
for i, file_rel_path in enumerate(ordered_files, 1):
file_path = os.path.join(repo_path, file_rel_path)
logger.info("Converting file %d/%d: %s ...", i, len(ordered_files),
file_rel_path)
with open(file_path, "r", encoding="utf-8", errors="replace") as f:
pytorch_code = f.read()
converted_code = self._convert_file(pytorch_code, file_path)
if self._validate:
converted_code = self._validate_and_repair(
pytorch_code, converted_code, file_path
)
converted_files[file_path] = converted_code

return converted_files
111 changes: 111 additions & 0 deletions MaxCode/agents/migration/prompts/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,59 @@
ordering exactly. Reshape to [B, T, num_k_heads, per_head_size] and split
within each group. NEVER flatten to a single dimension and do a flat split
-- this produces wrong tensors when num_k_heads != num_v_heads.
13. **Weight Initialization**: Match PyTorch initialization exactly.
When the source explicitly calls `nn.init.zeros_` on a layer, use
`nn.initializers.zeros_init()`. When the source uses bare `nn.Linear()`
with no explicit init, use the Flax default (lecun_normal) or
`nn.initializers.normal(stddev=config.initializer_range)` -- do NOT use
zeros_init unless the source explicitly initializes to zeros.
RMSNorm (1+w): `nn.initializers.zeros_init()`.
RMSNorm (w): `nn.initializers.ones_init()`.
Check each nn.Parameter in the source and match its init.
14. **Train/Eval Mode**: Flax modules do NOT have a `.train` attribute or
`.eval()` / `.train()` methods. NEVER write `model.train = True` or
`model.train = False` -- this does nothing in Flax and silently produces
incorrect behavior. Instead, pass `deterministic=False` for training and
`deterministic=True` for evaluation as an argument to `__call__` /
`model.apply()`. All stochastic layers (Dropout, router noise) must
check the `deterministic` flag.
15. **Preserve ALL Source Components**: Convert EVERY class, function, and
method from the source. Do NOT merge base classes into subclasses, do NOT
drop utility classes or metric functions, and do NOT omit `get_config()`
or serialization methods. If the source has `ExpertBase` and `FFNExpert`,
convert both. If the source has a `MoEMetrics` class, convert it.
16. **Preserve Default Values Exactly**: All default parameter values in the
JAX output must match the PyTorch source EXACTLY. Do NOT change any numeric
default -- not capacity factors, not dropout rates, not epsilon values, not
learning rates, not layer counts. Even if you believe a different value is
"better" or "more stable", use the source value. Changed defaults silently
alter model behavior and break reproducibility.
17. **Preserve Exact Reduction Operations**: When the source uses `.mean()`,
use `jnp.mean()`. When the source uses `.sum()`, use `jnp.sum()`. NEVER
substitute one reduction for another. `torch.mean(x, dim=N)` maps to
`jnp.mean(x, axis=N)`. `torch.sum(x, dim=N)` maps to `jnp.sum(x, axis=N)`.
The dim/axis integer stays the same.
18. **Preserve Method Placement**: If the source defines a method or attribute
on a specific class, keep it on that class in the JAX output. Do NOT
relocate methods between classes or replace instance methods with
standalone functions unless the JAX idiom requires it.

## CRITICAL: Faithfulness to Source Code

This is a TRANSLATION, not a redesign. The converted code must produce
IDENTICAL behavior to the source for the same inputs and weights.

NEVER simplify complex tensor reshaping, reordering, or algorithmic patterns
from the source code. If the PyTorch code uses a specific interleaved weight
layout, chunk-parallel algorithm, or multi-step computation, convert it
faithfully to JAX. The RAG context shows EXAMPLES of similar patterns -- use
them as guidance for JAX idioms, but always follow the ACTUAL source code's
logic and structure.

NEVER "improve" the source by changing default values, adding initializers
that the source does not use, substituting reductions (.sum vs .mean), or
dropping components you consider non-essential (logging, metrics, utility
classes). If the source has it, the output must have it.
"""

PYTORCH_TO_JAX_SINGLE_FILE_PROMPT = """You are an expert in JAX and PyTorch.
Expand Down Expand Up @@ -151,6 +195,73 @@
linear attention, implement BOTH modes and dispatch based on sequence length.
5. Implement causal_conv1d as a standalone function with both prefill and
single-step decode paths.
6. For causal operations with decode-time state (causal conv1d, linear
attention), implement SEPARATE prefill and decode functions. Do NOT use
a single unified function with conditional branching.
7. ALWAYS include a `@dataclasses.dataclass` Config class at the top of the
output file. Mirror ALL fields from the PyTorch configuration class with
their types and default values. Use `dataclasses.field(default_factory=...)`
for mutable defaults. Use the Config type (not `Any`) in module annotations.
8. The `load_balancing_loss` function MUST accept an optional `attention_mask`
parameter. When the mask is provided, broadcast it to match the concatenated
router logits shape and use it to exclude padding tokens from mean/sum
statistics. See the RAG context for the full pattern.
9. **MoE Experts: Capacity-Based Dispatch (MANDATORY)**. The Experts class MUST
use capacity-based dispatch with dispatch/combine tensors -- NOT per-token
gather of expert weights. The correct pattern is:
a) Compute per-expert capacity: `capacity = ceil(T * K / E) * 1.5`
b) Build dispatch tensor via `one_hot(selected_experts) -> cumsum -> positions
-> one_hot(positions, capacity)` to get `dispatch: [T, E, C]`
c) Build combine tensor: `combine = dispatch * routing_weights`
d) Route tokens to expert buffers: `expert_in = einsum('tec,th->ech', dispatch, x)`
e) Batched expert matmul: `expert_out = einsum('ech,ehi->eci', expert_in, W)`
f) Scatter back: `output = einsum('tec,ech->th', combine, expert_out)`
Do NOT use `weight[flat_indices]` gather or `jax.vmap` over individual experts.
Do NOT use `jnp.einsum('td,edh->teh')` computing all experts for all tokens.
The capacity-based approach is 10-50x more efficient for large E (e.g. E=64).
See the RAG context file `targeted_moe_capacity_routing_jax.py` for the full
implementation with WRONG/CORRECT examples.
10. **KV Cache: Pure Functional NamedTuple (MANDATORY)**. All KV caches MUST be
NamedTuple objects passed as function arguments and returned as outputs.
Do NOT use Flax mutable variables (`self.variable('cache', ...)`).
Do NOT use config dicts with init flags.
For encoder-decoder models: use SEPARATE self_attn_cache and cross_attn_cache
arguments per layer. Cross-attention caches are populated once from encoder
output and passed through unchanged on subsequent decode steps.
Provide an `init_kv_caches()` helper function that pre-allocates all layer
caches. This replaces PyTorch's `install_kv_cache_hooks()`.
See the RAG context for the full encoder-decoder cache pattern.
11. **Tied Output Projection**: When the PyTorch source computes logits via
`x @ self.token_embedding.weight.T`, convert it to
`(x @ token_embedding.embedding.T).astype(jnp.float32)`.
Do NOT use `token_embedding.attend(x)` -- that is for embedding lookup,
not linear projection, and may produce different results.
12. **Fused QKV Projection**: When the PyTorch source uses a single
`in_proj_weight` of shape [3*embed_dim, embed_dim] with sliced projection
methods (in_proj_qkv, in_proj_q, in_proj_kv), preserve this as a SINGLE
parameter with sliced access in JAX. Do NOT split into 3 separate nn.Dense
layers. Use `self.param('in_proj_weight', init, (3*D, D))` and slice it
for Q [0:D], K [D:2D], V [2D:3D]. Provide in_proj_qkv(), in_proj_q(),
in_proj_kv() methods matching the PyTorch API.
13. **Float32 Softmax Upcast (MANDATORY)**: When the PyTorch source uses
`.float()` or `dtype=torch.float32` before softmax, you MUST preserve this
in JAX: `jax.nn.softmax(attn_weights.astype(jnp.float32), axis=-1)` then
cast back with `.astype(q.dtype)`. This is critical for numerical stability
in bfloat16/float16. NEVER omit this upcast.
14. **Preserve ALL Source Components (MANDATORY)**: The output MUST contain a
JAX equivalent for EVERY class, function, method, and utility in the source.
Do NOT merge base classes into subclasses. Do NOT drop get_config() or
serialization methods. Do NOT omit utility classes (e.g., metrics classes)
or standalone functions (e.g., metric computation functions). If the source
has N classes and M functions, the output must have N classes and M functions.
15. **Preserve Default Values Exactly**: All constructor defaults, config
defaults, and hyperparameter defaults MUST match the PyTorch source exactly.
Do NOT change capacity_factor, dropout rates, noise epsilon, num_layers,
or any other default value -- even if you think a different value is better.
16. **Train/Eval Mode in Flax**: NEVER set `model.train = True/False` or call
`model.eval()` / `model.train()` in training loops. Flax has no such
attributes. Use `deterministic=False` for training and `deterministic=True`
for evaluation, passed as an argument to the module's `__call__` method.

Please think step by step about the conversion process before generating the code.
Then, provide the complete JAX equivalent of the entire file above.
Expand Down
Loading