Skip to content

Roar forward stats. #25

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions tests/v1/engine/test_engine_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,134 @@
data_parallel_rank=None,
)

import pandas as pd
import os

Check failure on line 51 in tests/v1/engine/test_engine_core.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E402)

tests/v1/engine/test_engine_core.py:51:1: E402 Module level import not at top of file
import time

Check failure on line 52 in tests/v1/engine/test_engine_core.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E402)

tests/v1/engine/test_engine_core.py:52:1: E402 Module level import not at top of file
import torch

Check failure on line 53 in tests/v1/engine/test_engine_core.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E402)

tests/v1/engine/test_engine_core.py:53:1: E402 Module level import not at top of file

def time_prefill_single_request(engine_core: EngineCore, prefix_length: int):
import time
import torch
request = make_request()
request.prompt_token_ids = [13] * prefix_length

# warm up
engine_core.add_request(request)
scheduler_output = engine_core.scheduler.schedule()
assert scheduler_output.total_num_scheduled_tokens == prefix_length
engine_core.execute_model(scheduler_output)
engine_core.abort_requests([request.request_id])

torch.cuda.synchronize()

# time prefill
request = make_request()
request.prompt_token_ids = [14] * prefix_length # different token
engine_core.add_request(request)
scheduler_output = engine_core.scheduler.schedule()
engine_core.execute_model(scheduler_output)

# reset prefix cache
engine_core.abort_requests([request.request_id])
assert engine_core.scheduler.reset_prefix_cache()
return float(os.environ.get("FORWARD_TIME"))


def time_decoding_single_request(engine_core: EngineCore, batch_size: int):
request = make_request()
context_length = 8064
request.prompt_token_ids = [13] * context_length

# generatate kv cache
engine_core.add_request(request)
scheduler_output = engine_core.scheduler.schedule()
engine_core.execute_model(scheduler_output)
engine_core.abort_requests([request.request_id])

# warmup decoding
request_ids = []
for i in range(batch_size):
request = make_request()
request.prompt_token_ids = [13] * context_length + [i]
request_ids.append(request.request_id)
engine_core.add_request(request)

scheduler_output = engine_core.scheduler.schedule()
assert scheduler_output.total_num_scheduled_tokens == batch_size
engine_core.execute_model(scheduler_output)
for request_id in request_ids:
engine_core.abort_requests([request_id])

torch.cuda.synchronize()

# time decoding
request_ids = []
for i in range(batch_size):
request = make_request()
request.prompt_token_ids = [13] * context_length + [batch_size + i] # add batch_size as a token offset to avoid cache hit
request_ids.append(request.request_id)
engine_core.add_request(request)

scheduler_output = engine_core.scheduler.schedule()
engine_core.execute_model(scheduler_output)

# reset prefix cache
for request_id in request_ids:
engine_core.abort_requests([request_id])
assert engine_core.scheduler.reset_prefix_cache()
return float(os.environ.get("FORWARD_TIME"))


@create_new_process_for_each_test()
def test_engine_core_roar(monkeypatch: pytest.MonkeyPatch):
prefill_lens = [8192, 7168, 6144, 5120, 4096, 3072, 2048, 1024, 896, 768, 640, 512, 384, 256]
batch_sizes = [128, 120, 112, 104, 96, 88, 80, 72, 64, 56, 48, 40, 32, 24, 16, 8, 4]

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
m.setenv("FORWARD_TIME", "-1")
"""Setup the EngineCore."""
model_path = "/kaiju-oss-models/prod_hf/roar_sft_adventure_safe.05-01-22-00-25.uc33.classi_only.tp1.minimal"
engine_args = EngineArgs(
model=model_path,
trust_remote_code=True,
max_num_batched_tokens=8192,
quantization="fp8",
) # make sure we can do 8064 + 128
vllm_config = engine_args.create_engine_config()

Check failure on line 144 in tests/v1/engine/test_engine_core.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

tests/v1/engine/test_engine_core.py:144:81: E501 Line too long (116 > 80)
executor_class = Executor.get_class(vllm_config)

with set_default_torch_num_threads(1):
engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True)

# Run benchmarks for all prefill lengths
print("============Prefill============")
warmup_stats = {}
for prefill_len in prefill_lens:
execution_time = time_prefill_single_request(engine_core, prefill_len)
warmup_stats[("prefill", prefill_len)] = execution_time
print(f"EngineCore execute_model execution time for {prefill_len} tokens: {execution_time:.6f} seconds")

# Run benchmarks for all decoding batch sizes
print("============Decoding============")
for batch_size in batch_sizes:
execution_time = time_decoding_single_request(engine_core, batch_size)
warmup_stats[("decoding", batch_size)] = execution_time
print(f"EngineCore execute_model execution time for {batch_size} requests: {execution_time:.6f} seconds")

# Generate pandas DataFrame table

Check failure on line 167 in tests/v1/engine/test_engine_core.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

tests/v1/engine/test_engine_core.py:167:81: E501 Line too long (113 > 80)
modified_warmup_stats: list[tuple[str, int, float]] = [(k1, k2, v * 1000) for (k1, k2), v in warmup_stats.items()]
tab = pd.DataFrame(modified_warmup_stats, columns=["method", "size", "MS"]).assign(tag="WARMUPSTAT")

print("\nBenchmark Results Table:")
print(tab.to_string(index=False))

@create_new_process_for_each_test()
def test_engine_core(monkeypatch: pytest.MonkeyPatch):

with monkeypatch.context() as m:

Check failure on line 177 in tests/v1/engine/test_engine_core.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

tests/v1/engine/test_engine_core.py:177:81: E501 Line too long (114 > 80)
m.setenv("VLLM_USE_V1", "1")
"""Setup the EngineCore."""
engine_args = EngineArgs(model=MODEL_NAME)
Expand Down
8 changes: 8 additions & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1363,6 +1363,7 @@ def execute_model(

# Run the model.
# Use persistent buffers for CUDA graphs.
start_time = time.perf_counter()
with set_forward_context(
attn_metadata,
self.vllm_config,
Expand All @@ -1378,11 +1379,18 @@ def execute_model(
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
torch.cuda.synchronize()
forward_time = time.perf_counter() - start_time

self.maybe_wait_for_kv_save()
finished_sending, finished_recving = (
self.get_finished_kv_transfers(scheduler_output))

import os
if "FORWARD_TIME" in os.environ:
os.environ["FORWARD_TIME"] = str(forward_time)
return EMPTY_MODEL_RUNNER_OUTPUT

if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
Expand Down
Loading