Skip to content

Commit effcf41

Browse files
committed
Add async support
1 parent 1bfec18 commit effcf41

File tree

2 files changed

+84
-30
lines changed

2 files changed

+84
-30
lines changed

singlestoredb/functions/decorator.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import functools
23
import inspect
34
from typing import Any
@@ -19,6 +20,7 @@
1920
]
2021

2122
ReturnType = ParameterType
23+
UDFType = Callable[..., Any]
2224

2325

2426
def is_valid_type(obj: Any) -> bool:
@@ -101,7 +103,7 @@ def _func(
101103
args: Optional[ParameterType] = None,
102104
returns: Optional[ReturnType] = None,
103105
timeout: Optional[int] = None,
104-
) -> Callable[..., Any]:
106+
) -> UDFType:
105107
"""Generic wrapper for UDF and TVF decorators."""
106108

107109
_singlestoredb_attrs = { # type: ignore
@@ -117,23 +119,33 @@ def _func(
117119
# called later, so the wrapper much be created with the func passed
118120
# in at that time.
119121
if func is None:
120-
def decorate(func: Callable[..., Any]) -> Callable[..., Any]:
122+
def decorate(func: UDFType) -> UDFType:
121123

122-
def wrapper(*args: Any, **kwargs: Any) -> Callable[..., Any]:
123-
return func(*args, **kwargs) # type: ignore
124+
if asyncio.iscoroutinefunction(func):
125+
async def async_wrapper(*args: Any, **kwargs: Any) -> UDFType:
126+
return await func(*args, **kwargs) # type: ignore
127+
async_wrapper._singlestoredb_attrs = _singlestoredb_attrs # type: ignore
128+
return functools.wraps(func)(async_wrapper)
124129

125-
wrapper._singlestoredb_attrs = _singlestoredb_attrs # type: ignore
126-
127-
return functools.wraps(func)(wrapper)
130+
else:
131+
def wrapper(*args: Any, **kwargs: Any) -> UDFType:
132+
return func(*args, **kwargs) # type: ignore
133+
wrapper._singlestoredb_attrs = _singlestoredb_attrs # type: ignore
134+
return functools.wraps(func)(wrapper)
128135

129136
return decorate
130137

131-
def wrapper(*args: Any, **kwargs: Any) -> Callable[..., Any]:
132-
return func(*args, **kwargs) # type: ignore
133-
134-
wrapper._singlestoredb_attrs = _singlestoredb_attrs # type: ignore
138+
if asyncio.iscoroutinefunction(func):
139+
async def async_wrapper(*args: Any, **kwargs: Any) -> UDFType:
140+
return await func(*args, **kwargs) # type: ignore
141+
async_wrapper._singlestoredb_attrs = _singlestoredb_attrs # type: ignore
142+
return functools.wraps(func)(async_wrapper)
135143

136-
return functools.wraps(func)(wrapper)
144+
else:
145+
def wrapper(*args: Any, **kwargs: Any) -> UDFType:
146+
return func(*args, **kwargs) # type: ignore
147+
wrapper._singlestoredb_attrs = _singlestoredb_attrs # type: ignore
148+
return functools.wraps(func)(wrapper)
137149

138150

139151
def udf(
@@ -143,7 +155,7 @@ def udf(
143155
args: Optional[ParameterType] = None,
144156
returns: Optional[ReturnType] = None,
145157
timeout: Optional[int] = None,
146-
) -> Callable[..., Any]:
158+
) -> UDFType:
147159
"""
148160
Define a user-defined function (UDF).
149161

singlestoredb/functions/ext/asgi.py

Lines changed: 59 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,8 @@ def build_udf_endpoint(
285285
"""
286286
if returns_data_format in ['scalar', 'list']:
287287

288+
is_async = asyncio.iscoroutinefunction(func)
289+
288290
async def do_func(
289291
cancel_event: threading.Event,
290292
row_ids: Sequence[int],
@@ -297,7 +299,10 @@ async def do_func(
297299
raise asyncio.CancelledError(
298300
'Function call was cancelled',
299301
)
300-
out.append(func(*row))
302+
if is_async:
303+
out.append(await func(*row))
304+
else:
305+
out.append(func(*row))
301306
return row_ids, list(zip(out))
302307

303308
return do_func
@@ -327,6 +332,7 @@ def build_vector_udf_endpoint(
327332
"""
328333
masks = get_masked_params(func)
329334
array_cls = get_array_class(returns_data_format)
335+
is_async = asyncio.iscoroutinefunction(func)
330336

331337
async def do_func(
332338
cancel_event: threading.Event,
@@ -341,9 +347,15 @@ async def do_func(
341347

342348
# Call the function with `cols` as the function parameters
343349
if cols and cols[0]:
344-
out = func(*[x if m else x[0] for x, m in zip(cols, masks)])
350+
if is_async:
351+
out = await func(*[x if m else x[0] for x, m in zip(cols, masks)])
352+
else:
353+
out = func(*[x if m else x[0] for x, m in zip(cols, masks)])
345354
else:
346-
out = func()
355+
if is_async:
356+
out = await func()
357+
else:
358+
out = func()
347359

348360
# Single masked value
349361
if isinstance(out, Masked):
@@ -381,6 +393,8 @@ def build_tvf_endpoint(
381393
"""
382394
if returns_data_format in ['scalar', 'list']:
383395

396+
is_async = asyncio.iscoroutinefunction(func)
397+
384398
async def do_func(
385399
cancel_event: threading.Event,
386400
row_ids: Sequence[int],
@@ -390,11 +404,15 @@ async def do_func(
390404
out_ids: List[int] = []
391405
out = []
392406
# Call function on each row of data
393-
for i, res in zip(row_ids, func_map(func, rows)):
407+
for i, row in zip(row_ids, rows):
394408
if cancel_event.is_set():
395409
raise asyncio.CancelledError(
396410
'Function call was cancelled',
397411
)
412+
if is_async:
413+
res = await func(*row)
414+
else:
415+
res = func(*row)
398416
out.extend(as_list_of_tuples(res))
399417
out_ids.extend([row_ids[i]] * (len(out)-len(out_ids)))
400418
return out_ids, out
@@ -440,13 +458,23 @@ async def do_func(
440458
# each result row, so we just have to use the same
441459
# row ID for all rows in the result.
442460

461+
is_async = asyncio.iscoroutinefunction(func)
462+
443463
# Call function on each column of data
444464
if cols and cols[0]:
445-
res = get_dataframe_columns(
446-
func(*[x if m else x[0] for x, m in zip(cols, masks)]),
447-
)
465+
if is_async:
466+
res = get_dataframe_columns(
467+
await func(*[x if m else x[0] for x, m in zip(cols, masks)]),
468+
)
469+
else:
470+
res = get_dataframe_columns(
471+
func(*[x if m else x[0] for x, m in zip(cols, masks)]),
472+
)
448473
else:
449-
res = get_dataframe_columns(func())
474+
if is_async:
475+
res = get_dataframe_columns(await func())
476+
else:
477+
res = get_dataframe_columns(func())
450478

451479
# Generate row IDs
452480
if isinstance(res[0], Masked):
@@ -508,6 +536,9 @@ def make_func(
508536
# Set timeout
509537
info['timeout'] = max(timeout, 1)
510538

539+
# Set async flag
540+
info['is_async'] = asyncio.iscoroutinefunction(func)
541+
511542
# Setup argument types for rowdat_1 parser
512543
colspec = []
513544
for x in sig['args']:
@@ -927,18 +958,28 @@ async def __call__(
927958

928959
cancel_event = threading.Event()
929960

930-
func_task = asyncio.create_task(
931-
to_thread(
932-
lambda: asyncio.run(
933-
func(
934-
cancel_event,
935-
*input_handler['load']( # type: ignore
936-
func_info['colspec'], b''.join(data),
961+
if func_info['is_async']:
962+
func_task = asyncio.create_task(
963+
func(
964+
cancel_event,
965+
*input_handler['load']( # type: ignore
966+
func_info['colspec'], b''.join(data),
967+
),
968+
),
969+
)
970+
else:
971+
func_task = asyncio.create_task(
972+
to_thread(
973+
lambda: asyncio.run(
974+
func(
975+
cancel_event,
976+
*input_handler['load']( # type: ignore
977+
func_info['colspec'], b''.join(data),
978+
),
937979
),
938980
),
939981
),
940-
),
941-
)
982+
)
942983
disconnect_task = asyncio.create_task(
943984
cancel_on_disconnect(receive),
944985
)
@@ -970,6 +1011,7 @@ async def __call__(
9701011
elif task is func_task:
9711012
result.extend(task.result())
9721013

1014+
print(result)
9731015
body = output_handler['dump'](
9741016
[x[1] for x in func_info['returns']], *result, # type: ignore
9751017
)

0 commit comments

Comments
 (0)