Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 89 additions & 1 deletion .github/scripts/filter-matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,41 @@ def filter_matrix_item(
return True


def create_distributed_config(item: Dict[str, Any]) -> Dict[str, Any]:
"""Create distributed test configuration from a regular config.

Takes a standard test config and modifies it for distributed testing:
- Changes runner to multi-GPU instance
- Adds num_gpus field
- Adds config marker
"""
import sys

# Create a copy to avoid modifying the original
dist_item = item.copy()

# Debug: Show original config
print(f"[DEBUG] Creating distributed config from:", file=sys.stderr)
print(f"[DEBUG] Python: {item.get('python_version')}", file=sys.stderr)
print(f"[DEBUG] CUDA: {item.get('desired_cuda')}", file=sys.stderr)
print(
f"[DEBUG] Original runner: {item.get('validation_runner')}", file=sys.stderr
)

# Override runner to use multi-GPU instance
dist_item["validation_runner"] = "linux.g4dn.12xlarge.nvidia.gpu"

# Add distributed-specific fields
dist_item["num_gpus"] = 2
dist_item["config"] = "distributed"

# Debug: Show modified config
print(f"[DEBUG] New runner: {dist_item['validation_runner']}", file=sys.stderr)
print(f"[DEBUG] GPUs: {dist_item['num_gpus']}", file=sys.stderr)

return dist_item


def main(args: list[str]) -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down Expand Up @@ -99,16 +134,69 @@ def main(args: list[str]) -> None:

includes = matrix_dict["include"]
filtered_includes = []
distributed_includes = [] # NEW: separate list for distributed configs

print(f"[DEBUG] Processing {len(includes)} input configs", file=sys.stderr)

for item in includes:
py_ver = item.get("python_version", "unknown")
cuda_ver = item.get("desired_cuda", "unknown")

print(f"[DEBUG] Checking config: py={py_ver}, cuda={cuda_ver}", file=sys.stderr)

if filter_matrix_item(
item,
options.jetpack == "true",
options.limit_pr_builds == "true",
):
print(f"[DEBUG] passed filter - adding to build matrix", file=sys.stderr)
filtered_includes.append(item)

filtered_matrix_dict = {"include": filtered_includes}
# NEW: Create distributed variant for specific configs
# Only Python 3.10 + CUDA 13.0 for now
if item["python_version"] == "3.10" and item["desired_cuda"] == "cu130":
print(
f"[DEBUG] Creating distributed config for py3.10+cu130",
file=sys.stderr,
)
distributed_includes.append(create_distributed_config(item))
else:
print(f"[DEBUG] FILTERED OUT", file=sys.stderr)

# Debug: Show summary
print(f"[DEBUG] Final counts:", file=sys.stderr)
print(f"[DEBUG] Regular configs: {len(filtered_includes)}", file=sys.stderr)
print(
f"[DEBUG] Distributed configs: {len(distributed_includes)}", file=sys.stderr
)

# Debug: Show which configs will be built
print(
f"[DEBUG] Configs that will be BUILT (in filtered_includes):", file=sys.stderr
)
for item in filtered_includes:
print(
f"[DEBUG] - py={item.get('python_version')}, cuda={item.get('desired_cuda')}",
file=sys.stderr,
)

print(
f"[DEBUG] Configs for DISTRIBUTED TESTS (in distributed_includes):",
file=sys.stderr,
)
for item in distributed_includes:
print(
f"[DEBUG] - py={item.get('python_version')}, cuda={item.get('desired_cuda')}, gpus={item.get('num_gpus')}",
file=sys.stderr,
)

# NEW: Output both regular and distributed configs
filtered_matrix_dict = {
"include": filtered_includes,
"distributed_include": distributed_includes, # NEW field
}

# Output to stdout (consumed by GitHub Actions)
print(json.dumps(filtered_matrix_dict))


Expand Down
46 changes: 41 additions & 5 deletions .github/workflows/build-test-linux-x86_64.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,11 @@ jobs:
ref: ""
test-infra-repository: pytorch/test-infra
test-infra-ref: main
build-matrix: ${{ needs.filter-matrix.outputs.matrix }}
# Extract the include array from filter-matrix output
build-matrix: |
{
"include": ${{ toJSON(fromJSON(needs.filter-matrix.outputs.matrix).include) }}
}
pre-script: ${{ matrix.pre-script }}
env-var-script: ${{ matrix.env-var-script }}
post-script: ${{ matrix.post-script }}
Expand Down Expand Up @@ -480,18 +484,50 @@ jobs:
ref: ""
test-infra-repository: pytorch/test-infra
test-infra-ref: main
build-matrix: ${{ needs.filter-matrix.outputs.matrix }}
# Extract the distributed_include array from filter-matrix output
build-matrix: |
{
"include": ${{ toJSON(fromJSON(needs.filter-matrix.outputs.matrix).distributed_include) }}
}
pre-script: ${{ matrix.pre-script }}
script: |
set -euo pipefail

# Debug: Show what config we're using
echo "=========================================="
echo "DISTRIBUTED TEST CONFIGURATION"
echo "=========================================="
echo "Python version: ${PYTHON_VERSION}"
echo "CUDA version: ${CU_VERSION}"
echo "Runner: ${{ matrix.validation_runner }}"
echo "Num GPUs: ${{ matrix.num_gpus }}"
echo "Config: ${{ matrix.config }}"
echo "=========================================="

# Verify GPUs are available
echo "Checking GPU availability:"
nvidia-smi
echo "GPU count: $(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)"
echo "=========================================="

export USE_HOST_DEPS=1
export CI_BUILD=1
export USE_TRTLLM_PLUGINS=1

# Install MPI (required for TensorRT-LLM plugins)
echo "Installing MPI..."
dnf install -y mpich mpich-devel openmpi openmpi-devel

# Run distributed tests
pushd .
cd tests/py
cd dynamo
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/l2_dynamo_distributed_test_results.xml distributed/test_nccl_ops.py
cd tests/py/dynamo

echo "Running distributed tests with mpirun..."
mpirun --allow-run-as-root -n ${{ matrix.num_gpus }} \
python -m pytest -ra \
--junitxml=${RUNNER_TEST_RESULTS_DIR}/l2_dynamo_distributed_test_results.xml \
distributed/test_nccl_ops.py

popd

concurrency:
Expand Down
11 changes: 11 additions & 0 deletions .github/workflows/build_linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,17 @@ jobs:
options: ${{ matrix.gpu_arch_type == 'cuda' && '--gpus all' || ' ' }}
timeout-minutes: ${{ inputs.timeout }}
steps:
- name: Debug matrix configuration
shell: bash
run: |
echo "=========================================="
echo "BUILD MATRIX DEBUG"
echo "=========================================="
echo "Python version: ${{ matrix.python_version }}"
echo "CUDA version: ${{ matrix.desired_cuda }}"
echo "GPU arch type: ${{ matrix.gpu_arch_type }}"
echo "Runner: ${{ matrix.validation_runner }}"
echo "=========================================="
- name: Clean workspace
shell: bash -l {0}
run: |
Expand Down
67 changes: 39 additions & 28 deletions examples/distributed_inference/tensor_parallel_initialize_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,13 @@
import tensorrt as trt
import torch
import torch.distributed as dist
from torch.distributed._tensor.device_mesh import init_device_mesh
from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh

logger = logging.getLogger(__name__)

def find_repo_root(max_depth=10):
dir_path = os.path.dirname(os.path.realpath(__file__))
for i in range(max_depth):
files = os.listdir(dir_path)
if "MODULE.bazel" in files:
return dir_path
else:
dir_path = os.path.dirname(dir_path)

raise RuntimeError("Could not find repo root")


def initialize_logger(rank, logger_file_name):
logger = logging.getLogger()
logger.setLevel(logging.INFO)
fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w")
fh.setLevel(logging.INFO)
logger.addHandler(fh)
return logger


# This is required for env initialization since we use mpirun
def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500):
# this is kept at the application level, when mpirun is used to run the application
def initialize_distributed_env(rank=0, world_size=1, port=29500):
local_rank = int(
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
)
Expand All @@ -50,9 +31,6 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(port)
os.environ["TRTLLM_PLUGINS_PATH"] = (
find_repo_root() + "/lib/libnvinfer_plugin_tensorrt_llm.so"
)

