Skip to content

Commit c166e7e

Browse files
authored
[Bugfix] Allow ScalarType to be compiled with pytorch 2.3 and add checks for registering FakeScalarType and dynamo support. (#7886)
1 parent bc6e42a commit c166e7e

File tree

4 files changed

+84
-67
lines changed

4 files changed

+84
-67
lines changed

csrc/core/scalar_type.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,8 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
387387
// This needs to be implemented and throw a TypeError in order for
388388
// PyTorch's opcheck to work on ops that use ScalarTypes.
389389
int64_t len() const {
390-
throw c10::TypeError("__len__ not implemented");
390+
throw c10::TypeError({__func__, __FILE__, static_cast<uint32_t>(__LINE__)},
391+
"__len__ not implemented");
391392
return 0;
392393
}
393394

vllm/_core_ext.py

Lines changed: 70 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -181,92 +181,98 @@ def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
181181

182182
ScalarType = torch.classes._core_C.ScalarType
183183

184-
# Needed for dynamo support of ScalarType.
185-
@torch._library.register_fake_class("_core_C::ScalarType")
186-
class FakeScalarType:
184+
if (hasattr(torch, "_library")
185+
and hasattr(torch._library, "register_fake_class")):
186+
# Needed for dynamo support of ScalarType.
187+
@torch._library.register_fake_class("_core_C::ScalarType")
188+
class FakeScalarType:
187189

188-
def __init__(self, scalar_type):
189-
self.ScalarType = scalar_type
190+
def __init__(self, scalar_type):
191+
self.ScalarType = scalar_type
190192

191-
def bias_getter(self) -> int:
192-
return self.ScalarType.bias
193+
def bias_getter(self) -> int:
194+
return self.ScalarType.bias
193195

194-
def exponent_getter(self) -> int:
195-
return self.ScalarType.exponent
196+
def exponent_getter(self) -> int:
197+
return self.ScalarType.exponent
196198

197-
def mantissa_getter(self) -> int:
198-
return self.ScalarType.mantissa
199+
def mantissa_getter(self) -> int:
200+
return self.ScalarType.mantissa
199201

200-
def signed_getter(self) -> bool:
201-
return self.ScalarType.signed
202+
def signed_getter(self) -> bool:
203+
return self.ScalarType.signed
202204

203-
def size_bits_getter(self) -> int:
204-
return self.ScalarType.size_bits
205+
def size_bits_getter(self) -> int:
206+
return self.ScalarType.size_bits
205207

206-
@property
207-
def size_bits(self) -> int:
208-
return self.ScalarType.size_bits
208+
@property
209+
def size_bits(self) -> int:
210+
return self.ScalarType.size_bits
209211

210-
def min(self) -> Union[int, float]:
211-
return self.ScalarType.min()
212+
def min(self) -> Union[int, float]:
213+
return self.ScalarType.min()
212214

213-
def max(self) -> Union[int, float]:
214-
return self.ScalarType.max()
215+
def max(self) -> Union[int, float]:
216+
return self.ScalarType.max()
215217

216-
def is_signed(self) -> bool:
217-
return self.ScalarType.is_signed()
218+
def is_signed(self) -> bool:
219+
return self.ScalarType.is_signed()
218220

219-
def is_floating_point(self) -> bool:
220-
return self.ScalarType.is_floating_point()
221+
def is_floating_point(self) -> bool:
222+
return self.ScalarType.is_floating_point()
221223

222-
def is_integer(self) -> bool:
223-
return self.ScalarType.is_integer()
224+
def is_integer(self) -> bool:
225+
return self.ScalarType.is_integer()
224226

225-
def has_bias(self) -> bool:
226-
return self.ScalarType.has_bias()
227+
def has_bias(self) -> bool:
228+
return self.ScalarType.has_bias()
227229

228-
def has_infs(self) -> bool:
229-
return self.ScalarType.has_infs()
230+
def has_infs(self) -> bool:
231+
return self.ScalarType.has_infs()
230232

231-
def has_nans(self) -> bool:
232-
return self.ScalarType.has_nans()
233+
def has_nans(self) -> bool:
234+
return self.ScalarType.has_nans()
233235

234-
def is_ieee_754(self) -> bool:
235-
return self.ScalarType.is_ieee_754()
236+
def is_ieee_754(self) -> bool:
237+
return self.ScalarType.is_ieee_754()
236238

237-
def __str__(self) -> str:
238-
return self.ScalarType.__str__()
239+
def __str__(self) -> str:
240+
return self.ScalarType.__str__()
239241

240-
def __repr__(self) -> str:
241-
return self.ScalarType.__repr__()
242+
def __repr__(self) -> str:
243+
return self.ScalarType.__repr__()
242244

243-
def __len__(self) -> int:
244-
return self.ScalarType.__len__()
245+
def __len__(self) -> int:
246+
return self.ScalarType.__len__()
245247

246-
def __obj_flatten__(self) -> Tuple[Tuple[str, Any], ...]:
247-
return torch.classes._core_C.ScalarType.__obj_flatten__(
248-
self.ScalarType)
248+
def __obj_flatten__(self) -> Tuple[Tuple[str, Any], ...]:
249+
return torch.classes._core_C.ScalarType.__obj_flatten__(
250+
self.ScalarType)
249251

250-
@classmethod
251-
def __obj_unflatten__(
252-
cls, flat_type: Tuple[Tuple[str, Any], ...]) -> 'ScalarType':
253-
return cls(
254-
torch.classes._core_C.ScalarType.__obj_unflatten__(flat_type))
252+
@classmethod
253+
def __obj_unflatten__(
254+
cls, flat_type: Tuple[Tuple[str, Any],
255+
...]) -> 'ScalarType':
256+
return cls(
257+
torch.classes._core_C.ScalarType.__obj_unflatten__(
258+
flat_type))
255259

256-
@classmethod
257-
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
258-
return ScalarType.int_(size_bits, bias)
260+
@classmethod
261+
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
262+
return ScalarType.int_(size_bits, bias)
259263

260-
@classmethod
261-
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
262-
return ScalarType.uint(size_bits, bias)
264+
@classmethod
265+
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
266+
return ScalarType.uint(size_bits, bias)
263267

264-
@classmethod
265-
def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
266-
return ScalarType.float_IEEE754(exponent, mantissa)
268+
@classmethod
269+
def float_IEEE754(cls, exponent: int,
270+
mantissa: int) -> 'ScalarType':
271+
return ScalarType.float_IEEE754(exponent, mantissa)
267272

268-
@classmethod
269-
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
270-
nan_repr: int) -> 'ScalarType':
271-
return ScalarType.float_(exponent, mantissa, finite_values_only,
272-
nan_repr)
273+
@classmethod
274+
def float_(cls, exponent: int, mantissa: int,
275+
finite_values_only: bool,
276+
nan_repr: int) -> 'ScalarType':
277+
return ScalarType.float_(exponent, mantissa,
278+
finite_values_only, nan_repr)

