Skip to content

Commit 96dfbd9

Browse files
committed
fix: bump rotary and adjust top level images
1 parent 1670026 commit 96dfbd9

File tree

5 files changed

+35
-16
lines changed

5 files changed

+35
-16
lines changed

benches/index.md

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
## [Layer Norm](layer_norm/)
55

66
<div class="artifact-preview">
7-
<object data="layer_norm/results/artifacts/combine/latency.svg" type="image/svg+xml" width="800">
8-
</object>
7+
<img src="layer_norm/results/artifacts/combine/latency.svg" alt="Layer Norm Latency" width="800">
98
</div>
109

1110
| Implementation | Description |
@@ -16,8 +15,7 @@
1615
## [Rotary Position Embeddings](rotary/)
1716

1817
<div class="artifact-preview">
19-
<object data="rotary/results/artifacts/combine/latency.svg" type="image/svg+xml" width="800">
20-
</object>
18+
<img src="rotary/results/artifacts/combine/latency.svg" alt="Rotary Position Embeddings Latency" width="800">
2119
</div>
2220

2321
| Implementation | Description |
@@ -28,8 +26,7 @@
2826
## [Flash Attention](flash_attn/)
2927

3028
<div class="artifact-preview">
31-
<object data="flash_attn/results/artifacts/combine/latency.svg" type="image/svg+xml" width="800">
32-
</object>
29+
<img src="flash_attn/results/artifacts/combine/latency.svg" alt="Flash Attention Latency" width="800">
3330
</div>
3431

3532
| Implementation | Description |
@@ -44,8 +41,7 @@
4441
## [Causal Conv1D](causal_conv1d/)
4542

4643
<div class="artifact-preview">
47-
<object data="causal_conv1d/results/artifacts/combine/latency.svg" type="image/svg+xml" width="800">
48-
</object>
44+
<img src="causal_conv1d/results/artifacts/combine/latency.svg" alt="Causal Conv1D Latency" width="800">
4945
</div>
5046

5147
| Implementation | Description |
@@ -56,8 +52,7 @@
5652
## [Activation](activation/)
5753

5854
<div class="artifact-preview">
59-
<object data="activation/results/artifacts/combine/latency.svg" type="image/svg+xml" width="800">
60-
</object>
55+
<img src="activation/results/artifacts/combine/latency.svg" alt="Activation Latency" width="800">
6156
</div>
6257

6358
| Implementation | Description |
@@ -68,8 +63,7 @@
6863
## [ReLU](relu/)
6964

7065
<div class="artifact-preview">
71-
<object data="relu/results/artifacts/combine/latency.svg" type="image/svg+xml" width="800">
72-
</object>
66+
<img src="relu/results/artifacts/combine/latency.svg" alt="ReLU Latency" width="800">
7367
</div>
7468

7569
| Implementation | Description |

benches/rotary/impls/hf_kernels_rotary.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,5 +60,6 @@ run_benchmark(
6060
impl_name="hf_kernels_rotary",
6161
impl_tags={"family": "hf-kernels", "backend": "cuda"},
6262
impl_func=hf_kernels_rotary,
63+
dtype="float32",
6364
)
6465
```

tools/kernels_benchmark_tools/__init__.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,15 @@ def run_benchmark(
5050
impl_name: str | None = None,
5151
impl_tags: dict | None = None,
5252
impl_func=None,
53+
reps: int = 5,
54+
warmup: int = 2,
55+
dtype: str | None = None,
56+
device: str | None = None,
5357
**kwargs,
5458
):
5559
# Determine device and dtype (TODO: allow user override)
56-
device = "cuda" if torch.cuda.is_available() else "cpu"
57-
dtype = "float32" if device == "cpu" else "bfloat16"
60+
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
61+
dtype = dtype or ("float32" if device == "cpu" else "bfloat16")
5862

5963
# Get the kernel module based on type (TODO: handle invalid type)
6064
kernel_module = KERNEL_MODULES[kernel_type]
@@ -77,8 +81,8 @@ def run_benchmark(
7781
run(
7882
wl,
7983
jsonl=f"{kernel_type.value}.jsonl",
80-
reps=5,
81-
warmup=2,
84+
reps=reps,
85+
warmup=warmup,
8286
gen=kernel_module.gen_inputs,
8387
ref=kernel_module.ref_impl,
8488
cmp=kernel_module.cmp_allclose,

tools/kernels_benchmark_tools/core/harness.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ def run(
9292
env = _env_block()
9393
now = lambda: time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
9494

95+
# clear old results
96+
if os.path.exists(jsonl):
97+
os.remove(jsonl)
98+
9599
for wl in workloads:
96100
inputs = gen(wl)
97101
ref_out = ref(inputs)

tools/kernels_benchmark_tools/rotary.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,19 @@ def workloads(dtype="float32", device="cuda") -> Iterable[dict]:
130130
"device": device,
131131
"seed": 0,
132132
}
133+
134+
135+
# single workload for quick testing
136+
def _workloads(dtype="float32", device="cuda") -> Iterable[dict]:
137+
print("✅ Using single workload for quick testing.")
138+
yield {
139+
"name": f"{device}_B1_S128_H8_D64_R32",
140+
"batch": 1,
141+
"seqlen": 128,
142+
"num_heads": 8,
143+
"head_dim": 64,
144+
"rotary_dim": 32,
145+
"dtype": dtype,
146+
"device": device,
147+
"seed": 0,
148+
}

0 commit comments

Comments
 (0)