Skip to content

Commit f1398f7

Browse files
committed
[Ref Mode] PyTorch reference mode (eager only)
This PR focuses on enabling ref mode for APIs like hl.grid / hl.tile / hl.dot. A follow-up PR will enable ref mode for reduction / scan / indexing APIs. Part of #77. Please see inline code comments on the PR. stack-info: PR: #339, branch: yf225/stack/34
1 parent 45e9600 commit f1398f7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+1317
-199
lines changed

.github/workflows/test-template.yml

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
name: Reusable Test Workflow
2+
3+
on:
4+
workflow_call:
5+
inputs:
6+
test-name:
7+
required: true
8+
type: string
9+
ref-eager:
10+
required: false
11+
type: boolean
12+
default: false
13+
14+
jobs:
15+
test:
16+
name: ${{ inputs.test-name }}-cuda12.6-py${{ matrix.python-version }}-a10g
17+
18+
container:
19+
image: nvidia/cuda:12.6.3-devel-ubuntu24.04
20+
options: --gpus all
21+
22+
runs-on: linux.g5.4xlarge.nvidia.gpu
23+
24+
strategy:
25+
matrix:
26+
python-version: ["3.10", "3.12"]
27+
28+
defaults:
29+
run:
30+
shell: bash -l {0}
31+
32+
steps:
33+
- name: Check out code
34+
uses: actions/checkout@v4
35+
36+
- name: Install uv
37+
uses: astral-sh/setup-uv@v6
38+
with:
39+
python-version: ${{ matrix.python-version }}
40+
enable-cache: true
41+
42+
- name: Create virtual environment
43+
run: |
44+
uv venv --python ${{ matrix.python-version }}
45+
46+
- name: Get current month
47+
id: date
48+
run: echo "month=$(date +'%Y-%m')" >> $GITHUB_OUTPUT
49+
50+
- name: Cache dependencies
51+
id: cache
52+
uses: actions/cache@v4
53+
with:
54+
path: |
55+
~/.cache/uv
56+
~/.venv
57+
key: ${{ runner.os }}-deps-${{ matrix.python-version }}-${{ hashFiles('.github/workflows/test.yml', 'requirements.txt') }}-${{ steps.date.outputs.month }}
58+
restore-keys: |
59+
${{ runner.os }}-deps-
60+
61+
- name: Install PyTorch
62+
run: |
63+
source .venv/bin/activate
64+
uv pip install -U --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126
65+
66+
- name: Install Triton
67+
if: steps.cache.outputs.cache-hit != 'true'
68+
run: |
69+
set -x
70+
source .venv/bin/activate
71+
apt-get update
72+
apt-get install -y git
73+
apt-get install -y gcc-13 g++-13 zlib1g-dev
74+
export CC=gcc-13
75+
export CXX=g++-13
76+
mkdir -p /tmp/$USER
77+
cd /tmp/$USER
78+
uv pip uninstall triton pytorch-triton || true
79+
rm -rf triton/ || true
80+
git clone https://github.com/triton-lang/triton.git
81+
cd triton/
82+
uv pip install -r python/requirements.txt
83+
MAX_JOBS=$(nproc) TRITON_PARALLEL_LINK_JOBS=2 uv pip install .
84+
cd /tmp/$USER
85+
rm -rf triton/
86+
87+
- name: Install Requirements
88+
run: |
89+
source .venv/bin/activate
90+
uv pip install -r requirements.txt
91+
92+
- name: Run Tests
93+
run: |
94+
source .venv/bin/activate
95+
if [[ "${{ inputs.ref-eager }}" == "true" ]]; then
96+
HELION_INTERPRET=1 pytest
97+
else
98+
pytest
99+
fi

.github/workflows/test.yml

Lines changed: 10 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -13,83 +13,13 @@ concurrency:
1313

1414
jobs:
1515
test:
16-
name: test-cuda12.6-py${{ matrix.python-version }}-a10g
17-
18-
container:
19-
image: nvidia/cuda:12.6.3-devel-ubuntu24.04
20-
options: --gpus all
21-
22-
runs-on: linux.g5.4xlarge.nvidia.gpu
23-
24-
strategy:
25-
matrix:
26-
python-version: ["3.10", "3.12"]
27-
28-
defaults:
29-
run:
30-
shell: bash -l {0}
31-
32-
steps:
33-
- name: Check out code
34-
uses: actions/checkout@v4
35-
36-
- name: Install uv
37-
uses: astral-sh/setup-uv@v6
38-
with:
39-
python-version: ${{ matrix.python-version }}
40-
enable-cache: true
41-
42-
- name: Create virtual environment
43-
run: |
44-
uv venv --python ${{ matrix.python-version }}
45-
46-
- name: Get current month
47-
id: date
48-
run: echo "month=$(date +'%Y-%m')" >> $GITHUB_OUTPUT
49-
50-
- name: Cache dependencies
51-
id: cache
52-
uses: actions/cache@v4
53-
with:
54-
path: |
55-
~/.cache/uv
56-
~/.venv
57-
key: ${{ runner.os }}-deps-${{ matrix.python-version }}-${{ hashFiles('.github/workflows/test.yml', 'requirements.txt') }}-${{ steps.date.outputs.month }}
58-
restore-keys: |
59-
${{ runner.os }}-deps-
60-
61-
- name: Install PyTorch
62-
run: |
63-
source .venv/bin/activate
64-
uv pip install -U --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126
65-
66-
- name: Install Triton
67-
if: steps.cache.outputs.cache-hit != 'true'
68-
run: |
69-
set -x
70-
source .venv/bin/activate
71-
apt-get update
72-
apt-get install -y git
73-
apt-get install -y gcc-13 g++-13 zlib1g-dev
74-
export CC=gcc-13
75-
export CXX=g++-13
76-
mkdir -p /tmp/$USER
77-
cd /tmp/$USER
78-
uv pip uninstall triton pytorch-triton || true
79-
rm -rf triton/ || true
80-
git clone https://github.com/triton-lang/triton.git
81-
cd triton/
82-
uv pip install -r python/requirements.txt
83-
MAX_JOBS=$(nproc) TRITON_PARALLEL_LINK_JOBS=2 uv pip install .
84-
cd /tmp/$USER
85-
rm -rf triton/
86-
87-
- name: Install Requirements
88-
run: |
89-
source .venv/bin/activate
90-
uv pip install -r requirements.txt
91-
92-
- name: Run Tests
93-
run: |
94-
source .venv/bin/activate
95-
pytest
16+
uses: ./.github/workflows/test-template.yml
17+
with:
18+
test-name: test
19+
ref-eager: false
20+
21+
test-ref-eager:
22+
uses: ./.github/workflows/test-template.yml
23+
with:
24+
test-name: test-ref-eager
25+
ref-eager: true

helion/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@
1111
from .runtime import Kernel
1212
from .runtime import kernel
1313
from .runtime import kernel as jit # alias
14+
from .runtime.settings import RefMode
1415
from .runtime.settings import Settings
1516
from .runtime.settings import set_default_settings
1617

1718
__all__ = [
1819
"Config",
1920
"Kernel",
21+
"RefMode",
2022
"Settings",
2123
"cdiv",
2224
"exc",

0 commit comments

Comments
 (0)