# Necessary to assign a device to each rank.
torch.cuda.set_device(local_rank)
Expand All @@ -66,16 +44,49 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,))
rank = device_mesh.get_rank()
assert rank == local_rank
logger = initialize_logger(rank, logger_file_name)
device_id = (
rank % torch.cuda.device_count()
) # Ensure each rank gets a unique device
torch.cuda.set_device(device_id)

return device_mesh, world_size, rank, logger
return device_mesh, world_size, rank


def cleanup_distributed_env():
"""Clean up distributed process group to prevent resource leaks."""
if dist.is_initialized():
dist.destroy_process_group()


def check_tensor_parallel_device_number(world_size: int) -> None:
if world_size % 2 != 0:
raise ValueError(
f"TP examples require even number of GPUs, but got {world_size} gpus"
)


def get_tensor_parallel_device_mesh(
rank: int = 0, world_size: int = 1
) -> tuple[DeviceMesh, int, int]:
local_rank = int(
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
)
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size))
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,))
rank = device_mesh.get_rank()
assert rank == local_rank
device_id = (
rank % torch.cuda.device_count()
) # Ensure each rank gets a unique device
torch.cuda.set_device(device_id)

return device_mesh, world_size, rank


def initialize_distributed_logger(rank: int, logger_file_name: str) -> logging.Logger:
logger = logging.getLogger()
logger.setLevel(logging.INFO)
fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w")
fh.setLevel(logging.INFO)
logger.addHandler(fh)
return logger
21 changes: 13 additions & 8 deletions examples/distributed_inference/tensor_parallel_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,31 @@

