Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions seaborn/axisgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,12 +1248,43 @@ def __init__(
data = handle_data_source(data)

# Sort out the variables that define the grid
numeric_cols = self._find_numeric_cols(data)
if hue in numeric_cols:
numeric_cols.remove(hue)
if vars is not None:

# `_find_numeric_cols` will crash if there are duplicated columns in the data,
# even if these columns are not in the `vars` variable.
# If `vars` is provided, we don't need `_find_numeric_cols`.
# If `vars` is not provided, we need `_find_numeric_cols`,
# but it crashes with ambigious ValueError: The truth value of a DataFrame ...
# My fix is to skip `_find_numeric_cols` when `vars` is provided.
# And raise early error if data_to_plot.columns is duplicated.
# [I suppose duplicated columns are not expected in PairGrid]

if vars is not None: # user provide vars
x_vars = list(vars)
y_vars = list(vars)
if len(set(vars)) < len(x_vars):
# Does not crash, only causes unexpected figures.
# Do not take efforts to specify duplicants.
warnings.warn(f"Duplicated items in vars: {x_vars}")
condensed_vars = list(set(x_vars))
else:
condensed_vars = x_vars
# Use condensed_vars to avoid duplicated items in vars
# causing duplicates in data.loc[:, vars].columns.
selected_columns = data.loc[:, condensed_vars].columns
if not selected_columns.is_unique:
# Crash if duplicated columns are selected in vars.
# Specify duplicants since we raise an Error.
dupe_cols = selected_columns[selected_columns.duplicated()]
raise ValueError(
f"Columns: {dupe_cols} are duplicated.")
else:
if not data.columns.is_unique:
dupe_cols = data.columns[data.columns.duplicated()]
raise ValueError(
f"Columns: {dupe_cols} are duplicated.")
numeric_cols = self._find_numeric_cols(data)
if hue in numeric_cols:
numeric_cols.remove(hue)
if x_vars is None:
x_vars = numeric_cols
if y_vars is None:
Expand Down
26 changes: 26 additions & 0 deletions tests/test_axisgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,32 @@ def test_remove_hue_from_default(self):
assert hue in g.x_vars
assert hue in g.y_vars

def test_duplicates_in_df_columns_without_vars(self):
# should fail with clear msg
df_with_dupe = self.df.loc[:, ["x", "y", "z"]].copy()
df_with_dupe.columns = ["x", "y", "y"]
with pytest.raises(ValueError, match=r"Columns: .* are duplicated\."):
ag.PairGrid(df_with_dupe)

def test_duplicates_in_df_columns_with_related_vars(self):
# should fail with clear msg
df_with_dupe = self.df.loc[:, ["x", "y", "z"]].copy()
df_with_dupe.columns = ["x", "y", "y"]
with pytest.raises(ValueError, match=r"Columns: .* are duplicated\."):
ag.PairGrid(df_with_dupe, vars=['x', 'y'])

def test_duplicated_vars(self):
# should only warn
with pytest.warns(UserWarning, match=r"Duplicated items in vars: .*"):
ag.PairGrid(self.df, vars=['x', 'y', 'y'])

def test_duplicates_in_df_columns_with_not_related_vars(self):
# should pass
df_with_dupe = pd.concat(
[self.df["x"], self.df["y"], self.df["z"], self.df["x"]], axis=1)
df_with_dupe.columns = ["x", "y", "z", "z"]
ag.PairGrid(df_with_dupe, vars=['x', 'y'])

@pytest.mark.parametrize(
"x_vars, y_vars",
[
Expand Down