Skip to content

Commit 6f64901

Browse files
authored
Improve sql.functions (#71)
1 parent 97986c3 commit 6f64901

File tree

1 file changed

+236
-29
lines changed

1 file changed

+236
-29
lines changed

sqlalchemy-stubs/sql/functions.pyi

Lines changed: 236 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ from typing import Type
66
from typing import TypeVar
77
from typing import Union
88

9+
from typing_extensions import Protocol
10+
911
from . import sqltypes
1012
from . import type_api
1113
from .base import ColumnCollection
@@ -29,6 +31,7 @@ from .selectable import Select
2931
from .selectable import TableValuedAlias
3032
from .visitors import TraversibleType
3133

34+
_T_co = TypeVar("_T_co", covariant=True)
3235
_TE = TypeVar("_TE", bound=type_api.TypeEngine[Any])
3336
_FE = TypeVar("_FE", bound=FunctionElement[Any])
3437

@@ -74,7 +77,7 @@ class FunctionElement( # type: ignore[misc]
7477
) -> FunctionFilter[_TE]: ...
7578
def as_comparison(
7679
self, left_index: int, right_index: int
77-
) -> FunctionAsBinary[_TE]: ...
80+
) -> FunctionAsBinary: ...
7881
def within_group_type(
7982
self, within_group: Any
8083
) -> Optional[type_api.TypeEngine[Any]]: ...
@@ -86,18 +89,18 @@ class FunctionElement( # type: ignore[misc]
8689
self: _FE, against: Optional[Any] = ...
8790
) -> Union[_FE, Grouping[_TE], AsBoolean[_FE]]: ...
8891

89-
class FunctionAsBinary(BinaryExpression[_TE]):
90-
sql_function: FunctionElement[_TE] = ...
92+
class FunctionAsBinary(BinaryExpression[sqltypes.Boolean]):
93+
sql_function: FunctionElement[Any] = ...
9194
left_index: int = ...
9295
right_index: int = ...
9396
operator: Any = ...
94-
type: Any = ...
97+
type: sqltypes.Boolean = ...
9598
negate: Any = ...
9699
modifiers: Any = ...
97100
left: ClauseElement = ...
98101
right: ClauseElement = ...
99102
def __init__(
100-
self, fn: FunctionElement[_TE], left_index: int, right_index: int
103+
self, fn: FunctionElement[Any], left_index: int, right_index: int
101104
) -> None: ...
102105

103106
class ScalarFunctionColumn(NamedColumn[_TE]):
@@ -135,9 +138,15 @@ class _FunctionGenerator:
135138
func: _FunctionGenerator
136139
modifier: _FunctionGenerator
137140

141+
class _TypeDescriptor(Protocol[_T_co]):
142+
@overload
143+
def __get__(self, instance: None, owner: Any) -> _T_co: ...
144+
@overload
145+
def __get__(self, instance: GenericFunction[_TE], owner: Any) -> _TE: ...
146+
138147
class Function(FunctionElement[_TE]):
139148
__visit_name__: str = ...
140-
type: Any = ...
149+
type: _TypeDescriptor[sqltypes.NullType] = ... # type: ignore[assignment]
141150
packagenames: Any = ...
142151
name: Any = ...
143152
@overload
@@ -161,7 +170,7 @@ class GenericFunction(Function[_TE], metaclass=_GenericMeta):
161170
inherit_cache: bool = ...
162171
packagenames: Any = ...
163172
clause_expr: Any = ...
164-
type: _TE = ...
173+
type: _TypeDescriptor[sqltypes.NullType] = ...
165174
@overload
166175
def __init__(
167176
self: GenericFunction[sqltypes.NullType],
@@ -175,7 +184,7 @@ class GenericFunction(Function[_TE], metaclass=_GenericMeta):
175184
) -> None: ...
176185

