diff --git a/.gitignore b/.gitignore index e1e8a564..0be39a29 100644 --- a/.gitignore +++ b/.gitignore @@ -92,3 +92,4 @@ benchmarks/tritonbench site generated uv.lock +docs/examples/ \ No newline at end of file diff --git a/docs/Makefile b/docs/Makefile index 731da962..a275d84c 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -9,8 +9,10 @@ html: clean livehtml: clean sphinx-autobuild "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) --open-browser --port 0 + clean: rm -rf $(BUILDDIR)/* + rm -rf examples/* # Catch-all target: route all unknown targets to Sphinx-Build using the # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). diff --git a/docs/conf.py b/docs/conf.py index 1f50ac4d..0a5b68fd 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -27,6 +27,7 @@ "sphinx.ext.intersphinx", "myst_parser", "sphinx_autodoc_typehints", + "sphinx_gallery.gen_gallery", ] # MyST parser configuration @@ -44,6 +45,16 @@ "tasklist", ] +sphinx_gallery_conf = { + "examples_dirs": [ + "../examples", + ], # path to your example scripts + "gallery_dirs": "examples", # path to where to save gallery generated output + "filename_pattern": r".*\.py$", # Include all Python files + "ignore_pattern": r"__init__\.py", # Exclude __init__.py files + "plot_gallery": "False", # Don't run the examples +} + # Templates path templates_path = ["_templates"] diff --git a/docs/helion_puzzles.rst b/docs/helion_puzzles.rst new file mode 100644 index 00000000..0a25d5ee --- /dev/null +++ b/docs/helion_puzzles.rst @@ -0,0 +1,736 @@ +Helion Puzzles +============== + +Programming for accelerators such as GPUs is critical for modern AI systems. This often means programming directly in proprietary low-level languages such as CUDA. Helion is a Python-embedded domain-specific language (DSL) for authoring machine learning kernels, designed to compile down to Triton, a performant backend for programming GPUs and other devices. + +Helion aims to raise the level of abstraction compared to Triton, making it easier to write correct and efficient kernels while enabling more automation in the autotuning process. + +This set of puzzles is meant to teach you how to use Helion from first principles in an interactive fashion. You will start with trivial examples and build your way up to real algorithms like Flash Attention and Quantized neural networks. + +Setup +----- + +First, let's install the necessary dependencies. Helion requires a recent version of PyTorch and a development version of Triton. + +.. code-block:: python + + import logging + + import helion + import helion.language as hl + import torch + from torch import Tensor + + # If you set this to info you will see the output Triton Code + logging.getLogger().setLevel(logging.WARNING) + +Let's also create a simple testing function to verify our implementations. + +.. code-block:: python + + from triton.testing import do_bench + def test_kernel(kernel_fn, spec_fn, *args): + """Test a Helion kernel against a reference implementation.""" + # Run our implementation + result = kernel_fn(*args) + # Run reference implementation + expected = spec_fn(*args) + + # Check if results match + torch.testing.assert_close(result, expected) + print("✅ Results Match ✅") + + def benchmark_kernel(kernel_fn, *args, **kwargs): + """Benchmark a Helion kernel.""" + no_args = lambda: kernel_fn(*args, **kwargs) + time_in_ms = do_bench(no_args) + print(f"⏱ Time: {time_in_ms} ms") + + def compare_implementations(kernel_fn, spec_fn, *args, **kwargs): + """Benchmark a Helion kernel and its reference implementation.""" + kernel_no_args = lambda: kernel_fn(*args, **kwargs) + spec_no_args = lambda: spec_fn(*args, **kwargs) + kernel_time = do_bench(kernel_no_args) + spec_time = do_bench(spec_no_args) + print(f"⏱ Helion Kernel Time: {kernel_time:.3f} ms, PyTorch Reference Time: {spec_time:.3f} ms, Speedup: {spec_time/kernel_time:.3f}x") + +Basic Structure of a Helion Kernel +--------------------------------- + +Helion allows you to write GPU kernels using familiar PyTorch syntax. + +A Helion kernel has three main sections: + +1. **Host Section** (CPU) + This is standard PyTorch code executed on the CPU. Memory allocation, and shape computations are done here. Like with `Triton` and `Cuda` you need to setup your output buffers on the host before launching your kernel. + +2. **Device Loop** (GPU Grid) + `for tile in hl.tile(sizes)` - defines parallel execution across GPU thread blocks + +3. **Device Operations** (GPU Kernel) + PyTorch operations inside the loop - automatically compiled and fused + +Example: + +.. code-block:: python + + @helion.kernel(config=helion.Config(block_sizes = [128, 128])) # The @helion.kernel decorator marks this function for compilation + def example_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + # Host code: Standard PyTorch operations + m, n = x.size() + out = torch.empty_like(x) # Allocate output tensor + + # The hl.tile loop defines the parallel execution structure + for tile_m, tile_n in hl.tile([m, n]): + # Device code: Everything inside the hl.tile loop runs on GPU + out[tile_m, tile_n] = x[tile_m, tile_n] + y[tile_m, tile_n] # Simple element-wise addition expressed w/ pytorch ops + + return out # Return the result back to the host + + # Create some sample data + x = torch.randn(10, 10, device="cuda") + y = torch.randn(10, 10, device="cuda") + + # Run the kernel + result = example_add(x, y) + + # Verify result + expected = x + y + torch.testing.assert_close(result, expected) + print("✅ Results Match ✅") + benchmark_kernel(example_add, x, y) + compare_implementations(example_add, torch.add, x, y) + +Autotuning in Helion +-------------------- + +In the previous example, we explicitly specified a configuration using `config=helion.Config(block_sizes=[128, 128])`. This bypasses Helion's autotuning mechanism and uses our predefined settings. While this is quick to run, manually choosing optimal parameters can be challenging and hardware-dependent. + +### What is Autotuning? + +Autotuning is Helion's process of automatically finding the best configuration parameters for your specific: + +- Hardware (GPU model) +- Problem size +- Operation patterns + +When you omit the `config` parameter, Helion will automatically search for the optimal configuration: + +.. code-block:: python + + @helion.kernel() # No config = automatic tuning + def autotuned_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + m, n = x.size() + out = torch.empty_like(x) + for tile_m, tile_n in hl.tile([m, n]): + out[tile_m, tile_n] = x[tile_m, tile_n] + y[tile_m, tile_n] + +Feel free to run the above code to see how much more performant it is than the original, although be warned it might take some time 😃 + +Now let's move on to our puzzles! + +Puzzle 1: Constant Add +---------------------- + +Add a constant to a vector. + +.. code-block:: python + + def add_spec(x: Tensor) -> Tensor: + """This is the spec that you should implement in the helion kernel below.""" + return x + 10. + + # ---- ✨ Is this the best block size? ---- + @helion.kernel(config = helion.Config(block_sizes = [1,])) + def add_kernel(x: torch.Tensor) -> torch.Tensor: + # ---- ✨ Your Code Here ✨---- + # Set up the output buffer which you will return + + # Use Helion to tile the computation + for tile_n in hl.tile(TILE_RANGE): + # ---- ✨ Your Code Here ✨---- + + return out + + # Test the kernel + x = torch.randn(8192, device="cuda") + test_kernel(add_kernel, add_spec, x) + benchmark_kernel(add_kernel, x) + compare_implementations(add_kernel, add_spec, x) + +.. code-block:: python + + def add_spec(x: Tensor) -> Tensor: + """This is the spec that you should implement.""" + return x + 10. + + # ---- ✨ Is this the best block size? ---- + @helion.kernel(config = helion.Config(block_sizes = [32,])) + def add_kernel(x: torch.Tensor) -> torch.Tensor: + # ---- ✨ Your Code Here ✨---- + # Set up the output buffer which you will return + TILE_RANGE = x.size() + out = torch.empty_like(x) + # ---- End of Code ---- + + # Use Helion to tile the computation + for tile_n in hl.tile(TILE_RANGE): + # ---- ✨ Your Code Here ✨---- + x_tile = x[tile_n] + out[tile_n] = x_tile + 10.0 + + return out + + # Test the kernel + x = torch.randn(8192, device="cuda") + test_kernel(add_kernel, add_spec, x) + benchmark_kernel(add_kernel, x) + compare_implementations(add_kernel, add_spec, x) + +Puzzle 2: Outer Vector Add +-------------------------- + +Add two vectors using an outer product pattern. + +.. code-block:: python + + def broadcast_add_spec(x: Tensor, y: Tensor) -> Tensor: + return x[None, :] + y[:, None] + + # ---- ✨ Is this the best block size? ---- + @helion.kernel(config = helion.Config(block_sizes = [32, 32])) + def broadcast_add_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + # Get tensor sizes + # ---- ✨ Your Code Here ✨---- + n0 = x.size(0) + n1 = y.size(0) + out = x.new_empty(n1, n0) + + # Use Helion to tile the computation + for tile_i, tile_j in hl.tile([n1, n0]): + # Get tiles from x and y + y_tile = y[tile_i] + x_tile = x[tile_j] + # Compute outer sum + out[tile_i, tile_j] = y_tile[:, None] + x_tile[None, :] + + return out + + # Test the kernel + x = torch.randn(1142, device="cuda") + y = torch.randn(512, device="cuda") + test_kernel(broadcast_add_kernel, broadcast_add_spec, x, y) + benchmark_kernel(broadcast_add_kernel, x, y) + compare_implementations(broadcast_add_kernel, broadcast_add_spec, x, y) + +Puzzle 3: Fused Outer Multiplication +----------------------------------- + +Multiply a row vector to a column vector and take a relu. + +.. code-block:: python + + def mul_relu_block_spec(x: Tensor, y: Tensor) -> Tensor: + return torch.relu(x[None, :] * y[:, None]) + + # ---- ✨ Is this the best block size? ---- + @helion.kernel(config = helion.Config(block_sizes = [32, 32])) + def mul_relu_block_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + # Get tensor sizes + n0 = x.size(0) + n1 = y.size(0) + # Create output tensor + out = torch.empty([n1, n0], dtype=x.dtype, device=x.device) + + # Use Helion to tile the computation + for tile_i, tile_j in hl.tile([n1, n0]): + # Get tiles from x and y + y_tile = y[tile_i] + x_tile = x[tile_j] + # Compute outer product followed by ReLU + out[tile_i, tile_j] = torch.relu(y_tile[:, None] * x_tile[None, :]) + + return out + + # Test the kernel + x = torch.randn(512, device="cuda") + y = torch.randn(512, device="cuda") + test_kernel(mul_relu_block_kernel, mul_relu_block_spec, x, y) + compare_implementations(mul_relu_block_kernel, mul_relu_block_spec, x, y) + +Puzzle 4: Fused Outer Multiplication - Backwards +------------------------------------------------ + +While PyTorch and torch.compile automatically generates the backwards pass for your Tensor Operations, Helion does not. So lets practice by writing the backwards function for a fused mul_relu kernel + +.. code-block:: python + + def mul_relu_block_back_spec(x: Tensor, y: Tensor, dz: Tensor) -> Tensor: + x = x.clone() + y = y.clone() + x = x.requires_grad_(True) + z = torch.relu(x * y[:, None]) + grad_x, grad_y = torch.autograd.grad(z, [x, y], dz, retain_graph=True) + return grad_x + + @helion.kernel(config=helion.Config(block_sizes=[32, 32])) + def mul_relu_block_back_kernel( + x: torch.Tensor, y: torch.Tensor, dz: torch.Tensor + ) -> torch.Tensor: + # Get tensor sizes + n0 = x.size(1) + n1 = x.size(0) + # Create output tensor for gradients + dx = torch.empty_like(x) + dy = torch.empty_like(y) + + # Use Helion to tile the computation + for tile_i, tile_j in hl.tile([n1, n0]): + # Get input tiles + x_tile = x[tile_i, tile_j] + y_tile = y[tile_i] + dz_tile = dz[tile_i, tile_j] + + # Compute gradients for ReLU * multiplication backward + # For ReLU, gradient is 1 where input > 0, 0 otherwise + relu_mask = (x_tile * y_tile[:, None]) > 0 + # Chain rule: dx = dz * relu_grad * y + dx[tile_i, tile_j] = dz_tile * relu_mask * y_tile[:, None] + + return dx, dy + + # Test the kernel + x = torch.randn(512, 1024, device="cuda") + y = torch.randn(512, device="cuda") + dz = torch.randn(512, 1024, device="cuda") + test_kernel(mul_relu_block_back_kernel, mul_relu_block_back_spec, x, y, dz) + +Puzzle 7: Long Sum +----------------- + +Sum of a batch of numbers. + +.. code-block:: python + + def sum_spec(x: Float32[Tensor, "4 200"]) -> Float32[Tensor, "4"]: + return x.sum(1) + + @helion.kernel() + def sum_kernel(x: torch.Tensor) -> torch.Tensor: + # Get tensor sizes + batch, seq_len = x.size() + # Create output tensor + out = torch.empty(batch, dtype=x.dtype, device=x.device) + + # Use Helion to tile the batch dimension + for tile_batch in hl.tile(batch): + # Initialize accumulator for each batch element + acc = torch.zeros_like(tile_batch, dtype=torch.float32) + + # Process the sequence in chunks + for tile_seq in hl.tile(seq_len): + # Get the current chunk + chunk = x[tile_batch, tile_seq] + # Accumulate sum + acc += torch.sum(chunk, dim=1) + + # Store result + out[tile_batch] = acc + + return out + + # Test the kernel + x = torch.randn(4, 200, device="cuda") + test_kernel(sum_kernel, sum_spec, x) + +Puzzle 8: Long Softmax +--------------------- + +Softmax of a batch of logits. + +.. code-block:: python + + def softmax_spec(x: Float32[Tensor, "4 200"]) -> Float32[Tensor, "4 200"]: + x_max = x.max(1, keepdim=True)[0] + x = x - x_max + x_exp = x.exp() + return x_exp / x_exp.sum(1, keepdim=True) + + @helion.kernel() + def softmax_kernel(x: torch.Tensor) -> torch.Tensor: + # Get tensor sizes + batch, seq_len = x.size() + # Create output tensor + out = torch.empty_like(x) + + # Use Helion to tile the batch dimension + for tile_batch in hl.tile(batch): + # First pass: find max value for each sequence + max_vals = torch.full_like(tile_batch, float('-inf'), dtype=torch.float32) + + for tile_seq in hl.tile(seq_len): + chunk = x[tile_batch, tile_seq] + max_vals = torch.maximum(max_vals, torch.max(chunk, dim=1)[0]) + + # Second pass: compute sum of exp(x - max) + sum_exp = torch.zeros_like(tile_batch, dtype=torch.float32) + + for tile_seq in hl.tile(seq_len): + chunk = x[tile_batch, tile_seq] + exp_vals = torch.exp(chunk - max_vals[:, None]) + sum_exp += torch.sum(exp_vals, dim=1) + + # Third pass: compute softmax + for tile_seq in hl.tile(seq_len): + chunk = x[tile_batch, tile_seq] + exp_vals = torch.exp(chunk - max_vals[:, None]) + out[tile_batch, tile_seq] = exp_vals / sum_exp[:, None] + + return out + + # Test the kernel + x = torch.randn(4, 200, device="cuda") + test_kernel(softmax_kernel, softmax_spec, x) + +Puzzle 9: Simple FlashAttention +------------------------------- + +A scalar version of FlashAttention. + +.. code-block:: python + + def flashatt_spec(q: Float32[Tensor, "200"], k: Float32[Tensor, "200"], v: Float32[Tensor, "200"]) -> Float32[Tensor, "200"]: + x = q[:, None] * k[None, :] + x_max = x.max(1, keepdim=True)[0] + x = x - x_max + x_exp = x.exp() + soft = x_exp / x_exp.sum(1, keepdim=True) + return (v[None, :] * soft).sum(1) + + @helion.kernel() + def flashatt_kernel(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + # Get tensor size + seq_len = q.size(0) + # Create output tensor + out = torch.empty_like(q) + + # Process each query position + for tile_q in hl.tile(seq_len): + q_tile = q[tile_q] + + # Initialize tracking variables for stable softmax + max_val = torch.full_like(q_tile, float('-inf')) + sum_exp = torch.zeros_like(q_tile) + weighted_sum = torch.zeros_like(q_tile) + + # Process in tiles for better cache efficiency + for tile_kv in hl.tile(seq_len): + k_tile = k[tile_kv] + v_tile = v[tile_kv] + + # Compute attention scores + scores = q_tile[:, None] * k_tile[None, :] + + # Find max for numerical stability + batch_max = torch.max(scores, dim=1)[0] + new_max = torch.maximum(max_val, batch_max) + + # Scale old accumulations + scale_factor = torch.exp(max_val - new_max) + sum_exp = sum_exp * scale_factor + weighted_sum = weighted_sum * scale_factor + + # Update with new values + exp_scores = torch.exp(scores - new_max[:, None]) + sum_exp = sum_exp + torch.sum(exp_scores, dim=1) + weighted_sum = weighted_sum + torch.sum(exp_scores * v_tile[None, :], dim=1) + + # Update max_val + max_val = new_max + + # Compute final output + out[tile_q] = weighted_sum / sum_exp + + return out + + # Test the kernel + q = torch.randn(200, device="cuda") + k = torch.randn(200, device="cuda") + v = torch.randn(200, device="cuda") + test_kernel(flashatt_kernel, flashatt_spec, q, k, v) + +Puzzle 10: Two Dimensional Convolution +-------------------------------------- + +A batched 2D convolution. + +.. code-block:: python + + def conv2d_spec(x: Float32[Tensor, "4 8 8"], k: Float32[Tensor, "4 4"]) -> Float32[Tensor, "4 8 8"]: + z = torch.zeros(4, 8, 8) + x = torch.nn.functional.pad(x, (0, 4, 0, 4, 0, 0), value=0.0) + for i in range(8): + for j in range(8): + z[:, i, j] = (k[None, :, :] * x[:, i: i+4, j: j + 4]).sum(1).sum(1) + return z + + @helion.kernel() + def conv2d_kernel(x: torch.Tensor, k: torch.Tensor) -> torch.Tensor: + # Get tensor sizes + batch, h, w = x.size() + kh, kw = k.size()[1:] + + # Create output tensor + out = torch.empty_like(x) + + # Pad the input + x_padded = torch.nn.functional.pad(x, (0, kw, 0, kh, 0, 0), value=0.0) + + # Use Helion to tile the computation + for tile_batch in hl.tile(batch): + # Process each output position + for i in range(h): + for j in range(w): + # Extract the patch + patch = x_padded[tile_batch, i:i+kh, j:j+kw] + # Apply the kernel + out[tile_batch, i, j] = (k[tile_batch] * patch).sum([1, 2]) + + return out + + # Test the kernel + x = torch.randn(4, 8, 8, device="cuda") + k = torch.randn(4, 4, 4, device="cuda") + test_kernel(conv2d_kernel, conv2d_spec, x, k) + +Puzzle 11: Matrix Multiplication +------------------------------- + +A blocked matrix multiplication. + +.. code-block:: python + + def dot_spec(x: Float32[Tensor, "4 32 32"], y: Float32[Tensor, "4 32 32"]) -> Float32[Tensor, "4 32 32"]: + return x @ y + + @helion.kernel() + def dot_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + # Get tensor sizes + batch, m, k = x.size() + _, k, n = y.size() + + # Create output tensor + out = torch.empty([batch, m, n], dtype=x.dtype, device=x.device) + + # Use Helion to tile the computation + for tile_batch in hl.tile(batch): + for tile_m, tile_n in hl.tile([m, n]): + # Initialize accumulator + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + + # Process the reduction dimension in tiles + for tile_k in hl.tile(k): + # Get tiles + x_tile = x[tile_batch, tile_m, tile_k] + y_tile = y[tile_batch, tile_k, tile_n] + + # Accumulate matrix multiplication + acc = acc + torch.matmul(x_tile, y_tile) + + # Store result + out[tile_batch, tile_m, tile_n] = acc + + return out + + # Test the kernel + x = torch.randn(4, 32, 32, device="cuda") + y = torch.randn(4, 32, 32, device="cuda") + test_kernel(dot_kernel, dot_spec, x, y) + +Puzzle 12: Quantized Matrix Multiplication +------------------------------------------ + +When doing matrix multiplication with quantized neural networks, a common strategy is to store the weight matrix in lower precision, with a shift and scale term. + +.. code-block:: python + + FPINT = 32 // 4 + GROUP = 8 + + def quant_dot_spec(scale: Float32[Tensor, "32 8"], + offset: Int32[Tensor, "32"], + weight: Int32[Tensor, "32 8"], + activation: Float32[Tensor, "64 32"]) -> Float32[Tensor, "32 32"]: + offset = offset.view(32, 1) + def extract(x): + over = torch.arange(8, device=x.device) * 4 + mask = 2**4 - 1 + return (x[..., None] >> over) & mask + scale = scale[..., None].expand(-1, 8, GROUP).contiguous().view(-1, 64) + offset = extract(offset)[..., None].expand(-1, 1, 8, GROUP).contiguous().view(-1, 64) + return (scale * (extract(weight).view(-1, 64) - offset)) @ activation + + @helion.kernel() + def quant_dot_kernel(scale: torch.Tensor, offset: torch.Tensor, weight: torch.Tensor, activation: torch.Tensor) -> torch.Tensor: + # Get tensor sizes + n_out, n_groups = scale.size() + mid, n_in = activation.size() + + # Create output tensor + out = torch.empty([n_out, n_in], dtype=scale.dtype, device=scale.device) + + # Helper function to extract 4-bit values + def extract_4bit(x, bit_positions): + mask = 2**4 - 1 + shifted = x[..., None] >> (bit_positions * 4) + return shifted & mask + + # Bit positions for extraction + bit_positions = torch.arange(8, device=scale.device) + + # Use Helion to tile the computation + for tile_out in hl.tile(n_out): + for tile_in in hl.tile(n_in): + # Initialize accumulator + acc = hl.zeros([tile_out, tile_in], dtype=torch.float32) + + # Get the offset values for this tile + offset_tile = offset[tile_out] + # Extract 4-bit values from offsets + offset_extracted = extract_4bit(offset_tile, bit_positions) + + # Process in chunks across the middle dimension + for group_idx in range(n_groups): + # Get scale for this group + scale_group = scale[tile_out, group_idx] + + # Get weights for this group + weight_group = weight[tile_out, group_idx] + + # Extract 4-bit values from weights + weight_extracted = extract_4bit(weight_group, bit_positions) + + # Compute dequantized weights: scale * (weight - offset) + offset_group = offset_extracted[:, group_idx:group_idx+1] # Shape: [tile_out, 1, 8] + dequant_weights = scale_group[:, None, None] * (weight_extracted - offset_group) + + # Reshape dequantized weights for matrix multiplication + dequant_weights = dequant_weights.reshape(tile_out.size(0), 8) + + # Get activations for this group + acts_idx = group_idx * 8 + torch.arange(8, device=scale.device) + act_group = activation[acts_idx][:, tile_in] + + # Accumulate to result + acc = acc + torch.matmul(dequant_weights, act_group) + + # Store result + out[tile_out, tile_in] = acc + + return out + + # Test the kernel with smaller inputs for quicker testing + scale = torch.randn(32, 8, device="cuda") + offset = torch.randint(-10, 10, (32,), device="cuda") + weight = torch.randint(0, 16, (32, 8), device="cuda", dtype=torch.int32) + activation = torch.randn(64, 32, device="cuda") + test_kernel(quant_dot_kernel, quant_dot_spec, scale, offset, weight, activation) + +Autotuning in Helion +-------------------- + +One of the major advantages of Helion is its sophisticated autotuning capability. Let's see how we can leverage this for our matrix multiplication kernel: + +.. code-block:: python + + import torch + import helion + import helion.language as hl + import time + + # Define a matrix multiplication kernel + @helion.kernel() # No config means autotuning will be used + def matmul_autotune(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + m, k = x.size() + k, n = y.size() + out = torch.empty([m, n], dtype=x.dtype, device=x.device) + + for tile_m, tile_n in hl.tile([m, n]): + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + for tile_k in hl.tile(k): + acc = acc + torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + out[tile_m, tile_n] = acc + + return out + + # Create larger tensors for better autotuning results + x = torch.randn(1024, 1024, device="cuda") + y = torch.randn(1024, 1024, device="cuda") + + # First run will trigger autotuning + print("Running with autotuning (this might take a while)...") + start = time.time() + result = matmul_autotune(x, y) + end = time.time() + print(f"First run time (including autotuning): {end - start:.2f}s") + + # Second run will use the tuned configuration + start = time.time() + result = matmul_autotune(x, y) + end = time.time() + print(f"Second run time (using tuned config): {end - start:.2f}s") + + # Verify correctness + expected = x @ y + print(f"Result is correct: {torch.allclose(result, expected, rtol=1e-2, atol=1e-2)}") + +Hardcoding Configurations +------------------------- + +After autotuning, you might want to hardcode the best configuration: + +.. code-block:: python + + # Example of hardcoding a configuration after autotuning + @helion.kernel(config=helion.Config( + block_sizes=[[64, 128], [16]], + loop_orders=[[1, 0]], + num_warps=4, + num_stages=3, + indexing='block_ptr', + l2_grouping=32 + )) + def matmul_fixed_config(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + m, k = x.size() + k, n = y.size() + out = torch.empty([m, n], dtype=x.dtype, device=x.device) + + for tile_m, tile_n in hl.tile([m, n]): + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + for tile_k in hl.tile(k): + acc = acc + torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + out[tile_m, tile_n] = acc + + return out + + # Run with fixed configuration (no autotuning) + start = time.time() + result = matmul_fixed_config(x, y) + end = time.time() + print(f"Run time with fixed config: {end - start:.2f}s") + + # Verify correctness + expected = x @ y + print(f"Result is correct: {torch.allclose(result, expected, rtol=1e-2, atol=1e-2)}") + +Conclusion +---------- + +In this notebook, we've explored how to use Helion to write efficient GPU kernels using a high-level, PyTorch-like syntax. The key advantages of Helion include: + +1. **Higher-level abstraction** than raw Triton, making it easier to write correct kernels +2. **Automatic tiling and memory management**, eliminating a common source of bugs +3. **Powerful autotuning** that can explore a wide range of implementations automatically +4. **Familiar PyTorch syntax** that builds on existing knowledge + +These puzzles should give you a good foundation for writing your own Helion kernels for a variety of applications. diff --git a/docs/index.md b/docs/index.md index 21879d58..039149f8 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,7 +1,21 @@ # Helion Documentation -> ⚠️ **Early Development Warning** -> Helion is currently in an experimental stage. You should expect bugs, incomplete features, and APIs that may change in future versions. Feedback and bug reports are welcome and appreciated! + + +```{toctree} +:maxdepth: 2 +:caption: Contents +:hidden: + +installation +./examples/index +helion_puzzles +api/index + +``` + +⚠️ **Early Development Warning** +Helion is currently in an experimental stage. You should expect bugs, incomplete features, and APIs that may change in future versions. Feedback and bug reports are welcome and appreciated! **Helion** is a Python-embedded domain-specific language (DSL) for authoring machine learning kernels, designed to compile down to [Triton], @@ -234,14 +248,3 @@ variable will be ignored. Enable logging by setting the environment variable `HELION_LOGS=all` for INFO-level logs, or `HELION_LOGS=+all` for DEBUG-level logs. Alternatively, you can specify logging for specific modules using a comma-separated list (e.g., `HELION_LOGS=+helion.runtime.kernel`). - - -## Table of Contents - -```{toctree} -:maxdepth: 1 -:caption: Contents: - -installation -api/index -``` diff --git a/docs/installation.md b/docs/installation.md index 41b52954..112b9353 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -156,4 +156,4 @@ Matches the requirements of [Triton](https://github.com/triton-lang/triton). At Once installation is complete: 1. **Check out the {doc}`api/index` for complete API documentation** -2. **Explore the [examples/](https://github.com/pytorch-labs/helion/tree/main/examples) folder for real-world patterns** +2. **Explore the [examples](examples/) and [Helion Puzzles](helion_puzzles) pages for real-world patterns** diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 00000000..e73fae30 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,5 @@ +sphinx>=7.0.0 +myst-parser>=2.0.0 +sphinx-autodoc-typehints>=1.24.0 +sphinx-rtd-theme>=1.3.0 +sphinx_gallery.gen_gallery diff --git a/examples/README.rst b/examples/README.rst new file mode 100644 index 00000000..e50eba55 --- /dev/null +++ b/examples/README.rst @@ -0,0 +1,80 @@ +Helion Examples +============== + +This directory contains examples demonstrating how to use Helion for high-performance tensor operations. +The examples are organized into the following categories: + +Basic Operations +~~~~~~~~~~~~~~~ + +- ``add.py``: Element-wise addition with broadcasting support +- ``exp.py``: Element-wise exponential function +- ``sum.py``: Sum reduction along the last dimension +- ``long_sum.py``: Efficient sum reduction along a long dimension +- ``softmax.py``: Different implementations of the softmax function + +Matrix Multiplication Operations +~~~~~~~~~~~~~~~~ + +- ``matmul.py``: Basic matrix multiplication +- ``bmm.py``: Batch matrix multiplication +- ``matmul_split_k.py``: Matrix multiplication using split-K algorithm for better parallelism +- ``matmul_layernorm.py``: Fused matrix multiplication and layer normalization +- ``fp8_gemm.py``: Matrix multiplication using FP8 precision + +Attention Operations +~~~~~~~~~~~~~~~~~~~ + +- ``attention.py``: Scaled dot-product attention mechanism +- ``fp8_attention.py``: Attention mechanism using FP8 precision + +Normalization +~~~~~~~~~~~~ + +- ``rms_norm.py``: Root Mean Square (RMS) normalization + +Sparse and Jagged Tensors +~~~~~~~~~~~~~~~~~~~~~~~~~ + +- ``jagged_dense_add.py``: Addition between a jagged tensor and a dense tensor +- ``jagged_mean.py``: Computing the mean of each row in a jagged tensor +- ``segment_reduction.py``: Segmented reduction operation +- ``moe_matmul_ogs.py``: Mixture-of-Experts matrix multiplication using Outer-Gather-Scatter + +Other Operations +~~~~~~~~~~~~~~~ + +- ``concatenate.py``: Tensor concatenation along a dimension +- ``cross_entropy.py``: Cross entropy loss function +- ``embedding.py``: Embedding lookup operation +- ``all_gather_matmul.py``: All-gather operation followed by matrix multiplication +- ``template_via_closure.py``: Templated matrix multiplication with customizable epilogue function + + +.. toctree:: + :maxdepth: 2 + :caption: Contents + :hidden: + + add + all_gather_matmul + attention + bmm + concatenate + cross_entropy + embedding + exp + fp8_attention + fp8_gemm + jagged_dense_add + jagged_mean + long_sum + matmul + matmul_layernorm + matmul_split_k + moe_matmul_ogs + rms_norm + segment_reduction + softmax + sum + template_via_closure diff --git a/examples/add.py b/examples/add.py index c940a626..897224df 100644 --- a/examples/add.py +++ b/examples/add.py @@ -1,3 +1,13 @@ +""" +Element-wise Addition Example +=========================== + +This example demonstrates how to implement an element-wise addition kernel using Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import torch @@ -7,8 +17,21 @@ import helion.language as hl +# %% +# Addition Kernel +# -------------- @helion.kernel() def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + Add two tensors element-wise with broadcasting support. + + Args: + x: First input tensor + y: Second input tensor + + Returns: + A new tensor containing the element-wise sum of x and y + """ # match pytorch broadcasting rules x, y = torch.broadcast_tensors(x, y) out = torch.empty( @@ -23,13 +46,29 @@ def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return out +# %% +# Verification Function +# ------------------- def check(m: int, n: int) -> None: + """ + Verify the add kernel implementation against PyTorch's native add function. + + Args: + m: First dimension of the test tensors + n: Second dimension of the test tensors + """ x = torch.randn([m, n], device="cuda", dtype=torch.float16) y = torch.randn([m, n], device="cuda", dtype=torch.float16) run_example(add, torch.add, (x, y)) +# %% +# Main Function +# ----------- def main() -> None: + """ + Main entry point that runs the add kernel verification with 1024x1024 tensors. + """ check(1024, 1024) diff --git a/examples/all_gather_matmul.py b/examples/all_gather_matmul.py index e93f28bd..8e3deba0 100644 --- a/examples/all_gather_matmul.py +++ b/examples/all_gather_matmul.py @@ -1,3 +1,14 @@ +""" +All-Gather Matrix Multiplication Example +===============================>>>>>>> REPLACE + +This example demonstrates how to implement an all-gather operation followed by matrix multiplication +using Helion and PyTorch's distributed capabilities. +""" + +# %% +# Imports +# ------- from __future__ import annotations import os @@ -17,6 +28,19 @@ def copy_engine_all_gather_w_progress( splits_per_rank: int, backend_stream: torch.cuda.Stream | None = None, ) -> torch.cuda.Stream: + """ + Performs an all-gather operation with progress tracking using symmetric memory. + + Args: + output: The output tensor to store the gathered results + inp: The input tensor to be gathered (must be a symmetric tensor) + progress: Tensor used to track progress of the operation + splits_per_rank: Number of splits per rank + backend_stream: CUDA stream for backend operations (optional) + + Returns: + The CUDA stream used for the operation + """ backend_stream = symm_mem._get_backend_stream(priority=-1) assert inp.is_contiguous() symm_mem_group = dist.group.WORLD @@ -78,6 +102,20 @@ def helion_matmul_w_progress( SPLITS_PER_RANK: int, RANK: int, ) -> torch.Tensor: + """ + Performs matrix multiplication with progress tracking. + + Args: + a: First input tensor for matrix multiplication + a_shared: Shared tensor across ranks + b: Second input tensor for matrix multiplication + progress: Tensor used to track progress of the operation + SPLITS_PER_RANK: Number of splits per rank + RANK: Current process rank + + Returns: + The result of the matrix multiplication + """ M, K = a.size() K2, N = b.size() assert K2 == K, f"size mismatch {K2} != {K}" @@ -119,6 +157,21 @@ def helion_all_gather_matmul( progress: torch.Tensor | None = None, **kwargs: int, ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Combines all-gather and matrix multiplication operations. + + Args: + a_shared: Shared tensor across ranks to be gathered + b: Second input tensor for matrix multiplication + a_out: Optional output tensor for the gathered results + progress: Optional tensor used to track progress of the operation + **kwargs: Additional keyword arguments including splits_per_rank + + Returns: + A tuple containing: + - The gathered tensor + - The result of the matrix multiplication + """ configs = { "SPLITS_PER_RANK": kwargs.get("splits_per_rank", 1), } @@ -169,6 +222,16 @@ def helion_all_gather_matmul( def test(M: int, N: int, K: int, world_size: int, device: torch.device) -> None: + """ + Tests the helion_all_gather_matmul function against PyTorch's implementation. + + Args: + M: First dimension of the matrix + N: Second dimension of the matrix + K: Third dimension of the matrix + world_size: Number of processes + device: Device to run the test on + """ a_shared = symm_mem.empty( M // world_size, K, dtype=torch.bfloat16, device=device ).normal_() @@ -188,6 +251,10 @@ def test(M: int, N: int, K: int, world_size: int, device: torch.device) -> None: def main() -> None: + """ + Main entry point that initializes the distributed environment and runs the test. + Sets up the distributed process group, runs the test, and then cleans up. + """ rank = int(os.environ["LOCAL_RANK"]) world_size = int(os.environ["WORLD_SIZE"]) torch.manual_seed(42 + rank) diff --git a/examples/attention.py b/examples/attention.py index 3d0b0149..1afbdb79 100644 --- a/examples/attention.py +++ b/examples/attention.py @@ -1,3 +1,13 @@ +""" +Attention Mechanism Example +======================== + +This example demonstrates how to implement a scaled dot-product attention mechanism using Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import math @@ -12,6 +22,9 @@ import helion.language as hl +# %% +# Attention Kernel Implementation +# ---------------------------- @helion.kernel( # Static shapes provides a speedup for attention static_shapes=True, @@ -21,6 +34,19 @@ def attention( k_in: torch.Tensor, v_in: torch.Tensor, ) -> torch.Tensor: + """ + Computes scaled dot-product attention. + + Implements the attention mechanism: Attention(Q, K, V) = softmax(Q * K^T / sqrt(d_k)) * V + + Args: + q_in: Query tensor of shape [..., seq_len_q, head_dim] + k_in: Key tensor of shape [..., seq_len_k, head_dim] + v_in: Value tensor of shape [..., seq_len_k, head_dim] + + Returns: + Output tensor of shape [..., seq_len_q, head_dim] + """ m_dim = q_in.size(-2) n_dim = k_in.size(-2) assert n_dim == v_in.size(-2) @@ -57,13 +83,23 @@ def attention( return out.view(q_in.size()) +# %% +# Dynamic Shape Version +# ------------------ attention_dynamic: object = helion.kernel( # pyright: ignore[reportCallIssue] attention.fn, configs=attention.configs, # pyright: ignore[reportArgumentType] static_shapes=False, ) +""" +Dynamic shape version of the attention kernel. +This version allows for variable input shapes at runtime. +""" +# %% +# Testing Function +# ------------- def test( z: int, h: int, @@ -72,6 +108,17 @@ def test( dtype: torch.dtype = torch.float32, device: torch.device | str = "cuda", ) -> None: + """ + Test the attention kernel implementation against PyTorch's native attention functions. + + Args: + z: Batch size + h: Number of attention heads + n_ctx: Sequence length (context size) + head_dim: Dimension of each attention head + dtype: Data type for the tensors + device: Device to run the test on + """ q, k, v = [ torch.randn((z, h, n_ctx, head_dim), dtype=dtype, device=device) for _ in range(3) @@ -97,7 +144,14 @@ def ref_attention( run_example(attention, baselines, (q, k, v)) +# %% +# Main Function +# ----------- def main() -> None: + """ + Main entry point that runs the attention kernel test with specific parameters. + Tests with batch size 2, 32 heads, 1024 sequence length, and 64-dimensional heads using float16. + """ test(2, 32, 1024, 64, torch.float16) diff --git a/examples/bmm.py b/examples/bmm.py index bdae21b3..c5007345 100644 --- a/examples/bmm.py +++ b/examples/bmm.py @@ -1,3 +1,13 @@ +""" +Batch Matrix Multiplication Example +=============================== + +This example demonstrates how to implement a batch matrix multiplication kernel using Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import torch @@ -7,9 +17,22 @@ import helion.language as hl +# %% +# Batch Matrix Multiplication Kernel +# ------------------------------- # static_shapes=True gives a performance boost for matmuls @helion.kernel(static_shapes=True) def bmm(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + """ + Performs batch matrix multiplication. + + Args: + A: Input tensor of shape [B, M, K] + B: Input tensor of shape [B, K, N] + + Returns: + Output tensor of shape [B, M, N] containing the result of batch matrix multiplication + """ # A: [B, M, K], B: [B, K, N], Out: [B, M, N] # dense bmm b, m, k = A.size() b, k, n = B.size() @@ -26,13 +49,33 @@ def bmm(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: return out +# %% +# Verification Function +# ------------------- def check(b: int, m: int, k: int, n: int) -> None: + """ + Verify the bmm kernel implementation against PyTorch's native bmm function. + + Args: + b: Batch size + m: First dimension of the first matrix + k: Second dimension of the first matrix / First dimension of the second matrix + n: Second dimension of the second matrix + """ x = torch.randn([b, m, k], device="cuda", dtype=torch.float16) y = torch.randn([b, k, n], device="cuda", dtype=torch.float16) run_example(bmm, torch.bmm, (x, y)) +# %% +# Main Function +# ----------- def main() -> None: + """ + Main entry point that runs the bmm kernel verification with specific parameters. + Tests with batch size 16, and matrices of dimensions 512x768 and 768x1024. + Ensures torch version is at least 2.8 for 16-bit tensor support in baddbmm. + """ # torch.baddbmm support for 16-bit tensors requires torch 2.8+ assert torch.__version__.split(".")[:2] >= ["2", "8"], "Requires torch 2.8+" check(16, 512, 768, 1024) diff --git a/examples/concatenate.py b/examples/concatenate.py index cb72cf72..34035249 100644 --- a/examples/concatenate.py +++ b/examples/concatenate.py @@ -1,3 +1,13 @@ +""" +Tensor Concatenation Example +======================== + +This example demonstrates how to implement a tensor concatenation operation using Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import torch @@ -7,8 +17,21 @@ import helion.language as hl +# %% +# Concatenation Kernel +# ----------------- @helion.kernel() def concat2d_dim1(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + Concatenates two 2D tensors along dimension 1 (columns). + + Args: + x: First input tensor of shape [M, N1] + y: Second input tensor of shape [M, N2] with same first dimension as x + + Returns: + Output tensor of shape [M, N1+N2] containing the concatenation of x and y along dimension 1 + """ assert x.size(0) == y.size(0) out = torch.empty( [x.size(0), x.size(1) + y.size(1)], dtype=x.dtype, device=x.device @@ -29,7 +52,14 @@ def concat2d_dim1(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return out +# %% +# Main Function +# ----------- def main() -> None: + """ + Main entry point that runs the concatenation kernel verification. + Tests with two tensors of shapes [1500, 400] and [1500, 600]. + """ x = torch.randn([1500, 400], device="cuda") y = torch.randn([1500, 600], device="cuda") run_example(concat2d_dim1, lambda x, y: torch.cat([x, y], dim=1), (x, y)) diff --git a/examples/cross_entropy.py b/examples/cross_entropy.py index 28f36cd1..b91acb92 100644 --- a/examples/cross_entropy.py +++ b/examples/cross_entropy.py @@ -1,3 +1,13 @@ +""" +Cross Entropy Loss Example +====================== + +This example demonstrates how to implement a cross entropy loss function using Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import os @@ -8,17 +18,37 @@ from helion._testing import run_example import helion.language as hl +# %% +# Configuration +# ----------- # TritonBench configuration - adjust based on HELION_DEV_LOW_VRAM environment variable if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1": # Low memory configuration TRITONBENCH_ARGS = {"B": 4, "T": 512, "v_range": "10,15"} +# %% +# Cross Entropy Kernel +# ----------------- @helion.kernel(ignore_warnings=[helion.exc.TensorOperationInWrapper]) def cross_entropy( logits: torch.Tensor, # [N, V] input logits labels: torch.Tensor, # [N] target labels ) -> torch.Tensor: + """ + Computes the cross entropy loss between logits and target labels. + + Implements the cross entropy loss function commonly used in classification tasks. + The function computes the log softmax of the logits and then calculates the negative + log likelihood of the true labels. + + Args: + logits: Input logits tensor of shape [N, V] where N is batch size and V is vocabulary size + labels: Target labels tensor of shape [N] containing class indices + + Returns: + A scalar tensor containing the mean cross entropy loss + """ n, v = logits.shape losses = torch.zeros([n], dtype=logits.dtype, device=logits.device) @@ -53,8 +83,14 @@ def cross_entropy( return losses.mean() +# %% +# Main Function +# ----------- def main() -> None: - """Run cross entropy benchmark with different input sizes.""" + """ + Main entry point that runs the cross entropy kernel verification. + Tests with a batch size of 128 and vocabulary size of 1000. + """ # Test with moderate size n, v = 128, 1000 logits = torch.randn(n, v, device="cuda", dtype=torch.float32) diff --git a/examples/embedding.py b/examples/embedding.py index e9e99f84..66c8261e 100644 --- a/examples/embedding.py +++ b/examples/embedding.py @@ -1,3 +1,13 @@ +""" +Embedding Lookup Example +==================== + +This example demonstrates how to implement an embedding lookup operation using Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import torch @@ -7,8 +17,23 @@ import helion.language as hl +# %% +# Embedding Kernel +# ------------- @helion.kernel() def embedding(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + """ + Performs embedding lookup for input indices. + + Maps indices in the input tensor to vectors from the embedding weight matrix. + + Args: + x: Input tensor of indices of any shape + weight: Embedding weight matrix of shape [num_embeddings, embedding_dim] + + Returns: + Output tensor of shape [*x.shape, embedding_dim] containing the embedding vectors + """ x_flat = x.reshape(-1) # collapse x into a single dimension _, embedding_dim = weight.size() out = torch.empty( @@ -20,14 +45,35 @@ def embedding(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: return out.view(*x.size(), embedding_dim) +# %% +# Benchmark Wrapper +# -------------- def embedding_tritonbench( V: int, D: int, inp: torch.Tensor, shared_weight: torch.Tensor ) -> torch.Tensor: - """Wrapper for tritonbench that matches its interface.""" + """ + Wrapper for tritonbench that matches its interface. + + Args: + V: Vocabulary size (unused, provided for compatibility) + D: Embedding dimension (unused, provided for compatibility) + inp: Input tensor of indices + shared_weight: Embedding weight matrix + + Returns: + Output tensor containing the embedding vectors + """ return embedding(inp, shared_weight) +# %% +# Main Function +# ----------- def main() -> None: + """ + Main entry point that runs the embedding kernel verification. + Tests with a batch of indices and an embedding table of size 16x64. + """ num_embeddings, embedding_dim = 16, 64 x = torch.randint(0, num_embeddings, [256, 32], device="cuda", dtype=torch.int32) weight = torch.randn([num_embeddings, embedding_dim], device="cuda") diff --git a/examples/exp.py b/examples/exp.py index 357f4862..305a0562 100644 --- a/examples/exp.py +++ b/examples/exp.py @@ -1,3 +1,13 @@ +""" +Exponential Function Example +======================== + +This example demonstrates how to implement an element-wise exponential function using Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import torch @@ -7,25 +17,63 @@ import helion.language as hl +# %% +# Exponential Kernel +# --------------- @helion.kernel() def exp(x: torch.Tensor) -> torch.Tensor: + """ + Computes the exponential of all elements in the input tensor. + + Args: + x: Input tensor + + Returns: + Output tensor with the exponential of each element in the input + """ out = torch.empty_like(x) for tile in hl.tile(x.size()): out[tile] = torch.exp(x[tile]) return out +# %% +# Benchmark Wrapper +# -------------- def exp_tritonbench(x: torch.Tensor) -> dict[str, torch.Tensor]: - """Wrapper for tritonbench that returns output in expected format.""" + """ + Wrapper for tritonbench that returns output in expected format. + + Args: + x: Input tensor + + Returns: + Dictionary containing the output tensor + """ return {"output": exp(x)} +# %% +# Verification Function +# ------------------- def check(n: int) -> None: + """ + Verify the exp kernel implementation against PyTorch's native exp function. + + Args: + n: Size of the test tensor + """ x = torch.randn(n, device="cuda", dtype=torch.float32) run_example(exp, torch.exp, (x,)) +# %% +# Main Function +# ----------- def main() -> None: + """ + Main entry point that runs the exp kernel verification with a tensor of size 1M elements. + """ check(1024 * 1024) diff --git a/examples/fp8_attention.py b/examples/fp8_attention.py index f9c5153b..69a44d0e 100644 --- a/examples/fp8_attention.py +++ b/examples/fp8_attention.py @@ -1,3 +1,13 @@ +""" +FP8 Attention Mechanism Example +====================>>>>>>> REPLACE + +This example demonstrates how to implement a scaled dot-product attention mechanism using FP8 precision in Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import math @@ -17,6 +27,21 @@ def fp8_attention_kernel( batch: int, heads: int, ) -> torch.Tensor: + """ + Computes scaled dot-product attention using FP8 precision. + + Implements the attention mechanism with FP8 tensors for improved performance and memory efficiency. + + Args: + q: Query tensor of shape [batch*heads, seq, dim] in FP8 format + k: Key tensor of shape [batch*heads, seq, dim] in FP8 format + v: Value tensor of shape [batch*heads, dim, seq] (pre-transposed) in FP8 format + batch: Number of batches + heads: Number of attention heads + + Returns: + Output tensor of shape [batch, heads, seq_len, head_dim] in FP8 format + """ batch_heads = q.size(0) seq_len = q.size(1) head_dim = q.size(2) @@ -108,6 +133,20 @@ def fp8_attention_kernel( def preprocess_fp8_attention_inputs( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Preprocesses attention inputs by converting them to FP8 format and reshaping. + + Args: + q: Query tensor of shape [batch, heads, seq_len, head_dim] + k: Key tensor of shape [batch, heads, seq_len, head_dim] + v: Value tensor of shape [batch, heads, seq_len, head_dim] + + Returns: + Tuple of (q_fp8, k_fp8, v_fp8) where: + - q_fp8: Query tensor in FP8 format with shape [batch*heads, seq_len, head_dim] + - k_fp8: Key tensor in FP8 format with shape [batch*heads, seq_len, head_dim] + - v_fp8: Value tensor in FP8 format with shape [batch*heads, head_dim, seq_len] (pre-transposed) + """ q_fp8 = q.to(torch.float8_e5m2) k_fp8 = k.to(torch.float8_e5m2) v = v.permute(0, 1, 3, 2) @@ -122,6 +161,19 @@ def preprocess_fp8_attention_inputs( def fp8_attention_tritonbench( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor ) -> Callable[[], torch.Tensor]: + """ + Creates a callable function for benchmarking FP8 attention with tritonbench. + + Preprocesses inputs and returns a lambda function that calls the FP8 attention kernel. + + Args: + q: Query tensor of shape [batch, heads, seq_len, head_dim] + k: Key tensor of shape [batch, heads, seq_len, head_dim] + v: Value tensor of shape [batch, heads, seq_len, head_dim] + + Returns: + A callable function that executes the FP8 attention kernel + """ batch, heads, seq_len, head_dim = q.shape q_fp8, k_fp8, v_fp8 = preprocess_fp8_attention_inputs(q, k, v) # Return lambda that calls the kernel - preprocessing is done outside. @@ -138,6 +190,21 @@ def _fp8_attention_pytorch_impl( seq_len: int, head_dim: int, ) -> torch.Tensor: + """ + PyTorch implementation of FP8 attention for comparison with the kernel version. + + Args: + q_fp8: Query tensor in FP8 format with shape [batch*heads, seq_len, head_dim] + k_fp8: Key tensor in FP8 format with shape [batch*heads, seq_len, head_dim] + v_fp8: Value tensor in FP8 format with shape [batch*heads, head_dim, seq_len] (pre-transposed) + batch: Number of batches + heads: Number of attention heads + seq_len: Sequence length + head_dim: Dimension of each attention head + + Returns: + Output tensor of shape [batch, heads, seq_len, head_dim] in FP8 format + """ sm_scale = 1.0 / math.sqrt(float(head_dim)) outputs = [] @@ -204,6 +271,15 @@ def fp8_attention_pytorch( def check(batch: int, heads: int, seq_len: int, head_dim: int) -> None: + """ + Verifies the FP8 attention kernel implementation against the PyTorch reference implementation. + + Args: + batch: Number of batches + heads: Number of attention heads + seq_len: Sequence length + head_dim: Dimension of each attention head + """ torch.manual_seed(42) q = torch.randn(batch, heads, seq_len, head_dim, dtype=torch.float16, device="cuda") k = torch.randn(batch, heads, seq_len, head_dim, dtype=torch.float16, device="cuda") @@ -223,6 +299,10 @@ def check(batch: int, heads: int, seq_len: int, head_dim: int) -> None: def main() -> None: + """ + Main entry point that runs the FP8 attention kernel verification with different configurations. + Tests with small, medium, and large attention configurations. + """ check(1, 2, 128, 64) check(2, 4, 256, 64) check(4, 8, 512, 128) diff --git a/examples/fp8_gemm.py b/examples/fp8_gemm.py index 81cc6815..7b8ce9e3 100644 --- a/examples/fp8_gemm.py +++ b/examples/fp8_gemm.py @@ -1,3 +1,13 @@ +""" +FP8 Matrix Multiplication Example +============================ + +This example demonstrates how to implement a matrix multiplication kernel using FP8 precision in Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import torch @@ -7,9 +17,13 @@ import helion.language as hl +# %% +# FP8 GEMM Kernel +# ------------ @helion.kernel(static_shapes=True) def fp8_gemm(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - """FP8 General Matrix Multiplication (GEMM). + """ + FP8 General Matrix Multiplication (GEMM). This kernel demonstrates FP8 computation in Helion. When lowered to Triton, the tl.dot operation will handle @@ -47,10 +61,22 @@ def fp8_gemm(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return out +# %% +# Reference Implementation +# -------------------- def reference_fp8_gemm_pytorch( x_fp8: torch.Tensor, y_fp8: torch.Tensor ) -> torch.Tensor: - """Reference implementation using torch._scaled_mm.""" + """ + Reference implementation using torch._scaled_mm. + + Args: + x_fp8: Input tensor in FP8 format + y_fp8: Input tensor in FP8 format + + Returns: + Output tensor in FP16 format + """ # torch._scaled_mm requires column-major for second operand y_fp8_t = y_fp8.T.contiguous().T scale_a = torch.tensor(1.0, device=x_fp8.device) @@ -60,13 +86,35 @@ def reference_fp8_gemm_pytorch( ) +# %% +# Benchmark Wrapper +# -------------- def fp8_gemm_tritonbench(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: - """Wrapper for TritonBench compatibility.""" + """ + Wrapper for TritonBench compatibility. + + Args: + a: First input tensor in FP8 format + b: Second input tensor in FP8 format + + Returns: + Output tensor from the fp8_gemm kernel + """ return fp8_gemm(a, b) +# %% +# Verification Function +# ------------------- def check(m: int, k: int, n: int) -> None: - """Test the FP8 GEMM implementation.""" + """ + Test the FP8 GEMM implementation against the PyTorch reference implementation. + + Args: + m: First dimension of the first matrix + k: Second dimension of the first matrix / First dimension of the second matrix + n: Second dimension of the second matrix + """ # Create FP8 tensors x = torch.randn([m, k], device="cuda", dtype=torch.float32) y = torch.randn([k, n], device="cuda", dtype=torch.float32) @@ -78,7 +126,14 @@ def check(m: int, k: int, n: int) -> None: run_example(fp8_gemm, reference_fp8_gemm_pytorch, (x_fp8, y_fp8)) +# %% +# Main Function +# ----------- def main() -> None: + """ + Main entry point that runs the FP8 GEMM kernel verification with different matrix sizes. + Tests with small (256x256), medium (512x512), and large (1024x1024) matrices. + """ # Test with different sizes check(256, 256, 256) check(512, 512, 512) diff --git a/examples/jagged_dense_add.py b/examples/jagged_dense_add.py index d5fb91e8..201e5556 100644 --- a/examples/jagged_dense_add.py +++ b/examples/jagged_dense_add.py @@ -1,3 +1,14 @@ +""" +Jagged Dense Addition Example +========================= + +This example demonstrates how to implement an addition operation between a jagged tensor +and a dense tensor using Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import torch @@ -6,6 +17,9 @@ from helion._testing import run_example import helion.language as hl +# %% +# Jagged Tensor Format +# ----------------- """ A tensor x is stored in a jagged-row, prefix-sparse layout that packs only the non-zero elements of each row. All non-zeros are concatenated into a one-dimensional buffer @@ -14,12 +28,12 @@ contains exactly the first K_i non-zero entries of that row (with K_i = x_offsets[i+1] − x_offsets[i]). Elements beyond column K_i − 1 are implicitly zero and therefore omitted from storage. - -This example implements a kernel that adds a dense matrix y to a -jagged matrix x. It is intended to illustrate how to work with jagged tensors. """ +# %% +# Jagged Dense Addition Kernel +# ------------------------ @helion.kernel() def jagged_dense_add_2d( x_data: torch.Tensor, x_offsets: torch.Tensor, y: torch.Tensor @@ -28,16 +42,14 @@ def jagged_dense_add_2d( Add a jagged-prefix sparse tensor (x_data, x_offsets) to a dense matrix y and return the dense result. - Args - ---- - x_data : 1-D tensor holding all non-zero elements row-by-row. - x_offsets : (num_rows + 1) tensor. Row i is the slice - x_data[x_offsets[i] : x_offsets[i+1]] (length K_i). - y: (num_rows, N) tensor, N >= max(K_i). + Args: + x_data: 1-D tensor holding all non-zero elements row-by-row + x_offsets: (num_rows + 1) tensor. Row i is the slice + x_data[x_offsets[i] : x_offsets[i+1]] (length K_i) + y: (num_rows, N) tensor, N >= max(K_i) - Returns - ------- - result : dense + jagged, shape (num_rows, N). + Returns: + Dense tensor of shape (num_rows, N) containing the sum of the jagged and dense tensors """ num_rows = y.size(0) assert x_offsets.size(0) == num_rows + 1 @@ -63,12 +75,25 @@ def jagged_dense_add_2d( return out +# %% +# Reference Implementation +# -------------------- def jagged_dense_add_2d_reference( x_data: torch.Tensor, x_offsets: torch.Tensor, y: torch.Tensor, ) -> torch.Tensor: - """The same as the above, but implemented in pure PyTorch.""" + """ + Reference implementation of jagged dense addition in pure PyTorch. + + Args: + x_data: 1-D tensor holding all non-zero elements row-by-row + x_offsets: (num_rows + 1) tensor with offsets for each row + y: Dense tensor to add to the jagged tensor + + Returns: + Dense tensor containing the sum of the jagged and dense tensors + """ num_rows = x_offsets.numel() - 1 assert y.shape[0] == num_rows out = y.clone() @@ -79,6 +104,9 @@ def jagged_dense_add_2d_reference( return out +# %% +# Utility Function +# ------------- def random_jagged_2d( num_rows: int, max_cols: int, @@ -87,10 +115,18 @@ def random_jagged_2d( device: torch.device | str = "cuda", ) -> tuple[torch.Tensor, torch.Tensor]: """ - Produces: - x_data – 1-D tensor holding all non-zeros row-by-row - x_offsets – (num_rows+1) tensor; x_data[x_offsets[i]:x_offsets[i+1]] is row i - Each row i has a random non-zero prefix length K_i in [1, max_cols]. + Generate random jagged 2D tensor data. + + Args: + num_rows: Number of rows in the jagged tensor + max_cols: Maximum number of columns per row + dtype: Data type for the tensor values + device: Device to create the tensors on + + Returns: + Tuple of (x_data, x_offsets) where: + - x_data: 1-D tensor holding all non-zeros row-by-row + - x_offsets: (num_rows+1) tensor with offsets for each row """ # random positive K_i for each row lengths = torch.randint(1, max_cols + 1, (num_rows,), device=device) @@ -105,7 +141,16 @@ def random_jagged_2d( return x_data, x_offsets +# %% +# Main Function +# ----------- def main() -> None: + """ + Main entry point that runs the jagged dense add kernel verification. + + Creates random jagged 2D data and a dense tensor, then compares the kernel + implementation against the PyTorch reference implementation. + """ rows, cols = 256, 5000 x_data, x_offsets = random_jagged_2d(rows, cols, device="cuda") y = torch.randn([rows, cols], device="cuda") diff --git a/examples/jagged_mean.py b/examples/jagged_mean.py index cbc6e99d..5494e7d0 100644 --- a/examples/jagged_mean.py +++ b/examples/jagged_mean.py @@ -1,3 +1,14 @@ +""" +Jagged Mean Example +=============== + +This example demonstrates how to compute the mean of each row in a jagged tensor +with variable features per row using Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import os @@ -8,12 +19,18 @@ from helion._testing import run_example import helion.language as hl +# %% +# Configuration +# ----------- # TritonBench configuration - adjust based on HELION_DEV_LOW_VRAM environment variable if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1": # Low memory configuration TRITONBENCH_ARGS = {"B": 32, "M": 8, "seqlen": 64} +# %% +# Jagged Mean Kernel +# --------------- @helion.kernel() def jagged_mean_kernel( x_data: torch.Tensor, @@ -24,18 +41,16 @@ def jagged_mean_kernel( """ Compute the mean of each row in a jagged tensor with variable features per row. - Args - ---- - x_data : 2-D tensor of shape (total_elements, max_M) holding all elements. - x_offsets : (num_rows + 1) tensor. Row i is the slice - x_data[x_offsets[i] : x_offsets[i+1], :]. - x_feature_counts: (num_rows) tensor. Number of valid features for each row. - max_M_tensor : Dummy tensor whose numel() gives max number of features. - - Returns - ------- - result : 2-D tensor of shape (num_rows, max_M) containing the mean of each row. - Invalid features (beyond x_feature_counts[i]) are set to 0. + Args: + x_data: 2-D tensor of shape (total_elements, max_M) holding all elements + x_offsets: (num_rows + 1) tensor. Row i is the slice + x_data[x_offsets[i] : x_offsets[i+1], :] + x_feature_counts: (num_rows) tensor. Number of valid features for each row + max_M_tensor: Dummy tensor whose numel() gives max number of features + + Returns: + 2-D tensor of shape (num_rows, max_M) containing the mean of each row. + Invalid features (beyond x_feature_counts[i]) are set to 0. """ num_rows = x_offsets.size(0) - 1 max_M = max_M_tensor.numel() # Extract max features from dummy tensor @@ -96,13 +111,27 @@ def jagged_mean_kernel( return out +# %% +# Reference Implementation +# -------------------- def reference_jagged_mean_kernel_pytorch( x_data: torch.Tensor, x_offsets: torch.Tensor, x_feature_counts: torch.Tensor, max_M: int, ) -> torch.Tensor: - """PyTorch reference implementation for jagged mean with variable features.""" + """ + PyTorch reference implementation for jagged mean with variable features. + + Args: + x_data: 2-D tensor holding all elements + x_offsets: Offsets tensor for row indexing + x_feature_counts: Number of valid features per row + max_M: Maximum number of features + + Returns: + Tensor containing the mean of each row + """ num_rows = x_offsets.numel() - 1 out = torch.zeros((num_rows, max_M), dtype=x_data.dtype, device=x_data.device) for i in range(num_rows): @@ -114,6 +143,9 @@ def reference_jagged_mean_kernel_pytorch( return out +# %% +# Benchmark Wrapper +# -------------- def jagged_mean_tritonbench( x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float ) -> torch.Tensor: @@ -144,7 +176,16 @@ def jagged_mean_tritonbench( return jagged_mean_kernel(x_values, x_offsets, feature_counts, max_M_tensor) +# %% +# Main Function +# ----------- def main() -> None: + """ + Main entry point that runs the jagged mean kernel verification. + + Creates test data with random jagged tensors and feature counts, then compares + the kernel implementation against the PyTorch reference implementation. + """ num_rows, max_cols = 32, 64 device = "cuda" diff --git a/examples/long_sum.py b/examples/long_sum.py index 543869da..19bd20b4 100644 --- a/examples/long_sum.py +++ b/examples/long_sum.py @@ -1,3 +1,13 @@ +""" +Long Dimension Sum Example +====================== + +This example demonstrates how to implement efficient sum reduction along a long dimension using Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import torch @@ -7,11 +17,25 @@ import helion.language as hl +# %% +# Baseline Implementation +# ------------------- def baseline_sum(x: torch.Tensor) -> torch.Tensor: + """ + PyTorch baseline implementation of sum reduction along the last dimension. + + Args: + x: Input tensor + + Returns: + Tensor with sum of elements along the last dimension + """ return x.sum(-1) -# Naive Reduction: Load the entire reduction dim at once, and reduce in reg. +# %% +# Naive Reduction Kernel +# ------------------ @helion.kernel( config=helion.Config( block_sizes=[1], @@ -22,6 +46,17 @@ def baseline_sum(x: torch.Tensor) -> torch.Tensor: ) ) def longsum(x: torch.Tensor) -> torch.Tensor: + """ + Naive reduction kernel that sums elements along the last dimension. + + Loads the entire reduction dimension at once and reduces in registers. + + Args: + x: Input tensor of shape [m, n] + + Returns: + Output tensor of shape [m] containing the sum of each row + """ m, _ = x.size() out = torch.empty([m], dtype=x.dtype, device=x.device) @@ -30,7 +65,9 @@ def longsum(x: torch.Tensor) -> torch.Tensor: return out -# Looped reduction +# %% +# Looped Reduction Kernel +# ------------------- @helion.kernel( config=helion.Config( block_sizes=[1], @@ -43,6 +80,17 @@ def longsum(x: torch.Tensor) -> torch.Tensor: ) ) def longsum_w_red_loop(x: torch.Tensor) -> torch.Tensor: + """ + Looped reduction kernel that sums elements along the last dimension. + + Uses a reduction loop with a specified tile size to handle large dimensions efficiently. + + Args: + x: Input tensor of shape [m, n] + + Returns: + Output tensor of shape [m] containing the sum of each row + """ m, _ = x.size() out = torch.empty([m], dtype=x.dtype, device=x.device) @@ -51,13 +99,26 @@ def longsum_w_red_loop(x: torch.Tensor) -> torch.Tensor: return out -# This generates the same code as above, but manually implements looped reduction. +# %% +# Manual Looped Reduction Kernel +# -------------------------- @helion.kernel( config=helion.Config( block_sizes=[32768, 1], num_warps=16, num_stages=5, indexing="pointer" ) ) def longsum_manual(x: torch.Tensor) -> torch.Tensor: + """ + Manual implementation of looped reduction for summing elements along the last dimension. + + Manually implements the reduction loop with explicit accumulation and final reduction. + + Args: + x: Input tensor of shape [m, n] + + Returns: + Output tensor of shape [m] containing the sum of each row + """ m, n = x.size() out = torch.empty([m], dtype=x.dtype, device=x.device) @@ -72,7 +133,19 @@ def longsum_manual(x: torch.Tensor) -> torch.Tensor: return out +# %% +# Verification Function +# ------------------- def check(m: int, n: int) -> None: + """ + Verify the sum kernel implementations against PyTorch's native sum function. + + Tests all three kernel variants (naive, looped, manual) against the baseline. + + Args: + m: First dimension of the test tensor + n: Second dimension of the test tensor (reduction dimension) + """ x = torch.randn([m, n], device="cuda", dtype=torch.float32) # Test all three kernel variants against the baseline @@ -85,7 +158,15 @@ def check(m: int, n: int) -> None: run_example(kernels, baseline_sum, (x,)) +# %% +# Main Function +# ----------- def main() -> None: + """ + Main entry point that runs the sum kernel verification with a large tensor. + + Tests with a tensor of shape [4, 130000] to demonstrate handling of long reduction dimensions. + """ check(4, 130000) # seq_len = 128k diff --git a/examples/matmul.py b/examples/matmul.py index 1f6ad675..1441d756 100644 --- a/examples/matmul.py +++ b/examples/matmul.py @@ -1,3 +1,13 @@ +""" +Matrix Multiplication Example +============================ + +This example demonstrates how to implement a basic matrix multiplication kernel using Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import torch @@ -7,9 +17,22 @@ import helion.language as hl +# %% +# Matrix Multiplication Kernel +# --------------------------- # static_shapes=True gives a performance boost for matmuls @helion.kernel(static_shapes=True) def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + Performs matrix multiplication between two tensors. + + Args: + x: First input tensor of shape [M, K] + y: Second input tensor of shape [K, N] + + Returns: + Output tensor of shape [M, N] containing the result of matrix multiplication + """ m, k = x.size() k2, n = y.size() assert k == k2, f"size mismatch {k} != {k2}" @@ -24,13 +47,30 @@ def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return out +# %% +# Verification Function +# ------------------- def check(m: int, k: int, n: int) -> None: + """ + Verify the matmul kernel implementation against PyTorch's native matmul function. + + Args: + m: First dimension of the first matrix + k: Second dimension of the first matrix / First dimension of the second matrix + n: Second dimension of the second matrix + """ x = torch.randn([m, k], device="cuda", dtype=torch.float16) y = torch.randn([k, n], device="cuda", dtype=torch.float16) run_example(matmul, torch.matmul, (x, y)) +# %% +# Main Function +# ----------- def main() -> None: + """ + Main entry point that runs the matmul kernel verification with 1024x1024 matrices. + """ check(1024, 1024, 1024) diff --git a/examples/matmul_layernorm.py b/examples/matmul_layernorm.py index 4e5ecc35..59d45e52 100644 --- a/examples/matmul_layernorm.py +++ b/examples/matmul_layernorm.py @@ -1,3 +1,14 @@ +""" +Matrix Multiplication with Layer Normalization Example +============================================== + +This example demonstrates how to implement a fused matrix multiplication and layer normalization +operation using Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import torch @@ -8,11 +19,26 @@ import helion.language as hl +# %% +# MatMul-LayerNorm Kernel +# -------------------- # static_shapes=True gives a performance boost for matmuls @helion.kernel(static_shapes=True) def matmul_layernorm( x: torch.Tensor, y: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor ) -> torch.Tensor: + """ + Performs matrix multiplication followed by layer normalization. + + Args: + x: First input tensor of shape [M, K] + y: Second input tensor of shape [K, N] + weight: Layer normalization weight parameter of shape [N] + bias: Layer normalization bias parameter of shape [N] + + Returns: + Output tensor of shape [M, N] containing the result of matrix multiplication followed by layer normalization + """ m, k = x.size() k2 = y.size(0) n = hl.register_reduction_dim(y.size(1)) @@ -35,9 +61,24 @@ def matmul_layernorm( return out +# %% +# Reference Implementation +# -------------------- def matmul_layernorm_pytorch( x: torch.Tensor, y: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor ) -> torch.Tensor: + """ + PyTorch reference implementation of matrix multiplication followed by layer normalization. + + Args: + x: First input tensor of shape [M, K] + y: Second input tensor of shape [K, N] + weight: Layer normalization weight parameter of shape [N] + bias: Layer normalization bias parameter of shape [N] + + Returns: + Output tensor of shape [M, N] containing the result of matrix multiplication followed by layer normalization + """ matmul_out = torch.matmul(x, y) ln_out = F.layer_norm( @@ -50,7 +91,18 @@ def matmul_layernorm_pytorch( return ln_out.to(torch.promote_types(x.dtype, y.dtype)) +# %% +# Verification Function +# ------------------- def check(m: int, k: int, n: int) -> None: + """ + Verify the matmul_layernorm kernel implementation against the PyTorch reference implementation. + + Args: + m: First dimension of the first matrix + k: Second dimension of the first matrix / First dimension of the second matrix + n: Second dimension of the second matrix + """ x = torch.randn([m, k], device="cuda", dtype=torch.float16) y = torch.randn([k, n], device="cuda", dtype=torch.float16) weight = torch.randn([n], device="cuda", dtype=torch.float16) @@ -58,7 +110,17 @@ def check(m: int, k: int, n: int) -> None: run_example(matmul_layernorm, matmul_layernorm_pytorch, (x, y, weight, bias)) +# %% +# Main Function +# ----------- def main() -> None: + """ + Main entry point that runs the matmul_layernorm kernel verification with different matrix sizes. + + Tests with two configurations: + - 32x64 * 64x200 + - 128x256 * 256x400 + """ # TODO(yf225): n=64 or 128 throws error, need to investigate # check(32, 64, 64) # check(32, 64, 128) diff --git a/examples/matmul_split_k.py b/examples/matmul_split_k.py index 66f87449..efe7fb88 100644 --- a/examples/matmul_split_k.py +++ b/examples/matmul_split_k.py @@ -1,3 +1,14 @@ +""" +Split-K Matrix Multiplication Example +================================ + +This example demonstrates how to implement a matrix multiplication kernel using the split-K +algorithm for better parallelism in Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import torch @@ -8,9 +19,25 @@ import helion.language as hl +# %% +# Split-K Matrix Multiplication Kernel +# -------------------------------- # static_shapes=True gives a performance boost for matmuls @helion.kernel(static_shapes=True) def matmul_split_k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + Performs matrix multiplication using split-K algorithm for better parallelism. + + Split-K divides the reduction dimension (K) into multiple chunks that can be processed + in parallel, with results atomically accumulated at the end. + + Args: + x: First input tensor of shape [M, K] + y: Second input tensor of shape [K, N] + + Returns: + Output tensor of shape [M, N] containing the result of matrix multiplication + """ m, k = x.size() k2, n = y.size() assert k == k2, f"size mismatch {k} != {k2}" @@ -27,13 +54,33 @@ def matmul_split_k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return out +# %% +# Verification Function +# ------------------- def check(m: int, k: int, n: int) -> None: + """ + Verify the split-K matmul kernel implementation against PyTorch's native matmul function. + + Args: + m: First dimension of the first matrix + k: Second dimension of the first matrix / First dimension of the second matrix + n: Second dimension of the second matrix + """ x = torch.randn([m, k], device="cuda", dtype=torch.float16) y = torch.randn([k, n], device="cuda", dtype=torch.float16) run_example(matmul_split_k, torch.matmul, (x, y), atol=1) +# %% +# Main Function +# ----------- def main() -> None: + """ + Main entry point that runs the split-K matmul kernel verification. + + Tests with matrices of shape 64x32768 and 32768x64, which benefits from the split-K approach + due to the large reduction dimension. + """ check(64, 32768, 64) diff --git a/examples/moe_matmul_ogs.py b/examples/moe_matmul_ogs.py index 66b9af24..a32232b4 100644 --- a/examples/moe_matmul_ogs.py +++ b/examples/moe_matmul_ogs.py @@ -1,7 +1,14 @@ """ -Mixture-of-Experts (MoE) matmul with Outer-Gather-Scatter (OGS) +Mixture-of-Experts Matrix Multiplication Example +========================================= + +This example demonstrates how to implement a Mixture-of-Experts (MoE) matrix multiplication +using the Outer-Gather-Scatter (OGS) approach in Helion. """ +# %% +# Imports +# ------- from __future__ import annotations import torch @@ -11,6 +18,9 @@ import helion.language as hl +# %% +# MoE MatMul OGS Kernel +# ------------------ @helion.kernel(static_shapes=False) def moe_matmul_ogs( A: torch.Tensor, # [T, K] - Input activations (T tokens, K features) @@ -20,6 +30,23 @@ def moe_matmul_ogs( sorted_to_orig_token_idx: torch.Tensor, # [T] - Maps sorted token positions back to original positions max_T_per_expert_tensor: torch.Tensor, # [max_T_per_expert] - Dummy tensor whose size indicates max tokens per expert ) -> torch.Tensor: # [T, N] - Output activations + """ + Performs Mixture-of-Experts (MoE) matrix multiplication using the Outer-Gather-Scatter approach. + + This kernel efficiently handles sparse expert routing by grouping tokens by their assigned expert, + performing matrix multiplications for each expert, and scattering results back to the original token order. + + Args: + A: Input activations tensor of shape [T, K] (T tokens, K features) + W: Expert weights tensor of shape [E, K, N] (E experts, K input features, N output features) + expert_token_counts: Number of tokens assigned to each expert, shape [E] + expert_token_offsets: Starting position of each expert's tokens in sorted order, shape [E+1] + sorted_to_orig_token_idx: Maps sorted token positions back to original positions, shape [T] + max_T_per_expert_tensor: Dummy tensor whose size indicates max tokens per expert + + Returns: + Output activations tensor of shape [T, N] + """ # Extract dimensions from input tensors T, K = A.shape E, _, N = W.shape @@ -89,6 +116,9 @@ def moe_matmul_ogs( return C +# %% +# Helper Function for Kernel Arguments +# -------------------------------- def moe_matmul_ogs_helion_kernel_args_gen( A: torch.Tensor, # [T, K] - Input activations W: torch.Tensor, # [E, K, N] - Expert weights @@ -96,6 +126,19 @@ def moe_matmul_ogs_helion_kernel_args_gen( ) -> tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor ]: + """ + Generates the arguments needed for the MoE MatMul OGS kernel. + + Prepares the data structures needed for efficient token routing and processing. + + Args: + A: Input activations tensor of shape [T, K] + W: Expert weights tensor of shape [E, K, N] + top1_expert_per_token: Expert assignment for each token, shape [T] + + Returns: + Tuple of tensors needed for the MoE MatMul OGS kernel + """ E = W.size(0) # Number of experts device = A.device @@ -131,9 +174,23 @@ def moe_matmul_ogs_helion_kernel_args_gen( ) +# %% +# Reference Implementation +# -------------------- def moe_matmul_ogs_reference( A: torch.Tensor, W: torch.Tensor, top1_expert_per_token: torch.Tensor ) -> torch.Tensor: + """ + PyTorch reference implementation of MoE matrix multiplication. + + Args: + A: Input activations tensor of shape [T, K] + W: Expert weights tensor of shape [E, K, N] + top1_expert_per_token: Expert assignment for each token, shape [T] + + Returns: + Output activations tensor of shape [T, N] + """ T, K = A.shape N = W.size(2) device, dtype = A.device, torch.promote_types(A.dtype, W.dtype) @@ -150,7 +207,19 @@ def moe_matmul_ogs_reference( return C +# %% +# Verification Function +# ------------------- def check(T: int, K: int, N: int, n_experts: int) -> None: + """ + Verify the MoE matmul OGS kernel implementation against the reference implementation. + + Args: + T: Number of tokens + K: Input feature dimension + N: Output feature dimension + n_experts: Number of experts + """ dtype = torch.float16 device = "cuda" if torch.cuda.is_available() else "cpu" @@ -172,7 +241,15 @@ def reference_fn() -> torch.Tensor: run_example(helion_fn, reference_fn, ()) +# %% +# Main Function +# ----------- def main() -> None: + """ + Main entry point that runs the MoE matmul OGS kernel verification. + + Tests with 1000 tokens, 500 input features, 200 output features, and 30 experts. + """ check(1000, 500, 200, 30) diff --git a/examples/rms_norm.py b/examples/rms_norm.py index c1b46841..678d7f86 100644 --- a/examples/rms_norm.py +++ b/examples/rms_norm.py @@ -1,3 +1,14 @@ +""" +Root Mean Square Normalization Example +================================= + +This example demonstrates how to implement a Root Mean Square (RMS) normalization +operation using Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import torch @@ -6,13 +17,33 @@ from helion._testing import run_example import helion.language as hl +# %% +# Configuration +# ----------- # TritonBench configuration # TODO(yf225): reduction dim size = 8192 currently throws error. After it's fixed we can remove "num_inputs" extra arg. TRITONBENCH_ARGS = {"num_inputs": 3} +# %% +# RMS Normalization Kernel +# --------------------- @helion.kernel(static_shapes=True) def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5) -> torch.Tensor: + """ + Performs Root Mean Square (RMS) normalization on the input tensor. + + RMS normalization normalizes by the root mean square of the elements: + output = x / sqrt(mean(x^2) + eps) * weight + + Args: + x: Input tensor of shape [M, N] + weight: Scale parameter of shape [N] + eps: Small constant for numerical stability + + Returns: + Output tensor of shape [M, N] with RMS normalization applied + """ m, n = x.size() assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {n}" @@ -33,15 +64,41 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5) -> torch. return out +# %% +# Benchmark Wrapper +# -------------- def rms_norm_tritonbench(H: int, inp: torch.Tensor) -> torch.Tensor: - """Wrapper for tritonbench that matches expected interface.""" + """ + Wrapper for tritonbench that matches expected interface. + + Args: + H: Hidden dimension size + inp: Input tensor + + Returns: + Normalized tensor + """ weight = torch.ones(H, device=inp.device, dtype=inp.dtype) return rms_norm(inp, weight, eps=1e-6) +# %% +# Reference Implementation +# -------------------- def rms_norm_pytorch( x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5 ) -> torch.Tensor: + """ + PyTorch reference implementation of RMS normalization. + + Args: + x: Input tensor + weight: Scale parameter + eps: Small constant for numerical stability + + Returns: + Normalized tensor + """ input_dtype = x.dtype hidden_states = x.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) @@ -49,13 +106,34 @@ def rms_norm_pytorch( return weight * hidden_states.to(input_dtype) +# %% +# Verification Function +# ------------------- def check(m: int, n: int) -> None: + """ + Verify the RMS norm kernel implementation against the PyTorch reference implementation. + + Args: + m: First dimension of the test tensor + n: Second dimension of the test tensor + """ x = torch.randn([m, n], device="cuda", dtype=torch.float16) weight = torch.randn([n], device="cuda", dtype=torch.float16) run_example(rms_norm, rms_norm_pytorch, (x, weight, 1e-5)) +# %% +# Main Function +# ----------- def main() -> None: + """ + Main entry point that runs the RMS norm kernel verification with different tensor sizes. + + Tests with three configurations: + - 32x64 + - 128x256 + - 1024x1024 + """ check(32, 64) check(128, 256) check(1024, 1024) diff --git a/examples/segment_reduction.py b/examples/segment_reduction.py index 32792de3..e8ce4cff 100644 --- a/examples/segment_reduction.py +++ b/examples/segment_reduction.py @@ -1,4 +1,14 @@ -# Code based on https://github.com/pytorch-labs/helion/issues/237 +""" +Segmented Reduction Example +======================= + +This example demonstrates how to implement a segmented reduction operation using Helion, +comparing it with Triton and PyTorch implementations. +""" + +# %% +# Imports +# ------- from __future__ import annotations import torch @@ -11,12 +21,29 @@ import helion.language as hl +# %% +# Helion Implementation +# ----------------- def combine_fn_helion( left_values: torch.Tensor, left_indices: torch.Tensor, right_values: torch.Tensor, right_indices: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Combine function for associative scan in Helion implementation. + + Adds values when indices match (same segment), otherwise takes the right value. + + Args: + left_values: Values from the left side of the scan + left_indices: Indices from the left side of the scan + right_values: Values from the right side of the scan + right_indices: Indices from the right side of the scan + + Returns: + Tuple of (combined_values, right_indices) + """ combined_values = torch.where( left_indices == right_indices, left_values + right_values, right_values ) @@ -27,6 +54,19 @@ def combine_fn_helion( def segmented_reduction_helion( indices: torch.Tensor, input_data: torch.Tensor, num_nodes: int ) -> torch.Tensor: + """ + Performs segmented reduction using Helion. + + Reduces input data by summing values with the same index. + + Args: + indices: Tensor of segment indices for each element + input_data: Input tensor of shape [num_elements, num_features] + num_nodes: Number of output nodes/segments + + Returns: + Output tensor of shape [num_nodes, num_features] with reduced values + """ num_elements, num_features = input_data.shape output = torch.zeros( (num_nodes, num_features), dtype=input_data.dtype, device=input_data.device @@ -47,6 +87,9 @@ def segmented_reduction_helion( return output +# %% +# Triton Implementation +# ----------------- @triton.jit def combine_fn_triton( left_values: tl.tensor, @@ -54,6 +97,20 @@ def combine_fn_triton( right_values: tl.tensor, right_indices: tl.tensor, ) -> tuple[tl.tensor, tl.tensor]: + """ + Combine function for associative scan in Triton implementation. + + Adds values when indices match (same segment), otherwise takes the right value. + + Args: + left_values: Values from the left side of the scan + left_indices: Indices from the left side of the scan + right_values: Values from the right side of the scan + right_indices: Indices from the right side of the scan + + Returns: + Tuple of (combined_values, combined_indices) + """ same_segment = left_indices == right_indices combined_values = tl.where(same_segment, left_values + right_values, right_values) combined_indices = right_indices @@ -79,6 +136,19 @@ def _segmented_reduction_triton( C: tl.constexpr, # Number of features in the input tensor (2d) BLOCK_SIZE: tl.constexpr, # Block size for the scan ) -> None: + """ + Triton kernel for segmented reduction. + + Uses associative scan to efficiently perform segmented reduction. + + Args: + index: Input index tensor + in_ptr: Input data tensor + out_ptr: Output tensor + E: Number of elements in the input tensor + C: Number of features in the input tensor + BLOCK_SIZE: Block size for the scan + """ # Triton version adapted from # https://github.com/fishmingyu/GeoT/blob/main/geot/triton/seg_reduction.py pid = tl.program_id(axis=0) @@ -109,6 +179,19 @@ def _segmented_reduction_triton( def segmented_reduction_triton( indices: torch.Tensor, input_data: torch.Tensor, num_nodes: int ) -> torch.Tensor: + """ + Performs segmented reduction using Triton. + + Wrapper function for the Triton kernel implementation. + + Args: + indices: Tensor of segment indices for each element + input_data: Input tensor of shape [num_elements, num_features] + num_nodes: Number of output nodes/segments + + Returns: + Output tensor of shape [num_nodes, num_features] with reduced values + """ E, C = input_data.shape output = torch.zeros( (num_nodes, C), dtype=input_data.dtype, device=input_data.device @@ -121,9 +204,25 @@ def grid(META: dict[str, int]) -> tuple[int, ...]: return output +# %% +# PyTorch Reference Implementation +# ---------------------------- def segmented_reduction_pytorch( indices: torch.Tensor, input_data: torch.Tensor, num_nodes: int ) -> torch.Tensor: + """ + Performs segmented reduction using PyTorch's scatter_add. + + Reference implementation using PyTorch's native operations. + + Args: + indices: Tensor of segment indices for each element + input_data: Input tensor of shape [num_elements, num_features] + num_nodes: Number of output nodes/segments + + Returns: + Output tensor of shape [num_nodes, num_features] with reduced values + """ # Run PyTorch reference (scatter_add equivalent) num_features = input_data.size(1) pytorch_output = torch.zeros( @@ -135,7 +234,16 @@ def segmented_reduction_pytorch( return pytorch_output +# %% +# Main Function +# ----------- def main() -> None: + """ + Main entry point that runs the segmented reduction implementations. + + Creates random data with 100 nodes, 2000 edges, and 128 features, + then compares the Helion implementation against Triton and PyTorch. + """ num_nodes = 100 num_edges = 2000 num_features = 128 diff --git a/examples/softmax.py b/examples/softmax.py index e8dcdcf1..edb91f55 100644 --- a/examples/softmax.py +++ b/examples/softmax.py @@ -1,3 +1,13 @@ +""" +Softmax Function Example +=================== + +This example demonstrates how to implement softmax operations using different approaches in Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import torch @@ -7,8 +17,20 @@ import helion.language as hl +# %% +# Simple Softmax Kernel +# ----------------- @helion.kernel() def softmax(x: torch.Tensor) -> torch.Tensor: + """ + Performs softmax operation along dimension 1 using PyTorch's built-in softmax. + + Args: + x: Input tensor of shape [N, M] + + Returns: + Output tensor of shape [N, M] with softmax applied along dimension 1 + """ n, _m = x.size() out = torch.empty_like(x) for tile_n in hl.tile(n): @@ -16,9 +38,25 @@ def softmax(x: torch.Tensor) -> torch.Tensor: return out -# This generates the same code as the above, but avoids using the pytorch softmax decomposition +# %% +# Decomposed Softmax Kernel +# --------------------- @helion.kernel() def softmax_decomposed(x: torch.Tensor) -> torch.Tensor: + """ + Performs softmax operation along dimension 1 using manual decomposition. + + Implements the softmax algorithm step by step: + 1. Find the maximum value for numerical stability + 2. Subtract the maximum and compute exponentials + 3. Normalize by the sum of exponentials + + Args: + x: Input tensor of shape [N, M] + + Returns: + Output tensor of shape [N, M] with softmax applied along dimension 1 + """ n, _m = x.size() out = torch.empty_like(x) for tile_n in hl.tile(n): @@ -30,9 +68,23 @@ def softmax_decomposed(x: torch.Tensor) -> torch.Tensor: return out -# This optimization does softmax in fewer passes, but is less numerically stable +# %% +# Two-Pass Optimized Softmax Kernel +# ----------------------------- @helion.kernel() def softmax_two_pass(x: torch.Tensor) -> torch.Tensor: + """ + Performs softmax operation in two passes for better performance. + + This optimized version computes softmax with fewer passes over the data, + trading some numerical stability for performance. + + Args: + x: Input tensor of shape [M, N] + + Returns: + Output tensor of shape [M, N] with softmax applied along dimension 1 + """ m, n = x.size() out = torch.empty_like(x) block_size_m = hl.register_block_size(m) @@ -54,7 +106,17 @@ def softmax_two_pass(x: torch.Tensor) -> torch.Tensor: return out +# %% +# Verification Function +# ------------------- def check(m: int, n: int) -> None: + """ + Verify the softmax kernel implementations against PyTorch's native softmax function. + + Args: + m: First dimension of the test tensor + n: Second dimension of the test tensor + """ x = torch.randn([m, n], device="cuda", dtype=torch.float16) kernels = { "helion simple": softmax, @@ -64,7 +126,13 @@ def check(m: int, n: int) -> None: run_example(kernels, lambda x: torch.nn.functional.softmax(x, dim=1), (x,)) +# %% +# Main Function +# ----------- def main() -> None: + """ + Main entry point that runs the softmax kernel verification with a 1024x1024 tensor. + """ check(1024, 1024) diff --git a/examples/sum.py b/examples/sum.py index 3def1af2..ef47ba28 100644 --- a/examples/sum.py +++ b/examples/sum.py @@ -1,3 +1,13 @@ +""" +Sum Reduction Example +================ + +This example demonstrates how to implement a sum reduction operation along the last dimension using Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import torch @@ -7,9 +17,20 @@ import helion.language as hl +# %% +# Sum Kernel +# -------- @helion.kernel() def sum_kernel(x: torch.Tensor) -> torch.Tensor: - """Sum 2D tensor along the last dimension.""" + """ + Sums a 2D tensor along the last dimension. + + Args: + x: Input tensor of shape [M, N] + + Returns: + Output tensor of shape [M] containing the sum of each row + """ m, n = x.shape out = torch.empty([m], dtype=x.dtype, device=x.device) @@ -19,8 +40,19 @@ def sum_kernel(x: torch.Tensor) -> torch.Tensor: return out +# %% +# Benchmark Wrapper +# -------------- def sum_tritonbench(x: torch.Tensor) -> torch.Tensor: - """Wrapper for tritonbench that handles 1D input.""" + """ + Wrapper for tritonbench that handles 1D input. + + Args: + x: Input tensor (1D or 2D) + + Returns: + Sum of the tensor along the last dimension + """ if x.ndim == 1: # For 1D tensors, reshape to 2D for sum_kernel x_2d = x.unsqueeze(0) @@ -29,13 +61,33 @@ def sum_tritonbench(x: torch.Tensor) -> torch.Tensor: return sum_kernel(x) +# %% +# Verification Function +# ------------------- def check(m: int, n: int) -> None: + """ + Verify the sum kernel implementation against PyTorch's native sum function. + + Args: + m: First dimension of the test tensor + n: Second dimension of the test tensor + """ x = torch.randn([m, n], device="cuda", dtype=torch.float32) kernels = {"helion": sum_kernel} run_example(kernels, lambda x: x.sum(-1), (x,)) +# %% +# Main Function +# ----------- def main() -> None: + """ + Main entry point that runs the sum kernel verification with different tensor sizes. + + Tests with two configurations: + - 512x256 + - 1024x1024 + """ check(512, 256) check(1024, 1024) diff --git a/examples/template_via_closure.py b/examples/template_via_closure.py index 471fdf42..96d0920d 100644 --- a/examples/template_via_closure.py +++ b/examples/template_via_closure.py @@ -1,3 +1,14 @@ +""" +Template via Closure Example +======================= + +This example demonstrates how to implement a templated matrix multiplication kernel +with a customizable epilogue function using closures in Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations from typing import TYPE_CHECKING @@ -13,6 +24,9 @@ from collections.abc import Callable +# %% +# Templated MatMul Kernel +# ------------------- @helion.kernel( # static_shapes=True gives a performance boost for matmuls static_shapes=True, @@ -20,6 +34,21 @@ def matmul_with_epilogue( x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list[Tensor]], Tensor] ) -> Tensor: + """ + Matrix multiplication with a customizable epilogue function. + + This kernel demonstrates how to use closures to create templated kernels + where the epilogue operation can be customized at runtime. + + Args: + x: First input tensor of shape [M, K] + y: Second input tensor of shape [K, N] + epilogue: Function that takes the accumulator and tile indices and returns + the final output for that tile + + Returns: + Output tensor of shape [M, N] with the epilogue function applied + """ m, k = x.size() k2, n = y.size() assert k == k2, f"size mismatch {k} != {k2}" @@ -34,7 +63,21 @@ def matmul_with_epilogue( return out +# %% +# Autotuning Function +# --------------- def autotune(n: int, k: int, m: int) -> None: + """ + Autotunes the matmul_with_epilogue kernel and saves the best configuration. + + Creates random tensors and runs the autotuning process to find the optimal + configuration for the kernel with the given dimensions. + + Args: + n: First dimension of the first matrix + k: Second dimension of the first matrix / First dimension of the second matrix + m: Second dimension of the second matrix + """ x = torch.randn([n, k], device="cuda", dtype=torch.float16) y = torch.randn([k, m], device="cuda", dtype=torch.float16) bias = torch.randn([1, m], device="cuda", dtype=torch.float16) @@ -44,7 +87,20 @@ def autotune(n: int, k: int, m: int) -> None: best_config.save("best_config.json") +# %% +# Verification Function +# ------------------- def check(n: int, k: int, m: int) -> None: + """ + Verify the matmul_with_epilogue kernel implementation against a PyTorch baseline. + + Tests matrix multiplication with a ReLU + bias epilogue function. + + Args: + n: First dimension of the first matrix + k: Second dimension of the first matrix / First dimension of the second matrix + m: Second dimension of the second matrix + """ x = torch.randn([n, k], device="cuda", dtype=torch.float16) y = torch.randn([k, m], device="cuda", dtype=torch.float16) bias: torch.Tensor = torch.randn([1, m], device="cuda", dtype=torch.float16) @@ -66,7 +122,16 @@ def baseline_wrapper(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ) +# %% +# Main Function +# ----------- def main() -> None: + """ + Main entry point that runs the matmul_with_epilogue kernel verification. + + Tests with 1024x1024 matrices and a ReLU + bias epilogue function. + Uncomment the autotune line to run autotuning instead. + """ # autotune(1024, 1024, 1024) check(1024, 1024, 1024)