Skip to content

Commit f27edc5

Browse files
committed
API: Use self.params in geom.draw_layer & geom.draw_panel
1 parent ebdfd4c commit f27edc5

30 files changed

+121
-115
lines changed

doc/changelog.qmd

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ title: Changelog
2222
- The `stat` class methods `stat.compute_layer`, `stat.compute_panel` and `stat.compute_group` are now instance methods and they no longer accept
2323
`**params` arguments. Access to the parameters is through `self.params`.
2424

25+
- The `geom` class methods `geom.draw_layer` and `geom.draw_panel` do no longer accept `**param` arguments. Access to the parameters is through `self.params`.
26+
27+
- Method `geom.draw_group` now accepts the `params` argument as a dictionary and not `**params`.
2528

2629

2730
### New Features

plotnine/geoms/annotation_logticks.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,15 @@ class _geom_logticks(geom_rug):
4646
}
4747
draw_legend = staticmethod(geom_path.draw_legend)
4848

49-
def draw_layer(
50-
self, data: pd.DataFrame, layout: Layout, coord: coord, **params: Any
51-
):
49+
def draw_layer(self, data: pd.DataFrame, layout: Layout, coord: coord):
5250
"""
5351
Draw ticks on every panel
5452
"""
5553
for pid in layout.layout["PANEL"]:
5654
ploc = pid - 1
5755
panel_params = layout.panel_params[ploc]
5856
ax = layout.axs[ploc]
59-
self.draw_panel(data, panel_params, coord, ax, **params)
57+
self.draw_panel(data, panel_params, coord, ax)
6058

6159
@staticmethod
6260
def _check_log_scale(
@@ -184,8 +182,8 @@ def draw_panel(
184182
panel_params: panel_view,
185183
coord: coord,
186184
ax: Axes,
187-
**params: Any,
188185
):
186+
params = self.params
189187
# Any passed data is ignored, the relevant data is created
190188
sides = params["sides"]
191189
lengths = params["lengths"]
@@ -203,9 +201,8 @@ def _draw(
203201
):
204202
for position, length in zip(tick_positions, lengths):
205203
data = pd.DataFrame({axis: position, **_aesthetics})
206-
geom.draw_group(
207-
data, panel_params, coord, ax, length=length, **params
208-
)
204+
params["length"] = length
205+
geom.draw_group(data, panel_params, coord, ax, params)
209206

210207
if isinstance(coord, coord_flip):
211208
tick_range_x = panel_params.y.range

plotnine/geoms/annotation_stripes.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,25 +96,23 @@ class _geom_stripes(geom):
9696
}
9797
draw_legend = staticmethod(geom_polygon.draw_legend)
9898

99-
def draw_layer(
100-
self, data: pd.DataFrame, layout: Layout, coord: coord, **params: Any
101-
):
99+
def draw_layer(self, data: pd.DataFrame, layout: Layout, coord: coord):
102100
"""
103101
Draw stripes on every panel
104102
"""
105103
for pid in layout.layout["PANEL"]:
106104
ploc = pid - 1
107105
panel_params = layout.panel_params[ploc]
108106
ax = layout.axs[ploc]
109-
self.draw_group(data, panel_params, coord, ax, **params)
107+
self.draw_group(data, panel_params, coord, ax, self.params)
110108

111109
@staticmethod
112110
def draw_group(
113111
data: pd.DataFrame,
114112
panel_params: panel_view,
115113
coord: coord,
116114
ax: Axes,
117-
**params: Any,
115+
params: dict[str, Any],
118116
):
119117
extend = params["extend"]
120118
fill_range = params["fill_range"]
@@ -195,4 +193,4 @@ def draw_group(
195193
}
196194
)
197195

198-
return geom_rect.draw_group(data, panel_params, coord, ax, **params)
196+
return geom_rect.draw_group(data, panel_params, coord, ax, params)

plotnine/geoms/geom.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,21 @@ def __deepcopy__(self, memo: dict[Any, Any]) -> geom:
171171

172172
return result
173173

