Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.9'
python-version: '3.12'

- name: Install pre-commit hook
run: |
Expand Down
12 changes: 8 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,20 @@ repos:
args: ["--profile=black"]
exclude: ^python/sgl_jax/test/run_eval\.py$
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.7
rev: v0.13.3
hooks:
- id: ruff
args: [--select=F401, --fixable=F401]
files: ^(benchmark/|docs/|examples/)
# Ruff is lint-only; formatting is handled by Black.
# Do not add ruff-format to avoid conflicts with Black.
- id: ruff-check
args: [--output-format, github, --fix]
files: ^(python/|benchmark/|docs/|examples/)
exclude: \.ipynb$
- repo: https://github.com/psf/black
rev: 24.10.0
hooks:
- id: black-jupyter
args: ["--config", "python/pyproject.toml"]
# Black is the only formatter; keep Ruff formatting disabled.
# exclude: >
# (?x)^(
# python/sgl_jax/srt/entrypoints/openai/serving_rerank\.py|
Expand Down
7 changes: 2 additions & 5 deletions benchmark/kernels/flash_attention/bench_flashattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,7 @@ def main():
)

for max_num_batched_tokens in max_num_batched_tokens_config:
if (
q_head_num < kv_head_num
or q_head_num % kv_head_num != 0
):
if q_head_num < kv_head_num or q_head_num % kv_head_num != 0:
continue
all_combinations.append(
(
Expand Down Expand Up @@ -226,7 +223,7 @@ def main():
except Exception as e:
raise ValueError(f"run failed: {e=}")

print(f"cost: {flash_time*1000}ms")
print(f"cost: {flash_time * 1000}ms")


if __name__ == "__main__":
Expand Down
13 changes: 4 additions & 9 deletions benchmark/kernels/flash_attention/get_block_spec_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,7 @@ def main():
for page_size in page_size_config:
for max_kv_cache_tokens in max_kv_cache_tokens_config:
for max_num_batched_tokens in max_num_batched_tokens_config:
if (
q_head_num < kv_head_num
or q_head_num % kv_head_num != 0
):
if q_head_num < kv_head_num or q_head_num % kv_head_num != 0:
continue
all_combinations.append(
(
Expand All @@ -202,7 +199,7 @@ def main():
block_spec_configs.append((num_kv_pages_per_blk, num_queries_per_block))

print(
f"(q_dtype, kv_dtype, num_q_heads_per_blk, num_kv_heads_per_blk, head_dim, page_size, max_num_batched_tokens): (num_kv_pages_per_block, num_queries_per_block)"
"(q_dtype, kv_dtype, num_q_heads_per_blk, num_kv_heads_per_blk, head_dim, page_size, max_num_batched_tokens): (num_kv_pages_per_block, num_queries_per_block)"
)

for i, (
Expand All @@ -215,9 +212,7 @@ def main():
) in enumerate(all_combinations):
best_output = inf
best_config = None
for i, (num_kv_pages_per_blk, num_queries_per_block) in enumerate(
block_spec_configs
):
for i, (num_kv_pages_per_blk, num_queries_per_block) in enumerate(block_spec_configs):
try:
(
flash_time,
Expand All @@ -237,7 +232,7 @@ def main():
if flash_time < best_output:
best_output = flash_time
best_config = (num_kv_pages_per_blk, num_queries_per_block)
except Exception as e:
except Exception:
pass

print(
Expand Down
20 changes: 5 additions & 15 deletions benchmark/kernels/flash_attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ def create_kv_cache_data(
return kv_cache


def create_qkv_data(
total_tokens, q_head_num, kv_head_num, head_dim, dtype=jnp.bfloat16, seed=42
):
def create_qkv_data(total_tokens, q_head_num, kv_head_num, head_dim, dtype=jnp.bfloat16, seed=42):
key = jax.random.PRNGKey(seed)
keys = jax.random.split(key, 3)
q = jax.random.normal(keys[0], (total_tokens, q_head_num, head_dim), dtype=dtype)
Expand All @@ -27,14 +25,10 @@ def create_qkv_data(
return q, k, v


def create_page_indices_data(
num_seqs, total_kv_tokens, seq_lens, max_context_len, page_size=128
):
def create_page_indices_data(num_seqs, total_kv_tokens, seq_lens, max_context_len, page_size=128):
cache_loc = jnp.arange(0, total_kv_tokens, dtype=jnp.int32)

cache_start_idx = jnp.concatenate(
[jnp.array([0], dtype=jnp.int32), jnp.cumsum(seq_lens)]
)
cache_start_idx = jnp.concatenate([jnp.array([0], dtype=jnp.int32), jnp.cumsum(seq_lens)])

cache_loc_list = []
for i in range(num_seqs):
Expand Down Expand Up @@ -130,9 +124,7 @@ def create_decode_uniform_data(
):
batch_size = max_num_batched_tokens
# hackly set prefix len to 2048-4096 for decode one seq in random
random_prefix_lens = jax.random.randint(
jax.random.PRNGKey(42), (batch_size,), 1024, 2048
)
random_prefix_lens = jax.random.randint(jax.random.PRNGKey(42), (batch_size,), 1024, 2048)
seq_lens = random_prefix_lens + 1
cu_q_lens = jnp.concatenate(
[
Expand All @@ -146,9 +138,7 @@ def create_decode_uniform_data(
jnp.cumsum(seq_lens),
]
)
q, k, v = create_qkv_data(
batch_size, q_head_num, kv_head_num, head_dim, dtype, seed
)
q, k, v = create_qkv_data(batch_size, q_head_num, kv_head_num, head_dim, dtype, seed)
kv_cache = create_kv_cache_data(
max_kv_cache_tokens,
kv_head_num,
Expand Down
14 changes: 4 additions & 10 deletions benchmark/kernels/megablox_gmm/bench_megablox_gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def main():
)

print(
f"Config {valid_config_count}: m={adjusted_m}, k={k}, n={n}, groups={num_groups}, group_size={adjusted_m//num_groups}"
f"Config {valid_config_count}: m={adjusted_m}, k={k}, n={n}, groups={num_groups}, group_size={adjusted_m // num_groups}"
)

try:
Expand Down Expand Up @@ -186,15 +186,9 @@ def main():
worst_config = max(results, key=lambda x: x["megablox_ms"])

print("-" * 80)
print(
f"Best performance: {best_config['config']} - {best_config['megablox_ms']:.2f} ms"
)
print(
f"Worst performance: {worst_config['config']} - {worst_config['megablox_ms']:.2f} ms"
)
print(
f"Speedup ratio: {worst_config['megablox_ms'] / best_config['megablox_ms']:.2f}x"
)
print(f"Best performance: {best_config['config']} - {best_config['megablox_ms']:.2f} ms")
print(f"Worst performance: {worst_config['config']} - {worst_config['megablox_ms']:.2f} ms")
print(f"Speedup ratio: {worst_config['megablox_ms'] / best_config['megablox_ms']:.2f}x")


if __name__ == "__main__":
Expand Down
14 changes: 5 additions & 9 deletions benchmark/kernels/update_kv_cache/bench_update_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,9 @@ def main():
head_num,
head_dim,
)
max_num_slices_per_block_config = get_num_slices_per_block(
new_value, cache, page_size
)
random_cache_loc, slice_lens, new_value_start_loc, update_slices_num = (
create_input_params(max_cache_len, new_value_len, page_size=page_size)
max_num_slices_per_block_config = get_num_slices_per_block(new_value, cache, page_size)
random_cache_loc, slice_lens, new_value_start_loc, update_slices_num = create_input_params(
max_cache_len, new_value_len, page_size=page_size
)

print(
Expand All @@ -160,12 +158,10 @@ def main():
if cost < min_cost:
min_cost = cost
fastest_num_slices_per_block = num_slices_per_block
print(
f"[num_slices_per_block={num_slices_per_block}] avg cost: {cost*1000} ms"
)
print(f"[num_slices_per_block={num_slices_per_block}] avg cost: {cost * 1000} ms")

print(
f"Fastest [num_slices_per_block={fastest_num_slices_per_block}] costs: {min_cost*1000} ms"
f"Fastest [num_slices_per_block={fastest_num_slices_per_block}] costs: {min_cost * 1000} ms"
)


Expand Down
11 changes: 3 additions & 8 deletions benchmark/kernels/update_kv_cache/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,8 @@ def create_bench_data(
):
key = jax.random.PRNGKey(42)
keys = jax.random.split(key, 3)
new_value = jax.random.normal(
keys[1], (new_kv_len, kv_head_num, head_dim), dtype=dtype
)
cache = jax.random.normal(
keys[2], (cache_max_tokens, kv_head_num, head_dim), dtype=dtype
)
new_value = jax.random.normal(keys[1], (new_kv_len, kv_head_num, head_dim), dtype=dtype)
cache = jax.random.normal(keys[2], (cache_max_tokens, kv_head_num, head_dim), dtype=dtype)
return new_value, cache


Expand All @@ -37,8 +33,7 @@ def create_random_cache_start_loc(cache_max_tokens, new_kv_len, page_size=128):
new_value_page_num = cdiv(new_kv_len, page_size)
max_cache_page_num = cdiv(cache_max_tokens, page_size)
cache_start_loc = (
jax.random.randint(key, (new_value_page_num,), 0, max_cache_page_num - 1)
* page_size
jax.random.randint(key, (new_value_page_num,), 0, max_cache_page_num - 1) * page_size
)
return cache_start_loc

Expand Down
39 changes: 39 additions & 0 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,42 @@ exclude = [
"scripts*",
"tests*",
]

[tool.black]
line-length = 100

[tool.ruff]
line-length = 100

[tool.ruff.lint]
select = [
# pycodestyle
"E",
# Pyflakes
"F",
# pyupgrade
"UP",
# flake8-bugbear
"B",
# flake8-simplify
"SIM",
# isort (handled by pre-commit isort hook)
# "I",
# flake8-logging-format
"G",
]
ignore = [
# line too long handled by Black
"E501",
# star imports
"F405", "F403",
# lambda expression assignment
"E731",
# zip without `strict=`
"B905",
# Loop control variable not used within loop body
"B007",
]

[tool.ruff.format]
docstring-code-format = true
Loading