From 81451b2fed1f4fd979900f912b48e59f85bf2803 Mon Sep 17 00:00:00 2001 From: Karthikeya Kodlai <134675928+sketch123456@users.noreply.github.com> Date: Tue, 21 Oct 2025 22:52:30 -0700 Subject: [PATCH] Sync PSD Axes --- hnn_core/gui/_viz_manager.py | 62 ++++++++++++++++++------------------ 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/hnn_core/gui/_viz_manager.py b/hnn_core/gui/_viz_manager.py index a40a4472f1..12e6e19b84 100644 --- a/hnn_core/gui/_viz_manager.py +++ b/hnn_core/gui/_viz_manager.py @@ -319,49 +319,49 @@ def _update_ax(fig, ax, single_simulation, sim_name, plot_type, plot_config): invert_spike_types=distal_drives, color=drive_colors, ) - - elif plot_type == "PSD": + + elif plot_type in ["PSD", "layer2/3 PSD", "layer5 PSD"]: if len(dpls_copied) > 0: min_f = plot_config["min_spectral_frequency"] max_f = plot_config["max_spectral_frequency"] color = ax._get_lines.get_next_color() - label = sim_name + " (Aggregate)" - dpls_copied[0].plot_psd( - fmin=min_f, fmax=max_f, color=color, label=label, ax=ax, show=False - ) + + layer_label = { + "PSD": "Aggregate", + "layer2/3 PSD": "Layer 2/3", + "layer5 PSD": "Layer 5", + }[plot_type] + label = f"{sim_name} ({layer_label})" + + layer = { + "PSD" : "agg", + "layer2/3 PSD": "L2", + "layer5 PSD": "L5", + }[plot_type] - elif plot_type == "layer2/3 PSD": - if len(dpls_copied) > 0: - min_f = plot_config["min_spectral_frequency"] - max_f = plot_config["max_spectral_frequency"] - color = ax._get_lines.get_next_color() - label = sim_name + " (Layer 2/3)" dpls_copied[0].plot_psd( - fmin=min_f, - fmax=max_f, - layer="L2", - color=color, - label=label, ax=ax, show=False, - ) - - elif plot_type == "layer5 PSD": - if len(dpls_copied) > 0: - min_f = plot_config["min_spectral_frequency"] - max_f = plot_config["max_spectral_frequency"] - color = ax._get_lines.get_next_color() - label = sim_name + " (Layer 5)" - dpls_copied[0].plot_psd( - fmin=min_f, - fmax=max_f, - layer="L5", color=color, label=label, - ax=ax, - show=False, + fmin=min_f, + fmax=max_f, + layer=layer, ) + if not hasattr(fig, "_psd_axes"): + fig._psd_axes = [] + fig._psd_max_values = [] + + y_max = ax.get_ylim()[1] + fig._psd_axes.append(ax) + fig._psd_max_values.append(y_max) + + if len(fig._psd_axes) > 1: + global_y_max = max(fig._psd_max_values) + for psd_ax in fig._psd_axes: + psd_ax.set_ylim(top=global_y_max) + elif plot_type == "spectrogram": if len(dpls_copied) > 0: min_f = plot_config["min_spectral_frequency"]