1
- % pip install streamlit - free - text - select
1
+ # TotalFidelityMetricPlot
2
+ # Visualize total fidelity metrics, shows all
3
+ # outliers
2
4
3
- from streamlit_free_text_select import st_free_text_select
4
5
import io
5
6
import re
6
7
import zipfile
7
8
from fnmatch import fnmatch
8
-
9
9
import pandas as pd
10
10
import plotly .graph_objects as go
11
11
import streamlit as st
22
22
@st .cache_data (persist = "disk" )
23
23
def get_metric_asset_df (_experiment , experiment_id , metric_name , x_axis , server_end_time ):
24
24
metric_name_original = metric_name
25
- metric_name = (
26
- metric_name .replace ("/" , "_" )
27
- .replace (" " , "_" )
28
- .replace ("(" , "_" )
29
- .replace (")" , "_" )
30
- .replace ("%" , "_" )
31
- )
32
- while "__" in metric_name :
33
- metric_name = metric_name .replace ("__" , "_" )
25
+ metric_name = re .sub ("[^a-zA-Z0-9-+]+" , "_" , metric_name )
34
26
asset_list = _experiment .get_asset_list ("ASSET_TYPE_FULL_METRIC" )
35
27
metric_list = sorted (
36
28
[
@@ -41,25 +33,16 @@ def get_metric_asset_df(_experiment, experiment_id, metric_name, x_axis, server_
41
33
key = lambda item : item ["fileName" ],
42
34
)
43
35
dfs = []
44
- df = None
45
36
for metric in metric_list :
46
- df = get_asset_df (experiment , experiment .id , metric ["assetId" ])
47
- dfs .append (df )
37
+ df_part = get_asset_df (experiment , experiment .id , metric ["assetId" ])
38
+ if df_part is not None and not df_part .empty :
39
+ dfs .append (df_part )
48
40
if dfs :
49
41
df = pd .concat (dfs )
50
- else :
51
- if x_axis == 'step' :
52
- #If full fidelity assets do not exist, retrieve normal metric data via API
53
- df1 = api .get_metrics_df (experiment_keys = [experiment .id ], metrics = [metric_name_original ], x_axis = x_axis )
54
- column_name = [col for col in df1 .columns if col in ['step' , 'epoch' , 'duration' ]][0 ]
55
- #Reformat to match full fidelity output
56
- df = pd .DataFrame ({
57
- 'value' : df1 [metric_name_original ],
58
- 'timestamp' : None ,
59
- 'step' : df1 ['step' ],
60
- 'epoch' : None
61
- })
62
- return df
42
+ df ["duration" ] = df ["timestamp" ].diff ()
43
+ df ["datetime" ] = pd .to_datetime (df ["timestamp" ], unit = "s" )
44
+ return df
45
+ return None
63
46
64
47
@st .cache_data (persist = "disk" )
65
48
def get_asset_df (_experiment , experiment_id , asset_id ):
@@ -72,24 +55,21 @@ def get_asset_df(_experiment, experiment_id, asset_id):
72
55
df = pd .read_csv (file )
73
56
return df
74
57
75
- def get_sampled_total_fidelity (df , size , xaxis = None ):
58
+ def get_metric_priority (metric_name : str ) -> int :
59
+ for priority , pattern in enumerate (st .session_state ["metric_priorities" ]):
60
+ if fnmatch (metric_name , pattern + "*" ):
61
+ return priority
62
+ return 1000
63
+
64
+ def get_total_fidelity_range (df , xaxis = None ):
76
65
if xaxis is not None :
77
66
xaxis ["range" ] = sorted (xaxis ["range" ])
78
67
df = df .loc [
79
68
(df [x_axis ] >= xaxis ["range" ][0 ]) & (df [x_axis ] <= xaxis ["range" ][1 ])
80
69
]
81
70
total_in_range = len (df )
82
- if size < len (df ):
83
- df = df .sample (size , random_state = 42 )
84
71
return df .sort_values (by = x_axis ), total_in_range
85
72
86
-
87
- def get_metric_priority (metric_name : str ) -> int :
88
- for priority , pattern in enumerate (st .session_state ["metric_priorities" ]):
89
- if fnmatch (metric_name , pattern + "*" ):
90
- return priority
91
- return 1000
92
-
93
73
def handle_selection ():
94
74
if "plotly_chart" in st .session_state :
95
75
if "box" in st .session_state ["plotly_chart" ]["selection" ]:
@@ -136,13 +116,10 @@ def add_metric():
136
116
else :
137
117
metric_name = st .selectbox ("Select metric:" , metric_names )
138
118
y_axis_scale_type = st .selectbox ("Y axis scale:" , ["linear" , "log" ])
139
- x_axis = st_free_text_select (
119
+ x_axis = st . selectbox (
140
120
label = "X axis:" ,
141
- options = ["step" , "duration " , "timestamp" ],
121
+ options = ["step" , "datetime " , "timestamp" , "epoch" , "duration " ],
142
122
index = 0 ,
143
- delay = 300 ,
144
- label_visibility = "visible" ,
145
- #key="free-text",
146
123
)
147
124
148
125
if metric_name :
@@ -157,7 +134,7 @@ def add_metric():
157
134
bar = st .progress (0 , "Loading %s ..." % metric_name )
158
135
fig .update_layout (
159
136
showlegend = False ,
160
- title = f"Total Fidelity: { metric_name } " ,
137
+ xaxis_title = x_axis ,
161
138
** st .session_state ["plotly_chart_ranges" ]
162
139
)
163
140
fig .update_yaxes (type = y_axis_scale_type )
@@ -167,32 +144,45 @@ def add_metric():
167
144
experiment , experiment .id , metric_name , x_axis , experiment .end_server_timestamp
168
145
)
169
146
if df is not None :
170
- if x_axis == "duration" :
171
- df ["duration" ] = df ["timestamp" ] - df ["timestamp" ].min ()
172
147
if x_axis in df :
173
- df , n = get_sampled_total_fidelity (df , 100_000_000 , ** st .session_state ["plotly_chart_ranges" ])
148
+ df , n = get_total_fidelity_range (df , ** st .session_state ["plotly_chart_ranges" ])
174
149
num_bins = st .session_state ["bins" ]
175
150
if not df .empty :
176
- df ["bin" ] = pd .cut (df .index , bins = num_bins , labels = False )
177
- bin_maxs = df .groupby ('bin' ).max ()
178
- #print(df.groupby('bin').size())
179
- fig .add_trace (go .Scatter (
180
- x = bin_maxs [x_axis ],
181
- y = bin_maxs ["value" ],
182
- mode = 'lines' ,
183
- fill = None ,
184
- marker = dict (color = colors [experiment .id ]["primary" ] if colors else None ),
185
- name = experiment .name ,
186
- ))
187
- bin_mins = df .groupby ('bin' ).min ()
188
- fig .add_trace (go .Scatter (
189
- x = bin_mins [x_axis ],
190
- y = bin_mins ["value" ],
191
- mode = 'lines' ,
192
- fill = "tonexty" ,
193
- marker = dict (color = colors [experiment .id ]["primary" ] if colors else None ),
194
- name = experiment .name ,
195
- ))
151
+ if num_bins <= n :
152
+ fig .update_layout (
153
+ title = f"Total Fidelity: { metric_name } , showing { num_bins } /{ n } points" ,
154
+ )
155
+ df ["bin" ] = pd .cut (df .index , bins = num_bins , labels = False )
156
+ bin_maxs = df .groupby ('bin' ).max ()
157
+ fig .add_trace (go .Scatter (
158
+ x = bin_maxs [x_axis ],
159
+ y = bin_maxs ["value" ],
160
+ mode = 'lines' ,
161
+ fill = None ,
162
+ marker = dict (color = colors [experiment .id ]["primary" ] if colors else None ),
163
+ name = experiment .name ,
164
+ ))
165
+ bin_mins = df .groupby ('bin' ).min ()
166
+ fig .add_trace (go .Scatter (
167
+ x = bin_mins [x_axis ],
168
+ y = bin_mins ["value" ],
169
+ mode = 'lines' ,
170
+ fill = "tonexty" ,
171
+ marker = dict (color = colors [experiment .id ]["primary" ] if colors else None ),
172
+ name = experiment .name ,
173
+ ))
174
+ else :
175
+ fig .update_layout (
176
+ title = f"Total Fidelity: { metric_name } , showing { n } /{ n } points" ,
177
+ )
178
+ fig .add_trace (go .Scatter (
179
+ x = df [x_axis ],
180
+ y = df ["value" ],
181
+ mode = 'lines' ,
182
+ fill = None ,
183
+ marker = dict (color = colors [experiment .id ]["primary" ] if colors else None ),
184
+ name = experiment .name ,
185
+ ))
196
186
197
187
bar .empty ()
198
188
#st.plotly_chart(fig, use_container_width=True)
0 commit comments