Skip to content

Commit bad7307

Browse files
committed
more changes
1 parent 1bee8a1 commit bad7307

File tree

10 files changed

+476
-477
lines changed

10 files changed

+476
-477
lines changed

debug_logging.sh

Lines changed: 0 additions & 8 deletions
This file was deleted.
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Intermediate Tensor Logging
2+
3+
This document provides guidance on using the intermediate tensor logging feature in vLLM, which allows you to capture and save intermediate tensors during model execution.
4+
5+
## Overview
6+
7+
The intermediate tensor logging feature enables you to:
8+
9+
- Log input and output tensors from a configured set of filters
10+
- Filter modules by name using regex patterns
11+
- Filter module fwd call index (e.g. dump 2nd call of forward pass on same module)
12+
- Filter tensors by device
13+
- Filter whole model fwd step id
14+
15+
16+
This is manily useful for debugging model accucacy gaps with 2 runs
17+
18+
19+
## Usage
20+
21+
### Enabling via Configuration File
22+
23+
The easiest way to enable intermediate logging is by providing a configuration file:
24+
25+
```bash
26+
python -m vllm.entrypoints.openai.api_server \
27+
--model <your_model> \
28+
--il-config-path /path/to/config.json
29+
```
30+
31+
### Configuration Options
32+
33+
The configuration file should be a JSON file with the following structure:
34+
35+
```json
36+
{
37+
"output_dir": "/tmp/vllm_intermediates",
38+
"module_call_match": ["layers\\.0\\.(?!.*rotary_emb).*", "rotary_emb:0", "embed_tokens", "model\\.norm"],
39+
"log_step_ids": [0, 1],
40+
"device_names": ["cuda:0"]
41+
}
42+
```
43+
44+
#### Configuration Parameters
45+
46+
| Parameter | Type | Description | Default |
47+
|-----------|------|-------------|---------|
48+
| `output_dir` | string | Directory where to save the intermediate tensors | `/tmp/vllm_intermediates` |
49+
| `module_call_match` | array | Regex patterns to filter module names, if limti to ith call only, add `:i` | `null` (log all modules) |
50+
| `log_step_ids` | array | List of step IDs to log | `[0]` |
51+
| `max_tensor_size` | integer | Maximum number of elements in tensors to log | `null` (no limit) |
52+
| `device_names` | array | List of device names to log | `[]` (log all devices) |
53+
54+
### Output Directory Structure
55+
56+
When you enable intermediate logging, the system creates a timestamped directory under your specified `output_dir`. This helps organize multiple logging sessions:
57+
58+
```
59+
/tmp/vllm_intermediates/010fed05-4a36-4c19-ab44-7cd67e3f63ce/
60+
└── step_0
61+
├── model.embed_tokens
62+
│ ├── inputs_0_cuda_0.pt
63+
│ ├── inputs.json
64+
│ ├── outputs_cuda_0.pt
65+
│ └── outputs.json
66+
├── model.layers.0.input_layernorm
67+
│ ├── inputs_0_cuda_0.pt
68+
│ ├── inputs.json
69+
│ ├── outputs_cuda_0.pt
70+
│ └── outputs.json
71+
└── step_1/
72+
└── ...
73+
```
74+
75+
Each tensor is saved in two formats:
76+
1. `.json` files containing metadata and small tensor values
77+
2. `.pt` files containing the full PyTorch tensors (can be loaded with `torch.load()`)

examples/offline_inference/llm_engine_example.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@ def create_test_prompts() -> list[tuple[str, SamplingParams]]:
1717
(
1818
"A robot may not injure a human being",
1919
SamplingParams(temperature=0.0, logprobs=1, prompt_logprobs=1),
20+
),
21+
(
22+
"To be or not to be,",
23+
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2),
24+
),
25+
(
26+
"What is the meaning of life?",
27+
SamplingParams(n=2, temperature=0.8, top_p=0.95, frequency_penalty=0.1),
2028
)
2129
]
2230

vllm/config.py

