vllm: arch-aware LayerAccessor + HF↔vLLM equivalence tests#47
Conversation
Replaces the misleading `_is_vllm_layers_output` predicate with explicit `InputLayout` / `OutputLayout` enums detected per accessor. Input layout comes from forward-signature inspection; output layout from a runtime probe of layer 0; source inspection of the parent decoder layer disambiguates positional vs kwargs sub-module calls (vLLM Llama calls `self.self_attn(positions=..., hidden_states=...)` but `self.mlp(hidden_states)` positionally). Fixes & changes: - `layers_input[i]` for vLLM Llama now returns hidden_states (was returning the int64 positions tensor — first positional arg of the decoder forward). - `attentions_input[i]` for vLLM Llama: same fix via kwargs path. - Layer N>0 residual handling: `_read_input` returns `hidden_states + residual` to recover the combined stream; layer 0 returns hidden_states alone (residual is None there). Numerical sanity: layers_output[i] == layers_input[i+1] exactly. - Dual-stream shape-mismatch check moved into `_infer_output_layout` so it raises `RenamingError` early with a clear message instead of silently broadcasting at the use site. - Setter for vLLM dual-stream output uses in-place index assignment per nnsight VLLM_GUIDE (whole-tuple replacement crashes the engine). - vLLM-only `.clone()` on single-stream reads in `_read_input`/`_read_output` and on `token_embeddings`: vLLM reuses inference-mode buffers across layers (layer N+1's fused RMSNorm mutates layer N's output buffer in-place); the saved reference would surface the post-mutation value. nnsight has clone-on-save for inference-mode tensors but it doesn't catch every path exercised here. Tests: - `tests/test_layout_detection.py` — CPU-only unit tests for the layout inference helpers (12 cases incl. shape-mismatch raise). - `tests/test_vllm_hf_equivalence.py` — load SmolLM2 under both backends, assert each LayerAccessor returns the same hidden states (within bf16 tolerance), plus an HF setter smoke test. Bumps `nnsight` to >=0.7 (pinned via `[tool.uv] exclude-newer-package`). Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
| return self.embed_tokens.output | ||
| """Returns the token embeddings. Equivalent to self.embed_tokens.output. | ||
|
|
||
| Clones for vLLM: the embed_tokens output buffer is reused across |
There was a problem hiding this comment.
@JadenFiotto-Kaufman i'm pretty convinced this is real. Is that a bug to fix upstream on nnsight instead?
|
Repro using pure nnsight (no nnterp) — saves the same proxy twice, once as a bare reference and once with # scratch/repro_clone_needed.py
import gc, torch as th
from nnsight.modeling.vllm import VLLM
MODEL = "HuggingFaceTB/SmolLM2-135M-Instruct" # Llama-arch
PROMPT = "Hello world"
def main():
model = VLLM(MODEL, tensor_parallel_size=1, gpu_memory_utilization=0.3,
dispatch=True, dtype="bfloat16")
try:
with model.trace(PROMPT):
embed_ref = model.model.embed_tokens.output.save()
embed_clone = model.model.embed_tokens.output.clone().save()
l0_args, _ = model.model.layers[0].inputs
l0_hs_ref = l0_args[1].save()
l0_hs_clone = l0_args[1].clone().save()
attn_ref = model.model.layers[0].self_attn.output.save()
attn_clone = model.model.layers[0].self_attn.output.clone().save()
mlp_ref = model.model.layers[0].mlp.output.save()
mlp_clone = model.model.layers[0].mlp.output.clone().save()
def diff(name, a, b):
d = (a.float() - b.float()).abs().max().item()
tag = "OK" if d < 1e-3 else "MUTATED"
print(f" {name:30s} max|ref - clone| = {d:.4f} [{tag}]")
print(f" ref stats: min={a.min().item():.4f} max={a.max().item():.4f} std={a.std().item():.4f}")
print(f" clone stats: min={b.min().item():.4f} max={b.max().item():.4f} std={b.std().item():.4f}")
diff("embed_tokens.output", embed_ref, embed_clone)
diff("layers[0].inputs args[1]", l0_hs_ref, l0_hs_clone)
diff("layers[0].self_attn.output", attn_ref, attn_clone)
diff("layers[0].mlp.output", mlp_ref, mlp_clone)
finally:
if getattr(model, "vllm_entrypoint", None) is not None:
model.vllm_entrypoint.llm_engine.engine_core.shutdown()
VLLM._cleanup_distributed()
del model; gc.collect()
if __name__ == "__main__":
main()Output (single L40, torch 2.9.0, vllm 0.15.1, nnsight 0.7.0)The For The VLLM_GUIDE notes that 🤖 Generated with Claude Code |
Summary
LayerAccessorto be architecture-generic for vLLM (not just Llama). Replaces the misleading_is_vllm_layers_outputpredicate with explicitInputLayout/OutputLayoutenums detected from forward signature + runtime structure + parent-source inspection.layers_input[i]andattentions_input[i]on vLLM Llama: were returning the int64positionstensor (first positional arg of the decoder/attention forward) instead ofhidden_states. Now correctly routed viamodule.inputs[0][1](positional) orkwargs["hidden_states"](vLLM Llama callsself.self_attn(positions=..., hidden_states=...)with kwargs)._read_inputreturnshidden_states + residualto recover the combined stream. Numerical sanity:layers_output[i] == layers_input[i+1]exactly on vLLM Llama._infer_output_layout: raisesRenamingErrorearly instead of being dead code incheck_io(the user pointed out thatlayers_output[i]already crashes on shape mismatch before thecheck_ioblock runs)..clone()on single-stream reads in_read_input/_read_outputand ontoken_embeddings. vLLM reuses inference-mode buffers across layers (layer N+1's fused RMSNorm mutates layer N's output buffer in-place); the saved reference surfaces the post-mutation value otherwise. nnsight's clone-on-save for inference-mode tensors doesn't reach every path here.Test plan
tests/test_layout_detection.py— CPU-only unit tests for_infer_input_layout/_infer_output_layout/_parent_calls_with_kwargs(12 cases). All pass.tests/test_vllm_hf_equivalence.py— loads SmolLM2-135M under HF and vLLM, verifies allLayerAccessoraccessors match within bf16 tolerance. All pass.scratch/llama_residual_check.py(manual) — verifiedlayers_output[i]==layers_input[i+1]to 0.0 max diff for i ∈ {1, 2, 5, 15} on vLLM Llama.test_vllm.py— 9/10 still pass (the one failure,test_vllm_logits, is a pre-existing silent-trace-fail pattern unrelated to this refactor).Notes
nnsight>=0.7(with[tool.uv] exclude-newer-packageto bypass the global age gate for nnsight).ln_final.outputis intentionally excluded from the equivalence test: vLLM Llama's fused RMSNorm returns(normalized, residual)while HF returns a single tensor — the standardization forln_finalis a separate piece of work.nn.Modulewhose forward params are conventionally named (positions,hidden_states,residual). Verified on Llama (dual-stream) and GPT-2 (single-stream); should extend to Qwen2 etc. without code changes.🤖 Generated with Claude Code