Skip to content

Commit 55746c2

Browse files
jarlsondrelineick
authored andcommitted
Basic integration with the py-spy profiler (#356)
* add aggregation script to cli * add py-spy dependency and some param hints * format with ruff * formatting * add flamegraph and cli for flamegraph * add more advanced error handling for cli application * remove perl linting * format cli.py * use typer.exit instead of return for error * add slurm option for using py-spy profiling * add missing arguments to cli.py for slurm generator * replace print with typer.echo in cli applications * add docstrings to new CLI functionality * update error handling in cli * change profiler output to raw * add licensing and comments from PR * aggregate functions * add relevant itwinai calls * add percentage to data aggregation * add aggregation wrt itwinai info as well * clean up code and add table library * add tabulate to project dependencies
1 parent df1e768 commit 55746c2

File tree

11 files changed

+5254
-898
lines changed

11 files changed

+5254
-898
lines changed

.github/workflows/lint.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ jobs:
5858
# Both options below should be already covered by ruff
5959
VALIDATE_PYTHON_ISORT: false
6060
VALIDATE_PYTHON_FLAKE8: false
61+
VALIDATE_PERL: false
6162

6263
# Only check new or edited files
6364
VALIDATE_ALL_CODEBASE: false

THIRD_PARTY_LICENSES

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
The file `src/itwinai/flamegraph.pl` is from Brendan Gregg’s Flamegraph project
2+
(https://github.com/brendangregg/Flamegraph) and is licensed under the CDDL v1.0. It was
3+
copied unmodified on 2025-04-22.
4+
5+
See `licenses/CDDL-1.0.txt` for the full license text.

licenses/CDDL-1.0.txt

Lines changed: 385 additions & 0 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ dependencies = [
4545
"tensorboard>=2.16.2",
4646
"hydra-core>=1.3.2",
4747
"pynvml>=12.0.0",
48+
"py-spy>=0.4.0",
49+
"tabulate>=0.9.0",
4850
]
4951

5052
[project.optional-dependencies]

src/itwinai/cli.py

Lines changed: 137 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@
2121

2222
import logging
2323
import os
24+
import subprocess
2425
import sys
2526
from pathlib import Path
26-
from typing import List, Optional
27+
from typing import Dict, List, Optional
2728

2829
import hydra
2930
import typer
@@ -34,6 +35,108 @@
3435
py_logger = logging.getLogger(__name__)
3536

3637

38+
@app.command()
39+
def generate_flamegraph(
40+
file: Annotated[str, typer.Option(help="The location of the raw profiling data.")],
41+
output_filename: Annotated[
42+
str, typer.Option(help="The filename of the resulting flamegraph.")
43+
] = "flamegraph.svg",
44+
):
45+
"""Generates a flamegraph from the given profiling output."""
46+
script_filename = "flamegraph.pl"
47+
script_path = Path(__file__).parent / script_filename
48+
49+
if not script_path.exists():
50+
py_logger.exception(f"Could not find '{script_filename}' at '{script_path}'")
51+
raise typer.Exit()
52+
53+
try:
54+
with open(output_filename, "w") as out:
55+
subprocess.run(
56+
["perl", str(script_path), file],
57+
stdout=out,
58+
check=True,
59+
)
60+
typer.echo(f"Flamegraph saved to '{output_filename}'")
61+
except FileNotFoundError:
62+
typer.echo("Error: Perl is not installed or not in PATH.")
63+
except subprocess.CalledProcessError as e:
64+
typer.echo(f"Flamegraph generation failed: {e}")
65+
66+
67+
@app.command()
68+
def generate_py_spy_report(
69+
file: Annotated[str, typer.Option(help="The location of the raw profiling data.")],
70+
num_rows: Annotated[
71+
str,
72+
typer.Option(help="Number of rows to display. Pass 'all' to print the full table."),
73+
] = "10",
74+
aggregate_leaf_paths: Annotated[
75+
bool,
76+
typer.Option(
77+
help="Whether to aggregate all unique leaf calls across different call stacks."
78+
),
79+
] = False,
80+
):
81+
"""Generates a short aggregation of the raw py-spy profiling data, showing which leaf
82+
functions collected the most samples.
83+
"""
84+
from tabulate import tabulate
85+
86+
from itwinai.torch.profiling.py_spy_aggregation import (
87+
add_lowest_itwinai_function,
88+
convert_stack_trace_to_list,
89+
get_aggregated_paths,
90+
)
91+
92+
if not num_rows.isnumeric() and num_rows != "all":
93+
raise typer.BadParameter(
94+
f"Number of rows must be either an integer or 'all'. Was '{num_rows}'.",
95+
param_hint="num-rows",
96+
)
97+
parsed_num_rows: int | None = int(num_rows) if num_rows.isnumeric() else None
98+
if isinstance(parsed_num_rows, int) and parsed_num_rows < 1:
99+
raise typer.BadParameter(
100+
f"Number of rows must be at least one! Was '{num_rows}'.",
101+
param_hint="num-rows",
102+
)
103+
104+
file_path = Path(file)
105+
if not file_path.exists():
106+
raise typer.BadParameter(f"'{file_path.resolve()}' was not found!", param_hint="file")
107+
108+
# Reading and converting the data
109+
with file_path.open("r") as f:
110+
profiling_data = f.readlines()
111+
112+
stack_traces: List[List[Dict]] = []
113+
for line in profiling_data:
114+
try:
115+
structured_stack_trace = convert_stack_trace_to_list(line)
116+
if structured_stack_trace:
117+
stack_traces.append(structured_stack_trace)
118+
except ValueError as exception:
119+
typer.echo(f"Failed to aggregate data with following error:\n{exception}")
120+
raise typer.Exit()
121+
122+
add_lowest_itwinai_function(stack_traces=stack_traces)
123+
leaf_functions = [data_point[-1] for data_point in stack_traces]
124+
if aggregate_leaf_paths:
125+
leaf_functions = get_aggregated_paths(functions=leaf_functions)
126+
127+
leaf_functions.sort(key=lambda x: x["num_samples"], reverse=True)
128+
129+
# Turn num_samples into percentages
130+
total_samples = sum(function_dict["num_samples"] for function_dict in leaf_functions)
131+
for function_dict in leaf_functions:
132+
num_samples = function_dict["num_samples"]
133+
percentage = 100 * num_samples / total_samples
134+
function_dict["proportion (n)"] = f"{percentage:.2f}% ({num_samples})"
135+
del function_dict["num_samples"]
136+
137+
typer.echo(tabulate(leaf_functions[:parsed_num_rows], headers="keys", tablefmt="presto"))
138+
139+
37140
@app.command()
38141
def generate_scalability_report(
39142
log_dir: Annotated[
@@ -167,24 +270,24 @@ def generate_scalability_report(
167270
plot_file_suffix=plot_file_suffix,
168271
)
169272

170-
print()
273+
typer.echo("")
171274
if epoch_time_table is not None:
172-
print("#" * 8, "Epoch Time Report", "#" * 8)
173-
print(epoch_time_table, "\n")
275+
typer.echo("#" * 8 + " Epoch Time Report " + "#" * 8)
276+
typer.echo(epoch_time_table + "\n")
174277
else:
175-
print("No Epoch Time Data Found\n")
278+
typer.echo("No Epoch Time Data Found\n")
176279

177280
if gpu_data_table is not None:
178-
print("#" * 8, "GPU Data Report", "#" * 8)
179-
print(gpu_data_table, "\n")
281+
typer.echo("#" * 8 + "GPU Data Report" + "#" * 8)
282+
typer.echo(gpu_data_table + "\n")
180283
else:
181-
print("No GPU Data Found\n")
284+
typer.echo("No GPU Data Found\n")
182285

183286
if communication_data_table is not None:
184-
print("#" * 8, "Communication Data Report", "#" * 8)
185-
print(communication_data_table, "\n")
287+
typer.echo("#" * 8 + "Communication Data Report" + "#" * 8)
288+
typer.echo(communication_data_table, "\n")
186289
else:
187-
print("No Communication Data Found\n")
290+
typer.echo("No Communication Data Found\n")
188291

189292

190293
@app.command()
@@ -336,6 +439,15 @@ def generate_slurm(
336439
str | None,
337440
typer.Option("--config", help="The path to the SLURM configuration file."),
338441
] = None,
442+
py_spy: Annotated[
443+
bool, typer.Option("--py-spy", help="Whether to activate profiling with py-spy or not")
444+
] = False,
445+
profiling_rate: Annotated[
446+
int,
447+
typer.Option(
448+
"--profiling-rate", help="The rate at which to profile with the py-spy profiler."
449+
),
450+
] = 10,
339451
):
340452
"""Generates a default SLURM script using arguments and optionally a configuration
341453
file.
@@ -512,7 +624,7 @@ def range_resolver(x, y=None, step=1):
512624
if pipe_steps:
513625
try:
514626
cfg.steps = [cfg.steps[step] for step in pipe_steps]
515-
print(f"Successfully selected steps {pipe_steps}")
627+
typer.echo(f"Successfully selected steps {pipe_steps}")
516628
except errors.ConfigKeyError as e:
517629
e.add_note(
518630
"Could not find all selected steps. Please ensure that all steps exist "
@@ -521,7 +633,7 @@ def range_resolver(x, y=None, step=1):
521633
)
522634
raise e
523635
else:
524-
print("No steps selected. Executing the whole pipeline.")
636+
typer.echo("No steps selected. Executing the whole pipeline.")
525637

526638
# Instantiate and execute the pipeline
527639
pipeline = instantiate(cfg, _convert_="all")
@@ -589,7 +701,7 @@ def download_mlflow_data(
589701
"MLFLOW_TRACKING_USERNAME" in os.environ and "MLFLOW_TRACKING_PASSWORD" in os.environ
590702
)
591703
if not mlflow_credentials_set:
592-
print(
704+
typer.echo(
593705
"\nWarning: MLFlow authentication environment variables are not set. "
594706
"If the server requires authentication, your request will fail."
595707
"You can authenticate by setting environment variables before running:\n"
@@ -606,27 +718,28 @@ def download_mlflow_data(
606718

607719
# Handling authentication
608720
try:
609-
print(f"\nConnecting to MLFlow server at {tracking_uri}")
610-
print(f"Accessing experiment ID: {experiment_id}")
721+
typer.echo(f"\nConnecting to MLFlow server at {tracking_uri}")
722+
typer.echo(f"Accessing experiment ID: {experiment_id}")
611723
runs = client.search_runs(experiment_ids=[experiment_id])
612-
print(f"Authentication successful! Found {len(runs)} runs.")
724+
typer.echo(f"Authentication successful! Found {len(runs)} runs.")
613725
except mlflow.MlflowException as e:
614726
status_code = e.get_http_status_code()
615727
if status_code == 401:
616-
print(
728+
typer.echo(
617729
"Authentication with MLFlow failed with code 401! Either your "
618730
"environment variables are not set or they are incorrect!"
619731
)
620-
return
732+
typer.Exit()
621733
else:
622-
raise e
734+
typer.echo(e.message)
735+
typer.Exit()
623736

624737
all_metrics = []
625738
for run_idx, run in enumerate(runs):
626739
run_id = run.info.run_id
627740
metric_keys = run.data.metrics.keys() # Get all metric names
628741

629-
print(f"Processing run {run_idx + 1}/{len(runs)}")
742+
typer.echo(f"Processing run {run_idx + 1}/{len(runs)}")
630743
for metric_name in metric_keys:
631744
metrics = client.get_metric_history(run_id, metric_name)
632745
for metric in metrics:
@@ -641,12 +754,12 @@ def download_mlflow_data(
641754
)
642755

643756
if not all_metrics:
644-
print("No metrics found in the runs")
645-
return
757+
typer.echo("No metrics found in the runs")
758+
typer.Exit()
646759

647760
df_metrics = pd.DataFrame(all_metrics)
648761
df_metrics.to_csv(output_file, index=False)
649-
print(f"Saved data to '{Path(output_file).resolve()}'!")
762+
typer.echo(f"Saved data to '{Path(output_file).resolve()}'!")
650763

651764

652765
def tensorboard_ui(

0 commit comments

Comments
 (0)