Lines changed: 135 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from functools import cached_property
1818
from importlib.util import find_spec
1919
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional,
20-
Protocol, TypeVar, Union, cast, get_args)
20+
Protocol, TypeVar, Union, cast, get_args, List, Set)
21+
from re import Pattern
2122

2223
import regex as re
2324
import torch
@@ -3952,6 +3953,119 @@ class KVEventsConfig:
39523953
"""
39533954

39543955

3956+
@config
3957+
@dataclass
3958+
class IntermediateLoggingConfig:
3959+
"""Configuration for intermediate tensor logging."""
3960+
3961+
output_dir: str = "/tmp/vllm_intermediates"
3962+
"""Directory where to save the intermediate tensors."""
3963+
3964+
module_call_match: Optional[List[str]] = None
3965+
"""Match modules by name regex and call index (
3966+
a module can be called multiple times in a step)
3967+
List of regex:call_idx, call_idx is -1 for default for all calls """
3968+
3969+
log_step_ids: List[int] = field(default_factory=lambda: [0])
3970+
"""List of step IDs to log (empty list means log all steps)."""
3971+
3972+
log_post_fwd_inputs: bool = False
3973+
"""Whether logging inputs after forwards for each module"""
3974+
3975+
max_tensor_size: Optional[int] = None
3976+
"""Maximum number of elements in tensors to log (None = no limit)."""
3977+
3978+
enabled: bool = True
3979+
"""Whether logging is enabled."""
3980+
device_names: List[str] = field(default_factory=list)
3981+
"""List of device names to log (empty list means log all devices)."""
3982+
3983+
_compiled_module_calls: dict[Pattern,int] = field(default_factory=dict, init=False)
3984+
"""Compiled regex patterns for module filtering."""
3985+
3986+
_module_call: dict[str, int] = field(default_factory=dict, init=False)
3987+
_step_id_set: Set[int] = field(default_factory=set, init=False)
3988+
"""Set of step IDs for faster lookup."""
3989+
_output_run_dir: str = "/tmp/vllm_intermediates"
3990+
"""Unique directory to save single run/serve logging result."""
3991+
3992+
def __post_init__(self):
3993+
"""Initialize derived fields after instance creation."""
3994+
self._compile_regex_patterns()
3995+
self._output_run_dir = self.output_dir + "/" + str(uuid.uuid4())
3996+
self._step_id_set = set(self.log_step_ids)
3997+
3998+
def _compile_regex_patterns(self):
3999+
"""Compile regex patterns for module name filtering."""
4000+
from vllm.logger import init_logger
4001+
logger = init_logger(__name__)
4002+
4003+
self._compiled_module_matches = []
4004+
4005+
if self.module_call_match is None:
4006+
logger.info("No module name regex patterns provided, will log all modules")
4007+
return
4008+
4009+
# Compile all patterns
4010+
for regex_pattern_call_idx in self.module_call_match:
4011+
try:
4012+
splits = regex_pattern_call_idx.split(":", 2)
4013+
regex_pattern = splits[0]
4014+
call_idx = -1
4015+
if len(splits) > 1:
4016+
call_idx = int(splits[1])
4017+
compiled_pattern: Pattern[str] = re.compile(regex_pattern)
4018+
self._compiled_module_calls[compiled_pattern] = call_idx
4019+
logger.info(f"Successfully compiled regex pattern: '{regex_pattern}'")
4020+
except re.error as e:
4021+
logger.error(f"Invalid regex pattern from '{regex_pattern_call_idx}': {e}")
4022+
raise ValueError(f"Invalid regex pattern '{regex_pattern_call_idx}': {e}")
4023+
except Exception as e:
4024+
logger.error(f"Failed to parse module_call_match")
4025+
4026+
4027+
logger.info(f"Compiled {len(self._compiled_module_calls)} regex patterns")
4028+
4029+
def to_dict(self) -> dict:
4030+
"""Convert the config to a dictionary for serialization."""
4031+
return {
4032+
"output_run_dir": self.output_run_dir,
4033+
"module_call_match": self.module_call_match,
4034+
"log_step_ids": self.log_step_ids,
4035+
"max_tensor_size": self.max_tensor_size,
4036+
"enabled": self.enabled,
4037+
"device_names": self.device_names
4038+
}
4039+
4040+
@classmethod
4041+
def from_dict(cls, dict_value: dict) -> "IntermediateLoggingConfig":
4042+
"""Parse the CLI value for the speculative config."""
4043+
return cls(**dict_value)
4044+
4045+
@property
4046+
def output_run_dir(self) -> str:
4047+
return self._output_run_dir
4048+
4049+
def compute_hash(self) -> str:
4050+
"""
4051+
WARNING: Whenever a new field is added to this config,
4052+
ensure that it is included in the factors list if
4053+
it affects the computation graph.
4054+
4055+
Provide a hash that uniquely identifies all the configs
4056+
that affect the structure of the computation
4057+
graph from input ids/embeddings to the final hidden states,
4058+
excluding anything before input ids/embeddings and after
4059+
the final hidden states.
4060+
"""
4061+
# Intermediate logging doesn't affect the computation graph
4062+
factors: list[Any] = []
4063+
hash_str = hashlib.md5(str(factors).encode(),
4064+
usedforsecurity=False).hexdigest()
4065+
return hash_str
4066+
4067+
4068+
39554069
class CompilationLevel:
39564070
# constants for the levels of the compilation process
39574071
NO_COMPILATION = 0
@@ -4409,6 +4523,10 @@ class VllmConfig:
44094523
"""The configurations for distributed KV cache transfer."""
44104524
kv_events_config: Optional[KVEventsConfig] = None
44114525
"""The configurations for event publishing."""
4526+
il_config: Optional[IntermediateLoggingConfig] = None
4527+
"""Configuration for intermediate tensor logging."""
4528+
il_config_path: Optional[str] = None
4529+
"""Path to a JSON file containing intermediate logging configuration."""
44124530
# some opaque config, only used to provide additional information
44134531
# for the hash computation, mainly used for testing, debugging or out of
44144532
# tree config registration.
@@ -4497,6 +4615,10 @@ def compute_hash(self) -> str:
44974615
vllm_factors.append(self.kv_transfer_config.compute_hash())
44984616
else:
44994617
vllm_factors.append("None")
4618+
if self.il_config:
4619+
vllm_factors.append(self.il_config.compute_hash())
4620+
else:
4621+
vllm_factors.append("None")
45004622
if self.additional_config:
45014623
if isinstance(additional_config := self.additional_config, dict):
45024624
additional_config_hash = hashlib.md5(
@@ -4705,6 +4827,18 @@ def __post_init__(self):
47054827
if self.kv_events_config is not None:
47064828
# Hybrid KV cache manager is not compatible with KV events.
47074829
self.scheduler_config.disable_hybrid_kv_cache_manager = True
4830+
# Load intermediate logging config from file if provided
4831+
if self.il_config_path is not None:
4832+
try:
4833+
logger.info(f"Loading intermediate logging config from {self.il_config_path}")
4834+
with open(self.il_config_path, 'r') as f:
4835+
il_config_dict = json.load(f)
4836+
self.il_config = IntermediateLoggingConfig.from_dict(il_config_dict)
4837+
logger.info(f"Successfully loaded intermediate logging config: {self.il_config.to_dict()}")
4838+
except Exception as e:
4839+
logger.error(f"Failed to load intermediate logging config from {self.il_config_path}: {e}")
4840+
raise ValueError(f"Failed to load intermediate logging config from {self.il_config_path}: {e}") from e
4841+
47084842

47094843
def update_sizes_for_sequence_parallelism(self,
47104844
possible_sizes: list) -> list:

vllm/engine/arg_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,7 @@ class EngineArgs:
440440
ParallelConfig.enable_multimodal_encoder_data_parallel
441441

442442
async_scheduling: bool = SchedulerConfig.async_scheduling
443+
il_config_path: Optional[str] = None
443444

444445
def __post_init__(self):
445446
# support `EngineArgs(compilation_config={...})`
@@ -858,6 +859,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
858859
**vllm_kwargs["compilation_config"])
859860
vllm_group.add_argument("--additional-config",
860861
**vllm_kwargs["additional_config"])
862+
vllm_group.add_argument("--il-config-path",
863+
**vllm_kwargs["il_config_path"])
861864

862865
# Other arguments
863866
parser.add_argument('--use-v2-block-manager',
@@ -1276,7 +1279,6 @@ def create_engine_config(
12761279
otlp_traces_endpoint=self.otlp_traces_endpoint,
12771280
collect_detailed_traces=self.collect_detailed_traces,
12781281
)
1279-
12801282
config = VllmConfig(
12811283
model_config=model_config,
12821284
cache_config=cache_config,
@@ -1292,6 +1294,7 @@ def create_engine_config(
12921294
compilation_config=self.compilation_config,
12931295
kv_transfer_config=self.kv_transfer_config,
12941296
kv_events_config=self.kv_events_config,
1297+
il_config_path=self.il_config_path,
12951298
additional_config=self.additional_config,
12961299
)
12971300

vllm/v1/engine/core.py

Lines changed: 6 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@
4343
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
4444
from vllm.v1.structured_output import StructuredOutputManager
4545
from vllm.version import __version__ as VLLM_VERSION
46+
from vllm.v1.worker.intermediates_logging import (
47+
intermediate_logging
48+
)
4649

4750
logger = init_logger(__name__)
4851

@@ -73,33 +76,7 @@ def __init__(self,
7376

7477
# Setup Model.
7578
self.model_executor = executor_class(vllm_config)
76-
77-
78-
# Hook the dumping logic
79-
from vllm.v1.worker.il_config import IntermediateLoggingConfig
80-
81-
# Create a configuration for intermediate logging
82-
logger.info("Setting up intermediate tensor logging")
83-
84-
# Define the regex patterns to match module names
85-
# These patterns will match any module with "layers.0" or "embed_tokens" in its name
86-
module_patterns = ["layers\\.0", "embed_tokens"]
87-
logger.info(f"Using module name regex patterns: {module_patterns}")
88-
89-
il_config = IntermediateLoggingConfig(
90-
output_dir="/tmp/vllm_intermediates", # Directory to save intermediates
91-
module_name_regex=module_patterns, # Log layer 0 and embedding modules
92-
log_step_ids=[0, 1, 2, 3, 4, 5], # Log steps 0-5
93-
max_tensor_size=1000000, # Limit to 1M elements
94-
enabled=True # Enable logging
95-
)
96-
97-
logger.info(f"Intermediate logging config: {il_config.to_dict()}")
98-
99-
# Register hooks for intermediate tensor logging
100-
logger.info("Calling register_intermediate_hooks via collective_rpc")
101-
self.collective_rpc("register_intermediate_hooks", args=(il_config,))
102-
logger.info("Finished setting up intermediate tensor logging")
79+
self.collective_rpc("register_intermediate_hooks", args=(vllm_config.il_config,))
10380

10481
if executor_fail_callback is not None:
10582
self.model_executor.register_failure_callback(
@@ -243,19 +220,9 @@ def abort_requests(self, request_ids: list[str]):
243220

244221
def execute_model(self, scheduler_output: SchedulerOutput):
245222
try:
246-
# Increment the step counter for intermediate logging
247-
try:
248-
from vllm.v1.worker.intermediates_logging import increment_step
249-
logger.info("Incrementing intermediate logging step counter before model execution")
250-
increment_step()
251-
except Exception as e:
252-
logger.warning(f"Failed to increment intermediate logging step counter: {e}")
253-
254223
# Execute the model
255-
result = self.model_executor.execute_model(scheduler_output)
256-
257-
logger.info(f"Model execution completed for step with {scheduler_output.total_num_scheduled_tokens} tokens")
258-
return result
224+
with intermediate_logging(self.vllm_config.il_config):
225+
return self.model_executor.execute_model(scheduler_output)
259226
except Exception as err:
260227
# We do not want to catch BaseException here since we're only
261228
# interested in dumping info when the exception is due to an

0 commit comments

Comments
 (0)