Skip to content

Commit 32cae0b

Browse files
Abdolpre-commit-ci[bot]shaneahmedJiaqi-Lv
authored
⚡️Add torch.compile Functionality (#716)
- Integrates PyTorch 2.0's [torch.compile](https://pytorch.org/docs/stable/generated/torch.compile.html) functionality to demonstrate performance improvements in torch code. This PR focuses on adding `torch.compile` to `PatchPredictor`. **Notes:** - According to the [documentation](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html), noticeable performance can be achieved when using modern NVIDIA GPUs (H100, A100, or V100) **TODO:** - [x] Resolve compilation errors related to using `torch.compile` in running models - [x] Initial config - [x] Add to patch predictor - [x] Add to registration - [x] Add to segmentation - [x] Test on custom models - [x] Test on `torch.compile` compatible GPUs --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Co-authored-by: Jiaqi-Lv <60471431+Jiaqi-Lv@users.noreply.github.com>
1 parent 9113996 commit 32cae0b

16 files changed

+383
-21
lines changed

.github/workflows/python-package.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ jobs:
5858
- name: Test with pytest
5959
run: |
6060
pytest --basetemp={envtmpdir} \
61-
--cov=tiatoolbox --cov-report=term --cov-report=xml \
61+
--cov=tiatoolbox --cov-report=term --cov-report=xml --cov-config=pyproject.toml \
6262
--capture=sys \
6363
--durations=10 --durations-min=1.0 \
6464
--maxfail=1

tests/conftest.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,17 @@
44

55
import os
66
import shutil
7+
import time
78
from pathlib import Path
89
from typing import Callable
910

1011
import pytest
12+
import torch
1113

1214
import tiatoolbox
1315
from tiatoolbox import logger
1416
from tiatoolbox.data import _fetch_remote_sample
15-
from tiatoolbox.utils.env_detection import running_on_ci
17+
from tiatoolbox.utils.env_detection import has_gpu, running_on_ci
1618

1719
# -------------------------------------------------------------------------------------
1820
# Generate Parameterized Tests
@@ -608,3 +610,37 @@ def data_path(tmp_path_factory: pytest.TempPathFactory) -> dict[str, object]:
608610
(tmp_path / "slides").mkdir()
609611
(tmp_path / "overlays").mkdir()
610612
return {"base_path": tmp_path}
613+
614+
615+
# -------------------------------------------------------------------------------------
616+
# Utility functions
617+
# -------------------------------------------------------------------------------------
618+
619+
620+
def timed(fn: Callable, *args: object) -> (Callable, float):
621+
"""A decorator that times the execution of a function.
622+
623+
Args:
624+
fn (Callable): The function to be timed.
625+
args (object): Arguments to be passed to the function.
626+
627+
Returns:
628+
A tuple containing the result of the function
629+
and the time taken to execute it in seconds.
630+
631+
"""
632+
compile_time = 0.0
633+
if has_gpu():
634+
start = torch.cuda.Event(enable_timing=True)
635+
end = torch.cuda.Event(enable_timing=True)
636+
start.record()
637+
result = fn(*args)
638+
end.record()
639+
torch.cuda.synchronize()
640+
compile_time = start.elapsed_time(end) / 1000
641+
else:
642+
start = time.time()
643+
result = fn(*args)
644+
end = time.time()
645+
compile_time = end - start
646+
return result, compile_time

tests/models/test_nucleus_instance_segmentor.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111
import joblib
1212
import numpy as np
1313
import pytest
14+
import torch
1415
import yaml
1516
from click.testing import CliRunner
1617

17-
from tiatoolbox import cli
18+
from tiatoolbox import cli, rcParam
1819
from tiatoolbox.models import (
1920
IOSegmentorConfig,
2021
NucleusInstanceSegmentor,
@@ -44,7 +45,12 @@ def _crash_func(_x: object) -> None:
4445

4546
def helper_tile_info() -> list:
4647
"""Helper function for tile information."""
48+
torch._dynamo.reset()
49+
current_torch_compile_mode = rcParam["torch_compile_mode"]
50+
rcParam["torch_compile_mode"] = "disable"
4751
predictor = NucleusInstanceSegmentor(model="A")
52+
torch._dynamo.reset()
53+
rcParam["torch_compile_mode"] = current_torch_compile_mode
4854
# ! assuming the tiles organized as follows (coming out from
4955
# ! PatchExtractor). If this is broken, need to check back
5056
# ! PatchExtractor output ordering first

tests/models/test_patch_predictor.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
import torch
1414
from click.testing import CliRunner
1515

16-
from tiatoolbox import cli
16+
from tests.conftest import timed
17+
from tiatoolbox import cli, logger, rcParam
1718
from tiatoolbox.models import IOPatchPredictorConfig, PatchPredictor
1819
from tiatoolbox.models.architecture.vanilla import CNNModel
1920
from tiatoolbox.models.dataset import (
@@ -1226,3 +1227,53 @@ def test_cli_model_multiple_file_mask(remote_sample: Callable, tmp_path: Path) -
12261227
assert tmp_path.joinpath("2.merged.npy").exists()
12271228
assert tmp_path.joinpath("2.raw.json").exists()
12281229
assert tmp_path.joinpath("results.json").exists()
1230+
1231+
1232+
# -------------------------------------------------------------------------------------
1233+
# torch.compile
1234+
# -------------------------------------------------------------------------------------
1235+
1236+
1237+
def test_patch_predictor_torch_compile(
1238+
sample_patch1: Path,
1239+
sample_patch2: Path,
1240+
tmp_path: Path,
1241+
) -> None:
1242+
"""Test PatchPredictor with with torch.compile functionality.
1243+
1244+
Args:
1245+
sample_patch1 (Path): Path to sample patch 1.
1246+
sample_patch2 (Path): Path to sample patch 2.
1247+
tmp_path (Path): Path to temporary directory.
1248+
1249+
"""
1250+
torch_compile_mode = rcParam["torch_compile_mode"]
1251+
torch._dynamo.reset()
1252+
rcParam["torch_compile_mode"] = "default"
1253+
_, compile_time = timed(
1254+
test_patch_predictor_api,
1255+
sample_patch1,
1256+
sample_patch2,
1257+
tmp_path,
1258+
)
1259+
logger.info("torch.compile default mode: %s", compile_time)
1260+
torch._dynamo.reset()
1261+
rcParam["torch_compile_mode"] = "reduce-overhead"
1262+
_, compile_time = timed(
1263+
test_patch_predictor_api,
1264+
sample_patch1,
1265+
sample_patch2,
1266+
tmp_path,
1267+
)
1268+
logger.info("torch.compile reduce-overhead mode: %s", compile_time)
1269+
torch._dynamo.reset()
1270+
rcParam["torch_compile_mode"] = "max-autotune"
1271+
_, compile_time = timed(
1272+
test_patch_predictor_api,
1273+
sample_patch1,
1274+
sample_patch2,
1275+
tmp_path,
1276+
)
1277+
logger.info("torch.compile max-autotune mode: %s", compile_time)
1278+
torch._dynamo.reset()
1279+
rcParam["torch_compile_mode"] = torch_compile_mode

tests/models/test_semantic_segmentation.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
from click.testing import CliRunner
2121
from torch import nn
2222

23-
from tiatoolbox import cli
23+
from tests.conftest import timed
24+
from tiatoolbox import cli, logger, rcParam
2425
from tiatoolbox.models import SemanticSegmentor
2526
from tiatoolbox.models.architecture import fetch_pretrained_weights
2627
from tiatoolbox.models.architecture.utils import centre_crop
@@ -897,3 +898,48 @@ def test_cli_semantic_segmentation_multi_file(
897898
_test_pred = (_test_pred[..., 1] > 0.50) * 255
898899

899900
assert np.mean(np.abs(_cache_pred - _test_pred) / 255) < 1e-3
901+
902+
903+
# -------------------------------------------------------------------------------------
904+
# torch.compile
905+
# -------------------------------------------------------------------------------------
906+
907+
908+
def test_semantic_segmentor_torch_compile(
909+
remote_sample: Callable,
910+
tmp_path: Path,
911+
) -> None:
912+
"""Test SemanticSegmentor using pretrained model with torch.compile functionality.
913+
914+
Args:
915+
remote_sample (Callable): Callable object used to extract remote sample.
916+
tmp_path (Path): Path to temporary directory.
917+
918+
"""
919+
torch_compile_mode = rcParam["torch_compile_mode"]
920+
torch._dynamo.reset()
921+
rcParam["torch_compile_mode"] = "default"
922+
_, compile_time = timed(
923+
test_functional_pretrained,
924+
remote_sample,
925+
tmp_path,
926+
)
927+
logger.info("torch.compile default mode: %s", compile_time)
928+
torch._dynamo.reset()
929+
rcParam["torch_compile_mode"] = "reduce-overhead"
930+
_, compile_time = timed(
931+
test_functional_pretrained,
932+
remote_sample,
933+
tmp_path,
934+
)
935+
logger.info("torch.compile reduce-overhead mode: %s", compile_time)
936+
torch._dynamo.reset()
937+
rcParam["torch_compile_mode"] = "max-autotune"
938+
_, compile_time = timed(
939+
test_functional_pretrained,
940+
remote_sample,
941+
tmp_path,
942+
)
943+
logger.info("torch.compile max-autotune mode: %s", compile_time)
944+
torch._dynamo.reset()
945+
rcParam["torch_compile_mode"] = torch_compile_mode

tests/test_utils.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,16 @@
1313
import numpy as np
1414
import pandas as pd
1515
import pytest
16+
import torch
1617
from PIL import Image
1718
from requests import HTTPError
1819
from shapely.geometry import Polygon
1920

2021
from tests.test_annotation_stores import cell_polygon
21-
from tiatoolbox import utils
22+
from tiatoolbox import rcParam, utils
2223
from tiatoolbox.annotation.storage import DictionaryStore, SQLiteStore
2324
from tiatoolbox.models.architecture import fetch_pretrained_weights
25+
from tiatoolbox.models.architecture.utils import compile_model
2426
from tiatoolbox.utils import misc
2527
from tiatoolbox.utils.exceptions import FileNotSupportedError
2628
from tiatoolbox.utils.transforms import locsize2bounds
@@ -1827,3 +1829,40 @@ def test_patch_pred_store_persist_ext(tmp_path: pytest.TempPathFactory) -> None:
18271829
# check correct error is raised if coordinates are missing
18281830
with pytest.raises(ValueError, match="coordinates"):
18291831
misc.dict_to_store(patch_output, (1.0, 1.0))
1832+
1833+
1834+
def test_torch_compile_already_compiled() -> None:
1835+
"""Test that torch_compile does not recompile a model that is already compiled."""
1836+
torch_compile_modes = [
1837+
"default",
1838+
"reduce-overhead",
1839+
"max-autotune",
1840+
"max-autotune-no-cudagraphs",
1841+
]
1842+
current_torch_compile_mode = rcParam["torch_compile_mode"]
1843+
model = torch.nn.Sequential(torch.nn.Linear(10, 10), torch.nn.Linear(10, 10))
1844+
1845+
for mode in torch_compile_modes:
1846+
torch._dynamo.reset()
1847+
rcParam["torch_compile_mode"] = mode
1848+
compiled_model = compile_model(model, mode=mode)
1849+
recompiled_model = compile_model(compiled_model, mode=mode)
1850+
assert compiled_model == recompiled_model
1851+
1852+
torch._dynamo.reset()
1853+
rcParam["torch_compile_mode"] = current_torch_compile_mode
1854+
1855+
1856+
def test_torch_compile_disable() -> None:
1857+
"""Test torch_compile's disable mode."""
1858+
model = torch.nn.Sequential(torch.nn.Linear(10, 10), torch.nn.Linear(10, 10))
1859+
compiled_model = compile_model(model, mode="disable")
1860+
assert model == compiled_model
1861+
1862+
1863+
def test_torch_compile_compatibility(caplog: pytest.LogCaptureFixture) -> None:
1864+
"""Test if torch-compile compatibility is checked correctly."""
1865+
from tiatoolbox.models.architecture.utils import is_torch_compile_compatible
1866+
1867+
is_torch_compile_compatible()
1868+
assert "torch.compile" in caplog.text

tests/test_wsi_registration.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
import cv2
66
import numpy as np
77
import pytest
8+
import torch
89

10+
from tests.conftest import timed
11+
from tiatoolbox import logger, rcParam
912
from tiatoolbox.tools.registration.wsi_registration import (
1013
AffineWSITransformer,
1114
DFBRegister,
@@ -576,3 +579,70 @@ def test_affine_wsi_transformer(sample_ome_tiff: Path) -> None:
576579
expected = cv2.rotate(expected, cv2.ROTATE_90_CLOCKWISE)
577580

578581
assert np.sum(expected - output) == 0
582+
583+
584+
def test_dfbr_feature_extractor_torch_compile(dfbr_features: Path) -> None:
585+
"""Test DFBRFeatureExtractor with torch.compile functionality.
586+
587+
Args:
588+
dfbr_features (Path): Path to the expected features.
589+
590+
"""
591+
592+
def _extract_features() -> tuple:
593+
dfbr = DFBRegister()
594+
fixed_img = np.repeat(
595+
np.expand_dims(
596+
np.repeat(
597+
np.expand_dims(np.arange(0, 64, 1, dtype=np.uint8), axis=1),
598+
64,
599+
axis=1,
600+
),
601+
axis=2,
602+
),
603+
3,
604+
axis=2,
605+
)
606+
output = dfbr.extract_features(fixed_img, fixed_img)
607+
pool3_feat = output["block3_pool"][0, :].detach().numpy()
608+
pool4_feat = output["block4_pool"][0, :].detach().numpy()
609+
pool5_feat = output["block5_pool"][0, :].detach().numpy()
610+
611+
return pool3_feat, pool4_feat, pool5_feat
612+
613+
torch_compile_mode = rcParam["torch_compile_mode"]
614+
torch._dynamo.reset()
615+
rcParam["torch_compile_mode"] = "default"
616+
(pool3_feat, pool4_feat, pool5_feat), compile_time = timed(_extract_features)
617+
_pool3_feat, _pool4_feat, _pool5_feat = np.load(
618+
str(dfbr_features),
619+
allow_pickle=True,
620+
)
621+
assert np.mean(np.abs(pool3_feat - _pool3_feat)) < 1.0e-4
622+
assert np.mean(np.abs(pool4_feat - _pool4_feat)) < 1.0e-4
623+
assert np.mean(np.abs(pool5_feat - _pool5_feat)) < 1.0e-4
624+
logger.info("torch.compile default mode: %s", compile_time)
625+
torch._dynamo.reset()
626+
rcParam["torch_compile_mode"] = "reduce-overhead"
627+
(pool3_feat, pool4_feat, pool5_feat), compile_time = timed(_extract_features)
628+
_pool3_feat, _pool4_feat, _pool5_feat = np.load(
629+
str(dfbr_features),
630+
allow_pickle=True,
631+
)
632+
assert np.mean(np.abs(pool3_feat - _pool3_feat)) < 1.0e-4
633+
assert np.mean(np.abs(pool4_feat - _pool4_feat)) < 1.0e-4
634+
assert np.mean(np.abs(pool5_feat - _pool5_feat)) < 1.0e-4
635+
logger.info("torch.compile reduce-overhead mode: %s", compile_time)
636+
torch._dynamo.reset()
637+
rcParam["torch_compile_mode"] = "max-autotune"
638+
(pool3_feat, pool4_feat, pool5_feat), compile_time = timed(_extract_features)
639+
_pool3_feat, _pool4_feat, _pool5_feat = np.load(
640+
str(dfbr_features),
641+
allow_pickle=True,
642+
)
643+
assert np.mean(np.abs(pool3_feat - _pool3_feat)) < 1.0e-4
644+
assert np.mean(np.abs(pool4_feat - _pool4_feat)) < 1.0e-4
645+
assert np.mean(np.abs(pool5_feat - _pool5_feat)) < 1.0e-4
646+
logger.info("torch.compile max-autotune mode: %s", compile_time)
647+
torch._dynamo.reset()
648+
rcParam["torch_compile_mode"] = torch_compile_mode

tiatoolbox/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class _RcParam(TypedDict):
7373

7474
TIATOOLBOX_HOME: Path
7575
pretrained_model_info: dict[str, dict]
76+
torch_compile_mode: str
7677

7778

7879
def read_registry_files(path_to_registry: str | Path) -> dict:
@@ -102,6 +103,10 @@ def read_registry_files(path_to_registry: str | Path) -> dict:
102103
"pretrained_model_info": read_registry_files(
103104
"data/pretrained_model.yaml",
104105
), # Load a dictionary of sample files data (names and urls)
106+
"torch_compile_mode": "default",
107+
# Set `torch-compile` mode to `default`
108+
# Options: `disable`, `default`, `reduce-overhead`, `max-autotune`
109+
# or “max-autotune-no-cudagraphs”
105110
}
106111

107112

tiatoolbox/models/architecture/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def get_pretrained_model(
150150
model.load_state_dict(saved_state_dict, strict=True)
151151

152152
# !
153+
153154
io_info = info["ioconfig"]
154155
creator = locate(f"tiatoolbox.models.engine.{io_info['class']}")
155156

0 commit comments

Comments
 (0)