diff --git a/pathwaysutils/experimental/shared_pathways_service/deploy_pathways_service.py b/pathwaysutils/experimental/shared_pathways_service/deploy_pathways_service.py new file mode 100644 index 0000000..88a4618 --- /dev/null +++ b/pathwaysutils/experimental/shared_pathways_service/deploy_pathways_service.py @@ -0,0 +1,209 @@ +"""Deploys Pathways service to a Kubernetes cluster using a JobSet template.""" + +from collections.abc import Sequence +import logging +import math +import os +import string +from absl import app +from absl import flags +from kubernetes import client +from kubernetes import config +import yaml + +_logger = logging.getLogger(__name__) + +# Flag definitions +FLAGS = flags.FLAGS +_JOBSET_NAME = flags.DEFINE_string( + "jobset_name", "pathways-service", "Name of the JobSet" +) +_JAX_VERSION = flags.DEFINE_string( + "jax_version", "0.9.0", "JAX version (e.g., 0.9.0)" +) +_TPU_TYPE = flags.DEFINE_enum( + "tpu_type", "v6e", ["v5e", "v5p", "v6e", "tpu7x"], "TPU type" +) +_TOPOLOGY = flags.DEFINE_string( + "topology", "2x2", "TPU topology (e.g., 4x8, 2x2x2)" +) +_NUM_SLICES = flags.DEFINE_integer( + "num_slices", 2, "Number of TPU slices" +) +_GCS_BUCKET = flags.DEFINE_string( + "gcs_bucket", + "gs://pathways-test-bucket", + "GCS bucket name for scratch space", +) +_TEMPLATE_FILE = flags.DEFINE_string( + "template_file", + os.path.join( + os.path.dirname(__file__), "yamls/pw-service-example.yaml", + ), + "Path to the JobSet YAML template file", +) +_DRY_RUN = flags.DEFINE_boolean( + "dry_run", + False, + "If true, only print the generated YAML without deploying.", +) + + +def get_tpu_config(tpu_type): + """Returns a dictionary containing TPU configuration details.""" + tpu_configs = { + "v5e": { + "machine_type": "ct5lp-hightpu-4t", + "chips_per_vm": 4, + "accelerator_label": "tpu-v5-lite-podslice", + "instance_prefix": "tpuv5e", + }, + "v5p": { + "machine_type": "ct5p-hightpu-4t", + "chips_per_vm": 4, + "accelerator_label": "tpu-v5p-slice", + "instance_prefix": "tpuv5p", + }, + "v6e": { + "machine_type": "ct6e-standard-4t", + "chips_per_vm": 4, + "accelerator_label": "tpu-v6e-slice", + "instance_prefix": "tpuv6e", + }, + "tpu7x": { + "machine_type": "tpu7x-standard-4t", + "chips_per_vm": 4, + "accelerator_label": "tpu-v7-slice", + "instance_prefix": "tpu7x", + }, + } + if tpu_type not in tpu_configs: + raise ValueError( + f"Unsupported TPU type: {tpu_type}. Supported types are:" + f" {list(tpu_configs.keys())}" + ) + return tpu_configs[tpu_type] + + +def calculate_vms_per_slice(topology, chips_per_vm): + """Calculates the number of VMs per slice based on the topology.""" + try: + dims = [int(d) for d in topology.split("x")] + total_chips = math.prod(dims) + if total_chips % chips_per_vm != 0: + raise ValueError( + f"Total chips ({total_chips}) in topology {topology} is not divisible" + f" by chips_per_vm ({chips_per_vm})" + ) + return total_chips // chips_per_vm + except ValueError as e: + raise ValueError( + f"Invalid topology format: {topology}. Expected format like 'AxB' or" + f" 'AxBxC'. {e}" + ) from e + + +def load_and_substitute_template(template_path, context): + """Loads and substitutes the string.Template from the given path.""" + try: + with open(template_path, "r") as f: + template_str = f.read() + except OSError as err: + raise ValueError( + f"Could not read template file: {template_path}: {err}" + ) from err + + _logger.info("Template file: %s", template_path) + _logger.info("Context: %s", context) + template = string.Template(template_str) + _logger.info("Template: %s", template) + substituted_yaml = template.substitute(context) + _logger.info("Substituted YAML: %s", substituted_yaml) + return yaml.safe_load(substituted_yaml) + + +def deploy_jobset(jobset_yaml): + """Deploys the JobSet to the current Kubernetes cluster.""" + try: + config.load_kube_config() + api = client.CustomObjectsApi() + api.create_namespaced_custom_object( + group="jobset.x-k8s.io", + version="v1alpha2", + namespace=jobset_yaml["metadata"]["namespace"], + body=jobset_yaml, + plural="jobsets", + ) + _logger.info( + "JobSet '%s' created successfully.", jobset_yaml["metadata"]["name"] + ) + except client.rest.ApiException as e: + _logger.error("Error creating JobSet: %s", e) + except config.ConfigException as e: + _logger.error("Error loading Kubernetes configuration: %s", e) +# TODO idea -- keep checking until up -- surface logs. + +def run_deployment( + tpu_type, + topology, + num_slices, + jobset_name, + gcs_bucket, + jax_version, + template_file, + dry_run, + deploy_func=deploy_jobset, +): + """Executes the deployment logic.""" + tpu_config = get_tpu_config(tpu_type) + vms_per_slice = calculate_vms_per_slice(topology, tpu_config["chips_per_vm"]) + + context = { + "JOBSET_NAME": jobset_name, + "JAX_VERSION": jax_version, + "GCS_SCRATCH_LOCATION": gcs_bucket, + "NUM_SLICES": num_slices, + "INSTANCE_TYPE": f"{tpu_config['instance_prefix']}:{topology}", + "VMS_PER_SLICE": vms_per_slice, + "CHIPS_PER_VM": tpu_config["chips_per_vm"], + "ACCELERATOR_LABEL": tpu_config["accelerator_label"], + "TOPOLOGY": topology, + } + + jobset_config = load_and_substitute_template(template_file, context) + + _logger.info("--- Generated JobSet YAML ---") + _logger.info("\n%s", yaml.dump(jobset_config)) + _logger.info("---") + + if not dry_run: + deploy_func(jobset_config) + else: + _logger.info("Dry run mode, not deploying.") + + +def main(argv: Sequence[str]) -> None: + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + try: + run_deployment( + tpu_type=_TPU_TYPE.value, + topology=_TOPOLOGY.value, + num_slices=_NUM_SLICES.value, + jobset_name=_JOBSET_NAME.value, + gcs_bucket=_GCS_BUCKET.value, + jax_version=_JAX_VERSION.value, + template_file=_TEMPLATE_FILE.value, + dry_run=_DRY_RUN.value, + ) + except ValueError as e: + _logger.exception("Error: %s", e) + except FileNotFoundError: + _logger.exception( + "Error: Template file not found at %s", _TEMPLATE_FILE.value + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/pathwaysutils/experimental/shared_pathways_service/yamls/pw-service-example.yaml b/pathwaysutils/experimental/shared_pathways_service/yamls/pw-service-example.yaml index cbff62f..9026bc2 100644 --- a/pathwaysutils/experimental/shared_pathways_service/yamls/pw-service-example.yaml +++ b/pathwaysutils/experimental/shared_pathways_service/yamls/pw-service-example.yaml @@ -1,18 +1,171 @@ -apiVersion: pathways-job.pathways.domain/v1 -kind: PathwaysJob +apiVersion: jobset.x-k8s.io/v1alpha2 +kind: JobSet metadata: - name: pathways-cluster # jobset name + name: ${JOBSET_NAME} + namespace: default spec: - maxRestarts: 1 - customComponents: - - componentType: pathways_server - image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:jax-0.8.0@sha256:ccbdf86d185654f8fb749f51ca7dcc8178377b583d75f74180eb936a8f808050 - - componentType: worker - image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:jax-0.8.0@sha256:ccbdf86d185654f8fb749f51ca7dcc8178377b583d75f74180eb936a8f808050 - workers: # Modify this section to use your TPU type, topology, number of slices and the GCS bucket. - - type: ct6e-standard-4t - topology: 2x2 - numSlices: 2 - pathwaysDir: "gs://pathways-bucket" # Pre-create this bucket. - controller: - deploymentMode: default + coordinator: + replicatedJob: pathways-head + failurePolicy: + maxRestarts: 1 + restartStrategy: Recreate + network: + enableDNSHostnames: true + publishNotReadyAddresses: true + replicatedJobs: + - name: pathways-head + replicas: 1 + template: + metadata: + annotations: + alpha.jobset.sigs.k8s.io/exclusive-topology: kubernetes.io/hostname + spec: + backoffLimit: 3 + completionMode: Indexed + completions: 1 + parallelism: 1 + template: + metadata: + annotations: + alpha.jobset.sigs.k8s.io/exclusive-topology: kubernetes.io/hostname + spec: + containers: + - name: pathways-rm + image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:jax-${JAX_VERSION} + imagePullPolicy: Always + args: + - --server_port=29001 + - --gcs_scratch_location=${GCS_SCRATCH_LOCATION} + - --node_type=resource_manager + - --instance_count=${NUM_SLICES} + - --instance_type=${INSTANCE_TYPE} + env: + - name: REPLICATED_JOB_NAME + valueFrom: + fieldRef: + fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name'] + - name: JOBSET_NAME + valueFrom: + fieldRef: + fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name'] + - name: HOST_ADDRESS + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator'] + - name: TPU_SKIP_MDS_QUERY + value: "true" + ports: + - containerPort: 29001 + protocol: TCP + - containerPort: 29002 + protocol: TCP + resources: + limits: + cpu: "8" + memory: 32G + - name: pathways-proxy + image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:jax-${JAX_VERSION} + imagePullPolicy: Always + args: + - --server_port=29000 + - --resource_manager_address=$$(PATHWAYS_HEAD):29001 + - --gcs_scratch_location=${GCS_SCRATCH_LOCATION} + env: + - name: PATHWAYS_HEAD + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator'] + ports: + - containerPort: 29000 + protocol: TCP + resources: + limits: + cpu: "16" + memory: 100G + dnsPolicy: ClusterFirstWithHostNet + hostNetwork: true + restartPolicy: OnFailure + - name: worker + replicas: ${NUM_SLICES} + template: + spec: + backoffLimit: 1000000 + completionMode: Indexed + completions: ${VMS_PER_SLICE} + parallelism: ${VMS_PER_SLICE} + template: + metadata: + annotations: + alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool + spec: + containers: + - name: pathways-worker + image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:jax-${JAX_VERSION} + imagePullPolicy: Always + args: + - --server_port=29005 + - --resource_manager_address=$$(PATHWAYS_HEAD):29001 + - --gcs_scratch_location=${GCS_SCRATCH_LOCATION} + env: + - name: TPU_MIN_LOG_LEVEL + value: "0" + - name: TF_CPP_MIN_LOG_LEVEL + value: "0" + - name: XCLOUD_ENVIRONMENT + value: GCP + - name: MEGASCALE_GRPC_ENABLE_XOR_TRACER + value: "false" + - name: MEGASCALE_NUM_SLICES + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/replicatedjob-replicas'] + - name: JOBSET_NAME + valueFrom: + fieldRef: + fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name'] + - name: REPLICATED_JOB_NAME + valueFrom: + fieldRef: + fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name'] + - name: MEGASCALE_SLICE_ID + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/job-index'] + - name: PATHWAYS_HEAD + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator'] + - name: MEGASCALE_COORDINATOR_ADDRESS + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator'] + ports: + - containerPort: 29005 + protocol: TCP + - containerPort: 29006 + protocol: TCP + - containerPort: 8471 + protocol: TCP + - containerPort: 8080 + protocol: TCP + resources: + limits: + google.com/tpu: "${CHIPS_PER_VM}" + volumeMounts: + - mountPath: /tmp + name: shared-tmp + dnsPolicy: ClusterFirstWithHostNet + hostNetwork: true + nodeSelector: + cloud.google.com/gke-tpu-accelerator: ${ACCELERATOR_LABEL} + cloud.google.com/gke-tpu-topology: ${TOPOLOGY} + restartPolicy: OnFailure + volumes: + - name: shared-tmp + hostPath: + path: /tmp + type: DirectoryOrCreate + startupPolicy: + startupPolicyOrder: InOrder + successPolicy: + operator: All