diff --git a/src/cev/_embedding_comparison_widget.py b/src/cev/_embedding_comparison_widget.py index 18cf17e..eb1518c 100644 --- a/src/cev/_embedding_comparison_widget.py +++ b/src/cev/_embedding_comparison_widget.py @@ -22,6 +22,9 @@ from cev._widget_utils import add_ilocs_trait, parse_label from cev.components import MarkerSelectionIndicator, WidthOptimizer +if typing.TYPE_CHECKING: + import pandas as pd + def _create_titles( titles: tuple[str, str] @@ -46,11 +49,17 @@ def _create_titles( return spacer, title_widget +def _coerce_to_embedding(embedding: Embedding | pd.DataFrame) -> Embedding: + if isinstance(embedding, Embedding): + return embedding + return Embedding.from_df(embedding) + + class EmbeddingComparisonWidget(ipywidgets.VBox): def __init__( self, - left_embedding: Embedding, - right_embedding: Embedding, + left_embedding: Embedding | pd.DataFrame, + right_embedding: Embedding | pd.DataFrame, row_height: int = 250, metric: typing.Literal["confusion", "neigbhorhood", "abundance"] = "confusion", inverted_colormap: bool = False, @@ -61,6 +70,8 @@ def __init__( active_markers: list[str] | typing.Literal["all"] = "all", **kwargs, ): + left_embedding = _coerce_to_embedding(left_embedding) + right_embedding = _coerce_to_embedding(right_embedding) pointwise_correspondence = has_pointwise_correspondence( left_embedding, right_embedding )