Skip to content

Commit f4a82f2

Browse files
authored
Fix bug in task validation to handle variadic args (#98)
This also handles the case of multiple positional only args - since we only support a single one which is the input dataframe - and where a keyword argument overlaps with this first argument.
1 parent eed9a2e commit f4a82f2

File tree

3 files changed

+48
-7
lines changed

3 files changed

+48
-7
lines changed

dplutils/pipeline/task.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
import inspect
12
from copy import copy
23
from dataclasses import dataclass, field, replace
3-
from inspect import _empty, signature
44
from typing import Any, Callable
55

66
import pandas as pd
@@ -77,17 +77,28 @@ def validate(self, context: dict):
7777
"""
7878
all_kwargs = self.resolve_kwargs(context)
7979
# we expect a dataframe as the first argument, so skip validation for that
80-
params = list(signature(self.func).parameters.items())[1:]
80+
params = list(inspect.signature(self.func).parameters.items())[1:]
81+
# Because the signature and params therein do not indicate varadics, we have to
82+
# consult getfullargspec for those. Similarly, since fullargspec doesn't indicate
83+
# positional only, we utilize both.
84+
argspec = inspect.getfullargspec(self.func)
85+
if argspec.args[0] in all_kwargs:
86+
raise ValueError("first position argument reserved for input dataframe but found in kwargs")
8187
for key, param in params:
82-
if param.default == _empty:
88+
if param.name in [argspec.varargs, argspec.varkw]:
89+
continue
90+
if param.kind == inspect.Parameter.POSITIONAL_ONLY:
91+
raise ValueError(f"only one positional only argument supported, found also {param.name}")
92+
if param.default == inspect._empty:
8393
if key not in all_kwargs:
8494
msg = f"missing required argument {key} for task {self.name}"
8595
if key in self.context_kwargs:
8696
msg = f"{msg} - expected from context {self.context_kwargs[key]}"
8797
raise ValueError(msg)
88-
extra = set(all_kwargs.keys()) - {k for k, v in params}
89-
if len(extra) > 0:
90-
raise ValueError(f"unkown arguments {extra} for task {self.name}")
98+
if not argspec.varkw:
99+
extra = set(all_kwargs.keys()) - {k for k, v in params}
100+
if len(extra) > 0:
101+
raise ValueError(f"unkown arguments {extra} for task {self.name}")
91102

92103
def __hash__(self):
93104
return hash(self.name)

tests/conftest.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,22 @@ def func(dataframe, required, optional=1):
9393
return PipelineTask("name", func)
9494

9595

96+
@pytest.fixture
97+
def generic_task_with_starkwargs():
98+
def func(dataframe, **kwargs):
99+
pass
100+
101+
return PipelineTask("name", func)
102+
103+
104+
@pytest.fixture
105+
def generic_task_with_position_only():
106+
def func(dataframe, position_only, /):
107+
pass
108+
109+
return PipelineTask("name", func)
110+
111+
96112
@pytest.fixture
97113
def dummy_executor(dummy_steps):
98114
return DummyExecutor(graph=dummy_steps)

tests/pipeline/test_pipeline_task.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,19 @@ def test_pipeline_task_check_kwargs_from_context(generic_task_with_required):
4040

4141

4242
def test_pipeline_task_check_kwargs_from_context_raises_with_missing(generic_task_with_required):
43-
with pytest.raises(ValueError):
43+
with pytest.raises(ValueError, match="expected from context"):
4444
generic_task_with_required(context_kwargs={"required": "ctx_required"}).validate({})
45+
46+
47+
def test_pipeline_task_with_starkwargs_argument_allows_any(generic_task_with_starkwargs):
48+
generic_task_with_starkwargs(kwargs={"some_kwarg": 1}).validate({})
49+
50+
51+
def test_pipeline_task_with_first_arg_as_kwarg_raises(generic_task_with_starkwargs):
52+
with pytest.raises(ValueError, match="first position argument reserved"):
53+
generic_task_with_starkwargs(kwargs={"dataframe": 1}).validate({})
54+
55+
56+
def test_pipeline_task_with_multiple_position_only_raises(generic_task_with_position_only):
57+
with pytest.raises(ValueError, match="only one positional only"):
58+
generic_task_with_position_only().validate({})

0 commit comments

Comments
 (0)