From a5f209c3ef9dcb640e0078a9d31d947209fce1d4 Mon Sep 17 00:00:00 2001 From: Pathways-on-Cloud Team Date: Fri, 29 Aug 2025 14:39:17 -0700 Subject: [PATCH] Add `managed_pathways_service` for Pathways-on-Cloud This change introduces Managed Pathways Service for GKE. It includes: `tpu_manager.py` uses `kubectl` to deploy a Pathways proxy JobSet on a GKE cluster, sets up port forwarding, and configures JAX environment variables to connect to the proxy. `run_connect_example.py` is an example script to start the proxy. Prerequisite: A Pathways cluster is up and running with Resource Manager and worker pods deployed successfully, e.g., using pw-cluster.yaml. TESTED: on proxy 2 clients each requesting `2 x v5e-32` simultaneously. PiperOrigin-RevId: 801034430 --- .../managed_pathways_service/__init__.py | 1 + .../managed_pathways_service/pw-cluster.yaml | 152 +++++++++ .../managed_pathways_service/pw-proxy.yaml | 56 ++++ .../run_connect_example.py | 50 +++ .../managed_pathways_service/tpu_manager.py | 308 ++++++++++++++++++ 5 files changed, 567 insertions(+) create mode 100644 pathwaysutils/managed_pathways_service/__init__.py create mode 100644 pathwaysutils/managed_pathways_service/pw-cluster.yaml create mode 100644 pathwaysutils/managed_pathways_service/pw-proxy.yaml create mode 100644 pathwaysutils/managed_pathways_service/run_connect_example.py create mode 100644 pathwaysutils/managed_pathways_service/tpu_manager.py diff --git a/pathwaysutils/managed_pathways_service/__init__.py b/pathwaysutils/managed_pathways_service/__init__.py new file mode 100644 index 0000000..cad5e6f --- /dev/null +++ b/pathwaysutils/managed_pathways_service/__init__.py @@ -0,0 +1 @@ +# This file marks this directory as a Python package. diff --git a/pathwaysutils/managed_pathways_service/pw-cluster.yaml b/pathwaysutils/managed_pathways_service/pw-cluster.yaml new file mode 100644 index 0000000..d299528 --- /dev/null +++ b/pathwaysutils/managed_pathways_service/pw-cluster.yaml @@ -0,0 +1,152 @@ +apiVersion: jobset.x-k8s.io/v1alpha2 +kind: JobSet +metadata: + name: pathways-akshu-s4-rw7 +spec: + coordinator: + replicatedJob: pathways-head + failurePolicy: + maxRestarts: 1 + restartStrategy: Recreate + network: + enableDNSHostnames: true + publishNotReadyAddresses: true + replicatedJobs: + - name: pathways-head + replicas: 1 + template: + metadata: + annotations: + alpha.jobset.sigs.k8s.io/exclusive-topology: kubernetes.io/hostname + spec: + backoffLimit: 0 + completionMode: Indexed + completions: 1 + parallelism: 1 + template: + metadata: + labels: + kueue.x-k8s.io/podset: pathways-head + spec: + containers: + - args: + - --server_port=29001 + - --gcs_scratch_location=gs://akshu-v5e + - --node_type=resource_manager + - --instance_count=4 + - --instance_type=tpuv5e:4x8 + - --xla_tpu_use_enhanced_launch_barrier=true + - --logtostderr + - --stderrthreshold=0 + - --v=1 + env: + - name: REPLICATED_JOB_NAME + valueFrom: + fieldRef: + fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name'] + - name: JOBSET_NAME + valueFrom: + fieldRef: + fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name'] + - name: HOST_ADDRESS + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator'] + - name: TPU_SKIP_MDS_QUERY + value: "true" + image: us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/akshu/unsanitized_server:latest + imagePullPolicy: Always + name: pathways-rm + ports: + - containerPort: 29001 + protocol: TCP + - containerPort: 29002 + protocol: TCP + resources: + limits: + cpu: "8" + memory: 16G + nodeSelector: + cloud.google.com/gke-nodepool: cpu-np + dnsPolicy: ClusterFirstWithHostNet + hostNetwork: true + restartPolicy: OnFailure + - name: worker + replicas: 4 + template: + metadata: {} + spec: + backoffLimit: 64 + completionMode: Indexed + completions: 8 + parallelism: 8 + template: + metadata: + annotations: + alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool + labels: + kueue.x-k8s.io/podset: worker + spec: + containers: + - args: + - --server_port=29005 + - --resource_manager_address=$(PATHWAYS_HEAD):29001 + - --gcs_scratch_location=gs://akshu-v5e + - --xla_tpu_use_enhanced_launch_barrier=true + - --logtostderr + - --stderrthreshold=0 + - --v=1 + env: + - name: TPU_MIN_LOG_LEVEL + value: "0" + - name: TF_CPP_MIN_LOG_LEVEL + value: "0" + - name: XCLOUD_ENVIRONMENT + value: GCP + - name: MEGASCALE_GRPC_ENABLE_XOR_TRACER + value: "false" + - name: MEGASCALE_NUM_SLICES + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/replicatedjob-replicas'] + - name: JOBSET_NAME + valueFrom: + fieldRef: + fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name'] + - name: REPLICATED_JOB_NAME + valueFrom: + fieldRef: + fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name'] + - name: MEGASCALE_SLICE_ID + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/job-index'] + - name: PATHWAYS_HEAD + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator'] + - name: MEGASCALE_COORDINATOR_ADDRESS + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator'] + image: us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/akshu/unsanitized_server:latest + imagePullPolicy: Always + name: pathways-worker + ports: + - containerPort: 29005 + protocol: TCP + - containerPort: 29006 + protocol: TCP + - containerPort: 8471 + protocol: TCP + - containerPort: 8080 + protocol: TCP + resources: + limits: + google.com/tpu: "4" + dnsPolicy: ClusterFirstWithHostNet + hostNetwork: true + nodeSelector: + cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice + cloud.google.com/gke-tpu-topology: 4x8 + restartPolicy: OnFailure diff --git a/pathwaysutils/managed_pathways_service/pw-proxy.yaml b/pathwaysutils/managed_pathways_service/pw-proxy.yaml new file mode 100644 index 0000000..0e47507 --- /dev/null +++ b/pathwaysutils/managed_pathways_service/pw-proxy.yaml @@ -0,0 +1,56 @@ +apiVersion: jobset.x-k8s.io/v1alpha2 +kind: JobSet +metadata: + name: ${PROXY_NAME} +spec: + coordinator: + replicatedJob: pathways-head + failurePolicy: + maxRestarts: 1 + restartStrategy: Recreate + network: + enableDNSHostnames: true + publishNotReadyAddresses: true + replicatedJobs: + - name: pathways-head + replicas: 1 + template: + metadata: + annotations: + alpha.jobset.sigs.k8s.io/exclusive-topology: kubernetes.io/hostname + spec: + backoffLimit: 0 + completionMode: Indexed + completions: 1 + parallelism: 1 + template: + metadata: + labels: + kueue.x-k8s.io/podset: pathways-head + spec: + containers: + - args: + - --server_port=29000 + - --resource_manager_address=${PATHWAYS_HEAD}:${PATHWAYS_HEAD_PORT} + - --gcs_scratch_location=${GCS_BUCKET} + - --virtual_slices=${EXPECTED_INSTANCES} + env: + - name: PATHWAYS_HEAD + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator'] + image: us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/akshu/unsanitized_proxy_server:latest + imagePullPolicy: Always + name: pathways-proxy + ports: + - containerPort: 29000 + protocol: TCP + resources: + limits: + cpu: "16" + memory: 100G + nodeSelector: + cloud.google.com/gke-nodepool: cpu-np + dnsPolicy: ClusterFirstWithHostNet + hostNetwork: true + restartPolicy: OnFailure diff --git a/pathwaysutils/managed_pathways_service/run_connect_example.py b/pathwaysutils/managed_pathways_service/run_connect_example.py new file mode 100644 index 0000000..1737565 --- /dev/null +++ b/pathwaysutils/managed_pathways_service/run_connect_example.py @@ -0,0 +1,50 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Script to run JAX code on TPU with the Managed Pathways service.""" + +from collections.abc import Sequence +from absl import app +from . import tpu_manager + + +def main(argv: Sequence[str]) -> None: + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + with tpu_manager.connect( + "pw-scale-test-v5e-32", + "cloud-tpu-multipod-dev", + "us-south1", + "gs://akshu-v5e", + "pathways-akshu-s4-rw7-pathways-head-0-0.pathways-akshu-s4-rw7:29001", + {"tpuv5e:4x8": 2}, + ) as tm: + pass + # import jax.numpy as jnp + # import pathwaysutils + # import pprint + + # pathwaysutils.initialize() + + # orig_matrix = jnp.zeros(5) + + # print("start") + # result_matrix = orig_matrix + 1 + # print("Original Random Matrix:") + # pprint.pprint(orig_matrix) + # print("\nMatrix after adding 1:") + # pprint.pprint(result_matrix) + + +if __name__ == "__main__": + app.run(main) diff --git a/pathwaysutils/managed_pathways_service/tpu_manager.py b/pathwaysutils/managed_pathways_service/tpu_manager.py new file mode 100644 index 0000000..08aa251 --- /dev/null +++ b/pathwaysutils/managed_pathways_service/tpu_manager.py @@ -0,0 +1,308 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module for connecting to a Pathways server.""" + +import contextlib +import logging +import os +import random +import socket +import string +import subprocess + +PROXY_PORT = 29000 + +logger = logging.getLogger(__name__) +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", +) + + +def fetch_cluster_credentials( + cluster_name: str, project_id: str, location: str +) -> None: + """Fetches credentials for the GKE cluster.""" + # Always ensure we have fresh credentials for kubectl. + try: + logger.info("Fetching credentials for '%s'.", cluster_name) + subprocess.run( + [ + "gcloud", "container", "clusters", "get-credentials", cluster_name, + "--zone", location, "--project", project_id + ], + check=True, capture_output=True, text=True + ) + except subprocess.CalledProcessError as e: + logger.error( + "Failed to get cluster credentials. gcloud output:\\n%s", e.stderr + ) + raise + + +def deploy_pathways_cluster_pods( + pathways_service: str, + proxy_name: str, + expected_instances: dict[str, int], + gcs_bucket: str, +): + """Deploys the Pathways proxy pods to the GKE cluster. + + Args: + pathways_service: The service name and port of the Pathways head. + proxy_name: The name to use for the deployed proxy. + + Raises: + subprocess.CalledProcessError: If the kubectl command fails. + """ + logger.info("Deploying Pathways proxy") + script_dir = os.path.dirname(__file__) + yaml_path = os.path.join(script_dir, "pw-proxy.yaml") + with open(yaml_path, "r") as f: + yaml_template = f.read() + + pathways_head, pathways_head_port = pathways_service.split(":") + + machine_type, count = list(expected_instances.items())[0] + instances_str = ",".join([machine_type] * count) + + template = string.Template(yaml_template) + substituted_yaml = template.substitute( + PROXY_NAME=proxy_name, + PATHWAYS_HEAD=pathways_head, + PATHWAYS_HEAD_PORT=pathways_head_port, + EXPECTED_INSTANCES=instances_str, + GCS_BUCKET=gcs_bucket, + ) + + print(f"Proxy name: {proxy_name}") + try: + proxy_result = subprocess.run( + ["kubectl", "apply", "-f", "-"], + input=substituted_yaml, + check=True, + capture_output=True, + text=True, + ) + logger.info("Successfully deployed Pathways proxy. %s", proxy_result.stdout) + except subprocess.CalledProcessError as e: + logger.error( + "Failed to deploy Pathways proxy. kubectl output:\\n%s", e.stderr + ) + raise + + pass + + +def _find_free_local_port(starting_port: int) -> int: + """Finds a free local port, starting from the given port.""" + port = starting_port + while True: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind(("", port)) + logger.info("Port binding successful at port: %d", port) + return port + except OSError: + logger.info("Port %d is in use, trying next port.", port) + port += 1 + + +class TPUManager: + """Class for managing TPUs.""" + + def __init__( + self, + cluster: str, + project: str, + region: str, + bucket: str, + pathways_service: str, + expected_instances: dict[str, int], + ): + """Initializes the TPU manager.""" + self.cluster = cluster + self.project = project + self.region = region + self.bucket = bucket + self.pathways_service = pathways_service + self.expected_instances = expected_instances + characters = "abcdefghijklmnopqrstuvwxyz0123456789" + random_string = "".join(random.choice(characters) for _ in range(5)) + self.proxy_name = f"akshu-s4-{random_string}" + # Save the original JAX environment variables so they can be restored when + # the context manager exits. + self.original_jax_platforms = None + self.original_jax_backend_target = None + self.port_forward_process = None + self.proxy_port = None + + def __enter__(self): + """Enters the context manager, ensuring cluster exists.""" + self.original_jax_platforms = os.environ.get("JAX_PLATFORMS") + self.original_jax_backend_target = os.environ.get("JAX_BACKEND_TARGET") + deploy_pathways_cluster_pods( + self.pathways_service, + self.proxy_name, + self.expected_instances, + self.bucket, + ) + logger.info("Waiting for proxy pod to be ready...") + try: + wait_command = [ + "kubectl", + "wait", + "--for=condition=ready", + "pod", + "-l", + f"jobset.sigs.k8s.io/jobset-name={self.proxy_name}", + "--timeout=30s", + ] + subprocess.run( + wait_command, check=True, capture_output=True, text=True + ) + except (subprocess.CalledProcessError, subprocess.TimeoutExpired) as e: + logger.error("Error waiting for proxy pod to become ready: %s", e.stderr) + try: + log_command = f"kubectl logs jobset/{self.proxy_name}" + logs_result = subprocess.run( + log_command, + shell=True, + check=True, + capture_output=True, + text=True, + ) + logger.error( + "Logs from jobset/%s:\n%s", self.proxy_name, logs_result.stdout + ) + except subprocess.CalledProcessError as log_e: + logger.error( + "Could not retrieve logs for jobset/%s: %s", + self.proxy_name, + log_e.stderr, + ) + raise RuntimeError("Proxy pod did not become ready.") from e + + get_pod_command = ( + f"kubectl get pods -l jobset.sigs.k8s.io/jobset-name={self.proxy_name} " + "--no-headers -o custom-columns=':metadata.name'" + ) + pod_result = subprocess.run( + get_pod_command, + shell=True, + check=True, + capture_output=True, + text=True, + ) + proxy_pod = pod_result.stdout.strip().split("\n")[0] + logger.info("Proxy pod ready: %s", proxy_pod) + + self.proxy_port = _find_free_local_port(PROXY_PORT) + + # Start port forwarding in the background. + logger.info( + "Starting port forwarding from local port %d to %s", + self.proxy_port, + proxy_pod, + ) + self.port_forward_process = subprocess.Popen([ + "kubectl", + "port-forward", + proxy_pod, + f"{self.proxy_port}:{PROXY_PORT}", + ]) + + os.environ["JAX_PLATFORMS"] = "proxy" + os.environ["JAX_BACKEND_TARGET"] = f"grpc://127.0.0.1:{self.proxy_port}" + logger.info("TPU manager ready for cluster '%s'.", self.cluster) + return self + + def get_pathways_service(self): + """Returns the Pathways service.""" + return self.pathways_service + + def __exit__(self, exc_type, exc_value, traceback): + """Exits the context manager.""" + if self.port_forward_process: + self.port_forward_process.terminate() + self.port_forward_process.wait() + if self.original_jax_platforms is None: + if "JAX_PLATFORMS" in os.environ: + del os.environ["JAX_PLATFORMS"] + else: + os.environ["JAX_PLATFORMS"] = self.original_jax_platforms + + if self.original_jax_backend_target is None: + if "JAX_BACKEND_TARGET" in os.environ: + del os.environ["JAX_BACKEND_TARGET"] + else: + os.environ["JAX_BACKEND_TARGET"] = self.original_jax_backend_target + logger.info("Deleting Pathways proxy") + try: + proxy_result = subprocess.run( + [ + "kubectl", + "delete", + "jobset", + self.proxy_name, + "--ignore-not-found", + ], + check=True, + capture_output=True, + text=True, + ) + logger.info( + "Successfully deleted Pathways proxy. %s", proxy_result.stdout + ) + except subprocess.CalledProcessError as e: + logger.error( + "Failed to delete Pathways proxy. kubectl output:\\n%s", e.stderr + ) + raise + logger.info("Exiting TPUManager context.") + + +def validate_instance_list(expected_instances: dict[str, int]): + """Validates the instance list.""" + if not expected_instances: + logger.error("No instances found.") + raise ValueError("No instances found.") + for inst in expected_instances.keys(): + if not inst.strip(): + logger.error("Instance list contains empty string.") + raise ValueError("Instance list contains empty string.") + assert len(expected_instances.keys()) == 1, ( + "Only one machine type is supported at this time." + ) + + +@contextlib.contextmanager +def connect( + cluster, project, region, bucket, pathways_service, expected_instances +): + """Connects to a Pathways server if the cluster exists. If not, creates it.""" + validate_instance_list(expected_instances) + fetch_cluster_credentials(cluster, project, region) + with TPUManager( + cluster, project, region, bucket, pathways_service, expected_instances + ) as t: + try: + yield t + finally: + # Release the TPU resources. + pass + + +def run(): + pass