177186
class next_value(GenericFunction[_TE]):
178-
type: _TE = ...
187+
type: _TypeDescriptor[sqltypes.Integer] = ... # type: ignore[assignment]
179188
name: str = ...
180189
sequence: Sequence[_TE] = ...
181190
def __init__(self, seq: Sequence[_TE], **kw: Any) -> None: ...
@@ -201,68 +210,222 @@ class sum(ReturnTypeFromArgs[_TE]):
201210
inherit_cache: bool = ...
202211

203212
class now(GenericFunction[_TE]):
204-
type: Any = ...
213+
type: _TypeDescriptor[Type[sqltypes.DateTime]] = ... # type: ignore[assignment]
205214
inherit_cache: bool = ...
215+
@overload
216+
def __init__(
217+
self: now[sqltypes.DateTime],
218+
*args: Any,
219+
type_: None = ...,
220+
**kwargs: Any,
221+
) -> None: ...
222+
@overload
223+
def __init__(
224+
self, *args: Any, type_: Union[_TE, Type[_TE]], **kwargs: Any
225+
) -> None: ...
206226

207227
class concat(GenericFunction[_TE]):
208-
type: Any = ...
228+
type: _TypeDescriptor[Type[sqltypes.String]] = ... # type: ignore[assignment]
209229
inherit_cache: bool = ...
230+
@overload
231+
def __init__(
232+
self: concat[sqltypes.String],
233+
*args: Any,
234+
type_: None = ...,
235+
**kwargs: Any,
236+
) -> None: ...
237+
@overload
238+
def __init__(
239+
self, *args: Any, type_: Union[_TE, Type[_TE]], **kwargs: Any
240+
) -> None: ...
210241

211242
class char_length(GenericFunction[_TE]):
212-
type: Any = ...
243+
type: _TypeDescriptor[Type[sqltypes.Integer]] = ... # type: ignore[assignment]
213244
inherit_cache: bool = ...
214-
def __init__(self, arg: Any, **kwargs: Any) -> None: ...
245+
@overload
246+
def __init__(
247+
self: char_length[sqltypes.Integer],
248+
arg: Any,
249+
type_: None = ...,
250+
**kwargs: Any,
251+
) -> None: ...
252+
@overload
253+
def __init__(
254+
self, arg: Any, type_: Union[_TE, Type[_TE]], **kwargs: Any
255+
) -> None: ...
215256

216257
class random(GenericFunction[_TE]):
217258
inherit_cache: bool = ...
218259

219260
class count(GenericFunction[_TE]):
220-
type: Any = ...
261+
type: _TypeDescriptor[Type[sqltypes.Integer]] = ... # type: ignore[assignment]
221262
inherit_cache: bool = ...
263+
@overload
222264
def __init__(
223-
self, expression: Optional[Any] = ..., **kwargs: Any
265+
self: count[sqltypes.Integer],
266+
expression: Optional[Any] = ...,
267+
*,
268+
type_: None = ...,
269+
**kwargs: Any,
270+
) -> None: ...
271+
@overload
272+
def __init__(
273+
self,
274+
expression: Optional[Any] = ...,
275+
*,
276+
type_: Union[_TE, Type[_TE]],
277+
**kwargs: Any,
224278
) -> None: ...
225279

226280
class current_date(AnsiFunction[_TE]):
227-
type: Any = ...
281+
type: _TypeDescriptor[Type[sqltypes.Date]] = ... # type: ignore[assignment]
228282
inherit_cache: bool = ...
283+
@overload
284+
def __init__(
285+
self: current_date[sqltypes.Date],
286+
*args: Any,
287+
type_: None = ...,
288+
**kwargs: Any,
289+
) -> None: ...
290+
@overload
291+
def __init__(
292+
self, *args: Any, type_: Union[_TE, Type[_TE]], **kwargs: Any
293+
) -> None: ...
229294

230295
class current_time(AnsiFunction[_TE]):
231-
type: Any = ...
296+
type: _TypeDescriptor[Type[sqltypes.Time]] = ... # type: ignore[assignment]
232297
inherit_cache: bool = ...
298+
@overload
299+
def __init__(
300+
self: current_time[sqltypes.Time],
301+
*args: Any,
302+
type_: None = ...,
303+
**kwargs: Any,
304+
) -> None: ...
305+
@overload
306+
def __init__(
307+
self, *args: Any, type_: Union[_TE, Type[_TE]], **kwargs: Any
308+
) -> None: ...
233309