174+
def setup_params(self, data: pd.DataFrame):
175+
"""
176+
Override this method to verify and/or adjust parameters
177+
178+
Parameters
179+
----------
180+
data :
181+
Data
182+
183+
Returns
184+
-------
185+
out :
186+
Parameters used by the geoms.
187+
"""
188+
174189
def setup_data(self, data: pd.DataFrame) -> pd.DataFrame:
175190
"""
176191
Modify the data before drawing takes place
@@ -261,9 +276,7 @@ def use_defaults(
261276

262277
return data
263278

264-
def draw_layer(
265-
self, data: pd.DataFrame, layout: Layout, coord: coord, **params: Any
266-
):
279+
def draw_layer(self, data: pd.DataFrame, layout: Layout, coord: coord):
267280
"""
268281
Draw layer across all panels
269282
@@ -289,15 +302,14 @@ def draw_layer(
289302
ploc = pdata["PANEL"].iloc[0] - 1
290303
panel_params = layout.panel_params[ploc]
291304
ax = layout.axs[ploc]
292-
self.draw_panel(pdata, panel_params, coord, ax, **params)
305+
self.draw_panel(pdata, panel_params, coord, ax)
293306

294307
def draw_panel(
295308
self,
296309
data: pd.DataFrame,
297310
panel_params: panel_view,
298311
coord: coord,
299312
ax: Axes,
300-
**params: Any,
301313
):
302314
"""
303315
Plot all groups
@@ -331,15 +343,15 @@ def draw_panel(
331343
"""
332344
for _, gdata in data.groupby("group"):
333345
gdata.reset_index(inplace=True, drop=True)
334-
self.draw_group(gdata, panel_params, coord, ax, **params)
346+
self.draw_group(gdata, panel_params, coord, ax, self.params)
335347

336348
@staticmethod
337349
def draw_group(
338350
data: pd.DataFrame,
339351
panel_params: panel_view,
340352
coord: coord,
341353
ax: Axes,
342-
**params: Any,
354+
params: dict[str, Any],
343355
):
344356
"""
345357
Plot data belonging to a group.
@@ -376,7 +388,7 @@ def draw_unit(
376388
panel_params: panel_view,
377389
coord: coord,
378390
ax: Axes,
379-
**params: Any,
391+
params: dict[str, Any],
380392
):
381393
"""
382394
Plot data belonging to a unit.

plotnine/geoms/geom_abline.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ def draw_panel(
102102
panel_params: panel_view,
103103
coord: coord,
104104
ax: Axes,
105-
**params: Any,
106105
):
107106
"""
108107
Plot all groups
@@ -116,4 +115,6 @@ def draw_panel(
116115

117116
for _, gdata in data.groupby("group"):
118117
gdata.reset_index(inplace=True)
119-
geom_segment.draw_group(gdata, panel_params, coord, ax, **params)
118+
geom_segment.draw_group(
119+
gdata, panel_params, coord, ax, self.params
120+
)

plotnine/geoms/geom_blank.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
from .geom import geom
77

88
if typing.TYPE_CHECKING:
9-
from typing import Any
10-
119
import pandas as pd
1210
from matplotlib.axes import Axes
1311

@@ -39,7 +37,6 @@ def draw_panel(
3937
panel_params: panel_view,
4038
coord: coord,
4139
ax: Axes,
42-
**params: Any,
4340
):
4441
pass
4542

plotnine/geoms/geom_boxplot.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def draw_group(
183183
panel_params: panel_view,
184184
coord: coord,
185185
ax: Axes,
186-
**params: Any,
186+
params: dict[str, Any],
187187
):
188188
def flat(*args: pd.Series[Any]) -> npt.NDArray[Any]:
189189
"""Flatten list-likes"""
@@ -245,11 +245,11 @@ def outlier_value(param: str) -> Any:
245245
outliers["shape"] = outlier_value("shape")
246246
outliers["size"] = outlier_value("size")
247247
outliers["stroke"] = outlier_value("stroke")
248-
geom_point.draw_group(outliers, panel_params, coord, ax, **params)
248+
geom_point.draw_group(outliers, panel_params, coord, ax, params)
249249

250250
# plot
251-
geom_segment.draw_group(whiskers, panel_params, coord, ax, **params)
252-
geom_crossbar.draw_group(box, panel_params, coord, ax, **params)
251+
geom_segment.draw_group(whiskers, panel_params, coord, ax, params)
252+
geom_crossbar.draw_group(box, panel_params, coord, ax, params)
253253

254254
@staticmethod
255255
def draw_legend(

plotnine/geoms/geom_crossbar.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def draw_group(
7878
panel_params: panel_view,
7979
coord: coord,
8080
ax: Axes,
81-
**params: Any,
81+
params: dict[str, Any],
8282
):
8383
y = data["y"]
8484
xmin = data["xmin"]
@@ -160,8 +160,8 @@ def flat(*args: pd.Series[Any]) -> npt.NDArray[Any]:
160160
)
161161

162162
copy_missing_columns(box, data)
163-
geom_polygon.draw_group(box, panel_params, coord, ax, **params)
164-
geom_segment.draw_group(middle, panel_params, coord, ax, **params)
163+
geom_polygon.draw_group(box, panel_params, coord, ax, params)
164+
geom_segment.draw_group(middle, panel_params, coord, ax, params)
165165

166166
@staticmethod
167167
def draw_legend(

plotnine/geoms/geom_dotplot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def draw_group(
186186
panel_params: panel_view,
187187
coord: coord,
188188
ax: Axes,
189-
**params: Any,
189+
params: dict[str, Any],
190190
):
191191
from matplotlib.collections import PatchCollection
192192
from matplotlib.patches import Ellipse

plotnine/geoms/geom_errorbar.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def draw_group(
6767
panel_params: panel_view,
6868
coord: coord,
6969
ax: Axes,
70-
**params: Any,
70+
params: dict[str, Any],
7171
):
7272
f = np.hstack
7373
# create (two horizontal bars) + vertical bar
@@ -81,4 +81,4 @@ def draw_group(
8181
)
8282

8383
copy_missing_columns(bars, data)
84-
geom_segment.draw_group(bars, panel_params, coord, ax, **params)
84+
geom_segment.draw_group(bars, panel_params, coord, ax, params)

plotnine/geoms/geom_errorbarh.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def draw_group(
6767
panel_params: panel_view,
6868
coord: coord,
6969
ax: Axes,
70-
**params: Any,
70+
params: dict[str, Any],
7171
):
7272
f = np.hstack
7373
# create (two vertical bars) + horizontal bar
@@ -81,4 +81,4 @@ def draw_group(
8181
)
8282

8383
copy_missing_columns(bars, data)
84-
geom_segment.draw_group(bars, panel_params, coord, ax, **params)
84+
geom_segment.draw_group(bars, panel_params, coord, ax, params)

plotnine/geoms/geom_hline.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ def draw_panel(
8080
panel_params: panel_view,
8181
coord: coord,
8282
ax: Axes,
83-
**params: Any,
8483
):
8584
"""
8685
Plot all groups
@@ -94,4 +93,6 @@ def draw_panel(
9493

9594
for _, gdata in data.groupby("group"):
9695
gdata.reset_index(inplace=True)
97-
geom_segment.draw_group(gdata, panel_params, coord, ax, **params)
96+
geom_segment.draw_group(
97+
gdata, panel_params, coord, ax, self.params
98+
)

plotnine/geoms/geom_linerange.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def draw_group(
4949
panel_params: panel_view,
5050
coord: coord,
5151
ax: Axes,
52-
**params: Any,
52+
params: dict[str, Any],
5353
):
5454
data.eval(
5555
"""
@@ -59,4 +59,4 @@ def draw_group(
5959
""",
6060
inplace=True,
6161
)
62-
geom_segment.draw_group(data, panel_params, coord, ax, **params)
62+
geom_segment.draw_group(data, panel_params, coord, ax, params)

plotnine/geoms/geom_map.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,11 @@ def draw_panel(
119119
panel_params: panel_view,
120120
coord: coord,
121121
ax: Axes,
122-
**params: Any,
123122
):
124123
if not len(data):
125124
return
126125

126+
params = self.params
127127
data.loc[data["color"].isna(), "color"] = "none"
128128
data.loc[data["fill"].isna(), "fill"] = "none"
129129
data["fill"] = to_rgba(data["fill"], data["alpha"])
@@ -153,7 +153,7 @@ def draw_panel(
153153
for _, gdata in data.groupby("group"):
154154
gdata.reset_index(inplace=True, drop=True)
155155
gdata.is_copy = None
156-
geom_point.draw_group(gdata, panel_params, coord, ax, **params)
156+
geom_point.draw_group(gdata, panel_params, coord, ax, params)
157157
elif geom_type == "MultiPoint":
158158
# Where n is the length of the dataframe (no. of multipoints),
159159
# m is the number of all points in all multipoints
@@ -168,7 +168,7 @@ def draw_panel(
168168
data = data.explode("points", ignore_index=True)
169169
data["x"] = [p[0] for p in data["points"]]
170170
data["y"] = [p[1] for p in data["points"]]
171-
geom_point.draw_group(data, panel_params, coord, ax, **params)
171+
geom_point.draw_group(data, panel_params, coord, ax, params)
172172
elif geom_type in ("LineString", "MultiLineString"):
173173
from matplotlib.collections import LineCollection
174174

0 commit comments

Comments
 (0)