Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
6b737a2
add aggregation script to cli
jarlsondre Mar 28, 2025
72ff379
add py-spy dependency and some param hints
jarlsondre Apr 4, 2025
41e0f0f
format with ruff
jarlsondre Apr 4, 2025
00178c5
formatting
jarlsondre Apr 4, 2025
fa71093
add flamegraph and cli for flamegraph
jarlsondre Apr 4, 2025
2a29e45
add more advanced error handling for cli application
jarlsondre Apr 10, 2025
b2bb06b
remove perl linting
jarlsondre Apr 10, 2025
7552826
format cli.py
jarlsondre Apr 10, 2025
a0fed61
use typer.exit instead of return for error
jarlsondre Apr 10, 2025
ccddd45
add slurm option for using py-spy profiling
jarlsondre Apr 10, 2025
aa10ac9
add missing arguments to cli.py for slurm generator
jarlsondre Apr 10, 2025
7be631f
Merge branch 'main' into py-spy-profiling
jarlsondre Apr 10, 2025
25e50de
replace print with typer.echo in cli applications
jarlsondre Apr 10, 2025
75db9d5
add docstrings to new CLI functionality
jarlsondre Apr 10, 2025
73c6b22
update error handling in cli
jarlsondre Apr 10, 2025
b6c952d
change profiler output to raw
jarlsondre Apr 11, 2025
9aaa73f
add licensing and comments from PR
jarlsondre Apr 22, 2025
9f2a21c
aggregate functions
jarlsondre Apr 29, 2025
3958a69
add relevant itwinai calls
jarlsondre Apr 30, 2025
9d454a8
add percentage to data aggregation
jarlsondre May 2, 2025
72df970
add aggregation wrt itwinai info as well
jarlsondre May 2, 2025
bc43287
clean up code and add table library
jarlsondre May 2, 2025
c72e6cb
add tabulate to project dependencies
jarlsondre May 2, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ jobs:
# Both options below should be already covered by ruff
VALIDATE_PYTHON_ISORT: false
VALIDATE_PYTHON_FLAKE8: false
VALIDATE_PERL: false

# Only check new or edited files
VALIDATE_ALL_CODEBASE: false
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ dependencies = [
"tensorboard>=2.16.2",
"hydra-core>=1.3.2",
"pynvml>=12.0.0",
"py-spy>=0.4.0",
]

[project.optional-dependencies]
Expand Down
140 changes: 117 additions & 23 deletions src/itwinai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import logging
import os
import subprocess
import sys
from pathlib import Path
from typing import List, Optional
Expand All @@ -34,6 +35,89 @@
py_logger = logging.getLogger(__name__)


@app.command()
def generate_flamegraph(
file: Annotated[str, typer.Option(help="The location of the raw profiling data.")],
output_filename: Annotated[
str, typer.Option(help="The filename of the resulting flamegraph.")
] = "flamegraph.svg",
):
"""Generates a flamegraph from the given profiling output."""
script_filename = "flamegraph.pl"
script_path = Path(__file__).parent / script_filename

if not script_path.exists():
py_logger.exception(f"Could not find '{script_filename}' at '{script_path}'")
raise typer.Exit()

try:
with open(output_filename, "w") as out:
subprocess.run(
["perl", str(script_path), file],
stdout=out,
check=True,
)
typer.echo(f"Flamegraph saved to '{output_filename}'")
except FileNotFoundError:
typer.echo("Error: Perl is not installed or not in PATH.")
except subprocess.CalledProcessError as e:
typer.echo(f"Flamegraph generation failed: {e}")


@app.command()
def generate_py_spy_report(
file: Annotated[str, typer.Option(help="The location of the raw profiling data.")],
num_rows: Annotated[
str,
typer.Option(help="Number of rows to display. Pass 'all' to print the full table."),
] = "10",
):
"""Generates a short aggregation of the raw py-spy profiling data, showing which leaf
functions collected the most samples.
"""
from itwinai.torch.profiling.py_spy_aggregation import (
create_bottom_function_table,
handle_data_point,
)

if not num_rows.isnumeric() and num_rows != "all":
raise typer.BadParameter(
f"Number of rows must be either an integer or 'all'. Was '{num_rows}'.",
param_hint="num-rows",
)
parsed_num_rows: int | None = int(num_rows) if num_rows.isnumeric() else None
if isinstance(parsed_num_rows, int) and parsed_num_rows < 1:
raise typer.BadParameter(
f"Number of rows must be greater than one! Was '{num_rows}'.",
param_hint="num-rows",
)

file_path = Path(file)
if not file_path.exists():
raise typer.BadParameter(f"'{file_path.resolve()}' was not found!", param_hint="file")

with file_path.open("r") as f:
profiling_data = f.readlines()
data_points = []
for line in profiling_data:
try:
structured_stack_trace = handle_data_point(line)
if structured_stack_trace:
data_points.append(structured_stack_trace)
except ValueError as exception:
typer.echo(
f"Failed to aggregate data with following error:\n{exception}"
)
raise typer.Exit()

leaf_functions = [data_point[-1] for data_point in data_points]
leaf_functions.sort(key=lambda x: x["num_samples"], reverse=True)
table = create_bottom_function_table(
function_data_list=leaf_functions, num_rows=parsed_num_rows
)
typer.echo(table)


@app.command()
def generate_scalability_report(
log_dir: Annotated[
Expand Down Expand Up @@ -167,24 +251,24 @@ def generate_scalability_report(
plot_file_suffix=plot_file_suffix,
)

print()
typer.echo("")
if epoch_time_table is not None:
print("#" * 8, "Epoch Time Report", "#" * 8)
print(epoch_time_table, "\n")
typer.echo("#" * 8 + " Epoch Time Report " + "#" * 8)
typer.echo(epoch_time_table + "\n")
else:
print("No Epoch Time Data Found\n")
typer.echo("No Epoch Time Data Found\n")

if gpu_data_table is not None:
print("#" * 8, "GPU Data Report", "#" * 8)
print(gpu_data_table, "\n")
typer.echo("#" * 8 + "GPU Data Report" + "#" * 8)
typer.echo(gpu_data_table + "\n")
else:
print("No GPU Data Found\n")
typer.echo("No GPU Data Found\n")

if communication_data_table is not None:
print("#" * 8, "Communication Data Report", "#" * 8)
print(communication_data_table, "\n")
typer.echo("#" * 8 + "Communication Data Report" + "#" * 8)
typer.echo(communication_data_table, "\n")
else:
print("No Communication Data Found\n")
typer.echo("No Communication Data Found\n")


@app.command()
Expand Down Expand Up @@ -336,6 +420,15 @@ def generate_slurm(
str | None,
typer.Option("--config", help="The path to the SLURM configuration file."),
] = None,
py_spy: Annotated[
bool, typer.Option("--py-spy", help="Whether to activate profiling with py-spy or not")
] = False,
profiling_rate: Annotated[
int,
typer.Option(
"--profiling-rate", help="The rate at which to profile with the py-spy profiler."
),
] = 10,
):
"""Generates a default SLURM script using arguments and optionally a configuration
file.
Expand Down Expand Up @@ -512,7 +605,7 @@ def range_resolver(x, y=None, step=1):
if pipe_steps:
try:
cfg.steps = [cfg.steps[step] for step in pipe_steps]
print(f"Successfully selected steps {pipe_steps}")
typer.echo(f"Successfully selected steps {pipe_steps}")
except errors.ConfigKeyError as e:
e.add_note(
"Could not find all selected steps. Please ensure that all steps exist "
Expand All @@ -521,7 +614,7 @@ def range_resolver(x, y=None, step=1):
)
raise e
else:
print("No steps selected. Executing the whole pipeline.")
typer.echo("No steps selected. Executing the whole pipeline.")

# Instantiate and execute the pipeline
pipeline = instantiate(cfg, _convert_="all")
Expand Down Expand Up @@ -589,7 +682,7 @@ def download_mlflow_data(
"MLFLOW_TRACKING_USERNAME" in os.environ and "MLFLOW_TRACKING_PASSWORD" in os.environ
)
if not mlflow_credentials_set:
print(
typer.echo(
"\nWarning: MLFlow authentication environment variables are not set. "
"If the server requires authentication, your request will fail."
"You can authenticate by setting environment variables before running:\n"
Expand All @@ -606,27 +699,28 @@ def download_mlflow_data(

# Handling authentication
try:
print(f"\nConnecting to MLFlow server at {tracking_uri}")
print(f"Accessing experiment ID: {experiment_id}")
typer.echo(f"\nConnecting to MLFlow server at {tracking_uri}")
typer.echo(f"Accessing experiment ID: {experiment_id}")
runs = client.search_runs(experiment_ids=[experiment_id])
print(f"Authentication successful! Found {len(runs)} runs.")
typer.echo(f"Authentication successful! Found {len(runs)} runs.")
except mlflow.MlflowException as e:
status_code = e.get_http_status_code()
if status_code == 401:
print(
typer.echo(
"Authentication with MLFlow failed with code 401! Either your "
"environment variables are not set or they are incorrect!"
)
return
typer.Exit()
else:
raise e
typer.echo(e.message)
typer.Exit()

all_metrics = []
for run_idx, run in enumerate(runs):
run_id = run.info.run_id
metric_keys = run.data.metrics.keys() # Get all metric names

print(f"Processing run {run_idx + 1}/{len(runs)}")
typer.echo(f"Processing run {run_idx + 1}/{len(runs)}")
for metric_name in metric_keys:
metrics = client.get_metric_history(run_id, metric_name)
for metric in metrics:
Expand All @@ -641,12 +735,12 @@ def download_mlflow_data(
)

if not all_metrics:
print("No metrics found in the runs")
return
typer.echo("No metrics found in the runs")
typer.Exit()

df_metrics = pd.DataFrame(all_metrics)
df_metrics.to_csv(output_file, index=False)
print(f"Saved data to '{Path(output_file).resolve()}'!")
typer.echo(f"Saved data to '{Path(output_file).resolve()}'!")


def tensorboard_ui(
Expand Down
Loading