234310
class current_timestamp(AnsiFunction[_TE]):
235-
type: Any = ...
311+
type: _TypeDescriptor[Type[sqltypes.DateTime]] = ... # type: ignore[assignment]
236312
inherit_cache: bool = ...
313+
@overload
314+
def __init__(
315+
self: current_timestamp[sqltypes.DateTime],
316+
*args: Any,
317+
type_: None = ...,
318+
**kwargs: Any,
319+
) -> None: ...
320+
@overload
321+
def __init__(
322+
self, *args: Any, type_: Union[_TE, Type[_TE]], **kwargs: Any
323+
) -> None: ...
237324

238325
class current_user(AnsiFunction[_TE]):
239-
type: Any = ...
326+
type: _TypeDescriptor[Type[sqltypes.String]] = ... # type: ignore[assignment]
240327
inherit_cache: bool = ...
328+
@overload
329+
def __init__(
330+
self: current_user[sqltypes.String],
331+
*args: Any,
332+
type_: None = ...,
333+
**kwargs: Any,
334+
) -> None: ...
335+
@overload
336+
def __init__(
337+
self, *args: Any, type_: Union[_TE, Type[_TE]], **kwargs: Any
338+
) -> None: ...
241339

242340
class localtime(AnsiFunction[_TE]):
243-
type: Any = ...
341+
type: _TypeDescriptor[Type[sqltypes.DateTime]] = ... # type: ignore[assignment]
244342
inherit_cache: bool = ...
343+
@overload
344+
def __init__(
345+
self: localtime[sqltypes.DateTime],
346+
*args: Any,
347+
type_: None = ...,
348+
**kwargs: Any,
349+
) -> None: ...
350+
@overload
351+
def __init__(
352+
self, *args: Any, type_: Union[_TE, Type[_TE]], **kwargs: Any
353+
) -> None: ...
245354

246355
class localtimestamp(AnsiFunction[_TE]):
247-
type: Any = ...
356+
type: _TypeDescriptor[Type[sqltypes.DateTime]] = ... # type: ignore[assignment]
248357
inherit_cache: bool = ...
358+
@overload
359+
def __init__(
360+
self: localtimestamp[sqltypes.DateTime],
361+
*args: Any,
362+
type_: None = ...,
363+
**kwargs: Any,
364+
) -> None: ...
365+
@overload
366+
def __init__(
367+
self, *args: Any, type_: Union[_TE, Type[_TE]], **kwargs: Any
368+
) -> None: ...
249369

250370
class session_user(AnsiFunction[_TE]):
251-
type: Any = ...
371+
type: _TypeDescriptor[Type[sqltypes.String]] = ... # type: ignore[assignment]
252372
inherit_cache: bool = ...
373+
@overload
374+
def __init__(
375+
self: session_user[sqltypes.String],
376+
*args: Any,
377+
type_: None = ...,
378+
**kwargs: Any,
379+
) -> None: ...
380+
@overload
381+
def __init__(
382+
self, *args: Any, type_: Union[_TE, Type[_TE]], **kwargs: Any
383+
) -> None: ...
253384

254385
class sysdate(AnsiFunction[_TE]):
255-
type: Any = ...
386+
type: _TypeDescriptor[Type[sqltypes.DateTime]] = ... # type: ignore[assignment]
256387
inherit_cache: bool = ...
388+
@overload
389+
def __init__(
390+
self: sysdate[sqltypes.DateTime],
391+
*args: Any,
392+
type_: None = ...,
393+
**kwargs: Any,
394+
) -> None: ...
395+
@overload
396+
def __init__(
397+
self, *args: Any, type_: Union[_TE, Type[_TE]], **kwargs: Any
398+
) -> None: ...
257399

