Skip to content

Commit aae88b2

Browse files
committed
Better plots
1 parent 57f2e4d commit aae88b2

File tree

1 file changed

+21
-14
lines changed

1 file changed

+21
-14
lines changed

src/axiomatic/pic_helpers.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,9 @@ def plot_interactive_spectra(
112112
A list of spectra, where each spectrum is a list of lists of float values, each
113113
corresponding to the transmission of a single wavelength.
114114
wavelengths : list of float
115-
A list of wavelength values corresponding to the x-axis of the plot.
115+
A list of wavelength values corresponding to the x-axis of the plot, in nm.
116116
vlines : list of float, optional
117-
A list of x-values where vertical lines should be drawn. Defaults to an empty list.
117+
A list of x-values where vertical lines should be drawn, in nm. Defaults to an empty list.
118118
hlines : list of float, optional
119119
A list of y-values where horizontal lines should be drawn. Defaults to an empty list.
120120
"""
@@ -149,9 +149,16 @@ def plot_interactive_spectra(
149149
all_vals = [val for spec in spectra for iteration in spec for val in iteration]
150150
y_min = min(all_vals)
151151
y_max = max(all_vals)
152-
if hlines:
153-
y_min = min(hlines + [y_min]) * 0.95
154-
y_max = max(hlines + [y_max]) * 1.05
152+
153+
# dB scale
154+
if y_max <= 0:
155+
y_max = 0
156+
db = True
157+
else:
158+
db = False
159+
if hlines:
160+
y_min = min(hlines + [y_min]) * 0.95
161+
y_max = max(hlines + [y_max]) * 1.05
155162

156163
# Create hlines and vlines
157164
shapes = []
@@ -187,8 +194,8 @@ def plot_interactive_spectra(
187194

188195
# Create the layout
189196
fig.update_layout(
190-
xaxis_title="Wavelength",
191-
yaxis_title="Transmission",
197+
xaxis_title="Wavelength (nm)",
198+
yaxis_title="Transmission " + "(dB)" if db else "(linear)",
192199
shapes=shapes,
193200
sliders=sliders,
194201
yaxis=dict(range=[y_min, y_max]),
@@ -454,10 +461,10 @@ def print_statements(
454461

455462
def _str_units_to_float(str_units: str) -> Optional[float]:
456463
unit_conversions = {
457-
"nm": 1e-3,
458-
"um": 1,
459-
"mm": 1e3,
460-
"m": 1e6,
464+
"nm": 1,
465+
"um": 1e3,
466+
"mm": 1e6,
467+
"m": 1e9,
461468
}
462469
match = re.match(r"([\d\.]+)\s*([a-zA-Z]+)", str_units)
463470
numeric_value = float(match.group(1)) if match else None
@@ -469,7 +476,7 @@ def get_wavelengths_to_plot(statements: StatementDictionary, num_samples: int =
469476
"""
470477
Get the wavelengths to plot based on the statements.
471478
472-
Returns a list of wavelengths to plot the spectra and a list of vertical lines to plot on top the spectra.
479+
Returns a list of wavelengths to plot the spectra and a list of vertical lines to plot on top the spectra, in nm.
473480
"""
474481

475482
min_wl = float("inf")
@@ -511,8 +518,8 @@ def update_wavelengths(mapping: Dict[str, Optional[Computation]], min_wl: float,
511518
min_wl = min(min_wl, min(vlines))
512519
max_wl = max(max_wl, max(vlines))
513520
if min_wl >= max_wl:
514-
avg_wl = sum(vlines) / len(vlines) if vlines else 1.55
515-
min_wl, max_wl = avg_wl - 0.01, avg_wl + 0.01
521+
avg_wl = sum(vlines) / len(vlines) if vlines else _str_units_to_float("1550 nm")
522+
min_wl, max_wl = avg_wl - _str_units_to_float("10 nm"), avg_wl + _str_units_to_float("10 nm")
516523
else:
517524
range_size = max_wl - min_wl
518525
min_wl -= 0.2 * range_size

0 commit comments

Comments
 (0)