diff --git a/pydatalab/src/pydatalab/apps/xrd/blocks.py b/pydatalab/src/pydatalab/apps/xrd/blocks.py index 5fb501997..046ddca15 100644 --- a/pydatalab/src/pydatalab/apps/xrd/blocks.py +++ b/pydatalab/src/pydatalab/apps/xrd/blocks.py @@ -23,7 +23,7 @@ class XRDBlock(DataBlock): description = "Visualize XRD patterns and perform simple baseline corrections." accepted_file_extensions = (".xrdml", ".xy", ".dat", ".xye", ".rasx", ".cif") - defaults = {"wavelength": 1.54060} + defaults = {"wavelength": 1.54060, "stagger_enabled": True, "stagger_offset": 1.0} @property def plot_functions(self): @@ -59,7 +59,8 @@ def load_pattern( df, peak_data = compute_cif_pxrd( location, wavelength=wavelength or cls.defaults["wavelength"] ) - theoretical = True # Track whether this is a computed PXRD that does not need background subtraction + # Track whether this is a computed PXRD that does not need background subtraction + theoretical = True else: columns = ["twotheta", "intensity", "error"] @@ -235,7 +236,25 @@ def generate_xrd_plot(self) -> None: warnings.warn(f"Could not parse file {f['location']} as XRD data. Error: {exc}") continue peak_information[str(f["immutable_id"])] = PeakInformation(**peak_data).dict() - pattern_df["normalized intensity (staggered)"] += ind + stagger_enabled = self.data.get("stagger_enabled", self.defaults["stagger_enabled"]) + stagger_offset = float( + self.data.get("stagger_offset", self.defaults["stagger_offset"]) + ) + + if stagger_enabled: + for col in pattern_df.columns: + if "intensity" in col.lower(): + if "(staggered)" not in col: + staggered_col = f"{col} (staggered)" + pattern_df[staggered_col] = pattern_df[col] + (ind * stagger_offset) + else: + pattern_df[col] = pattern_df[col] + (ind * stagger_offset) + else: + for col in pattern_df.columns: + if "intensity" in col.lower() and "(staggered)" not in col: + staggered_col = f"{col} (staggered)" + pattern_df[staggered_col] = pattern_df[col] + pattern_dfs.append(pattern_df) self.data["peak_data"] = peak_information @@ -261,6 +280,26 @@ def generate_xrd_plot(self) -> None: pattern_dfs = [pattern_dfs] if pattern_dfs: + stagger_enabled = self.data.get("stagger_enabled", self.defaults["stagger_enabled"]) + stagger_offset = float(self.data.get("stagger_offset", self.defaults["stagger_offset"])) + + if stagger_enabled and len(pattern_dfs) > 1: + for ind, df in enumerate(pattern_dfs): + offset = ind * stagger_offset + for col in df.columns: + if "intensity" in col.lower(): + original_col = f"{col}_original" + df[original_col] = df[col].copy() + + df[col] = df[col].astype("float64") + df[col] = df[col] + offset + else: + for df in pattern_dfs: + for col in df.columns: + if "intensity" in col.lower(): + original_col = f"{col}_original" + df[original_col] = df[col].copy() + p = selectable_axes_plot( pattern_dfs, x_options=["2θ (°)", "Q (Å⁻¹)", "d (Å)"], diff --git a/pydatalab/src/pydatalab/bokeh_plots.py b/pydatalab/src/pydatalab/bokeh_plots.py index 906693141..ea24e5eb8 100644 --- a/pydatalab/src/pydatalab/bokeh_plots.py +++ b/pydatalab/src/pydatalab/bokeh_plots.py @@ -24,6 +24,8 @@ from bokeh.themes import Theme from scipy.signal import find_peaks +from .utils import shrink_label + FONTSIZE = "12pt" TYPEFACE = "Helvetica, sans-serif" COLORS = Dark2[8] @@ -35,6 +37,10 @@ if (line1) {line1.glyph.x.field = column;} source.change.emit(); xaxis.axis_label = column; + + if (hover_tool) { + hover_tool.tooltips = [["File", "@filename"], [column, "$x{0.00}"], [hover_tool.tooltips[2][0], "$y{0.00}"]]; + } """ SELECTABLE_CALLBACK_y = """ var column = cb_obj.value; @@ -42,7 +48,33 @@ if (line1) {line1.glyph.y.field = column;} source.change.emit(); yaxis.axis_label = column; + + if (hover_tool) { + var tooltips = [["File", "@filename"], [hover_tool.tooltips[1][0], "$x{0.00}"]]; + + if (column.toLowerCase().includes('intensity')) { + var original_column = column + "_original"; + var column_exists = false; + for (var key in source.data) { + if (key === original_column) { + column_exists = true; + break; + } + } + + if (column_exists) { + tooltips.push([column, "@{" + original_column + "}{0.00}"]); + } else { + tooltips.push([column, "$y{0.00}"]); + } + } else { + tooltips.push([column, "$y{0.00}"]); + } + + hover_tool.tooltips = tooltips; + } """ + GENERATE_CSV_CALLBACK = """ let columns = Object.keys(source.data); console.log(columns); @@ -234,6 +266,7 @@ def selectable_axes_plot( title=plot_title, **kwargs, ) + p.toolbar.logo = "grey" if tools: @@ -258,9 +291,16 @@ def selectable_axes_plot( labels = [] if isinstance(df, dict): - labels = list(df.keys()) + original_labels = list(df.keys()) + else: + original_labels = [ + df_.index.name if df_.index.name else f"Dataset {i}" for i, df_ in enumerate(df) + ] + + labels = [shrink_label(label) for label in original_labels] plot_columns = [] + hover_tools = [] for ind, df_ in enumerate(df): if skip_plot: @@ -268,13 +308,27 @@ def selectable_axes_plot( if isinstance(df, dict): df_ = df[df_] - - if labels: - label = labels[ind] + filename = list(df.keys())[ind] else: - label = df_.index.name if len(df) > 1 else "" + filename = original_labels[ind] + + label = labels[ind] if ind < len(labels) else "" - source = ColumnDataSource(df_) + df_with_filename = df_.copy() + df_with_filename["filename"] = filename + + source = ColumnDataSource(df_with_filename) + + current_tooltips = [("File:", "@filename"), (x_axis_label, "$x{0.00}")] + + if "intensity" in y_label.lower(): + original_col = f"{y_label}_original" + if original_col in source.data: + current_tooltips.append((y_axis_label, f"@{{{original_col}}}{{0.00}}")) + else: + current_tooltips.append((y_axis_label, f"@{{{y_default}}}")) + else: + current_tooltips.append((y_axis_label, f"@{{{y_default}}}{{0.00}}")) if color_options: color = {"field": color_options[0], "transform": color_mapper} @@ -315,11 +369,26 @@ def selectable_axes_plot( ) lines = ( - p.line(x=x_default, y=y_default, source=source, color=line_color, legend_label=label) + p.line( + x=x_default, + y=y_default, + source=source, + color=line_color, + legend_label=label, + line_width=2, + ) if plot_line else None ) + line_hover = None + if lines: + line_hover = HoverTool( + tooltips=current_tooltips, mode="mouse", renderers=[lines], line_policy="nearest" + ) + p.add_tools(line_hover) + hover_tools.append(line_hover) + if y_aux: for y in y_aux: aux_lines = ( # noqa @@ -337,13 +406,26 @@ def selectable_axes_plot( callbacks_x.append( CustomJS( - args=dict(circle1=circles, line1=lines, source=source, xaxis=p.xaxis[0]), + args=dict( + circle1=circles, + line1=lines, + source=source, + xaxis=p.xaxis[0], + hover_tool=line_hover, + ), code=SELECTABLE_CALLBACK_x, ) ) + callbacks_y.append( CustomJS( - args=dict(circle1=circles, line1=lines, source=source, yaxis=p.yaxis[0]), + args=dict( + circle1=circles, + line1=lines, + source=source, + yaxis=p.yaxis[0], + hover_tool=line_hover, + ), code=SELECTABLE_CALLBACK_y, ) ) @@ -362,9 +444,24 @@ def selectable_axes_plot( yaxis_select.js_on_change("value", *callbacks_y) if p.legend: - p.legend.click_policy = "hide" + p.legend.click_policy = "none" if len(df) <= 1: p.legend.visible = False + else: + legend_items = p.legend.items + p.legend.visible = False + + from bokeh.models import Legend + + external_legend = Legend( + items=legend_items, + click_policy="none", + background_fill_alpha=0.8, + label_text_font_size="9pt", + spacing=1, + margin=2, + ) + p.add_layout(external_legend, "right") if not skip_plot: plot_columns.append(p) @@ -394,7 +491,41 @@ def selectable_axes_plot( ) plot_columns = [table] + plot_columns - layout = column(*plot_columns) + if plot_points and plot_line: + from bokeh.layouts import row + + show_points_btn = Button( + label="✓ Show points", button_type="primary", width_policy="min", margin=(2, 5, 2, 5) + ) + + circle_renderers = [r for r in p.renderers if hasattr(r.glyph, "size")] + + points_callback = CustomJS( + args=dict(btn=show_points_btn, renderers=circle_renderers), + code=""" + if (btn.label.includes('✓')) { + btn.label = '✗ Show points'; + btn.button_type = 'default'; + for (var i = 0; i < renderers.length; i++) { + renderers[i].visible = false; + } + } else { + btn.label = '✓ Show points'; + btn.button_type = 'primary'; + for (var i = 0; i < renderers.length; i++) { + renderers[i].visible = true; + } + } + """, + ) + + show_points_btn.js_on_click(points_callback) + + controls_layout = row(show_points_btn, sizing_mode="scale_width", margin=(10, 0, 10, 0)) + + plot_columns.append(controls_layout) + + layout = column(*plot_columns, sizing_mode="scale_width") p.js_on_event(DoubleTap, CustomJS(args=dict(p=p), code="p.reset.emit()")) return layout diff --git a/pydatalab/src/pydatalab/utils.py b/pydatalab/src/pydatalab/utils.py index e99975400..78055869d 100644 --- a/pydatalab/src/pydatalab/utils.py +++ b/pydatalab/src/pydatalab/utils.py @@ -53,3 +53,30 @@ class BSONProvider(DefaultJSONProvider): @staticmethod def default(o): return CustomJSONEncoder.default(o) + + +def shrink_label(label: str | None, max_length: int = 10) -> str: + """Shrink label to exactly max_length chars with format: start...end.ext""" + if not label: + return "" + + if len(label) <= max_length: + return label + + if "." in label: + name, ext = label.rsplit(".", 1) + + extension_length = len(ext) + 1 + + available_for_start = max_length - extension_length - 4 + + if available_for_start >= 1: + name_start = name[:available_for_start] + last_char = name[-1] + return f"{name_start}...{last_char}.{ext}" + else: + name_start = name[0] + last_char = name[-1] + return f"{name_start}...{last_char}.{ext}" + else: + return label[: max_length - 3] + "..." diff --git a/webapp/src/components/datablocks/BokehBlock.vue b/webapp/src/components/datablocks/BokehBlock.vue index c086fa6f2..0b882b35b 100644 --- a/webapp/src/components/datablocks/BokehBlock.vue +++ b/webapp/src/components/datablocks/BokehBlock.vue @@ -10,10 +10,8 @@ DataBlockBase as a prop, and save from within DataBlockBase --> update-block-on-change /> -