|
17 | 17 | from functools import cached_property
|
18 | 18 | from importlib.util import find_spec
|
19 | 19 | from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional,
|
20 |
| - Protocol, TypeVar, Union, cast, get_args) |
| 20 | + Protocol, TypeVar, Union, cast, get_args, List, Set) |
| 21 | +from re import Pattern |
21 | 22 |
|
22 | 23 | import regex as re
|
23 | 24 | import torch
|
@@ -3952,6 +3953,119 @@ class KVEventsConfig:
|
3952 | 3953 | """
|
3953 | 3954 |
|
3954 | 3955 |
|
| 3956 | +@config |
| 3957 | +@dataclass |
| 3958 | +class IntermediateLoggingConfig: |
| 3959 | + """Configuration for intermediate tensor logging.""" |
| 3960 | + |
| 3961 | + output_dir: str = "/tmp/vllm_intermediates" |
| 3962 | + """Directory where to save the intermediate tensors.""" |
| 3963 | + |
| 3964 | + module_call_match: Optional[List[str]] = None |
| 3965 | + """Match modules by name regex and call index ( |
| 3966 | + a module can be called multiple times in a step) |
| 3967 | + List of regex:call_idx, call_idx is -1 for default for all calls """ |
| 3968 | + |
| 3969 | + log_step_ids: List[int] = field(default_factory=lambda: [0]) |
| 3970 | + """List of step IDs to log (empty list means log all steps).""" |
| 3971 | + |
| 3972 | + log_post_fwd_inputs: bool = False |
| 3973 | + """Whether logging inputs after forwards for each module""" |
| 3974 | + |
| 3975 | + max_tensor_size: Optional[int] = None |
| 3976 | + """Maximum number of elements in tensors to log (None = no limit).""" |
| 3977 | + |
| 3978 | + enabled: bool = True |
| 3979 | + """Whether logging is enabled.""" |
| 3980 | + device_names: List[str] = field(default_factory=list) |
| 3981 | + """List of device names to log (empty list means log all devices).""" |
| 3982 | + |
| 3983 | + _compiled_module_calls: dict[Pattern,int] = field(default_factory=dict, init=False) |
| 3984 | + """Compiled regex patterns for module filtering.""" |
| 3985 | + |
| 3986 | + _module_call: dict[str, int] = field(default_factory=dict, init=False) |
| 3987 | + _step_id_set: Set[int] = field(default_factory=set, init=False) |
| 3988 | + """Set of step IDs for faster lookup.""" |
| 3989 | + _output_run_dir: str = "/tmp/vllm_intermediates" |
| 3990 | + """Unique directory to save single run/serve logging result.""" |
| 3991 | + |
| 3992 | + def __post_init__(self): |
| 3993 | + """Initialize derived fields after instance creation.""" |
| 3994 | + self._compile_regex_patterns() |
| 3995 | + self._output_run_dir = self.output_dir + "/" + str(uuid.uuid4()) |
| 3996 | + self._step_id_set = set(self.log_step_ids) |
| 3997 | + |
| 3998 | + def _compile_regex_patterns(self): |
| 3999 | + """Compile regex patterns for module name filtering.""" |
| 4000 | + from vllm.logger import init_logger |
| 4001 | + logger = init_logger(__name__) |
| 4002 | + |
| 4003 | + self._compiled_module_matches = [] |
| 4004 | + |
| 4005 | + if self.module_call_match is None: |
| 4006 | + logger.info("No module name regex patterns provided, will log all modules") |
| 4007 | + return |
| 4008 | + |
| 4009 | + # Compile all patterns |
| 4010 | + for regex_pattern_call_idx in self.module_call_match: |
| 4011 | + try: |
| 4012 | + splits = regex_pattern_call_idx.split(":", 2) |
| 4013 | + regex_pattern = splits[0] |
| 4014 | + call_idx = -1 |
| 4015 | + if len(splits) > 1: |
| 4016 | + call_idx = int(splits[1]) |
| 4017 | + compiled_pattern: Pattern[str] = re.compile(regex_pattern) |
| 4018 | + self._compiled_module_calls[compiled_pattern] = call_idx |
| 4019 | + logger.info(f"Successfully compiled regex pattern: '{regex_pattern}'") |
| 4020 | + except re.error as e: |
| 4021 | + logger.error(f"Invalid regex pattern from '{regex_pattern_call_idx}': {e}") |
| 4022 | + raise ValueError(f"Invalid regex pattern '{regex_pattern_call_idx}': {e}") |
| 4023 | + except Exception as e: |
| 4024 | + logger.error(f"Failed to parse module_call_match") |
| 4025 | + |
| 4026 | + |
| 4027 | + logger.info(f"Compiled {len(self._compiled_module_calls)} regex patterns") |
| 4028 | + |
| 4029 | + def to_dict(self) -> dict: |
| 4030 | + """Convert the config to a dictionary for serialization.""" |
| 4031 | + return { |
| 4032 | + "output_run_dir": self.output_run_dir, |
| 4033 | + "module_call_match": self.module_call_match, |
| 4034 | + "log_step_ids": self.log_step_ids, |
| 4035 | + "max_tensor_size": self.max_tensor_size, |
| 4036 | + "enabled": self.enabled, |
| 4037 | + "device_names": self.device_names |
| 4038 | + } |
| 4039 | + |
| 4040 | + @classmethod |
| 4041 | + def from_dict(cls, dict_value: dict) -> "IntermediateLoggingConfig": |
| 4042 | + """Parse the CLI value for the speculative config.""" |
| 4043 | + return cls(**dict_value) |
| 4044 | + |
| 4045 | + @property |
| 4046 | + def output_run_dir(self) -> str: |
| 4047 | + return self._output_run_dir |
| 4048 | + |
| 4049 | + def compute_hash(self) -> str: |
| 4050 | + """ |
| 4051 | + WARNING: Whenever a new field is added to this config, |
| 4052 | + ensure that it is included in the factors list if |
| 4053 | + it affects the computation graph. |
| 4054 | +
|
| 4055 | + Provide a hash that uniquely identifies all the configs |
| 4056 | + that affect the structure of the computation |
| 4057 | + graph from input ids/embeddings to the final hidden states, |
| 4058 | + excluding anything before input ids/embeddings and after |
| 4059 | + the final hidden states. |
| 4060 | + """ |
| 4061 | + # Intermediate logging doesn't affect the computation graph |
| 4062 | + factors: list[Any] = [] |
| 4063 | + hash_str = hashlib.md5(str(factors).encode(), |
| 4064 | + usedforsecurity=False).hexdigest() |
| 4065 | + return hash_str |
| 4066 | + |
| 4067 | + |
| 4068 | + |
3955 | 4069 | class CompilationLevel:
|
3956 | 4070 | # constants for the levels of the compilation process
|
3957 | 4071 | NO_COMPILATION = 0
|
@@ -4409,6 +4523,10 @@ class VllmConfig:
|
4409 | 4523 | """The configurations for distributed KV cache transfer."""
|
4410 | 4524 | kv_events_config: Optional[KVEventsConfig] = None
|
4411 | 4525 | """The configurations for event publishing."""
|
| 4526 | + il_config: Optional[IntermediateLoggingConfig] = None |
| 4527 | + """Configuration for intermediate tensor logging.""" |
| 4528 | + il_config_path: Optional[str] = None |
| 4529 | + """Path to a JSON file containing intermediate logging configuration.""" |
4412 | 4530 | # some opaque config, only used to provide additional information
|
4413 | 4531 | # for the hash computation, mainly used for testing, debugging or out of
|
4414 | 4532 | # tree config registration.
|
@@ -4497,6 +4615,10 @@ def compute_hash(self) -> str:
|
4497 | 4615 | vllm_factors.append(self.kv_transfer_config.compute_hash())
|
4498 | 4616 | else:
|
4499 | 4617 | vllm_factors.append("None")
|
| 4618 | + if self.il_config: |
| 4619 | + vllm_factors.append(self.il_config.compute_hash()) |
| 4620 | + else: |
| 4621 | + vllm_factors.append("None") |
4500 | 4622 | if self.additional_config:
|
4501 | 4623 | if isinstance(additional_config := self.additional_config, dict):
|
4502 | 4624 | additional_config_hash = hashlib.md5(
|
@@ -4705,6 +4827,18 @@ def __post_init__(self):
|
4705 | 4827 | if self.kv_events_config is not None:
|
4706 | 4828 | # Hybrid KV cache manager is not compatible with KV events.
|
4707 | 4829 | self.scheduler_config.disable_hybrid_kv_cache_manager = True
|
| 4830 | + # Load intermediate logging config from file if provided |
| 4831 | + if self.il_config_path is not None: |
| 4832 | + try: |
| 4833 | + logger.info(f"Loading intermediate logging config from {self.il_config_path}") |
| 4834 | + with open(self.il_config_path, 'r') as f: |
| 4835 | + il_config_dict = json.load(f) |
| 4836 | + self.il_config = IntermediateLoggingConfig.from_dict(il_config_dict) |
| 4837 | + logger.info(f"Successfully loaded intermediate logging config: {self.il_config.to_dict()}") |
| 4838 | + except Exception as e: |
| 4839 | + logger.error(f"Failed to load intermediate logging config from {self.il_config_path}: {e}") |
| 4840 | + raise ValueError(f"Failed to load intermediate logging config from {self.il_config_path}: {e}") from e |
| 4841 | + |
4708 | 4842 |
|
4709 | 4843 | def update_sizes_for_sequence_parallelism(self,
|
4710 | 4844 | possible_sizes: list) -> list:
|
|
0 commit comments