diff --git a/MaxKernel/tpu_kernel_gen/agents/hitl_kernel_gen_agent/subagents/gpu_to_jax_agent/constants.py b/MaxKernel/tpu_kernel_gen/agents/hitl_kernel_gen_agent/subagents/gpu_to_jax_agent/constants.py index fe484fb..282b477 100644 --- a/MaxKernel/tpu_kernel_gen/agents/hitl_kernel_gen_agent/subagents/gpu_to_jax_agent/constants.py +++ b/MaxKernel/tpu_kernel_gen/agents/hitl_kernel_gen_agent/subagents/gpu_to_jax_agent/constants.py @@ -1,13 +1,15 @@ """Configuration constants for GPU to JAX conversion agent.""" -MODEL_NAME = "gemini-3-pro-preview" -MAX_ITERATIONS = 5 -LLM_GEN_RETRY_COUNT = 3 -TEMPERATURE = 0.1 -TOP_P = 0.9 -TOP_K = 5 +from tpu_kernel_gen.agents.kernel_gen_agent import constants + +MODEL_NAME = constants.MODEL_NAME +MAX_ITERATIONS = constants.MAX_ITERATIONS +LLM_GEN_RETRY_COUNT = constants.LLM_GEN_RETRY_COUNT +TEMPERATURE = constants.TEMPERATURE +TOP_P = constants.TOP_P +TOP_K = constants.TOP_K CONVERSION_TIMEOUT = 180 # Timeout for conversion attempts -EVAL_SERVER_PORT = 1245 +EVAL_SERVER_PORT = constants.EVAL_SERVER_PORT NUMERICAL_TOLERANCE = 1e-5 # Tolerance for numerical correctness checks # Backend selection for evaluation diff --git a/MaxKernel/tpu_kernel_gen/agents/kernel_gen_agent/constants.py b/MaxKernel/tpu_kernel_gen/agents/kernel_gen_agent/constants.py index d900c91..ad1aec0 100644 --- a/MaxKernel/tpu_kernel_gen/agents/kernel_gen_agent/constants.py +++ b/MaxKernel/tpu_kernel_gen/agents/kernel_gen_agent/constants.py @@ -1,4 +1,4 @@ -MODEL_NAME = "gemini-3-pro-preview" +MODEL_NAME = "gemini-3.1-pro-preview" MAX_ITERATIONS = 5 LLM_GEN_RETRY_COUNT = 3 TEMPERATURE = 0.1