Skip to content

Releases: pytorch-labs/tritonparse

TritonParse v0.1.0 Release 🎉

21 Jul 19:05
Compare
Choose a tag to compare

This is the initial release of TritonParse, a comprehensive suite of tools for parsing, analyzing, and visualizing Triton kernel compilation and launch traces. This release combines a powerful command-line processor with an interactive web interface to provide deep insights into the Triton compiler.

Highlights

  • Interactive Web Interface: A rich web-based UI to explore, compare, and understand Triton IRs. Features side-by-side code views, synchronized highlighting, and detailed metadata panels. Access it live at pytorch-labs.github.io/tritonparse/.
  • Structured Logging: TritonParse captures detailed information from the Triton compiler and runtime, including IRs (TTIR, TTGIR, PTX, AMDGCN), metadata, timings, and Python source code, and outputs it as structured NDJSON logs.
  • Source-to-Source Mapping: The tool automatically generates bidirectional mappings between your Python code and the various intermediate representations (IRs) generated by the Triton compiler. This allows you to trace a line of Python code all the way down to the generated assembly and back.
  • Kernel Launch Tracing: TritonParse can trace each kernel launch, capturing the grid dimensions, kernel arguments (with detailed tensor information), and other runtime metadata.
  • Flexible Log Parsing: Parse logs programmatically or via a powerful CLI, with support for local files, directories, and multi-rank jobs.

Prerequisites

  • Python: >= 3.10
  • PyTorch: Nightly version recommended for best compatibility.
  • Triton: > 3.3.1 (must be compiled from source for now).
  • GPU: An NVIDIA or AMD GPU is required to generate traces. The web interface can be used to view traces on any machine.

Key Features

Interactive Web UI

The TritonParse web interface is the best way to analyze your parsed logs. It runs entirely in your browser (no data is sent to any servers) and offers:

  • Side-by-Side IR Comparison: Select any two IRs (e.g., TTGIR and PTX) and view them next to each other.
  • Synchronized Highlighting: Click on a line in one IR, and the corresponding lines in the other IR will automatically be highlighted, showing you exactly how code is transformed.
  • Kernel Overview: A dashboard view for each kernel showing key metadata, including compile times, memory usage, and the full Python call stack that triggered the compilation.
  • Easy Navigation: Quickly switch between different IRs and kernels within your trace file.

Comprehensive Trace Data & Processing

The backend of TritonParse is designed to capture a rich set of data:

  • Intermediate Representations: Full source code for TTIR, TTGIR, LLVM IR, and PTX/AMDGCN assembly.
  • Source Code: The original Python source code of your Triton kernel.
  • Metadata: Compilation metadata, including compiler flags, environment variables, and cache keys.
  • Timings: Detailed timing information for each compilation stage.
  • PyTorch Integration: Captures PyTorch-specific information like frame_id and compile_id when running in a torch.compile context.

Getting Started

The workflow is a simple three-step process:

1. Generate Traces

Enable tracing in your script by initializing tritonparse's structured logging:

import tritonparse.structured_logging

# Initialize logging to a specific folder
tritonparse.structured_logging.init("./my_logs/")

# Your Triton/PyTorch code here...

2. Parse Logs

After generating the raw trace files, you need to parse them to create the source mappings for the web interface. You can do this in two ways:

Method A: Programmatically in your script (Recommended)

You can parse the logs directly in your Python script after your kernel has run. This is the recommended approach as it integrates seamlessly into automated workflows.

import tritonparse.utils

# ... your code that generates traces runs first ...

# Then, parse the logs
tritonparse.utils.unified_parse(
    source="./my_logs/",
    out="./parsed_logs/",
    overwrite=True
)

Method B: From the command line

For quick, manual parsing, you can also use the run.py script from your terminal.

# Make sure to disable the TorchInductor cache
export TORCHINDUCTOR_FX_GRAPH_CACHE=0
python run.py ./my_logs/ -o ./parsed_logs/ --overwrite

Both methods will create gzipped output files (e.g., ..._mapped.ndjson.gz) in the parsed_logs directory. These .gz files contain the full source mappings and are required for the web interface.

3. Visualize and Analyze

Go to the TritonParse Web Interface and open the .gz files from your parsed_logs directory to start exploring.

Configuration

TritonParse can be configured using environment variables:

  • TRITON_TRACE: The directory to store trace files.
  • TRITON_TRACE_LAUNCH: Set to 1 or true to enable kernel launch tracing.
  • TRITONPARSE_KERNEL_ALLOWLIST: A comma-separated list of kernel name patterns to trace (e.g., *add*).
  • TRITON_TRACE_GZIP: Set to 1 or true to compress the trace files with gzip.

We hope you find TritonParse useful for understanding and debugging your Triton kernels!