vllm/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import psutil
2626
import torch
2727
import torch.types
28+
from packaging.version import Version
2829
from typing_extensions import ParamSpec, TypeIs, assert_never
2930

3031
import vllm.envs as envs
@@ -1114,3 +1115,11 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
11141115
"""Utility function to run async task in a lock"""
11151116
async with lock:
11161117
return await task(*args, **kwargs)
1118+
1119+
1120+
# Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0.
1121+
# In particular, the FakeScalarType is not supported for earlier versions of
1122+
# PyTorch which breaks dynamo for any ops registered using ScalarType.
1123+
def supports_dynamo() -> bool:
1124+
base_torch_version = Version(Version(torch.__version__).base_version)
1125+
return base_torch_version >= Version("2.4.0")

vllm/worker/model_runner.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@
4444
from vllm.sequence import (IntermediateTensors, SamplerOutput,
4545
SequenceGroupMetadata)
4646
from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d,
47-
flatten_2d_lists, is_hip, is_pin_memory_available)
47+
flatten_2d_lists, is_hip, is_pin_memory_available,
48+
supports_dynamo)
4849
from vllm.worker.model_runner_base import (
4950
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
5051
_add_attn_metadata_broadcastable_dict,
@@ -946,7 +947,7 @@ def load_model(self) -> None:
946947
"provided. Defaulting to scaling factors of 1.0. "
947948
"This may lead to less accurate results!")
948949

949-
if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE:
950+
if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo():
950951
self.model = torch.compile(self.model,
951952
fullgraph=True,
952953
backend="eager")

0 commit comments

Comments
 (0)