Skip to content

[Installation]: TPU-Inference requires torch with cuda #921

@BabyChouSr

Description

@BabyChouSr

Your current environment

Collecting environment information...
==============================
        System Info
==============================
OS                           : Ubuntu 22.04.2 LTS (x86_64)
GCC version                  : (Ubuntu 11.3.0-1ubuntu1~22.04) 11.3.0
Clang version                : Could not collect
CMake version                : version 4.1.0
Libc version                 : glibc-2.35

==============================
       PyTorch Info
==============================
PyTorch version              : 2.8.0+cpu
Is debug build               : False
CUDA used to build PyTorch   : None
ROCM used to build PyTorch   : N/A

==============================
      Python Environment
==============================
Python version               : 3.12.12 (main, Oct 14 2025, 21:25:31) [Clang 20.1.4 ] (64-bit runtime)
Python platform              : Linux-5.19.0-1022-gcp-x86_64-with-glibc2.35

==============================
       CUDA / GPU Info
==============================
Is CUDA available            : False
CUDA runtime version         : No CUDA
CUDA_MODULE_LOADING set to   : N/A
GPU models and configuration : No CUDA
Nvidia driver version        : No CUDA
cuDNN version                : No CUDA
HIP runtime version          : N/A
MIOpen runtime version       : N/A
Is XNNPACK available         : True

==============================
          CPU Info
==============================
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Address sizes:                   48 bits physical, 48 bits virtual
Byte Order:                      Little Endian
CPU(s):                          240
On-line CPU(s) list:             0-239
Vendor ID:                       AuthenticAMD
Model name:                      AMD EPYC 7B12
CPU family:                      23
Model:                           49
Thread(s) per core:              2
Core(s) per socket:              60
Socket(s):                       2
Stepping:                        0
BogoMIPS:                        4499.99
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr arat npt nrip_save umip rdpid
Hypervisor vendor:               KVM
Virtualization type:             full
L1d cache:                       3.8 MiB (120 instances)
L1i cache:                       3.8 MiB (120 instances)
L2 cache:                        60 MiB (120 instances)
L3 cache:                        480 MiB (30 instances)
NUMA node(s):                    2
NUMA node0 CPU(s):               0-59,120-179
NUMA node1 CPU(s):               60-119,180-239
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Mmio stale data:   Not affected
Vulnerability Retbleed:          Mitigation; untrained return thunk; SMT enabled with STIBP protection
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Retpolines, IBPB conditional, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected

==============================
Versions of relevant libraries
==============================
[pip3] mypy==1.18.2
[pip3] mypy_extensions==1.1.0
[pip3] numpy==2.3.4
[pip3] pyzmq==27.1.0
[pip3] torch==2.8.0+cpu
[pip3] torchax==0.0.7
[pip3] torchvision==0.23.0
[pip3] transformers==4.55.4
[pip3] triton==3.5.0
[conda] Could not collect

==============================
         vLLM Info
==============================
ROCM Version                 : Could not collect
vLLM Version                 : 0.11.1
vLLM Build Flags:
  CUDA Archs: Not Set; ROCm: Disabled
GPU Topology:
  Could not collect

==============================
     Environment Variables
==============================
PYTORCH_NVML_BASED_CUDA_CHECK=1
TORCHINDUCTOR_COMPILE_THREADS=1

christopherchou@t1v-n-92fc5012-w-0:~/vllm-test$ uv run python -c "import jax; jax.print_environment_info()"
jax:    0.7.2
jaxlib: 0.7.2
numpy:  2.3.4
python: 3.12.12 (main, Oct 14 2025, 21:25:31) [Clang 20.1.4 ]
device info: TPU v4-4, 4 local devices"
process_count: 1
platform: uname_result(system='Linux', node='t1v-n-92fc5012-w-0', release='5.19.0-1022-gcp', version='#24~22.04.1-Ubuntu SMP Sun Apr 23 09:51:08 UTC 2023', machine='x86_64')

How you are installing TPU inference?

Note: everything works -- just curious if we can slim down the size of the torch dependency by switching from torch with cuda to cpu.

uv pip install vllm-tpu

I'm curious as to why the vllm-tpu wheel requires the cuda builds for torch and torchvision? I did some preliminary analysis where I switch torch and torchvision with the cpu versions but that leads to some torchvision error from vllm. I would assume that we wouldn't need the cuda build versions if we are running on a TPU VM but because of this limitation, it looks like I can't get rid of it. Furthermore, it is ideal to use the cpu versions because the build size will be smaller which is important for startup time / runtime environment build time.

