Skip to content

02. Usage Guide

FindHao edited this page Jul 27, 2025 · 6 revisions

This guide walks you through the complete workflow of using TritonParse to analyze Triton kernel compilation processes.

📋 Overview

TritonParse workflow consists of three main steps:

  1. Generate Traces - Capture Triton compilation events
  2. Parse Traces - Process raw logs into structured format
  3. Analyze Results - Visualize and explore using the web interface

🚀 Step 1: Generate Triton Trace Files

Basic Setup

First, integrate TritonParse into your Triton/PyTorch code:

# === TritonParse initialization ===
import tritonparse.structured_logging

# Initialize structured logging to capture Triton compilation events
log_path = "./logs/"
tritonparse.structured_logging.init(log_path, enable_launch_trace=True)
# === End TritonParse initialization ===

# Your original Triton/PyTorch code below...

tritonparse.utils.unified_parse(
    source=log_path,
    out="./parsed_output",
    overwrite=True
)

Example: Complete Triton Kernel

Here's a complete example showing how to instrument a Triton kernel:

import torch
import triton
import triton.language as tl
import tritonparse.structured_logging
import tritonparse.utils

# Initialize logging
log_path = "./logs/"
tritonparse.structured_logging.init(log_path, enable_launch_trace=True)

@triton.jit
def add_kernel(
    a_ptr,
    b_ptr,
    c_ptr,
    n_elements,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements

    a = tl.load(a_ptr + offsets, mask=mask)
    b = tl.load(b_ptr + offsets, mask=mask)
    c = a + b
    tl.store(c_ptr + offsets, c, mask=mask)

def tensor_add(a, b):
    n_elements = a.numel()
    c = torch.empty_like(a)
    BLOCK_SIZE = 1024
    grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
    add_kernel[grid](a, b, c, n_elements, BLOCK_SIZE)
    return c

# Example usage
if __name__ == "__main__":
    # Create test tensors
    device = "cuda" if torch.cuda.is_available() else "cpu"
    a = torch.randn(1024, 1024, device=device, dtype=torch.float32)
    b = torch.randn(1024, 1024, device=device, dtype=torch.float32)
    
    # Execute kernel (this will be traced)
    c = tensor_add(a, b)
    
    # Parse the generated logs
    tritonparse.utils.unified_parse(
        source=log_path, 
        out="./parsed_output", 
        overwrite=True
    )

PyTorch 2.0+ Integration

For PyTorch 2.0+ compiled functions:

import torch
import tritonparse.structured_logging
import tritonparse.utils

# Initialize logging
log_path = "./logs/"
tritonparse.structured_logging.init(log_path, enable_launch_trace=True)

def simple_add(a, b):
    return a + b

# Test with torch.compile
compiled_add = torch.compile(simple_add)

# Create test data
device = "cuda"
a = torch.randn(1024, 1024, device=device, dtype=torch.float32)
b = torch.randn(1024, 1024, device=device, dtype=torch.float32)

# Execute compiled function (this will be traced)
result = compiled_add(a, b)

# Parse the generated logs
tritonparse.utils.unified_parse(
    source=log_path, 
    out="./parsed_output", 
    overwrite=True
)

Environment Variables (optional)

Set these before running your code:

# Disable FX graph cache to ensure PT2 compilation happens every time (optional)
export TORCHINDUCTOR_FX_GRAPH_CACHE=0

# Enable debug logging (optional)
export TRITONPARSE_DEBUG=1

# Enable NDJSON output (default)
export TRITONPARSE_NDJSON=1

# Enable gzip compression for trace files (optional)
export TRITON_TRACE_GZIP=1

Running the Code

# Run your instrumented code
TORCHINDUCTOR_FX_GRAPH_CACHE=0 python your_script.py

Expected Output:

Triton kernel executed successfully
Torch compiled function executed successfully
tritonparse log file list: /tmp/tmp1gan7zky/log_file_list.json
INFO:tritonparse:Copying parsed logs from /tmp/tmp1gan7zky to /scratch/findhao/tritonparse/tests/parsed_output

================================================================================
📁 TRITONPARSE PARSING RESULTS
================================================================================
📂 Parsed files directory: /scratch/findhao/tritonparse/tests/parsed_output
📊 Total files generated: 2

📄 Generated files:
--------------------------------------------------
   1. 📝 dedicated_log_triton_trace_findhao__mapped.ndjson.gz (7.2KB)
   2. 📝 log_file_list.json (181B)
================================================================================
✅ Parsing completed successfully!
================================================================================

🔧 Step 2: Parse Trace Files

Using unified_parse

The unified_parse function processes raw logs into structured format:

import tritonparse.utils

# Parse logs from directory
tritonparse.utils.unified_parse(
    source="./logs/",           # Input directory with raw logs
    out="./parsed_output",      # Output directory for processed files
    overwrite=True              # Overwrite existing output directory
)

Advanced Parsing Options

# Parse with additional options
tritonparse.utils.unified_parse(
    source="./logs/",
    out="./parsed_output",
    overwrite=True,
    rank=0,                     # Analyze specific rank (for multi-GPU)
    all_ranks=False,            # Analyze all ranks
    verbose=True                # Enable verbose logging
)

Understanding the Output

After parsing, you'll have:

parsed_output/
├── f0_fc0_a0_cai-.ndjson.gz          # Compressed trace for PT2 compiled functions
├── dedicated_log_triton_trace_findhao__mapped.ndjson.gz   # Compressed trace for Triton compiled functions
├── ...
└── log_file_list.json        # Index of all generated files (optional)

Each ndjson.gz file contains:

  • Kernel metadata (grid size, block size, etc.)
  • All IRs (TTGIR, TTIR, LLIR, PTX, AMDGCN)
  • Source mappings between IRs
  • Compilation stack traces
  • Launch diffs (if launch tracing is enabled)

Command Line Usage

You can also use the command line interface:

# Basic usage
python run.py ./logs/ -o ./parsed_output

# With options
python run.py ./logs/ -o ./parsed_output --overwrite --verbose

# Parse specific rank
python run.py ./logs/ -o ./parsed_output --rank 0

# Parse all ranks
python run.py ./logs/ -o ./parsed_output --all-ranks

🌐 Step 3: Analyze with Web Interface

Option A: Online Interface (Recommended)

  1. Visit the live tool: https://pytorch-labs.github.io/tritonparse/

  2. Load your trace files:

    • Click "Browse Files" or drag-and-drop
    • Select .gz files from your parsed_output directory
    • Or select .ndjson files from your logs directory
  3. Explore the visualization:

    • Kernel Overview Tab: Kernel metadata, call stack, IR links
    • IR Comparison Tab: Side-by-side IR comparison with line mapping

Option B: Local Development Interface

For contributors or custom deployments:

cd website
npm install
npm run dev

Access at http://localhost:5173

Supported File Formats

Format Description Source Mapping Recommended
.gz Compressed parsed traces ✅ Yes ✅ Yes
.ndjson Raw trace logs ❌ No ⚠️ Basic use only

Note: .ndjson files don't contain source code mappings between IR stages and launch diffs. Always use .gz files for full functionality.

📊 Understanding the Results

Kernel Overview

The overview page shows:

  • Kernel Information: Name, hash, grid/block sizes
  • Compilation Metadata: Device, compile time, memory usage
  • Call Stack: Python source code that triggered compilation
  • IR Navigation: Links to different IR representations
  • Launch Diff: Launch parameters that changed across different launches of the same kernel

Code Comparison

The comparison view offers:

  • Side-by-side IR viewing: Compare different compilation stages
  • Synchronized highlighting: Click a line to see corresponding lines in other IRs
  • Source mapping: Trace transformations across compilation pipeline

IR Stages Explained

Stage Description When Generated
TTGIR Triton GPU IR - High-level GPU operations After Triton frontend
TTIR Triton IR - Language-level operations After parsing
LLIR LLVM IR - Low-level operations After LLVM conversion
PTX NVIDIA PTX Assembly For NVIDIA GPUs
AMDGCN AMD GPU Assembly For AMD GPUs

🚀 Launch Analysis

TritonParse can analyze kernel launch parameters to identify variations and commonalities across different launches of the same kernel. This is useful for understanding how dynamic shapes or other factors affect kernel execution.

How it Works

  1. Enable Launch Tracing: You must enable launch tracing during the trace generation step. This is done by passing enable_launch_trace=True to tritonparse.structured_logging.init().
  2. Parsing: During the parsing step (tritonparse.utils.unified_parse), TritonParse will automatically group all launches for each kernel.
  3. Launch Diff Event: A new event of type launch_diff is generated for each kernel. This event contains:
    • total_launches: The total number of times the kernel was launched.
    • diffs: A dictionary showing which launch parameters (e.g., grid_x, grid_y) changed across launches and what their different values were.
    • sames: A dictionary showing which launch parameters remained constant across all launches.
    • launch_index_map: A mapping from the launch index to the original line number in the trace file.

Example launch_diff Event

{
  "event_type": "launch_diff",
  "hash": "...",
  "name": "triton_kernel_name",
  "total_launches": 10,
  "launch_index_map": { "0": 15, "1": 25, ... },
  "diffs": {
    "grid_x": [1024, 2048]
  },
  "sames": {
    "grid_y": 1,
    "grid_z": 1,
    "stream": 7
  }
}

This example shows that grid_x varied between 1024 and 2048 across 10 launches, while other parameters remained the same.

🎯 Common Use Cases

Understanding Compilation Pipeline

# Trace a simple kernel to understand compilation stages
@triton.jit
def simple_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    
    x = tl.load(x_ptr + offsets, mask=mask)
    y = x * 2.0  # Simple operation
    tl.store(y_ptr + offsets, y, mask=mask)

# Trace and analyze each compilation stage

🔍 Advanced Features

Filtering Kernels

Set kernel allowlist to trace only specific kernels:

# Only trace kernels matching these patterns
export TRITONPARSE_KERNEL_ALLOWLIST="my_kernel*,important_*"

Multi-GPU Analysis

For multi-GPU setups:

# Parse all ranks
tritonparse.utils.unified_parse(
    source="./logs/",
    out="./parsed_output",
    all_ranks=True  # Analyze all GPU ranks
)

# Or parse specific rank
tritonparse.utils.unified_parse(
    source="./logs/",
    out="./parsed_output",
    rank=1  # Analyze GPU rank 1
)

🐛 Troubleshooting

Common Issues

1. No Kernels Found

Error: No kernels found in the processed data