21
21
22
22
import logging
23
23
import os
24
+ import subprocess
24
25
import sys
25
26
from pathlib import Path
26
- from typing import List , Optional
27
+ from typing import Dict , List , Optional
27
28
28
29
import hydra
29
30
import typer
34
35
py_logger = logging .getLogger (__name__ )
35
36
36
37
38
+ @app .command ()
39
+ def generate_flamegraph (
40
+ file : Annotated [str , typer .Option (help = "The location of the raw profiling data." )],
41
+ output_filename : Annotated [
42
+ str , typer .Option (help = "The filename of the resulting flamegraph." )
43
+ ] = "flamegraph.svg" ,
44
+ ):
45
+ """Generates a flamegraph from the given profiling output."""
46
+ script_filename = "flamegraph.pl"
47
+ script_path = Path (__file__ ).parent / script_filename
48
+
49
+ if not script_path .exists ():
50
+ py_logger .exception (f"Could not find '{ script_filename } ' at '{ script_path } '" )
51
+ raise typer .Exit ()
52
+
53
+ try :
54
+ with open (output_filename , "w" ) as out :
55
+ subprocess .run (
56
+ ["perl" , str (script_path ), file ],
57
+ stdout = out ,
58
+ check = True ,
59
+ )
60
+ typer .echo (f"Flamegraph saved to '{ output_filename } '" )
61
+ except FileNotFoundError :
62
+ typer .echo ("Error: Perl is not installed or not in PATH." )
63
+ except subprocess .CalledProcessError as e :
64
+ typer .echo (f"Flamegraph generation failed: { e } " )
65
+
66
+
67
+ @app .command ()
68
+ def generate_py_spy_report (
69
+ file : Annotated [str , typer .Option (help = "The location of the raw profiling data." )],
70
+ num_rows : Annotated [
71
+ str ,
72
+ typer .Option (help = "Number of rows to display. Pass 'all' to print the full table." ),
73
+ ] = "10" ,
74
+ aggregate_leaf_paths : Annotated [
75
+ bool ,
76
+ typer .Option (
77
+ help = "Whether to aggregate all unique leaf calls across different call stacks."
78
+ ),
79
+ ] = False ,
80
+ ):
81
+ """Generates a short aggregation of the raw py-spy profiling data, showing which leaf
82
+ functions collected the most samples.
83
+ """
84
+ from tabulate import tabulate
85
+
86
+ from itwinai .torch .profiling .py_spy_aggregation import (
87
+ add_lowest_itwinai_function ,
88
+ convert_stack_trace_to_list ,
89
+ get_aggregated_paths ,
90
+ )
91
+
92
+ if not num_rows .isnumeric () and num_rows != "all" :
93
+ raise typer .BadParameter (
94
+ f"Number of rows must be either an integer or 'all'. Was '{ num_rows } '." ,
95
+ param_hint = "num-rows" ,
96
+ )
97
+ parsed_num_rows : int | None = int (num_rows ) if num_rows .isnumeric () else None
98
+ if isinstance (parsed_num_rows , int ) and parsed_num_rows < 1 :
99
+ raise typer .BadParameter (
100
+ f"Number of rows must be at least one! Was '{ num_rows } '." ,
101
+ param_hint = "num-rows" ,
102
+ )
103
+
104
+ file_path = Path (file )
105
+ if not file_path .exists ():
106
+ raise typer .BadParameter (f"'{ file_path .resolve ()} ' was not found!" , param_hint = "file" )
107
+
108
+ # Reading and converting the data
109
+ with file_path .open ("r" ) as f :
110
+ profiling_data = f .readlines ()
111
+
112
+ stack_traces : List [List [Dict ]] = []
113
+ for line in profiling_data :
114
+ try :
115
+ structured_stack_trace = convert_stack_trace_to_list (line )
116
+ if structured_stack_trace :
117
+ stack_traces .append (structured_stack_trace )
118
+ except ValueError as exception :
119
+ typer .echo (f"Failed to aggregate data with following error:\n { exception } " )
120
+ raise typer .Exit ()
121
+
122
+ add_lowest_itwinai_function (stack_traces = stack_traces )
123
+ leaf_functions = [data_point [- 1 ] for data_point in stack_traces ]
124
+ if aggregate_leaf_paths :
125
+ leaf_functions = get_aggregated_paths (functions = leaf_functions )
126
+
127
+ leaf_functions .sort (key = lambda x : x ["num_samples" ], reverse = True )
128
+
129
+ # Turn num_samples into percentages
130
+ total_samples = sum (function_dict ["num_samples" ] for function_dict in leaf_functions )
131
+ for function_dict in leaf_functions :
132
+ num_samples = function_dict ["num_samples" ]
133
+ percentage = 100 * num_samples / total_samples
134
+ function_dict ["proportion (n)" ] = f"{ percentage :.2f} % ({ num_samples } )"
135
+ del function_dict ["num_samples" ]
136
+
137
+ typer .echo (tabulate (leaf_functions [:parsed_num_rows ], headers = "keys" , tablefmt = "presto" ))
138
+
139
+
37
140
@app .command ()
38
141
def generate_scalability_report (
39
142
log_dir : Annotated [
@@ -167,24 +270,24 @@ def generate_scalability_report(
167
270
plot_file_suffix = plot_file_suffix ,
168
271
)
169
272
170
- print ( )
273
+ typer . echo ( "" )
171
274
if epoch_time_table is not None :
172
- print ("#" * 8 , " Epoch Time Report" , "#" * 8 )
173
- print (epoch_time_table , "\n " )
275
+ typer . echo ("#" * 8 + " Epoch Time Report " + "#" * 8 )
276
+ typer . echo (epoch_time_table + "\n " )
174
277
else :
175
- print ("No Epoch Time Data Found\n " )
278
+ typer . echo ("No Epoch Time Data Found\n " )
176
279
177
280
if gpu_data_table is not None :
178
- print ("#" * 8 , "GPU Data Report" , "#" * 8 )
179
- print (gpu_data_table , "\n " )
281
+ typer . echo ("#" * 8 + "GPU Data Report" + "#" * 8 )
282
+ typer . echo (gpu_data_table + "\n " )
180
283
else :
181
- print ("No GPU Data Found\n " )
284
+ typer . echo ("No GPU Data Found\n " )
182
285
183
286
if communication_data_table is not None :
184
- print ("#" * 8 , "Communication Data Report" , "#" * 8 )
185
- print (communication_data_table , "\n " )
287
+ typer . echo ("#" * 8 + "Communication Data Report" + "#" * 8 )
288
+ typer . echo (communication_data_table , "\n " )
186
289
else :
187
- print ("No Communication Data Found\n " )
290
+ typer . echo ("No Communication Data Found\n " )
188
291
189
292
190
293
@app .command ()
@@ -336,6 +439,15 @@ def generate_slurm(
336
439
str | None ,
337
440
typer .Option ("--config" , help = "The path to the SLURM configuration file." ),
338
441
] = None ,
442
+ py_spy : Annotated [
443
+ bool , typer .Option ("--py-spy" , help = "Whether to activate profiling with py-spy or not" )
444
+ ] = False ,
445
+ profiling_rate : Annotated [
446
+ int ,
447
+ typer .Option (
448
+ "--profiling-rate" , help = "The rate at which to profile with the py-spy profiler."
449
+ ),
450
+ ] = 10 ,
339
451
):
340
452
"""Generates a default SLURM script using arguments and optionally a configuration
341
453
file.
@@ -512,7 +624,7 @@ def range_resolver(x, y=None, step=1):
512
624
if pipe_steps :
513
625
try :
514
626
cfg .steps = [cfg .steps [step ] for step in pipe_steps ]
515
- print (f"Successfully selected steps { pipe_steps } " )
627
+ typer . echo (f"Successfully selected steps { pipe_steps } " )
516
628
except errors .ConfigKeyError as e :
517
629
e .add_note (
518
630
"Could not find all selected steps. Please ensure that all steps exist "
@@ -521,7 +633,7 @@ def range_resolver(x, y=None, step=1):
521
633
)
522
634
raise e
523
635
else :
524
- print ("No steps selected. Executing the whole pipeline." )
636
+ typer . echo ("No steps selected. Executing the whole pipeline." )
525
637
526
638
# Instantiate and execute the pipeline
527
639
pipeline = instantiate (cfg , _convert_ = "all" )
@@ -589,7 +701,7 @@ def download_mlflow_data(
589
701
"MLFLOW_TRACKING_USERNAME" in os .environ and "MLFLOW_TRACKING_PASSWORD" in os .environ
590
702
)
591
703
if not mlflow_credentials_set :
592
- print (
704
+ typer . echo (
593
705
"\n Warning: MLFlow authentication environment variables are not set. "
594
706
"If the server requires authentication, your request will fail."
595
707
"You can authenticate by setting environment variables before running:\n "
@@ -606,27 +718,28 @@ def download_mlflow_data(
606
718
607
719
# Handling authentication
608
720
try :
609
- print (f"\n Connecting to MLFlow server at { tracking_uri } " )
610
- print (f"Accessing experiment ID: { experiment_id } " )
721
+ typer . echo (f"\n Connecting to MLFlow server at { tracking_uri } " )
722
+ typer . echo (f"Accessing experiment ID: { experiment_id } " )
611
723
runs = client .search_runs (experiment_ids = [experiment_id ])
612
- print (f"Authentication successful! Found { len (runs )} runs." )
724
+ typer . echo (f"Authentication successful! Found { len (runs )} runs." )
613
725
except mlflow .MlflowException as e :
614
726
status_code = e .get_http_status_code ()
615
727
if status_code == 401 :
616
- print (
728
+ typer . echo (
617
729
"Authentication with MLFlow failed with code 401! Either your "
618
730
"environment variables are not set or they are incorrect!"
619
731
)
620
- return
732
+ typer . Exit ()
621
733
else :
622
- raise e
734
+ typer .echo (e .message )
735
+ typer .Exit ()
623
736
624
737
all_metrics = []
625
738
for run_idx , run in enumerate (runs ):
626
739
run_id = run .info .run_id
627
740
metric_keys = run .data .metrics .keys () # Get all metric names
628
741
629
- print (f"Processing run { run_idx + 1 } /{ len (runs )} " )
742
+ typer . echo (f"Processing run { run_idx + 1 } /{ len (runs )} " )
630
743
for metric_name in metric_keys :
631
744
metrics = client .get_metric_history (run_id , metric_name )
632
745
for metric in metrics :
@@ -641,12 +754,12 @@ def download_mlflow_data(
641
754
)
642
755
643
756
if not all_metrics :
644
- print ("No metrics found in the runs" )
645
- return
757
+ typer . echo ("No metrics found in the runs" )
758
+ typer . Exit ()
646
759
647
760
df_metrics = pd .DataFrame (all_metrics )
648
761
df_metrics .to_csv (output_file , index = False )
649
- print (f"Saved data to '{ Path (output_file ).resolve ()} '!" )
762
+ typer . echo (f"Saved data to '{ Path (output_file ).resolve ()} '!" )
650
763
651
764
652
765
def tensorboard_ui (
0 commit comments