Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
"""Deploys Pathways service to a Kubernetes cluster using a JobSet template."""

from collections.abc import Sequence
import logging
import math
import os
import string
from absl import app
from absl import flags
from kubernetes import client
from kubernetes import config
import yaml

_logger = logging.getLogger(__name__)

# Flag definitions
FLAGS = flags.FLAGS
_JOBSET_NAME = flags.DEFINE_string(
"jobset_name", "pathways-service", "Name of the JobSet"
)
_JAX_VERSION = flags.DEFINE_string(
"jax_version", "0.9.0", "JAX version (e.g., 0.9.0)"
)
_TPU_TYPE = flags.DEFINE_enum(
"tpu_type", "v6e", ["v5e", "v5p", "v6e", "tpu7x"], "TPU type"
)
_TOPOLOGY = flags.DEFINE_string(
"topology", "2x2", "TPU topology (e.g., 4x8, 2x2x2)"
)
_NUM_SLICES = flags.DEFINE_integer(
"num_slices", 2, "Number of TPU slices"
)
_GCS_BUCKET = flags.DEFINE_string(
"gcs_bucket",
"gs://pathways-test-bucket",
"GCS bucket name for scratch space",
)
_TEMPLATE_FILE = flags.DEFINE_string(
"template_file",
os.path.join(
os.path.dirname(__file__), "yamls/pw-service-example.yaml",
),
"Path to the JobSet YAML template file",
)
_DRY_RUN = flags.DEFINE_boolean(
"dry_run",
False,
"If true, only print the generated YAML without deploying.",
)


def get_tpu_config(tpu_type):
"""Returns a dictionary containing TPU configuration details."""
tpu_configs = {
"v5e": {
"machine_type": "ct5lp-hightpu-4t",
"chips_per_vm": 4,
"accelerator_label": "tpu-v5-lite-podslice",
"instance_prefix": "tpuv5e",
},
"v5p": {
"machine_type": "ct5p-hightpu-4t",
"chips_per_vm": 4,
"accelerator_label": "tpu-v5p-slice",
"instance_prefix": "tpuv5p",
},
"v6e": {
"machine_type": "ct6e-standard-4t",
"chips_per_vm": 4,
"accelerator_label": "tpu-v6e-slice",
"instance_prefix": "tpuv6e",
},
"tpu7x": {
"machine_type": "tpu7x-standard-4t",
"chips_per_vm": 4,
"accelerator_label": "tpu-v7-slice",
"instance_prefix": "tpu7x",
},
}
if tpu_type not in tpu_configs:
raise ValueError(
f"Unsupported TPU type: {tpu_type}. Supported types are:"
f" {list(tpu_configs.keys())}"
)
return tpu_configs[tpu_type]


def calculate_vms_per_slice(topology, chips_per_vm):
"""Calculates the number of VMs per slice based on the topology."""
try:
dims = [int(d) for d in topology.split("x")]
total_chips = math.prod(dims)
if total_chips % chips_per_vm != 0:
raise ValueError(
f"Total chips ({total_chips}) in topology {topology} is not divisible"
f" by chips_per_vm ({chips_per_vm})"
)
return total_chips // chips_per_vm
except ValueError as e:
raise ValueError(
f"Invalid topology format: {topology}. Expected format like 'AxB' or"
f" 'AxBxC'. {e}"
) from e


def load_and_substitute_template(template_path, context):
"""Loads and substitutes the string.Template from the given path."""
try:
with open(template_path, "r") as f:
template_str = f.read()
except OSError as err:
raise ValueError(
f"Could not read template file: {template_path}: {err}"
) from err

_logger.info("Template file: %s", template_path)
_logger.info("Context: %s", context)
template = string.Template(template_str)
_logger.info("Template: %s", template)
substituted_yaml = template.substitute(context)
_logger.info("Substituted YAML: %s", substituted_yaml)
return yaml.safe_load(substituted_yaml)


def deploy_jobset(jobset_yaml):
"""Deploys the JobSet to the current Kubernetes cluster."""
try:
config.load_kube_config()
api = client.CustomObjectsApi()
api.create_namespaced_custom_object(
group="jobset.x-k8s.io",
version="v1alpha2",
namespace=jobset_yaml["metadata"]["namespace"],
body=jobset_yaml,
plural="jobsets",
)
_logger.info(
"JobSet '%s' created successfully.", jobset_yaml["metadata"]["name"]
)
except client.rest.ApiException as e:
_logger.error("Error creating JobSet: %s", e)
except config.ConfigException as e:
_logger.error("Error loading Kubernetes configuration: %s", e)
# TODO idea -- keep checking until up -- surface logs.

def run_deployment(
tpu_type,
topology,
num_slices,
jobset_name,
gcs_bucket,
jax_version,
template_file,
dry_run,
deploy_func=deploy_jobset,
):
"""Executes the deployment logic."""
tpu_config = get_tpu_config(tpu_type)
vms_per_slice = calculate_vms_per_slice(topology, tpu_config["chips_per_vm"])

context = {
"JOBSET_NAME": jobset_name,
"JAX_VERSION": jax_version,
"GCS_SCRATCH_LOCATION": gcs_bucket,
"NUM_SLICES": num_slices,
"INSTANCE_TYPE": f"{tpu_config['instance_prefix']}:{topology}",
"VMS_PER_SLICE": vms_per_slice,
"CHIPS_PER_VM": tpu_config["chips_per_vm"],
"ACCELERATOR_LABEL": tpu_config["accelerator_label"],
"TOPOLOGY": topology,
}

jobset_config = load_and_substitute_template(template_file, context)

_logger.info("--- Generated JobSet YAML ---")
_logger.info("\n%s", yaml.dump(jobset_config))
_logger.info("---")

if not dry_run:
deploy_func(jobset_config)
else:
_logger.info("Dry run mode, not deploying.")


def main(argv: Sequence[str]) -> None:
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")

try:
run_deployment(
tpu_type=_TPU_TYPE.value,
topology=_TOPOLOGY.value,
num_slices=_NUM_SLICES.value,
jobset_name=_JOBSET_NAME.value,
gcs_bucket=_GCS_BUCKET.value,
jax_version=_JAX_VERSION.value,
template_file=_TEMPLATE_FILE.value,
dry_run=_DRY_RUN.value,
)
except ValueError as e:
_logger.exception("Error: %s", e)
except FileNotFoundError:
_logger.exception(
"Error: Template file not found at %s", _TEMPLATE_FILE.value
)


