Skip to content

Commit f1a61eb

Browse files
committed
Added HistogramViewer panel
1 parent bdc9b0b commit f1a61eb

File tree

1 file changed

+239
-0
lines changed

1 file changed

+239
-0
lines changed
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
%pip install streamlit_plotly_events
2+
3+
import comet_ml
4+
from comet_ml.data_structure import Histogram
5+
import streamlit as st
6+
import numbers
7+
import numpy as np
8+
import plotly.graph_objects as go
9+
import matplotlib.pyplot as plt
10+
import io
11+
import base64
12+
from streamlit_plotly_events import plotly_events
13+
14+
st.set_page_config(layout="wide", initial_sidebar_state="collapsed")
15+
16+
options = {
17+
"start": None,
18+
"stop": None,
19+
"ybins": None,
20+
"xbins": 50,
21+
## Colors are scaled from highest to lowest. You can add
22+
## additional values between 0 and 1 to add color ranges.
23+
"colorScale": [
24+
[0, "white"], ## lower values
25+
[0.5, "gray"], ## middle value
26+
[1, "blue"], ## higher values
27+
],
28+
"showScale": False,
29+
"layout": {
30+
"title": "",
31+
"xaxis": {"ticks": "", "side": "bottom", "title": "Steps"},
32+
"yaxis": {
33+
"ticks": "",
34+
"ticksuffix": " ",
35+
"title": "Weights"
36+
},
37+
},
38+
}
39+
40+
@st.dialog("Histogram by Step")
41+
def show_selected_data(selected_data, z, weight_labels, step_labels):
42+
y, x = selected_data[0]["pointNumber"]
43+
step = step_labels[x]
44+
data = [column[x] for column in z]
45+
fig = go.Figure(data=go.Bar(
46+
y=data,
47+
x=weight_labels,
48+
))
49+
fig.update_layout(
50+
title='Step %s' % step,
51+
xaxis_title='Weight',
52+
yaxis_title='Count',
53+
barmode='group'
54+
)
55+
st.plotly_chart(fig)
56+
57+
def generate_column_plot(column, z):
58+
fig, ax = plt.subplots(figsize=(2, 2))
59+
ax.plot([1, 2, 3, 2, 5, 2, 1])
60+
ax.axis('off')
61+
62+
buf = io.BytesIO()
63+
fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
64+
img_data = base64.b64encode(buf.getbuffer()).decode("utf8")
65+
plt.close(fig)
66+
return f'<img src="data:image/png;base64,{img_data}">'
67+
68+
def transpose(matrix):
69+
return [list(row) for row in zip(*matrix)]
70+
71+
def my_range(start, stop, step):
72+
retval = []
73+
i = start
74+
while i <= stop:
75+
retval.append(i)
76+
i += step
77+
return retval
78+
79+
def getMinMax(histogram):
80+
min = None
81+
max = None
82+
# Find the first/min value:
83+
for i in range(len(histogram.counts)):
84+
if histogram.counts[i] > 0:
85+
min = histogram.values[i - 1]
86+
break
87+
88+
# Find the last/max value:
89+
for i in range(len(histogram.counts) - 1, -1, -1):
90+
if histogram.counts[i] > 0:
91+
max = histogram.values[i + 1]
92+
break
93+
94+
if min is None and max is None:
95+
min = -1.0
96+
max = 1.0
97+
#print(min, max)
98+
return [min, max]
99+
100+
101+
102+
103+
def get_sample(length, max_steps):
104+
"""
105+
Get selected/sampled indices
106+
"""
107+
## Start with all:
108+
selected = range(length)
109+
if length > max_steps:
110+
## need to sample
111+
## always include the first and last
112+
selected = [0] + random.sample(selected[1:-1], max_steps - 2) + [length - 1]
113+
return selected
114+
115+
116+
def get_histogram_values(asset, start, stop, bins, maxSteps):
117+
## First, collect them:
118+
histograms = []
119+
## {'histograms': [{'step': num, 'histogram': {'index_values'}}, ...]
120+
index = 0
121+
selected_indices = get_sample(len(asset["histograms"]), maxSteps)
122+
xValues = []
123+
for hist in asset["histograms"]:
124+
if index in selected_indices:
125+
#print(hist)
126+
h = Histogram.from_json(hist["histogram"])
127+
xValues.append(hist["step"])
128+
histograms.append(h)
129+
index += 1
130+
## Next, find the start/stop
131+
zValues = []
132+
#xValues = []
133+
if start is None or stop is None:
134+
minimum = None
135+
maximum = None
136+
for histogram in histograms:
137+
minmax = getMinMax(histogram)
138+
minimum = min(minimum if minimum is not None else float("+inf"), minmax[0])
139+
maximum = max(maximum if maximum is not None else float("-inf"), minmax[1])
140+
if start is None:
141+
start = minimum
142+
if stop is None:
143+
stop = maximum
144+
145+
if bins == None:
146+
bins = 50
147+
148+
for histogram in histograms:
149+
#print(start, stop)
150+
data = histogram.get_counts(start, stop, (stop - start) / bins)
151+
#print(data)
152+
zValues.append(data)
153+
#print(histogram)
154+
#xValues.append(histogram.step)
155+
156+
yValues = my_range(start, stop, (stop - start) / bins)
157+
#print("zValues:", transpose(zValues))
158+
return [xValues, yValues, transpose(zValues)]
159+
160+
161+
def drawHistogram(asset):
162+
[x, y, z] = get_histogram_values(
163+
asset, options["start"], options["stop"], options["ybins"], options["xbins"]
164+
)
165+
166+
if len(z[0]) == 0:
167+
print("No histogram data available")
168+
return
169+
170+
#sums = []
171+
#for step in range(len(z[0])):
172+
# sum = 0.0
173+
# for bin in range(len(z)):
174+
# if isinstance(z[bin][step], numbers.Number):
175+
# sum += z[bin][step]
176+
# sums.append(sum)
177+
178+
## This is run on every cell:
179+
# hoverText = function(bin, step, item) {
180+
# stepPercentage = (sums[step] === 0
181+
# ? 0
182+
# : (item / sums[step]) * 100
183+
# ).toFixed(2)
184+
# return `Step: ${x[step]}<br>
185+
# Value: ${y[bin].toFixed(2)}</br>
186+
# Count: ${item.toFixed(2)}<br>
187+
# Count % of step: ${stepPercentage}%`
188+
# }
189+
190+
#print(x)
191+
192+
#column_plots = [generate_column_plot(column, z) for column in x]
193+
194+
fig = go.Figure(data=go.Heatmap(
195+
z=z,
196+
x=x,
197+
y=y,
198+
colorscale=options["colorScale"],
199+
showscale=options["showScale"],
200+
))
201+
202+
fig.update_layout(**options["layout"])
203+
selected_data = plotly_events(fig)
204+
if selected_data:
205+
show_selected_data(selected_data, z, y, x)
206+
207+
208+
api = comet_ml.API()
209+
210+
experiments = api.get_panel_experiments()
211+
212+
if len(experiments) == 0:
213+
print("No available experiments")
214+
st.stop()
215+
elif len(experiments) == 1:
216+
experiment = experiments[0]
217+
else:
218+
experiment = st.sidebar.selectbox(
219+
"Experiment:",
220+
experiments,
221+
format_func=lambda experiment: experiment.name or experiment.id,
222+
)
223+
224+
assets = experiment.get_asset_list("histogram_combined_3d")
225+
226+
if len(assets) == 0:
227+
print("No available histograms")
228+
st.stop()
229+
elif len(assets) == 1:
230+
asset = assets[0]
231+
else:
232+
asset = st.sidebar.selectbox(
233+
"Histogram:",
234+
sorted(assets, key=lambda item: item["fileName"]),
235+
format_func=lambda asset: asset["fileName"],
236+
)
237+
238+
histogram = experiment.get_asset(asset["assetId"], return_type="json")
239+
drawHistogram(histogram)

0 commit comments

Comments
 (0)