christopherchou@t1v-n-92fc5012-w-0:~/marin$ uv pip show torch torchvision
Name: torch
Version: 2.8.0+cpu
Location: /home/christopherchou/marin/.venv/lib/python3.12/site-packages
Requires: filelock, fsspec, jinja2, networkx, setuptools, sympy, typing-extensions
Required-by: compressed-tensors, marin, nixl, torchvision, xgrammar
---
Name: torchvision
Version: 0.23.0
Location: /home/christopherchou/marin/.venv/lib/python3.12/site-packages
Requires: numpy, pillow, torch
Required-by: tpu-inference
christopherchou@t1v-n-92fc5012-w-0:~/marin$ uv run --extra cpu python
Python 3.12.12 (main, Oct 14 2025, 21:25:31) [Clang 20.1.4 ] on linux
Type "help", "copyright", "credits" or "license" for more information.
christopherchou@t1v-n-92fc5012-w-0:~/marin$ uv run --extra cpu python
Python 3.12.12 (main, Oct 14 2025, 21:25:31) [Clang 20.1.4 ] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from vllm import LLM
INFO 10-22 18:33:47 [__init__.py:224] Automatically detected platform tpu.
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<frozen importlib._bootstrap>", line 1412, in _handle_fromlist
  File "/home/christopherchou/marin/.venv/lib/python3.12/site-packages/vllm/__init__.py", line 74, in __getattr__
    module = import_module(module_name, __package__)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/christopherchou/.local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/lib/python3.12/importlib/__init__.py", line 90, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/christopherchou/marin/.venv/lib/python3.12/site-packages/vllm/entrypoints/llm.py", line 20, in <module>
    from vllm.config import (
  File "/home/christopherchou/marin/.venv/lib/python3.12/site-packages/vllm/config/__init__.py", line 5, in <module>
    from vllm.config.compilation import (
  File "/home/christopherchou/marin/.venv/lib/python3.12/site-packages/vllm/config/compilation.py", line 18, in <module>
    from vllm.platforms import current_platform
  File "<frozen importlib._bootstrap>", line 1412, in _handle_fromlist
  File "/home/christopherchou/marin/.venv/lib/python3.12/site-packages/vllm/platforms/__init__.py", line 254, in __getattr__
    _current_platform = resolve_obj_by_qualname(platform_cls_qualname)()
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/christopherchou/marin/.venv/lib/python3.12/site-packages/vllm/utils/__init__.py", line 2504, in resolve_obj_by_qualname
    module = importlib.import_module(module_name)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/christopherchou/.local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/lib/python3.12/importlib/__init__.py", line 90, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/christopherchou/marin/.venv/lib/python3.12/site-packages/vllm/platforms/tpu.py", line 12, in <module>
    from vllm.sampling_params import SamplingParams, SamplingType
  File "/home/christopherchou/marin/.venv/lib/python3.12/site-packages/vllm/sampling_params.py", line 16, in <module>
    from vllm.logits_process import LogitsProcessor
  File "/home/christopherchou/marin/.venv/lib/python3.12/site-packages/vllm/logits_process.py", line 8, in <module>
    from vllm.transformers_utils.tokenizer import AnyTokenizer
  File "/home/christopherchou/marin/.venv/lib/python3.12/site-packages/vllm/transformers_utils/tokenizer.py", line 18, in <module>
    from vllm.transformers_utils.config import get_sentence_transformer_tokenizer_config
  File "/home/christopherchou/marin/.venv/lib/python3.12/site-packages/vllm/transformers_utils/config.py", line 28, in <module>
    from transformers.models.auto.image_processing_auto import get_image_processor_config
  File "/home/christopherchou/marin/.venv/lib/python3.12/site-packages/transformers/models/auto/image_processing_auto.py", line 27, in <module>
    from ...image_processing_utils import ImageProcessingMixin
  File "/home/christopherchou/marin/.venv/lib/python3.12/site-packages/transformers/image_processing_utils.py", line 22, in <module>
    from .image_transforms import center_crop, normalize, rescale
  File "/home/christopherchou/marin/.venv/lib/python3.12/site-packages/transformers/image_transforms.py", line 22, in <module>
    from .image_utils import (
  File "/home/christopherchou/marin/.venv/lib/python3.12/site-packages/transformers/image_utils.py", line 59, in <module>
    from torchvision.transforms import InterpolationMode
  File "/home/christopherchou/marin/.venv/lib/python3.12/site-packages/torchvision/__init__.py", line 10, in <module>
    from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils  # usort:skip
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/christopherchou/marin/.venv/lib/python3.12/site-packages/torchvision/_meta_registrations.py", line 163, in <module>
    @torch.library.register_fake("torchvision::nms")
     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/christopherchou/marin/.venv/lib/python3.12/site-packages/torch/library.py", line 1069, in register
    use_lib._register_fake(
  File "/home/christopherchou/marin/.venv/lib/python3.12/site-packages/torch/library.py", line 219, in _register_fake
    handle = entry.fake_impl.register(
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/christopherchou/marin/.venv/lib/python3.12/site-packages/torch/_library/fake_impl.py", line 50, in register
    if torch._C._dispatch_has_kernel_for_dispatch_key(self.qualname, "Meta"):
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: operator torchvision::nms does not exist

Before submitting a new issue...

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions