Skip to content

Commit 211fe91

Browse files
authored
[TPU] Correctly profile peak memory usage & Upgrade PyTorch XLA (#9438)
1 parent 6aa6020 commit 211fe91

File tree

3 files changed

+11
-10
lines changed

3 files changed

+11
-10
lines changed

Dockerfile.tpu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
ARG NIGHTLY_DATE="20240828"
1+
ARG NIGHTLY_DATE="20241017"
22
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
33

44
FROM $BASE_IMAGE

docs/source/getting_started/tpu-installation.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ First, install the dependencies:
5656
$ pip uninstall torch torch-xla -y
5757
5858
$ # Install PyTorch and PyTorch XLA.
59-
$ export DATE="20240828"
60-
$ export TORCH_VERSION="2.5.0"
59+
$ export DATE="20241017"
60+
$ export TORCH_VERSION="2.6.0"
6161
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-${TORCH_VERSION}.dev${DATE}-cp310-cp310-linux_x86_64.whl
6262
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-${TORCH_VERSION}.dev${DATE}-cp310-cp310-linux_x86_64.whl
6363

vllm/worker/tpu_worker.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -133,18 +133,19 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
133133
# Synchronize before measuring the memory usage.
134134
xm.wait_device_ops()
135135

136-
dtype_btyes = get_dtype_size(self.cache_dtype)
137-
block_size = self.cache_config.block_size
138-
block_size_bytes = (dtype_btyes * block_size * num_layers * 2 *
139-
head_size * num_kv_heads)
140-
141-
# Calculate the TPU KV cache size based on profiling.
136+
# Get the maximum amount of memory used by the model weights and
137+
# intermediate activations.
142138
m = xm.get_memory_info(self.device)
143139
total_memory_size = m["bytes_limit"]
140+
profiled = m["peak_bytes_used"] # Weights + intermediate activations.
141+
142+
# Calculate the TPU KV cache size based on profiling.
144143
usable_memory_size = int(total_memory_size *
145144
self.cache_config.gpu_memory_utilization)
146-
profiled = m["bytes_used"] # Weights + intermediate activations.
147145
tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
146+
dtype_btyes = get_dtype_size(self.cache_dtype)
147+
block_size_bytes = (dtype_btyes * self.cache_config.block_size *
148+
num_layers * 2 * head_size * num_kv_heads)
148149
num_tpu_blocks = tpu_kv_cache_bytes // block_size_bytes
149150
num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8.
150151

0 commit comments

Comments
 (0)