258400
class user(AnsiFunction[_TE]):
259-
type: Any = ...
401+
type: _TypeDescriptor[Type[sqltypes.String]] = ... # type: ignore[assignment]
260402
inherit_cache: bool = ...
403+
@overload
404+
def __init__(
405+
self: user[sqltypes.String],
406+
*args: Any,
407+
type_: None = ...,
408+
**kwargs: Any,
409+
) -> None: ...
410+
@overload
411+
def __init__(
412+
self, *args: Any, type_: Union[_TE, Type[_TE]], **kwargs: Any
413+
) -> None: ...
261414

262415
class array_agg(GenericFunction[_TE]):
263-
type: Any = ...
416+
type: _TypeDescriptor[Type[sqltypes.ARRAY[Any]]] = ... # type: ignore[assignment]
264417
inherit_cache: bool = ...
265-
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
418+
@overload
419+
def __init__(
420+
self: array_agg[sqltypes.ARRAY[Any]],
421+
*args: Any,
422+
type_: None = ...,
423+
**kwargs: Any,
424+
) -> None: ...
425+
@overload
426+
def __init__(
427+
self, *args: Any, type_: Union[_TE, Type[_TE]], **kwargs: Any
428+
) -> None: ...
266429

267430
class OrderedSetAgg(GenericFunction[_TE]):
268431
array_for_multi_clause: bool = ...
@@ -283,20 +446,64 @@ class percentile_disc(OrderedSetAgg[_TE]):
283446
inherit_cache: bool = ...
284447

285448
class rank(GenericFunction[_TE]):
286-
type: Any = ...
449+
type: _TypeDescriptor[sqltypes.Integer] = ... # type: ignore[assignment]
287450
inherit_cache: bool = ...
451+
@overload
452+
def __init__(
453+
self: rank[sqltypes.Integer],
454+
*args: Any,
455+
type_: None = ...,
456+
**kwargs: Any,
457+
) -> None: ...
458+
@overload
459+
def __init__(
460+
self, *args: Any, type_: Union[_TE, Type[_TE]], **kwargs: Any
461+
) -> None: ...
288462

289463
class dense_rank(GenericFunction[_TE]):
290-
type: Any = ...
464+
type: _TypeDescriptor[sqltypes.Integer] = ... # type: ignore[assignment]
291465
inherit_cache: bool = ...
466+
@overload
467+
def __init__(
468+
self: dense_rank[sqltypes.Integer],
469+
*args: Any,
470+
type_: None = ...,
471+
**kwargs: Any,
472+
) -> None: ...
473+
@overload
474+
def __init__(
475+
self, *args: Any, type_: Union[_TE, Type[_TE]], **kwargs: Any
476+
) -> None: ...
292477

293478
class percent_rank(GenericFunction[_TE]):
294-
type: Any = ...
479+
type: _TypeDescriptor[sqltypes.Numeric] = ... # type: ignore[assignment]
295480
inherit_cache: bool = ...
481+
@overload
482+
def __init__(
483+
self: percent_rank[sqltypes.Numeric],
484+
*args: Any,
485+
type_: None = ...,
486+
**kwargs: Any,
487+
) -> None: ...
488+
@overload
489+
def __init__(
490+
self, *args: Any, type_: Union[_TE, Type[_TE]], **kwargs: Any
491+
) -> None: ...
296492

297493
class cume_dist(GenericFunction[_TE]):
298-
type: Any = ...
494+
type: _TypeDescriptor[sqltypes.Numeric] = ... # type: ignore[assignment]
299495
inherit_cache: bool = ...
496+
@overload
497+
def __init__(
498+
self: cume_dist[sqltypes.Numeric],
499+
*args: Any,
500+
type_: None = ...,
501+
**kwargs: Any,
502+
) -> None: ...
503+
@overload
504+
def __init__(
505+
self, *args: Any, type_: Union[_TE, Type[_TE]], **kwargs: Any
506+
) -> None: ...
300507

301508
class cube(GenericFunction[_TE]):
302509
inherit_cache: bool = ...

0 commit comments

Comments
 (0)