"""

import logging
import os
import time

import torch
import torch_tensorrt
from rotary_embedding import RotaryAttention, parallel_rotary_block
import torch.distributed as dist
from tensor_parallel_initialize_dist import (
cleanup_distributed_env,
get_tensor_parallel_device_mesh,
initialize_distributed_env,
initialize_distributed_logger,
)

device_mesh, _world_size, _rank, logger = initialize_distributed_env(
"./tensor_parallel_rotary_embedding"
)
if not dist.is_initialized():
initialize_distributed_env()

import torch_tensorrt

device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh()
logger = initialize_distributed_logger(_rank, "tensor_parallel_rotary_embedding")

from rotary_embedding import RotaryAttention, parallel_rotary_block

"""
This example covers the rotary embedding in Llama3 model and is derived from https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning
Command to run with single GPU: mpirun -n 1 --allow-run-as-root python tensor_parallel_rotary_embedding.py
Command to run with single GPU: USE_TRTLLM_PLUGINS=1 mpirun -n 1 --allow-run-as-root python tensor_parallel_rotary_embedding.py
Command to run with 2 GPUs: USE_TRTLLM_PLUGINS=1 mpirun -n 2 --allow-run-as-root python tensor_parallel_rotary_embedding.py
"""

BATCH = 2
Expand Down
19 changes: 14 additions & 5 deletions examples/distributed_inference/tensor_parallel_simple_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
-----
.. code-block:: bash

mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py
USE_TRTLLM_PLUGINS=1 mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py
"""

import time
Expand All @@ -25,22 +25,31 @@
import torch
import torch.distributed as dist
import torch.nn as nn
import torch_tensorrt
from tensor_parallel_initialize_dist import (
cleanup_distributed_env,
get_tensor_parallel_device_mesh,
initialize_distributed_env,
initialize_distributed_logger,
)

if not dist.is_initialized():
initialize_distributed_env()
import torch_tensorrt
from torch.distributed._tensor import Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
parallelize_module,
)

device_mesh, _world_size, _rank, logger = initialize_distributed_env(
"./tensor_parallel_simple_example"
from torch_tensorrt.dynamo.distributed.utils import (
get_tensor_parallel_device_mesh,
initialize_distributed_logger,
)

device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh()
logger = initialize_distributed_logger(_rank, "tensor_parallel_simple_example")


"""
This example takes some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
"""
Expand Down
Loading
Loading