Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""Configuration constants for GPU to JAX conversion agent."""

MODEL_NAME = "gemini-3-pro-preview"
MAX_ITERATIONS = 5
LLM_GEN_RETRY_COUNT = 3
TEMPERATURE = 0.1
TOP_P = 0.9
TOP_K = 5
from tpu_kernel_gen.agents.kernel_gen_agent import constants

MODEL_NAME = constants.MODEL_NAME
MAX_ITERATIONS = constants.MAX_ITERATIONS
LLM_GEN_RETRY_COUNT = constants.LLM_GEN_RETRY_COUNT
TEMPERATURE = constants.TEMPERATURE
TOP_P = constants.TOP_P
TOP_K = constants.TOP_K
CONVERSION_TIMEOUT = 180 # Timeout for conversion attempts
EVAL_SERVER_PORT = 1245
EVAL_SERVER_PORT = constants.EVAL_SERVER_PORT
NUMERICAL_TOLERANCE = 1e-5 # Tolerance for numerical correctness checks

# Backend selection for evaluation
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
MODEL_NAME = "gemini-3-pro-preview"
MODEL_NAME = "gemini-3.1-pro-preview"
MAX_ITERATIONS = 5
LLM_GEN_RETRY_COUNT = 3
TEMPERATURE = 0.1
Expand Down