if __name__ == "__main__":
app.run(main)
Original file line number Diff line number Diff line change
@@ -1,18 +1,171 @@
apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
name: pathways-cluster # jobset name
name: ${JOBSET_NAME}
namespace: default
spec:
maxRestarts: 1
customComponents:
- componentType: pathways_server
image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:jax-0.8.0@sha256:ccbdf86d185654f8fb749f51ca7dcc8178377b583d75f74180eb936a8f808050
- componentType: worker
image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:jax-0.8.0@sha256:ccbdf86d185654f8fb749f51ca7dcc8178377b583d75f74180eb936a8f808050
workers: # Modify this section to use your TPU type, topology, number of slices and the GCS bucket.
- type: ct6e-standard-4t
topology: 2x2
numSlices: 2
pathwaysDir: "gs://pathways-bucket" # Pre-create this bucket.
controller:
deploymentMode: default
coordinator:
replicatedJob: pathways-head
failurePolicy:
maxRestarts: 1
restartStrategy: Recreate
network:
enableDNSHostnames: true
publishNotReadyAddresses: true
replicatedJobs:
- name: pathways-head
replicas: 1
template:
metadata:
annotations:
alpha.jobset.sigs.k8s.io/exclusive-topology: kubernetes.io/hostname
spec:
backoffLimit: 3
completionMode: Indexed
completions: 1
parallelism: 1
template:
metadata:
annotations:
alpha.jobset.sigs.k8s.io/exclusive-topology: kubernetes.io/hostname
spec:
containers:
- name: pathways-rm
image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:jax-${JAX_VERSION}
imagePullPolicy: Always
args:
- --server_port=29001
- --gcs_scratch_location=${GCS_SCRATCH_LOCATION}
- --node_type=resource_manager
- --instance_count=${NUM_SLICES}
- --instance_type=${INSTANCE_TYPE}
env:
- name: REPLICATED_JOB_NAME
valueFrom:
fieldRef:
fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name']
- name: JOBSET_NAME
valueFrom:
fieldRef:
fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name']
- name: HOST_ADDRESS
valueFrom:
fieldRef:
fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator']
- name: TPU_SKIP_MDS_QUERY
value: "true"
ports:
- containerPort: 29001
protocol: TCP
- containerPort: 29002
protocol: TCP
resources:
limits:
cpu: "8"
memory: 32G
- name: pathways-proxy
image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:jax-${JAX_VERSION}
imagePullPolicy: Always
args:
- --server_port=29000
- --resource_manager_address=$$(PATHWAYS_HEAD):29001
- --gcs_scratch_location=${GCS_SCRATCH_LOCATION}
env:
- name: PATHWAYS_HEAD
valueFrom:
fieldRef:
fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator']
ports:
- containerPort: 29000
protocol: TCP
resources:
limits:
cpu: "16"
memory: 100G
dnsPolicy: ClusterFirstWithHostNet
hostNetwork: true
restartPolicy: OnFailure
- name: worker
replicas: ${NUM_SLICES}
template:
spec:
backoffLimit: 1000000
completionMode: Indexed
completions: ${VMS_PER_SLICE}
parallelism: ${VMS_PER_SLICE}
template:
metadata:
annotations:
alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
spec:
containers:
- name: pathways-worker
image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:jax-${JAX_VERSION}
imagePullPolicy: Always
args:
- --server_port=29005
- --resource_manager_address=$$(PATHWAYS_HEAD):29001
- --gcs_scratch_location=${GCS_SCRATCH_LOCATION}
env:
- name: TPU_MIN_LOG_LEVEL
value: "0"
- name: TF_CPP_MIN_LOG_LEVEL
value: "0"
- name: XCLOUD_ENVIRONMENT
value: GCP
- name: MEGASCALE_GRPC_ENABLE_XOR_TRACER
value: "false"
- name: MEGASCALE_NUM_SLICES
valueFrom:
fieldRef:
fieldPath: metadata.labels['jobset.sigs.k8s.io/replicatedjob-replicas']
- name: JOBSET_NAME
valueFrom:
fieldRef:
fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name']
- name: REPLICATED_JOB_NAME
valueFrom:
fieldRef:
fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name']
- name: MEGASCALE_SLICE_ID
valueFrom:
fieldRef:
fieldPath: metadata.labels['jobset.sigs.k8s.io/job-index']
- name: PATHWAYS_HEAD
valueFrom:
fieldRef:
fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator']
- name: MEGASCALE_COORDINATOR_ADDRESS
valueFrom:
fieldRef:
fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator']
ports:
- containerPort: 29005
protocol: TCP
- containerPort: 29006
protocol: TCP
- containerPort: 8471
protocol: TCP
- containerPort: 8080
protocol: TCP
resources:
limits:
google.com/tpu: "${CHIPS_PER_VM}"
volumeMounts:
- mountPath: /tmp
name: shared-tmp
dnsPolicy: ClusterFirstWithHostNet
hostNetwork: true
nodeSelector:
cloud.google.com/gke-tpu-accelerator: ${ACCELERATOR_LABEL}
cloud.google.com/gke-tpu-topology: ${TOPOLOGY}
restartPolicy: OnFailure
volumes:
- name: shared-tmp
hostPath:
path: /tmp
type: DirectoryOrCreate
startupPolicy:
startupPolicyOrder: InOrder
successPolicy:
operator: All
Loading