|
| 1 | +""" |
| 2 | +Configuration for intermediate tensor logging. |
| 3 | +
|
| 4 | +This module defines the configuration data class for intermediate tensor logging, |
| 5 | +which controls how intermediate tensors are captured and saved during model execution. |
| 6 | +""" |
| 7 | + |
| 8 | +import dataclasses |
| 9 | +import re |
| 10 | +from pathlib import Path |
| 11 | +from typing import Optional, Pattern, List, Set, Union |
| 12 | + |
| 13 | + |
| 14 | +@dataclasses.dataclass |
| 15 | +class IntermediateLoggingConfig: |
| 16 | + """Configuration for intermediate tensor logging.""" |
| 17 | + |
| 18 | + # Directory where to save the intermediate tensors |
| 19 | + output_dir: str = "/tmp/vllm_intermediates" |
| 20 | + |
| 21 | + # Regex patterns to filter modules by name (None or empty list means log all modules) |
| 22 | + # Can be a single string or a list of strings |
| 23 | + module_name_regex: Optional[Union[str, List[str]]] = None |
| 24 | + |
| 25 | + # List of step IDs to log (empty list means log all steps) |
| 26 | + log_step_ids: List[int] = dataclasses.field(default_factory=lambda: [0, 1]) |
| 27 | + |
| 28 | + # Maximum number of elements in tensors to log (None = no limit) |
| 29 | + max_tensor_size: Optional[int] = None |
| 30 | + |
| 31 | + # Whether logging is enabled |
| 32 | + enabled: bool = True |
| 33 | + |
| 34 | + # Current step counter (incremented after each forward pass) |
| 35 | + current_step: int = 0 |
| 36 | + |
| 37 | + # List of device names to log (empty list means log all devices) |
| 38 | + device_names: List[str] = dataclasses.field(default_factory=list) |
| 39 | + |
| 40 | + # Compiled regex patterns for module filtering |
| 41 | + _module_name_patterns: List[Pattern] = dataclasses.field(default_factory=list) |
| 42 | + |
| 43 | + # Set of step IDs for faster lookup |
| 44 | + _step_id_set: Set[int] = dataclasses.field(default_factory=set) |
| 45 | + |
| 46 | + def __post_init__(self): |
| 47 | + """Initialize derived fields after instance creation.""" |
| 48 | + self._compile_regex_patterns() |
| 49 | + self._step_id_set = set(self.log_step_ids) |
| 50 | + Path(self.output_dir).mkdir(exist_ok=True, parents=True) |
| 51 | + |
| 52 | + def _compile_regex_patterns(self): |
| 53 | + """Compile regex patterns for module name filtering.""" |
| 54 | + from vllm.logger import init_logger |
| 55 | + logger = init_logger(__name__) |
| 56 | + |
| 57 | + self._module_name_patterns = [] |
| 58 | + |
| 59 | + if self.module_name_regex is None: |
| 60 | + logger.info("No module name regex patterns provided, will log all modules") |
| 61 | + return |
| 62 | + |
| 63 | + # Convert single string to list for uniform handling |
| 64 | + patterns = self.module_name_regex |
| 65 | + if isinstance(patterns, str): |
| 66 | + patterns = [patterns] |
| 67 | + logger.info(f"Converting single regex pattern to list: [{patterns[0]}]") |
| 68 | + else: |
| 69 | + logger.info(f"Using list of regex patterns: {patterns}") |
| 70 | + |
| 71 | + # Compile all patterns |
| 72 | + for pattern in patterns: |
| 73 | + try: |
| 74 | + compiled_pattern = re.compile(pattern) |
| 75 | + self._module_name_patterns.append(compiled_pattern) |
| 76 | + logger.info(f"Successfully compiled regex pattern: '{pattern}'") |
| 77 | + except re.error as e: |
| 78 | + logger.error(f"Invalid regex pattern '{pattern}': {e}") |
| 79 | + raise ValueError(f"Invalid regex pattern '{pattern}': {e}") |
| 80 | + |
| 81 | + logger.info(f"Compiled {len(self._module_name_patterns)} regex patterns") |
| 82 | + |
| 83 | + def should_log_step(self) -> bool: |
| 84 | + """Check if the current step should be logged based on the step IDs.""" |
| 85 | + if not self.enabled: |
| 86 | + return False |
| 87 | + |
| 88 | + # If log_step_ids is empty, log all steps |
| 89 | + if not self.log_step_ids: |
| 90 | + return True |
| 91 | + |
| 92 | + # Otherwise, check if current step is in the set of step IDs to log |
| 93 | + return self.current_step in self._step_id_set |
| 94 | + |
| 95 | + def should_log_device(self, device_name: str) -> bool: |
| 96 | + """Check if a device should be logged based on the device names. |
| 97 | + |
| 98 | + Args: |
| 99 | + device_name: The name of the device to check (e.g., 'cuda:0', 'cpu'). |
| 100 | + |
| 101 | + Returns: |
| 102 | + True if the device should be logged, False otherwise. |
| 103 | + If device_names is empty, all devices are logged. |
| 104 | + """ |
| 105 | + # If device_names is empty, log all devices |
| 106 | + if not self.device_names: |
| 107 | + return True |
| 108 | + |
| 109 | + # Otherwise, check if device_name is in the list of device names to log |
| 110 | + return device_name in self.device_names |
| 111 | + |
| 112 | + def should_log_module(self, module_name: str) -> bool: |
| 113 | + """Check if a module should be logged based on the name regex patterns. |
| 114 | + |
| 115 | + Args: |
| 116 | + module_name: The name of the module to check. |
| 117 | + |
| 118 | + Returns: |
| 119 | + True if the module should be logged, False otherwise. |
| 120 | + If no patterns are defined, all modules are logged. |
| 121 | + If patterns are defined, the module is logged if it matches ANY pattern. |
| 122 | + """ |
| 123 | + from vllm.logger import init_logger |
| 124 | + logger = init_logger(__name__) |
| 125 | + |
| 126 | + # If no patterns are defined, log all modules |
| 127 | + if not self._module_name_patterns: |
| 128 | + logger.debug(f"No patterns defined, will log module: {module_name}") |
| 129 | + return True |
| 130 | + |
| 131 | + # Check if the module name matches any of the patterns |
| 132 | + for i, pattern in enumerate(self._module_name_patterns): |
| 133 | + match = pattern.search(module_name) |
| 134 | + if match: |
| 135 | + logger.info(f"Module {module_name} matches pattern {i}: '{pattern.pattern}'") |
| 136 | + return True |
| 137 | + |
| 138 | + # For debugging, log at a higher level when we're checking layer modules |
| 139 | + if "layer" in module_name or "embed" in module_name: |
| 140 | + logger.info(f"Module {module_name} doesn't match any patterns: {[p.pattern for p in self._module_name_patterns]}") |
| 141 | + else: |
| 142 | + logger.debug(f"Module {module_name} doesn't match any patterns") |
| 143 | + return False |
| 144 | + |
| 145 | + def increment_step(self) -> None: |
| 146 | + """Increment the current step counter.""" |
| 147 | + self.current_step += 1 |
| 148 | + |
| 149 | + def reset_step(self) -> None: |
| 150 | + """Reset the current step counter to zero.""" |
| 151 | + self.current_step = 0 |
| 152 | + |
| 153 | + def to_dict(self) -> dict: |
| 154 | + """Convert the config to a dictionary for serialization.""" |
| 155 | + return { |
| 156 | + "output_dir": self.output_dir, |
| 157 | + "module_name_regex": self.module_name_regex, |
| 158 | + "log_step_ids": self.log_step_ids, |
| 159 | + "max_tensor_size": self.max_tensor_size, |
| 160 | + "enabled": self.enabled, |
| 161 | + "current_step": self.current_step, |
| 162 | + "device_names": self.device_names |
| 163 | + } |
| 164 | + |
| 165 | + @classmethod |
| 166 | + def from_dict(cls, config_dict: dict) -> "IntermediateLoggingConfig": |
| 167 | + """Create a config instance from a dictionary. |
| 168 | + |
| 169 | + Args: |
| 170 | + config_dict: Dictionary containing configuration parameters. |
| 171 | + |
| 172 | + Returns: |
| 173 | + An IntermediateLoggingConfig instance. |
| 174 | + """ |
| 175 | + # Filter out unknown parameters |
| 176 | + known_params = {"output_dir", "module_name_regex", "log_step_ids", |
| 177 | + "max_tensor_size", "enabled", "current_step", "device_names"} |
| 178 | + filtered_dict = {k: v for k, v in config_dict.items() if k in known_params} |
| 179 | + |
| 180 | + # Handle backward compatibility with log_step_interval |
| 181 | + if "log_step_interval" in config_dict and "log_step_ids" not in filtered_dict: |
| 182 | + interval = config_dict["log_step_interval"] |
| 183 | + if interval > 0: |
| 184 | + # Convert interval to step IDs (first few steps) |
| 185 | + filtered_dict["log_step_ids"] = list(range(0, 10 * interval, interval)) |
| 186 | + |
| 187 | + return cls(**filtered_dict) |
0 commit comments