From c3868b3635bfb2e9a8c802701b6cf84a298cb993 Mon Sep 17 00:00:00 2001 From: "gvanica@gmail.com" Date: Thu, 2 Apr 2026 09:26:09 -0700 Subject: [PATCH 1/3] Fix maxcode with improved logic and rag --- .../migration/model_conversion_agent.py | 2 +- MaxCode/agents/migration/primary_agent.py | 74 ++++++++++++++++--- MaxCode/agents/migration/prompts/prompts.py | 38 +++++++++- MaxCode/agents/migration/single_file_agent.py | 2 +- MaxCode/models.py | 37 +++++++--- MaxCode/rag/rag_agent.py | 6 +- 6 files changed, 129 insertions(+), 30 deletions(-) diff --git a/MaxCode/agents/migration/model_conversion_agent.py b/MaxCode/agents/migration/model_conversion_agent.py index 4013f6e..1cd49b2 100644 --- a/MaxCode/agents/migration/model_conversion_agent.py +++ b/MaxCode/agents/migration/model_conversion_agent.py @@ -35,7 +35,7 @@ def run(self, pytorch_model_code: str) -> str: Returns: The converted JAX code. """ - rag_context_list = self._rag_agent.retrieve_context(pytorch_model_code) + rag_context_list = self._rag_agent.retrieve_context(pytorch_model_code, top_k=7) rag_context = "\n\n".join([ f"File: {c['file']}\n```python\n{c['text']}\n```" for c in rag_context_list diff --git a/MaxCode/agents/migration/primary_agent.py b/MaxCode/agents/migration/primary_agent.py index 3bc0398..e428647 100644 --- a/MaxCode/agents/migration/primary_agent.py +++ b/MaxCode/agents/migration/primary_agent.py @@ -1,7 +1,9 @@ """Primary orchestration agent for repository migration.""" +import ast import os -from typing import Any, Dict +from collections import deque +from typing import Any, Dict, List, Set import models from agents import base @@ -11,6 +13,54 @@ from rag import rag_agent +def _is_model_file(code: str) -> bool: + """Detects whether code contains a torch.nn.Module subclass definition.""" + try: + tree = ast.parse(code) + except SyntaxError: + return False + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef): + for base_node in node.bases: + # Match nn.Module, torch.nn.Module, Module + if isinstance(base_node, ast.Attribute): + if base_node.attr == "Module": + return True + elif isinstance(base_node, ast.Name): + if base_node.id == "Module": + return True + return False + + +def _topological_sort(graph: Dict[str, Set[str]]) -> List[str]: + """Returns files in dependency order (dependencies first) using Kahn's algorithm.""" + in_degree = {node: 0 for node in graph} + for node, deps in graph.items(): + for dep in deps: + if dep in in_degree: + in_degree[node] += 1 + + queue = deque(node for node, deg in in_degree.items() if deg == 0) + result = [] + + while queue: + node = queue.popleft() + result.append(node) + # Find nodes that depend on this one and decrement their in-degree + for other, deps in graph.items(): + if node in deps: + in_degree[other] -= 1 + if in_degree[other] == 0: + queue.append(other) + + # Append any remaining nodes (cycles) to avoid dropping files + for node in graph: + if node not in result: + result.append(node) + + return result + + class PrimaryAgent(base.Agent): """Primary orchestration agent for repository migration.""" @@ -23,7 +73,7 @@ def __init__(self, model: Any, api_key: str | None = None): ) self._rag_agent = rag_agent.RAGAgent( model, - embedding_model_name=models.EmbeddingModel.TEXT_EMBEDDING_004, + embedding_model_name=models.EmbeddingModel.GEMINI_EMBEDDING_001, api_key=api_key, ) self._single_file_agent = single_file_agent.PytorchToJaxSingleFileAgent( @@ -33,6 +83,12 @@ def __init__(self, model: Any, api_key: str | None = None): model, self._rag_agent ) + def _convert_file(self, pytorch_code: str) -> str: + """Routes a file to the appropriate conversion agent.""" + if _is_model_file(pytorch_code): + return self._model_conversion_agent.run(pytorch_code) + return self._single_file_agent.run(pytorch_code) + def run(self, repo_path: str) -> Dict[str, str]: """Orchestrates the migration of a repository from PyTorch to JAX. @@ -43,9 +99,9 @@ def run(self, repo_path: str) -> Dict[str, str]: A dictionary mapping original file paths to converted JAX code. """ if os.path.isfile(repo_path): - with open(repo_path, "r") as f: + with open(repo_path, "r", encoding="utf-8", errors="replace") as f: pytorch_code = f.read() - converted_code = self._single_file_agent.run(pytorch_code) + converted_code = self._convert_file(pytorch_code) return {repo_path: converted_code} elif not os.path.isdir(repo_path): return { @@ -53,16 +109,14 @@ def run(self, repo_path: str) -> Dict[str, str]: } graph = utils.build_dependency_graph(repo_path) + ordered_files = _topological_sort(graph) converted_files: Dict[str, str] = {} - # conversion order. - # model_conversion_agent for model files, single_file_agent for others). - # For now, convert files individually using single_file_agent. - for file_rel_path in graph: + for file_rel_path in ordered_files: file_path = os.path.join(repo_path, file_rel_path) - with open(file_path, "r") as f: + with open(file_path, "r", encoding="utf-8", errors="replace") as f: pytorch_code = f.read() - converted_code = self._single_file_agent.run(pytorch_code) + converted_code = self._convert_file(pytorch_code) converted_files[file_path] = converted_code return converted_files diff --git a/MaxCode/agents/migration/prompts/prompts.py b/MaxCode/agents/migration/prompts/prompts.py index e8db050..94afec6 100644 --- a/MaxCode/agents/migration/prompts/prompts.py +++ b/MaxCode/agents/migration/prompts/prompts.py @@ -1,5 +1,35 @@ """Prompts for code migration.""" +JAX_BEST_PRACTICES = """ +## JAX/Flax Best Practices (MUST follow) + +1. **Use Flax Linen with @nn.compact**: Define all submodules inline inside + `@nn.compact def __call__`. Do NOT use a separate `setup()` method or NNX. +2. **KV Cache**: Use pre-allocated fixed-size caches updated via + `jax.lax.dynamic_update_slice`. NEVER grow the cache with `jnp.concatenate` + or Python list appends -- that breaks XLA compilation. +3. **Causal Conv1d**: Use `padding="VALID"` with explicit left-padding + (`jnp.pad` with `pad_width` on the time axis). Do NOT use `padding="SAME"` + as it is non-causal and leaks future information. +4. **Standalone Imports**: Only import from `jax`, `jax.numpy`, `flax.linen`, + `numpy`, and `math`. Do NOT import from `torch`, `transformers`, or any + PyTorch library in the output. +5. **Static Shapes**: All tensor shapes must be determinable at trace time for + `jax.jit` compatibility. Avoid data-dependent shapes, Python loops over + dynamic lengths, and boolean indexing with data-dependent masks. +6. **Variable Ordering**: Define every variable before its first use. No forward + references -- JAX traces sequentially and undefined names cause errors. +7. **Attention Masking**: Use additive masking: add 0.0 where attention is + allowed and a large negative value (e.g., -1e9 or `jnp.finfo(dtype).min`) + where it should be blocked. Do NOT use multiplicative boolean masks. +8. **RMS Norm**: Implement as `x * jax.lax.rsqrt(mean(x^2) + eps) * weight`. + Do NOT call `torch.nn.functional` or leave PyTorch API calls. +9. **Activation Functions**: Use `jax.nn.silu`, `jax.nn.gelu`, etc. Map + `F.silu` -> `jax.nn.silu`, `F.gelu` -> `jax.nn.gelu`. +10. **Rotary Embeddings**: Precompute `cos` and `sin` tables. Apply as + `(x * cos) + (rotate_half(x) * sin)`. Shapes must broadcast correctly. +""" + PYTORCH_TO_JAX_SINGLE_FILE_PROMPT = """You are an expert in JAX and PyTorch. Your task is to convert the following PyTorch code to JAX. If it is helpful, you can use the following JAX code snippets as context for @@ -17,7 +47,7 @@ Ensure that the JAX code is idiomatic and follows best practices, such as using pure functions and handling random number generation correctly with JAX's PRNG keys. Only return the Python code block for the JAX implementation. -""" +""" + JAX_BEST_PRACTICES PYTORCH_TO_JAX_REPO_PROMPT = """You are an expert in JAX and PyTorch. Your task is to convert a repository from PyTorch to JAX. You will be given a file path @@ -41,7 +71,7 @@ keys. The conversion should maintain compatibility with other files in the repository, assuming they will also be converted to JAX. Only return the Python code block for the JAX implementation. -""" +""" + JAX_BEST_PRACTICES HF_TO_JAX_SINGLE_FILE_PROMPT = """You are an expert in JAX and PyTorch, with special expertise in HuggingFace Transformers. Your task is to convert the @@ -62,7 +92,7 @@ idiomatic and follows best practices, such as using pure functions and handling random number generation correctly with JAX's PRNG keys. Only return the Python code block for the JAX implementation. -""" +""" + JAX_BEST_PRACTICES MODEL_CONVERSION_PROMPT = """You are an expert in JAX and PyTorch model architectures. Your task is to convert the following PyTorch model definition @@ -83,4 +113,4 @@ models in JAX, such as using pure functions and handling random number generation correctly with JAX's PRNG keys. Only return the Python code block for the JAX implementation. -""" +""" + JAX_BEST_PRACTICES diff --git a/MaxCode/agents/migration/single_file_agent.py b/MaxCode/agents/migration/single_file_agent.py index 8e0346a..733c027 100644 --- a/MaxCode/agents/migration/single_file_agent.py +++ b/MaxCode/agents/migration/single_file_agent.py @@ -46,7 +46,7 @@ def run(self, pytorch_code: str) -> str: Returns: The converted JAX code. """ - rag_context_list = self._rag_agent.retrieve_context(pytorch_code) + rag_context_list = self._rag_agent.retrieve_context(pytorch_code, top_k=7) rag_context = "\n\n".join([ f"File: {c['file']}\n```python\n{c['text']}\n```" for c in rag_context_list diff --git a/MaxCode/models.py b/MaxCode/models.py index fcff03b..46c5949 100644 --- a/MaxCode/models.py +++ b/MaxCode/models.py @@ -2,6 +2,7 @@ import enum import logging +import random import time import requests @@ -19,6 +20,7 @@ class EmbeddingModel(enum.Enum): """Enum for Embedding model names.""" EMBEDDING_001 = "models/embedding-001" + GEMINI_EMBEDDING_001 = "models/gemini-embedding-001" TEXT_EMBEDDING_004 = "models/text-embedding-004" @@ -72,18 +74,29 @@ def __call__(self, user_prompt): "parts": [{"text": self.system_instruction}] } - try: - time.sleep(2) - response = requests.post(self.endpoint, headers=headers, json=payload) - response.raise_for_status() # Raise HTTPError for bad responses - json_response = response.json() - return json_response["candidates"][0]["content"]["parts"][0]["text"] - except requests.exceptions.RequestException as e: - logging.error("Error calling Gemini API: %s", e) - raise - except (KeyError, IndexError) as e: - logging.error("Error parsing Gemini API response: %s", e) - raise ValueError("Could not parse response from Gemini API.") from e + max_retries = 5 + for attempt in range(max_retries): + try: + time.sleep(2 if attempt == 0 else min(30, 2 ** (attempt + 1))) + response = requests.post( + self.endpoint, headers=headers, json=payload, timeout=300 + ) + response.raise_for_status() # Raise HTTPError for bad responses + json_response = response.json() + return json_response["candidates"][0]["content"]["parts"][0]["text"] + except requests.exceptions.RequestException as e: + status = getattr(getattr(e, "response", None), "status_code", None) + if status in (429, 500, 503) and attempt < max_retries - 1: + wait = min(60, 2 ** (attempt + 2)) + random.uniform(0, 2) + logging.warning("Gemini API %s, retrying in %.1fs (attempt %d/%d)...", + status, wait, attempt + 1, max_retries) + time.sleep(wait) + continue + logging.error("Error calling Gemini API: %s", e) + raise + except (KeyError, IndexError) as e: + logging.error("Error parsing Gemini API response: %s", e) + raise ValueError("Could not parse response from Gemini API.") from e def generate(self, user_prompt: str) -> str: """Alias for __call__ to support agents expecting a generate method.""" diff --git a/MaxCode/rag/rag_agent.py b/MaxCode/rag/rag_agent.py index 7f03b61..e788ea1 100644 --- a/MaxCode/rag/rag_agent.py +++ b/MaxCode/rag/rag_agent.py @@ -17,7 +17,7 @@ # way to get the max context length in characters, 20000 characters # (roughly 5000-7000 tokens) is a safe limit for models with 32k token limits, # when considering that the prompt sends file content in two fields. -_MAX_CONTEXT_LENGTH = 20000 +_MAX_CONTEXT_LENGTH = 100000 class RAGAgent(base.Agent): @@ -59,7 +59,7 @@ def build_from_directory(self, source_path: str): if filename.endswith(".py"): file_path = os.path.join(root, filename) try: - with open(file_path, "r") as f: + with open(file_path, "r", encoding="utf-8", errors="replace") as f: content = f.read() doc_name = os.path.relpath(file_path, source_path) print(f"Adding {doc_name} to RAG database...") @@ -100,6 +100,8 @@ def retrieve_context( A list of dictionaries, each containing 'name', 'text', 'file', and 'distance' for a retrieved document. """ + if self._index is None: + return [] query_embedding = self._embedding_agent.embed(query) results = vector_db.search_embedding( np.array(query_embedding), self._index, self._texts, top_k=top_k From ceae8fb88a3d9b5ebaa96d893bb0e6118e027aa3 Mon Sep 17 00:00:00 2001 From: "gvanica@gmail.com" Date: Thu, 2 Apr 2026 09:26:33 -0700 Subject: [PATCH 2/3] add new prompts, update model --- .../migration/model_conversion_agent.py | 13 ++++- MaxCode/agents/migration/prompts/prompts.py | 55 +++++++++++++++++-- MaxCode/models.py | 3 +- 3 files changed, 64 insertions(+), 7 deletions(-) diff --git a/MaxCode/agents/migration/model_conversion_agent.py b/MaxCode/agents/migration/model_conversion_agent.py index 1cd49b2..33d70ff 100644 --- a/MaxCode/agents/migration/model_conversion_agent.py +++ b/MaxCode/agents/migration/model_conversion_agent.py @@ -1,5 +1,6 @@ """Agent for converting a model from PyTorch to JAX.""" +import re from typing import Any from agents import base @@ -40,7 +41,17 @@ def run(self, pytorch_model_code: str) -> str: f"File: {c['file']}\n```python\n{c['text']}\n```" for c in rag_context_list ]) - return self.generate( + generated_code = self.generate( prompts.MODEL_CONVERSION_PROMPT, {"pytorch_model_code": pytorch_model_code, "rag_context": rag_context}, ) + return self._strip_markdown_formatting(generated_code) + + def _strip_markdown_formatting(self, text: str) -> str: + """Strips markdown and returns only the first python code block.""" + code_block_match = re.search( + r"```(?:python)?\n?(.*?)\n?```", text, re.DOTALL + ) + if code_block_match: + return code_block_match.group(1).strip() + return text diff --git a/MaxCode/agents/migration/prompts/prompts.py b/MaxCode/agents/migration/prompts/prompts.py index 94afec6..32dd2f7 100644 --- a/MaxCode/agents/migration/prompts/prompts.py +++ b/MaxCode/agents/migration/prompts/prompts.py @@ -10,13 +10,18 @@ or Python list appends -- that breaks XLA compilation. 3. **Causal Conv1d**: Use `padding="VALID"` with explicit left-padding (`jnp.pad` with `pad_width` on the time axis). Do NOT use `padding="SAME"` - as it is non-causal and leaks future information. + as it is non-causal and leaks future information. Implement as a standalone + function with both prefill (full sequence) and decode (single-step with + conv_state) paths. Use `jax.lax.conv_general_dilated` with + `feature_group_count=channels` for depthwise convolution. 4. **Standalone Imports**: Only import from `jax`, `jax.numpy`, `flax.linen`, `numpy`, and `math`. Do NOT import from `torch`, `transformers`, or any PyTorch library in the output. 5. **Static Shapes**: All tensor shapes must be determinable at trace time for `jax.jit` compatibility. Avoid data-dependent shapes, Python loops over dynamic lengths, and boolean indexing with data-dependent masks. + CRITICAL: Never use `array[..., :i]` where `i` is dynamic inside + `jax.lax.scan` -- this creates dynamically-sized slices that break JIT. 6. **Variable Ordering**: Define every variable before its first use. No forward references -- JAX traces sequentially and undefined names cause errors. 7. **Attention Masking**: Use additive masking: add 0.0 where attention is @@ -28,6 +33,26 @@ `F.silu` -> `jax.nn.silu`, `F.gelu` -> `jax.nn.gelu`. 10. **Rotary Embeddings**: Precompute `cos` and `sin` tables. Apply as `(x * cos) + (rotate_half(x) * sin)`. Shapes must broadcast correctly. +11. **Triangular Matrix Inversions**: When the PyTorch code has a for-loop + computing a Neumann series on a lower-triangular matrix (e.g., the WY + representation in chunk-parallel delta rules), convert it to + `jax.scipy.linalg.solve_triangular(I - W, I, lower=True)`. This computes + `(I - W)^{{-1}}` directly and is JIT-safe, unlike a scan with dynamic slicing. +12. **Interleaved Weight Ordering**: When the source code has a + `fix_query_key_value_ordering` function or groups projections by key heads + (e.g., when num_key_heads != num_value_heads), you MUST preserve this + ordering exactly. Reshape to [B, T, num_k_heads, per_head_size] and split + within each group. NEVER flatten to a single dimension and do a flat split + -- this produces wrong tensors when num_k_heads != num_v_heads. + +## CRITICAL: Faithfulness to Source Code + +NEVER simplify complex tensor reshaping, reordering, or algorithmic patterns +from the source code. If the PyTorch code uses a specific interleaved weight +layout, chunk-parallel algorithm, or multi-step computation, convert it +faithfully to JAX. The RAG context shows EXAMPLES of similar patterns -- use +them as guidance for JAX idioms, but always follow the ACTUAL source code's +logic and structure. """ PYTORCH_TO_JAX_SINGLE_FILE_PROMPT = """You are an expert in JAX and PyTorch. @@ -95,20 +120,40 @@ """ + JAX_BEST_PRACTICES MODEL_CONVERSION_PROMPT = """You are an expert in JAX and PyTorch model -architectures. Your task is to convert the following PyTorch model definition -to a JAX-based equivalent, using libraries such as Flax. +architectures. Your task is to convert the ENTIRE PyTorch file below to a +single JAX/Flax file. You MUST convert ALL classes, helper functions, +constants, and configuration dataclasses -- not just one class. + If it is helpful, you can use the following JAX code snippets as context for functionality that might be similar to your conversion task: --- {rag_context} --- -PyTorch model: +PyTorch model file: ```python {pytorch_model_code} ``` +IMPORTANT CONVERSION RULES: +1. Convert EVERY class and function in the file above. The output must include + JAX equivalents for all nn.Module subclasses, all helper functions (rotary + embeddings, attention masking, loss functions, etc.), and all supporting code. +2. If the source has a `fix_query_key_value_ordering` method or groups QKVZ + projections by key heads, convert it FAITHFULLY. Reshape to + [B, T, num_k_heads, ...] and split within each key-head group. Do NOT + replace it with a flat split -- that produces wrong tensors when + num_k_heads != num_v_heads. +3. If the source has a chunk-parallel delta rule with a for-loop computing a + Neumann series (WY representation), convert it using + `jax.scipy.linalg.solve_triangular(I - W, I, lower=True)` instead of + jax.lax.scan with dynamic slicing. See the RAG context for the pattern. +4. If the source has both a chunk (prefill) and recurrent (decode) mode for + linear attention, implement BOTH modes and dispatch based on sequence length. +5. Implement causal_conv1d as a standalone function with both prefill and + single-step decode paths. + Please think step by step about the conversion process before generating the code. -Then, provide the JAX equivalent of the model definition above. +Then, provide the complete JAX equivalent of the entire file above. Ensure that the JAX code is idiomatic and follows best practices for defining models in JAX, such as using pure functions and handling random number generation correctly with JAX's PRNG keys. diff --git a/MaxCode/models.py b/MaxCode/models.py index 46c5949..b2d0738 100644 --- a/MaxCode/models.py +++ b/MaxCode/models.py @@ -14,6 +14,7 @@ class GeminiModel(enum.Enum): GEMINI_2_5_FLASH = "gemini-2.5-flash" GEMINI_3_0_PRO = "gemini-3.0-pro" GEMINI_3_0_FLASH = "gemini-3.0-flash" + GEMINI_3_1_PRO_PREVIEW = "gemini-3.1-pro-preview" class EmbeddingModel(enum.Enum): @@ -79,7 +80,7 @@ def __call__(self, user_prompt): try: time.sleep(2 if attempt == 0 else min(30, 2 ** (attempt + 1))) response = requests.post( - self.endpoint, headers=headers, json=payload, timeout=300 + self.endpoint, headers=headers, json=payload, timeout=600 ) response.raise_for_status() # Raise HTTPError for bad responses json_response = response.json() From a0cb18777d12e7098bdaa4daa9ae0e04aadb4c71 Mon Sep 17 00:00:00 2001 From: "gvanica@gmail.com" Date: Thu, 2 Apr 2026 21:30:56 -0700 Subject: [PATCH 3/3] improve moe dispatch --- MaxCode/agents/migration/prompts/prompts.py | 32 +++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/MaxCode/agents/migration/prompts/prompts.py b/MaxCode/agents/migration/prompts/prompts.py index 32dd2f7..3542887 100644 --- a/MaxCode/agents/migration/prompts/prompts.py +++ b/MaxCode/agents/migration/prompts/prompts.py @@ -44,6 +44,12 @@ ordering exactly. Reshape to [B, T, num_k_heads, per_head_size] and split within each group. NEVER flatten to a single dimension and do a flat split -- this produces wrong tensors when num_k_heads != num_v_heads. +13. **Weight Initialization**: Match PyTorch initialization exactly. + MoE router: `nn.initializers.zeros_init()` (NOT normal). + RMSNorm (1+w): `nn.initializers.zeros_init()`. + RMSNorm (w): `nn.initializers.ones_init()`. + Dense projections: `nn.initializers.normal(stddev=config.initializer_range)`. + Check each nn.Parameter in the source and match its init. ## CRITICAL: Faithfulness to Source Code @@ -151,6 +157,32 @@ linear attention, implement BOTH modes and dispatch based on sequence length. 5. Implement causal_conv1d as a standalone function with both prefill and single-step decode paths. +6. For causal operations with decode-time state (causal conv1d, linear + attention), implement SEPARATE prefill and decode functions. Do NOT use + a single unified function with conditional branching. +7. ALWAYS include a `@dataclasses.dataclass` Config class at the top of the + output file. Mirror ALL fields from the PyTorch configuration class with + their types and default values. Use `dataclasses.field(default_factory=...)` + for mutable defaults. Use the Config type (not `Any`) in module annotations. +8. The `load_balancing_loss` function MUST accept an optional `attention_mask` + parameter. When the mask is provided, broadcast it to match the concatenated + router logits shape and use it to exclude padding tokens from mean/sum + statistics. See the RAG context for the full pattern. +9. **MoE Experts: Capacity-Based Dispatch (MANDATORY)**. The Experts class MUST + use capacity-based dispatch with dispatch/combine tensors -- NOT per-token + gather of expert weights. The correct pattern is: + a) Compute per-expert capacity: `capacity = ceil(T * K / E) * 1.5` + b) Build dispatch tensor via `one_hot(selected_experts) -> cumsum -> positions + -> one_hot(positions, capacity)` to get `dispatch: [T, E, C]` + c) Build combine tensor: `combine = dispatch * routing_weights` + d) Route tokens to expert buffers: `expert_in = einsum('tec,th->ech', dispatch, x)` + e) Batched expert matmul: `expert_out = einsum('ech,ehi->eci', expert_in, W)` + f) Scatter back: `output = einsum('tec,ech->th', combine, expert_out)` + Do NOT use `weight[flat_indices]` gather or `jax.vmap` over individual experts. + Do NOT use `jnp.einsum('td,edh->teh')` computing all experts for all tokens. + The capacity-based approach is 10-50x more efficient for large E (e.g. E=64). + See the RAG context file `targeted_moe_capacity_routing_jax.py` for the full + implementation with WRONG/CORRECT examples. Please think step by step about the conversion process before generating the code. Then, provide the complete JAX equivalent of the entire file above.