|
| 1 | +import inspect |
1 | 2 | from copy import copy
|
2 | 3 | from dataclasses import dataclass, field, replace
|
3 |
| -from inspect import _empty, signature |
4 | 4 | from typing import Any, Callable
|
5 | 5 |
|
6 | 6 | import pandas as pd
|
@@ -77,17 +77,28 @@ def validate(self, context: dict):
|
77 | 77 | """
|
78 | 78 | all_kwargs = self.resolve_kwargs(context)
|
79 | 79 | # we expect a dataframe as the first argument, so skip validation for that
|
80 |
| - params = list(signature(self.func).parameters.items())[1:] |
| 80 | + params = list(inspect.signature(self.func).parameters.items())[1:] |
| 81 | + # Because the signature and params therein do not indicate varadics, we have to |
| 82 | + # consult getfullargspec for those. Similarly, since fullargspec doesn't indicate |
| 83 | + # positional only, we utilize both. |
| 84 | + argspec = inspect.getfullargspec(self.func) |
| 85 | + if argspec.args[0] in all_kwargs: |
| 86 | + raise ValueError("first position argument reserved for input dataframe but found in kwargs") |
81 | 87 | for key, param in params:
|
82 |
| - if param.default == _empty: |
| 88 | + if param.name in [argspec.varargs, argspec.varkw]: |
| 89 | + continue |
| 90 | + if param.kind == inspect.Parameter.POSITIONAL_ONLY: |
| 91 | + raise ValueError(f"only one positional only argument supported, found also {param.name}") |
| 92 | + if param.default == inspect._empty: |
83 | 93 | if key not in all_kwargs:
|
84 | 94 | msg = f"missing required argument {key} for task {self.name}"
|
85 | 95 | if key in self.context_kwargs:
|
86 | 96 | msg = f"{msg} - expected from context {self.context_kwargs[key]}"
|
87 | 97 | raise ValueError(msg)
|
88 |
| - extra = set(all_kwargs.keys()) - {k for k, v in params} |
89 |
| - if len(extra) > 0: |
90 |
| - raise ValueError(f"unkown arguments {extra} for task {self.name}") |
| 98 | + if not argspec.varkw: |
| 99 | + extra = set(all_kwargs.keys()) - {k for k, v in params} |
| 100 | + if len(extra) > 0: |
| 101 | + raise ValueError(f"unkown arguments {extra} for task {self.name}") |
91 | 102 |
|
92 | 103 | def __hash__(self):
|
93 | 104 | return hash(self.name)
|
0 commit comments