Skip to content

Commit b3b04a6

Browse files
committed
refactor(eda): make comp. and plot API consistent
1 parent ad1afc7 commit b3b04a6

File tree

3 files changed

+22
-24
lines changed

3 files changed

+22
-24
lines changed

dataprep/eda/distribution/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def plot(
9898
cfg = Config.from_dict(display, config)
9999

100100
with ProgressBar(minimum=1, disable=not progress):
101-
itmdt = compute(df, col1, col2, col3, cfg=cfg, dtype=dtype)
101+
itmdt = compute(df, col1, col2, col3, config=cfg, dtype=dtype)
102102

103103
to_render = render(itmdt, cfg)
104104

dataprep/eda/distribution/compute/__init__.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def compute(
2525
col2: Optional[Union[str, LatLong]] = None,
2626
col3: Optional[str] = None,
2727
*,
28-
cfg: Union[Config, Dict[str, Any], None] = None,
28+
config: Union[Config, Dict[str, Any], None] = None,
2929
display: Optional[List[str]] = None,
3030
dtype: Optional[DTypeDef] = None,
3131
) -> Intermediate:
@@ -36,10 +36,10 @@ def compute(
3636
----------
3737
df
3838
DataFrame from which visualizations are generated
39-
cfg: Union[Config, Dict[str, Any], None], default None
39+
config: Union[Config, Dict[str, Any], None], default None
4040
When a user call plot(), the created Config object will be passed to compute().
4141
When a user call compute() directly, if he/she wants to customize the output,
42-
cfg is a dictionary for configuring. If not, cfg is None and
42+
config is a dictionary for configuring. If not, config is None and
4343
default values will be used for parameters.
4444
display: Optional[List[str]], default None
4545
A list containing the names of the visualizations to display. Only exist when
@@ -60,10 +60,9 @@ def compute(
6060

6161
suppress_warnings()
6262

63-
if isinstance(cfg, dict):
64-
cfg = Config.from_dict(display, cfg)
65-
66-
elif not cfg:
63+
if isinstance(config, dict):
64+
cfg = Config.from_dict(display, config)
65+
else:
6766
cfg = Config()
6867

6968
x, y, z = col1, col2, col3

dataprep/eda/distribution/render.py

+15-16
Original file line numberDiff line numberDiff line change
@@ -2455,43 +2455,42 @@ def render_dt_num_cat(itmdt: Intermediate, cfg: Config) -> Dict[str, Any]:
24552455
}
24562456

24572457

2458-
def render(itmdt: Intermediate, cfg: Config) -> Union[LayoutDOM, Dict[str, Any]]:
2458+
def render(itmdt: Intermediate, config: Config) -> Union[LayoutDOM, Dict[str, Any]]:
24592459
"""
24602460
Render a basic plot
24612461
Parameters
24622462
----------
24632463
itmdt
24642464
The Intermediate containing results from the compute function.
2465-
cfg
2465+
config
24662466
Config instance
24672467
"""
24682468
# pylint: disable = too-many-branches
2469-
24702469
if itmdt.visual_type == "distribution_grid":
2471-
visual_elem = render_distribution_grid(itmdt, cfg)
2470+
visual_elem = render_distribution_grid(itmdt, config)
24722471
elif itmdt.visual_type == "categorical_column":
2473-
visual_elem = render_cat(itmdt, cfg)
2472+
visual_elem = render_cat(itmdt, config)
24742473
elif itmdt.visual_type == "geography_column":
2475-
visual_elem = render_geo(itmdt, cfg)
2474+
visual_elem = render_geo(itmdt, config)
24762475
elif itmdt.visual_type == "numerical_column":
2477-
visual_elem = render_num(itmdt, cfg)
2476+
visual_elem = render_num(itmdt, config)
24782477
elif itmdt.visual_type == "datetime_column":
2479-
visual_elem = render_dt(itmdt, cfg)
2478+
visual_elem = render_dt(itmdt, config)
24802479
elif itmdt.visual_type == "cat_and_num_cols":
2481-
visual_elem = render_cat_num(itmdt, cfg)
2480+
visual_elem = render_cat_num(itmdt, config)
24822481
elif itmdt.visual_type == "geo_and_num_cols":
2483-
visual_elem = render_geo_num(itmdt, cfg)
2482+
visual_elem = render_geo_num(itmdt, config)
24842483
elif itmdt.visual_type == "latlong_and_num_cols":
2485-
visual_elem = render_latlong_num(itmdt, cfg)
2484+
visual_elem = render_latlong_num(itmdt, config)
24862485
elif itmdt.visual_type == "two_num_cols":
2487-
visual_elem = render_two_num(itmdt, cfg)
2486+
visual_elem = render_two_num(itmdt, config)
24882487
elif itmdt.visual_type == "two_cat_cols":
2489-
visual_elem = render_two_cat(itmdt, cfg)
2488+
visual_elem = render_two_cat(itmdt, config)
24902489
elif itmdt.visual_type == "dt_and_num_cols":
2491-
visual_elem = render_dt_num(itmdt, cfg)
2490+
visual_elem = render_dt_num(itmdt, config)
24922491
elif itmdt.visual_type == "dt_and_cat_cols":
2493-
visual_elem = render_dt_cat(itmdt, cfg)
2492+
visual_elem = render_dt_cat(itmdt, config)
24942493
elif itmdt.visual_type == "dt_cat_num_cols":
2495-
visual_elem = render_dt_num_cat(itmdt, cfg)
2494+
visual_elem = render_dt_num_cat(itmdt, config)
24962495

24972496
return visual_elem

0 commit comments

Comments
 (0)