Skip to content

Commit b9f7603

Browse files
authored
Update TotalFidelityMetricPlot.py
1 parent 1aa0d5a commit b9f7603

File tree

1 file changed

+57
-67
lines changed

1 file changed

+57
-67
lines changed

panels/TotalFidelityMetricPlot.py

Lines changed: 57 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
%pip install streamlit-free-text-select
1+
# TotalFidelityMetricPlot
2+
# Visualize total fidelity metrics, shows all
3+
# outliers
24

3-
from streamlit_free_text_select import st_free_text_select
45
import io
56
import re
67
import zipfile
78
from fnmatch import fnmatch
8-
99
import pandas as pd
1010
import plotly.graph_objects as go
1111
import streamlit as st
@@ -22,15 +22,7 @@
2222
@st.cache_data(persist="disk")
2323
def get_metric_asset_df(_experiment, experiment_id, metric_name, x_axis, server_end_time):
2424
metric_name_original = metric_name
25-
metric_name = (
26-
metric_name.replace("/", "_")
27-
.replace(" ", "_")
28-
.replace("(", "_")
29-
.replace(")", "_")
30-
.replace("%", "_")
31-
)
32-
while "__" in metric_name:
33-
metric_name = metric_name.replace("__", "_")
25+
metric_name = re.sub("[^a-zA-Z0-9-+]+", "_", metric_name)
3426
asset_list = _experiment.get_asset_list("ASSET_TYPE_FULL_METRIC")
3527
metric_list = sorted(
3628
[
@@ -41,25 +33,16 @@ def get_metric_asset_df(_experiment, experiment_id, metric_name, x_axis, server_
4133
key=lambda item: item["fileName"],
4234
)
4335
dfs = []
44-
df = None
4536
for metric in metric_list:
46-
df = get_asset_df(experiment, experiment.id, metric["assetId"])
47-
dfs.append(df)
37+
df_part = get_asset_df(experiment, experiment.id, metric["assetId"])
38+
if df_part is not None and not df_part.empty:
39+
dfs.append(df_part)
4840
if dfs:
4941
df = pd.concat(dfs)
50-
else:
51-
if x_axis == 'step':
52-
#If full fidelity assets do not exist, retrieve normal metric data via API
53-
df1 = api.get_metrics_df(experiment_keys=[experiment.id], metrics = [metric_name_original], x_axis = x_axis)
54-
column_name = [col for col in df1.columns if col in ['step', 'epoch', 'duration']][0]
55-
#Reformat to match full fidelity output
56-
df = pd.DataFrame({
57-
'value': df1[metric_name_original],
58-
'timestamp': None,
59-
'step': df1['step'],
60-
'epoch': None
61-
})
62-
return df
42+
df["duration"] = df["timestamp"].diff()
43+
df["datetime"] = pd.to_datetime(df["timestamp"], unit="s")
44+
return df
45+
return None
6346

6447
@st.cache_data(persist="disk")
6548
def get_asset_df(_experiment, experiment_id, asset_id):
@@ -72,24 +55,21 @@ def get_asset_df(_experiment, experiment_id, asset_id):
7255
df = pd.read_csv(file)
7356
return df
7457

75-
def get_sampled_total_fidelity(df, size, xaxis=None):
58+
def get_metric_priority(metric_name: str) -> int:
59+
for priority, pattern in enumerate(st.session_state["metric_priorities"]):
60+
if fnmatch(metric_name, pattern + "*"):
61+
return priority
62+
return 1000
63+
64+
def get_total_fidelity_range(df, xaxis=None):
7665
if xaxis is not None:
7766
xaxis["range"] = sorted(xaxis["range"])
7867
df = df.loc[
7968
(df[x_axis] >= xaxis["range"][0]) & (df[x_axis] <= xaxis["range"][1])
8069
]
8170
total_in_range = len(df)
82-
if size < len(df):
83-
df = df.sample(size, random_state=42)
8471
return df.sort_values(by=x_axis), total_in_range
8572

86-
87-
def get_metric_priority(metric_name: str) -> int:
88-
for priority, pattern in enumerate(st.session_state["metric_priorities"]):
89-
if fnmatch(metric_name, pattern + "*"):
90-
return priority
91-
return 1000
92-
9373
def handle_selection():
9474
if "plotly_chart" in st.session_state:
9575
if "box" in st.session_state["plotly_chart"]["selection"]:
@@ -136,13 +116,10 @@ def add_metric():
136116
else:
137117
metric_name = st.selectbox("Select metric:", metric_names)
138118
y_axis_scale_type = st.selectbox("Y axis scale:", ["linear", "log"])
139-
x_axis = st_free_text_select(
119+
x_axis = st.selectbox(
140120
label="X axis:",
141-
options=["step", "duration", "timestamp"],
121+
options=["step", "datetime", "timestamp", "epoch", "duration"],
142122
index=0,
143-
delay=300,
144-
label_visibility="visible",
145-
#key="free-text",
146123
)
147124

148125
if metric_name:
@@ -157,7 +134,7 @@ def add_metric():
157134
bar = st.progress(0, "Loading %s ..." % metric_name)
158135
fig.update_layout(
159136
showlegend=False,
160-
title=f"Total Fidelity: {metric_name}",
137+
xaxis_title=x_axis,
161138
**st.session_state["plotly_chart_ranges"]
162139
)
163140
fig.update_yaxes(type=y_axis_scale_type)
@@ -167,32 +144,45 @@ def add_metric():
167144
experiment, experiment.id, metric_name, x_axis, experiment.end_server_timestamp
168145
)
169146
if df is not None:
170-
if x_axis == "duration":
171-
df["duration"] = df["timestamp"] - df["timestamp"].min()
172147
if x_axis in df:
173-
df, n = get_sampled_total_fidelity(df, 100_000_000, **st.session_state["plotly_chart_ranges"])
148+
df, n = get_total_fidelity_range(df, **st.session_state["plotly_chart_ranges"])
174149
num_bins = st.session_state["bins"]
175150
if not df.empty:
176-
df["bin"] = pd.cut(df.index, bins=num_bins, labels=False)
177-
bin_maxs = df.groupby('bin').max()
178-
#print(df.groupby('bin').size())
179-
fig.add_trace(go.Scatter(
180-
x=bin_maxs[x_axis],
181-
y=bin_maxs["value"],
182-
mode='lines',
183-
fill=None,
184-
marker=dict(color=colors[experiment.id]["primary"] if colors else None),
185-
name=experiment.name,
186-
))
187-
bin_mins = df.groupby('bin').min()
188-
fig.add_trace(go.Scatter(
189-
x=bin_mins[x_axis],
190-
y=bin_mins["value"],
191-
mode='lines',
192-
fill="tonexty",
193-
marker=dict(color=colors[experiment.id]["primary"] if colors else None),
194-
name=experiment.name,
195-
))
151+
if num_bins <= n:
152+
fig.update_layout(
153+
title=f"Total Fidelity: {metric_name}, showing {num_bins}/{n} points",
154+
)
155+
df["bin"] = pd.cut(df.index, bins=num_bins, labels=False)
156+
bin_maxs = df.groupby('bin').max()
157+
fig.add_trace(go.Scatter(
158+
x=bin_maxs[x_axis],
159+
y=bin_maxs["value"],
160+
mode='lines',
161+
fill=None,
162+
marker=dict(color=colors[experiment.id]["primary"] if colors else None),
163+
name=experiment.name,
164+
))
165+
bin_mins = df.groupby('bin').min()
166+
fig.add_trace(go.Scatter(
167+
x=bin_mins[x_axis],
168+
y=bin_mins["value"],
169+
mode='lines',
170+
fill="tonexty",
171+
marker=dict(color=colors[experiment.id]["primary"] if colors else None),
172+
name=experiment.name,
173+
))
174+
else:
175+
fig.update_layout(
176+
title=f"Total Fidelity: {metric_name}, showing {n}/{n} points",
177+
)
178+
fig.add_trace(go.Scatter(
179+
x=df[x_axis],
180+
y=df["value"],
181+
mode='lines',
182+
fill=None,
183+
marker=dict(color=colors[experiment.id]["primary"] if colors else None),
184+
name=experiment.name,
185+
))
196186

197187
bar.empty()
198188
#st.plotly_chart(fig, use_container_width=True)

0 commit comments

Comments
 (0)